Compare commits

...

1 Commits

Author SHA1 Message Date
1c6d942177 Lean MCP implementation 2026-02-23 04:37:39 +00:00
37 changed files with 2380 additions and 3541 deletions

View File

@@ -1,2 +1,3 @@
Use `uv` for project management.
Use `uv run ruff check` for linting, and `uv run ty check` for type checking
Use `uv run pytest` for testing.

View File

@@ -1,5 +1,5 @@
[project]
name = "rag-eval"
name = "va-agent"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
@@ -9,28 +9,20 @@ authors = [
]
requires-python = "~=3.12.0"
dependencies = [
"aiohttp>=3.13.3",
"gcloud-aio-auth>=5.4.2",
"gcloud-aio-storage>=9.6.1",
"google-adk>=1.14.1",
"google-cloud-aiplatform>=1.126.1",
"google-cloud-storage>=2.19.0",
"google-cloud-firestore>=2.23.0",
"pydantic-settings[yaml]>=2.13.1",
"structlog>=25.5.0",
]
[project.scripts]
ragops = "rag_eval.cli:app"
[build-system]
requires = ["uv_build>=0.8.3,<0.9.0"]
build-backend = "uv_build"
[dependency-groups]
dev = [
"clai>=1.62.0",
"marimo>=0.20.1",
"pytest>=8.4.1",
"pytest-asyncio>=1.3.0",
"pytest-sugar>=1.1.1",
"ruff>=0.12.10",
"ty>=0.0.1a19",
]
@@ -43,4 +35,10 @@ exclude = ["scripts"]
[tool.ruff.lint]
select = ['ALL']
ignore = ['D203', 'D213', 'COM812']
ignore = [
'D203', # one-blank-line-before-class
'D213', # multi-line-summary-second-line
'COM812', # missing-trailing-comma
'ANN401', # dynamically-typed-any
'ERA001', # commented-out-code
]

View File

@@ -1,59 +0,0 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ADK agent with vector search RAG tool."""
from __future__ import annotations
import os
from google.adk.agents.llm_agent import Agent
from .config_helper import settings
from .vector_search_tool import VectorSearchTool
# Set environment variables for Google GenAI Client to use Vertex AI
os.environ["GOOGLE_CLOUD_PROJECT"] = settings.project_id
os.environ["GOOGLE_CLOUD_LOCATION"] = settings.location
# Create vector search tool with configuration
vector_search_tool = VectorSearchTool(
name='conocimiento',
description='Search the vector index for company products and services information',
embedder=settings.embedder,
project_id=settings.project_id,
location=settings.location,
bucket=settings.bucket,
index_name=settings.index_name,
index_endpoint=settings.index_endpoint,
index_deployed_id=settings.index_deployed_id,
similarity_top_k=5,
min_similarity_threshold=0.6,
relative_threshold_factor=0.9,
)
# Create agent with vector search tool
# Configure model with Vertex AI fully qualified path
model_path = (
f'projects/{settings.project_id}/locations/{settings.location}/'
f'publishers/google/models/{settings.agent_language_model}'
)
root_agent = Agent(
model=model_path,
name=settings.agent_name,
description='A helpful assistant for user questions.',
instruction=settings.agent_instructions,
tools=[vector_search_tool],
)

View File

@@ -1,68 +0,0 @@
"""Abstract base class for vector search providers."""
from abc import ABC, abstractmethod
from typing import Any, TypedDict
class SearchResult(TypedDict):
"""A single vector search result."""
id: str
distance: float
content: str
class BaseVectorSearch(ABC):
"""Abstract base class for a vector search provider.
This class defines the standard interface for creating a vector search
index and running queries against it.
"""
@abstractmethod
def create_index(
self, name: str, content_path: str, **kwargs: Any # noqa: ANN401
) -> None:
"""Create a new vector search index with the provided content.
Args:
name: The desired name for the new index.
content_path: Path to the data used to populate the index.
**kwargs: Additional provider-specific arguments.
"""
...
@abstractmethod
def update_index(
self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401
) -> None:
"""Update an existing vector search index with new content.
Args:
index_name: The name of the index to update.
content_path: Path to the data used to populate the index.
**kwargs: Additional provider-specific arguments.
"""
...
@abstractmethod
def run_query(
self,
deployed_index_id: str,
query: list[float],
limit: int,
) -> list[SearchResult]:
"""Run a similarity search query against the index.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
...

View File

@@ -1,120 +0,0 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration helper for ADK agent with vector search."""
from __future__ import annotations
import os
from dataclasses import dataclass
from functools import cached_property
import vertexai
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)
from vertexai.language_models import TextEmbeddingModel
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
@dataclass
class EmbeddingResult:
"""Result from embedding a query."""
embeddings: list[list[float]]
class VertexAIEmbedder:
"""Embedder using Vertex AI TextEmbeddingModel."""
def __init__(self, model_name: str, project_id: str, location: str) -> None:
"""Initialize the embedder.
Args:
model_name: Name of the embedding model (e.g., 'text-embedding-004')
project_id: GCP project ID
location: GCP location
"""
vertexai.init(project=project_id, location=location)
self.model = TextEmbeddingModel.from_pretrained(model_name)
async def embed_query(self, query: str) -> EmbeddingResult:
"""Embed a single query string.
Args:
query: Text to embed
Returns:
EmbeddingResult with embeddings list
"""
embeddings = self.model.get_embeddings([query])
return EmbeddingResult(embeddings=[list(embeddings[0].values)])
class AgentSettings(BaseSettings):
"""Settings for ADK agent with vector search."""
# Google Cloud settings
project_id: str
location: str
bucket: str
# Agent configuration
agent_name: str
agent_instructions: str
agent_language_model: str
agent_embedding_model: str
# Vector index configuration
index_name: str
index_deployed_id: str
index_endpoint: str
model_config = SettingsConfigDict(
yaml_file=CONFIG_FILE_PATH,
extra="ignore", # Ignore extra fields from config.yaml
)
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
) -> tuple[PydanticBaseSettingsSource, ...]:
"""Use env vars and YAML as settings sources."""
return (
env_settings,
YamlConfigSettingsSource(settings_cls),
)
@cached_property
def embedder(self) -> VertexAIEmbedder:
"""Return an embedder configured for the agent's embedding model."""
return VertexAIEmbedder(
model_name=self.agent_embedding_model,
project_id=self.project_id,
location=self.location,
)
settings = AgentSettings.model_validate({})

View File

@@ -1 +0,0 @@
"""File storage provider implementations."""

View File

@@ -1,56 +0,0 @@
"""Abstract base class for file storage providers."""
from abc import ABC, abstractmethod
from typing import BinaryIO
class BaseFileStorage(ABC):
"""Abstract base class for a remote file processor.
Defines the interface for listing and processing files from
a remote source.
"""
@abstractmethod
def upload_file(
self,
file_path: str,
destination_blob_name: str,
content_type: str | None = None,
) -> None:
"""Upload a file to the remote source.
Args:
file_path: The local path to the file to upload.
destination_blob_name: Name of the file in remote storage.
content_type: The content type of the file.
"""
...
@abstractmethod
def list_files(self, path: str | None = None) -> list[str]:
"""List files from a remote location.
Args:
path: Path to a specific file or directory. If None,
recursively lists all files in the bucket.
Returns:
A list of file paths.
"""
...
@abstractmethod
def get_file_stream(self, file_name: str) -> BinaryIO:
"""Get a file from the remote source as a file-like object.
Args:
file_name: The name of the file to retrieve.
Returns:
A file-like object containing the file data.
"""
...

View File

