Add CI and production grade improvements #3
@@ -110,7 +110,7 @@ func main() {
|
|||||||
Issuer: cfg.Auth.Issuer,
|
Issuer: cfg.Auth.Issuer,
|
||||||
Audience: cfg.Auth.Audience,
|
Audience: cfg.Auth.Audience,
|
||||||
}
|
}
|
||||||
authMiddleware, err := auth.New(authConfig)
|
authMiddleware, err := auth.New(authConfig, logger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("failed to initialize auth", slog.String("error", err.Error()))
|
logger.Error("failed to initialize auth", slog.String("error", err.Error()))
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -28,12 +29,13 @@ type Middleware struct {
|
|||||||
keys map[string]*rsa.PublicKey
|
keys map[string]*rsa.PublicKey
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
client *http.Client
|
client *http.Client
|
||||||
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates an authentication middleware.
|
// New creates an authentication middleware.
|
||||||
func New(cfg Config) (*Middleware, error) {
|
func New(cfg Config, logger *slog.Logger) (*Middleware, error) {
|
||||||
if !cfg.Enabled {
|
if !cfg.Enabled {
|
||||||
return &Middleware{cfg: cfg}, nil
|
return &Middleware{cfg: cfg, logger: logger}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Issuer == "" {
|
if cfg.Issuer == "" {
|
||||||
@@ -44,6 +46,7 @@ func New(cfg Config) (*Middleware, error) {
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
keys: make(map[string]*rsa.PublicKey),
|
keys: make(map[string]*rsa.PublicKey),
|
||||||
client: &http.Client{Timeout: 10 * time.Second},
|
client: &http.Client{Timeout: 10 * time.Second},
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch JWKS on startup
|
// Fetch JWKS on startup
|
||||||
@@ -255,6 +258,15 @@ func (m *Middleware) periodicRefresh() {
|
|||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
_ = m.refreshJWKS()
|
if err := m.refreshJWKS(); err != nil {
|
||||||
|
m.logger.Error("failed to refresh JWKS",
|
||||||
|
slog.String("issuer", m.cfg.Issuer),
|
||||||
|
slog.String("error", err.Error()),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
m.logger.Debug("successfully refreshed JWKS",
|
||||||
|
slog.String("issuer", m.cfg.Issuer),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -213,7 +214,7 @@ func TestNew(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := New(tt.config)
|
m, err := New(tt.config, slog.Default())
|
||||||
|
|
||||||
if tt.expectError {
|
if tt.expectError {
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
@@ -239,7 +240,7 @@ func TestMiddleware_Handler(t *testing.T) {
|
|||||||
Issuer: server.server.URL,
|
Issuer: server.server.URL,
|
||||||
Audience: testAudience,
|
Audience: testAudience,
|
||||||
}
|
}
|
||||||
m, err := New(cfg)
|
m, err := New(cfg, slog.Default())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Create a test handler that echoes back claims
|
// Create a test handler that echoes back claims
|
||||||
@@ -415,7 +416,7 @@ func TestMiddleware_Handler_DisabledAuth(t *testing.T) {
|
|||||||
cfg := Config{
|
cfg := Config{
|
||||||
Enabled: false,
|
Enabled: false,
|
||||||
}
|
}
|
||||||
m, err := New(cfg)
|
m, err := New(cfg, slog.Default())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -442,7 +443,7 @@ func TestValidateToken(t *testing.T) {
|
|||||||
Issuer: server.server.URL,
|
Issuer: server.server.URL,
|
||||||
Audience: testAudience,
|
Audience: testAudience,
|
||||||
}
|
}
|
||||||
m, err := New(cfg)
|
m, err := New(cfg, slog.Default())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -665,7 +666,7 @@ func TestValidateToken_NoAudienceConfigured(t *testing.T) {
|
|||||||
Issuer: server.server.URL,
|
Issuer: server.server.URL,
|
||||||
Audience: "", // No audience required
|
Audience: "", // No audience required
|
||||||
}
|
}
|
||||||
m, err := New(cfg)
|
m, err := New(cfg, slog.Default())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Token without audience should be valid
|
// Token without audience should be valid
|
||||||
@@ -897,7 +898,7 @@ func TestRefreshJWKS_Concurrency(t *testing.T) {
|
|||||||
Issuer: server.server.URL,
|
Issuer: server.server.URL,
|
||||||
Audience: testAudience,
|
Audience: testAudience,
|
||||||
}
|
}
|
||||||
m, err := New(cfg)
|
m, err := New(cfg, slog.Default())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Trigger concurrent refreshes
|
// Trigger concurrent refreshes
|
||||||
@@ -982,7 +983,7 @@ func TestMiddleware_IssuerWithTrailingSlash(t *testing.T) {
|
|||||||
Issuer: server.server.URL + "/", // Trailing slash
|
Issuer: server.server.URL + "/", // Trailing slash
|
||||||
Audience: testAudience,
|
Audience: testAudience,
|
||||||
}
|
}
|
||||||
m, err := New(cfg)
|
m, err := New(cfg, slog.Default())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, m)
|
require.NotNil(t, m)
|
||||||
assert.Len(t, m.keys, 1)
|
assert.Len(t, m.keys, 1)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package conversation
|
package conversation
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -9,10 +10,10 @@ import (
|
|||||||
|
|
||||||
// Store defines the interface for conversation storage backends.
|
// Store defines the interface for conversation storage backends.
|
||||||
type Store interface {
|
type Store interface {
|
||||||
Get(id string) (*Conversation, error)
|
Get(ctx context.Context, id string) (*Conversation, error)
|
||||||
Create(id string, model string, messages []api.Message) (*Conversation, error)
|
Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error)
|
||||||
Append(id string, messages ...api.Message) (*Conversation, error)
|
Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error)
|
||||||
Delete(id string) error
|
Delete(ctx context.Context, id string) error
|
||||||
Size() int
|
Size() int
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
@@ -51,7 +52,7 @@ func NewMemoryStore(ttl time.Duration) *MemoryStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
|
// Get retrieves a conversation by ID. Returns a deep copy to prevent data races.
|
||||||
func (s *MemoryStore) Get(id string) (*Conversation, error) {
|
func (s *MemoryStore) Get(ctx context.Context, id string) (*Conversation, error) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ func (s *MemoryStore) Get(id string) (*Conversation, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new conversation with the given messages.
|
// Create creates a new conversation with the given messages.
|
||||||
func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
func (s *MemoryStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
@@ -105,7 +106,7 @@ func (s *MemoryStore) Create(id string, model string, messages []api.Message) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Append adds new messages to an existing conversation.
|
// Append adds new messages to an existing conversation.
|
||||||
func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
func (s *MemoryStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
@@ -131,7 +132,7 @@ func (s *MemoryStore) Append(id string, messages ...api.Message) (*Conversation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a conversation from the store.
|
// Delete removes a conversation from the store.
|
||||||
func (s *MemoryStore) Delete(id string) error {
|
func (s *MemoryStore) Delete(ctx context.Context, id string) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package conversation
|
package conversation
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
conv, err := store.Create("test-id", "gpt-4", messages)
|
conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, conv)
|
require.NotNil(t, conv)
|
||||||
assert.Equal(t, "test-id", conv.ID)
|
assert.Equal(t, "test-id", conv.ID)
|
||||||
@@ -29,7 +30,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
|
|||||||
assert.Len(t, conv.Messages, 1)
|
assert.Len(t, conv.Messages, 1)
|
||||||
assert.Equal(t, "Hello", conv.Messages[0].Content[0].Text)
|
assert.Equal(t, "Hello", conv.Messages[0].Content[0].Text)
|
||||||
|
|
||||||
retrieved, err := store.Get("test-id")
|
retrieved, err := store.Get(context.Background(),"test-id")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, retrieved)
|
require.NotNil(t, retrieved)
|
||||||
assert.Equal(t, conv.ID, retrieved.ID)
|
assert.Equal(t, conv.ID, retrieved.ID)
|
||||||
@@ -40,7 +41,7 @@ func TestMemoryStore_CreateAndGet(t *testing.T) {
|
|||||||
func TestMemoryStore_GetNonExistent(t *testing.T) {
|
func TestMemoryStore_GetNonExistent(t *testing.T) {
|
||||||
store := NewMemoryStore(1 * time.Hour)
|
store := NewMemoryStore(1 * time.Hour)
|
||||||
|
|
||||||
conv, err := store.Get("nonexistent")
|
conv, err := store.Get(context.Background(),"nonexistent")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Nil(t, conv, "should return nil for nonexistent conversation")
|
assert.Nil(t, conv, "should return nil for nonexistent conversation")
|
||||||
}
|
}
|
||||||
@@ -57,7 +58,7 @@ func TestMemoryStore_Append(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := store.Create("test-id", "gpt-4", initialMessages)
|
_, err := store.Create(context.Background(),"test-id", "gpt-4", initialMessages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
newMessages := []api.Message{
|
newMessages := []api.Message{
|
||||||
@@ -75,7 +76,7 @@ func TestMemoryStore_Append(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
conv, err := store.Append("test-id", newMessages...)
|
conv, err := store.Append(context.Background(),"test-id", newMessages...)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, conv)
|
require.NotNil(t, conv)
|
||||||
assert.Len(t, conv.Messages, 3, "should have all messages")
|
assert.Len(t, conv.Messages, 3, "should have all messages")
|
||||||
@@ -94,7 +95,7 @@ func TestMemoryStore_AppendNonExistent(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
conv, err := store.Append("nonexistent", newMessage)
|
conv, err := store.Append(context.Background(),"nonexistent", newMessage)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Nil(t, conv, "should return nil when appending to nonexistent conversation")
|
assert.Nil(t, conv, "should return nil when appending to nonexistent conversation")
|
||||||
}
|
}
|
||||||
@@ -111,20 +112,20 @@ func TestMemoryStore_Delete(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := store.Create("test-id", "gpt-4", messages)
|
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify it exists
|
// Verify it exists
|
||||||
conv, err := store.Get("test-id")
|
conv, err := store.Get(context.Background(),"test-id")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, conv)
|
assert.NotNil(t, conv)
|
||||||
|
|
||||||
// Delete it
|
// Delete it
|
||||||
err = store.Delete("test-id")
|
err = store.Delete(context.Background(),"test-id")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify it's gone
|
// Verify it's gone
|
||||||
conv, err = store.Get("test-id")
|
conv, err = store.Get(context.Background(),"test-id")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Nil(t, conv, "conversation should be deleted")
|
assert.Nil(t, conv, "conversation should be deleted")
|
||||||
}
|
}
|
||||||
@@ -138,15 +139,15 @@ func TestMemoryStore_Size(t *testing.T) {
|
|||||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := store.Create("conv-1", "gpt-4", messages)
|
_, err := store.Create(context.Background(),"conv-1", "gpt-4", messages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, store.Size())
|
assert.Equal(t, 1, store.Size())
|
||||||
|
|
||||||
_, err = store.Create("conv-2", "gpt-4", messages)
|
_, err = store.Create(context.Background(),"conv-2", "gpt-4", messages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 2, store.Size())
|
assert.Equal(t, 2, store.Size())
|
||||||
|
|
||||||
err = store.Delete("conv-1")
|
err = store.Delete(context.Background(),"conv-1")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, store.Size())
|
assert.Equal(t, 1, store.Size())
|
||||||
}
|
}
|
||||||
@@ -159,14 +160,14 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create initial conversation
|
// Create initial conversation
|
||||||
_, err := store.Create("test-id", "gpt-4", messages)
|
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Simulate concurrent reads and writes
|
// Simulate concurrent reads and writes
|
||||||
done := make(chan bool, 10)
|
done := make(chan bool, 10)
|
||||||
for i := 0; i < 5; i++ {
|
for i := 0; i < 5; i++ {
|
||||||
go func() {
|
go func() {
|
||||||
_, _ = store.Get("test-id")
|
_, _ = store.Get(context.Background(),"test-id")
|
||||||
done <- true
|
done <- true
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -176,7 +177,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
|
|||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
|
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
|
||||||
}
|
}
|
||||||
_, _ = store.Append("test-id", newMsg)
|
_, _ = store.Append(context.Background(),"test-id", newMsg)
|
||||||
done <- true
|
done <- true
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -187,7 +188,7 @@ func TestMemoryStore_ConcurrentAccess(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify final state
|
// Verify final state
|
||||||
conv, err := store.Get("test-id")
|
conv, err := store.Get(context.Background(),"test-id")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, conv)
|
assert.NotNil(t, conv)
|
||||||
assert.GreaterOrEqual(t, len(conv.Messages), 1)
|
assert.GreaterOrEqual(t, len(conv.Messages), 1)
|
||||||
@@ -205,11 +206,11 @@ func TestMemoryStore_DeepCopy(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := store.Create("test-id", "gpt-4", messages)
|
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Get conversation
|
// Get conversation
|
||||||
conv1, err := store.Get("test-id")
|
conv1, err := store.Get(context.Background(),"test-id")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Note: Current implementation copies the Messages slice but not the Content blocks
|
// Note: Current implementation copies the Messages slice but not the Content blocks
|
||||||
@@ -225,7 +226,7 @@ func TestMemoryStore_DeepCopy(t *testing.T) {
|
|||||||
assert.Equal(t, originalLen+1, len(conv1.Messages), "can modify returned message slice")
|
assert.Equal(t, originalLen+1, len(conv1.Messages), "can modify returned message slice")
|
||||||
|
|
||||||
// Verify original is unchanged
|
// Verify original is unchanged
|
||||||
conv2, err := store.Get("test-id")
|
conv2, err := store.Get(context.Background(),"test-id")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, originalLen, len(conv2.Messages), "original conversation unaffected by slice modification")
|
assert.Equal(t, originalLen, len(conv2.Messages), "original conversation unaffected by slice modification")
|
||||||
}
|
}
|
||||||
@@ -238,11 +239,11 @@ func TestMemoryStore_TTLCleanup(t *testing.T) {
|
|||||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := store.Create("test-id", "gpt-4", messages)
|
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify it exists
|
// Verify it exists
|
||||||
conv, err := store.Get("test-id")
|
conv, err := store.Get(context.Background(),"test-id")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, conv)
|
assert.NotNil(t, conv)
|
||||||
assert.Equal(t, 1, store.Size())
|
assert.Equal(t, 1, store.Size())
|
||||||
@@ -265,12 +266,12 @@ func TestMemoryStore_NoTTL(t *testing.T) {
|
|||||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := store.Create("test-id", "gpt-4", messages)
|
_, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, 1, store.Size())
|
assert.Equal(t, 1, store.Size())
|
||||||
|
|
||||||
// Without TTL, conversation should persist indefinitely
|
// Without TTL, conversation should persist indefinitely
|
||||||
conv, err := store.Get("test-id")
|
conv, err := store.Get(context.Background(),"test-id")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, conv)
|
assert.NotNil(t, conv)
|
||||||
}
|
}
|
||||||
@@ -282,7 +283,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) {
|
|||||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello"}}},
|
||||||
}
|
}
|
||||||
|
|
||||||
conv, err := store.Create("test-id", "gpt-4", messages)
|
conv, err := store.Create(context.Background(),"test-id", "gpt-4", messages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
createdAt := conv.CreatedAt
|
createdAt := conv.CreatedAt
|
||||||
updatedAt := conv.UpdatedAt
|
updatedAt := conv.UpdatedAt
|
||||||
@@ -296,7 +297,7 @@ func TestMemoryStore_UpdatedAtTracking(t *testing.T) {
|
|||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
|
Content: []api.ContentBlock{{Type: "output_text", Text: "Response"}},
|
||||||
}
|
}
|
||||||
conv, err = store.Append("test-id", newMsg)
|
conv, err = store.Append(context.Background(),"test-id", newMsg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Equal(t, createdAt, conv.CreatedAt, "created time should not change")
|
assert.Equal(t, createdAt, conv.CreatedAt, "created time should not change")
|
||||||
@@ -313,7 +314,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) {
|
|||||||
messages := []api.Message{
|
messages := []api.Message{
|
||||||
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello " + id}}},
|
{Role: "user", Content: []api.ContentBlock{{Type: "input_text", Text: "Hello " + id}}},
|
||||||
}
|
}
|
||||||
_, err := store.Create(id, model, messages)
|
_, err := store.Create(context.Background(),id, model, messages)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -322,7 +323,7 @@ func TestMemoryStore_MultipleConversations(t *testing.T) {
|
|||||||
// Verify each conversation is independent
|
// Verify each conversation is independent
|
||||||
for i := 0; i < 10; i++ {
|
for i := 0; i < 10; i++ {
|
||||||
id := "conv-" + string(rune('0'+i))
|
id := "conv-" + string(rune('0'+i))
|
||||||
conv, err := store.Get(id)
|
conv, err := store.Get(context.Background(),id)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, conv)
|
require.NotNil(t, conv)
|
||||||
assert.Equal(t, id, conv.ID)
|
assert.Equal(t, id, conv.ID)
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
type RedisStore struct {
|
type RedisStore struct {
|
||||||
client *redis.Client
|
client *redis.Client
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
ctx context.Context
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRedisStore creates a Redis-backed conversation store.
|
// NewRedisStore creates a Redis-backed conversation store.
|
||||||
@@ -21,7 +20,6 @@ func NewRedisStore(client *redis.Client, ttl time.Duration) *RedisStore {
|
|||||||
return &RedisStore{
|
return &RedisStore{
|
||||||
client: client,
|
client: client,
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
ctx: context.Background(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,8 +29,8 @@ func (s *RedisStore) key(id string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get retrieves a conversation by ID from Redis.
|
// Get retrieves a conversation by ID from Redis.
|
||||||
func (s *RedisStore) Get(id string) (*Conversation, error) {
|
func (s *RedisStore) Get(ctx context.Context, id string) (*Conversation, error) {
|
||||||
data, err := s.client.Get(s.ctx, s.key(id)).Bytes()
|
data, err := s.client.Get(ctx, s.key(id)).Bytes()
|
||||||
if err == redis.Nil {
|
if err == redis.Nil {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@@ -49,7 +47,7 @@ func (s *RedisStore) Get(id string) (*Conversation, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create creates a new conversation with the given messages.
|
// Create creates a new conversation with the given messages.
|
||||||
func (s *RedisStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
func (s *RedisStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
conv := &Conversation{
|
conv := &Conversation{
|
||||||
ID: id,
|
ID: id,
|
||||||
@@ -64,7 +62,7 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,8 +70,8 @@ func (s *RedisStore) Create(id string, model string, messages []api.Message) (*C
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Append adds new messages to an existing conversation.
|
// Append adds new messages to an existing conversation.
|
||||||
func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
func (s *RedisStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
|
||||||
conv, err := s.Get(id)
|
conv, err := s.Get(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -89,7 +87,7 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.client.Set(s.ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
if err := s.client.Set(ctx, s.key(id), data, s.ttl).Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -97,17 +95,18 @@ func (s *RedisStore) Append(id string, messages ...api.Message) (*Conversation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Delete removes a conversation from Redis.
|
// Delete removes a conversation from Redis.
|
||||||
func (s *RedisStore) Delete(id string) error {
|
func (s *RedisStore) Delete(ctx context.Context, id string) error {
|
||||||
return s.client.Del(s.ctx, s.key(id)).Err()
|
return s.client.Del(ctx, s.key(id)).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size returns the number of active conversations in Redis.
|
// Size returns the number of active conversations in Redis.
|
||||||
func (s *RedisStore) Size() int {
|
func (s *RedisStore) Size() int {
|
||||||
var count int
|
var count int
|
||||||
var cursor uint64
|
var cursor uint64
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
keys, nextCursor, err := s.client.Scan(s.ctx, cursor, "conv:*", 100).Result()
|
keys, nextCursor, err := s.client.Scan(ctx, cursor, "conv:*", 100).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package conversation
|
package conversation
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
@@ -71,8 +72,8 @@ func NewSQLStore(db *sql.DB, driver string, ttl time.Duration) (*SQLStore, error
|
|||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) Get(id string) (*Conversation, error) {
|
func (s *SQLStore) Get(ctx context.Context, id string) (*Conversation, error) {
|
||||||
row := s.db.QueryRow(s.dialect.getByID, id)
|
row := s.db.QueryRowContext(ctx, s.dialect.getByID, id)
|
||||||
|
|
||||||
var conv Conversation
|
var conv Conversation
|
||||||
var msgJSON string
|
var msgJSON string
|
||||||
@@ -91,14 +92,14 @@ func (s *SQLStore) Get(id string) (*Conversation, error) {
|
|||||||
return &conv, nil
|
return &conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Conversation, error) {
|
func (s *SQLStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*Conversation, error) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
msgJSON, err := json.Marshal(messages)
|
msgJSON, err := json.Marshal(messages)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.db.Exec(s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
|
if _, err := s.db.ExecContext(ctx, s.dialect.upsert, id, model, string(msgJSON), now, now); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,8 +112,8 @@ func (s *SQLStore) Create(id string, model string, messages []api.Message) (*Con
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, error) {
|
func (s *SQLStore) Append(ctx context.Context, id string, messages ...api.Message) (*Conversation, error) {
|
||||||
conv, err := s.Get(id)
|
conv, err := s.Get(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -128,15 +129,15 @@ func (s *SQLStore) Append(id string, messages ...api.Message) (*Conversation, er
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := s.db.Exec(s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
|
if _, err := s.db.ExecContext(ctx, s.dialect.update, string(msgJSON), conv.UpdatedAt, id); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return conv, nil
|
return conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQLStore) Delete(id string) error {
|
func (s *SQLStore) Delete(ctx context.Context, id string) error {
|
||||||
_, err := s.db.Exec(s.dialect.deleteByID, id)
|
_, err := s.db.ExecContext(ctx, s.dialect.deleteByID, id)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -42,9 +42,7 @@ func NewInstrumentedStore(s conversation.Store, backend string, registry *promet
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get wraps the store's Get method with metrics and tracing.
|
// Get wraps the store's Get method with metrics and tracing.
|
||||||
func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) {
|
func (s *InstrumentedStore) Get(ctx context.Context, id string) (*conversation.Conversation, error) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Start span if tracing is enabled
|
// Start span if tracing is enabled
|
||||||
if s.tracer != nil {
|
if s.tracer != nil {
|
||||||
var span trace.Span
|
var span trace.Span
|
||||||
@@ -61,7 +59,7 @@ func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
// Call underlying store
|
// Call underlying store
|
||||||
conv, err := s.base.Get(id)
|
conv, err := s.base.Get(ctx, id)
|
||||||
|
|
||||||
// Record metrics
|
// Record metrics
|
||||||
duration := time.Since(start).Seconds()
|
duration := time.Since(start).Seconds()
|
||||||
@@ -95,9 +93,7 @@ func (s *InstrumentedStore) Get(id string) (*conversation.Conversation, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create wraps the store's Create method with metrics and tracing.
|
// Create wraps the store's Create method with metrics and tracing.
|
||||||
func (s *InstrumentedStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) {
|
func (s *InstrumentedStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*conversation.Conversation, error) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Start span if tracing is enabled
|
// Start span if tracing is enabled
|
||||||
if s.tracer != nil {
|
if s.tracer != nil {
|
||||||
var span trace.Span
|
var span trace.Span
|
||||||
@@ -116,7 +112,7 @@ func (s *InstrumentedStore) Create(id string, model string, messages []api.Messa
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
// Call underlying store
|
// Call underlying store
|
||||||
conv, err := s.base.Create(id, model, messages)
|
conv, err := s.base.Create(ctx, id, model, messages)
|
||||||
|
|
||||||
// Record metrics
|
// Record metrics
|
||||||
duration := time.Since(start).Seconds()
|
duration := time.Since(start).Seconds()
|
||||||
@@ -146,9 +142,7 @@ func (s *InstrumentedStore) Create(id string, model string, messages []api.Messa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Append wraps the store's Append method with metrics and tracing.
|
// Append wraps the store's Append method with metrics and tracing.
|
||||||
func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) {
|
func (s *InstrumentedStore) Append(ctx context.Context, id string, messages ...api.Message) (*conversation.Conversation, error) {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Start span if tracing is enabled
|
// Start span if tracing is enabled
|
||||||
if s.tracer != nil {
|
if s.tracer != nil {
|
||||||
var span trace.Span
|
var span trace.Span
|
||||||
@@ -166,7 +160,7 @@ func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*convers
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
// Call underlying store
|
// Call underlying store
|
||||||
conv, err := s.base.Append(id, messages...)
|
conv, err := s.base.Append(ctx, id, messages...)
|
||||||
|
|
||||||
// Record metrics
|
// Record metrics
|
||||||
duration := time.Since(start).Seconds()
|
duration := time.Since(start).Seconds()
|
||||||
@@ -199,9 +193,7 @@ func (s *InstrumentedStore) Append(id string, messages ...api.Message) (*convers
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Delete wraps the store's Delete method with metrics and tracing.
|
// Delete wraps the store's Delete method with metrics and tracing.
|
||||||
func (s *InstrumentedStore) Delete(id string) error {
|
func (s *InstrumentedStore) Delete(ctx context.Context, id string) error {
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Start span if tracing is enabled
|
// Start span if tracing is enabled
|
||||||
if s.tracer != nil {
|
if s.tracer != nil {
|
||||||
var span trace.Span
|
var span trace.Span
|
||||||
@@ -218,7 +210,7 @@ func (s *InstrumentedStore) Delete(id string) error {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
|
|
||||||
// Call underlying store
|
// Call underlying store
|
||||||
err := s.base.Delete(id)
|
err := s.base.Delete(ctx, id)
|
||||||
|
|
||||||
// Record metrics
|
// Record metrics
|
||||||
duration := time.Since(start).Seconds()
|
duration := time.Since(start).Seconds()
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ func (s *GatewayServer) handleReady(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
// Test conversation store by attempting a simple operation
|
// Test conversation store by attempting a simple operation
|
||||||
testID := "health_check_test"
|
testID := "health_check_test"
|
||||||
_, err := s.convs.Get(testID)
|
_, err := s.convs.Get(ctx, testID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
checks["conversation_store"] = "unhealthy: " + err.Error()
|
checks["conversation_store"] = "unhealthy: " + err.Error()
|
||||||
allHealthy = false
|
allHealthy = false
|
||||||
|
|||||||
@@ -156,7 +156,7 @@ func newMockConversationStore() *mockConversationStore {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockConversationStore) Get(id string) (*conversation.Conversation, error) {
|
func (m *mockConversationStore) Get(ctx context.Context, id string) (*conversation.Conversation, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ func (m *mockConversationStore) Get(id string) (*conversation.Conversation, erro
|
|||||||
return conv, nil
|
return conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockConversationStore) Create(id string, model string, messages []api.Message) (*conversation.Conversation, error) {
|
func (m *mockConversationStore) Create(ctx context.Context, id string, model string, messages []api.Message) (*conversation.Conversation, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
@@ -187,7 +187,7 @@ func (m *mockConversationStore) Create(id string, model string, messages []api.M
|
|||||||
return conv, nil
|
return conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockConversationStore) Append(id string, messages ...api.Message) (*conversation.Conversation, error) {
|
func (m *mockConversationStore) Append(ctx context.Context, id string, messages ...api.Message) (*conversation.Conversation, error) {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
@@ -203,7 +203,7 @@ func (m *mockConversationStore) Append(id string, messages ...api.Message) (*con
|
|||||||
return conv, nil
|
return conv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockConversationStore) Delete(id string) error {
|
func (m *mockConversationStore) Delete(ctx context.Context, id string) error {
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ func (s *GatewayServer) handleResponses(w http.ResponseWriter, r *http.Request)
|
|||||||
// Build full message history from previous conversation
|
// Build full message history from previous conversation
|
||||||
var historyMsgs []api.Message
|
var historyMsgs []api.Message
|
||||||
if req.PreviousResponseID != nil && *req.PreviousResponseID != "" {
|
if req.PreviousResponseID != nil && *req.PreviousResponseID != "" {
|
||||||
conv, err := s.convs.Get(*req.PreviousResponseID)
|
conv, err := s.convs.Get(r.Context(), *req.PreviousResponseID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.ErrorContext(r.Context(), "failed to retrieve conversation",
|
s.logger.ErrorContext(r.Context(), "failed to retrieve conversation",
|
||||||
logger.LogAttrsWithTrace(r.Context(),
|
logger.LogAttrsWithTrace(r.Context(),
|
||||||
@@ -186,7 +186,7 @@ func (s *GatewayServer) handleSyncResponse(w http.ResponseWriter, r *http.Reques
|
|||||||
ToolCalls: result.ToolCalls,
|
ToolCalls: result.ToolCalls,
|
||||||
}
|
}
|
||||||
allMsgs := append(storeMsgs, assistantMsg)
|
allMsgs := append(storeMsgs, assistantMsg)
|
||||||
if _, err := s.convs.Create(responseID, result.Model, allMsgs); err != nil {
|
if _, err := s.convs.Create(r.Context(), responseID, result.Model, allMsgs); err != nil {
|
||||||
s.logger.ErrorContext(r.Context(), "failed to store conversation",
|
s.logger.ErrorContext(r.Context(), "failed to store conversation",
|
||||||
logger.LogAttrsWithTrace(r.Context(),
|
logger.LogAttrsWithTrace(r.Context(),
|
||||||
slog.String("request_id", logger.FromContext(r.Context())),
|
slog.String("request_id", logger.FromContext(r.Context())),
|
||||||
@@ -543,7 +543,7 @@ loop:
|
|||||||
ToolCalls: toolCalls,
|
ToolCalls: toolCalls,
|
||||||
}
|
}
|
||||||
allMsgs := append(storeMsgs, assistantMsg)
|
allMsgs := append(storeMsgs, assistantMsg)
|
||||||
if _, err := s.convs.Create(responseID, model, allMsgs); err != nil {
|
if _, err := s.convs.Create(r.Context(), responseID, model, allMsgs); err != nil {
|
||||||
s.logger.ErrorContext(r.Context(), "failed to store conversation",
|
s.logger.ErrorContext(r.Context(), "failed to store conversation",
|
||||||
slog.String("request_id", logger.FromContext(r.Context())),
|
slog.String("request_id", logger.FromContext(r.Context())),
|
||||||
slog.String("response_id", responseID),
|
slog.String("response_id", responseID),
|
||||||
|
|||||||
Reference in New Issue
Block a user