Lean MCP implementation
This commit is contained in:
@@ -1 +0,0 @@
|
||||
"""RAG evaluation agent package."""
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Pydantic AI agent with RAG tool for vector search."""
|
||||
|
||||
import time
|
||||
|
||||
import structlog
|
||||
from pydantic import BaseModel
|
||||
from pydantic_ai import Agent, Embedder, RunContext
|
||||
from pydantic_ai.models.google import GoogleModel
|
||||
|
||||
from rag_eval.config import settings
|
||||
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class Deps(BaseModel):
|
||||
"""Dependencies injected into the agent at runtime."""
|
||||
|
||||
vector_search: GoogleCloudVectorSearch
|
||||
embedder: Embedder
|
||||
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
model = GoogleModel(
|
||||
settings.agent_language_model,
|
||||
provider=settings.provider,
|
||||
)
|
||||
agent = Agent(
|
||||
model,
|
||||
deps_type=Deps,
|
||||
system_prompt=settings.agent_instructions,
|
||||
)
|
||||
|
||||
|
||||
@agent.tool
|
||||
async def conocimiento(ctx: RunContext[Deps], query: str) -> str:
|
||||
"""Search the vector index for the given query.
|
||||
|
||||
Args:
|
||||
ctx: The run context containing dependencies.
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
min_sim = 0.6
|
||||
|
||||
query_embedding = await ctx.deps.embedder.embed_query(query)
|
||||
t_embed = time.perf_counter()
|
||||
|
||||
search_results = await ctx.deps.vector_search.async_run_query(
|
||||
deployed_index_id=settings.index_deployed_id,
|
||||
query=list(query_embedding.embeddings[0]),
|
||||
limit=5,
|
||||
)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
if search_results:
|
||||
max_sim = max(r["distance"] for r in search_results)
|
||||
cutoff = max_sim * 0.9
|
||||
search_results = [
|
||||
s
|
||||
for s in search_results
|
||||
if s["distance"] > cutoff and s["distance"] > min_sim
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"conocimiento.timing",
|
||||
embedding_ms=round((t_embed - t0) * 1000, 1),
|
||||
vector_search_ms=round((t_search - t_embed) * 1000, 1),
|
||||
total_ms=round((t_search - t0) * 1000, 1),
|
||||
chunks=[s["id"] for s in search_results],
|
||||
)
|
||||
|
||||
formatted_results = [
|
||||
f"<document {i} name={result['id']}>\n"
|
||||
f"{result['content']}\n"
|
||||
f"</document {i}>"
|
||||
for i, result in enumerate(search_results, start=1)
|
||||
]
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
deps = Deps(
|
||||
vector_search=settings.vector_search,
|
||||
embedder=settings.embedder,
|
||||
)
|
||||
agent.to_cli_sync(deps=deps)
|
||||
@@ -1,92 +0,0 @@
|
||||
"""Application settings loaded from YAML and environment variables."""
|
||||
|
||||
import os
|
||||
from functools import cached_property
|
||||
|
||||
from pydantic_ai import Embedder
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from config.yaml and env vars."""
|
||||
|
||||
project_id: str
|
||||
location: str
|
||||
bucket: str
|
||||
|
||||
agent_name: str
|
||||
agent_instructions: str
|
||||
agent_language_model: str
|
||||
agent_embedding_model: str
|
||||
agent_thinking: int
|
||||
|
||||
index_name: str
|
||||
index_deployed_id: str
|
||||
index_endpoint: str
|
||||
index_dimensions: int
|
||||
index_machine_type: str = "e2-standard-16"
|
||||
index_origin: str
|
||||
index_destination: str
|
||||
index_chunk_limit: int
|
||||
|
||||
|
||||
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""Use env vars and YAML as settings sources."""
|
||||
return (
|
||||
env_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def provider(self) -> GoogleProvider:
|
||||
"""Return a Google provider configured for Vertex AI."""
|
||||
return GoogleProvider(
|
||||
project=self.project_id,
|
||||
location=self.location,
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def vector_search(self) -> GoogleCloudVectorSearch:
|
||||
"""Return a configured vector search client."""
|
||||
vs = GoogleCloudVectorSearch(
|
||||
project_id=self.project_id,
|
||||
location=self.location,
|
||||
bucket=self.bucket,
|
||||
index_name=self.index_name,
|
||||
)
|
||||
vs.load_index_endpoint(self.index_endpoint)
|
||||
return vs
|
||||
|
||||
@cached_property
|
||||
def embedder(self) -> Embedder:
|
||||
"""Return an embedder configured for the agent's embedding model."""
|
||||
from pydantic_ai.embeddings.google import GoogleEmbeddingModel # noqa: PLC0415
|
||||
|
||||
model = GoogleEmbeddingModel(
|
||||
self.agent_embedding_model,
|
||||
provider=self.provider,
|
||||
)
|
||||
return Embedder(model)
|
||||
|
||||
|
||||
settings = Settings.model_validate({})
|
||||
@@ -1 +0,0 @@
|
||||
"""File storage provider implementations."""
|
||||
@@ -1,56 +0,0 @@
|
||||
"""Abstract base class for file storage providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import BinaryIO
|
||||
|
||||
|
||||
class BaseFileStorage(ABC):
|
||||
"""Abstract base class for a remote file processor.
|
||||
|
||||
Defines the interface for listing and processing files from
|
||||
a remote source.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: str | None = None,
|
||||
) -> None:
|
||||
"""Upload a file to the remote source.
|
||||
|
||||
Args:
|
||||
file_path: The local path to the file to upload.
|
||||
destination_blob_name: Name of the file in remote storage.
|
||||
content_type: The content type of the file.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List files from a remote location.
|
||||
|
||||
Args:
|
||||
path: Path to a specific file or directory. If None,
|
||||
recursively lists all files in the bucket.
|
||||
|
||||
Returns:
|
||||
A list of file paths.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||
"""Get a file from the remote source as a file-like object.
|
||||
|
||||
Args:
|
||||
file_name: The name of the file to retrieve.
|
||||
|
||||
Returns:
|
||||
A file-like object containing the file data.
|
||||
|
||||
"""
|
||||
...
|
||||
@@ -1,188 +0,0 @@
|
||||
"""Google Cloud Storage file storage implementation."""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
from typing import BinaryIO
|
||||
|
||||
import aiohttp
|
||||
from gcloud.aio.storage import Storage
|
||||
from google.cloud import storage
|
||||
|
||||
from rag_eval.file_storage.base import BaseFileStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HTTP_TOO_MANY_REQUESTS = 429
|
||||
HTTP_SERVER_ERROR = 500
|
||||
|
||||
|
||||
class GoogleCloudFileStorage(BaseFileStorage):
|
||||
"""File storage backed by Google Cloud Storage."""
|
||||
|
||||
def __init__(self, bucket: str) -> None: # noqa: D107
|
||||
self.bucket_name = bucket
|
||||
|
||||
self.storage_client = storage.Client()
|
||||
self.bucket_client = self.storage_client.bucket(self.bucket_name)
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._aio_storage: Storage | None = None
|
||||
self._cache: dict[str, bytes] = {}
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
destination_blob_name: str,
|
||||
content_type: str | None = None,
|
||||
) -> None:
|
||||
"""Upload a file to Cloud Storage.
|
||||
|
||||
Args:
|
||||
file_path: The local path to the file to upload.
|
||||
destination_blob_name: Name of the blob in the bucket.
|
||||
content_type: The content type of the file.
|
||||
|
||||
"""
|
||||
blob = self.bucket_client.blob(destination_blob_name)
|
||||
blob.upload_from_filename(
|
||||
file_path,
|
||||
content_type=content_type,
|
||||
if_generation_match=0,
|
||||
)
|
||||
self._cache.pop(destination_blob_name, None)
|
||||
|
||||
def list_files(self, path: str | None = None) -> list[str]:
|
||||
"""List all files at the given path in the bucket.
|
||||
|
||||
If path is None, recursively lists all files.
|
||||
|
||||
Args:
|
||||
path: Prefix to filter files by.
|
||||
|
||||
Returns:
|
||||
A list of blob names.
|
||||
|
||||
"""
|
||||
blobs = self.storage_client.list_blobs(
|
||||
self.bucket_name, prefix=path,
|
||||
)
|
||||
return [blob.name for blob in blobs]
|
||||
|
||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
||||
"""Get a file as a file-like object, using cache.
|
||||
|
||||
Args:
|
||||
file_name: The blob name to retrieve.
|
||||
|
||||
Returns:
|
||||
A BytesIO stream with the file contents.
|
||||
|
||||
"""
|
||||
if file_name not in self._cache:
|
||||
blob = self.bucket_client.blob(file_name)
|
||||
self._cache[file_name] = blob.download_as_bytes()
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=300, limit_per_host=50,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector,
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def _get_aio_storage(self) -> Storage:
|
||||
if self._aio_storage is None:
|
||||
self._aio_storage = Storage(
|
||||
session=self._get_aio_session(),
|
||||
)
|
||||
return self._aio_storage
|
||||
|
||||
async def async_get_file_stream(
|
||||
self, file_name: str, max_retries: int = 3,
|
||||
) -> BinaryIO:
|
||||
"""Get a file asynchronously with retry on transient errors.
|
||||
|
||||
Args:
|
||||
file_name: The blob name to retrieve.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
|
||||
Returns:
|
||||
A BytesIO stream with the file contents.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If all retry attempts fail.
|
||||
|
||||
"""
|
||||
if file_name in self._cache:
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
return file_stream
|
||||
|
||||
storage_client = self._get_aio_storage()
|
||||
last_exception: Exception | None = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
self._cache[file_name] = await storage_client.download(
|
||||
self.bucket_name, file_name,
|
||||
)
|
||||
file_stream = io.BytesIO(self._cache[file_name])
|
||||
file_stream.name = file_name
|
||||
except TimeoutError as exc:
|
||||
last_exception = exc
|
||||
logger.warning(
|
||||
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
||||
self.bucket_name,
|
||||
file_name,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
except aiohttp.ClientResponseError as exc:
|
||||
last_exception = exc
|
||||
if (
|
||||
exc.status == HTTP_TOO_MANY_REQUESTS
|
||||
or exc.status >= HTTP_SERVER_ERROR
|
||||
):
|
||||
logger.warning(
|
||||
"HTTP %d downloading gs://%s/%s "
|
||||
"(attempt %d/%d)",
|
||||
exc.status,
|
||||
self.bucket_name,
|
||||
file_name,
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
return file_stream
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
delay = 0.5 * (2**attempt)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
msg = (
|
||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
||||
f"after {max_retries} attempts"
|
||||
)
|
||||
raise TimeoutError(msg) from last_exception
|
||||
|
||||
def delete_files(self, path: str) -> None:
|
||||
"""Delete all files at the given path in the bucket.
|
||||
|
||||
Args:
|
||||
path: Prefix of blobs to delete.
|
||||
|
||||
"""
|
||||
blobs = self.storage_client.list_blobs(
|
||||
self.bucket_name, prefix=path,
|
||||
)
|
||||
for blob in blobs:
|
||||
blob.delete()
|
||||
self._cache.pop(blob.name, None)
|
||||
@@ -1,47 +0,0 @@
|
||||
"""Structured logging configuration using structlog."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def setup_logging(*, json: bool = True, level: int = logging.INFO) -> None:
|
||||
"""Configure structlog with JSON or console output."""
|
||||
shared_processors: list[structlog.types.Processor] = [
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.stdlib.add_log_level,
|
||||
structlog.stdlib.add_logger_name,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
||||
]
|
||||
|
||||
if json:
|
||||
formatter = structlog.stdlib.ProcessorFormatter(
|
||||
processors=[
|
||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||
structlog.processors.JSONRenderer(),
|
||||
],
|
||||
)
|
||||
else:
|
||||
formatter = structlog.stdlib.ProcessorFormatter(
|
||||
processors=[
|
||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
||||
structlog.dev.ConsoleRenderer(),
|
||||
],
|
||||
)
|
||||
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
root = logging.getLogger()
|
||||
root.handlers.clear()
|
||||
root.addHandler(handler)
|
||||
root.setLevel(level)
|
||||
|
||||
structlog.configure(
|
||||
processors=shared_processors,
|
||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
||||
wrapper_class=structlog.stdlib.BoundLogger,
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
@@ -1,61 +0,0 @@
|
||||
"""FastAPI server exposing the RAG agent endpoint."""
|
||||
|
||||
import time
|
||||
from typing import Literal
|
||||
from uuid import uuid4
|
||||
|
||||
import structlog
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from rag_eval.agent import Deps, agent
|
||||
from rag_eval.config import settings
|
||||
from rag_eval.logging import setup_logging
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
setup_logging()
|
||||
|
||||
app = FastAPI(title="RAG Agent")
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""A single chat message."""
|
||||
|
||||
role: Literal["system", "user", "assistant"]
|
||||
content: str
|
||||
|
||||
|
||||
class AgentRequest(BaseModel):
|
||||
"""Request body for the agent endpoint."""
|
||||
|
||||
messages: list[Message]
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Response body from the agent endpoint."""
|
||||
|
||||
response: str
|
||||
|
||||
|
||||
@app.post("/agent")
|
||||
async def run_agent(request: AgentRequest) -> AgentResponse:
|
||||
"""Run the RAG agent with the provided messages."""
|
||||
request_id = uuid4().hex[:8]
|
||||
structlog.contextvars.clear_contextvars()
|
||||
structlog.contextvars.bind_contextvars(request_id=request_id)
|
||||
|
||||
prompt = request.messages[-1].content
|
||||
logger.info("request.start", prompt_length=len(prompt))
|
||||
t0 = time.perf_counter()
|
||||
|
||||
deps = Deps(
|
||||
vector_search=settings.vector_search,
|
||||
embedder=settings.embedder,
|
||||
)
|
||||
result = await agent.run(prompt, deps=deps)
|
||||
|
||||
elapsed = round((time.perf_counter() - t0) * 1000, 1)
|
||||
logger.info("request.end", elapsed_ms=elapsed)
|
||||
|
||||
return AgentResponse(response=result.output)
|
||||
@@ -1 +0,0 @@
|
||||
"""Vector search provider implementations."""
|
||||
@@ -1,68 +0,0 @@
|
||||
"""Abstract base class for vector search providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypedDict
|
||||
|
||||
|
||||
class SearchResult(TypedDict):
|
||||
"""A single vector search result."""
|
||||
|
||||
id: str
|
||||
distance: float
|
||||
content: str
|
||||
|
||||
|
||||
class BaseVectorSearch(ABC):
|
||||
"""Abstract base class for a vector search provider.
|
||||
|
||||
This class defines the standard interface for creating a vector search
|
||||
index and running queries against it.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create_index(
|
||||
self, name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
||||
) -> None:
|
||||
"""Create a new vector search index with the provided content.
|
||||
|
||||
Args:
|
||||
name: The desired name for the new index.
|
||||
content_path: Path to the data used to populate the index.
|
||||
**kwargs: Additional provider-specific arguments.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def update_index(
|
||||
self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
||||
) -> None:
|
||||
"""Update an existing vector search index with new content.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to update.
|
||||
content_path: Path to the data used to populate the index.
|
||||
**kwargs: Additional provider-specific arguments.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: list[float],
|
||||
limit: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Run a similarity search query against the index.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
...
|
||||
@@ -1,310 +0,0 @@
|
||||
"""Google Cloud Vertex AI Vector Search implementation."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import google.auth
|
||||
import google.auth.credentials
|
||||
import google.auth.transport.requests
|
||||
from gcloud.aio.auth import Token
|
||||
from google.cloud import aiplatform
|
||||
|
||||
from rag_eval.file_storage.google_cloud import GoogleCloudFileStorage
|
||||
from rag_eval.vector_search.base import BaseVectorSearch, SearchResult
|
||||
|
||||
|
||||
class GoogleCloudVectorSearch(BaseVectorSearch):
|
||||
"""A vector search provider using Vertex AI Vector Search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
bucket: str,
|
||||
index_name: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the GoogleCloudVectorSearch client.
|
||||
|
||||
Args:
|
||||
project_id: The Google Cloud project ID.
|
||||
location: The Google Cloud location (e.g., 'us-central1').
|
||||
bucket: The GCS bucket to use for file storage.
|
||||
index_name: The name of the index.
|
||||
|
||||
"""
|
||||
aiplatform.init(project=project_id, location=location)
|
||||
self.project_id = project_id
|
||||
self.location = location
|
||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
||||
self.index_name = index_name
|
||||
self._credentials: google.auth.credentials.Credentials | None = None
|
||||
self._aio_session: aiohttp.ClientSession | None = None
|
||||
self._async_token: Token | None = None
|
||||
|
||||
def _get_auth_headers(self) -> dict[str, str]:
|
||||
if self._credentials is None:
|
||||
self._credentials, _ = google.auth.default(
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
if not self._credentials.token or self._credentials.expired:
|
||||
self._credentials.refresh(
|
||||
google.auth.transport.requests.Request(),
|
||||
)
|
||||
return {
|
||||
"Authorization": f"Bearer {self._credentials.token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async def _async_get_auth_headers(self) -> dict[str, str]:
|
||||
if self._async_token is None:
|
||||
self._async_token = Token(
|
||||
session=self._get_aio_session(),
|
||||
scopes=[
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
],
|
||||
)
|
||||
access_token = await self._async_token.get()
|
||||
return {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
||||
if self._aio_session is None or self._aio_session.closed:
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=300, limit_per_host=50,
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=60)
|
||||
self._aio_session = aiohttp.ClientSession(
|
||||
timeout=timeout, connector=connector,
|
||||
)
|
||||
return self._aio_session
|
||||
|
||||
def create_index(
|
||||
self,
|
||||
name: str,
|
||||
content_path: str,
|
||||
*,
|
||||
dimensions: int = 3072,
|
||||
approximate_neighbors_count: int = 150,
|
||||
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
|
||||
**kwargs: Any, # noqa: ANN401, ARG002
|
||||
) -> None:
|
||||
"""Create a new Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
name: The display name for the new index.
|
||||
content_path: GCS URI to the embeddings JSON file.
|
||||
dimensions: Number of dimensions in embedding vectors.
|
||||
approximate_neighbors_count: Neighbors to find per vector.
|
||||
distance_measure_type: The distance measure to use.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
||||
display_name=name,
|
||||
contents_delta_uri=content_path,
|
||||
dimensions=dimensions,
|
||||
approximate_neighbors_count=approximate_neighbors_count,
|
||||
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
|
||||
leaf_node_embedding_count=1000,
|
||||
leaf_nodes_to_search_percent=10,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def update_index(
|
||||
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
|
||||
) -> None:
|
||||
"""Update an existing Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
index_name: The resource name of the index to update.
|
||||
content_path: GCS URI to the new embeddings JSON file.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex(index_name=index_name)
|
||||
index.update_embeddings(
|
||||
contents_delta_uri=content_path,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def deploy_index(
|
||||
self,
|
||||
index_name: str,
|
||||
machine_type: str = "e2-standard-2",
|
||||
) -> None:
|
||||
"""Deploy a Vertex AI Vector Search index to an endpoint.
|
||||
|
||||
Args:
|
||||
index_name: The name of the index to deploy.
|
||||
machine_type: The machine type for the endpoint.
|
||||
|
||||
"""
|
||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
||||
display_name=f"{index_name}-endpoint",
|
||||
public_endpoint_enabled=True,
|
||||
)
|
||||
index_endpoint.deploy_index(
|
||||
index=self.index,
|
||||
deployed_index_id=(
|
||||
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
|
||||
),
|
||||
machine_type=machine_type,
|
||||
)
|
||||
self.index_endpoint = index_endpoint
|
||||
|
||||
def load_index_endpoint(self, endpoint_name: str) -> None:
|
||||
"""Load an existing Vertex AI Vector Search index endpoint.
|
||||
|
||||
Args:
|
||||
endpoint_name: The resource name of the index endpoint.
|
||||
|
||||
"""
|
||||
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
||||
endpoint_name,
|
||||
)
|
||||
if not self.index_endpoint.public_endpoint_domain_name:
|
||||
msg = (
|
||||
"The index endpoint does not have a public endpoint. "
|
||||
"Ensure the endpoint is configured for public access."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: list[float],
|
||||
limit: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Run a similarity search query against the deployed index.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
response = self.index_endpoint.find_neighbors(
|
||||
deployed_index_id=deployed_index_id,
|
||||
queries=[query],
|
||||
num_neighbors=limit,
|
||||
)
|
||||
results = []
|
||||
for neighbor in response[0]:
|
||||
file_path = (
|
||||
f"{self.index_name}/contents/{neighbor.id}.md"
|
||||
)
|
||||
content = (
|
||||
self.storage.get_file_stream(file_path)
|
||||
.read()
|
||||
.decode("utf-8")
|
||||
)
|
||||
results.append(
|
||||
SearchResult(
|
||||
id=neighbor.id,
|
||||
distance=float(neighbor.distance or 0),
|
||||
content=content,
|
||||
),
|
||||
)
|
||||
return results
|
||||
|
||||
async def async_run_query(
|
||||
self,
|
||||
deployed_index_id: str,
|
||||
query: Sequence[float],
|
||||
limit: int,
|
||||
) -> list[SearchResult]:
|
||||
"""Run an async similarity search via the REST API.
|
||||
|
||||
Args:
|
||||
deployed_index_id: The ID of the deployed index.
|
||||
query: The embedding vector for the search query.
|
||||
limit: Maximum number of nearest neighbors to return.
|
||||
|
||||
Returns:
|
||||
A list of matched items with id, distance, and content.
|
||||
|
||||
"""
|
||||
domain = self.index_endpoint.public_endpoint_domain_name
|
||||
endpoint_id = self.index_endpoint.name.split("/")[-1]
|
||||
url = (
|
||||
f"https://{domain}/v1/projects/{self.project_id}"
|
||||
f"/locations/{self.location}"
|
||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
||||
)
|
||||
payload = {
|
||||
"deployed_index_id": deployed_index_id,
|
||||
"queries": [
|
||||
{
|
||||
"datapoint": {"feature_vector": list(query)},
|
||||
"neighbor_count": limit,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
headers = await self._async_get_auth_headers()
|
||||
session = self._get_aio_session()
|
||||
async with session.post(
|
||||
url, json=payload, headers=headers,
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
data = await response.json()
|
||||
|
||||
neighbors = (
|
||||
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
||||
)
|
||||
content_tasks = []
|
||||
for neighbor in neighbors:
|
||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
||||
file_path = (
|
||||
f"{self.index_name}/contents/{datapoint_id}.md"
|
||||
)
|
||||
content_tasks.append(
|
||||
self.storage.async_get_file_stream(file_path),
|
||||
)
|
||||
|
||||
file_streams = await asyncio.gather(*content_tasks)
|
||||
results: list[SearchResult] = []
|
||||
for neighbor, stream in zip(
|
||||
neighbors, file_streams, strict=True,
|
||||
):
|
||||
results.append(
|
||||
SearchResult(
|
||||
id=neighbor["datapoint"]["datapointId"],
|
||||
distance=neighbor["distance"],
|
||||
content=stream.read().decode("utf-8"),
|
||||
),
|
||||
)
|
||||
return results
|
||||
|
||||
def delete_index(self, index_name: str) -> None:
|
||||
"""Delete a Vertex AI Vector Search index.
|
||||
|
||||
Args:
|
||||
index_name: The resource name of the index.
|
||||
|
||||
"""
|
||||
index = aiplatform.MatchingEngineIndex(index_name)
|
||||
index.delete()
|
||||
|
||||
def delete_index_endpoint(
|
||||
self, index_endpoint_name: str,
|
||||
) -> None:
|
||||
"""Delete a Vertex AI Vector Search index endpoint.
|
||||
|
||||
Args:
|
||||
index_endpoint_name: The resource name of the endpoint.
|
||||
|
||||
"""
|
||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
||||
index_endpoint_name,
|
||||
)
|
||||
index_endpoint.undeploy_all()
|
||||
index_endpoint.delete(force=True)
|
||||
6
src/va_agent/__init__.py
Normal file
6
src/va_agent/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Package export for the ADK root agent."""
|
||||
|
||||
import os
|
||||
|
||||
# Ensure the Google GenAI SDK talks to Vertex AI instead of the public Gemini API.
|
||||
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
||||
29
src/va_agent/agent.py
Normal file
29
src/va_agent/agent.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""ADK agent with vector search RAG tool."""
|
||||
|
||||
from google import genai
|
||||
from google.adk.agents.llm_agent import Agent
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.tools.mcp_tool import McpToolset
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
|
||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
|
||||
from va_agent.config import settings
|
||||
from va_agent.session import FirestoreSessionService
|
||||
|
||||
connection_params = SseConnectionParams(url=settings.mcp_remote_url)
|
||||
toolset = McpToolset(connection_params=connection_params)
|
||||
|
||||
agent = Agent(
|
||||
model=settings.agent_model,
|
||||
name=settings.agent_name,
|
||||
instruction=settings.agent_instructions,
|
||||
tools=[toolset],
|
||||
)
|
||||
|
||||
session_service = FirestoreSessionService(
|
||||
db=AsyncClient(database=settings.firestore_db),
|
||||
compaction_token_threshold=10_000,
|
||||
genai_client=genai.Client(),
|
||||
)
|
||||
|
||||
runner = Runner(app_name="va_agent", agent=agent, session_service=session_service)
|
||||
53
src/va_agent/config.py
Normal file
53
src/va_agent/config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Configuration helper for ADK agent."""
|
||||
|
||||
import os
|
||||
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
SettingsConfigDict,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||
|
||||
|
||||
class AgentSettings(BaseSettings):
|
||||
"""Settings for ADK agent with vector search."""
|
||||
|
||||
google_cloud_project: str
|
||||
google_cloud_location: str
|
||||
|
||||
# Agent configuration
|
||||
agent_name: str
|
||||
agent_instructions: str
|
||||
agent_model: str
|
||||
|
||||
# Firestore configuration
|
||||
firestore_db: str
|
||||
|
||||
# MCP configuration
|
||||
mcp_remote_url: str
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
yaml_file=CONFIG_FILE_PATH,
|
||||
extra="ignore", # Ignore extra fields from config.yaml
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""Use env vars and YAML as settings sources."""
|
||||
return (
|
||||
env_settings,
|
||||
YamlConfigSettingsSource(settings_cls),
|
||||
)
|
||||
|
||||
|
||||
settings = AgentSettings.model_validate({})
|
||||
10
src/va_agent/server.py
Normal file
10
src/va_agent/server.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""FastAPI server exposing the RAG agent endpoint.
|
||||
|
||||
NOTE: This file is a stub. The rag_eval module was removed in the
|
||||
lean MCP implementation. This file is kept for reference but is not
|
||||
functional.
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI(title="RAG Agent")
|
||||
582
src/va_agent/session.py
Normal file
582
src/va_agent/session.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""Firestore-backed session service for Google ADK."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, override
|
||||
|
||||
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.sessions import _session_util
|
||||
from google.adk.sessions.base_session_service import (
|
||||
BaseSessionService,
|
||||
GetSessionConfig,
|
||||
ListSessionsResponse,
|
||||
)
|
||||
from google.adk.sessions.session import Session
|
||||
from google.adk.sessions.state import State
|
||||
from google.cloud.firestore_v1.async_transaction import async_transactional
|
||||
from google.cloud.firestore_v1.base_query import FieldFilter
|
||||
from google.cloud.firestore_v1.field_path import FieldPath
|
||||
from google.genai.types import Content, Part
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google import genai
|
||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||
|
||||
logger = logging.getLogger("google_adk." + __name__)
|
||||
|
||||
_COMPACTION_LOCK_TTL = 300 # seconds
|
||||
|
||||
|
||||
@async_transactional
|
||||
async def _try_claim_compaction_txn(transaction: Any, session_ref: Any) -> bool:
|
||||
"""Atomically claim the compaction lock if it is free or stale."""
|
||||
snapshot = await session_ref.get(transaction=transaction)
|
||||
if not snapshot.exists:
|
||||
return False
|
||||
data = snapshot.to_dict() or {}
|
||||
lock_time = data.get("compaction_lock")
|
||||
if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL:
|
||||
return False
|
||||
transaction.update(session_ref, {"compaction_lock": time.time()})
|
||||
return True
|
||||
|
||||
|
||||
class FirestoreSessionService(BaseSessionService):
|
||||
"""A Firestore-backed implementation of BaseSessionService.
|
||||
|
||||
Firestore document layout (given ``collection_prefix="adk"``)::
|
||||
|
||||
adk_app_states/{app_name}
|
||||
→ app-scoped state key/values
|
||||
|
||||
adk_user_states/{app_name}__{user_id}
|
||||
→ user-scoped state key/values
|
||||
|
||||
adk_sessions/{app_name}__{user_id}__{session_id}
|
||||
→ {app_name, user_id, session_id, state: {…}, last_update_time}
|
||||
└─ events/{event_id} → serialised Event
|
||||
"""
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
self,
|
||||
*,
|
||||
db: AsyncClient,
|
||||
collection_prefix: str = "adk",
|
||||
compaction_token_threshold: int | None = None,
|
||||
compaction_model: str = "gemini-2.5-flash",
|
||||
compaction_keep_recent: int = 10,
|
||||
genai_client: genai.Client | None = None,
|
||||
) -> None:
|
||||
"""Initialize FirestoreSessionService.
|
||||
|
||||
Args:
|
||||
db: Firestore async client
|
||||
collection_prefix: Prefix for Firestore collections
|
||||
compaction_token_threshold: Token count threshold for compaction
|
||||
compaction_model: Model to use for summarization
|
||||
compaction_keep_recent: Number of recent events to keep
|
||||
genai_client: GenAI client for compaction summaries
|
||||
|
||||
"""
|
||||
if compaction_token_threshold is not None and genai_client is None:
|
||||
msg = "genai_client is required when compaction_token_threshold is set."
|
||||
raise ValueError(msg)
|
||||
self._db = db
|
||||
self._prefix = collection_prefix
|
||||
self._compaction_threshold = compaction_token_threshold
|
||||
self._compaction_model = compaction_model
|
||||
self._compaction_keep_recent = compaction_keep_recent
|
||||
self._genai_client = genai_client
|
||||
self._compaction_locks: dict[str, asyncio.Lock] = {}
|
||||
self._active_tasks: set[asyncio.Task] = set()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Document-reference helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _app_state_ref(self, app_name: str) -> Any:
|
||||
return self._db.collection(f"{self._prefix}_app_states").document(app_name)
|
||||
|
||||
def _user_state_ref(self, app_name: str, user_id: str) -> Any:
|
||||
return self._db.collection(f"{self._prefix}_user_states").document(
|
||||
f"{app_name}__{user_id}"
|
||||
)
|
||||
|
||||
def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any:
|
||||
return self._db.collection(f"{self._prefix}_sessions").document(
|
||||
f"{app_name}__{user_id}__{session_id}"
|
||||
)
|
||||
|
||||
def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any:
|
||||
return self._session_ref(app_name, user_id, session_id).collection("events")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
|
||||
snap = await self._app_state_ref(app_name).get()
|
||||
return snap.to_dict() or {} if snap.exists else {}
|
||||
|
||||
async def _get_user_state(self, app_name: str, user_id: str) -> dict[str, Any]:
|
||||
snap = await self._user_state_ref(app_name, user_id).get()
|
||||
return snap.to_dict() or {} if snap.exists else {}
|
||||
|
||||
@staticmethod
|
||||
def _merge_state(
|
||||
app_state: dict[str, Any],
|
||||
user_state: dict[str, Any],
|
||||
session_state: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
merged = dict(session_state)
|
||||
for key, value in app_state.items():
|
||||
merged[State.APP_PREFIX + key] = value
|
||||
for key, value in user_state.items():
|
||||
merged[State.USER_PREFIX + key] = value
|
||||
return merged
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Compaction helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _events_to_text(events: list[Event]) -> str:
|
||||
lines: list[str] = []
|
||||
for event in events:
|
||||
if event.content and event.content.parts:
|
||||
text = "".join(p.text or "" for p in event.content.parts)
|
||||
if text:
|
||||
role = "User" if event.author == "user" else "Assistant"
|
||||
lines.append(f"{role}: {text}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
async def _generate_summary(
|
||||
self, existing_summary: str, events: list[Event]
|
||||
) -> str:
|
||||
conversation_text = self._events_to_text(events)
|
||||
previous = (
|
||||
f"Previous summary of earlier conversation:\n{existing_summary}\n\n"
|
||||
if existing_summary
|
||||
else ""
|
||||
)
|
||||
prompt = (
|
||||
"Summarize the following conversation between a user and an "
|
||||
"assistant. Preserve:\n"
|
||||
"- Key decisions and conclusions\n"
|
||||
"- User preferences and requirements\n"
|
||||
"- Important facts, names, and numbers\n"
|
||||
"- The overall topic and direction of the conversation\n"
|
||||
"- Any pending tasks or open questions\n\n"
|
||||
f"{previous}"
|
||||
f"Conversation:\n{conversation_text}\n\n"
|
||||
"Provide a clear, comprehensive summary."
|
||||
)
|
||||
if self._genai_client is None:
|
||||
msg = "genai_client is required for compaction"
|
||||
raise RuntimeError(msg)
|
||||
response = await self._genai_client.aio.models.generate_content(
|
||||
model=self._compaction_model,
|
||||
contents=prompt,
|
||||
)
|
||||
return response.text or ""
|
||||
|
||||
async def _compact_session(self, session: Session) -> None:
|
||||
app_name = session.app_name
|
||||
user_id = session.user_id
|
||||
session_id = session.id
|
||||
|
||||
events_ref = self._events_col(app_name, user_id, session_id)
|
||||
query = events_ref.order_by("timestamp")
|
||||
event_docs = await query.get()
|
||||
|
||||
if len(event_docs) <= self._compaction_keep_recent:
|
||||
return
|
||||
|
||||
all_events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||
events_to_summarize = all_events[: -self._compaction_keep_recent]
|
||||
|
||||
session_snap = await self._session_ref(app_name, user_id, session_id).get()
|
||||
existing_summary = (session_snap.to_dict() or {}).get(
|
||||
"conversation_summary", ""
|
||||
)
|
||||
|
||||
try:
|
||||
summary = await self._generate_summary(
|
||||
existing_summary, events_to_summarize
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Compaction summary generation failed; skipping.")
|
||||
return
|
||||
|
||||
# Write summary BEFORE deleting events so a crash between the two
|
||||
# steps leaves safe duplication rather than data loss.
|
||||
await self._session_ref(app_name, user_id, session_id).update(
|
||||
{"conversation_summary": summary}
|
||||
)
|
||||
|
||||
docs_to_delete = event_docs[: -self._compaction_keep_recent]
|
||||
for i in range(0, len(docs_to_delete), 500):
|
||||
batch = self._db.batch()
|
||||
for doc in docs_to_delete[i : i + 500]:
|
||||
batch.delete(doc.reference)
|
||||
await batch.commit()
|
||||
|
||||
logger.info(
|
||||
"Compacted session %s: summarised %d events, kept %d.",
|
||||
session_id,
|
||||
len(docs_to_delete),
|
||||
self._compaction_keep_recent,
|
||||
)
|
||||
|
||||
async def _guarded_compact(self, session: Session) -> None:
|
||||
"""Run compaction in the background with per-session locking."""
|
||||
key = f"{session.app_name}__{session.user_id}__{session.id}"
|
||||
lock = self._compaction_locks.setdefault(key, asyncio.Lock())
|
||||
|
||||
if lock.locked():
|
||||
logger.debug("Compaction already running locally for %s; skipping.", key)
|
||||
return
|
||||
|
||||
async with lock:
|
||||
session_ref = self._session_ref(
|
||||
session.app_name, session.user_id, session.id
|
||||
)
|
||||
try:
|
||||
transaction = self._db.transaction()
|
||||
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||
except Exception:
|
||||
logger.exception("Failed to claim compaction lock for %s", key)
|
||||
return
|
||||
|
||||
if not claimed:
|
||||
logger.debug(
|
||||
"Compaction lock held by another instance for %s; skipping.",
|
||||
key,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await self._compact_session(session)
|
||||
except Exception:
|
||||
logger.exception("Background compaction failed for %s", key)
|
||||
finally:
|
||||
try:
|
||||
await session_ref.update({"compaction_lock": None})
|
||||
except Exception:
|
||||
logger.exception("Failed to release compaction lock for %s", key)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Await all in-flight compaction tasks. Call before shutdown."""
|
||||
if self._active_tasks:
|
||||
await asyncio.gather(*self._active_tasks, return_exceptions=True)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# BaseSessionService implementation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@override
|
||||
async def create_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
state: dict[str, Any] | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> Session:
|
||||
if session_id and session_id.strip():
|
||||
session_id = session_id.strip()
|
||||
existing = await self._session_ref(app_name, user_id, session_id).get()
|
||||
if existing.exists:
|
||||
msg = f"Session with id {session_id} already exists."
|
||||
raise AlreadyExistsError(msg)
|
||||
else:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
state_deltas = _session_util.extract_state_delta(state) # type: ignore[attr-defined]
|
||||
app_state_delta = state_deltas["app"]
|
||||
user_state_delta = state_deltas["user"]
|
||||
session_state = state_deltas["session"]
|
||||
|
||||
write_coros: list = []
|
||||
if app_state_delta:
|
||||
write_coros.append(
|
||||
self._app_state_ref(app_name).set(app_state_delta, merge=True)
|
||||
)
|
||||
if user_state_delta:
|
||||
write_coros.append(
|
||||
self._user_state_ref(app_name, user_id).set(
|
||||
user_state_delta, merge=True
|
||||
)
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
write_coros.append(
|
||||
self._session_ref(app_name, user_id, session_id).set(
|
||||
{
|
||||
"app_name": app_name,
|
||||
"user_id": user_id,
|
||||
"session_id": session_id,
|
||||
"state": session_state or {},
|
||||
"last_update_time": now,
|
||||
}
|
||||
)
|
||||
)
|
||||
await asyncio.gather(*write_coros)
|
||||
|
||||
app_state, user_state = await asyncio.gather(
|
||||
self._get_app_state(app_name),
|
||||
self._get_user_state(app_name, user_id),
|
||||
)
|
||||
merged = self._merge_state(app_state, user_state, session_state or {})
|
||||
|
||||
return Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=merged,
|
||||
last_update_time=now,
|
||||
)
|
||||
|
||||
@override
|
||||
async def get_session(
|
||||
self,
|
||||
*,
|
||||
app_name: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
config: GetSessionConfig | None = None,
|
||||
) -> Session | None:
|
||||
snap = await self._session_ref(app_name, user_id, session_id).get()
|
||||
if not snap.exists:
|
||||
return None
|
||||
|
||||
session_data = snap.to_dict()
|
||||
|
||||
# Build events query
|
||||
events_ref = self._events_col(app_name, user_id, session_id)
|
||||
query = events_ref
|
||||
if config and config.after_timestamp:
|
||||
query = query.where(
|
||||
filter=FieldFilter("timestamp", ">=", config.after_timestamp)
|
||||
)
|
||||
query = query.order_by("timestamp")
|
||||
|
||||
event_docs, app_state, user_state = await asyncio.gather(
|
||||
query.get(),
|
||||
self._get_app_state(app_name),
|
||||
self._get_user_state(app_name, user_id),
|
||||
)
|
||||
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||
|
||||
if config and config.num_recent_events:
|
||||
events = events[-config.num_recent_events :]
|
||||
|
||||
# Prepend conversation summary as synthetic context events
|
||||
conversation_summary = session_data.get("conversation_summary")
|
||||
if conversation_summary:
|
||||
summary_event = Event(
|
||||
id="summary-context",
|
||||
author="user",
|
||||
content=Content(
|
||||
role="user",
|
||||
parts=[
|
||||
Part(
|
||||
text=(
|
||||
"[Conversation context from previous"
|
||||
" messages]\n"
|
||||
f"{conversation_summary}"
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
timestamp=0.0,
|
||||
invocation_id="compaction-summary",
|
||||
)
|
||||
ack_event = Event(
|
||||
id="summary-ack",
|
||||
author=app_name,
|
||||
content=Content(
|
||||
role="model",
|
||||
parts=[
|
||||
Part(
|
||||
text=(
|
||||
"Understood, I have the context from our"
|
||||
" previous conversation and will continue"
|
||||
" accordingly."
|
||||
)
|
||||
)
|
||||
],
|
||||
),
|
||||
timestamp=0.001,
|
||||
invocation_id="compaction-summary",
|
||||
)
|
||||
events = [summary_event, ack_event, *events]
|
||||
|
||||
# Merge scoped state
|
||||
merged = self._merge_state(app_state, user_state, session_data.get("state", {}))
|
||||
|
||||
return Session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
id=session_id,
|
||||
state=merged,
|
||||
events=events,
|
||||
last_update_time=session_data.get("last_update_time", 0.0),
|
||||
)
|
||||
|
||||
@override
|
||||
async def list_sessions(
|
||||
self, *, app_name: str, user_id: str | None = None
|
||||
) -> ListSessionsResponse:
|
||||
query = self._db.collection(f"{self._prefix}_sessions").where(
|
||||
filter=FieldFilter("app_name", "==", app_name)
|
||||
)
|
||||
if user_id is not None:
|
||||
query = query.where(filter=FieldFilter("user_id", "==", user_id))
|
||||
|
||||
docs = await query.get()
|
||||
if not docs:
|
||||
return ListSessionsResponse()
|
||||
|
||||
doc_dicts: list[dict[str, Any]] = [doc.to_dict() or {} for doc in docs]
|
||||
|
||||
# Pre-fetch app state and all distinct user states in parallel
|
||||
unique_user_ids = list({d["user_id"] for d in doc_dicts})
|
||||
app_state, *user_states = await asyncio.gather(
|
||||
self._get_app_state(app_name),
|
||||
*(self._get_user_state(app_name, uid) for uid in unique_user_ids),
|
||||
)
|
||||
user_state_cache = dict(zip(unique_user_ids, user_states, strict=False))
|
||||
|
||||
sessions: list[Session] = []
|
||||
for data in doc_dicts:
|
||||
s_user_id = data["user_id"]
|
||||
merged = self._merge_state(
|
||||
app_state,
|
||||
user_state_cache[s_user_id],
|
||||
data.get("state", {}),
|
||||
)
|
||||
|
||||
sessions.append(
|
||||
Session(
|
||||
app_name=app_name,
|
||||
user_id=s_user_id,
|
||||
id=data["session_id"],
|
||||
state=merged,
|
||||
events=[],
|
||||
last_update_time=data.get("last_update_time", 0.0),
|
||||
)
|
||||
)
|
||||
|
||||
return ListSessionsResponse(sessions=sessions)
|
||||
|
||||
@override
|
||||
async def delete_session(
|
||||
self, *, app_name: str, user_id: str, session_id: str
|
||||
) -> None:
|
||||
ref = self._session_ref(app_name, user_id, session_id)
|
||||
await self._db.recursive_delete(ref)
|
||||
|
||||
@override
|
||||
async def append_event(self, session: Session, event: Event) -> Event:
|
||||
if event.partial:
|
||||
return event
|
||||
|
||||
t0 = time.monotonic()
|
||||
|
||||
app_name = session.app_name
|
||||
user_id = session.user_id
|
||||
session_id = session.id
|
||||
|
||||
# Base class: strips temp state, applies delta to in-memory session,
|
||||
# appends event to session.events
|
||||
event = await super().append_event(session=session, event=event)
|
||||
session.last_update_time = event.timestamp
|
||||
|
||||
# Persist event document
|
||||
event_data = event.model_dump(mode="json", exclude_none=True)
|
||||
await (
|
||||
self._events_col(app_name, user_id, session_id)
|
||||
.document(event.id)
|
||||
.set(event_data)
|
||||
)
|
||||
|
||||
# Persist state deltas
|
||||
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||
|
||||
if event.actions and event.actions.state_delta:
|
||||
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
|
||||
|
||||
write_coros: list = []
|
||||
if state_deltas["app"]:
|
||||
write_coros.append(
|
||||
self._app_state_ref(app_name).set(state_deltas["app"], merge=True)
|
||||
)
|
||||
if state_deltas["user"]:
|
||||
write_coros.append(
|
||||
self._user_state_ref(app_name, user_id).set(
|
||||
state_deltas["user"], merge=True
|
||||
)
|
||||
)
|
||||
|
||||
if state_deltas["session"]:
|
||||
field_updates: dict[str, Any] = {
|
||||
FieldPath("state", k).to_api_repr(): v
|
||||
for k, v in state_deltas["session"].items()
|
||||
}
|
||||
field_updates["last_update_time"] = event.timestamp
|
||||
write_coros.append(session_ref.update(field_updates))
|
||||
else:
|
||||
write_coros.append(
|
||||
session_ref.update({"last_update_time": event.timestamp})
|
||||
)
|
||||
|
||||
await asyncio.gather(*write_coros)
|
||||
else:
|
||||
await session_ref.update({"last_update_time": event.timestamp})
|
||||
|
||||
# Log token usage
|
||||
if event.usage_metadata:
|
||||
meta = event.usage_metadata
|
||||
logger.info(
|
||||
"Token usage for session %s event %s: "
|
||||
"prompt=%s, candidates=%s, total=%s",
|
||||
session_id,
|
||||
event.id,
|
||||
meta.prompt_token_count,
|
||||
meta.candidates_token_count,
|
||||
meta.total_token_count,
|
||||
)
|
||||
|
||||
# Trigger compaction if total token count exceeds threshold
|
||||
if (
|
||||
self._compaction_threshold is not None
|
||||
and event.usage_metadata
|
||||
and event.usage_metadata.total_token_count
|
||||
and event.usage_metadata.total_token_count >= self._compaction_threshold
|
||||
):
|
||||
logger.info(
|
||||
"Compaction triggered for session %s: "
|
||||
"total_token_count=%d >= threshold=%d",
|
||||
session_id,
|
||||
event.usage_metadata.total_token_count,
|
||||
self._compaction_threshold,
|
||||
)
|
||||
task = asyncio.create_task(self._guarded_compact(session))
|
||||
self._active_tasks.add(task)
|
||||
task.add_done_callback(self._active_tasks.discard)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
logger.info(
|
||||
"append_event completed for session %s event %s in %.3fs",
|
||||
session_id,
|
||||
event.id,
|
||||
elapsed,
|
||||
)
|
||||
|
||||
return event
|
||||
Reference in New Issue
Block a user