From 5941c41296b5712c7212d6191173fd8f79966ce4 Mon Sep 17 00:00:00 2001 From: Anibal Angulo Date: Thu, 5 Mar 2026 05:55:09 +0000 Subject: [PATCH] Remove firestore emulator from test dependencies --- tests/conftest.py | 8 +- tests/fake_firestore.py | 284 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 287 insertions(+), 5 deletions(-) create mode 100644 tests/fake_firestore.py diff --git a/tests/conftest.py b/tests/conftest.py index 959b677..2c1c3ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,25 +2,23 @@ from __future__ import annotations -import os import uuid import pytest import pytest_asyncio -from google.cloud.firestore_v1.async_client import AsyncClient from va_agent.session import FirestoreSessionService -os.environ.setdefault("FIRESTORE_EMULATOR_HOST", "localhost:8602") +from .fake_firestore import FakeAsyncClient @pytest_asyncio.fixture async def db(): - return AsyncClient(project="test-project") + return FakeAsyncClient() @pytest_asyncio.fixture -async def service(db: AsyncClient): +async def service(db): prefix = f"test_{uuid.uuid4().hex[:8]}" return FirestoreSessionService(db=db, collection_prefix=prefix) diff --git a/tests/fake_firestore.py b/tests/fake_firestore.py new file mode 100644 index 0000000..f7fce9b --- /dev/null +++ b/tests/fake_firestore.py @@ -0,0 +1,284 @@ +"""In-memory fake of the Firestore async surface used by this project. + +Covers: AsyncClient, DocumentReference, CollectionReference, Query, +DocumentSnapshot, WriteBatch, and basic transaction support (enough for +``@async_transactional``). +""" + +from __future__ import annotations + +import copy +from typing import Any + + +# ------------------------------------------------------------------ # +# DocumentSnapshot +# ------------------------------------------------------------------ # + +class FakeDocumentSnapshot: + def __init__(self, *, exists: bool, data: dict[str, Any] | None, reference: FakeDocumentReference) -> None: + self._exists = exists + self._data = data + self._reference = reference + + @property + def exists(self) -> bool: + return self._exists + + @property + def reference(self) -> FakeDocumentReference: + return self._reference + + def to_dict(self) -> dict[str, Any] | None: + if not self._exists: + return None + return copy.deepcopy(self._data) + + +# ------------------------------------------------------------------ # +# DocumentReference +# ------------------------------------------------------------------ # + +class FakeDocumentReference: + def __init__(self, store: FakeStore, path: str) -> None: + self._store = store + self._path = path + + @property + def path(self) -> str: + return self._path + + # --- read --- + + async def get(self, *, transaction: FakeTransaction | None = None) -> FakeDocumentSnapshot: + data = self._store.get_doc(self._path) + if data is None: + return FakeDocumentSnapshot(exists=False, data=None, reference=self) + return FakeDocumentSnapshot(exists=True, data=copy.deepcopy(data), reference=self) + + # --- write --- + + async def set(self, document_data: dict[str, Any], merge: bool = False) -> None: + if merge: + existing = self._store.get_doc(self._path) or {} + existing.update(document_data) + self._store.set_doc(self._path, existing) + else: + self._store.set_doc(self._path, copy.deepcopy(document_data)) + + async def update(self, field_updates: dict[str, Any]) -> None: + data = self._store.get_doc(self._path) + if data is None: + msg = f"Document {self._path} does not exist" + raise ValueError(msg) + for key, value in field_updates.items(): + _nested_set(data, key, value) + self._store.set_doc(self._path, data) + + # --- subcollection --- + + def collection(self, subcollection_name: str) -> FakeCollectionReference: + return FakeCollectionReference(self._store, f"{self._path}/{subcollection_name}") + + +# ------------------------------------------------------------------ # +# Helpers for nested field-path updates ("state.counter" → data["state"]["counter"]) +# ------------------------------------------------------------------ # + +def _nested_set(data: dict[str, Any], dotted_key: str, value: Any) -> None: + parts = dotted_key.split(".") + for part in parts[:-1]: + # Backtick-quoted segments (Firestore FieldPath encoding) + part = part.strip("`") + data = data.setdefault(part, {}) + final = parts[-1].strip("`") + data[final] = value + + +# ------------------------------------------------------------------ # +# Query +# ------------------------------------------------------------------ # + +class FakeQuery: + """Supports chained .where() / .order_by() / .get().""" + + def __init__(self, store: FakeStore, collection_path: str) -> None: + self._store = store + self._collection_path = collection_path + self._filters: list[tuple[str, str, Any]] = [] + self._order_by_field: str | None = None + + def where(self, *, filter: Any) -> FakeQuery: # noqa: A002 + clone = FakeQuery(self._store, self._collection_path) + clone._filters = [*self._filters, (filter.field_path, filter.op_string, filter.value)] + clone._order_by_field = self._order_by_field + return clone + + def order_by(self, field_path: str) -> FakeQuery: + clone = FakeQuery(self._store, self._collection_path) + clone._filters = list(self._filters) + clone._order_by_field = field_path + return clone + + async def get(self) -> list[FakeDocumentSnapshot]: + docs = self._store.list_collection(self._collection_path) + results: list[tuple[str, dict[str, Any]]] = [] + + for doc_path, data in docs: + if all(_match(data, field, op, val) for field, op, val in self._filters): + results.append((doc_path, data)) + + if self._order_by_field: + field = self._order_by_field + results.sort(key=lambda item: item[1].get(field, 0)) + + return [ + FakeDocumentSnapshot( + exists=True, + data=copy.deepcopy(data), + reference=FakeDocumentReference(self._store, path), + ) + for path, data in results + ] + + +def _match(data: dict[str, Any], field: str, op: str, value: Any) -> bool: + doc_val = data.get(field) + if op == "==": + return doc_val == value + if op == ">=": + return doc_val is not None and doc_val >= value + return False + + +# ------------------------------------------------------------------ # +# CollectionReference (extends Query behaviour) +# ------------------------------------------------------------------ # + +class FakeCollectionReference(FakeQuery): + def document(self, document_id: str) -> FakeDocumentReference: + return FakeDocumentReference(self._store, f"{self._collection_path}/{document_id}") + + +# ------------------------------------------------------------------ # +# WriteBatch +# ------------------------------------------------------------------ # + +class FakeWriteBatch: + def __init__(self, store: FakeStore) -> None: + self._store = store + self._deletes: list[str] = [] + + def delete(self, doc_ref: FakeDocumentReference) -> None: + self._deletes.append(doc_ref.path) + + async def commit(self) -> None: + for path in self._deletes: + self._store.delete_doc(path) + + +# ------------------------------------------------------------------ # +# Transaction (minimal, supports @async_transactional) +# ------------------------------------------------------------------ # + +class FakeTransaction: + """Minimal transaction compatible with ``@async_transactional``. + + The decorator calls ``_clean_up()``, ``_begin()``, the wrapped function, + then ``_commit()``. On error it calls ``_rollback()``. + ``in_progress`` is a property that checks ``_id is not None``. + """ + + def __init__(self, store: FakeStore) -> None: + self._store = store + self._staged_updates: list[tuple[str, dict[str, Any]]] = [] + self._id: bytes | None = None + self._max_attempts = 1 + self._read_only = False + + @property + def in_progress(self) -> bool: + return self._id is not None + + def _clean_up(self) -> None: + self._id = None + + async def _begin(self, retry_id: bytes | None = None) -> None: + self._id = b"fake-txn" + + async def _commit(self) -> list: + for path, updates in self._staged_updates: + data = self._store.get_doc(path) + if data is not None: + for key, value in updates.items(): + _nested_set(data, key, value) + self._store.set_doc(path, data) + self._staged_updates.clear() + self._clean_up() + return [] + + async def _rollback(self) -> None: + self._staged_updates.clear() + self._clean_up() + + def update(self, doc_ref: FakeDocumentReference, field_updates: dict[str, Any]) -> None: + self._staged_updates.append((doc_ref.path, field_updates)) + + +# ------------------------------------------------------------------ # +# Document store (flat dict keyed by path) +# ------------------------------------------------------------------ # + +class FakeStore: + def __init__(self) -> None: + self._docs: dict[str, dict[str, Any]] = {} + + def get_doc(self, path: str) -> dict[str, Any] | None: + data = self._docs.get(path) + return data # returns reference, callers deepcopy where needed + + def set_doc(self, path: str, data: dict[str, Any]) -> None: + self._docs[path] = data + + def delete_doc(self, path: str) -> None: + self._docs.pop(path, None) + + def list_collection(self, collection_path: str) -> list[tuple[str, dict[str, Any]]]: + """Return (path, data) for every direct child doc of *collection_path*.""" + prefix = collection_path + "/" + results: list[tuple[str, dict[str, Any]]] = [] + for doc_path, data in self._docs.items(): + if not doc_path.startswith(prefix): + continue + # Must be a direct child (no further '/' after the prefix, except maybe subcollection paths) + remainder = doc_path[len(prefix):] + if "/" not in remainder: + results.append((doc_path, data)) + return results + + def recursive_delete(self, path: str) -> None: + """Delete a document and everything nested under it.""" + to_delete = [p for p in self._docs if p == path or p.startswith(path + "/")] + for p in to_delete: + del self._docs[p] + + +# ------------------------------------------------------------------ # +# FakeAsyncClient (drop-in for AsyncClient) +# ------------------------------------------------------------------ # + +class FakeAsyncClient: + def __init__(self, **_kwargs: Any) -> None: + self._store = FakeStore() + + def collection(self, collection_path: str) -> FakeCollectionReference: + return FakeCollectionReference(self._store, collection_path) + + def batch(self) -> FakeWriteBatch: + return FakeWriteBatch(self._store) + + def transaction(self, **kwargs: Any) -> FakeTransaction: + return FakeTransaction(self._store) + + async def recursive_delete(self, doc_ref: FakeDocumentReference) -> None: + self._store.recursive_delete(doc_ref.path)