Add panic recovery and request size limit
This commit is contained in:
@@ -29,7 +29,9 @@ func (s *GatewayServer) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(status)
|
||||
if err := json.NewEncoder(w).Encode(status); err != nil {
|
||||
s.logger.ErrorContext(r.Context(), "failed to encode health response", "error", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// handleReady returns a readiness check that verifies dependencies.
|
||||
@@ -83,5 +85,7 @@ func (s *GatewayServer) handleReady(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}
|
||||
|
||||
_ = json.NewEncoder(w).Encode(status)
|
||||
if err := json.NewEncoder(w).Encode(status); err != nil {
|
||||
s.logger.ErrorContext(r.Context(), "failed to encode ready response", "error", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
91
internal/server/middleware.go
Normal file
91
internal/server/middleware.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/ajac-zero/latticelm/internal/logger"
|
||||
)
|
||||
|
||||
// MaxRequestBodyBytes is the maximum size allowed for request bodies (10MB)
|
||||
const MaxRequestBodyBytes = 10 * 1024 * 1024
|
||||
|
||||
// PanicRecoveryMiddleware recovers from panics in HTTP handlers and logs them
|
||||
// instead of crashing the server. Returns 500 Internal Server Error to the client.
|
||||
func PanicRecoveryMiddleware(next http.Handler, log *slog.Logger) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
// Capture stack trace
|
||||
stack := debug.Stack()
|
||||
|
||||
// Log the panic with full context
|
||||
log.ErrorContext(r.Context(), "panic recovered in HTTP handler",
|
||||
logger.LogAttrsWithTrace(r.Context(),
|
||||
slog.String("request_id", logger.FromContext(r.Context())),
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
slog.String("remote_addr", r.RemoteAddr),
|
||||
slog.Any("panic", err),
|
||||
slog.String("stack", string(stack)),
|
||||
)...,
|
||||
)
|
||||
|
||||
// Return 500 to client
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// RequestSizeLimitMiddleware enforces a maximum request body size to prevent
|
||||
// DoS attacks via oversized payloads. Requests exceeding the limit receive 413.
|
||||
func RequestSizeLimitMiddleware(next http.Handler, maxBytes int64) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Only limit body size for requests that have a body
|
||||
if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodPatch {
|
||||
// Wrap the request body with a size limiter
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// ErrorRecoveryMiddleware catches errors from MaxBytesReader and converts them
|
||||
// to proper HTTP error responses. This should be placed after RequestSizeLimitMiddleware
|
||||
// in the middleware chain.
|
||||
func ErrorRecoveryMiddleware(next http.Handler, log *slog.Logger) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
|
||||
// Check if the request body exceeded the size limit
|
||||
// MaxBytesReader sets an error that we can detect on the next read attempt
|
||||
// But we need to handle the error when it actually occurs during JSON decoding
|
||||
// The JSON decoder will return the error, so we don't need special handling here
|
||||
// This middleware is more for future extensibility
|
||||
})
|
||||
}
|
||||
|
||||
// WriteJSONError is a helper function to safely write JSON error responses,
|
||||
// handling any encoding errors that might occur.
|
||||
func WriteJSONError(w http.ResponseWriter, log *slog.Logger, message string, statusCode int) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
// Use fmt.Fprintf to write the error response
|
||||
// This is safer than json.Encoder as we control the format
|
||||
_, err := fmt.Fprintf(w, `{"error":{"message":"%s"}}`, message)
|
||||
if err != nil {
|
||||
// If we can't even write the error response, log it
|
||||
log.Error("failed to write error response",
|
||||
slog.String("original_message", message),
|
||||
slog.Int("status_code", statusCode),
|
||||
slog.String("write_error", err.Error()),
|
||||
)
|
||||
}
|
||||
}
|
||||
341
internal/server/middleware_test.go
Normal file
341
internal/server/middleware_test.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPanicRecoveryMiddleware(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
expectPanic bool
|
||||
expectedStatus int
|
||||
expectedBody string
|
||||
}{
|
||||
{
|
||||
name: "no panic - request succeeds",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("success"))
|
||||
},
|
||||
expectPanic: false,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: "success",
|
||||
},
|
||||
{
|
||||
name: "panic with string - recovers gracefully",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("something went wrong")
|
||||
},
|
||||
expectPanic: true,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "Internal Server Error\n",
|
||||
},
|
||||
{
|
||||
name: "panic with error - recovers gracefully",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
panic(io.ErrUnexpectedEOF)
|
||||
},
|
||||
expectPanic: true,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "Internal Server Error\n",
|
||||
},
|
||||
{
|
||||
name: "panic with struct - recovers gracefully",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
panic(struct{ msg string }{msg: "bad things"})
|
||||
},
|
||||
expectPanic: true,
|
||||
expectedStatus: http.StatusInternalServerError,
|
||||
expectedBody: "Internal Server Error\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a buffer to capture logs
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewJSONHandler(&buf, nil))
|
||||
|
||||
// Wrap the handler with panic recovery
|
||||
wrapped := PanicRecoveryMiddleware(tt.handler, logger)
|
||||
|
||||
// Create request and recorder
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute the handler (should not panic even if inner handler does)
|
||||
wrapped.ServeHTTP(rec, req)
|
||||
|
||||
// Verify response
|
||||
assert.Equal(t, tt.expectedStatus, rec.Code)
|
||||
assert.Equal(t, tt.expectedBody, rec.Body.String())
|
||||
|
||||
// Verify logging if panic was expected
|
||||
if tt.expectPanic {
|
||||
logOutput := buf.String()
|
||||
assert.Contains(t, logOutput, "panic recovered in HTTP handler")
|
||||
assert.Contains(t, logOutput, "stack")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestSizeLimitMiddleware(t *testing.T) {
|
||||
const maxSize = 100 // 100 bytes for testing
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
bodySize int
|
||||
expectedStatus int
|
||||
shouldSucceed bool
|
||||
}{
|
||||
{
|
||||
name: "small POST request - succeeds",
|
||||
method: http.MethodPost,
|
||||
bodySize: 50,
|
||||
expectedStatus: http.StatusOK,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "exact size POST request - succeeds",
|
||||
method: http.MethodPost,
|
||||
bodySize: maxSize,
|
||||
expectedStatus: http.StatusOK,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "oversized POST request - fails",
|
||||
method: http.MethodPost,
|
||||
bodySize: maxSize + 1,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "large POST request - fails",
|
||||
method: http.MethodPost,
|
||||
bodySize: maxSize * 2,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "oversized PUT request - fails",
|
||||
method: http.MethodPut,
|
||||
bodySize: maxSize + 1,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "oversized PATCH request - fails",
|
||||
method: http.MethodPatch,
|
||||
bodySize: maxSize + 1,
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
shouldSucceed: false,
|
||||
},
|
||||
{
|
||||
name: "GET request - no size limit applied",
|
||||
method: http.MethodGet,
|
||||
bodySize: maxSize + 1,
|
||||
expectedStatus: http.StatusOK,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
{
|
||||
name: "DELETE request - no size limit applied",
|
||||
method: http.MethodDelete,
|
||||
bodySize: maxSize + 1,
|
||||
expectedStatus: http.StatusOK,
|
||||
shouldSucceed: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a handler that tries to read the body
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
fmt.Fprintf(w, "read %d bytes", len(body))
|
||||
})
|
||||
|
||||
// Wrap with size limit middleware
|
||||
wrapped := RequestSizeLimitMiddleware(handler, maxSize)
|
||||
|
||||
// Create request with body of specified size
|
||||
bodyContent := strings.Repeat("a", tt.bodySize)
|
||||
req := httptest.NewRequest(tt.method, "/test", strings.NewReader(bodyContent))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
wrapped.ServeHTTP(rec, req)
|
||||
|
||||
// Verify response
|
||||
assert.Equal(t, tt.expectedStatus, rec.Code)
|
||||
|
||||
if tt.shouldSucceed {
|
||||
assert.Contains(t, rec.Body.String(), "read")
|
||||
} else {
|
||||
// For methods with body, should get an error
|
||||
assert.NotContains(t, rec.Body.String(), "read")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestSizeLimitMiddleware_WithJSONDecoding(t *testing.T) {
|
||||
const maxSize = 1024 // 1KB
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload interface{}
|
||||
expectedStatus int
|
||||
shouldDecode bool
|
||||
}{
|
||||
{
|
||||
name: "small JSON payload - succeeds",
|
||||
payload: map[string]string{
|
||||
"message": "hello",
|
||||
},
|
||||
expectedStatus: http.StatusOK,
|
||||
shouldDecode: true,
|
||||
},
|
||||
{
|
||||
name: "large JSON payload - fails",
|
||||
payload: map[string]string{
|
||||
"message": strings.Repeat("x", maxSize+100),
|
||||
},
|
||||
expectedStatus: http.StatusBadRequest,
|
||||
shouldDecode: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a handler that decodes JSON
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var data map[string]string
|
||||
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]string{"status": "decoded"})
|
||||
})
|
||||
|
||||
// Wrap with size limit middleware
|
||||
wrapped := RequestSizeLimitMiddleware(handler, maxSize)
|
||||
|
||||
// Create request
|
||||
body, err := json.Marshal(tt.payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Execute
|
||||
wrapped.ServeHTTP(rec, req)
|
||||
|
||||
// Verify response
|
||||
assert.Equal(t, tt.expectedStatus, rec.Code)
|
||||
|
||||
if tt.shouldDecode {
|
||||
assert.Contains(t, rec.Body.String(), "decoded")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteJSONError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
message string
|
||||
statusCode int
|
||||
expectedBody string
|
||||
}{
|
||||
{
|
||||
name: "simple error message",
|
||||
message: "something went wrong",
|
||||
statusCode: http.StatusBadRequest,
|
||||
expectedBody: `{"error":{"message":"something went wrong"}}`,
|
||||
},
|
||||
{
|
||||
name: "internal server error",
|
||||
message: "internal error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
expectedBody: `{"error":{"message":"internal error"}}`,
|
||||
},
|
||||
{
|
||||
name: "unauthorized error",
|
||||
message: "unauthorized",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
expectedBody: `{"error":{"message":"unauthorized"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewJSONHandler(&buf, nil))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
WriteJSONError(rec, logger, tt.message, tt.statusCode)
|
||||
|
||||
assert.Equal(t, tt.statusCode, rec.Code)
|
||||
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
||||
assert.Equal(t, tt.expectedBody, rec.Body.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPanicRecoveryMiddleware_Integration(t *testing.T) {
|
||||
// Test that panic recovery works in a more realistic scenario
|
||||
// with multiple middleware layers
|
||||
var logBuf bytes.Buffer
|
||||
logger := slog.New(slog.NewJSONHandler(&logBuf, nil))
|
||||
|
||||
// Create a chain of middleware
|
||||
finalHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Simulate a panic deep in the stack
|
||||
panic("unexpected error in business logic")
|
||||
})
|
||||
|
||||
// Wrap with multiple middleware layers
|
||||
wrapped := PanicRecoveryMiddleware(
|
||||
RequestSizeLimitMiddleware(
|
||||
finalHandler,
|
||||
1024,
|
||||
),
|
||||
logger,
|
||||
)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", strings.NewReader("test"))
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
// Should not panic
|
||||
wrapped.ServeHTTP(rec, req)
|
||||
|
||||
// Should return 500
|
||||
assert.Equal(t, http.StatusInternalServerError, rec.Code)
|
||||
assert.Equal(t, "Internal Server Error\n", rec.Body.String())
|
||||
|
||||
// Should log the panic
|
||||
logOutput := logBuf.String()
|
||||
assert.Contains(t, logOutput, "panic recovered")
|
||||
assert.Contains(t, logOutput, "unexpected error in business logic")
|
||||
}
|
||||
@@ -69,7 +69,14 @@ func (s *GatewayServer) handleModels(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.ErrorContext(r.Context(), "failed to encode models response",
|
||||
logger.LogAttrsWithTrace(r.Context(),
|
||||
slog.String("request_id", logger.FromContext(r.Context())),
|
||||
slog.String("error", err.Error()),
|
||||
)...,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -80,6 +87,11 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
var req api.ResponseRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
// Check if error is due to request size limit
|
||||
if err.Error() == "http: request body too large" {
|
||||
http.Error(w, "request body too large", http.StatusRequestEntityTooLarge)
|
||||
return
|
||||
}
|
||||
http.Error(w, "invalid JSON payload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@@ -202,7 +214,15 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
s.logger.ErrorContext(r.Context(), "failed to encode response",
|
||||
logger.LogAttrsWithTrace(r.Context(),
|
||||
slog.String("request_id", logger.FromContext(r.Context())),
|
||||
slog.String("response_id", responseID),
|
||||
slog.String("error", err.Error()),
|
||||
)...,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayServer) handleStreamingResponse(w http.ResponseWriter, r *http.Request, provider providers.Provider, providerMsgs []api.Message, resolvedReq *api.ResponseRequest, origReq *api.ResponseRequest, storeMsgs []api.Message) {
|
||||
|
||||
Reference in New Issue
Block a user