Use filters instead of string formatting

This commit is contained in:
ajac-zero
2026-02-21 22:09:41 -06:00
parent 3cb78afc3a
commit 52aa8bfe0d
2 changed files with 6 additions and 5 deletions

View File

@@ -19,6 +19,7 @@ from google.adk.sessions.base_session_service import (
from google.adk.sessions.session import Session from google.adk.sessions.session import Session
from google.adk.sessions.state import State from google.adk.sessions.state import State
from google.cloud.firestore_v1.async_client import AsyncClient from google.cloud.firestore_v1.async_client import AsyncClient
from google.cloud.firestore_v1.base_query import FieldFilter
from google.cloud.firestore_v1.field_path import FieldPath from google.cloud.firestore_v1.field_path import FieldPath
from google.genai.types import Content, Part from google.genai.types import Content, Part
from typing_extensions import override from typing_extensions import override
@@ -289,7 +290,7 @@ class FirestoreSessionService(BaseSessionService):
events_ref = self._events_col(app_name, user_id, session_id) events_ref = self._events_col(app_name, user_id, session_id)
query = events_ref query = events_ref
if config and config.after_timestamp: if config and config.after_timestamp:
query = query.where("timestamp", ">=", config.after_timestamp) query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp))
query = query.order_by("timestamp") query = query.order_by("timestamp")
event_docs = await query.get() event_docs = await query.get()
@@ -360,10 +361,10 @@ class FirestoreSessionService(BaseSessionService):
self, *, app_name: str, user_id: Optional[str] = None self, *, app_name: str, user_id: Optional[str] = None
) -> ListSessionsResponse: ) -> ListSessionsResponse:
query = self._db.collection(f"{self._prefix}_sessions").where( query = self._db.collection(f"{self._prefix}_sessions").where(
"app_name", "==", app_name filter=FieldFilter("app_name", "==", app_name)
) )
if user_id is not None: if user_id is not None:
query = query.where("user_id", "==", user_id) query = query.where(filter=FieldFilter("user_id", "==", user_id))
docs = await query.get() docs = await query.get()
if not docs: if not docs:

View File

@@ -247,7 +247,7 @@ class TestGetSessionWithSummary:
class TestEventsToText: class TestEventsToText:
def test_formats_user_and_assistant(self): async def test_formats_user_and_assistant(self):
events = [ events = [
Event( Event(
author="user", author="user",
@@ -270,7 +270,7 @@ class TestEventsToText:
assert "User: Hi there" in text assert "User: Hi there" in text
assert "Assistant: Hello!" in text assert "Assistant: Hello!" in text
def test_skips_events_without_text(self): async def test_skips_events_without_text(self):
events = [ events = [
Event( Event(
author="user", author="user",