Use filters instead of string formatting
This commit is contained in:
@@ -19,6 +19,7 @@ from google.adk.sessions.base_session_service import (
|
|||||||
from google.adk.sessions.session import Session
|
from google.adk.sessions.session import Session
|
||||||
from google.adk.sessions.state import State
|
from google.adk.sessions.state import State
|
||||||
from google.cloud.firestore_v1.async_client import AsyncClient
|
from google.cloud.firestore_v1.async_client import AsyncClient
|
||||||
|
from google.cloud.firestore_v1.base_query import FieldFilter
|
||||||
from google.cloud.firestore_v1.field_path import FieldPath
|
from google.cloud.firestore_v1.field_path import FieldPath
|
||||||
from google.genai.types import Content, Part
|
from google.genai.types import Content, Part
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
@@ -289,7 +290,7 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
events_ref = self._events_col(app_name, user_id, session_id)
|
events_ref = self._events_col(app_name, user_id, session_id)
|
||||||
query = events_ref
|
query = events_ref
|
||||||
if config and config.after_timestamp:
|
if config and config.after_timestamp:
|
||||||
query = query.where("timestamp", ">=", config.after_timestamp)
|
query = query.where(filter=FieldFilter("timestamp", ">=", config.after_timestamp))
|
||||||
query = query.order_by("timestamp")
|
query = query.order_by("timestamp")
|
||||||
|
|
||||||
event_docs = await query.get()
|
event_docs = await query.get()
|
||||||
@@ -360,10 +361,10 @@ class FirestoreSessionService(BaseSessionService):
|
|||||||
self, *, app_name: str, user_id: Optional[str] = None
|
self, *, app_name: str, user_id: Optional[str] = None
|
||||||
) -> ListSessionsResponse:
|
) -> ListSessionsResponse:
|
||||||
query = self._db.collection(f"{self._prefix}_sessions").where(
|
query = self._db.collection(f"{self._prefix}_sessions").where(
|
||||||
"app_name", "==", app_name
|
filter=FieldFilter("app_name", "==", app_name)
|
||||||
)
|
)
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
query = query.where("user_id", "==", user_id)
|
query = query.where(filter=FieldFilter("user_id", "==", user_id))
|
||||||
|
|
||||||
docs = await query.get()
|
docs = await query.get()
|
||||||
if not docs:
|
if not docs:
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ class TestGetSessionWithSummary:
|
|||||||
|
|
||||||
|
|
||||||
class TestEventsToText:
|
class TestEventsToText:
|
||||||
def test_formats_user_and_assistant(self):
|
async def test_formats_user_and_assistant(self):
|
||||||
events = [
|
events = [
|
||||||
Event(
|
Event(
|
||||||
author="user",
|
author="user",
|
||||||
@@ -270,7 +270,7 @@ class TestEventsToText:
|
|||||||
assert "User: Hi there" in text
|
assert "User: Hi there" in text
|
||||||
assert "Assistant: Hello!" in text
|
assert "Assistant: Hello!" in text
|
||||||
|
|
||||||
def test_skips_events_without_text(self):
|
async def test_skips_events_without_text(self):
|
||||||
events = [
|
events = [
|
||||||
Event(
|
Event(
|
||||||
author="user",
|
author="user",
|
||||||
|
|||||||
Reference in New Issue
Block a user