"""Tests for FirestoreSessionService against the Firestore emulator.""" from __future__ import annotations import time import uuid import pytest from google.adk.errors.already_exists_error import AlreadyExistsError from google.adk.events.event import Event from google.adk.events.event_actions import EventActions from google.adk.sessions.base_session_service import GetSessionConfig from google.genai.types import Content, Part pytestmark = pytest.mark.asyncio # ------------------------------------------------------------------ # create_session # ------------------------------------------------------------------ class TestCreateSession: async def test_auto_generates_id(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id ) assert session.id assert session.app_name == app_name assert session.user_id == user_id assert session.last_update_time > 0 async def test_custom_id(self, service, app_name, user_id): sid = "my-custom-session" session = await service.create_session( app_name=app_name, user_id=user_id, session_id=sid ) assert session.id == sid async def test_duplicate_id_raises(self, service, app_name, user_id): sid = "dup-session" await service.create_session( app_name=app_name, user_id=user_id, session_id=sid ) with pytest.raises(AlreadyExistsError): await service.create_session( app_name=app_name, user_id=user_id, session_id=sid ) async def test_session_state(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id, state={"count": 42}, ) assert session.state["count"] == 42 async def test_scoped_state(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id, state={ "app:global_flag": True, "user:lang": "es", "local_key": "val", }, ) assert session.state["app:global_flag"] is True assert session.state["user:lang"] == "es" assert session.state["local_key"] == "val" async def test_temp_state_not_persisted(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id, state={"temp:scratch": "gone", "keep": "yes"}, ) retrieved = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert "temp:scratch" not in retrieved.state assert retrieved.state["keep"] == "yes" # ------------------------------------------------------------------ # get_session # ------------------------------------------------------------------ class TestGetSession: async def test_nonexistent_returns_none(self, service, app_name, user_id): result = await service.get_session( app_name=app_name, user_id=user_id, session_id="nope" ) assert result is None async def test_roundtrip(self, service, app_name, user_id): created = await service.create_session( app_name=app_name, user_id=user_id, state={"foo": "bar"}, ) fetched = await service.get_session( app_name=app_name, user_id=user_id, session_id=created.id ) assert fetched is not None assert fetched.id == created.id assert fetched.state["foo"] == "bar" assert fetched.last_update_time == pytest.approx( created.last_update_time, abs=0.01 ) async def test_returns_events(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="user", content=Content(parts=[Part(text="hello")]), timestamp=time.time(), invocation_id="inv-1", ) await service.append_event(session, event) fetched = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert len(fetched.events) == 1 assert fetched.events[0].author == "user" async def test_num_recent_events(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id ) for i in range(5): e = Event( author="user", timestamp=time.time() + i, invocation_id=f"inv-{i}", ) await service.append_event(session, e) fetched = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id, config=GetSessionConfig(num_recent_events=2), ) assert len(fetched.events) == 2 async def test_after_timestamp(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id ) base = time.time() for i in range(3): e = Event( author="user", timestamp=base + i, invocation_id=f"inv-{i}", ) await service.append_event(session, e) fetched = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id, config=GetSessionConfig(after_timestamp=base + 1), ) assert len(fetched.events) == 2 # ------------------------------------------------------------------ # list_sessions # ------------------------------------------------------------------ class TestListSessions: async def test_empty(self, service, app_name, user_id): resp = await service.list_sessions( app_name=app_name, user_id=user_id ) assert resp.sessions == [] or resp.sessions is None async def test_returns_created_sessions( self, service, app_name, user_id ): s1 = await service.create_session( app_name=app_name, user_id=user_id ) s2 = await service.create_session( app_name=app_name, user_id=user_id ) resp = await service.list_sessions( app_name=app_name, user_id=user_id ) ids = {s.id for s in resp.sessions} assert s1.id in ids assert s2.id in ids async def test_filter_by_user(self, service, app_name): uid1 = f"user_{uuid.uuid4().hex[:8]}" uid2 = f"user_{uuid.uuid4().hex[:8]}" await service.create_session(app_name=app_name, user_id=uid1) await service.create_session(app_name=app_name, user_id=uid2) resp = await service.list_sessions( app_name=app_name, user_id=uid1 ) assert len(resp.sessions) == 1 assert resp.sessions[0].user_id == uid1 async def test_sessions_have_merged_state( self, service, app_name, user_id ): await service.create_session( app_name=app_name, user_id=user_id, state={"app:shared": "yes", "local": "val"}, ) resp = await service.list_sessions( app_name=app_name, user_id=user_id ) s = resp.sessions[0] assert s.state["app:shared"] == "yes" assert s.state["local"] == "val" async def test_sessions_have_no_events( self, service, app_name, user_id ): session = await service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="user", timestamp=time.time(), invocation_id="inv-1" ) await service.append_event(session, event) resp = await service.list_sessions( app_name=app_name, user_id=user_id ) assert resp.sessions[0].events == [] # ------------------------------------------------------------------ # delete_session # ------------------------------------------------------------------ class TestDeleteSession: async def test_delete(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id ) await service.delete_session( app_name=app_name, user_id=user_id, session_id=session.id ) result = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert result is None async def test_delete_removes_events(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="user", timestamp=time.time(), invocation_id="inv-1" ) await service.append_event(session, event) await service.delete_session( app_name=app_name, user_id=user_id, session_id=session.id ) result = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert result is None # ------------------------------------------------------------------ # append_event # ------------------------------------------------------------------ class TestAppendEvent: async def test_basic(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="user", content=Content(parts=[Part(text="hi")]), timestamp=time.time(), invocation_id="inv-1", ) returned = await service.append_event(session, event) assert returned.id == event.id assert returned.timestamp > 0 async def test_partial_event_not_persisted( self, service, app_name, user_id ): session = await service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="user", partial=True, timestamp=time.time(), invocation_id="inv-1", ) await service.append_event(session, event) fetched = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert len(fetched.events) == 0 async def test_session_state_delta(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="agent", actions=EventActions(state_delta={"counter": 1}), timestamp=time.time(), invocation_id="inv-1", ) await service.append_event(session, event) fetched = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert fetched.state["counter"] == 1 async def test_app_state_delta(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="agent", actions=EventActions(state_delta={"app:version": "2.0"}), timestamp=time.time(), invocation_id="inv-1", ) await service.append_event(session, event) fetched = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert fetched.state["app:version"] == "2.0" async def test_user_state_delta(self, service, app_name, user_id): session = await service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="agent", actions=EventActions(state_delta={"user:pref": "dark"}), timestamp=time.time(), invocation_id="inv-1", ) await service.append_event(session, event) fetched = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert fetched.state["user:pref"] == "dark" async def test_updates_last_update_time( self, service, app_name, user_id ): session = await service.create_session( app_name=app_name, user_id=user_id ) original_time = session.last_update_time event = Event( author="user", timestamp=time.time() + 10, invocation_id="inv-1", ) await service.append_event(session, event) fetched = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert fetched.last_update_time > original_time async def test_multiple_events_accumulate( self, service, app_name, user_id ): session = await service.create_session( app_name=app_name, user_id=user_id ) for i in range(3): e = Event( author="user", content=Content(parts=[Part(text=f"msg {i}")]), timestamp=time.time() + i, invocation_id=f"inv-{i}", ) await service.append_event(session, e) fetched = await service.get_session( app_name=app_name, user_id=user_id, session_id=session.id ) assert len(fetched.events) == 3 async def test_app_state_shared_across_sessions( self, service, app_name, user_id ): s1 = await service.create_session( app_name=app_name, user_id=user_id ) event = Event( author="agent", actions=EventActions(state_delta={"app:shared_val": 99}), timestamp=time.time(), invocation_id="inv-1", ) await service.append_event(s1, event) s2 = await service.create_session( app_name=app_name, user_id=user_id ) assert s2.state["app:shared_val"] == 99