"""In-memory fake of the Firestore async surface used by this project. Covers: AsyncClient, DocumentReference, CollectionReference, Query, DocumentSnapshot, WriteBatch, and basic transaction support (enough for ``@async_transactional``). """ from __future__ import annotations import copy from typing import Any # ------------------------------------------------------------------ # # DocumentSnapshot # ------------------------------------------------------------------ # class FakeDocumentSnapshot: def __init__(self, *, exists: bool, data: dict[str, Any] | None, reference: FakeDocumentReference) -> None: self._exists = exists self._data = data self._reference = reference @property def exists(self) -> bool: return self._exists @property def reference(self) -> FakeDocumentReference: return self._reference def to_dict(self) -> dict[str, Any] | None: if not self._exists: return None return copy.deepcopy(self._data) # ------------------------------------------------------------------ # # DocumentReference # ------------------------------------------------------------------ # class FakeDocumentReference: def __init__(self, store: FakeStore, path: str) -> None: self._store = store self._path = path @property def path(self) -> str: return self._path # --- read --- async def get(self, *, transaction: FakeTransaction | None = None) -> FakeDocumentSnapshot: data = self._store.get_doc(self._path) if data is None: return FakeDocumentSnapshot(exists=False, data=None, reference=self) return FakeDocumentSnapshot(exists=True, data=copy.deepcopy(data), reference=self) # --- write --- async def set(self, document_data: dict[str, Any], merge: bool = False) -> None: if merge: existing = self._store.get_doc(self._path) or {} existing.update(document_data) self._store.set_doc(self._path, existing) else: self._store.set_doc(self._path, copy.deepcopy(document_data)) async def update(self, field_updates: dict[str, Any]) -> None: data = self._store.get_doc(self._path) if data is None: msg = f"Document {self._path} does not exist" raise ValueError(msg) for key, value in field_updates.items(): _nested_set(data, key, value) self._store.set_doc(self._path, data) # --- subcollection --- def collection(self, subcollection_name: str) -> FakeCollectionReference: return FakeCollectionReference(self._store, f"{self._path}/{subcollection_name}") # ------------------------------------------------------------------ # # Helpers for nested field-path updates ("state.counter" → data["state"]["counter"]) # ------------------------------------------------------------------ # def _nested_set(data: dict[str, Any], dotted_key: str, value: Any) -> None: parts = dotted_key.split(".") for part in parts[:-1]: # Backtick-quoted segments (Firestore FieldPath encoding) part = part.strip("`") data = data.setdefault(part, {}) final = parts[-1].strip("`") data[final] = value # ------------------------------------------------------------------ # # Query # ------------------------------------------------------------------ # class FakeQuery: """Supports chained .where() / .order_by() / .get().""" def __init__(self, store: FakeStore, collection_path: str) -> None: self._store = store self._collection_path = collection_path self._filters: list[tuple[str, str, Any]] = [] self._order_by_field: str | None = None def where(self, *, filter: Any) -> FakeQuery: # noqa: A002 clone = FakeQuery(self._store, self._collection_path) clone._filters = [*self._filters, (filter.field_path, filter.op_string, filter.value)] clone._order_by_field = self._order_by_field return clone def order_by(self, field_path: str) -> FakeQuery: clone = FakeQuery(self._store, self._collection_path) clone._filters = list(self._filters) clone._order_by_field = field_path return clone async def get(self) -> list[FakeDocumentSnapshot]: docs = self._store.list_collection(self._collection_path) results: list[tuple[str, dict[str, Any]]] = [] for doc_path, data in docs: if all(_match(data, field, op, val) for field, op, val in self._filters): results.append((doc_path, data)) if self._order_by_field: field = self._order_by_field results.sort(key=lambda item: item[1].get(field, 0)) return [ FakeDocumentSnapshot( exists=True, data=copy.deepcopy(data), reference=FakeDocumentReference(self._store, path), ) for path, data in results ] def _match(data: dict[str, Any], field: str, op: str, value: Any) -> bool: doc_val = data.get(field) if op == "==": return doc_val == value if op == ">=": return doc_val is not None and doc_val >= value return False # ------------------------------------------------------------------ # # CollectionReference (extends Query behaviour) # ------------------------------------------------------------------ # class FakeCollectionReference(FakeQuery): def document(self, document_id: str) -> FakeDocumentReference: return FakeDocumentReference(self._store, f"{self._collection_path}/{document_id}") # ------------------------------------------------------------------ # # WriteBatch # ------------------------------------------------------------------ # class FakeWriteBatch: def __init__(self, store: FakeStore) -> None: self._store = store self._deletes: list[str] = [] def delete(self, doc_ref: FakeDocumentReference) -> None: self._deletes.append(doc_ref.path) async def commit(self) -> None: for path in self._deletes: self._store.delete_doc(path) # ------------------------------------------------------------------ # # Transaction (minimal, supports @async_transactional) # ------------------------------------------------------------------ # class FakeTransaction: """Minimal transaction compatible with ``@async_transactional``. The decorator calls ``_clean_up()``, ``_begin()``, the wrapped function, then ``_commit()``. On error it calls ``_rollback()``. ``in_progress`` is a property that checks ``_id is not None``. """ def __init__(self, store: FakeStore) -> None: self._store = store self._staged_updates: list[tuple[str, dict[str, Any]]] = [] self._id: bytes | None = None self._max_attempts = 1 self._read_only = False @property def in_progress(self) -> bool: return self._id is not None def _clean_up(self) -> None: self._id = None async def _begin(self, retry_id: bytes | None = None) -> None: self._id = b"fake-txn" async def _commit(self) -> list: for path, updates in self._staged_updates: data = self._store.get_doc(path) if data is not None: for key, value in updates.items(): _nested_set(data, key, value) self._store.set_doc(path, data) self._staged_updates.clear() self._clean_up() return [] async def _rollback(self) -> None: self._staged_updates.clear() self._clean_up() def update(self, doc_ref: FakeDocumentReference, field_updates: dict[str, Any]) -> None: self._staged_updates.append((doc_ref.path, field_updates)) # ------------------------------------------------------------------ # # Document store (flat dict keyed by path) # ------------------------------------------------------------------ # class FakeStore: def __init__(self) -> None: self._docs: dict[str, dict[str, Any]] = {} def get_doc(self, path: str) -> dict[str, Any] | None: data = self._docs.get(path) return data # returns reference, callers deepcopy where needed def set_doc(self, path: str, data: dict[str, Any]) -> None: self._docs[path] = data def delete_doc(self, path: str) -> None: self._docs.pop(path, None) def list_collection(self, collection_path: str) -> list[tuple[str, dict[str, Any]]]: """Return (path, data) for every direct child doc of *collection_path*.""" prefix = collection_path + "/" results: list[tuple[str, dict[str, Any]]] = [] for doc_path, data in self._docs.items(): if not doc_path.startswith(prefix): continue # Must be a direct child (no further '/' after the prefix, except maybe subcollection paths) remainder = doc_path[len(prefix):] if "/" not in remainder: results.append((doc_path, data)) return results def recursive_delete(self, path: str) -> None: """Delete a document and everything nested under it.""" to_delete = [p for p in self._docs if p == path or p.startswith(path + "/")] for p in to_delete: del self._docs[p] # ------------------------------------------------------------------ # # FakeAsyncClient (drop-in for AsyncClient) # ------------------------------------------------------------------ # class FakeAsyncClient: def __init__(self, **_kwargs: Any) -> None: self._store = FakeStore() def collection(self, collection_path: str) -> FakeCollectionReference: return FakeCollectionReference(self._store, collection_path) def batch(self) -> FakeWriteBatch: return FakeWriteBatch(self._store) def transaction(self, **kwargs: Any) -> FakeTransaction: return FakeTransaction(self._store) async def recursive_delete(self, doc_ref: FakeDocumentReference) -> None: self._store.recursive_delete(doc_ref.path)