@@ -1,188 +0,0 @@
"""Google Cloud Storage file storage implementation."""
import asyncio
import io
import logging
from typing import BinaryIO
import aiohttp
from gcloud.aio.storage import Storage
from google.cloud import storage
from .base import BaseFileStorage
logger = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500
class GoogleCloudFileStorage(BaseFileStorage):
"""File storage backed by Google Cloud Storage."""
def __init__(self, bucket: str) -> None: # noqa: D107
self.bucket_name = bucket
self.storage_client = storage.Client()
self.bucket_client = self.storage_client.bucket(self.bucket_name)
self._aio_session: aiohttp.ClientSession | None = None
self._aio_storage: Storage | None = None
self._cache: dict[str, bytes] = {}
def upload_file(
self,
file_path: str,
destination_blob_name: str,
content_type: str | None = None,
) -> None:
"""Upload a file to Cloud Storage.
Args:
file_path: The local path to the file to upload.
destination_blob_name: Name of the blob in the bucket.
content_type: The content type of the file.
"""
blob = self.bucket_client.blob(destination_blob_name)
blob.upload_from_filename(
file_path,
content_type=content_type,
if_generation_match=0,
)
self._cache.pop(destination_blob_name, None)
def list_files(self, path: str | None = None) -> list[str]:
"""List all files at the given path in the bucket.
If path is None, recursively lists all files.
Args:
path: Prefix to filter files by.
Returns:
A list of blob names.
"""
blobs = self.storage_client.list_blobs(
self.bucket_name, prefix=path,
)
return [blob.name for blob in blobs]
def get_file_stream(self, file_name: str) -> BinaryIO:
"""Get a file as a file-like object, using cache.
Args:
file_name: The blob name to retrieve.
Returns:
A BytesIO stream with the file contents.
"""
if file_name not in self._cache:
blob = self.bucket_client.blob(file_name)
self._cache[file_name] = blob.download_as_bytes()
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300, limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout, connector=connector,
)
return self._aio_session
def _get_aio_storage(self) -> Storage:
if self._aio_storage is None:
self._aio_storage = Storage(
session=self._get_aio_session(),
)
return self._aio_storage
async def async_get_file_stream(
self, file_name: str, max_retries: int = 3,
) -> BinaryIO:
"""Get a file asynchronously with retry on transient errors.
Args:
file_name: The blob name to retrieve.
max_retries: Maximum number of retry attempts.
Returns:
A BytesIO stream with the file contents.
Raises:
TimeoutError: If all retry attempts fail.
"""
if file_name in self._cache:
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
storage_client = self._get_aio_storage()
last_exception: Exception | None = None
for attempt in range(max_retries):
try:
self._cache[file_name] = await storage_client.download(
self.bucket_name, file_name,
)
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
except TimeoutError as exc:
last_exception = exc
logger.warning(
"Timeout downloading gs://%s/%s (attempt %d/%d)",
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
except aiohttp.ClientResponseError as exc:
last_exception = exc
if (
exc.status == HTTP_TOO_MANY_REQUESTS
or exc.status >= HTTP_SERVER_ERROR
):
logger.warning(
"HTTP %d downloading gs://%s/%s "
"(attempt %d/%d)",
exc.status,
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
else:
raise
else:
return file_stream
if attempt < max_retries - 1:
delay = 0.5 * (2**attempt)
await asyncio.sleep(delay)
msg = (
f"Failed to download gs://{self.bucket_name}/{file_name} "
f"after {max_retries} attempts"
)
raise TimeoutError(msg) from last_exception
def delete_files(self, path: str) -> None:
"""Delete all files at the given path in the bucket.
Args:
path: Prefix of blobs to delete.
"""
blobs = self.storage_client.list_blobs(
self.bucket_name, prefix=path,
)
for blob in blobs:
blob.delete()
self._cache.pop(blob.name, None)

View File

@@ -1,176 +0,0 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A retrieval tool that uses Vertex AI Vector Search (not RAG Engine)."""
from __future__ import annotations
import logging
from typing import Any
from typing import TYPE_CHECKING
from google.adk.tools.tool_context import ToolContext
from typing_extensions import override
from .vertex_ai import GoogleCloudVectorSearch
from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool
if TYPE_CHECKING:
from .config_helper import VertexAIEmbedder
logger = logging.getLogger('google_adk.' + __name__)
class VectorSearchTool(BaseRetrievalTool):
"""A retrieval tool using Vertex AI Vector Search (not RAG Engine).
This tool uses GoogleCloudVectorSearch to query a vector index directly,
which is useful when Vertex AI RAG Engine is not available in your GCP project.
"""
def __init__(
self,
*,
name: str,
description: str,
embedder: VertexAIEmbedder,
project_id: str,
location: str,
bucket: str,
index_name: str,
index_endpoint: str,
index_deployed_id: str,
similarity_top_k: int = 5,
min_similarity_threshold: float = 0.6,
relative_threshold_factor: float = 0.9,
):
"""Initialize the VectorSearchTool.
Args:
name: Tool name for function declaration
description: Tool description for LLM
embedder: Embedder instance for query embedding
project_id: GCP project ID
location: GCP location (e.g., 'us-central1')
bucket: GCS bucket for content storage
index_name: Vector search index name
index_endpoint: Resource name of index endpoint
index_deployed_id: Deployed index ID
similarity_top_k: Number of results to retrieve (default: 5)
min_similarity_threshold: Minimum similarity score 0.0-1.0 (default: 0.6)
relative_threshold_factor: Factor of max similarity for dynamic filtering (default: 0.9)
"""
super().__init__(name=name, description=description)
self.embedder = embedder
self.index_endpoint = index_endpoint
self.index_deployed_id = index_deployed_id
self.similarity_top_k = similarity_top_k
self.min_similarity_threshold = min_similarity_threshold
self.relative_threshold_factor = relative_threshold_factor
# Initialize vector search (endpoint loaded lazily on first use)
self.vector_search = GoogleCloudVectorSearch(
project_id=project_id,
location=location,
bucket=bucket,
index_name=index_name,
)
self._endpoint_loaded = False
logger.info(
'VectorSearchTool initialized with index=%s, deployed_id=%s',
index_name,
index_deployed_id,
)
@override
async def run_async(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:
"""Execute vector search with the user's query.
Args:
args: Dictionary containing 'query' key
tool_context: Tool execution context
Returns:
Formatted search results as XML-like documents or error message
"""
query = args['query']
logger.debug('VectorSearchTool query: %s', query)
try:
# Load index endpoint on first use (lazy loading)
if not self._endpoint_loaded:
self.vector_search.load_index_endpoint(self.index_endpoint)
self._endpoint_loaded = True
logger.info('Index endpoint loaded successfully')
# Embed the query using the configured embedder
embedding_result = await self.embedder.embed_query(query)
query_embedding = list(embedding_result.embeddings[0])
# Run vector search
search_results = await self.vector_search.async_run_query(
deployed_index_id=self.index_deployed_id,
query=query_embedding,
limit=self.similarity_top_k,
)
# Apply similarity filtering (dual threshold approach)
if search_results:
# Dynamic threshold based on max similarity
max_similarity = max(r['distance'] for r in search_results)
dynamic_cutoff = max_similarity * self.relative_threshold_factor
# Filter by both absolute and relative thresholds
search_results = [
result
for result in search_results
if (
result['distance'] > dynamic_cutoff
and result['distance'] > self.min_similarity_threshold
)
]
logger.debug(
'VectorSearchTool results: %d documents after filtering',
len(search_results),
)
# Format results
if not search_results:
return (
f"No matching documents found for query: '{query}' "
f'(min_threshold={self.min_similarity_threshold})'
)
# Format as XML-like documents (matching pydantic_ai pattern)
formatted_results = [
f'<document {i} name={result["id"]}>\n'
f'{result["content"]}\n'
f'</document {i}>'
for i, result in enumerate(search_results, start=1)
]
return '\n'.join(formatted_results)
except Exception as e:
logger.error('VectorSearchTool error: %s', e, exc_info=True)
return f'Error during vector search: {str(e)}'

View File

@@ -1,310 +0,0 @@
"""Google Cloud Vertex AI Vector Search implementation."""
import asyncio
from collections.abc import Sequence
from typing import Any
from uuid import uuid4
import aiohttp
import google.auth
import google.auth.credentials
import google.auth.transport.requests
from gcloud.aio.auth import Token
from google.cloud import aiplatform
from .file_storage.google_cloud import GoogleCloudFileStorage
from .base import BaseVectorSearch, SearchResult
class GoogleCloudVectorSearch(BaseVectorSearch):
"""A vector search provider using Vertex AI Vector Search."""
def __init__(
self,
project_id: str,
location: str,
bucket: str,
index_name: str | None = None,
) -> None:
"""Initialize the GoogleCloudVectorSearch client.
Args:
project_id: The Google Cloud project ID.
location: The Google Cloud location (e.g., 'us-central1').
bucket: The GCS bucket to use for file storage.
index_name: The name of the index.
"""
aiplatform.init(project=project_id, location=location)
self.project_id = project_id
self.location = location
self.storage = GoogleCloudFileStorage(bucket=bucket)
self.index_name = index_name
self._credentials: google.auth.credentials.Credentials | None = None
self._aio_session: aiohttp.ClientSession | None = None
self._async_token: Token | None = None
def _get_auth_headers(self) -> dict[str, str]:
if self._credentials is None:
self._credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
if not self._credentials.token or self._credentials.expired:
self._credentials.refresh(
google.auth.transport.requests.Request(),
)
return {
"Authorization": f"Bearer {self._credentials.token}",
"Content-Type": "application/json",
}
async def _async_get_auth_headers(self) -> dict[str, str]:
if self._async_token is None:
self._async_token = Token(
session=self._get_aio_session(),
scopes=[
"https://www.googleapis.com/auth/cloud-platform",
],
)
access_token = await self._async_token.get()
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300, limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout, connector=connector,
)
return self._aio_session
def create_index(
self,
name: str,
content_path: str,
*,
dimensions: int = 3072,
approximate_neighbors_count: int = 150,
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
**kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""Create a new Vertex AI Vector Search index.
Args:
name: The display name for the new index.
content_path: GCS URI to the embeddings JSON file.
dimensions: Number of dimensions in embedding vectors.
approximate_neighbors_count: Neighbors to find per vector.
distance_measure_type: The distance measure to use.
**kwargs: Additional arguments.
"""
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
display_name=name,
contents_delta_uri=content_path,
dimensions=dimensions,
approximate_neighbors_count=approximate_neighbors_count,
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
leaf_node_embedding_count=1000,
leaf_nodes_to_search_percent=10,
)
self.index = index
def update_index(
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""Update an existing Vertex AI Vector Search index.
Args:
index_name: The resource name of the index to update.
content_path: GCS URI to the new embeddings JSON file.
**kwargs: Additional arguments.
"""
index = aiplatform.MatchingEngineIndex(index_name=index_name)
index.update_embeddings(
contents_delta_uri=content_path,
)
self.index = index
def deploy_index(
self,
index_name: str,
machine_type: str = "e2-standard-2",
) -> None:
"""Deploy a Vertex AI Vector Search index to an endpoint.
Args:
index_name: The name of the index to deploy.
machine_type: The machine type for the endpoint.
"""
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=f"{index_name}-endpoint",
public_endpoint_enabled=True,
)
index_endpoint.deploy_index(
index=self.index,
deployed_index_id=(
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
),
machine_type=machine_type,
)
self.index_endpoint = index_endpoint
def load_index_endpoint(self, endpoint_name: str) -> None:
"""Load an existing Vertex AI Vector Search index endpoint.
Args:
endpoint_name: The resource name of the index endpoint.
"""
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
endpoint_name,
)
if not self.index_endpoint.public_endpoint_domain_name:
msg = (
"The index endpoint does not have a public endpoint. "
"Ensure the endpoint is configured for public access."
)
raise ValueError(msg)
def run_query(
self,
deployed_index_id: str,
query: list[float],
limit: int,
) -> list[SearchResult]:
"""Run a similarity search query against the deployed index.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
response = self.index_endpoint.find_neighbors(
deployed_index_id=deployed_index_id,
queries=[query],
num_neighbors=limit,
)
results = []
for neighbor in response[0]:
file_path = (
f"{self.index_name}/contents/{neighbor.id}.md"
)
content = (
self.storage.get_file_stream(file_path)
.read()
.decode("utf-8")
)
results.append(
SearchResult(
id=neighbor.id,
distance=float(neighbor.distance or 0),
content=content,
),
)
return results
async def async_run_query(
self,
deployed_index_id: str,
query: Sequence[float],
limit: int,
) -> list[SearchResult]:
"""Run an async similarity search via the REST API.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
domain = self.index_endpoint.public_endpoint_domain_name
endpoint_id = self.index_endpoint.name.split("/")[-1]
url = (
f"https://{domain}/v1/projects/{self.project_id}"
f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors"
)
payload = {
"deployed_index_id": deployed_index_id,
"queries": [
{
"datapoint": {"feature_vector": list(query)},
"neighbor_count": limit,
},
],
}
headers = await self._async_get_auth_headers()
session = self._get_aio_session()
async with session.post(
url, json=payload, headers=headers,
) as response:
response.raise_for_status()
data = await response.json()
neighbors = (
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
)
content_tasks = []
for neighbor in neighbors:
datapoint_id = neighbor["datapoint"]["datapointId"]
file_path = (
f"{self.index_name}/contents/{datapoint_id}.md"
)
content_tasks.append(
self.storage.async_get_file_stream(file_path),
)
file_streams = await asyncio.gather(*content_tasks)
results: list[SearchResult] = []
for neighbor, stream in zip(
neighbors, file_streams, strict=True,
):
results.append(
SearchResult(
id=neighbor["datapoint"]["datapointId"],
distance=neighbor["distance"],
content=stream.read().decode("utf-8"),
),
)
return results
def delete_index(self, index_name: str) -> None:
"""Delete a Vertex AI Vector Search index.
Args:
index_name: The resource name of the index.
"""
index = aiplatform.MatchingEngineIndex(index_name)
index.delete()
def delete_index_endpoint(
self, index_endpoint_name: str,
) -> None:
"""Delete a Vertex AI Vector Search index endpoint.
Args:
index_endpoint_name: The resource name of the endpoint.
"""
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name,
)
index_endpoint.undeploy_all()
index_endpoint.delete(force=True)

View File

@@ -1,79 +0,0 @@
import asyncio
import logging
import os
import random
import typer
from dotenv import load_dotenv
from embedder.vertex_ai import VertexAIEmbedder
load_dotenv()
project = os.getenv("GOOGLE_CLOUD_PROJECT")
location = os.getenv("GOOGLE_CLOUD_LOCATION")
MODEL_NAME = "gemini-embedding-001"
CONTENT_LIST = [
"¿Cuáles son los beneficios de una tarjeta de crédito?",
"¿Cómo puedo abrir una cuenta de ahorros?",
"¿Qué es una hipoteca y cómo funciona?",
"¿Cuáles son las tasas de interés para un préstamo personal?",
"¿Cómo puedo solicitar un préstamo para un coche?",
"¿Qué es la banca en línea y cómo me registro?",
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
"¿Qué es el phishing y cómo puedo protegerme?",
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
"¿Cómo puedo transferir dinero a una cuenta internacional?",
]
TASK_TYPE = "RETRIEVAL_DOCUMENT"
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
app = typer.Typer()
logger.info(f"Initializing GenAI Client for project '{project}' in '{location}'")
embedder = VertexAIEmbedder(MODEL_NAME, project, location)
async def embed_content_task():
"""A single task to send one embedding request using the global client."""
content_to_embed = random.choice(CONTENT_LIST)
await embedder.async_generate_embedding(content_to_embed)
async def run_test(concurrency: int):
"""Continuously calls the embedding API and tracks requests."""
total_requests = 0
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on model '{MODEL_NAME}'.")
logger.info("Press Ctrl+C to stop.")
while True:
# Create tasks, passing project_id and location
tasks = [embed_content_task() for _ in range(concurrency)]
try:
await asyncio.gather(*tasks)
total_requests += concurrency
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
except Exception as e:
logger.error("Caught an error. Stopping test.")
print("\n--- STATS ---")
print(f"Total successful requests: {total_requests}")
print(f"Concurrent requests during failure: {concurrency}")
print(f"Error Type: {e.__class__.__name__}")
print(f"Error Details: {e}")
print("-------------")
break
@app.command()
def main(
concurrency: int = typer.Option(
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
),
):
try:
asyncio.run(run_test(concurrency))
except KeyboardInterrupt:
logger.info("\nKeyboard interrupt received. Exiting.")
if __name__ == "__main__":
app()

View File

@@ -1,99 +0,0 @@
import asyncio
import logging
import random
import httpx
import typer
CONTENT_LIST = [
"¿Cuáles son los beneficios de una tarjeta de crédito?",
"¿Cómo puedo abrir una cuenta de ahorros?",
"¿Qué es una hipoteca y cómo funciona?",
"¿Cuáles son las tasas de interés para un préstamo personal?",
"¿Cómo puedo solicitar un préstamo para un coche?",
"¿Qué es la banca en línea y cómo me registro?",
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
"¿Qué es el phishing y cómo puedo protegerme?",
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
"¿Cómo puedo transferir dinero a una cuenta internacional?",
]
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
app = typer.Typer()
async def call_rag_endpoint_task(client: httpx.AsyncClient, url: str):
"""A single task to send one request to the RAG endpoint."""
question = random.choice(CONTENT_LIST)
json_payload = {
"sessionInfo": {
"parameters": {
"query": question
}
}
}
response = await client.post(url, json=json_payload)
response.raise_for_status() # Raise an exception for bad status codes
response_data = response.json()
response_text = response_data["sessionInfo"]["parameters"]["response"]
logger.info(f"Question: {question[:50]}... Response: {response_text[:100]}...")
async def run_test(concurrency: int, url: str, timeout_seconds: float):
"""Continuously calls the RAG endpoint and tracks requests."""
total_requests = 0
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on endpoint '{url}'.")
logger.info(f"Request timeout is set to {timeout_seconds} seconds.")
logger.info("Press Ctrl+C to stop.")
timeout = httpx.Timeout(timeout_seconds)
async with httpx.AsyncClient(timeout=timeout) as client:
while True:
tasks = [call_rag_endpoint_task(client, url) for _ in range(concurrency)]
try:
await asyncio.gather(*tasks)
total_requests += concurrency
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
except httpx.TimeoutException as e:
logger.error(f"A request timed out: {e.request.method} {e.request.url}")
logger.error("Consider increasing the timeout with the --timeout option.")
break
except httpx.HTTPStatusError as e:
logger.error(f"An HTTP error occurred: {e.response.status_code} - {e.request.method} {e.request.url}")
logger.error(f"Response body: {e.response.text}")
break
except httpx.RequestError as e:
logger.error(f"A request error occurred: {e.request.method} {e.request.url}")
logger.error(f"Error details: {e}")
break
except Exception as e:
logger.error("Caught an unexpected error. Stopping test.")
print("\n--- STATS ---")
print(f"Total successful requests: {total_requests}")
print(f"Concurrent requests during failure: {concurrency}")
print(f"Error Type: {e.__class__.__name__}")
print(f"Error Details: {e}")
print("-------------")
break
@app.command()
def main(
concurrency: int = typer.Option(
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
),
url: str = typer.Option(
"http://127.0.0.1:8000/sigma-rag", "--url", "-u", help="The URL of the RAG endpoint to test."
),
timeout_seconds: float = typer.Option(
30.0, "--timeout", "-t", help="Request timeout in seconds."
)
):
try:
asyncio.run(run_test(concurrency, url, timeout_seconds))
except KeyboardInterrupt:
logger.info("\nKeyboard interrupt received. Exiting.")
if __name__ == "__main__":
app()

View File

@@ -1,91 +0,0 @@
import concurrent.futures
import random
import threading
import requests
# URL for the endpoint
url = "http://localhost:8000/sigma-rag"
# List of Spanish banking questions
spanish_questions = [
"¿Cuáles son los beneficios de una tarjeta de crédito?",
"¿Cómo puedo abrir una cuenta de ahorros?",
"¿Qué es una hipoteca y cómo funciona?",
"¿Cuáles son las tasas de interés para un préstamo personal?",
"¿Cómo puedo solicitar un préstamo para un coche?",
"¿Qué es la banca en línea y cómo me registro?",
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
"¿Qué es el phishing y cómo puedo protegerme?",
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
"¿Cómo puedo transferir dinero a una cuenta internacional?",
]
# A threading Event to signal all threads to stop
stop_event = threading.Event()
def send_request(question, request_id):
"""Sends a single request and handles the response."""
if stop_event.is_set():
return
data = {"sessionInfo": {"parameters": {"query": question}}}
try:
response = requests.post(url, json=data)
if stop_event.is_set():
return
if response.status_code == 500:
print(f"Request {request_id}: Received 500 error with question: '{question}'.")
print("Stopping stress test.")
stop_event.set()
else:
print(f"Request {request_id}: Successful with status code {response.status_code}.")
except requests.exceptions.RequestException as e:
if not stop_event.is_set():
print(f"Request {request_id}: An error occurred: {e}")
stop_event.set()
def main():
"""Runs the stress test with parallel requests."""
num_workers = 30 # Number of parallel requests
print(f"Starting stress test with {num_workers} parallel workers. Press Ctrl+C to stop.")
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {
executor.submit(send_request, random.choice(spanish_questions), i)
for i in range(1, num_workers + 1)
}
request_id_counter = num_workers + 1
try:
while not stop_event.is_set():
# Wait for any future to complete
done, _ = concurrent.futures.wait(
futures, return_when=concurrent.futures.FIRST_COMPLETED
)
for future in done:
# Remove the completed future
futures.remove(future)
# If we are not stopping, submit a new one
if not stop_event.is_set():
futures.add(
executor.submit(
send_request,
random.choice(spanish_questions),
request_id_counter,
)
)
request_id_counter += 1
except KeyboardInterrupt:
print("\nKeyboard interrupt received. Stopping threads.")
stop_event.set()
print("Stress test finished.")
if __name__ == "__main__":
main()

View File

@@ -1,84 +0,0 @@
from typing import Annotated
import typer
from google.cloud import aiplatform
from rag_eval.config import settings
app = typer.Typer()
@app.command()
def main(
pipeline_spec_path: Annotated[
str,
typer.Option(
"--pipeline-spec-path",
"-p",
help="Path to the compiled pipeline YAML file.",
),
],
input_table: Annotated[
str,
typer.Option(
"--input-table",
"-i",
help="Full BigQuery table name for input (e.g., 'project.dataset.table')",
),
],
output_table: Annotated[
str,
typer.Option(
"--output-table",
"-o",
help="Full BigQuery table name for output (e.g., 'project.dataset.table')",
),
],
project_id: Annotated[
str,
typer.Option(
"--project-id",
help="Google Cloud project ID.",
),
] = settings.project_id,
location: Annotated[
str,
typer.Option(
"--location",
help="Google Cloud location for the pipeline job.",
),
] = settings.location,
display_name: Annotated[
str,
typer.Option(
"--display-name",
help="Display name for the pipeline job.",
),
] = "search-eval-pipeline-job",
):
"""Submits a Vertex AI pipeline job."""
parameter_values = {
"project_id": project_id,
"location": location,
"input_table": input_table,
"output_table": output_table,
}
job = aiplatform.PipelineJob(
display_name=display_name,
template_path=pipeline_spec_path,
pipeline_root=f"gs://{settings.bucket}/pipeline_root",
parameter_values=parameter_values,
project=project_id,
location=location,
)
print(f"Submitting pipeline job with parameters: {parameter_values}")
job.submit(
service_account="sa-cicd-gitlab@bnt-orquestador-cognitivo-dev.iam.gserviceaccount.com"
)
print(f"Pipeline job submitted. You can view it at: {job._dashboard_uri()}")
if __name__ == "__main__":
app()

View File

@@ -1,42 +0,0 @@
from google.cloud import discoveryengine_v1 as discoveryengine
# TODO(developer): Uncomment these variables before running the sample.
project_id = "bnt-orquestador-cognitivo-dev"
client = discoveryengine.RankServiceClient()
# The full resource name of the ranking config.
# Format: projects/{project_id}/locations/{location}/rankingConfigs/default_ranking_config
ranking_config = client.ranking_config_path(
project=project_id,
location="global",
ranking_config="default_ranking_config",
)
request = discoveryengine.RankRequest(
ranking_config=ranking_config,
model="semantic-ranker-default@latest",
top_n=10,
query="What is Google Gemini?",
records=[
discoveryengine.RankingRecord(
id="1",
title="Gemini",
content="The Gemini zodiac symbol often depicts two figures standing side-by-side.",
),
discoveryengine.RankingRecord(
id="2",
title="Gemini",
content="Gemini is a cutting edge large language model created by Google.",
),
discoveryengine.RankingRecord(
id="3",
title="Gemini Constellation",
content="Gemini is a constellation that can be seen in the night sky.",
),
],
)
response = client.rank(request=request)
# Handle the response
print(response)

View File

@@ -1,12 +0,0 @@
import requests
# Test the /sigma-rag endpoint
url = "http://localhost:8000/sigma-rag"
data = {
"sessionInfo": {"parameters": {"query": "What are the benefits of a credit card?"}}
}
response = requests.post(url, json=data)
print("Response from /sigma-rag:")
print(response.json())

View File

@@ -1 +0,0 @@
"""RAG evaluation agent package."""

View File

@@ -1,92 +0,0 @@
"""Pydantic AI agent with RAG tool for vector search."""
import time
import structlog
from pydantic import BaseModel
from pydantic_ai import Agent, Embedder, RunContext
from pydantic_ai.models.google import GoogleModel
from rag_eval.config import settings
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
logger = structlog.get_logger(__name__)
class Deps(BaseModel):
"""Dependencies injected into the agent at runtime."""
vector_search: GoogleCloudVectorSearch
embedder: Embedder
model_config = {"arbitrary_types_allowed": True}
model = GoogleModel(
settings.agent_language_model,
provider=settings.provider,
)
agent = Agent(
model,
deps_type=Deps,
system_prompt=settings.agent_instructions,
)
@agent.tool
async def conocimiento(ctx: RunContext[Deps], query: str) -> str:
"""Search the vector index for the given query.
Args:
ctx: The run context containing dependencies.
query: The query to search for.
Returns:
A formatted string containing the search results.
"""
t0 = time.perf_counter()
min_sim = 0.6
query_embedding = await ctx.deps.embedder.embed_query(query)
t_embed = time.perf_counter()
search_results = await ctx.deps.vector_search.async_run_query(
deployed_index_id=settings.index_deployed_id,
query=list(query_embedding.embeddings[0]),
limit=5,
)
t_search = time.perf_counter()
if search_results:
max_sim = max(r["distance"] for r in search_results)
cutoff = max_sim * 0.9
search_results = [
s
for s in search_results
if s["distance"] > cutoff and s["distance"] > min_sim
]
logger.info(
"conocimiento.timing",
embedding_ms=round((t_embed - t0) * 1000, 1),
vector_search_ms=round((t_search - t_embed) * 1000, 1),
total_ms=round((t_search - t0) * 1000, 1),
chunks=[s["id"] for s in search_results],
)
formatted_results = [
f"<document {i} name={result['id']}>\n"
f"{result['content']}\n"
f"</document {i}>"
for i, result in enumerate(search_results, start=1)
]
return "\n".join(formatted_results)
if __name__ == "__main__":
deps = Deps(
vector_search=settings.vector_search,
embedder=settings.embedder,
)
agent.to_cli_sync(deps=deps)

View File

@@ -1,92 +0,0 @@
"""Application settings loaded from YAML and environment variables."""
import os
from functools import cached_property
from pydantic_ai import Embedder
from pydantic_ai.providers.google import GoogleProvider
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
class Settings(BaseSettings):
"""Application settings loaded from config.yaml and env vars."""
project_id: str
location: str
bucket: str
agent_name: str
agent_instructions: str
agent_language_model: str
agent_embedding_model: str
agent_thinking: int
index_name: str
index_deployed_id: str
index_endpoint: str
index_dimensions: int
index_machine_type: str = "e2-standard-16"
index_origin: str
index_destination: str
index_chunk_limit: int
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
) -> tuple[PydanticBaseSettingsSource, ...]:
"""Use env vars and YAML as settings sources."""
return (
env_settings,
YamlConfigSettingsSource(settings_cls),
)
@cached_property
def provider(self) -> GoogleProvider:
"""Return a Google provider configured for Vertex AI."""
return GoogleProvider(
project=self.project_id,
location=self.location,
)
@cached_property
def vector_search(self) -> GoogleCloudVectorSearch:
"""Return a configured vector search client."""
vs = GoogleCloudVectorSearch(
project_id=self.project_id,
location=self.location,
bucket=self.bucket,
index_name=self.index_name,
)
vs.load_index_endpoint(self.index_endpoint)
return vs
@cached_property
def embedder(self) -> Embedder:
"""Return an embedder configured for the agent's embedding model."""
from pydantic_ai.embeddings.google import GoogleEmbeddingModel # noqa: PLC0415
model = GoogleEmbeddingModel(
self.agent_embedding_model,
provider=self.provider,
)
return Embedder(model)
settings = Settings.model_validate({})

View File

@@ -1 +0,0 @@
"""File storage provider implementations."""

View File

@@ -1,56 +0,0 @@
"""Abstract base class for file storage providers."""
from abc import ABC, abstractmethod
from typing import BinaryIO
class BaseFileStorage(ABC):
"""Abstract base class for a remote file processor.
Defines the interface for listing and processing files from
a remote source.
"""
@abstractmethod
def upload_file(
self,
file_path: str,
destination_blob_name: str,
content_type: str | None = None,
) -> None:
"""Upload a file to the remote source.
Args:
file_path: The local path to the file to upload.
destination_blob_name: Name of the file in remote storage.
content_type: The content type of the file.
"""
...
@abstractmethod
def list_files(self, path: str | None = None) -> list[str]:
"""List files from a remote location.
Args:
path: Path to a specific file or directory. If None,
recursively lists all files in the bucket.
Returns:
A list of file paths.
"""
...
@abstractmethod
def get_file_stream(self, file_name: str) -> BinaryIO:
"""Get a file from the remote source as a file-like object.
Args:
file_name: The name of the file to retrieve.
Returns:
A file-like object containing the file data.
"""
...

View File

@@ -1,188 +0,0 @@
"""Google Cloud Storage file storage implementation."""
import asyncio
import io
import logging
from typing import BinaryIO
import aiohttp
from gcloud.aio.storage import Storage
from google.cloud import storage
from rag_eval.file_storage.base import BaseFileStorage
logger = logging.getLogger(__name__)
HTTP_TOO_MANY_REQUESTS = 429
HTTP_SERVER_ERROR = 500
class GoogleCloudFileStorage(BaseFileStorage):
"""File storage backed by Google Cloud Storage."""
def __init__(self, bucket: str) -> None: # noqa: D107
self.bucket_name = bucket
self.storage_client = storage.Client()
self.bucket_client = self.storage_client.bucket(self.bucket_name)
self._aio_session: aiohttp.ClientSession | None = None
self._aio_storage: Storage | None = None
self._cache: dict[str, bytes] = {}
def upload_file(
self,
file_path: str,
destination_blob_name: str,
content_type: str | None = None,
) -> None:
"""Upload a file to Cloud Storage.
Args:
file_path: The local path to the file to upload.
destination_blob_name: Name of the blob in the bucket.
content_type: The content type of the file.
"""
blob = self.bucket_client.blob(destination_blob_name)
blob.upload_from_filename(
file_path,
content_type=content_type,
if_generation_match=0,
)
self._cache.pop(destination_blob_name, None)
def list_files(self, path: str | None = None) -> list[str]:
"""List all files at the given path in the bucket.
If path is None, recursively lists all files.
Args:
path: Prefix to filter files by.
Returns:
A list of blob names.
"""
blobs = self.storage_client.list_blobs(
self.bucket_name, prefix=path,
)
return [blob.name for blob in blobs]
def get_file_stream(self, file_name: str) -> BinaryIO:
"""Get a file as a file-like object, using cache.
Args:
file_name: The blob name to retrieve.
Returns:
A BytesIO stream with the file contents.
"""
if file_name not in self._cache:
blob = self.bucket_client.blob(file_name)
self._cache[file_name] = blob.download_as_bytes()
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300, limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout, connector=connector,
)
return self._aio_session
def _get_aio_storage(self) -> Storage:
if self._aio_storage is None:
self._aio_storage = Storage(
session=self._get_aio_session(),
)
return self._aio_storage
async def async_get_file_stream(
self, file_name: str, max_retries: int = 3,
) -> BinaryIO:
"""Get a file asynchronously with retry on transient errors.
Args:
file_name: The blob name to retrieve.
max_retries: Maximum number of retry attempts.
Returns:
A BytesIO stream with the file contents.
Raises:
TimeoutError: If all retry attempts fail.
"""
if file_name in self._cache:
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
return file_stream
storage_client = self._get_aio_storage()
last_exception: Exception | None = None
for attempt in range(max_retries):
try:
self._cache[file_name] = await storage_client.download(
self.bucket_name, file_name,
)
file_stream = io.BytesIO(self._cache[file_name])
file_stream.name = file_name
except TimeoutError as exc:
last_exception = exc
logger.warning(
"Timeout downloading gs://%s/%s (attempt %d/%d)",
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
except aiohttp.ClientResponseError as exc:
last_exception = exc
if (
exc.status == HTTP_TOO_MANY_REQUESTS
or exc.status >= HTTP_SERVER_ERROR
):
logger.warning(
"HTTP %d downloading gs://%s/%s "
"(attempt %d/%d)",
exc.status,
self.bucket_name,
file_name,
attempt + 1,
max_retries,
)
else:
raise
else:
return file_stream
if attempt < max_retries - 1:
delay = 0.5 * (2**attempt)
await asyncio.sleep(delay)
msg = (
f"Failed to download gs://{self.bucket_name}/{file_name} "
f"after {max_retries} attempts"
)
raise TimeoutError(msg) from last_exception
def delete_files(self, path: str) -> None:
"""Delete all files at the given path in the bucket.
Args:
path: Prefix of blobs to delete.
"""
blobs = self.storage_client.list_blobs(
self.bucket_name, prefix=path,
)
for blob in blobs:
blob.delete()
self._cache.pop(blob.name, None)

View File

@@ -1,47 +0,0 @@
"""Structured logging configuration using structlog."""
import logging
import sys
import structlog
def setup_logging(*, json: bool = True, level: int = logging.INFO) -> None:
"""Configure structlog with JSON or console output."""
shared_processors: list[structlog.types.Processor] = [
structlog.contextvars.merge_contextvars,
structlog.stdlib.add_log_level,
structlog.stdlib.add_logger_name,
structlog.processors.TimeStamper(fmt="iso"),
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
]
if json:
formatter = structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
structlog.processors.JSONRenderer(),
],
)
else:
formatter = structlog.stdlib.ProcessorFormatter(
processors=[
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
structlog.dev.ConsoleRenderer(),
],
)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
root = logging.getLogger()
root.handlers.clear()
root.addHandler(handler)
root.setLevel(level)
structlog.configure(
processors=shared_processors,
logger_factory=structlog.stdlib.LoggerFactory(),
wrapper_class=structlog.stdlib.BoundLogger,
cache_logger_on_first_use=True,
)

