Files
agent/src/rag_eval/agent.py
Anibal Angulo 733c65ce85 Configure cli
2026-02-20 15:23:53 +00:00

93 lines
2.4 KiB
Python

"""Pydantic AI agent with RAG tool for vector search."""
import time
import structlog
from pydantic import BaseModel
from pydantic_ai import Agent, Embedder, RunContext
from pydantic_ai.models.google import GoogleModel
from rag_eval.config import settings
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
logger = structlog.get_logger(__name__)
class Deps(BaseModel):
"""Dependencies injected into the agent at runtime."""
vector_search: GoogleCloudVectorSearch
embedder: Embedder
model_config = {"arbitrary_types_allowed": True}
model = GoogleModel(
settings.agent_language_model,
provider=settings.provider,
)
agent = Agent(
model,
deps_type=Deps,
system_prompt=settings.agent_instructions,
)
@agent.tool
async def conocimiento(ctx: RunContext[Deps], query: str) -> str:
"""Search the vector index for the given query.
Args:
ctx: The run context containing dependencies.
query: The query to search for.
Returns:
A formatted string containing the search results.
"""
t0 = time.perf_counter()
min_sim = 0.6
query_embedding = await ctx.deps.embedder.embed_query(query)
t_embed = time.perf_counter()
search_results = await ctx.deps.vector_search.async_run_query(
deployed_index_id=settings.index_deployed_id,
query=list(query_embedding.embeddings[0]),
limit=5,
)
t_search = time.perf_counter()
if search_results:
max_sim = max(r["distance"] for r in search_results)
cutoff = max_sim * 0.9
search_results = [
s
for s in search_results
if s["distance"] > cutoff and s["distance"] > min_sim
]
logger.info(
"conocimiento.timing",
embedding_ms=round((t_embed - t0) * 1000, 1),
vector_search_ms=round((t_search - t_embed) * 1000, 1),
total_ms=round((t_search - t0) * 1000, 1),
chunks=[s["id"] for s in search_results],
)
formatted_results = [
f"<document {i} name={result['id']}>\n"
f"{result['content']}\n"
f"</document {i}>"
for i, result in enumerate(search_results, start=1)
]
return "\n".join(formatted_results)
if __name__ == "__main__":
deps = Deps(
vector_search=settings.vector_search,
embedder=settings.embedder,
)
agent.to_cli_sync(deps=deps)