Merge pull request 'Replace local Tool call with MCP implementation' (#2) from mcp into main
Reviewed-on: va/legacy-rag#2
This commit was merged in pull request #2.
This commit is contained in:
@@ -1,2 +1,3 @@
|
|||||||
Use `uv` for project management.
|
Use `uv` for project management.
|
||||||
Use `uv run ruff check` for linting, and `uv run ty check` for type checking
|
Use `uv run ruff check` for linting, and `uv run ty check` for type checking
|
||||||
|
Use `uv run pytest` for testing.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "rag-eval"
|
name = "va-agent"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@@ -9,28 +9,20 @@ authors = [
|
|||||||
]
|
]
|
||||||
requires-python = "~=3.12.0"
|
requires-python = "~=3.12.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp>=3.13.3",
|
|
||||||
"gcloud-aio-auth>=5.4.2",
|
|
||||||
"gcloud-aio-storage>=9.6.1",
|
|
||||||
"google-adk>=1.14.1",
|
"google-adk>=1.14.1",
|
||||||
"google-cloud-aiplatform>=1.126.1",
|
"google-cloud-firestore>=2.23.0",
|
||||||
"google-cloud-storage>=2.19.0",
|
|
||||||
"pydantic-settings[yaml]>=2.13.1",
|
"pydantic-settings[yaml]>=2.13.1",
|
||||||
"structlog>=25.5.0",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
ragops = "rag_eval.cli:app"
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["uv_build>=0.8.3,<0.9.0"]
|
requires = ["uv_build>=0.8.3,<0.9.0"]
|
||||||
build-backend = "uv_build"
|
build-backend = "uv_build"
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
"clai>=1.62.0",
|
|
||||||
"marimo>=0.20.1",
|
|
||||||
"pytest>=8.4.1",
|
"pytest>=8.4.1",
|
||||||
|
"pytest-asyncio>=1.3.0",
|
||||||
|
"pytest-sugar>=1.1.1",
|
||||||
"ruff>=0.12.10",
|
"ruff>=0.12.10",
|
||||||
"ty>=0.0.1a19",
|
"ty>=0.0.1a19",
|
||||||
]
|
]
|
||||||
@@ -43,4 +35,10 @@ exclude = ["scripts"]
|
|||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ['ALL']
|
select = ['ALL']
|
||||||
ignore = ['D203', 'D213', 'COM812']
|
ignore = [
|
||||||
|
'D203', # one-blank-line-before-class
|
||||||
|
'D213', # multi-line-summary-second-line
|
||||||
|
'COM812', # missing-trailing-comma
|
||||||
|
'ANN401', # dynamically-typed-any
|
||||||
|
'ERA001', # commented-out-code
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,59 +0,0 @@
|
|||||||
# Copyright 2026 Google LLC
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""ADK agent with vector search RAG tool."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from google.adk.agents.llm_agent import Agent
|
|
||||||
|
|
||||||
from .config_helper import settings
|
|
||||||
from .vector_search_tool import VectorSearchTool
|
|
||||||
|
|
||||||
# Set environment variables for Google GenAI Client to use Vertex AI
|
|
||||||
os.environ["GOOGLE_CLOUD_PROJECT"] = settings.project_id
|
|
||||||
os.environ["GOOGLE_CLOUD_LOCATION"] = settings.location
|
|
||||||
|
|
||||||
# Create vector search tool with configuration
|
|
||||||
vector_search_tool = VectorSearchTool(
|
|
||||||
name='conocimiento',
|
|
||||||
description='Search the vector index for company products and services information',
|
|
||||||
embedder=settings.embedder,
|
|
||||||
project_id=settings.project_id,
|
|
||||||
location=settings.location,
|
|
||||||
bucket=settings.bucket,
|
|
||||||
index_name=settings.index_name,
|
|
||||||
index_endpoint=settings.index_endpoint,
|
|
||||||
index_deployed_id=settings.index_deployed_id,
|
|
||||||
similarity_top_k=5,
|
|
||||||
min_similarity_threshold=0.6,
|
|
||||||
relative_threshold_factor=0.9,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create agent with vector search tool
|
|
||||||
# Configure model with Vertex AI fully qualified path
|
|
||||||
model_path = (
|
|
||||||
f'projects/{settings.project_id}/locations/{settings.location}/'
|
|
||||||
f'publishers/google/models/{settings.agent_language_model}'
|
|
||||||
)
|
|
||||||
|
|
||||||
root_agent = Agent(
|
|
||||||
model=model_path,
|
|
||||||
name=settings.agent_name,
|
|
||||||
description='A helpful assistant for user questions.',
|
|
||||||
instruction=settings.agent_instructions,
|
|
||||||
tools=[vector_search_tool],
|
|
||||||
)
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
"""Abstract base class for vector search providers."""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResult(TypedDict):
|
|
||||||
"""A single vector search result."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
distance: float
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class BaseVectorSearch(ABC):
|
|
||||||
"""Abstract base class for a vector search provider.
|
|
||||||
|
|
||||||
This class defines the standard interface for creating a vector search
|
|
||||||
index and running queries against it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_index(
|
|
||||||
self, name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
|
||||||
) -> None:
|
|
||||||
"""Create a new vector search index with the provided content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The desired name for the new index.
|
|
||||||
content_path: Path to the data used to populate the index.
|
|
||||||
**kwargs: Additional provider-specific arguments.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_index(
|
|
||||||
self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
|
||||||
) -> None:
|
|
||||||
"""Update an existing vector search index with new content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The name of the index to update.
|
|
||||||
content_path: Path to the data used to populate the index.
|
|
||||||
**kwargs: Additional provider-specific arguments.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def run_query(
|
|
||||||
self,
|
|
||||||
deployed_index_id: str,
|
|
||||||
query: list[float],
|
|
||||||
limit: int,
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
"""Run a similarity search query against the index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
deployed_index_id: The ID of the deployed index.
|
|
||||||
query: The embedding vector for the search query.
|
|
||||||
limit: Maximum number of nearest neighbors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of matched items with id, distance, and content.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
# Copyright 2026 Google LLC
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Configuration helper for ADK agent with vector search."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import cached_property
|
|
||||||
|
|
||||||
import vertexai
|
|
||||||
from pydantic_settings import (
|
|
||||||
BaseSettings,
|
|
||||||
PydanticBaseSettingsSource,
|
|
||||||
SettingsConfigDict,
|
|
||||||
YamlConfigSettingsSource,
|
|
||||||
)
|
|
||||||
from vertexai.language_models import TextEmbeddingModel
|
|
||||||
|
|
||||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EmbeddingResult:
|
|
||||||
"""Result from embedding a query."""
|
|
||||||
|
|
||||||
embeddings: list[list[float]]
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIEmbedder:
|
|
||||||
"""Embedder using Vertex AI TextEmbeddingModel."""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str, project_id: str, location: str) -> None:
|
|
||||||
"""Initialize the embedder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: Name of the embedding model (e.g., 'text-embedding-004')
|
|
||||||
project_id: GCP project ID
|
|
||||||
location: GCP location
|
|
||||||
|
|
||||||
"""
|
|
||||||
vertexai.init(project=project_id, location=location)
|
|
||||||
self.model = TextEmbeddingModel.from_pretrained(model_name)
|
|
||||||
|
|
||||||
async def embed_query(self, query: str) -> EmbeddingResult:
|
|
||||||
"""Embed a single query string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Text to embed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
EmbeddingResult with embeddings list
|
|
||||||
|
|
||||||
"""
|
|
||||||
embeddings = self.model.get_embeddings([query])
|
|
||||||
return EmbeddingResult(embeddings=[list(embeddings[0].values)])
|
|
||||||
|
|
||||||
|
|
||||||
class AgentSettings(BaseSettings):
|
|
||||||
"""Settings for ADK agent with vector search."""
|
|
||||||
|
|
||||||
# Google Cloud settings
|
|
||||||
project_id: str
|
|
||||||
location: str
|
|
||||||
bucket: str
|
|
||||||
|
|
||||||
# Agent configuration
|
|
||||||
agent_name: str
|
|
||||||
agent_instructions: str
|
|
||||||
agent_language_model: str
|
|
||||||
agent_embedding_model: str
|
|
||||||
|
|
||||||
# Vector index configuration
|
|
||||||
index_name: str
|
|
||||||
index_deployed_id: str
|
|
||||||
index_endpoint: str
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
|
||||||
yaml_file=CONFIG_FILE_PATH,
|
|
||||||
extra="ignore", # Ignore extra fields from config.yaml
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def settings_customise_sources(
|
|
||||||
cls,
|
|
||||||
settings_cls: type[BaseSettings],
|
|
||||||
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
|
||||||
env_settings: PydanticBaseSettingsSource,
|
|
||||||
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
|
||||||
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
|
||||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
|
||||||
"""Use env vars and YAML as settings sources."""
|
|
||||||
return (
|
|
||||||
env_settings,
|
|
||||||
YamlConfigSettingsSource(settings_cls),
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def embedder(self) -> VertexAIEmbedder:
|
|
||||||
"""Return an embedder configured for the agent's embedding model."""
|
|
||||||
return VertexAIEmbedder(
|
|
||||||
model_name=self.agent_embedding_model,
|
|
||||||
project_id=self.project_id,
|
|
||||||
location=self.location,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
settings = AgentSettings.model_validate({})
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""File storage provider implementations."""
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
"""Abstract base class for file storage providers."""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import BinaryIO
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFileStorage(ABC):
|
|
||||||
"""Abstract base class for a remote file processor.
|
|
||||||
|
|
||||||
Defines the interface for listing and processing files from
|
|
||||||
a remote source.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def upload_file(
|
|
||||||
self,
|
|
||||||
file_path: str,
|
|
||||||
destination_blob_name: str,
|
|
||||||
content_type: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Upload a file to the remote source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: The local path to the file to upload.
|
|
||||||
destination_blob_name: Name of the file in remote storage.
|
|
||||||
content_type: The content type of the file.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_files(self, path: str | None = None) -> list[str]:
|
|
||||||
"""List files from a remote location.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to a specific file or directory. If None,
|
|
||||||
recursively lists all files in the bucket.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of file paths.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
|
||||||
"""Get a file from the remote source as a file-like object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_name: The name of the file to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A file-like object containing the file data.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
@@ -1,188 +0,0 @@
|
|||||||
"""Google Cloud Storage file storage implementation."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
from typing import BinaryIO
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from gcloud.aio.storage import Storage
|
|
||||||
from google.cloud import storage
|
|
||||||
|
|
||||||
from .base import BaseFileStorage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
HTTP_TOO_MANY_REQUESTS = 429
|
|
||||||
HTTP_SERVER_ERROR = 500
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleCloudFileStorage(BaseFileStorage):
|
|
||||||
"""File storage backed by Google Cloud Storage."""
|
|
||||||
|
|
||||||
def __init__(self, bucket: str) -> None: # noqa: D107
|
|
||||||
self.bucket_name = bucket
|
|
||||||
|
|
||||||
self.storage_client = storage.Client()
|
|
||||||
self.bucket_client = self.storage_client.bucket(self.bucket_name)
|
|
||||||
self._aio_session: aiohttp.ClientSession | None = None
|
|
||||||
self._aio_storage: Storage | None = None
|
|
||||||
self._cache: dict[str, bytes] = {}
|
|
||||||
|
|
||||||
def upload_file(
|
|
||||||
self,
|
|
||||||
file_path: str,
|
|
||||||
destination_blob_name: str,
|
|
||||||
content_type: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Upload a file to Cloud Storage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: The local path to the file to upload.
|
|
||||||
destination_blob_name: Name of the blob in the bucket.
|
|
||||||
content_type: The content type of the file.
|
|
||||||
|
|
||||||
"""
|
|
||||||
blob = self.bucket_client.blob(destination_blob_name)
|
|
||||||
blob.upload_from_filename(
|
|
||||||
file_path,
|
|
||||||
content_type=content_type,
|
|
||||||
if_generation_match=0,
|
|
||||||
)
|
|
||||||
self._cache.pop(destination_blob_name, None)
|
|
||||||
|
|
||||||
def list_files(self, path: str | None = None) -> list[str]:
|
|
||||||
"""List all files at the given path in the bucket.
|
|
||||||
|
|
||||||
If path is None, recursively lists all files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Prefix to filter files by.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of blob names.
|
|
||||||
|
|
||||||
"""
|
|
||||||
blobs = self.storage_client.list_blobs(
|
|
||||||
self.bucket_name, prefix=path,
|
|
||||||
)
|
|
||||||
return [blob.name for blob in blobs]
|
|
||||||
|
|
||||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
|
||||||
"""Get a file as a file-like object, using cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_name: The blob name to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A BytesIO stream with the file contents.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if file_name not in self._cache:
|
|
||||||
blob = self.bucket_client.blob(file_name)
|
|
||||||
self._cache[file_name] = blob.download_as_bytes()
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
return file_stream
|
|
||||||
|
|
||||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._aio_session is None or self._aio_session.closed:
|
|
||||||
connector = aiohttp.TCPConnector(
|
|
||||||
limit=300, limit_per_host=50,
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=60)
|
|
||||||
self._aio_session = aiohttp.ClientSession(
|
|
||||||
timeout=timeout, connector=connector,
|
|
||||||
)
|
|
||||||
return self._aio_session
|
|
||||||
|
|
||||||
def _get_aio_storage(self) -> Storage:
|
|
||||||
if self._aio_storage is None:
|
|
||||||
self._aio_storage = Storage(
|
|
||||||
session=self._get_aio_session(),
|
|
||||||
)
|
|
||||||
return self._aio_storage
|
|
||||||
|
|
||||||
async def async_get_file_stream(
|
|
||||||
self, file_name: str, max_retries: int = 3,
|
|
||||||
) -> BinaryIO:
|
|
||||||
"""Get a file asynchronously with retry on transient errors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_name: The blob name to retrieve.
|
|
||||||
max_retries: Maximum number of retry attempts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A BytesIO stream with the file contents.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TimeoutError: If all retry attempts fail.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if file_name in self._cache:
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
return file_stream
|
|
||||||
|
|
||||||
storage_client = self._get_aio_storage()
|
|
||||||
last_exception: Exception | None = None
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
self._cache[file_name] = await storage_client.download(
|
|
||||||
self.bucket_name, file_name,
|
|
||||||
)
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
except TimeoutError as exc:
|
|
||||||
last_exception = exc
|
|
||||||
logger.warning(
|
|
||||||
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
|
||||||
self.bucket_name,
|
|
||||||
file_name,
|
|
||||||
attempt + 1,
|
|
||||||
max_retries,
|
|
||||||
)
|
|
||||||
except aiohttp.ClientResponseError as exc:
|
|
||||||
last_exception = exc
|
|
||||||
if (
|
|
||||||
exc.status == HTTP_TOO_MANY_REQUESTS
|
|
||||||
or exc.status >= HTTP_SERVER_ERROR
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"HTTP %d downloading gs://%s/%s "
|
|
||||||
"(attempt %d/%d)",
|
|
||||||
exc.status,
|
|
||||||
self.bucket_name,
|
|
||||||
file_name,
|
|
||||||
attempt + 1,
|
|
||||||
max_retries,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
return file_stream
|
|
||||||
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
delay = 0.5 * (2**attempt)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
msg = (
|
|
||||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
|
||||||
f"after {max_retries} attempts"
|
|
||||||
)
|
|
||||||
raise TimeoutError(msg) from last_exception
|
|
||||||
|
|
||||||
def delete_files(self, path: str) -> None:
|
|
||||||
"""Delete all files at the given path in the bucket.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Prefix of blobs to delete.
|
|
||||||
|
|
||||||
"""
|
|
||||||
blobs = self.storage_client.list_blobs(
|
|
||||||
self.bucket_name, prefix=path,
|
|
||||||
)
|
|
||||||
for blob in blobs:
|
|
||||||
blob.delete()
|
|
||||||
self._cache.pop(blob.name, None)
|
|
||||||
@@ -1,176 +0,0 @@
|
|||||||
# Copyright 2026 Google LLC
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""A retrieval tool that uses Vertex AI Vector Search (not RAG Engine)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from google.adk.tools.tool_context import ToolContext
|
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
from .vertex_ai import GoogleCloudVectorSearch
|
|
||||||
|
|
||||||
from google.adk.tools.retrieval.base_retrieval_tool import BaseRetrievalTool
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .config_helper import VertexAIEmbedder
|
|
||||||
|
|
||||||
logger = logging.getLogger('google_adk.' + __name__)
|
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchTool(BaseRetrievalTool):
|
|
||||||
"""A retrieval tool using Vertex AI Vector Search (not RAG Engine).
|
|
||||||
|
|
||||||
This tool uses GoogleCloudVectorSearch to query a vector index directly,
|
|
||||||
which is useful when Vertex AI RAG Engine is not available in your GCP project.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
name: str,
|
|
||||||
description: str,
|
|
||||||
embedder: VertexAIEmbedder,
|
|
||||||
project_id: str,
|
|
||||||
location: str,
|
|
||||||
bucket: str,
|
|
||||||
index_name: str,
|
|
||||||
index_endpoint: str,
|
|
||||||
index_deployed_id: str,
|
|
||||||
similarity_top_k: int = 5,
|
|
||||||
min_similarity_threshold: float = 0.6,
|
|
||||||
relative_threshold_factor: float = 0.9,
|
|
||||||
):
|
|
||||||
"""Initialize the VectorSearchTool.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Tool name for function declaration
|
|
||||||
description: Tool description for LLM
|
|
||||||
embedder: Embedder instance for query embedding
|
|
||||||
project_id: GCP project ID
|
|
||||||
location: GCP location (e.g., 'us-central1')
|
|
||||||
bucket: GCS bucket for content storage
|
|
||||||
index_name: Vector search index name
|
|
||||||
index_endpoint: Resource name of index endpoint
|
|
||||||
index_deployed_id: Deployed index ID
|
|
||||||
similarity_top_k: Number of results to retrieve (default: 5)
|
|
||||||
min_similarity_threshold: Minimum similarity score 0.0-1.0 (default: 0.6)
|
|
||||||
relative_threshold_factor: Factor of max similarity for dynamic filtering (default: 0.9)
|
|
||||||
"""
|
|
||||||
super().__init__(name=name, description=description)
|
|
||||||
|
|
||||||
self.embedder = embedder
|
|
||||||
self.index_endpoint = index_endpoint
|
|
||||||
self.index_deployed_id = index_deployed_id
|
|
||||||
self.similarity_top_k = similarity_top_k
|
|
||||||
self.min_similarity_threshold = min_similarity_threshold
|
|
||||||
self.relative_threshold_factor = relative_threshold_factor
|
|
||||||
|
|
||||||
# Initialize vector search (endpoint loaded lazily on first use)
|
|
||||||
self.vector_search = GoogleCloudVectorSearch(
|
|
||||||
project_id=project_id,
|
|
||||||
location=location,
|
|
||||||
bucket=bucket,
|
|
||||||
index_name=index_name,
|
|
||||||
)
|
|
||||||
self._endpoint_loaded = False
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
'VectorSearchTool initialized with index=%s, deployed_id=%s',
|
|
||||||
index_name,
|
|
||||||
index_deployed_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
@override
|
|
||||||
async def run_async(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
args: dict[str, Any],
|
|
||||||
tool_context: ToolContext,
|
|
||||||
) -> Any:
|
|
||||||
"""Execute vector search with the user's query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Dictionary containing 'query' key
|
|
||||||
tool_context: Tool execution context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted search results as XML-like documents or error message
|
|
||||||
"""
|
|
||||||
query = args['query']
|
|
||||||
logger.debug('VectorSearchTool query: %s', query)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Load index endpoint on first use (lazy loading)
|
|
||||||
if not self._endpoint_loaded:
|
|
||||||
self.vector_search.load_index_endpoint(self.index_endpoint)
|
|
||||||
self._endpoint_loaded = True
|
|
||||||
logger.info('Index endpoint loaded successfully')
|
|
||||||
|
|
||||||
# Embed the query using the configured embedder
|
|
||||||
embedding_result = await self.embedder.embed_query(query)
|
|
||||||
query_embedding = list(embedding_result.embeddings[0])
|
|
||||||
|
|
||||||
# Run vector search
|
|
||||||
search_results = await self.vector_search.async_run_query(
|
|
||||||
deployed_index_id=self.index_deployed_id,
|
|
||||||
query=query_embedding,
|
|
||||||
limit=self.similarity_top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply similarity filtering (dual threshold approach)
|
|
||||||
if search_results:
|
|
||||||
# Dynamic threshold based on max similarity
|
|
||||||
max_similarity = max(r['distance'] for r in search_results)
|
|
||||||
dynamic_cutoff = max_similarity * self.relative_threshold_factor
|
|
||||||
|
|
||||||
# Filter by both absolute and relative thresholds
|
|
||||||
search_results = [
|
|
||||||
result
|
|
||||||
for result in search_results
|
|
||||||
if (
|
|
||||||
result['distance'] > dynamic_cutoff
|
|
||||||
and result['distance'] > self.min_similarity_threshold
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
'VectorSearchTool results: %d documents after filtering',
|
|
||||||
len(search_results),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format results
|
|
||||||
if not search_results:
|
|
||||||
return (
|
|
||||||
f"No matching documents found for query: '{query}' "
|
|
||||||
f'(min_threshold={self.min_similarity_threshold})'
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format as XML-like documents (matching pydantic_ai pattern)
|
|
||||||
formatted_results = [
|
|
||||||
f'<document {i} name={result["id"]}>\n'
|
|
||||||
f'{result["content"]}\n'
|
|
||||||
f'</document {i}>'
|
|
||||||
for i, result in enumerate(search_results, start=1)
|
|
||||||
]
|
|
||||||
|
|
||||||
return '\n'.join(formatted_results)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error('VectorSearchTool error: %s', e, exc_info=True)
|
|
||||||
return f'Error during vector search: {str(e)}'
|
|
||||||
@@ -1,310 +0,0 @@
|
|||||||
"""Google Cloud Vertex AI Vector Search implementation."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import google.auth
|
|
||||||
import google.auth.credentials
|
|
||||||
import google.auth.transport.requests
|
|
||||||
from gcloud.aio.auth import Token
|
|
||||||
from google.cloud import aiplatform
|
|
||||||
|
|
||||||
from .file_storage.google_cloud import GoogleCloudFileStorage
|
|
||||||
from .base import BaseVectorSearch, SearchResult
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleCloudVectorSearch(BaseVectorSearch):
|
|
||||||
"""A vector search provider using Vertex AI Vector Search."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
project_id: str,
|
|
||||||
location: str,
|
|
||||||
bucket: str,
|
|
||||||
index_name: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the GoogleCloudVectorSearch client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
project_id: The Google Cloud project ID.
|
|
||||||
location: The Google Cloud location (e.g., 'us-central1').
|
|
||||||
bucket: The GCS bucket to use for file storage.
|
|
||||||
index_name: The name of the index.
|
|
||||||
|
|
||||||
"""
|
|
||||||
aiplatform.init(project=project_id, location=location)
|
|
||||||
self.project_id = project_id
|
|
||||||
self.location = location
|
|
||||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
|
||||||
self.index_name = index_name
|
|
||||||
self._credentials: google.auth.credentials.Credentials | None = None
|
|
||||||
self._aio_session: aiohttp.ClientSession | None = None
|
|
||||||
self._async_token: Token | None = None
|
|
||||||
|
|
||||||
def _get_auth_headers(self) -> dict[str, str]:
|
|
||||||
if self._credentials is None:
|
|
||||||
self._credentials, _ = google.auth.default(
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
if not self._credentials.token or self._credentials.expired:
|
|
||||||
self._credentials.refresh(
|
|
||||||
google.auth.transport.requests.Request(),
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"Authorization": f"Bearer {self._credentials.token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _async_get_auth_headers(self) -> dict[str, str]:
|
|
||||||
if self._async_token is None:
|
|
||||||
self._async_token = Token(
|
|
||||||
session=self._get_aio_session(),
|
|
||||||
scopes=[
|
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
access_token = await self._async_token.get()
|
|
||||||
return {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._aio_session is None or self._aio_session.closed:
|
|
||||||
connector = aiohttp.TCPConnector(
|
|
||||||
limit=300, limit_per_host=50,
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=60)
|
|
||||||
self._aio_session = aiohttp.ClientSession(
|
|
||||||
timeout=timeout, connector=connector,
|
|
||||||
)
|
|
||||||
return self._aio_session
|
|
||||||
|
|
||||||
def create_index(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
content_path: str,
|
|
||||||
*,
|
|
||||||
dimensions: int = 3072,
|
|
||||||
approximate_neighbors_count: int = 150,
|
|
||||||
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
|
|
||||||
**kwargs: Any, # noqa: ANN401, ARG002
|
|
||||||
) -> None:
|
|
||||||
"""Create a new Vertex AI Vector Search index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The display name for the new index.
|
|
||||||
content_path: GCS URI to the embeddings JSON file.
|
|
||||||
dimensions: Number of dimensions in embedding vectors.
|
|
||||||
approximate_neighbors_count: Neighbors to find per vector.
|
|
||||||
distance_measure_type: The distance measure to use.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
"""
|
|
||||||
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
|
||||||
display_name=name,
|
|
||||||
contents_delta_uri=content_path,
|
|
||||||
dimensions=dimensions,
|
|
||||||
approximate_neighbors_count=approximate_neighbors_count,
|
|
||||||
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
|
|
||||||
leaf_node_embedding_count=1000,
|
|
||||||
leaf_nodes_to_search_percent=10,
|
|
||||||
)
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def update_index(
|
|
||||||
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
|
|
||||||
) -> None:
|
|
||||||
"""Update an existing Vertex AI Vector Search index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The resource name of the index to update.
|
|
||||||
content_path: GCS URI to the new embeddings JSON file.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
"""
|
|
||||||
index = aiplatform.MatchingEngineIndex(index_name=index_name)
|
|
||||||
index.update_embeddings(
|
|
||||||
contents_delta_uri=content_path,
|
|
||||||
)
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def deploy_index(
|
|
||||||
self,
|
|
||||||
index_name: str,
|
|
||||||
machine_type: str = "e2-standard-2",
|
|
||||||
) -> None:
|
|
||||||
"""Deploy a Vertex AI Vector Search index to an endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The name of the index to deploy.
|
|
||||||
machine_type: The machine type for the endpoint.
|
|
||||||
|
|
||||||
"""
|
|
||||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
|
||||||
display_name=f"{index_name}-endpoint",
|
|
||||||
public_endpoint_enabled=True,
|
|
||||||
)
|
|
||||||
index_endpoint.deploy_index(
|
|
||||||
index=self.index,
|
|
||||||
deployed_index_id=(
|
|
||||||
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
|
|
||||||
),
|
|
||||||
machine_type=machine_type,
|
|
||||||
)
|
|
||||||
self.index_endpoint = index_endpoint
|
|
||||||
|
|
||||||
def load_index_endpoint(self, endpoint_name: str) -> None:
|
|
||||||
"""Load an existing Vertex AI Vector Search index endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
endpoint_name: The resource name of the index endpoint.
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
|
||||||
endpoint_name,
|
|
||||||
)
|
|
||||||
if not self.index_endpoint.public_endpoint_domain_name:
|
|
||||||
msg = (
|
|
||||||
"The index endpoint does not have a public endpoint. "
|
|
||||||
"Ensure the endpoint is configured for public access."
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def run_query(
|
|
||||||
self,
|
|
||||||
deployed_index_id: str,
|
|
||||||
query: list[float],
|
|
||||||
limit: int,
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
"""Run a similarity search query against the deployed index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
deployed_index_id: The ID of the deployed index.
|
|
||||||
query: The embedding vector for the search query.
|
|
||||||
limit: Maximum number of nearest neighbors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of matched items with id, distance, and content.
|
|
||||||
|
|
||||||
"""
|
|
||||||
response = self.index_endpoint.find_neighbors(
|
|
||||||
deployed_index_id=deployed_index_id,
|
|
||||||
queries=[query],
|
|
||||||
num_neighbors=limit,
|
|
||||||
)
|
|
||||||
results = []
|
|
||||||
for neighbor in response[0]:
|
|
||||||
file_path = (
|
|
||||||
f"{self.index_name}/contents/{neighbor.id}.md"
|
|
||||||
)
|
|
||||||
content = (
|
|
||||||
self.storage.get_file_stream(file_path)
|
|
||||||
.read()
|
|
||||||
.decode("utf-8")
|
|
||||||
)
|
|
||||||
results.append(
|
|
||||||
SearchResult(
|
|
||||||
id=neighbor.id,
|
|
||||||
distance=float(neighbor.distance or 0),
|
|
||||||
content=content,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def async_run_query(
|
|
||||||
self,
|
|
||||||
deployed_index_id: str,
|
|
||||||
query: Sequence[float],
|
|
||||||
limit: int,
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
"""Run an async similarity search via the REST API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
deployed_index_id: The ID of the deployed index.
|
|
||||||
query: The embedding vector for the search query.
|
|
||||||
limit: Maximum number of nearest neighbors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of matched items with id, distance, and content.
|
|
||||||
|
|
||||||
"""
|
|
||||||
domain = self.index_endpoint.public_endpoint_domain_name
|
|
||||||
endpoint_id = self.index_endpoint.name.split("/")[-1]
|
|
||||||
url = (
|
|
||||||
f"https://{domain}/v1/projects/{self.project_id}"
|
|
||||||
f"/locations/{self.location}"
|
|
||||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
|
||||||
)
|
|
||||||
payload = {
|
|
||||||
"deployed_index_id": deployed_index_id,
|
|
||||||
"queries": [
|
|
||||||
{
|
|
||||||
"datapoint": {"feature_vector": list(query)},
|
|
||||||
"neighbor_count": limit,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = await self._async_get_auth_headers()
|
|
||||||
session = self._get_aio_session()
|
|
||||||
async with session.post(
|
|
||||||
url, json=payload, headers=headers,
|
|
||||||
) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
data = await response.json()
|
|
||||||
|
|
||||||
neighbors = (
|
|
||||||
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
|
||||||
)
|
|
||||||
content_tasks = []
|
|
||||||
for neighbor in neighbors:
|
|
||||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
|
||||||
file_path = (
|
|
||||||
f"{self.index_name}/contents/{datapoint_id}.md"
|
|
||||||
)
|
|
||||||
content_tasks.append(
|
|
||||||
self.storage.async_get_file_stream(file_path),
|
|
||||||
)
|
|
||||||
|
|
||||||
file_streams = await asyncio.gather(*content_tasks)
|
|
||||||
results: list[SearchResult] = []
|
|
||||||
for neighbor, stream in zip(
|
|
||||||
neighbors, file_streams, strict=True,
|
|
||||||
):
|
|
||||||
results.append(
|
|
||||||
SearchResult(
|
|
||||||
id=neighbor["datapoint"]["datapointId"],
|
|
||||||
distance=neighbor["distance"],
|
|
||||||
content=stream.read().decode("utf-8"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def delete_index(self, index_name: str) -> None:
|
|
||||||
"""Delete a Vertex AI Vector Search index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The resource name of the index.
|
|
||||||
|
|
||||||
"""
|
|
||||||
index = aiplatform.MatchingEngineIndex(index_name)
|
|
||||||
index.delete()
|
|
||||||
|
|
||||||
def delete_index_endpoint(
|
|
||||||
self, index_endpoint_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Delete a Vertex AI Vector Search index endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_endpoint_name: The resource name of the endpoint.
|
|
||||||
|
|
||||||
"""
|
|
||||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
|
||||||
index_endpoint_name,
|
|
||||||
)
|
|
||||||
index_endpoint.undeploy_all()
|
|
||||||
index_endpoint.delete(force=True)
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from embedder.vertex_ai import VertexAIEmbedder
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
project = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
||||||
location = os.getenv("GOOGLE_CLOUD_LOCATION")
|
|
||||||
|
|
||||||
MODEL_NAME = "gemini-embedding-001"
|
|
||||||
CONTENT_LIST = [
|
|
||||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
|
||||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
|
||||||
"¿Qué es una hipoteca y cómo funciona?",
|
|
||||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
|
||||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
|
||||||
"¿Qué es la banca en línea y cómo me registro?",
|
|
||||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
|
||||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
|
||||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
|
||||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
|
||||||
]
|
|
||||||
TASK_TYPE = "RETRIEVAL_DOCUMENT"
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
logger.info(f"Initializing GenAI Client for project '{project}' in '{location}'")
|
|
||||||
embedder = VertexAIEmbedder(MODEL_NAME, project, location)
|
|
||||||
|
|
||||||
async def embed_content_task():
|
|
||||||
"""A single task to send one embedding request using the global client."""
|
|
||||||
content_to_embed = random.choice(CONTENT_LIST)
|
|
||||||
await embedder.async_generate_embedding(content_to_embed)
|
|
||||||
|
|
||||||
async def run_test(concurrency: int):
|
|
||||||
"""Continuously calls the embedding API and tracks requests."""
|
|
||||||
total_requests = 0
|
|
||||||
|
|
||||||
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on model '{MODEL_NAME}'.")
|
|
||||||
logger.info("Press Ctrl+C to stop.")
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Create tasks, passing project_id and location
|
|
||||||
tasks = [embed_content_task() for _ in range(concurrency)]
|
|
||||||
|
|
||||||
try:
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
total_requests += concurrency
|
|
||||||
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Caught an error. Stopping test.")
|
|
||||||
print("\n--- STATS ---")
|
|
||||||
print(f"Total successful requests: {total_requests}")
|
|
||||||
print(f"Concurrent requests during failure: {concurrency}")
|
|
||||||
print(f"Error Type: {e.__class__.__name__}")
|
|
||||||
print(f"Error Details: {e}")
|
|
||||||
print("-------------")
|
|
||||||
break
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def main(
|
|
||||||
concurrency: int = typer.Option(
|
|
||||||
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
|
|
||||||
),
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
asyncio.run(run_test(concurrency))
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("\nKeyboard interrupt received. Exiting.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app()
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import random
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
import typer
|
|
||||||
|
|
||||||
CONTENT_LIST = [
|
|
||||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
|
||||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
|
||||||
"¿Qué es una hipoteca y cómo funciona?",
|
|
||||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
|
||||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
|
||||||
"¿Qué es la banca en línea y cómo me registro?",
|
|
||||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
|
||||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
|
||||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
|
||||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
|
||||||
]
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
async def call_rag_endpoint_task(client: httpx.AsyncClient, url: str):
|
|
||||||
"""A single task to send one request to the RAG endpoint."""
|
|
||||||
question = random.choice(CONTENT_LIST)
|
|
||||||
json_payload = {
|
|
||||||
"sessionInfo": {
|
|
||||||
"parameters": {
|
|
||||||
"query": question
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
response = await client.post(url, json=json_payload)
|
|
||||||
response.raise_for_status() # Raise an exception for bad status codes
|
|
||||||
response_data = response.json()
|
|
||||||
response_text = response_data["sessionInfo"]["parameters"]["response"]
|
|
||||||
logger.info(f"Question: {question[:50]}... Response: {response_text[:100]}...")
|
|
||||||
|
|
||||||
async def run_test(concurrency: int, url: str, timeout_seconds: float):
|
|
||||||
"""Continuously calls the RAG endpoint and tracks requests."""
|
|
||||||
total_requests = 0
|
|
||||||
|
|
||||||
logger.info(f"Starting diagnostic test with {concurrency} concurrent requests on endpoint '{url}'.")
|
|
||||||
logger.info(f"Request timeout is set to {timeout_seconds} seconds.")
|
|
||||||
logger.info("Press Ctrl+C to stop.")
|
|
||||||
|
|
||||||
timeout = httpx.Timeout(timeout_seconds)
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
||||||
while True:
|
|
||||||
tasks = [call_rag_endpoint_task(client, url) for _ in range(concurrency)]
|
|
||||||
|
|
||||||
try:
|
|
||||||
await asyncio.gather(*tasks)
|
|
||||||
total_requests += concurrency
|
|
||||||
logger.info(f"Successfully completed batch. Total requests so far: {total_requests}")
|
|
||||||
except httpx.TimeoutException as e:
|
|
||||||
logger.error(f"A request timed out: {e.request.method} {e.request.url}")
|
|
||||||
logger.error("Consider increasing the timeout with the --timeout option.")
|
|
||||||
break
|
|
||||||
except httpx.HTTPStatusError as e:
|
|
||||||
logger.error(f"An HTTP error occurred: {e.response.status_code} - {e.request.method} {e.request.url}")
|
|
||||||
logger.error(f"Response body: {e.response.text}")
|
|
||||||
break
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.error(f"A request error occurred: {e.request.method} {e.request.url}")
|
|
||||||
logger.error(f"Error details: {e}")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Caught an unexpected error. Stopping test.")
|
|
||||||
print("\n--- STATS ---")
|
|
||||||
print(f"Total successful requests: {total_requests}")
|
|
||||||
print(f"Concurrent requests during failure: {concurrency}")
|
|
||||||
print(f"Error Type: {e.__class__.__name__}")
|
|
||||||
print(f"Error Details: {e}")
|
|
||||||
print("-------------")
|
|
||||||
break
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def main(
|
|
||||||
concurrency: int = typer.Option(
|
|
||||||
10, "--concurrency", "-c", help="Number of concurrent requests to send in each batch."
|
|
||||||
),
|
|
||||||
url: str = typer.Option(
|
|
||||||
"http://127.0.0.1:8000/sigma-rag", "--url", "-u", help="The URL of the RAG endpoint to test."
|
|
||||||
),
|
|
||||||
timeout_seconds: float = typer.Option(
|
|
||||||
30.0, "--timeout", "-t", help="Request timeout in seconds."
|
|
||||||
)
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
asyncio.run(run_test(concurrency, url, timeout_seconds))
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("\nKeyboard interrupt received. Exiting.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app()
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
import concurrent.futures
|
|
||||||
import random
|
|
||||||
import threading
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
# URL for the endpoint
|
|
||||||
url = "http://localhost:8000/sigma-rag"
|
|
||||||
|
|
||||||
# List of Spanish banking questions
|
|
||||||
spanish_questions = [
|
|
||||||
"¿Cuáles son los beneficios de una tarjeta de crédito?",
|
|
||||||
"¿Cómo puedo abrir una cuenta de ahorros?",
|
|
||||||
"¿Qué es una hipoteca y cómo funciona?",
|
|
||||||
"¿Cuáles son las tasas de interés para un préstamo personal?",
|
|
||||||
"¿Cómo puedo solicitar un préstamo para un coche?",
|
|
||||||
"¿Qué es la banca en línea y cómo me registro?",
|
|
||||||
"¿Cómo puedo reportar una tarjeta de crédito perdida o robada?",
|
|
||||||
"¿Qué es el phishing y cómo puedo protegerme?",
|
|
||||||
"¿Cuáles son los diferentes tipos de cuentas corrientes que ofrecen?",
|
|
||||||
"¿Cómo puedo transferir dinero a una cuenta internacional?",
|
|
||||||
]
|
|
||||||
|
|
||||||
# A threading Event to signal all threads to stop
|
|
||||||
stop_event = threading.Event()
|
|
||||||
|
|
||||||
def send_request(question, request_id):
|
|
||||||
"""Sends a single request and handles the response."""
|
|
||||||
if stop_event.is_set():
|
|
||||||
return
|
|
||||||
|
|
||||||
data = {"sessionInfo": {"parameters": {"query": question}}}
|
|
||||||
try:
|
|
||||||
response = requests.post(url, json=data)
|
|
||||||
|
|
||||||
if stop_event.is_set():
|
|
||||||
return
|
|
||||||
|
|
||||||
if response.status_code == 500:
|
|
||||||
print(f"Request {request_id}: Received 500 error with question: '{question}'.")
|
|
||||||
print("Stopping stress test.")
|
|
||||||
stop_event.set()
|
|
||||||
else:
|
|
||||||
print(f"Request {request_id}: Successful with status code {response.status_code}.")
|
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
if not stop_event.is_set():
|
|
||||||
print(f"Request {request_id}: An error occurred: {e}")
|
|
||||||
stop_event.set()
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Runs the stress test with parallel requests."""
|
|
||||||
num_workers = 30 # Number of parallel requests
|
|
||||||
print(f"Starting stress test with {num_workers} parallel workers. Press Ctrl+C to stop.")
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
|
|
||||||
futures = {
|
|
||||||
executor.submit(send_request, random.choice(spanish_questions), i)
|
|
||||||
for i in range(1, num_workers + 1)
|
|
||||||
}
|
|
||||||
request_id_counter = num_workers + 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
while not stop_event.is_set():
|
|
||||||
# Wait for any future to complete
|
|
||||||
done, _ = concurrent.futures.wait(
|
|
||||||
futures, return_when=concurrent.futures.FIRST_COMPLETED
|
|
||||||
)
|
|
||||||
|
|
||||||
for future in done:
|
|
||||||
# Remove the completed future
|
|
||||||
futures.remove(future)
|
|
||||||
|
|
||||||
# If we are not stopping, submit a new one
|
|
||||||
if not stop_event.is_set():
|
|
||||||
futures.add(
|
|
||||||
executor.submit(
|
|
||||||
send_request,
|
|
||||||
random.choice(spanish_questions),
|
|
||||||
request_id_counter,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
request_id_counter += 1
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\nKeyboard interrupt received. Stopping threads.")
|
|
||||||
stop_event.set()
|
|
||||||
|
|
||||||
print("Stress test finished.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
from typing import Annotated
|
|
||||||
|
|
||||||
import typer
|
|
||||||
from google.cloud import aiplatform
|
|
||||||
|
|
||||||
from rag_eval.config import settings
|
|
||||||
|
|
||||||
app = typer.Typer()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def main(
|
|
||||||
pipeline_spec_path: Annotated[
|
|
||||||
str,
|
|
||||||
typer.Option(
|
|
||||||
"--pipeline-spec-path",
|
|
||||||
"-p",
|
|
||||||
help="Path to the compiled pipeline YAML file.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
input_table: Annotated[
|
|
||||||
str,
|
|
||||||
typer.Option(
|
|
||||||
"--input-table",
|
|
||||||
"-i",
|
|
||||||
help="Full BigQuery table name for input (e.g., 'project.dataset.table')",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
output_table: Annotated[
|
|
||||||
str,
|
|
||||||
typer.Option(
|
|
||||||
"--output-table",
|
|
||||||
"-o",
|
|
||||||
help="Full BigQuery table name for output (e.g., 'project.dataset.table')",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
project_id: Annotated[
|
|
||||||
str,
|
|
||||||
typer.Option(
|
|
||||||
"--project-id",
|
|
||||||
help="Google Cloud project ID.",
|
|
||||||
),
|
|
||||||
] = settings.project_id,
|
|
||||||
location: Annotated[
|
|
||||||
str,
|
|
||||||
typer.Option(
|
|
||||||
"--location",
|
|
||||||
help="Google Cloud location for the pipeline job.",
|
|
||||||
),
|
|
||||||
] = settings.location,
|
|
||||||
display_name: Annotated[
|
|
||||||
str,
|
|
||||||
typer.Option(
|
|
||||||
"--display-name",
|
|
||||||
help="Display name for the pipeline job.",
|
|
||||||
),
|
|
||||||
] = "search-eval-pipeline-job",
|
|
||||||
):
|
|
||||||
"""Submits a Vertex AI pipeline job."""
|
|
||||||
parameter_values = {
|
|
||||||
"project_id": project_id,
|
|
||||||
"location": location,
|
|
||||||
"input_table": input_table,
|
|
||||||
"output_table": output_table,
|
|
||||||
}
|
|
||||||
|
|
||||||
job = aiplatform.PipelineJob(
|
|
||||||
display_name=display_name,
|
|
||||||
template_path=pipeline_spec_path,
|
|
||||||
pipeline_root=f"gs://{settings.bucket}/pipeline_root",
|
|
||||||
parameter_values=parameter_values,
|
|
||||||
project=project_id,
|
|
||||||
location=location,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Submitting pipeline job with parameters: {parameter_values}")
|
|
||||||
job.submit(
|
|
||||||
service_account="sa-cicd-gitlab@bnt-orquestador-cognitivo-dev.iam.gserviceaccount.com"
|
|
||||||
)
|
|
||||||
print(f"Pipeline job submitted. You can view it at: {job._dashboard_uri()}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
app()
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
from google.cloud import discoveryengine_v1 as discoveryengine
|
|
||||||
|
|
||||||
# TODO(developer): Uncomment these variables before running the sample.
|
|
||||||
project_id = "bnt-orquestador-cognitivo-dev"
|
|
||||||
|
|
||||||
client = discoveryengine.RankServiceClient()
|
|
||||||
|
|
||||||
# The full resource name of the ranking config.
|
|
||||||
# Format: projects/{project_id}/locations/{location}/rankingConfigs/default_ranking_config
|
|
||||||
ranking_config = client.ranking_config_path(
|
|
||||||
project=project_id,
|
|
||||||
location="global",
|
|
||||||
ranking_config="default_ranking_config",
|
|
||||||
)
|
|
||||||
request = discoveryengine.RankRequest(
|
|
||||||
ranking_config=ranking_config,
|
|
||||||
model="semantic-ranker-default@latest",
|
|
||||||
top_n=10,
|
|
||||||
query="What is Google Gemini?",
|
|
||||||
records=[
|
|
||||||
discoveryengine.RankingRecord(
|
|
||||||
id="1",
|
|
||||||
title="Gemini",
|
|
||||||
content="The Gemini zodiac symbol often depicts two figures standing side-by-side.",
|
|
||||||
),
|
|
||||||
discoveryengine.RankingRecord(
|
|
||||||
id="2",
|
|
||||||
title="Gemini",
|
|
||||||
content="Gemini is a cutting edge large language model created by Google.",
|
|
||||||
),
|
|
||||||
discoveryengine.RankingRecord(
|
|
||||||
id="3",
|
|
||||||
title="Gemini Constellation",
|
|
||||||
content="Gemini is a constellation that can be seen in the night sky.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.rank(request=request)
|
|
||||||
|
|
||||||
# Handle the response
|
|
||||||
print(response)
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
import requests
|
|
||||||
|
|
||||||
# Test the /sigma-rag endpoint
|
|
||||||
url = "http://localhost:8000/sigma-rag"
|
|
||||||
data = {
|
|
||||||
"sessionInfo": {"parameters": {"query": "What are the benefits of a credit card?"}}
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(url, json=data)
|
|
||||||
|
|
||||||
print("Response from /sigma-rag:")
|
|
||||||
print(response.json())
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""RAG evaluation agent package."""
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
"""Pydantic AI agent with RAG tool for vector search."""
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from pydantic_ai import Agent, Embedder, RunContext
|
|
||||||
from pydantic_ai.models.google import GoogleModel
|
|
||||||
|
|
||||||
from rag_eval.config import settings
|
|
||||||
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
|
|
||||||
|
|
||||||
logger = structlog.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Deps(BaseModel):
|
|
||||||
"""Dependencies injected into the agent at runtime."""
|
|
||||||
|
|
||||||
vector_search: GoogleCloudVectorSearch
|
|
||||||
embedder: Embedder
|
|
||||||
|
|
||||||
model_config = {"arbitrary_types_allowed": True}
|
|
||||||
|
|
||||||
|
|
||||||
model = GoogleModel(
|
|
||||||
settings.agent_language_model,
|
|
||||||
provider=settings.provider,
|
|
||||||
)
|
|
||||||
agent = Agent(
|
|
||||||
model,
|
|
||||||
deps_type=Deps,
|
|
||||||
system_prompt=settings.agent_instructions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@agent.tool
|
|
||||||
async def conocimiento(ctx: RunContext[Deps], query: str) -> str:
|
|
||||||
"""Search the vector index for the given query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx: The run context containing dependencies.
|
|
||||||
query: The query to search for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A formatted string containing the search results.
|
|
||||||
|
|
||||||
"""
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
min_sim = 0.6
|
|
||||||
|
|
||||||
query_embedding = await ctx.deps.embedder.embed_query(query)
|
|
||||||
t_embed = time.perf_counter()
|
|
||||||
|
|
||||||
search_results = await ctx.deps.vector_search.async_run_query(
|
|
||||||
deployed_index_id=settings.index_deployed_id,
|
|
||||||
query=list(query_embedding.embeddings[0]),
|
|
||||||
limit=5,
|
|
||||||
)
|
|
||||||
t_search = time.perf_counter()
|
|
||||||
|
|
||||||
if search_results:
|
|
||||||
max_sim = max(r["distance"] for r in search_results)
|
|
||||||
cutoff = max_sim * 0.9
|
|
||||||
search_results = [
|
|
||||||
s
|
|
||||||
for s in search_results
|
|
||||||
if s["distance"] > cutoff and s["distance"] > min_sim
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"conocimiento.timing",
|
|
||||||
embedding_ms=round((t_embed - t0) * 1000, 1),
|
|
||||||
vector_search_ms=round((t_search - t_embed) * 1000, 1),
|
|
||||||
total_ms=round((t_search - t0) * 1000, 1),
|
|
||||||
chunks=[s["id"] for s in search_results],
|
|
||||||
)
|
|
||||||
|
|
||||||
formatted_results = [
|
|
||||||
f"<document {i} name={result['id']}>\n"
|
|
||||||
f"{result['content']}\n"
|
|
||||||
f"</document {i}>"
|
|
||||||
for i, result in enumerate(search_results, start=1)
|
|
||||||
]
|
|
||||||
return "\n".join(formatted_results)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
deps = Deps(
|
|
||||||
vector_search=settings.vector_search,
|
|
||||||
embedder=settings.embedder,
|
|
||||||
)
|
|
||||||
agent.to_cli_sync(deps=deps)
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
"""Application settings loaded from YAML and environment variables."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from functools import cached_property
|
|
||||||
|
|
||||||
from pydantic_ai import Embedder
|
|
||||||
from pydantic_ai.providers.google import GoogleProvider
|
|
||||||
from pydantic_settings import (
|
|
||||||
BaseSettings,
|
|
||||||
PydanticBaseSettingsSource,
|
|
||||||
SettingsConfigDict,
|
|
||||||
YamlConfigSettingsSource,
|
|
||||||
)
|
|
||||||
|
|
||||||
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
|
|
||||||
|
|
||||||
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
"""Application settings loaded from config.yaml and env vars."""
|
|
||||||
|
|
||||||
project_id: str
|
|
||||||
location: str
|
|
||||||
bucket: str
|
|
||||||
|
|
||||||
agent_name: str
|
|
||||||
agent_instructions: str
|
|
||||||
agent_language_model: str
|
|
||||||
agent_embedding_model: str
|
|
||||||
agent_thinking: int
|
|
||||||
|
|
||||||
index_name: str
|
|
||||||
index_deployed_id: str
|
|
||||||
index_endpoint: str
|
|
||||||
index_dimensions: int
|
|
||||||
index_machine_type: str = "e2-standard-16"
|
|
||||||
index_origin: str
|
|
||||||
index_destination: str
|
|
||||||
index_chunk_limit: int
|
|
||||||
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(yaml_file=CONFIG_FILE_PATH)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def settings_customise_sources(
|
|
||||||
cls,
|
|
||||||
settings_cls: type[BaseSettings],
|
|
||||||
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
|
||||||
env_settings: PydanticBaseSettingsSource,
|
|
||||||
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
|
||||||
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
|
||||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
|
||||||
"""Use env vars and YAML as settings sources."""
|
|
||||||
return (
|
|
||||||
env_settings,
|
|
||||||
YamlConfigSettingsSource(settings_cls),
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def provider(self) -> GoogleProvider:
|
|
||||||
"""Return a Google provider configured for Vertex AI."""
|
|
||||||
return GoogleProvider(
|
|
||||||
project=self.project_id,
|
|
||||||
location=self.location,
|
|
||||||
)
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def vector_search(self) -> GoogleCloudVectorSearch:
|
|
||||||
"""Return a configured vector search client."""
|
|
||||||
vs = GoogleCloudVectorSearch(
|
|
||||||
project_id=self.project_id,
|
|
||||||
location=self.location,
|
|
||||||
bucket=self.bucket,
|
|
||||||
index_name=self.index_name,
|
|
||||||
)
|
|
||||||
vs.load_index_endpoint(self.index_endpoint)
|
|
||||||
return vs
|
|
||||||
|
|
||||||
@cached_property
|
|
||||||
def embedder(self) -> Embedder:
|
|
||||||
"""Return an embedder configured for the agent's embedding model."""
|
|
||||||
from pydantic_ai.embeddings.google import GoogleEmbeddingModel # noqa: PLC0415
|
|
||||||
|
|
||||||
model = GoogleEmbeddingModel(
|
|
||||||
self.agent_embedding_model,
|
|
||||||
provider=self.provider,
|
|
||||||
)
|
|
||||||
return Embedder(model)
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings.model_validate({})
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""File storage provider implementations."""
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
"""Abstract base class for file storage providers."""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import BinaryIO
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFileStorage(ABC):
|
|
||||||
"""Abstract base class for a remote file processor.
|
|
||||||
|
|
||||||
Defines the interface for listing and processing files from
|
|
||||||
a remote source.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def upload_file(
|
|
||||||
self,
|
|
||||||
file_path: str,
|
|
||||||
destination_blob_name: str,
|
|
||||||
content_type: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Upload a file to the remote source.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: The local path to the file to upload.
|
|
||||||
destination_blob_name: Name of the file in remote storage.
|
|
||||||
content_type: The content type of the file.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_files(self, path: str | None = None) -> list[str]:
|
|
||||||
"""List files from a remote location.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to a specific file or directory. If None,
|
|
||||||
recursively lists all files in the bucket.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of file paths.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
|
||||||
"""Get a file from the remote source as a file-like object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_name: The name of the file to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A file-like object containing the file data.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
@@ -1,188 +0,0 @@
|
|||||||
"""Google Cloud Storage file storage implementation."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
from typing import BinaryIO
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from gcloud.aio.storage import Storage
|
|
||||||
from google.cloud import storage
|
|
||||||
|
|
||||||
from rag_eval.file_storage.base import BaseFileStorage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
HTTP_TOO_MANY_REQUESTS = 429
|
|
||||||
HTTP_SERVER_ERROR = 500
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleCloudFileStorage(BaseFileStorage):
|
|
||||||
"""File storage backed by Google Cloud Storage."""
|
|
||||||
|
|
||||||
def __init__(self, bucket: str) -> None: # noqa: D107
|
|
||||||
self.bucket_name = bucket
|
|
||||||
|
|
||||||
self.storage_client = storage.Client()
|
|
||||||
self.bucket_client = self.storage_client.bucket(self.bucket_name)
|
|
||||||
self._aio_session: aiohttp.ClientSession | None = None
|
|
||||||
self._aio_storage: Storage | None = None
|
|
||||||
self._cache: dict[str, bytes] = {}
|
|
||||||
|
|
||||||
def upload_file(
|
|
||||||
self,
|
|
||||||
file_path: str,
|
|
||||||
destination_blob_name: str,
|
|
||||||
content_type: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Upload a file to Cloud Storage.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: The local path to the file to upload.
|
|
||||||
destination_blob_name: Name of the blob in the bucket.
|
|
||||||
content_type: The content type of the file.
|
|
||||||
|
|
||||||
"""
|
|
||||||
blob = self.bucket_client.blob(destination_blob_name)
|
|
||||||
blob.upload_from_filename(
|
|
||||||
file_path,
|
|
||||||
content_type=content_type,
|
|
||||||
if_generation_match=0,
|
|
||||||
)
|
|
||||||
self._cache.pop(destination_blob_name, None)
|
|
||||||
|
|
||||||
def list_files(self, path: str | None = None) -> list[str]:
|
|
||||||
"""List all files at the given path in the bucket.
|
|
||||||
|
|
||||||
If path is None, recursively lists all files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Prefix to filter files by.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of blob names.
|
|
||||||
|
|
||||||
"""
|
|
||||||
blobs = self.storage_client.list_blobs(
|
|
||||||
self.bucket_name, prefix=path,
|
|
||||||
)
|
|
||||||
return [blob.name for blob in blobs]
|
|
||||||
|
|
||||||
def get_file_stream(self, file_name: str) -> BinaryIO:
|
|
||||||
"""Get a file as a file-like object, using cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_name: The blob name to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A BytesIO stream with the file contents.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if file_name not in self._cache:
|
|
||||||
blob = self.bucket_client.blob(file_name)
|
|
||||||
self._cache[file_name] = blob.download_as_bytes()
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
return file_stream
|
|
||||||
|
|
||||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._aio_session is None or self._aio_session.closed:
|
|
||||||
connector = aiohttp.TCPConnector(
|
|
||||||
limit=300, limit_per_host=50,
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=60)
|
|
||||||
self._aio_session = aiohttp.ClientSession(
|
|
||||||
timeout=timeout, connector=connector,
|
|
||||||
)
|
|
||||||
return self._aio_session
|
|
||||||
|
|
||||||
def _get_aio_storage(self) -> Storage:
|
|
||||||
if self._aio_storage is None:
|
|
||||||
self._aio_storage = Storage(
|
|
||||||
session=self._get_aio_session(),
|
|
||||||
)
|
|
||||||
return self._aio_storage
|
|
||||||
|
|
||||||
async def async_get_file_stream(
|
|
||||||
self, file_name: str, max_retries: int = 3,
|
|
||||||
) -> BinaryIO:
|
|
||||||
"""Get a file asynchronously with retry on transient errors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_name: The blob name to retrieve.
|
|
||||||
max_retries: Maximum number of retry attempts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A BytesIO stream with the file contents.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TimeoutError: If all retry attempts fail.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if file_name in self._cache:
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
return file_stream
|
|
||||||
|
|
||||||
storage_client = self._get_aio_storage()
|
|
||||||
last_exception: Exception | None = None
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
self._cache[file_name] = await storage_client.download(
|
|
||||||
self.bucket_name, file_name,
|
|
||||||
)
|
|
||||||
file_stream = io.BytesIO(self._cache[file_name])
|
|
||||||
file_stream.name = file_name
|
|
||||||
except TimeoutError as exc:
|
|
||||||
last_exception = exc
|
|
||||||
logger.warning(
|
|
||||||
"Timeout downloading gs://%s/%s (attempt %d/%d)",
|
|
||||||
self.bucket_name,
|
|
||||||
file_name,
|
|
||||||
attempt + 1,
|
|
||||||
max_retries,
|
|
||||||
)
|
|
||||||
except aiohttp.ClientResponseError as exc:
|
|
||||||
last_exception = exc
|
|
||||||
if (
|
|
||||||
exc.status == HTTP_TOO_MANY_REQUESTS
|
|
||||||
or exc.status >= HTTP_SERVER_ERROR
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"HTTP %d downloading gs://%s/%s "
|
|
||||||
"(attempt %d/%d)",
|
|
||||||
exc.status,
|
|
||||||
self.bucket_name,
|
|
||||||
file_name,
|
|
||||||
attempt + 1,
|
|
||||||
max_retries,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
return file_stream
|
|
||||||
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
delay = 0.5 * (2**attempt)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
msg = (
|
|
||||||
f"Failed to download gs://{self.bucket_name}/{file_name} "
|
|
||||||
f"after {max_retries} attempts"
|
|
||||||
)
|
|
||||||
raise TimeoutError(msg) from last_exception
|
|
||||||
|
|
||||||
def delete_files(self, path: str) -> None:
|
|
||||||
"""Delete all files at the given path in the bucket.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Prefix of blobs to delete.
|
|
||||||
|
|
||||||
"""
|
|
||||||
blobs = self.storage_client.list_blobs(
|
|
||||||
self.bucket_name, prefix=path,
|
|
||||||
)
|
|
||||||
for blob in blobs:
|
|
||||||
blob.delete()
|
|
||||||
self._cache.pop(blob.name, None)
|
|
||||||
@@ -1,47 +0,0 @@
|
|||||||
"""Structured logging configuration using structlog."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(*, json: bool = True, level: int = logging.INFO) -> None:
|
|
||||||
"""Configure structlog with JSON or console output."""
|
|
||||||
shared_processors: list[structlog.types.Processor] = [
|
|
||||||
structlog.contextvars.merge_contextvars,
|
|
||||||
structlog.stdlib.add_log_level,
|
|
||||||
structlog.stdlib.add_logger_name,
|
|
||||||
structlog.processors.TimeStamper(fmt="iso"),
|
|
||||||
structlog.stdlib.ProcessorFormatter.wrap_for_formatter,
|
|
||||||
]
|
|
||||||
|
|
||||||
if json:
|
|
||||||
formatter = structlog.stdlib.ProcessorFormatter(
|
|
||||||
processors=[
|
|
||||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
|
||||||
structlog.processors.JSONRenderer(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
formatter = structlog.stdlib.ProcessorFormatter(
|
|
||||||
processors=[
|
|
||||||
structlog.stdlib.ProcessorFormatter.remove_processors_meta,
|
|
||||||
structlog.dev.ConsoleRenderer(),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
handler = logging.StreamHandler(sys.stdout)
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
|
|
||||||
root = logging.getLogger()
|
|
||||||
root.handlers.clear()
|
|
||||||
root.addHandler(handler)
|
|
||||||
root.setLevel(level)
|
|
||||||
|
|
||||||
structlog.configure(
|
|
||||||
processors=shared_processors,
|
|
||||||
logger_factory=structlog.stdlib.LoggerFactory(),
|
|
||||||
wrapper_class=structlog.stdlib.BoundLogger,
|
|
||||||
cache_logger_on_first_use=True,
|
|
||||||
)
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
"""FastAPI server exposing the RAG agent endpoint."""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from typing import Literal
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import structlog
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from rag_eval.agent import Deps, agent
|
|
||||||
from rag_eval.config import settings
|
|
||||||
from rag_eval.logging import setup_logging
|
|
||||||
|
|
||||||
logger = structlog.get_logger(__name__)
|
|
||||||
|
|
||||||
setup_logging()
|
|
||||||
|
|
||||||
app = FastAPI(title="RAG Agent")
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
|
||||||
"""A single chat message."""
|
|
||||||
|
|
||||||
role: Literal["system", "user", "assistant"]
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRequest(BaseModel):
|
|
||||||
"""Request body for the agent endpoint."""
|
|
||||||
|
|
||||||
messages: list[Message]
|
|
||||||
|
|
||||||
|
|
||||||
class AgentResponse(BaseModel):
|
|
||||||
"""Response body from the agent endpoint."""
|
|
||||||
|
|
||||||
response: str
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/agent")
|
|
||||||
async def run_agent(request: AgentRequest) -> AgentResponse:
|
|
||||||
"""Run the RAG agent with the provided messages."""
|
|
||||||
request_id = uuid4().hex[:8]
|
|
||||||
structlog.contextvars.clear_contextvars()
|
|
||||||
structlog.contextvars.bind_contextvars(request_id=request_id)
|
|
||||||
|
|
||||||
prompt = request.messages[-1].content
|
|
||||||
logger.info("request.start", prompt_length=len(prompt))
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
|
|
||||||
deps = Deps(
|
|
||||||
vector_search=settings.vector_search,
|
|
||||||
embedder=settings.embedder,
|
|
||||||
)
|
|
||||||
result = await agent.run(prompt, deps=deps)
|
|
||||||
|
|
||||||
elapsed = round((time.perf_counter() - t0) * 1000, 1)
|
|
||||||
logger.info("request.end", elapsed_ms=elapsed)
|
|
||||||
|
|
||||||
return AgentResponse(response=result.output)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Vector search provider implementations."""
|
|
||||||
@@ -1,68 +0,0 @@
|
|||||||
"""Abstract base class for vector search providers."""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResult(TypedDict):
|
|
||||||
"""A single vector search result."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
distance: float
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class BaseVectorSearch(ABC):
|
|
||||||
"""Abstract base class for a vector search provider.
|
|
||||||
|
|
||||||
This class defines the standard interface for creating a vector search
|
|
||||||
index and running queries against it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_index(
|
|
||||||
self, name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
|
||||||
) -> None:
|
|
||||||
"""Create a new vector search index with the provided content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The desired name for the new index.
|
|
||||||
content_path: Path to the data used to populate the index.
|
|
||||||
**kwargs: Additional provider-specific arguments.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def update_index(
|
|
||||||
self, index_name: str, content_path: str, **kwargs: Any # noqa: ANN401
|
|
||||||
) -> None:
|
|
||||||
"""Update an existing vector search index with new content.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The name of the index to update.
|
|
||||||
content_path: Path to the data used to populate the index.
|
|
||||||
**kwargs: Additional provider-specific arguments.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def run_query(
|
|
||||||
self,
|
|
||||||
deployed_index_id: str,
|
|
||||||
query: list[float],
|
|
||||||
limit: int,
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
"""Run a similarity search query against the index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
deployed_index_id: The ID of the deployed index.
|
|
||||||
query: The embedding vector for the search query.
|
|
||||||
limit: Maximum number of nearest neighbors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of matched items with id, distance, and content.
|
|
||||||
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
@@ -1,310 +0,0 @@
|
|||||||
"""Google Cloud Vertex AI Vector Search implementation."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import google.auth
|
|
||||||
import google.auth.credentials
|
|
||||||
import google.auth.transport.requests
|
|
||||||
from gcloud.aio.auth import Token
|
|
||||||
from google.cloud import aiplatform
|
|
||||||
|
|
||||||
from rag_eval.file_storage.google_cloud import GoogleCloudFileStorage
|
|
||||||
from rag_eval.vector_search.base import BaseVectorSearch, SearchResult
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleCloudVectorSearch(BaseVectorSearch):
|
|
||||||
"""A vector search provider using Vertex AI Vector Search."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
project_id: str,
|
|
||||||
location: str,
|
|
||||||
bucket: str,
|
|
||||||
index_name: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Initialize the GoogleCloudVectorSearch client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
project_id: The Google Cloud project ID.
|
|
||||||
location: The Google Cloud location (e.g., 'us-central1').
|
|
||||||
bucket: The GCS bucket to use for file storage.
|
|
||||||
index_name: The name of the index.
|
|
||||||
|
|
||||||
"""
|
|
||||||
aiplatform.init(project=project_id, location=location)
|
|
||||||
self.project_id = project_id
|
|
||||||
self.location = location
|
|
||||||
self.storage = GoogleCloudFileStorage(bucket=bucket)
|
|
||||||
self.index_name = index_name
|
|
||||||
self._credentials: google.auth.credentials.Credentials | None = None
|
|
||||||
self._aio_session: aiohttp.ClientSession | None = None
|
|
||||||
self._async_token: Token | None = None
|
|
||||||
|
|
||||||
def _get_auth_headers(self) -> dict[str, str]:
|
|
||||||
if self._credentials is None:
|
|
||||||
self._credentials, _ = google.auth.default(
|
|
||||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
|
||||||
)
|
|
||||||
if not self._credentials.token or self._credentials.expired:
|
|
||||||
self._credentials.refresh(
|
|
||||||
google.auth.transport.requests.Request(),
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"Authorization": f"Bearer {self._credentials.token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _async_get_auth_headers(self) -> dict[str, str]:
|
|
||||||
if self._async_token is None:
|
|
||||||
self._async_token = Token(
|
|
||||||
session=self._get_aio_session(),
|
|
||||||
scopes=[
|
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
access_token = await self._async_token.get()
|
|
||||||
return {
|
|
||||||
"Authorization": f"Bearer {access_token}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_aio_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._aio_session is None or self._aio_session.closed:
|
|
||||||
connector = aiohttp.TCPConnector(
|
|
||||||
limit=300, limit_per_host=50,
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=60)
|
|
||||||
self._aio_session = aiohttp.ClientSession(
|
|
||||||
timeout=timeout, connector=connector,
|
|
||||||
)
|
|
||||||
return self._aio_session
|
|
||||||
|
|
||||||
def create_index(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
content_path: str,
|
|
||||||
*,
|
|
||||||
dimensions: int = 3072,
|
|
||||||
approximate_neighbors_count: int = 150,
|
|
||||||
distance_measure_type: str = "DOT_PRODUCT_DISTANCE",
|
|
||||||
**kwargs: Any, # noqa: ANN401, ARG002
|
|
||||||
) -> None:
|
|
||||||
"""Create a new Vertex AI Vector Search index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: The display name for the new index.
|
|
||||||
content_path: GCS URI to the embeddings JSON file.
|
|
||||||
dimensions: Number of dimensions in embedding vectors.
|
|
||||||
approximate_neighbors_count: Neighbors to find per vector.
|
|
||||||
distance_measure_type: The distance measure to use.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
"""
|
|
||||||
index = aiplatform.MatchingEngineIndex.create_tree_ah_index(
|
|
||||||
display_name=name,
|
|
||||||
contents_delta_uri=content_path,
|
|
||||||
dimensions=dimensions,
|
|
||||||
approximate_neighbors_count=approximate_neighbors_count,
|
|
||||||
distance_measure_type=distance_measure_type, # type: ignore[arg-type]
|
|
||||||
leaf_node_embedding_count=1000,
|
|
||||||
leaf_nodes_to_search_percent=10,
|
|
||||||
)
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def update_index(
|
|
||||||
self, index_name: str, content_path: str, **kwargs: Any, # noqa: ANN401, ARG002
|
|
||||||
) -> None:
|
|
||||||
"""Update an existing Vertex AI Vector Search index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The resource name of the index to update.
|
|
||||||
content_path: GCS URI to the new embeddings JSON file.
|
|
||||||
**kwargs: Additional arguments.
|
|
||||||
|
|
||||||
"""
|
|
||||||
index = aiplatform.MatchingEngineIndex(index_name=index_name)
|
|
||||||
index.update_embeddings(
|
|
||||||
contents_delta_uri=content_path,
|
|
||||||
)
|
|
||||||
self.index = index
|
|
||||||
|
|
||||||
def deploy_index(
|
|
||||||
self,
|
|
||||||
index_name: str,
|
|
||||||
machine_type: str = "e2-standard-2",
|
|
||||||
) -> None:
|
|
||||||
"""Deploy a Vertex AI Vector Search index to an endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The name of the index to deploy.
|
|
||||||
machine_type: The machine type for the endpoint.
|
|
||||||
|
|
||||||
"""
|
|
||||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
|
|
||||||
display_name=f"{index_name}-endpoint",
|
|
||||||
public_endpoint_enabled=True,
|
|
||||||
)
|
|
||||||
index_endpoint.deploy_index(
|
|
||||||
index=self.index,
|
|
||||||
deployed_index_id=(
|
|
||||||
f"{index_name.replace('-', '_')}_deployed_{uuid4().hex}"
|
|
||||||
),
|
|
||||||
machine_type=machine_type,
|
|
||||||
)
|
|
||||||
self.index_endpoint = index_endpoint
|
|
||||||
|
|
||||||
def load_index_endpoint(self, endpoint_name: str) -> None:
|
|
||||||
"""Load an existing Vertex AI Vector Search index endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
endpoint_name: The resource name of the index endpoint.
|
|
||||||
|
|
||||||
"""
|
|
||||||
self.index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
|
||||||
endpoint_name,
|
|
||||||
)
|
|
||||||
if not self.index_endpoint.public_endpoint_domain_name:
|
|
||||||
msg = (
|
|
||||||
"The index endpoint does not have a public endpoint. "
|
|
||||||
"Ensure the endpoint is configured for public access."
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def run_query(
|
|
||||||
self,
|
|
||||||
deployed_index_id: str,
|
|
||||||
query: list[float],
|
|
||||||
limit: int,
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
"""Run a similarity search query against the deployed index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
deployed_index_id: The ID of the deployed index.
|
|
||||||
query: The embedding vector for the search query.
|
|
||||||
limit: Maximum number of nearest neighbors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of matched items with id, distance, and content.
|
|
||||||
|
|
||||||
"""
|
|
||||||
response = self.index_endpoint.find_neighbors(
|
|
||||||
deployed_index_id=deployed_index_id,
|
|
||||||
queries=[query],
|
|
||||||
num_neighbors=limit,
|
|
||||||
)
|
|
||||||
results = []
|
|
||||||
for neighbor in response[0]:
|
|
||||||
file_path = (
|
|
||||||
f"{self.index_name}/contents/{neighbor.id}.md"
|
|
||||||
)
|
|
||||||
content = (
|
|
||||||
self.storage.get_file_stream(file_path)
|
|
||||||
.read()
|
|
||||||
.decode("utf-8")
|
|
||||||
)
|
|
||||||
results.append(
|
|
||||||
SearchResult(
|
|
||||||
id=neighbor.id,
|
|
||||||
distance=float(neighbor.distance or 0),
|
|
||||||
content=content,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def async_run_query(
|
|
||||||
self,
|
|
||||||
deployed_index_id: str,
|
|
||||||
query: Sequence[float],
|
|
||||||
limit: int,
|
|
||||||
) -> list[SearchResult]:
|
|
||||||
"""Run an async similarity search via the REST API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
deployed_index_id: The ID of the deployed index.
|
|
||||||
query: The embedding vector for the search query.
|
|
||||||
limit: Maximum number of nearest neighbors to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of matched items with id, distance, and content.
|
|
||||||
|
|
||||||
"""
|
|
||||||
domain = self.index_endpoint.public_endpoint_domain_name
|
|
||||||
endpoint_id = self.index_endpoint.name.split("/")[-1]
|
|
||||||
url = (
|
|
||||||
f"https://{domain}/v1/projects/{self.project_id}"
|
|
||||||
f"/locations/{self.location}"
|
|
||||||
f"/indexEndpoints/{endpoint_id}:findNeighbors"
|
|
||||||
)
|
|
||||||
payload = {
|
|
||||||
"deployed_index_id": deployed_index_id,
|
|
||||||
"queries": [
|
|
||||||
{
|
|
||||||
"datapoint": {"feature_vector": list(query)},
|
|
||||||
"neighbor_count": limit,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
headers = await self._async_get_auth_headers()
|
|
||||||
session = self._get_aio_session()
|
|
||||||
async with session.post(
|
|
||||||
url, json=payload, headers=headers,
|
|
||||||
) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
data = await response.json()
|
|
||||||
|
|
||||||
neighbors = (
|
|
||||||
data.get("nearestNeighbors", [{}])[0].get("neighbors", [])
|
|
||||||
)
|
|
||||||
content_tasks = []
|
|
||||||
for neighbor in neighbors:
|
|
||||||
datapoint_id = neighbor["datapoint"]["datapointId"]
|
|
||||||
file_path = (
|
|
||||||
f"{self.index_name}/contents/{datapoint_id}.md"
|
|
||||||
)
|
|
||||||
content_tasks.append(
|
|
||||||
self.storage.async_get_file_stream(file_path),
|
|
||||||
)
|
|
||||||
|
|
||||||
file_streams = await asyncio.gather(*content_tasks)
|
|
||||||
results: list[SearchResult] = []
|
|
||||||
for neighbor, stream in zip(
|
|
||||||
neighbors, file_streams, strict=True,
|
|
||||||
):
|
|
||||||
results.append(
|
|
||||||
SearchResult(
|
|
||||||
id=neighbor["datapoint"]["datapointId"],
|
|
||||||
distance=neighbor["distance"],
|
|
||||||
content=stream.read().decode("utf-8"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
def delete_index(self, index_name: str) -> None:
|
|
||||||
"""Delete a Vertex AI Vector Search index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_name: The resource name of the index.
|
|
||||||
|
|
||||||
"""
|
|
||||||
index = aiplatform.MatchingEngineIndex(index_name)
|
|
||||||
index.delete()
|
|
||||||
|
|
||||||
def delete_index_endpoint(
|
|
||||||
self, index_endpoint_name: str,
|
|
||||||
) -> None:
|
|
||||||
"""Delete a Vertex AI Vector Search index endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
index_endpoint_name: The resource name of the endpoint.
|
|
||||||
|
|
||||||
"""
|
|
||||||
index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
|
|
||||||
index_endpoint_name,
|
|
||||||
)
|
|
||||||
index_endpoint.undeploy_all()
|
|
||||||
index_endpoint.delete(force=True)
|
|
||||||
@@ -4,7 +4,3 @@ import os
|
|||||||
|
|
||||||
# Ensure the Google GenAI SDK talks to Vertex AI instead of the public Gemini API.
|
# Ensure the Google GenAI SDK talks to Vertex AI instead of the public Gemini API.
|
||||||
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "true")
|
||||||
|
|
||||||
from .agent import root_agent
|
|
||||||
|
|
||||||
__all__ = ["root_agent"]
|
|
||||||
29
src/va_agent/agent.py
Normal file
29
src/va_agent/agent.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""ADK agent with vector search RAG tool."""
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from google.adk.agents.llm_agent import Agent
|
||||||
|
from google.adk.runners import Runner
|
||||||
|
from google.adk.tools.mcp_tool import McpToolset
|
||||||
|
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
|
||||||
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
|
||||||
|
from va_agent.config import settings
|
||||||
|
from va_agent.session import FirestoreSessionService
|
||||||
|
|
||||||
|
connection_params = SseConnectionParams(url=settings.mcp_remote_url)
|
||||||
|
toolset = McpToolset(connection_params=connection_params)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
model=settings.agent_model,
|
||||||
|
name=settings.agent_name,
|
||||||
|
instruction=settings.agent_instructions,
|
||||||
|
tools=[toolset],
|
||||||
|
)
|
||||||
|
|
||||||
|
session_service = FirestoreSessionService(
|
||||||
|
db=AsyncClient(database=settings.firestore_db),
|
||||||
|
compaction_token_threshold=10_000,
|
||||||
|
genai_client=genai.Client(),
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = Runner(app_name="va_agent", agent=agent, session_service=session_service)
|
||||||
53
src/va_agent/config.py
Normal file
53
src/va_agent/config.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Configuration helper for ADK agent."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from pydantic_settings import (
|
||||||
|
BaseSettings,
|
||||||
|
PydanticBaseSettingsSource,
|
||||||
|
SettingsConfigDict,
|
||||||
|
YamlConfigSettingsSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentSettings(BaseSettings):
|
||||||
|
"""Settings for ADK agent with vector search."""
|
||||||
|
|
||||||
|
google_cloud_project: str
|
||||||
|
google_cloud_location: str
|
||||||
|
|
||||||
|
# Agent configuration
|
||||||
|
agent_name: str
|
||||||
|
agent_instructions: str
|
||||||
|
agent_model: str
|
||||||
|
|
||||||
|
# Firestore configuration
|
||||||
|
firestore_db: str
|
||||||
|
|
||||||
|
# MCP configuration
|
||||||
|
mcp_remote_url: str
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
yaml_file=CONFIG_FILE_PATH,
|
||||||
|
extra="ignore", # Ignore extra fields from config.yaml
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def settings_customise_sources(
|
||||||
|
cls,
|
||||||
|
settings_cls: type[BaseSettings],
|
||||||
|
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||||
|
env_settings: PydanticBaseSettingsSource,
|
||||||
|
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||||
|
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
||||||
|
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||||
|
"""Use env vars and YAML as settings sources."""
|
||||||
|
return (
|
||||||
|
env_settings,
|
||||||
|
YamlConfigSettingsSource(settings_cls),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
settings = AgentSettings.model_validate({})
|
||||||
10
src/va_agent/server.py
Normal file
10
src/va_agent/server.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""FastAPI server exposing the RAG agent endpoint.
|
||||||
|
|
||||||
|
NOTE: This file is a stub. The rag_eval module was removed in the
|
||||||
|
lean MCP implementation. This file is kept for reference but is not
|
||||||
|
functional.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI(title="RAG Agent")
|
||||||
582
src/va_agent/session.py
Normal file
582
src/va_agent/session.py
Normal file
@@ -0,0 +1,582 @@
|
|||||||
|
"""Firestore-backed session service for Google ADK."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import TYPE_CHECKING, Any, override
|
||||||
|
|
||||||
|
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||||
|
from google.adk.events.event import Event
|
||||||
|
from google.adk.sessions import _session_util
|
||||||
|
from google.adk.sessions.base_session_service import (
|
||||||
|
BaseSessionService,
|
||||||
|
GetSessionConfig,
|
||||||
|
ListSessionsResponse,
|
||||||
|
)
|
||||||
|
from google.adk.sessions.session import Session
|
||||||
|
from google.adk.sessions.state import State
|
||||||
|
from google.cloud.firestore_v1.async_transaction import async_transactional
|
||||||
|
from google.cloud.firestore_v1.base_query import FieldFilter
|
||||||
|
from google.cloud.firestore_v1.field_path import FieldPath
|
||||||
|
from google.genai.types import Content, Part
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from google import genai
|
||||||
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
|
||||||
|
logger = logging.getLogger("google_adk." + __name__)
|
||||||
|
|
||||||
|
_COMPACTION_LOCK_TTL = 300 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
@async_transactional
|
||||||
|
async def _try_claim_compaction_txn(transaction: Any, session_ref: Any) -> bool:
|
||||||
|
"""Atomically claim the compaction lock if it is free or stale."""
|
||||||
|
snapshot = await session_ref.get(transaction=transaction)
|
||||||
|
if not snapshot.exists:
|
||||||
|
return False
|
||||||
|
data = snapshot.to_dict() or {}
|
||||||
|
lock_time = data.get("compaction_lock")
|
||||||
|
if lock_time and (time.time() - lock_time) < _COMPACTION_LOCK_TTL:
|
||||||
|
return False
|
||||||
|
transaction.update(session_ref, {"compaction_lock": time.time()})
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class FirestoreSessionService(BaseSessionService):
|
||||||
|
"""A Firestore-backed implementation of BaseSessionService.
|
||||||
|
|
||||||
|
Firestore document layout (given ``collection_prefix="adk"``)::
|
||||||
|
|
||||||
|
adk_app_states/{app_name}
|
||||||
|
→ app-scoped state key/values
|
||||||
|
|
||||||
|
adk_user_states/{app_name}__{user_id}
|
||||||
|
→ user-scoped state key/values
|
||||||
|
|
||||||
|
adk_sessions/{app_name}__{user_id}__{session_id}
|
||||||
|
→ {app_name, user_id, session_id, state: {…}, last_update_time}
|
||||||
|
└─ events/{event_id} → serialised Event
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__( # noqa: PLR0913
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
db: AsyncClient,
|
||||||
|
collection_prefix: str = "adk",
|
||||||
|
compaction_token_threshold: int | None = None,
|
||||||
|
compaction_model: str = "gemini-2.5-flash",
|
||||||
|
compaction_keep_recent: int = 10,
|
||||||
|
genai_client: genai.Client | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize FirestoreSessionService.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Firestore async client
|
||||||
|
collection_prefix: Prefix for Firestore collections
|
||||||
|
compaction_token_threshold: Token count threshold for compaction
|
||||||
|
compaction_model: Model to use for summarization
|
||||||
|
compaction_keep_recent: Number of recent events to keep
|
||||||
|
genai_client: GenAI client for compaction summaries
|
||||||
|
|
||||||
|
"""
|
||||||
|
if compaction_token_threshold is not None and genai_client is None:
|
||||||
|
msg = "genai_client is required when compaction_token_threshold is set."
|
||||||
|
raise ValueError(msg)
|
||||||
|
self._db = db
|
||||||
|
self._prefix = collection_prefix
|
||||||
|
self._compaction_threshold = compaction_token_threshold
|
||||||
|
self._compaction_model = compaction_model
|
||||||
|
self._compaction_keep_recent = compaction_keep_recent
|
||||||
|
self._genai_client = genai_client
|
||||||
|
self._compaction_locks: dict[str, asyncio.Lock] = {}
|
||||||
|
self._active_tasks: set[asyncio.Task] = set()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Document-reference helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _app_state_ref(self, app_name: str) -> Any:
|
||||||
|
return self._db.collection(f"{self._prefix}_app_states").document(app_name)
|
||||||
|
|
||||||
|
def _user_state_ref(self, app_name: str, user_id: str) -> Any:
|
||||||
|
return self._db.collection(f"{self._prefix}_user_states").document(
|
||||||
|
f"{app_name}__{user_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _session_ref(self, app_name: str, user_id: str, session_id: str) -> Any:
|
||||||
|
return self._db.collection(f"{self._prefix}_sessions").document(
|
||||||
|
f"{app_name}__{user_id}__{session_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _events_col(self, app_name: str, user_id: str, session_id: str) -> Any:
|
||||||
|
return self._session_ref(app_name, user_id, session_id).collection("events")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# State helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _get_app_state(self, app_name: str) -> dict[str, Any]:
|
||||||
|
snap = await self._app_state_ref(app_name).get()
|
||||||
|
return snap.to_dict() or {} if snap.exists else {}
|
||||||
|
|
||||||
|
async def _get_user_state(self, app_name: str, user_id: str) -> dict[str, Any]:
|
||||||
|
snap = await self._user_state_ref(app_name, user_id).get()
|
||||||
|
return snap.to_dict() or {} if snap.exists else {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_state(
|
||||||
|
app_state: dict[str, Any],
|
||||||
|
user_state: dict[str, Any],
|
||||||
|
session_state: dict[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
merged = dict(session_state)
|
||||||
|
for key, value in app_state.items():
|
||||||
|
merged[State.APP_PREFIX + key] = value
|
||||||
|
for key, value in user_state.items():
|
||||||
|
merged[State.USER_PREFIX + key] = value
|
||||||
|
return merged
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Compaction helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _events_to_text(events: list[Event]) -> str:
|
||||||
|
lines: list[str] = []
|
||||||
|
for event in events:
|
||||||
|
if event.content and event.content.parts:
|
||||||
|
text = "".join(p.text or "" for p in event.content.parts)
|
||||||
|
if text:
|
||||||
|
role = "User" if event.author == "user" else "Assistant"
|
||||||
|
lines.append(f"{role}: {text}")
|
||||||
|
return "\n\n".join(lines)
|
||||||
|
|
||||||
|
async def _generate_summary(
|
||||||
|
self, existing_summary: str, events: list[Event]
|
||||||
|
) -> str:
|
||||||
|
conversation_text = self._events_to_text(events)
|
||||||
|
previous = (
|
||||||
|
f"Previous summary of earlier conversation:\n{existing_summary}\n\n"
|
||||||
|
if existing_summary
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
"Summarize the following conversation between a user and an "
|
||||||
|
"assistant. Preserve:\n"
|
||||||
|
"- Key decisions and conclusions\n"
|
||||||
|
"- User preferences and requirements\n"
|
||||||
|
"- Important facts, names, and numbers\n"
|
||||||
|
"- The overall topic and direction of the conversation\n"
|
||||||
|
"- Any pending tasks or open questions\n\n"
|
||||||
|
f"{previous}"
|
||||||
|
f"Conversation:\n{conversation_text}\n\n"
|
||||||
|
"Provide a clear, comprehensive summary."
|
||||||
|
)
|
||||||
|
if self._genai_client is None:
|
||||||
|
msg = "genai_client is required for compaction"
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
response = await self._genai_client.aio.models.generate_content(
|
||||||
|
model=self._compaction_model,
|
||||||
|
contents=prompt,
|
||||||
|
)
|
||||||
|
return response.text or ""
|
||||||
|
|
||||||
|
async def _compact_session(self, session: Session) -> None:
|
||||||
|
app_name = session.app_name
|
||||||
|
user_id = session.user_id
|
||||||
|
session_id = session.id
|
||||||
|
|
||||||
|
events_ref = self._events_col(app_name, user_id, session_id)
|
||||||
|
query = events_ref.order_by("timestamp")
|
||||||
|
event_docs = await query.get()
|
||||||
|
|
||||||
|
if len(event_docs) <= self._compaction_keep_recent:
|
||||||
|
return
|
||||||
|
|
||||||
|
all_events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||||
|
events_to_summarize = all_events[: -self._compaction_keep_recent]
|
||||||
|
|
||||||
|
session_snap = await self._session_ref(app_name, user_id, session_id).get()
|
||||||
|
existing_summary = (session_snap.to_dict() or {}).get(
|
||||||
|
"conversation_summary", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
summary = await self._generate_summary(
|
||||||
|
existing_summary, events_to_summarize
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Compaction summary generation failed; skipping.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Write summary BEFORE deleting events so a crash between the two
|
||||||
|
# steps leaves safe duplication rather than data loss.
|
||||||
|
await self._session_ref(app_name, user_id, session_id).update(
|
||||||
|
{"conversation_summary": summary}
|
||||||
|
)
|
||||||
|
|
||||||
|
docs_to_delete = event_docs[: -self._compaction_keep_recent]
|
||||||
|
for i in range(0, len(docs_to_delete), 500):
|
||||||
|
batch = self._db.batch()
|
||||||
|
for doc in docs_to_delete[i : i + 500]:
|
||||||
|
batch.delete(doc.reference)
|
||||||
|
await batch.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Compacted session %s: summarised %d events, kept %d.",
|
||||||
|
session_id,
|
||||||
|
len(docs_to_delete),
|
||||||
|
self._compaction_keep_recent,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _guarded_compact(self, session: Session) -> None:
|
||||||
|
"""Run compaction in the background with per-session locking."""
|
||||||
|
key = f"{session.app_name}__{session.user_id}__{session.id}"
|
||||||
|
lock = self._compaction_locks.setdefault(key, asyncio.Lock())
|
||||||
|
|
||||||
|
if lock.locked():
|
||||||
|
logger.debug("Compaction already running locally for %s; skipping.", key)
|
||||||
|
return
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
session_ref = self._session_ref(
|
||||||
|
session.app_name, session.user_id, session.id
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
transaction = self._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to claim compaction lock for %s", key)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not claimed:
|
||||||
|
logger.debug(
|
||||||
|
"Compaction lock held by another instance for %s; skipping.",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._compact_session(session)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Background compaction failed for %s", key)
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await session_ref.update({"compaction_lock": None})
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to release compaction lock for %s", key)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Await all in-flight compaction tasks. Call before shutdown."""
|
||||||
|
if self._active_tasks:
|
||||||
|
await asyncio.gather(*self._active_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# BaseSessionService implementation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def create_session(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
app_name: str,
|
||||||
|
user_id: str,
|
||||||
|
state: dict[str, Any] | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> Session:
|
||||||
|
if session_id and session_id.strip():
|
||||||
|
session_id = session_id.strip()
|
||||||
|
existing = await self._session_ref(app_name, user_id, session_id).get()
|
||||||
|
if existing.exists:
|
||||||
|
msg = f"Session with id {session_id} already exists."
|
||||||
|
raise AlreadyExistsError(msg)
|
||||||
|
else:
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
state_deltas = _session_util.extract_state_delta(state) # type: ignore[attr-defined]
|
||||||
|
app_state_delta = state_deltas["app"]
|
||||||
|
user_state_delta = state_deltas["user"]
|
||||||
|
session_state = state_deltas["session"]
|
||||||
|
|
||||||
|
write_coros: list = []
|
||||||
|
if app_state_delta:
|
||||||
|
write_coros.append(
|
||||||
|
self._app_state_ref(app_name).set(app_state_delta, merge=True)
|
||||||
|
)
|
||||||
|
if user_state_delta:
|
||||||
|
write_coros.append(
|
||||||
|
self._user_state_ref(app_name, user_id).set(
|
||||||
|
user_state_delta, merge=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
write_coros.append(
|
||||||
|
self._session_ref(app_name, user_id, session_id).set(
|
||||||
|
{
|
||||||
|
"app_name": app_name,
|
||||||
|
"user_id": user_id,
|
||||||
|
"session_id": session_id,
|
||||||
|
"state": session_state or {},
|
||||||
|
"last_update_time": now,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await asyncio.gather(*write_coros)
|
||||||
|
|
||||||
|
app_state, user_state = await asyncio.gather(
|
||||||
|
self._get_app_state(app_name),
|
||||||
|
self._get_user_state(app_name, user_id),
|
||||||
|
)
|
||||||
|
merged = self._merge_state(app_state, user_state, session_state or {})
|
||||||
|
|
||||||
|
return Session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
id=session_id,
|
||||||
|
state=merged,
|
||||||
|
last_update_time=now,
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def get_session(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
app_name: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
config: GetSessionConfig | None = None,
|
||||||
|
) -> Session | None:
|
||||||
|
snap = await self._session_ref(app_name, user_id, session_id).get()
|
||||||
|
if not snap.exists:
|
||||||
|
return None
|
||||||
|
|
||||||
|
session_data = snap.to_dict()
|
||||||
|
|
||||||
|
# Build events query
|
||||||
|
events_ref = self._events_col(app_name, user_id, session_id)
|
||||||
|
query = events_ref
|
||||||
|
if config and config.after_timestamp:
|
||||||
|
query = query.where(
|
||||||
|
filter=FieldFilter("timestamp", ">=", config.after_timestamp)
|
||||||
|
)
|
||||||
|
query = query.order_by("timestamp")
|
||||||
|
|
||||||
|
event_docs, app_state, user_state = await asyncio.gather(
|
||||||
|
query.get(),
|
||||||
|
self._get_app_state(app_name),
|
||||||
|
self._get_user_state(app_name, user_id),
|
||||||
|
)
|
||||||
|
events = [Event.model_validate(doc.to_dict()) for doc in event_docs]
|
||||||
|
|
||||||
|
if config and config.num_recent_events:
|
||||||
|
events = events[-config.num_recent_events :]
|
||||||
|
|
||||||
|
# Prepend conversation summary as synthetic context events
|
||||||
|
conversation_summary = session_data.get("conversation_summary")
|
||||||
|
if conversation_summary:
|
||||||
|
summary_event = Event(
|
||||||
|
id="summary-context",
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
text=(
|
||||||
|
"[Conversation context from previous"
|
||||||
|
" messages]\n"
|
||||||
|
f"{conversation_summary}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
timestamp=0.0,
|
||||||
|
invocation_id="compaction-summary",
|
||||||
|
)
|
||||||
|
ack_event = Event(
|
||||||
|
id="summary-ack",
|
||||||
|
author=app_name,
|
||||||
|
content=Content(
|
||||||
|
role="model",
|
||||||
|
parts=[
|
||||||
|
Part(
|
||||||
|
text=(
|
||||||
|
"Understood, I have the context from our"
|
||||||
|
" previous conversation and will continue"
|
||||||
|
" accordingly."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
timestamp=0.001,
|
||||||
|
invocation_id="compaction-summary",
|
||||||
|
)
|
||||||
|
events = [summary_event, ack_event, *events]
|
||||||
|
|
||||||
|
# Merge scoped state
|
||||||
|
merged = self._merge_state(app_state, user_state, session_data.get("state", {}))
|
||||||
|
|
||||||
|
return Session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
id=session_id,
|
||||||
|
state=merged,
|
||||||
|
events=events,
|
||||||
|
last_update_time=session_data.get("last_update_time", 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def list_sessions(
|
||||||
|
self, *, app_name: str, user_id: str | None = None
|
||||||
|
) -> ListSessionsResponse:
|
||||||
|
query = self._db.collection(f"{self._prefix}_sessions").where(
|
||||||
|
filter=FieldFilter("app_name", "==", app_name)
|
||||||
|
)
|
||||||
|
if user_id is not None:
|
||||||
|
query = query.where(filter=FieldFilter("user_id", "==", user_id))
|
||||||
|
|
||||||
|
docs = await query.get()
|
||||||
|
if not docs:
|
||||||
|
return ListSessionsResponse()
|
||||||
|
|
||||||
|
doc_dicts: list[dict[str, Any]] = [doc.to_dict() or {} for doc in docs]
|
||||||
|
|
||||||
|
# Pre-fetch app state and all distinct user states in parallel
|
||||||
|
unique_user_ids = list({d["user_id"] for d in doc_dicts})
|
||||||
|
app_state, *user_states = await asyncio.gather(
|
||||||
|
self._get_app_state(app_name),
|
||||||
|
*(self._get_user_state(app_name, uid) for uid in unique_user_ids),
|
||||||
|
)
|
||||||
|
user_state_cache = dict(zip(unique_user_ids, user_states, strict=False))
|
||||||
|
|
||||||
|
sessions: list[Session] = []
|
||||||
|
for data in doc_dicts:
|
||||||
|
s_user_id = data["user_id"]
|
||||||
|
merged = self._merge_state(
|
||||||
|
app_state,
|
||||||
|
user_state_cache[s_user_id],
|
||||||
|
data.get("state", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
sessions.append(
|
||||||
|
Session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=s_user_id,
|
||||||
|
id=data["session_id"],
|
||||||
|
state=merged,
|
||||||
|
events=[],
|
||||||
|
last_update_time=data.get("last_update_time", 0.0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ListSessionsResponse(sessions=sessions)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def delete_session(
|
||||||
|
self, *, app_name: str, user_id: str, session_id: str
|
||||||
|
) -> None:
|
||||||
|
ref = self._session_ref(app_name, user_id, session_id)
|
||||||
|
await self._db.recursive_delete(ref)
|
||||||
|
|
||||||
|
@override
|
||||||
|
async def append_event(self, session: Session, event: Event) -> Event:
|
||||||
|
if event.partial:
|
||||||
|
return event
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
|
||||||
|
app_name = session.app_name
|
||||||
|
user_id = session.user_id
|
||||||
|
session_id = session.id
|
||||||
|
|
||||||
|
# Base class: strips temp state, applies delta to in-memory session,
|
||||||
|
# appends event to session.events
|
||||||
|
event = await super().append_event(session=session, event=event)
|
||||||
|
session.last_update_time = event.timestamp
|
||||||
|
|
||||||
|
# Persist event document
|
||||||
|
event_data = event.model_dump(mode="json", exclude_none=True)
|
||||||
|
await (
|
||||||
|
self._events_col(app_name, user_id, session_id)
|
||||||
|
.document(event.id)
|
||||||
|
.set(event_data)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist state deltas
|
||||||
|
session_ref = self._session_ref(app_name, user_id, session_id)
|
||||||
|
|
||||||
|
if event.actions and event.actions.state_delta:
|
||||||
|
state_deltas = _session_util.extract_state_delta(event.actions.state_delta)
|
||||||
|
|
||||||
|
write_coros: list = []
|
||||||
|
if state_deltas["app"]:
|
||||||
|
write_coros.append(
|
||||||
|
self._app_state_ref(app_name).set(state_deltas["app"], merge=True)
|
||||||
|
)
|
||||||
|
if state_deltas["user"]:
|
||||||
|
write_coros.append(
|
||||||
|
self._user_state_ref(app_name, user_id).set(
|
||||||
|
state_deltas["user"], merge=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if state_deltas["session"]:
|
||||||
|
field_updates: dict[str, Any] = {
|
||||||
|
FieldPath("state", k).to_api_repr(): v
|
||||||
|
for k, v in state_deltas["session"].items()
|
||||||
|
}
|
||||||
|
field_updates["last_update_time"] = event.timestamp
|
||||||
|
write_coros.append(session_ref.update(field_updates))
|
||||||
|
else:
|
||||||
|
write_coros.append(
|
||||||
|
session_ref.update({"last_update_time": event.timestamp})
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.gather(*write_coros)
|
||||||
|
else:
|
||||||
|
await session_ref.update({"last_update_time": event.timestamp})
|
||||||
|
|
||||||
|
# Log token usage
|
||||||
|
if event.usage_metadata:
|
||||||
|
meta = event.usage_metadata
|
||||||
|
logger.info(
|
||||||
|
"Token usage for session %s event %s: "
|
||||||
|
"prompt=%s, candidates=%s, total=%s",
|
||||||
|
session_id,
|
||||||
|
event.id,
|
||||||
|
meta.prompt_token_count,
|
||||||
|
meta.candidates_token_count,
|
||||||
|
meta.total_token_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger compaction if total token count exceeds threshold
|
||||||
|
if (
|
||||||
|
self._compaction_threshold is not None
|
||||||
|
and event.usage_metadata
|
||||||
|
and event.usage_metadata.total_token_count
|
||||||
|
and event.usage_metadata.total_token_count >= self._compaction_threshold
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Compaction triggered for session %s: "
|
||||||
|
"total_token_count=%d >= threshold=%d",
|
||||||
|
session_id,
|
||||||
|
event.usage_metadata.total_token_count,
|
||||||
|
self._compaction_threshold,
|
||||||
|
)
|
||||||
|
task = asyncio.create_task(self._guarded_compact(session))
|
||||||
|
self._active_tasks.add(task)
|
||||||
|
task.add_done_callback(self._active_tasks.discard)
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
logger.info(
|
||||||
|
"append_event completed for session %s event %s in %.3fs",
|
||||||
|
session_id,
|
||||||
|
event.id,
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
return event
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
35
tests/conftest.py
Normal file
35
tests/conftest.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""Shared fixtures for Firestore session service tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
|
||||||
|
from va_agent.session import FirestoreSessionService
|
||||||
|
|
||||||
|
os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8153")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db():
|
||||||
|
return AsyncClient(project="test-project")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def service(db: AsyncClient):
|
||||||
|
prefix = f"test_{uuid.uuid4().hex[:8]}"
|
||||||
|
return FirestoreSessionService(db=db, collection_prefix=prefix)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app_name():
|
||||||
|
return f"app_{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def user_id():
|
||||||
|
return f"user_{uuid.uuid4().hex[:8]}"
|
||||||
515
tests/test_compaction.py
Normal file
515
tests/test_compaction.py
Normal file
@@ -0,0 +1,515 @@
|
|||||||
|
"""Tests for conversation compaction in FirestoreSessionService."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from google import genai
|
||||||
|
from google.adk.events.event import Event
|
||||||
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
from google.genai.types import Content, GenerateContentResponseUsageMetadata, Part
|
||||||
|
|
||||||
|
from va_agent.session import FirestoreSessionService, _try_claim_compaction_txn
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_genai_client():
|
||||||
|
client = MagicMock(spec=genai.Client)
|
||||||
|
response = MagicMock()
|
||||||
|
response.text = "Summary of the conversation so far."
|
||||||
|
client.aio.models.generate_content = AsyncMock(return_value=response)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def compaction_service(db: AsyncClient, mock_genai_client):
|
||||||
|
prefix = f"test_{uuid.uuid4().hex[:8]}"
|
||||||
|
return FirestoreSessionService(
|
||||||
|
db=db,
|
||||||
|
collection_prefix=prefix,
|
||||||
|
compaction_token_threshold=100,
|
||||||
|
compaction_keep_recent=2,
|
||||||
|
genai_client=mock_genai_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# __init__ validation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionInit:
|
||||||
|
async def test_requires_genai_client(self, db):
|
||||||
|
with pytest.raises(ValueError, match="genai_client is required"):
|
||||||
|
FirestoreSessionService(
|
||||||
|
db=db,
|
||||||
|
compaction_token_threshold=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_no_threshold_no_client_ok(self, db):
|
||||||
|
svc = FirestoreSessionService(db=db)
|
||||||
|
assert svc._compaction_threshold is None
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Compaction trigger
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionTrigger:
|
||||||
|
async def test_compaction_triggered_above_threshold(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add 5 events, last one with usage_metadata above threshold
|
||||||
|
base = time.time()
|
||||||
|
for i in range(4):
|
||||||
|
e = Event(
|
||||||
|
author="user" if i % 2 == 0 else app_name,
|
||||||
|
content=Content(
|
||||||
|
role="user" if i % 2 == 0 else "model",
|
||||||
|
parts=[Part(text=f"message {i}")],
|
||||||
|
),
|
||||||
|
timestamp=base + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# This event crosses the threshold
|
||||||
|
trigger_event = Event(
|
||||||
|
author=app_name,
|
||||||
|
content=Content(
|
||||||
|
role="model", parts=[Part(text="final response")]
|
||||||
|
),
|
||||||
|
timestamp=base + 4,
|
||||||
|
invocation_id="inv-4",
|
||||||
|
usage_metadata=GenerateContentResponseUsageMetadata(
|
||||||
|
total_token_count=200,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, trigger_event)
|
||||||
|
await compaction_service.close()
|
||||||
|
|
||||||
|
# Summary generation should have been called
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_called_once()
|
||||||
|
|
||||||
|
# Fetch session: should have summary + only keep_recent events
|
||||||
|
fetched = await compaction_service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
# 2 synthetic summary events + 2 kept real events
|
||||||
|
assert len(fetched.events) == 4
|
||||||
|
assert fetched.events[0].id == "summary-context"
|
||||||
|
assert fetched.events[1].id == "summary-ack"
|
||||||
|
assert "Summary of the conversation" in fetched.events[0].content.parts[0].text
|
||||||
|
|
||||||
|
async def test_no_compaction_below_threshold(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author=app_name,
|
||||||
|
content=Content(
|
||||||
|
role="model", parts=[Part(text="short reply")]
|
||||||
|
),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
usage_metadata=GenerateContentResponseUsageMetadata(
|
||||||
|
total_token_count=50,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, event)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_no_compaction_without_usage_metadata(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text="hello")]
|
||||||
|
),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, event)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Compaction with too few events (nothing to compact)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionEdgeCases:
|
||||||
|
async def test_skip_when_fewer_events_than_keep_recent(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
# Only 2 events, keep_recent=2 → nothing to summarize
|
||||||
|
for i in range(2):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# Trigger compaction manually even though threshold wouldn't fire
|
||||||
|
await compaction_service._compact_session(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_summary_generation_failure_is_non_fatal(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# Make summary generation fail
|
||||||
|
mock_genai_client.aio.models.generate_content = AsyncMock(
|
||||||
|
side_effect=RuntimeError("API error")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await compaction_service._compact_session(session)
|
||||||
|
|
||||||
|
# All events should still be present
|
||||||
|
fetched = await compaction_service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert len(fetched.events) == 5
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# get_session with summary
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSessionWithSummary:
|
||||||
|
async def test_no_summary_no_synthetic_events(
|
||||||
|
self, compaction_service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text="hello")]
|
||||||
|
),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, event)
|
||||||
|
|
||||||
|
fetched = await compaction_service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert len(fetched.events) == 1
|
||||||
|
assert fetched.events[0].author == "user"
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# _events_to_text
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEventsToText:
|
||||||
|
async def test_formats_user_and_assistant(self):
|
||||||
|
events = [
|
||||||
|
Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text="Hi there")]
|
||||||
|
),
|
||||||
|
timestamp=1.0,
|
||||||
|
invocation_id="inv-1",
|
||||||
|
),
|
||||||
|
Event(
|
||||||
|
author="bot",
|
||||||
|
content=Content(
|
||||||
|
role="model", parts=[Part(text="Hello!")]
|
||||||
|
),
|
||||||
|
timestamp=2.0,
|
||||||
|
invocation_id="inv-2",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
text = FirestoreSessionService._events_to_text(events)
|
||||||
|
assert "User: Hi there" in text
|
||||||
|
assert "Assistant: Hello!" in text
|
||||||
|
|
||||||
|
async def test_skips_events_without_text(self):
|
||||||
|
events = [
|
||||||
|
Event(
|
||||||
|
author="user",
|
||||||
|
timestamp=1.0,
|
||||||
|
invocation_id="inv-1",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
text = FirestoreSessionService._events_to_text(events)
|
||||||
|
assert text == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Firestore distributed lock
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompactionLock:
|
||||||
|
async def test_claim_and_release(
|
||||||
|
self, compaction_service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Claim the lock
|
||||||
|
transaction = compaction_service._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||||
|
assert claimed is True
|
||||||
|
|
||||||
|
# Lock is now held — second claim should fail
|
||||||
|
transaction2 = compaction_service._db.transaction()
|
||||||
|
claimed2 = await _try_claim_compaction_txn(transaction2, session_ref)
|
||||||
|
assert claimed2 is False
|
||||||
|
|
||||||
|
# Release the lock
|
||||||
|
await session_ref.update({"compaction_lock": None})
|
||||||
|
|
||||||
|
# Can claim again after release
|
||||||
|
transaction3 = compaction_service._db.transaction()
|
||||||
|
claimed3 = await _try_claim_compaction_txn(transaction3, session_ref)
|
||||||
|
assert claimed3 is True
|
||||||
|
|
||||||
|
async def test_stale_lock_can_be_reclaimed(
|
||||||
|
self, compaction_service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set a stale lock (older than TTL)
|
||||||
|
await session_ref.update({"compaction_lock": time.time() - 600})
|
||||||
|
|
||||||
|
# Should be able to reclaim a stale lock
|
||||||
|
transaction = compaction_service._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(transaction, session_ref)
|
||||||
|
assert claimed is True
|
||||||
|
|
||||||
|
async def test_claim_nonexistent_session(self, compaction_service):
|
||||||
|
ref = compaction_service._session_ref("no_app", "no_user", "no_id")
|
||||||
|
transaction = compaction_service._db.transaction()
|
||||||
|
claimed = await _try_claim_compaction_txn(transaction, ref)
|
||||||
|
assert claimed is False
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Guarded compact
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGuardedCompact:
|
||||||
|
async def test_local_lock_skips_concurrent(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# Hold the in-process lock so _guarded_compact skips
|
||||||
|
key = f"{app_name}__{user_id}__{session.id}"
|
||||||
|
lock = compaction_service._compaction_locks.setdefault(
|
||||||
|
key, asyncio.Lock()
|
||||||
|
)
|
||||||
|
async with lock:
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_firestore_lock_held_skips(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
# Set a fresh Firestore lock (simulating another instance)
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
await session_ref.update({"compaction_lock": time.time()})
|
||||||
|
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_claim_failure_logs_and_skips(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"va_agent.session._try_claim_compaction_txn",
|
||||||
|
side_effect=RuntimeError("Firestore down"),
|
||||||
|
):
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
mock_genai_client.aio.models.generate_content.assert_not_called()
|
||||||
|
|
||||||
|
async def test_compaction_failure_releases_lock(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make _compact_session raise an unhandled exception
|
||||||
|
with patch.object(
|
||||||
|
compaction_service,
|
||||||
|
"_compact_session",
|
||||||
|
side_effect=RuntimeError("unexpected crash"),
|
||||||
|
):
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
# Lock should be released even after failure
|
||||||
|
session_ref = compaction_service._session_ref(
|
||||||
|
app_name, user_id, session.id
|
||||||
|
)
|
||||||
|
snap = await session_ref.get()
|
||||||
|
assert snap.to_dict().get("compaction_lock") is None
|
||||||
|
|
||||||
|
async def test_lock_release_failure_is_non_fatal(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
original_session_ref = compaction_service._session_ref
|
||||||
|
|
||||||
|
def patched_session_ref(an, uid, sid):
|
||||||
|
ref = original_session_ref(an, uid, sid)
|
||||||
|
original_update = ref.update
|
||||||
|
|
||||||
|
async def failing_update(data):
|
||||||
|
if "compaction_lock" in data:
|
||||||
|
raise RuntimeError("Firestore write failed")
|
||||||
|
return await original_update(data)
|
||||||
|
|
||||||
|
ref.update = failing_update
|
||||||
|
return ref
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
compaction_service,
|
||||||
|
"_session_ref",
|
||||||
|
side_effect=patched_session_ref,
|
||||||
|
):
|
||||||
|
# Should not raise despite lock release failure
|
||||||
|
await compaction_service._guarded_compact(session)
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# close()
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestClose:
|
||||||
|
async def test_close_no_tasks(self, compaction_service):
|
||||||
|
await compaction_service.close()
|
||||||
|
|
||||||
|
async def test_close_awaits_tasks(
|
||||||
|
self, compaction_service, mock_genai_client, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await compaction_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
base = time.time()
|
||||||
|
for i in range(4):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(
|
||||||
|
role="user", parts=[Part(text=f"msg {i}")]
|
||||||
|
),
|
||||||
|
timestamp=base + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, e)
|
||||||
|
|
||||||
|
trigger = Event(
|
||||||
|
author=app_name,
|
||||||
|
content=Content(
|
||||||
|
role="model", parts=[Part(text="trigger")]
|
||||||
|
),
|
||||||
|
timestamp=base + 4,
|
||||||
|
invocation_id="inv-4",
|
||||||
|
usage_metadata=GenerateContentResponseUsageMetadata(
|
||||||
|
total_token_count=200,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await compaction_service.append_event(session, trigger)
|
||||||
|
assert len(compaction_service._active_tasks) > 0
|
||||||
|
|
||||||
|
await compaction_service.close()
|
||||||
|
assert len(compaction_service._active_tasks) == 0
|
||||||
428
tests/test_firestore_session_service.py
Normal file
428
tests/test_firestore_session_service.py
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
"""Tests for FirestoreSessionService against the Firestore emulator."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||||
|
from google.adk.events.event import Event
|
||||||
|
from google.adk.events.event_actions import EventActions
|
||||||
|
from google.adk.sessions.base_session_service import GetSessionConfig
|
||||||
|
from google.genai.types import Content, Part
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# create_session
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCreateSession:
|
||||||
|
async def test_auto_generates_id(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
assert session.id
|
||||||
|
assert session.app_name == app_name
|
||||||
|
assert session.user_id == user_id
|
||||||
|
assert session.last_update_time > 0
|
||||||
|
|
||||||
|
async def test_custom_id(self, service, app_name, user_id):
|
||||||
|
sid = "my-custom-session"
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=sid
|
||||||
|
)
|
||||||
|
assert session.id == sid
|
||||||
|
|
||||||
|
async def test_duplicate_id_raises(self, service, app_name, user_id):
|
||||||
|
sid = "dup-session"
|
||||||
|
await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=sid
|
||||||
|
)
|
||||||
|
with pytest.raises(AlreadyExistsError):
|
||||||
|
await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=sid
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_session_state(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
state={"count": 42},
|
||||||
|
)
|
||||||
|
assert session.state["count"] == 42
|
||||||
|
|
||||||
|
async def test_scoped_state(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
state={
|
||||||
|
"app:global_flag": True,
|
||||||
|
"user:lang": "es",
|
||||||
|
"local_key": "val",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert session.state["app:global_flag"] is True
|
||||||
|
assert session.state["user:lang"] == "es"
|
||||||
|
assert session.state["local_key"] == "val"
|
||||||
|
|
||||||
|
async def test_temp_state_not_persisted(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
state={"temp:scratch": "gone", "keep": "yes"},
|
||||||
|
)
|
||||||
|
retrieved = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert "temp:scratch" not in retrieved.state
|
||||||
|
assert retrieved.state["keep"] == "yes"
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# get_session
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetSession:
|
||||||
|
async def test_nonexistent_returns_none(self, service, app_name, user_id):
|
||||||
|
result = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id="nope"
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_roundtrip(self, service, app_name, user_id):
|
||||||
|
created = await service.create_session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
state={"foo": "bar"},
|
||||||
|
)
|
||||||
|
fetched = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=created.id
|
||||||
|
)
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.id == created.id
|
||||||
|
assert fetched.state["foo"] == "bar"
|
||||||
|
assert fetched.last_update_time == pytest.approx(
|
||||||
|
created.last_update_time, abs=0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_returns_events(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(parts=[Part(text="hello")]),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await service.append_event(session, event)
|
||||||
|
|
||||||
|
fetched = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert len(fetched.events) == 1
|
||||||
|
assert fetched.events[0].author == "user"
|
||||||
|
|
||||||
|
async def test_num_recent_events(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
for i in range(5):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await service.append_event(session, e)
|
||||||
|
|
||||||
|
fetched = await service.get_session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session.id,
|
||||||
|
config=GetSessionConfig(num_recent_events=2),
|
||||||
|
)
|
||||||
|
assert len(fetched.events) == 2
|
||||||
|
|
||||||
|
async def test_after_timestamp(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
base = time.time()
|
||||||
|
for i in range(3):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
timestamp=base + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await service.append_event(session, e)
|
||||||
|
|
||||||
|
fetched = await service.get_session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session.id,
|
||||||
|
config=GetSessionConfig(after_timestamp=base + 1),
|
||||||
|
)
|
||||||
|
assert len(fetched.events) == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# list_sessions
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestListSessions:
|
||||||
|
async def test_empty(self, service, app_name, user_id):
|
||||||
|
resp = await service.list_sessions(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
assert resp.sessions == [] or resp.sessions is None
|
||||||
|
|
||||||
|
async def test_returns_created_sessions(
|
||||||
|
self, service, app_name, user_id
|
||||||
|
):
|
||||||
|
s1 = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
s2 = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
resp = await service.list_sessions(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
ids = {s.id for s in resp.sessions}
|
||||||
|
assert s1.id in ids
|
||||||
|
assert s2.id in ids
|
||||||
|
|
||||||
|
async def test_filter_by_user(self, service, app_name):
|
||||||
|
uid1 = f"user_{uuid.uuid4().hex[:8]}"
|
||||||
|
uid2 = f"user_{uuid.uuid4().hex[:8]}"
|
||||||
|
await service.create_session(app_name=app_name, user_id=uid1)
|
||||||
|
await service.create_session(app_name=app_name, user_id=uid2)
|
||||||
|
|
||||||
|
resp = await service.list_sessions(
|
||||||
|
app_name=app_name, user_id=uid1
|
||||||
|
)
|
||||||
|
assert len(resp.sessions) == 1
|
||||||
|
assert resp.sessions[0].user_id == uid1
|
||||||
|
|
||||||
|
async def test_sessions_have_merged_state(
|
||||||
|
self, service, app_name, user_id
|
||||||
|
):
|
||||||
|
await service.create_session(
|
||||||
|
app_name=app_name,
|
||||||
|
user_id=user_id,
|
||||||
|
state={"app:shared": "yes", "local": "val"},
|
||||||
|
)
|
||||||
|
resp = await service.list_sessions(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
s = resp.sessions[0]
|
||||||
|
assert s.state["app:shared"] == "yes"
|
||||||
|
assert s.state["local"] == "val"
|
||||||
|
|
||||||
|
async def test_sessions_have_no_events(
|
||||||
|
self, service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="user", timestamp=time.time(), invocation_id="inv-1"
|
||||||
|
)
|
||||||
|
await service.append_event(session, event)
|
||||||
|
|
||||||
|
resp = await service.list_sessions(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
assert resp.sessions[0].events == []
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# delete_session
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteSession:
|
||||||
|
async def test_delete(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
await service.delete_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
result = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_delete_removes_events(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="user", timestamp=time.time(), invocation_id="inv-1"
|
||||||
|
)
|
||||||
|
await service.append_event(session, event)
|
||||||
|
|
||||||
|
await service.delete_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
result = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# append_event
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAppendEvent:
|
||||||
|
async def test_basic(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(parts=[Part(text="hi")]),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
returned = await service.append_event(session, event)
|
||||||
|
assert returned.id == event.id
|
||||||
|
assert returned.timestamp > 0
|
||||||
|
|
||||||
|
async def test_partial_event_not_persisted(
|
||||||
|
self, service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="user",
|
||||||
|
partial=True,
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await service.append_event(session, event)
|
||||||
|
|
||||||
|
fetched = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert len(fetched.events) == 0
|
||||||
|
|
||||||
|
async def test_session_state_delta(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="agent",
|
||||||
|
actions=EventActions(state_delta={"counter": 1}),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await service.append_event(session, event)
|
||||||
|
|
||||||
|
fetched = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert fetched.state["counter"] == 1
|
||||||
|
|
||||||
|
async def test_app_state_delta(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="agent",
|
||||||
|
actions=EventActions(state_delta={"app:version": "2.0"}),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await service.append_event(session, event)
|
||||||
|
|
||||||
|
fetched = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert fetched.state["app:version"] == "2.0"
|
||||||
|
|
||||||
|
async def test_user_state_delta(self, service, app_name, user_id):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="agent",
|
||||||
|
actions=EventActions(state_delta={"user:pref": "dark"}),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await service.append_event(session, event)
|
||||||
|
|
||||||
|
fetched = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert fetched.state["user:pref"] == "dark"
|
||||||
|
|
||||||
|
async def test_updates_last_update_time(
|
||||||
|
self, service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
original_time = session.last_update_time
|
||||||
|
|
||||||
|
event = Event(
|
||||||
|
author="user",
|
||||||
|
timestamp=time.time() + 10,
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await service.append_event(session, event)
|
||||||
|
|
||||||
|
fetched = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert fetched.last_update_time > original_time
|
||||||
|
|
||||||
|
async def test_multiple_events_accumulate(
|
||||||
|
self, service, app_name, user_id
|
||||||
|
):
|
||||||
|
session = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
for i in range(3):
|
||||||
|
e = Event(
|
||||||
|
author="user",
|
||||||
|
content=Content(parts=[Part(text=f"msg {i}")]),
|
||||||
|
timestamp=time.time() + i,
|
||||||
|
invocation_id=f"inv-{i}",
|
||||||
|
)
|
||||||
|
await service.append_event(session, e)
|
||||||
|
|
||||||
|
fetched = await service.get_session(
|
||||||
|
app_name=app_name, user_id=user_id, session_id=session.id
|
||||||
|
)
|
||||||
|
assert len(fetched.events) == 3
|
||||||
|
|
||||||
|
async def test_app_state_shared_across_sessions(
|
||||||
|
self, service, app_name, user_id
|
||||||
|
):
|
||||||
|
s1 = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
event = Event(
|
||||||
|
author="agent",
|
||||||
|
actions=EventActions(state_delta={"app:shared_val": 99}),
|
||||||
|
timestamp=time.time(),
|
||||||
|
invocation_id="inv-1",
|
||||||
|
)
|
||||||
|
await service.append_event(s1, event)
|
||||||
|
|
||||||
|
s2 = await service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id
|
||||||
|
)
|
||||||
|
assert s2.state["app:shared_val"] == 99
|
||||||
Reference in New Issue
Block a user