View File

@@ -1,61 +0,0 @@
"""FastAPI server exposing the RAG agent endpoint."""
import time
from typing import Literal
from uuid import uuid4
import structlog
from fastapi import FastAPI
from pydantic import BaseModel
from rag_eval.agent import Deps, agent
from rag_eval.config import settings
from rag_eval.logging import setup_logging
logger = structlog.get_logger(__name__)
setup_logging()
app = FastAPI(title="RAG Agent")
class Message(BaseModel):
"""A single chat message."""
role: Literal["system", "user", "assistant"]
content: str
class AgentRequest(BaseModel):
"""Request body for the agent endpoint."""
messages: list[Message]
class AgentResponse(BaseModel):
"""Response body from the agent endpoint."""
response: str
@app.post("/agent")
async def run_agent(request: AgentRequest) -> AgentResponse:
"""Run the RAG agent with the provided messages."""
request_id = uuid4().hex[:8]
structlog.contextvars.clear_contextvars()
structlog.contextvars.bind_contextvars(request_id=request_id)
prompt = request.messages[-1].content
logger.info("request.start", prompt_length=len(prompt))
t0 = time.perf_counter()
deps = Deps(
vector_search=settings.vector_search,
embedder=settings.embedder,
)
result = await agent.run(prompt, deps=deps)
elapsed = round((time.perf_counter() - t0) * 1000, 1)
logger.info("request.end", elapsed_ms=elapsed)
return AgentResponse(response=result.output)

