Add testing
This commit is contained in:
430
tests/test_firestore_session_service.py
Normal file
430
tests/test_firestore_session_service.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""Tests for FirestoreSessionService against the Firestore emulator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from google.adk.errors.already_exists_error import AlreadyExistsError
|
||||
from google.adk.events.event import Event
|
||||
from google.adk.events.event_actions import EventActions
|
||||
from google.adk.sessions.base_session_service import GetSessionConfig
|
||||
from google.genai.types import Content, Part
|
||||
|
||||
from adk_firestore_sessionmanager import FirestoreSessionService
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# create_session
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateSession:
|
||||
async def test_auto_generates_id(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
assert session.id
|
||||
assert session.app_name == app_name
|
||||
assert session.user_id == user_id
|
||||
assert session.last_update_time > 0
|
||||
|
||||
async def test_custom_id(self, service, app_name, user_id):
|
||||
sid = "my-custom-session"
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id, session_id=sid
|
||||
)
|
||||
assert session.id == sid
|
||||
|
||||
async def test_duplicate_id_raises(self, service, app_name, user_id):
|
||||
sid = "dup-session"
|
||||
await service.create_session(
|
||||
app_name=app_name, user_id=user_id, session_id=sid
|
||||
)
|
||||
with pytest.raises(AlreadyExistsError):
|
||||
await service.create_session(
|
||||
app_name=app_name, user_id=user_id, session_id=sid
|
||||
)
|
||||
|
||||
async def test_session_state(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state={"count": 42},
|
||||
)
|
||||
assert session.state["count"] == 42
|
||||
|
||||
async def test_scoped_state(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state={
|
||||
"app:global_flag": True,
|
||||
"user:lang": "es",
|
||||
"local_key": "val",
|
||||
},
|
||||
)
|
||||
assert session.state["app:global_flag"] is True
|
||||
assert session.state["user:lang"] == "es"
|
||||
assert session.state["local_key"] == "val"
|
||||
|
||||
async def test_temp_state_not_persisted(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state={"temp:scratch": "gone", "keep": "yes"},
|
||||
)
|
||||
retrieved = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert "temp:scratch" not in retrieved.state
|
||||
assert retrieved.state["keep"] == "yes"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# get_session
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSession:
|
||||
async def test_nonexistent_returns_none(self, service, app_name, user_id):
|
||||
result = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id="nope"
|
||||
)
|
||||
assert result is None
|
||||
|
||||
async def test_roundtrip(self, service, app_name, user_id):
|
||||
created = await service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state={"foo": "bar"},
|
||||
)
|
||||
fetched = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=created.id
|
||||
)
|
||||
assert fetched is not None
|
||||
assert fetched.id == created.id
|
||||
assert fetched.state["foo"] == "bar"
|
||||
assert fetched.last_update_time == pytest.approx(
|
||||
created.last_update_time, abs=0.01
|
||||
)
|
||||
|
||||
async def test_returns_events(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
author="user",
|
||||
content=Content(parts=[Part(text="hello")]),
|
||||
timestamp=time.time(),
|
||||
invocation_id="inv-1",
|
||||
)
|
||||
await service.append_event(session, event)
|
||||
|
||||
fetched = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert len(fetched.events) == 1
|
||||
assert fetched.events[0].author == "user"
|
||||
|
||||
async def test_num_recent_events(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
for i in range(5):
|
||||
e = Event(
|
||||
author="user",
|
||||
timestamp=time.time() + i,
|
||||
invocation_id=f"inv-{i}",
|
||||
)
|
||||
await service.append_event(session, e)
|
||||
|
||||
fetched = await service.get_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session.id,
|
||||
config=GetSessionConfig(num_recent_events=2),
|
||||
)
|
||||
assert len(fetched.events) == 2
|
||||
|
||||
async def test_after_timestamp(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
base = time.time()
|
||||
for i in range(3):
|
||||
e = Event(
|
||||
author="user",
|
||||
timestamp=base + i,
|
||||
invocation_id=f"inv-{i}",
|
||||
)
|
||||
await service.append_event(session, e)
|
||||
|
||||
fetched = await service.get_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
session_id=session.id,
|
||||
config=GetSessionConfig(after_timestamp=base + 1),
|
||||
)
|
||||
assert len(fetched.events) == 2
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# list_sessions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListSessions:
|
||||
async def test_empty(self, service, app_name, user_id):
|
||||
resp = await service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
assert resp.sessions == [] or resp.sessions is None
|
||||
|
||||
async def test_returns_created_sessions(
|
||||
self, service, app_name, user_id
|
||||
):
|
||||
s1 = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
s2 = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
resp = await service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
ids = {s.id for s in resp.sessions}
|
||||
assert s1.id in ids
|
||||
assert s2.id in ids
|
||||
|
||||
async def test_filter_by_user(self, service, app_name):
|
||||
uid1 = f"user_{uuid.uuid4().hex[:8]}"
|
||||
uid2 = f"user_{uuid.uuid4().hex[:8]}"
|
||||
await service.create_session(app_name=app_name, user_id=uid1)
|
||||
await service.create_session(app_name=app_name, user_id=uid2)
|
||||
|
||||
resp = await service.list_sessions(
|
||||
app_name=app_name, user_id=uid1
|
||||
)
|
||||
assert len(resp.sessions) == 1
|
||||
assert resp.sessions[0].user_id == uid1
|
||||
|
||||
async def test_sessions_have_merged_state(
|
||||
self, service, app_name, user_id
|
||||
):
|
||||
await service.create_session(
|
||||
app_name=app_name,
|
||||
user_id=user_id,
|
||||
state={"app:shared": "yes", "local": "val"},
|
||||
)
|
||||
resp = await service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
s = resp.sessions[0]
|
||||
assert s.state["app:shared"] == "yes"
|
||||
assert s.state["local"] == "val"
|
||||
|
||||
async def test_sessions_have_no_events(
|
||||
self, service, app_name, user_id
|
||||
):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
author="user", timestamp=time.time(), invocation_id="inv-1"
|
||||
)
|
||||
await service.append_event(session, event)
|
||||
|
||||
resp = await service.list_sessions(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
assert resp.sessions[0].events == []
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete_session
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteSession:
|
||||
async def test_delete(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
await service.delete_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
result = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert result is None
|
||||
|
||||
async def test_delete_removes_events(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
author="user", timestamp=time.time(), invocation_id="inv-1"
|
||||
)
|
||||
await service.append_event(session, event)
|
||||
|
||||
await service.delete_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
result = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# append_event
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAppendEvent:
|
||||
async def test_basic(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
author="user",
|
||||
content=Content(parts=[Part(text="hi")]),
|
||||
timestamp=time.time(),
|
||||
invocation_id="inv-1",
|
||||
)
|
||||
returned = await service.append_event(session, event)
|
||||
assert returned.id == event.id
|
||||
assert returned.timestamp > 0
|
||||
|
||||
async def test_partial_event_not_persisted(
|
||||
self, service, app_name, user_id
|
||||
):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
author="user",
|
||||
partial=True,
|
||||
timestamp=time.time(),
|
||||
invocation_id="inv-1",
|
||||
)
|
||||
await service.append_event(session, event)
|
||||
|
||||
fetched = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert len(fetched.events) == 0
|
||||
|
||||
async def test_session_state_delta(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
author="agent",
|
||||
actions=EventActions(state_delta={"counter": 1}),
|
||||
timestamp=time.time(),
|
||||
invocation_id="inv-1",
|
||||
)
|
||||
await service.append_event(session, event)
|
||||
|
||||
fetched = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert fetched.state["counter"] == 1
|
||||
|
||||
async def test_app_state_delta(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
author="agent",
|
||||
actions=EventActions(state_delta={"app:version": "2.0"}),
|
||||
timestamp=time.time(),
|
||||
invocation_id="inv-1",
|
||||
)
|
||||
await service.append_event(session, event)
|
||||
|
||||
fetched = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert fetched.state["app:version"] == "2.0"
|
||||
|
||||
async def test_user_state_delta(self, service, app_name, user_id):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
author="agent",
|
||||
actions=EventActions(state_delta={"user:pref": "dark"}),
|
||||
timestamp=time.time(),
|
||||
invocation_id="inv-1",
|
||||
)
|
||||
await service.append_event(session, event)
|
||||
|
||||
fetched = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert fetched.state["user:pref"] == "dark"
|
||||
|
||||
async def test_updates_last_update_time(
|
||||
self, service, app_name, user_id
|
||||
):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
original_time = session.last_update_time
|
||||
|
||||
event = Event(
|
||||
author="user",
|
||||
timestamp=time.time() + 10,
|
||||
invocation_id="inv-1",
|
||||
)
|
||||
await service.append_event(session, event)
|
||||
|
||||
fetched = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert fetched.last_update_time > original_time
|
||||
|
||||
async def test_multiple_events_accumulate(
|
||||
self, service, app_name, user_id
|
||||
):
|
||||
session = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
for i in range(3):
|
||||
e = Event(
|
||||
author="user",
|
||||
content=Content(parts=[Part(text=f"msg {i}")]),
|
||||
timestamp=time.time() + i,
|
||||
invocation_id=f"inv-{i}",
|
||||
)
|
||||
await service.append_event(session, e)
|
||||
|
||||
fetched = await service.get_session(
|
||||
app_name=app_name, user_id=user_id, session_id=session.id
|
||||
)
|
||||
assert len(fetched.events) == 3
|
||||
|
||||
async def test_app_state_shared_across_sessions(
|
||||
self, service, app_name, user_id
|
||||
):
|
||||
s1 = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
event = Event(
|
||||
author="agent",
|
||||
actions=EventActions(state_delta={"app:shared_val": 99}),
|
||||
timestamp=time.time(),
|
||||
invocation_id="inv-1",
|
||||
)
|
||||
await service.append_event(s1, event)
|
||||
|
||||
s2 = await service.create_session(
|
||||
app_name=app_name, user_id=user_id
|
||||
)
|
||||
assert s2.state["app:shared_val"] == 99
|
||||
Reference in New Issue
Block a user