View File

@@ -1 +0,0 @@
"""Vector search provider implementations."""

View File

@@ -1,68 +0,0 @@
"""Abstract base class for vector search providers."""
from abc import ABC, abstractmethod
from typing import Any, TypedDict
class SearchResult(TypedDict):
"""A single vector search result."""
id: str
distance: float
content: str
class BaseVectorSearch(ABC):
"""Abstract base class for a vector search provider.
This class defines the standard interface for creating a vector search
index and running queries against it.
"""
@abstractmethod
def create_index(
self, name: str, content_path: str, **kwargs: Any # noqa: ANN401
) -> None:
"""Create a new vector search index with the provided content.
Args:
name: The desired name for the new index.
content_path: Path to the data used to populate the index.
**kwargs: Additional provider-specific arguments.
"""
...
@abstractmethod
def update_index(
self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401
) -> None:
"""Update an existing vector search index with new content.
Args:
index_name: The name of the index to update.
content_path: Path to the data used to populate the index.
**kwargs: Additional provider-specific arguments.
"""
...
@abstractmethod
def run_query(
self,
deployed_index_id: str,
query: list[float],
limit: int,
) -> list[SearchResult]:
"""Run a similarity search query against the index.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
...

View File

@@ -1,310 +0,0 @@
"""Google Cloud Vertex AI Vector Search implementation."""
import asyncio
from collections.abc import Sequence
from typing import Any
from uuid import uuid4
import aiohttp
import google.auth
import google.auth.credentials
import google.auth.transport.requests
from gcloud.aio.auth import Token
from google.cloud import aiplatform
from rag_eval.file_storage.google_cloud import GoogleCloudFileStorage
from rag_eval.vector_search.base import BaseVectorSearch, SearchResult
class GoogleCloudVectorSearch(BaseVectorSearch):
"""A vector search provider using Vertex AI Vector Search."""
def __init__(
self,
project_id: str,
location: str,
bucket: str,
index_name: str | None = None,
) -> None:
"""Initialize the GoogleCloudVectorSearch client.
Args:
project_id: The Google Cloud project ID.
location: The Google Cloud location (e.g., 'us-central1').
bucket: The GCS bucket to use for file storage.
index_name: The name of the index.
"""
aiplatform.init(project=project_id, location=location)
self.project_id = project_id
self.location = location
self.storage = GoogleCloudFileStorage(bucket=bucket)
self.index_name = index_name
self._credentials: google.auth.credentials.Credentials | None = None
self._aio_session: aiohttp.ClientSession | None = None
self._async_token: Token | None = None
def _get_auth_headers(self) -> dict[str, str]:
if self._credentials is None:
self._credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
if not self._credentials.token or self._credentials.expired:
self._credentials.refresh(
google.auth.transport.requests.Request(),
)
return {
"Authorization": f"Bearer {self._credentials.token}",
"Content-Type": "application/json",
}
async def _async_get_auth_headers(self) -> dict[str, str]:
if self._async_token is None:
self._async_token = Token(
session=self._get_aio_session(),
scopes=[
"https://www.googleapis.com/auth/cloud-platform",
],
)
access_token = await self._async_token.get()
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def _get_aio_session(self) -> aiohttp.ClientSession:
if self._aio_session is None or self._aio_session.closed:
connector = aiohttp.TCPConnector(
limit=300, limit_per_host=50,
)
timeout = aiohttp.ClientTimeout(total=60)
self._aio_session = aiohttp.ClientSession(
timeout=timeout, connector=connector,
)
return self._aio_session
def create_index(
self,
name: str,
content_path: str,
*,
dimensions: int = 3072,
approximate_neighbors_count: int = 150,
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
**kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""Create a new Vertex AI Vector Search index.
Args:
name: The display name for the new index.
content_path: GCS URI to the embeddings JSON file.
dimensions: Number of dimensions in embedding vectors.
approximate_neighbors_count: Neighbors to find per vector.
distance_measure_type: The distance measure to use.
**kwargs: Additional arguments.
"""
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
display_name=name,
contents_delta_uri=content_path,
dimensions=dimensions,
approximate_neighbors_count=approximate_neighbors_count,
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
leaf_node_embedding_count=1000,
leaf_nodes_to_search_percent=10,
)
self.index = index
def update_index(
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
) -> None:
"""Update an existing Vertex AI Vector Search index.
Args:
index_name: The resource name of the index to update.
content_path: GCS URI to the new embeddings JSON file.
**kwargs: Additional arguments.
"""
index = aiplatform.MatchingEngineIndex(index_name=index_name)
index.update_embeddings(
contents_delta_uri=content_path,
)
self.index = index
def deploy_index(
self,
index_name: str,
machine_type: str = "e2-standard-2",
) -> None:
"""Deploy a Vertex AI Vector Search index to an endpoint.
Args:
index_name: The name of the index to deploy.
machine_type: The machine type for the endpoint.
"""
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=f"{index_name}-endpoint",
public_endpoint_enabled=True,
)
index_endpoint.deploy_index(
index=self.index,
deployed_index_id=(
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
),
machine_type=machine_type,
)
self.index_endpoint = index_endpoint
def load_index_endpoint(self, endpoint_name: str) -> None:
"""Load an existing Vertex AI Vector Search index endpoint.
Args:
endpoint_name: The resource name of the index endpoint.
"""
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
endpoint_name,
)
if not self.index_endpoint.public_endpoint_domain_name:
msg = (
"The index endpoint does not have a public endpoint. "
"Ensure the endpoint is configured for public access."
)
raise ValueError(msg)
def run_query(
self,
deployed_index_id: str,
query: list[float],
limit: int,
) -> list[SearchResult]:
"""Run a similarity search query against the deployed index.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
response = self.index_endpoint.find_neighbors(
deployed_index_id=deployed_index_id,
queries=[query],
num_neighbors=limit,
)
results = []
for neighbor in response[0]:
file_path = (
f"{self.index_name}/contents/{neighbor.id}.md"
)
content = (
self.storage.get_file_stream(file_path)
.read()
.decode("utf-8")
)
results.append(
SearchResult(
id=neighbor.id,
distance=float(neighbor.distance or 0),
content=content,
),
)
return results
async def async_run_query(
self,
deployed_index_id: str,
query: Sequence[float],
limit: int,
) -> list[SearchResult]:
"""Run an async similarity search via the REST API.
Args:
deployed_index_id: The ID of the deployed index.
query: The embedding vector for the search query.
limit: Maximum number of nearest neighbors to return.
Returns:
A list of matched items with id, distance, and content.
"""
domain = self.index_endpoint.public_endpoint_domain_name
endpoint_id = self.index_endpoint.name.split("/")[-1]
url = (
f"https://{domain}/v1/projects/{self.project_id}"
f"/locations/{self.location}"
f"/indexEndpoints/{endpoint_id}:findNeighbors"
)
payload = {
"deployed_index_id": deployed_index_id,
"queries": [
{
"datapoint": {"feature_vector": list(query)},
"neighbor_count": limit,
},
],
}
headers = await self._async_get_auth_headers()
session = self._get_aio_session()
async with session.post(
url, json=payload, headers=headers,
) as response:
response.raise_for_status()
data = await response.json()
neighbors = (
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
)
content_tasks = []
for neighbor in neighbors:
datapoint_id = neighbor["datapoint"]["datapointId"]
file_path = (
f"{self.index_name}/contents/{datapoint_id}.md"
)
content_tasks.append(
self.storage.async_get_file_stream(file_path),
)
file_streams = await asyncio.gather(*content_tasks)
results: list[SearchResult] = []
for neighbor, stream in zip(
neighbors, file_streams, strict=True,
):
results.append(
SearchResult(
id=neighbor["datapoint"]["datapointId"],
distance=neighbor["distance"],
content=stream.read().decode("utf-8"),
),
)
return results
def delete_index(self, index_name: str) -> None:
"""Delete a Vertex AI Vector Search index.
Args:
index_name: The resource name of the index.
"""
index = aiplatform.MatchingEngineIndex(index_name)
index.delete()
def delete_index_endpoint(
self, index_endpoint_name: str,
) -> None:
"""Delete a Vertex AI Vector Search index endpoint.
Args:
index_endpoint_name: The resource name of the endpoint.
"""
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name,
)
index_endpoint.undeploy_all()
index_endpoint.delete(force=True)

View File

@@ -4,7 +4,3 @@ import os
# Ensure the Google GenAI SDK talks to Vertex AI instead of the public Gemini API.
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "true")
from .agent import root_agent
__all__ = ["root_agent"]

29
src/va_agent/agent.py Normal file
View File

@@ -0,0 +1,29 @@
"""ADK agent with vector search RAG tool."""
from google import genai
from google.adk.agents.llm_agent import Agent
from google.adk.runners import Runner
from google.adk.tools.mcp_tool import McpToolset
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
from google.cloud.firestore_v1.async_client import AsyncClient
from va_agent.config import settings
from va_agent.session import FirestoreSessionService
connection_params = SseConnectionParams(url=settings.mcp_remote_url)
toolset = McpToolset(connection_params=connection_params)
agent = Agent(
model=settings.agent_model,
name=settings.agent_name,
instruction=settings.agent_instructions,
tools=[toolset],
)
session_service = FirestoreSessionService(
db=AsyncClient(database=settings.firestore_db),
compaction_token_threshold=10_000,
genai_client=genai.Client(),
)
runner = Runner(app_name="va_agent", agent=agent, session_service=session_service)

53
src/va_agent/config.py Normal file
View File

@@ -0,0 +1,53 @@
"""Configuration helper for ADK agent."""
import os
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
YamlConfigSettingsSource,
)
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
class AgentSettings(BaseSettings):
"""Settings for ADK agent with vector search."""
google_cloud_project: str
google_cloud_location: str
# Agent configuration
agent_name: str
agent_instructions: str
agent_model: str
# Firestore configuration
firestore_db: str
# MCP configuration
mcp_remote_url: str
model_config = SettingsConfigDict(
yaml_file=CONFIG_FILE_PATH,
extra="ignore", # Ignore extra fields from config.yaml
)
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
) -> tuple[PydanticBaseSettingsSource, ...]:
"""Use env vars and YAML as settings sources."""
return (
env_settings,
YamlConfigSettingsSource(settings_cls),
)
settings = AgentSettings.model_validate({})

10
src/va_agent/server.py Normal file
View File

@@ -0,0 +1,10 @@
"""FastAPI server exposing the RAG agent endpoint.
NOTE: This file is a stub. The rag_eval module was removed in the
lean MCP implementation. This file is kept for reference but is not
functional.
"""
from fastapi import FastAPI
app = FastAPI(title="RAG Agent")

582
src/va_agent/session.py Normal file
View File

@@ -0,0 +1,582 @@
"""Firestore-backed session service for Google ADK."""
from __future__ import annotations
import asyncio
import logging
import time
import uuid
from typing import TYPE_CHECKING, Any, override
from google.adk.errors.already_exists_error import AlreadyExistsError
from google.adk.events.event import Event
from google.adk.sessions import _session_util
from google.adk.sessions.base_session_service import (
BaseSessionService,
GetSessionConfig,
ListSessionsResponse,
)
from google.adk.sessions.session import Session
from google.adk.sessions.state import State
from google.cloud.firestore_v1.async_transaction import async_transactional
from google.cloud.firestore_v1.base_query import FieldFilter
from google.cloud.firestore_v1.field_path import FieldPath
from google.genai.types import Content, Part
if TYPE_CHECKING:
from google import genai
from google.cloud.firestore_v1.async_client import AsyncClient
logger = logging.getLogger("google_adk." + __name__)
_COMPACTION_LOCK_TTL = 300 # seconds
@async_transactional
async def _try_claim_compaction_txn(transaction: Any, session_ref: Any) -> bool:
"""Atomically claim the compaction lock if it is free or stale."""
snapshot = await session_ref.get(transaction=transaction)
if not snapshot.exists:
return False
data = snapshot.to_dict() or {}
lock_time = data.get("compaction_lock")
if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL:
return False
transaction.update(session_ref, {"compaction_lock": time.time()})
return True
class FirestoreSessionService(BaseSessionService):
"""A Firestore-backed implementation of BaseSessionService.
Firestore document layout (given ``collection_prefix="adk"``)::
adk_app_states/{app_name}
→ app-scoped state key/values
adk_user_states/{app_name}__{user_id}
→ user-scoped state key/values
adk_sessions/{app_name}__{user_id}__{session_id}
{app_name, user_id, session_id, state: {…}, last_update_time}
└─ events/{event_id} → serialised Event
"""
def __init__( # noqa: PLR0913
self,
*,
db: AsyncClient,
collection_prefix: str = "adk",
compaction_token_threshold: int | None = None,
compaction_model: str = "gemini-2.5-flash",
compaction_keep_recent: int = 10,
genai_client: genai.Client | None = None,
) -> None:
"""Initialize FirestoreSessionService.
Args:
db: Firestore async client
collection_prefix: Prefix for Firestore collections
compaction_token_threshold: Token count threshold for compaction
compaction_model: Model to use for summarization
compaction_keep_recent: Number of recent events to keep
genai_client: GenAI client for compaction summaries
"""
if compaction_token_threshold is not None and genai_client is None:
msg = "genai_client is required when compaction_token_threshold is set."
raise ValueError(msg)
self._db = db
self._prefix = collection_prefix
self._compaction_threshold = compaction_token_threshold
self._compaction_model = compaction_model
self._compaction_keep_recent = compaction_keep_recent
self._genai_client = genai_client
self._compaction_locks: dict[str, asyncio.Lock] = {}
self._active_tasks: set[asyncio.Task] = set()
# ------------------------------------------------------------------
# Document-reference helpers
# ------------------------------------------------------------------
def _app_state_ref(self, app_name: str) -> Any:
return self._db.collection(f"{self._prefix}_app_states").document(app_name)
def _user_state_ref(self, app_name: str, user_id: str) -> Any:
return self._db.collection(f"{self._prefix}_user_states").document(
f"{app_name}__{user_id}"
)
def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any:
return self._db.collection(f"{self._prefix}_sessions").document(
f"{app_name}__{user_id}__{session_id}"
)
def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any:
return self._session_ref(app_name, user_id, session_id).collection("events")
# ------------------------------------------------------------------
# State helpers
# ------------------------------------------------------------------
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
snap = await self._app_state_ref(app_name).get()
return snap.to_dict() or {} if snap.exists else {}
async def _get_user_state(self, app_name: str, user_id: str) -> dict[str, Any]:
snap = await self._user_state_ref(app_name, user_id).get()
return snap.to_dict() or {} if snap.exists else {}
@staticmethod
def _merge_state(
app_state: dict[str, Any],
user_state: dict[str, Any],
session_state: dict[str, Any],
) -> dict[str, Any]:
merged = dict(session_state)
for key, value in app_state.items():
merged[State.APP_PREFIX + key] = value
for key, value in user_state.items():
merged[State.USER_PREFIX + key] = value
return merged
# ------------------------------------------------------------------
# Compaction helpers
# ------------------------------------------------------------------
@staticmethod
def _events_to_text(events: list[Event]) -> str:
lines: list[str] = []
for event in events:
if event.content and event.content.parts:
text = "".join(p.text or "" for p in event.content.parts)
if text:
role = "User" if event.author == "user" else "Assistant"
lines.append(f"{role}: {text}")
return "\n\n".join(lines)
async def _generate_summary(
self, existing_summary: str, events: list[Event]
) -> str:
conversation_text = self._events_to_text(events)
previous = (
f"Previous summary of earlier conversation:\n{existing_summary}\n\n"
if existing_summary
else ""
)
prompt = (
"Summarize the following conversation between a user and an "
"assistant. Preserve:\n"
"- Key decisions and conclusions\n"
"- User preferences and requirements\n"
"- Important facts, names, and numbers\n"
"- The overall topic and direction of the conversation\n"
"- Any pending tasks or open questions\n\n"
f"{previous}"
f"Conversation:\n{conversation_text}\n\n"
"Provide a clear, comprehensive summary."
)
if self._genai_client is None:
msg = "genai_client is required for compaction"
raise RuntimeError(msg)
response = await self._genai_client.aio.models.generate_content(
model=self._compaction_model,
contents=prompt,
)
return response.text or ""
async def _compact_session(self, session: Session) -> None:
app_name = session.app_name
user_id = session.user_id
session_id = session.id
events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref.order_by("timestamp")
event_docs = await query.get()
if len(event_docs) <= self._compaction_keep_recent:
return
all_events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
events_to_summarize = all_events[: -self._compaction_keep_recent]
session_snap = await self._session_ref(app_name, user_id, session_id).get()
existing_summary = (session_snap.to_dict() or {}).get(
"conversation_summary", ""
)
try:
summary = await self._generate_summary(
existing_summary, events_to_summarize
)
except Exception:
logger.exception("Compaction summary generation failed; skipping.")
return
# Write summary BEFORE deleting events so a crash between the two
# steps leaves safe duplication rather than data loss.
await self._session_ref(app_name, user_id, session_id).update(
{"conversation_summary": summary}
)
docs_to_delete = event_docs[: -self._compaction_keep_recent]
for i in range(0, len(docs_to_delete), 500):
batch = self._db.batch()
for doc in docs_to_delete[i : i + 500]:
batch.delete(doc.reference)
await batch.commit()
logger.info(
"Compacted session %s: summarised %d events, kept %d.",
session_id,
len(docs_to_delete),
self._compaction_keep_recent,
)
async def _guarded_compact(self, session: Session) -> None:
"""Run compaction in the background with per-session locking."""
key = f"{session.app_name}__{session.user_id}__{session.id}"
lock = self._compaction_locks.setdefault(key, asyncio.Lock())
if lock.locked():
logger.debug("Compaction already running locally for %s; skipping.", key)
return
async with lock:
session_ref = self._session_ref(
session.app_name, session.user_id, session.id
)
try:
transaction = self._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, session_ref)
except Exception:
logger.exception("Failed to claim compaction lock for %s", key)
return
if not claimed:
logger.debug(
"Compaction lock held by another instance for %s; skipping.",
key,
)
return
try:
await self._compact_session(session)
except Exception:
logger.exception("Background compaction failed for %s", key)
finally:
try:
await session_ref.update({"compaction_lock": None})
except Exception:
logger.exception("Failed to release compaction lock for %s", key)
async def close(self) -> None:
"""Await all in-flight compaction tasks. Call before shutdown."""
if self._active_tasks:
await asyncio.gather(*self._active_tasks, return_exceptions=True)
# ------------------------------------------------------------------
# BaseSessionService implementation
# ------------------------------------------------------------------
@override
async def create_session(
self,
*,
app_name: str,
user_id: str,
state: dict[str, Any] | None = None,
session_id: str | None = None,
) -> Session:
if session_id and session_id.strip():
session_id = session_id.strip()
existing = await self._session_ref(app_name, user_id, session_id).get()
if existing.exists:
msg = f"Session with id {session_id} already exists."
raise AlreadyExistsError(msg)
else:
session_id = str(uuid.uuid4())
state_deltas = _session_util.extract_state_delta(state) # type: ignore[attr-defined]
app_state_delta = state_deltas["app"]
user_state_delta = state_deltas["user"]
session_state = state_deltas["session"]
write_coros: list = []
if app_state_delta:
write_coros.append(
self._app_state_ref(app_name).set(app_state_delta, merge=True)
)
if user_state_delta:
write_coros.append(
self._user_state_ref(app_name, user_id).set(
user_state_delta, merge=True
)
)
now = time.time()
write_coros.append(
self._session_ref(app_name, user_id, session_id).set(
{
"app_name": app_name,
"user_id": user_id,
"session_id": session_id,
"state": session_state or {},
"last_update_time": now,
}
)
)
await asyncio.gather(*write_coros)
app_state, user_state = await asyncio.gather(
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
merged = self._merge_state(app_state, user_state, session_state or {})
return Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged,
last_update_time=now,
)
@override
async def get_session(
self,
*,
app_name: str,
user_id: str,
session_id: str,
config: GetSessionConfig | None = None,
) -> Session | None:
snap = await self._session_ref(app_name, user_id, session_id).get()
if not snap.exists:
return None
session_data = snap.to_dict()
# Build events query
events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref
if config and config.after_timestamp:
query = query.where(
filter=FieldFilter("timestamp", ">=", config.after_timestamp)
)
query = query.order_by("timestamp")
event_docs, app_state, user_state = await asyncio.gather(
query.get(),
self._get_app_state(app_name),
self._get_user_state(app_name, user_id),
)
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
if config and config.num_recent_events:
events = events[-config.num_recent_events :]
# Prepend conversation summary as synthetic context events
conversation_summary = session_data.get("conversation_summary")
if conversation_summary:
summary_event = Event(
id="summary-context",
author="user",
content=Content(
role="user",
parts=[
Part(
text=(
"[Conversation context from previous"
" messages]\n"
f"{conversation_summary}"
)
)
],
),
timestamp=0.0,
invocation_id="compaction-summary",
)
ack_event = Event(
id="summary-ack",
author=app_name,
content=Content(
role="model",
parts=[
Part(
text=(
"Understood, I have the context from our"
" previous conversation and will continue"
" accordingly."
)
)
],
),
timestamp=0.001,
invocation_id="compaction-summary",
)
events = [summary_event, ack_event, *events]
# Merge scoped state
merged = self._merge_state(app_state, user_state, session_data.get("state", {}))
return Session(
app_name=app_name,
user_id=user_id,
id=session_id,
state=merged,
events=events,
last_update_time=session_data.get("last_update_time", 0.0),
)
@override
async def list_sessions(
self, *, app_name: str, user_id: str | None = None
) -> ListSessionsResponse:
query = self._db.collection(f"{self._prefix}_sessions").where(
filter=FieldFilter("app_name", "==", app_name)
)
if user_id is not None:
query = query.where(filter=FieldFilter("user_id", "==", user_id))
docs = await query.get()
if not docs:
return ListSessionsResponse()
doc_dicts: list[dict[str, Any]] = [doc.to_dict() or {} for doc in docs]
# Pre-fetch app state and all distinct user states in parallel
unique_user_ids = list({d["user_id"] for d in doc_dicts})
app_state, *user_states = await asyncio.gather(
self._get_app_state(app_name),
*(self._get_user_state(app_name, uid) for uid in unique_user_ids),
)
user_state_cache = dict(zip(unique_user_ids, user_states, strict=False))
sessions: list[Session] = []
for data in doc_dicts:
s_user_id = data["user_id"]
merged = self._merge_state(
app_state,
user_state_cache[s_user_id],
data.get("state", {}),
)
sessions.append(
Session(
app_name=app_name,
user_id=s_user_id,
id=data["session_id"],
state=merged,
events=[],
last_update_time=data.get("last_update_time", 0.0),
)
)
return ListSessionsResponse(sessions=sessions)
@override
async def delete_session(
self, *, app_name: str, user_id: str, session_id: str
) -> None:
ref = self._session_ref(app_name, user_id, session_id)
await self._db.recursive_delete(ref)
@override
async def append_event(self, session: Session, event: Event) -> Event:
if event.partial:
return event
t0 = time.monotonic()
app_name = session.app_name
user_id = session.user_id
session_id = session.id
# Base class: strips temp state, applies delta to in-memory session,
# appends event to session.events
event = await super().append_event(session=session, event=event)
session.last_update_time = event.timestamp
# Persist event document
event_data = event.model_dump(mode="json", exclude_none=True)
await (
self._events_col(app_name, user_id, session_id)
.document(event.id)
.set(event_data)
)
# Persist state deltas
session_ref = self._session_ref(app_name, user_id, session_id)
if event.actions and event.actions.state_delta:
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
write_coros: list = []
if state_deltas["app"]:
write_coros.append(
self._app_state_ref(app_name).set(state_deltas["app"], merge=True)
)
if state_deltas["user"]:
write_coros.append(
self._user_state_ref(app_name, user_id).set(
state_deltas["user"], merge=True
)
)
if state_deltas["session"]:
field_updates: dict[str, Any] = {
FieldPath("state", k).to_api_repr(): v
for k, v in state_deltas["session"].items()
}
field_updates["last_update_time"] = event.timestamp
write_coros.append(session_ref.update(field_updates))
else:
write_coros.append(
session_ref.update({"last_update_time": event.timestamp})
)
await asyncio.gather(*write_coros)
else:
await session_ref.update({"last_update_time": event.timestamp})
# Log token usage
if event.usage_metadata:
meta = event.usage_metadata
logger.info(
"Token usage for session %s event %s: "
"prompt=%s, candidates=%s, total=%s",
session_id,
event.id,
meta.prompt_token_count,
meta.candidates_token_count,
meta.total_token_count,
)
# Trigger compaction if total token count exceeds threshold
if (
self._compaction_threshold is not None
and event.usage_metadata
and event.usage_metadata.total_token_count
and event.usage_metadata.total_token_count >= self._compaction_threshold
):
logger.info(
"Compaction triggered for session %s: "
"total_token_count=%d >= threshold=%d",
session_id,
event.usage_metadata.total_token_count,
self._compaction_threshold,
)
task = asyncio.create_task(self._guarded_compact(session))
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
elapsed = time.monotonic() - t0
logger.info(
"append_event completed for session %s event %s in %.3fs",
session_id,
event.id,
elapsed,
)
return event

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@

35
tests/conftest.py Normal file
View File

@@ -0,0 +1,35 @@
"""Shared fixtures for Firestore session service tests."""
from __future__ import annotations
import os
import uuid
import pytest
import pytest_asyncio
from google.cloud.firestore_v1.async_client import AsyncClient
from va_agent.session import FirestoreSessionService
os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8153")
@pytest_asyncio.fixture
async def db():
return AsyncClient(project="test-project")
@pytest_asyncio.fixture
async def service(db: AsyncClient):
prefix = f"test_{uuid.uuid4().hex[:8]}"
return FirestoreSessionService(db=db, collection_prefix=prefix)
@pytest.fixture
def app_name():
return f"app_{uuid.uuid4().hex[:8]}"
@pytest.fixture
def user_id():
return f"user_{uuid.uuid4().hex[:8]}"

515
tests/test_compaction.py Normal file
View File

@@ -0,0 +1,515 @@
"""Tests for conversation compaction in FirestoreSessionService."""
from __future__ import annotations
import asyncio
import time
from unittest.mock import AsyncMock, MagicMock, patch
import uuid
import pytest
import pytest_asyncio
from google import genai
from google.adk.events.event import Event
from google.cloud.firestore_v1.async_client import AsyncClient
from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part
from va_agent.session import FirestoreSessionService, _try_claim_compaction_txn
pytestmark = pytest.mark.asyncio
@pytest_asyncio.fixture
async def mock_genai_client():
client = MagicMock(spec=genai.Client)
response = MagicMock()
response.text = "Summary of the conversation so far."
client.aio.models.generate_content = AsyncMock(return_value=response)
return client
@pytest_asyncio.fixture
async def compaction_service(db: AsyncClient, mock_genai_client):
prefix = f"test_{uuid.uuid4().hex[:8]}"
return FirestoreSessionService(
db=db,
collection_prefix=prefix,
compaction_token_threshold=100,
compaction_keep_recent=2,
genai_client=mock_genai_client,
)
# ------------------------------------------------------------------
# __init__ validation
# ------------------------------------------------------------------
class TestCompactionInit:
async def test_requires_genai_client(self, db):
with pytest.raises(ValueError, match="genai_client is required"):
FirestoreSessionService(
db=db,
compaction_token_threshold=1000,
)
async def test_no_threshold_no_client_ok(self, db):
svc = FirestoreSessionService(db=db)
assert svc._compaction_threshold is None
# ------------------------------------------------------------------
# Compaction trigger
# ------------------------------------------------------------------
class TestCompactionTrigger:
async def test_compaction_triggered_above_threshold(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
# Add 5 events, last one with usage_metadata above threshold
base = time.time()
for i in range(4):
e = Event(
author="user" if i % 2 == 0 else app_name,
content=Content(
role="user" if i % 2 == 0 else "model",
parts=[Part(text=f"message {i}")],
),
timestamp=base + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# This event crosses the threshold
trigger_event = Event(
author=app_name,
content=Content(
role="model", parts=[Part(text="final response")]
),
timestamp=base + 4,
invocation_id="inv-4",
usage_metadata=GenerateContentResponseUsageMetadata(
total_token_count=200,
),
)
await compaction_service.append_event(session, trigger_event)
await compaction_service.close()
# Summary generation should have been called
mock_genai_client.aio.models.generate_content.assert_called_once()
# Fetch session: should have summary + only keep_recent events
fetched = await compaction_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
# 2 synthetic summary events + 2 kept real events
assert len(fetched.events) == 4
assert fetched.events[0].id == "summary-context"
assert fetched.events[1].id == "summary-ack"
assert "Summary of the conversation" in fetched.events[0].content.parts[0].text
async def test_no_compaction_below_threshold(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author=app_name,
content=Content(
role="model", parts=[Part(text="short reply")]
),
timestamp=time.time(),
invocation_id="inv-1",
usage_metadata=GenerateContentResponseUsageMetadata(
total_token_count=50,
),
)
await compaction_service.append_event(session, event)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_no_compaction_without_usage_metadata(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
content=Content(
role="user", parts=[Part(text="hello")]
),
timestamp=time.time(),
invocation_id="inv-1",
)
await compaction_service.append_event(session, event)
mock_genai_client.aio.models.generate_content.assert_not_called()
# ------------------------------------------------------------------
# Compaction with too few events (nothing to compact)
# ------------------------------------------------------------------
class TestCompactionEdgeCases:
async def test_skip_when_fewer_events_than_keep_recent(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
# Only 2 events, keep_recent=2 → nothing to summarize
for i in range(2):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Trigger compaction manually even though threshold wouldn't fire
await compaction_service._compact_session(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_summary_generation_failure_is_non_fatal(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Make summary generation fail
mock_genai_client.aio.models.generate_content = AsyncMock(
side_effect=RuntimeError("API error")
)
# Should not raise
await compaction_service._compact_session(session)
# All events should still be present
fetched = await compaction_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 5
# ------------------------------------------------------------------
# get_session with summary
# ------------------------------------------------------------------
class TestGetSessionWithSummary:
async def test_no_summary_no_synthetic_events(
self, compaction_service, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
content=Content(
role="user", parts=[Part(text="hello")]
),
timestamp=time.time(),
invocation_id="inv-1",
)
await compaction_service.append_event(session, event)
fetched = await compaction_service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 1
assert fetched.events[0].author == "user"
# ------------------------------------------------------------------
# _events_to_text
# ------------------------------------------------------------------
class TestEventsToText:
async def test_formats_user_and_assistant(self):
events = [
Event(
author="user",
content=Content(
role="user", parts=[Part(text="Hi there")]
),
timestamp=1.0,
invocation_id="inv-1",
),
Event(
author="bot",
content=Content(
role="model", parts=[Part(text="Hello!")]
),
timestamp=2.0,
invocation_id="inv-2",
),
]
text = FirestoreSessionService._events_to_text(events)
assert "User: Hi there" in text
assert "Assistant: Hello!" in text
async def test_skips_events_without_text(self):
events = [
Event(
author="user",
timestamp=1.0,
invocation_id="inv-1",
),
]
text = FirestoreSessionService._events_to_text(events)
assert text == ""
# ------------------------------------------------------------------
# Firestore distributed lock
# ------------------------------------------------------------------
class TestCompactionLock:
async def test_claim_and_release(
self, compaction_service, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
# Claim the lock
transaction = compaction_service._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, session_ref)
assert claimed is True
# Lock is now held — second claim should fail
transaction2 = compaction_service._db.transaction()
claimed2 = await _try_claim_compaction_txn(transaction2, session_ref)
assert claimed2 is False
# Release the lock
await session_ref.update({"compaction_lock": None})
# Can claim again after release
transaction3 = compaction_service._db.transaction()
claimed3 = await _try_claim_compaction_txn(transaction3, session_ref)
assert claimed3 is True
async def test_stale_lock_can_be_reclaimed(
self, compaction_service, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
# Set a stale lock (older than TTL)
await session_ref.update({"compaction_lock": time.time() - 600})
# Should be able to reclaim a stale lock
transaction = compaction_service._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, session_ref)
assert claimed is True
async def test_claim_nonexistent_session(self, compaction_service):
ref = compaction_service._session_ref("no_app", "no_user", "no_id")
transaction = compaction_service._db.transaction()
claimed = await _try_claim_compaction_txn(transaction, ref)
assert claimed is False
# ------------------------------------------------------------------
# Guarded compact
# ------------------------------------------------------------------
class TestGuardedCompact:
async def test_local_lock_skips_concurrent(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Hold the in-process lock so _guarded_compact skips
key = f"{app_name}__{user_id}__{session.id}"
lock = compaction_service._compaction_locks.setdefault(
key, asyncio.Lock()
)
async with lock:
await compaction_service._guarded_compact(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_firestore_lock_held_skips(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
# Set a fresh Firestore lock (simulating another instance)
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
await session_ref.update({"compaction_lock": time.time()})
await compaction_service._guarded_compact(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_claim_failure_logs_and_skips(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
with patch(
"va_agent.session._try_claim_compaction_txn",
side_effect=RuntimeError("Firestore down"),
):
await compaction_service._guarded_compact(session)
mock_genai_client.aio.models.generate_content.assert_not_called()
async def test_compaction_failure_releases_lock(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
# Make _compact_session raise an unhandled exception
with patch.object(
compaction_service,
"_compact_session",
side_effect=RuntimeError("unexpected crash"),
):
await compaction_service._guarded_compact(session)
# Lock should be released even after failure
session_ref = compaction_service._session_ref(
app_name, user_id, session.id
)
snap = await session_ref.get()
assert snap.to_dict().get("compaction_lock") is None
async def test_lock_release_failure_is_non_fatal(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
original_session_ref = compaction_service._session_ref
def patched_session_ref(an, uid, sid):
ref = original_session_ref(an, uid, sid)
original_update = ref.update
async def failing_update(data):
if "compaction_lock" in data:
raise RuntimeError("Firestore write failed")
return await original_update(data)
ref.update = failing_update
return ref
with patch.object(
compaction_service,
"_session_ref",
side_effect=patched_session_ref,
):
# Should not raise despite lock release failure
await compaction_service._guarded_compact(session)
# ------------------------------------------------------------------
# close()
# ------------------------------------------------------------------
class TestClose:
async def test_close_no_tasks(self, compaction_service):
await compaction_service.close()
async def test_close_awaits_tasks(
self, compaction_service, mock_genai_client, app_name, user_id
):
session = await compaction_service.create_session(
app_name=app_name, user_id=user_id
)
base = time.time()
for i in range(4):
e = Event(
author="user",
content=Content(
role="user", parts=[Part(text=f"msg {i}")]
),
timestamp=base + i,
invocation_id=f"inv-{i}",
)
await compaction_service.append_event(session, e)
trigger = Event(
author=app_name,
content=Content(
role="model", parts=[Part(text="trigger")]
),
timestamp=base + 4,
invocation_id="inv-4",
usage_metadata=GenerateContentResponseUsageMetadata(
total_token_count=200,
),
)
await compaction_service.append_event(session, trigger)
assert len(compaction_service._active_tasks) > 0
await compaction_service.close()
assert len(compaction_service._active_tasks) == 0

View File

@@ -0,0 +1,428 @@
"""Tests for FirestoreSessionService against the Firestore emulator."""
from __future__ import annotations
import time
import uuid
import pytest
from google.adk.errors.already_exists_error import AlreadyExistsError
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.sessions.base_session_service import GetSessionConfig
from google.genai.types import Content, Part
pytestmark = pytest.mark.asyncio
# ------------------------------------------------------------------
# create_session
# ------------------------------------------------------------------
class TestCreateSession:
async def test_auto_generates_id(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
assert session.id
assert session.app_name == app_name
assert session.user_id == user_id
assert session.last_update_time > 0
async def test_custom_id(self, service, app_name, user_id):
sid = "my-custom-session"
session = await service.create_session(
app_name=app_name, user_id=user_id, session_id=sid
)
assert session.id == sid
async def test_duplicate_id_raises(self, service, app_name, user_id):
sid = "dup-session"
await service.create_session(
app_name=app_name, user_id=user_id, session_id=sid
)
with pytest.raises(AlreadyExistsError):
await service.create_session(
app_name=app_name, user_id=user_id, session_id=sid
)
async def test_session_state(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name,
user_id=user_id,
state={"count": 42},
)
assert session.state["count"] == 42
async def test_scoped_state(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name,
user_id=user_id,
state={
"app:global_flag": True,
"user:lang": "es",
"local_key": "val",
},
)
assert session.state["app:global_flag"] is True
assert session.state["user:lang"] == "es"
assert session.state["local_key"] == "val"
async def test_temp_state_not_persisted(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name,
user_id=user_id,
state={"temp:scratch": "gone", "keep": "yes"},
)
retrieved = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert "temp:scratch" not in retrieved.state
assert retrieved.state["keep"] == "yes"
# ------------------------------------------------------------------
# get_session
# ------------------------------------------------------------------
class TestGetSession:
async def test_nonexistent_returns_none(self, service, app_name, user_id):
result = await service.get_session(
app_name=app_name, user_id=user_id, session_id="nope"
)
assert result is None
async def test_roundtrip(self, service, app_name, user_id):
created = await service.create_session(
app_name=app_name,
user_id=user_id,
state={"foo": "bar"},
)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=created.id
)
assert fetched is not None
assert fetched.id == created.id
assert fetched.state["foo"] == "bar"
assert fetched.last_update_time == pytest.approx(
created.last_update_time, abs=0.01
)
async def test_returns_events(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
content=Content(parts=[Part(text="hello")]),
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 1
assert fetched.events[0].author == "user"
async def test_num_recent_events(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(5):
e = Event(
author="user",
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await service.append_event(session, e)
fetched = await service.get_session(
app_name=app_name,
user_id=user_id,
session_id=session.id,
config=GetSessionConfig(num_recent_events=2),
)
assert len(fetched.events) == 2
async def test_after_timestamp(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
base = time.time()
for i in range(3):
e = Event(
author="user",
timestamp=base + i,
invocation_id=f"inv-{i}",
)
await service.append_event(session, e)
fetched = await service.get_session(
app_name=app_name,
user_id=user_id,
session_id=session.id,
config=GetSessionConfig(after_timestamp=base + 1),
)
assert len(fetched.events) == 2
# ------------------------------------------------------------------
# list_sessions
# ------------------------------------------------------------------
class TestListSessions:
async def test_empty(self, service, app_name, user_id):
resp = await service.list_sessions(
app_name=app_name, user_id=user_id
)
assert resp.sessions == [] or resp.sessions is None
async def test_returns_created_sessions(
self, service, app_name, user_id
):
s1 = await service.create_session(
app_name=app_name, user_id=user_id
)
s2 = await service.create_session(
app_name=app_name, user_id=user_id
)
resp = await service.list_sessions(
app_name=app_name, user_id=user_id
)
ids = {s.id for s in resp.sessions}
assert s1.id in ids
assert s2.id in ids
async def test_filter_by_user(self, service, app_name):
uid1 = f"user_{uuid.uuid4().hex[:8]}"
uid2 = f"user_{uuid.uuid4().hex[:8]}"
await service.create_session(app_name=app_name, user_id=uid1)
await service.create_session(app_name=app_name, user_id=uid2)
resp = await service.list_sessions(
app_name=app_name, user_id=uid1
)
assert len(resp.sessions) == 1
assert resp.sessions[0].user_id == uid1
async def test_sessions_have_merged_state(
self, service, app_name, user_id
):
await service.create_session(
app_name=app_name,
user_id=user_id,
state={"app:shared": "yes", "local": "val"},
)
resp = await service.list_sessions(
app_name=app_name, user_id=user_id
)
s = resp.sessions[0]
assert s.state["app:shared"] == "yes"
assert s.state["local"] == "val"
async def test_sessions_have_no_events(
self, service, app_name, user_id
):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user", timestamp=time.time(), invocation_id="inv-1"
)
await service.append_event(session, event)
resp = await service.list_sessions(
app_name=app_name, user_id=user_id
)
assert resp.sessions[0].events == []
# ------------------------------------------------------------------
# delete_session
# ------------------------------------------------------------------
class TestDeleteSession:
async def test_delete(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
await service.delete_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
result = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert result is None
async def test_delete_removes_events(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user", timestamp=time.time(), invocation_id="inv-1"
)
await service.append_event(session, event)
await service.delete_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
result = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert result is None
# ------------------------------------------------------------------
# append_event
# ------------------------------------------------------------------
class TestAppendEvent:
async def test_basic(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
content=Content(parts=[Part(text="hi")]),
timestamp=time.time(),
invocation_id="inv-1",
)
returned = await service.append_event(session, event)
assert returned.id == event.id
assert returned.timestamp > 0
async def test_partial_event_not_persisted(
self, service, app_name, user_id
):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="user",
partial=True,
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 0
async def test_session_state_delta(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="agent",
actions=EventActions(state_delta={"counter": 1}),
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert fetched.state["counter"] == 1
async def test_app_state_delta(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="agent",
actions=EventActions(state_delta={"app:version": "2.0"}),
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert fetched.state["app:version"] == "2.0"
async def test_user_state_delta(self, service, app_name, user_id):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="agent",
actions=EventActions(state_delta={"user:pref": "dark"}),
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert fetched.state["user:pref"] == "dark"
async def test_updates_last_update_time(
self, service, app_name, user_id
):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
original_time = session.last_update_time
event = Event(
author="user",
timestamp=time.time() + 10,
invocation_id="inv-1",
)
await service.append_event(session, event)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert fetched.last_update_time > original_time
async def test_multiple_events_accumulate(
self, service, app_name, user_id
):
session = await service.create_session(
app_name=app_name, user_id=user_id
)
for i in range(3):
e = Event(
author="user",
content=Content(parts=[Part(text=f"msg {i}")]),
timestamp=time.time() + i,
invocation_id=f"inv-{i}",
)
await service.append_event(session, e)
fetched = await service.get_session(
app_name=app_name, user_id=user_id, session_id=session.id
)
assert len(fetched.events) == 3
async def test_app_state_shared_across_sessions(
self, service, app_name, user_id
):
s1 = await service.create_session(
app_name=app_name, user_id=user_id
)
event = Event(
author="agent",
actions=EventActions(state_delta={"app:shared_val": 99}),
timestamp=time.time(),
invocation_id="inv-1",
)
await service.append_event(s1, event)
s2 = await service.create_session(
app_name=app_name, user_id=user_id
)
assert s2.state["app:shared_val"] == 99

1937
uv.lock generated

File diff suppressed because it is too large Load Diff