Switch to agent arch

This commit is contained in:
2026-02-20 08:59:43 +00:00
parent a53f8fcf62
commit 259a8528e3
113 changed files with 788 additions and 7820 deletions

View File

@@ -1,196 +1,84 @@
import asyncio
"""Pydantic AI agent with RAG tool for vector search."""
import time
import structlog
from embedder.vertex_ai import VertexAIEmbedder
from google.genai import types
from llm.vertex_ai import VertexAILLM
from vector_search.vertex_ai import GoogleCloudVectorSearch
from pydantic import BaseModel
from pydantic_ai import Agent, Embedder, RunContext
from pydantic_ai.models.google import GoogleModel
from rag_eval.config import settings
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
logger = structlog.get_logger(__name__)
MAX_CONCURRENT_LLM_CALLS = 20
_llm_semaphore = asyncio.Semaphore(MAX_CONCURRENT_LLM_CALLS)
class Agent:
"""A class to handle the RAG workflow."""
class Deps(BaseModel):
"""Dependencies injected into the agent at runtime."""
def __init__(self):
"""Initializes the Agent class."""
self.settings = settings.agent
self.index_settings = settings.index
self.model = self.settings.language_model
self.system_prompt = self.settings.instructions
self.llm = VertexAILLM(
project=settings.project_id,
location=settings.location,
thinking=self.settings.thinking,
)
self.vector_search = GoogleCloudVectorSearch(
project_id=settings.project_id,
location=settings.location,
bucket=settings.bucket,
index_name=self.index_settings.name,
)
self.vector_search.load_index_endpoint(self.index_settings.endpoint)
self.embedder = VertexAIEmbedder(
project=settings.project_id,
location=settings.location,
model_name=self.settings.embedding_model, task="RETRIEVAL_QUERY"
)
self.min_sim = 0.60
vector_search: GoogleCloudVectorSearch
embedder: Embedder
def call(self, query: str | list[dict[str, str]]) -> str:
"""Calls the LLM with the provided query and tools.
model_config = {"arbitrary_types_allowed": True}
Args:
query: The user's query.
Returns:
The response from the LLM.
"""
if isinstance(query, str):
search_query = query
else:
search_query = query[-1]["content"]
model = GoogleModel(
settings.agent_language_model,
provider=settings.provider,
)
agent = Agent(
model,
deps_type=Deps,
system_prompt=settings.agent_instructions,
)
context = self.search(search_query)
user_prompt = f"{search_query}\n\n{context}"
contents = []
if isinstance(query, str):
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
else:
for turn in query[:-1]:
role = "model" if turn["role"] == "assistant" else "user"
contents.append(
types.Content(role=role, parts=[types.Part(text=turn["content"])])
)
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
@agent.tool
async def conocimiento(ctx: RunContext[Deps], query: str) -> str:
"""Search the vector index for the given query.
generation = self.llm.generate(
model=self.model,
prompt=contents,
system_prompt=self.system_prompt,
)
Args:
ctx: The run context containing dependencies.
query: The query to search for.
logger.info(f"total usage={generation.usage}")
logger.info(f"costo ${generation.usage.get_cost(self.model)} MXN")
Returns:
A formatted string containing the search results.
return generation.text
"""
t0 = time.perf_counter()
min_sim = 0.6
def search(self, query: str ) -> str:
"""Searches the vector index for the given query.
query_embedding = await ctx.deps.embedder.embed_query(query)
t_embed = time.perf_counter()
Args:
query: The query to search for.
search_results = await ctx.deps.vector_search.async_run_query(
deployed_index_id=settings.index_name,
query=list(query_embedding.embeddings[0]),
limit=5,
)
t_search = time.perf_counter()
Returns:
A formatted string containing the search results.
"""
logger.debug(f"Search term: {query}")
query_embedding = self.embedder.generate_embedding(query)
search_results = self.vector_search.run_query(
deployed_index_id=self.index_settings.deployment,
query=query_embedding,
limit=5,
)
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
cutoff= max_sim * 0.9
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
logger.debug(f"{max_sim=}")
logger.debug(f"{cutoff=}")
logger.debug(f"chunks={[s['id'] for s in search_results]}")
logger.debug(f"distancias={[s['distance'] for s in search_results]}")
return self._format_results(search_results)
async def async_search(self, query: str) -> str:
"""Searches the vector index for the given query.
Args:
query: The query to search for.
Returns:
A formatted string containing the search results.
"""
t0 = time.perf_counter()
query_embedding = await self.embedder.async_generate_embedding(query)
t_embed = time.perf_counter()
search_results = await self.vector_search.async_run_query(
deployed_index_id=self.index_settings.deployment,
query=query_embedding,
limit=5,
)
t_search = time.perf_counter()
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
if search_results:
max_sim = max(r["distance"] for r in search_results)
cutoff = max_sim * 0.9
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
logger.info(
"async_search.timing",
embedding_ms=round((t_embed - t0) * 1000, 1),
vector_search_ms=round((t_search - t_embed) * 1000, 1),
total_ms=round((t_search - t0) * 1000, 1),
chunks=[s["id"] for s in search_results],
)
return self._format_results(search_results)
def _format_results(self, search_results):
formatted_results = [
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
for i, result in enumerate(search_results, start=1)
search_results = [
s
for s in search_results
if s["distance"] > cutoff and s["distance"] > min_sim
]
return "\n".join(formatted_results)
async def async_call(self, query: str) -> str:
"""Calls the LLM with the provided query and tools.
logger.info(
"conocimiento.timing",
embedding_ms=round((t_embed - t0) * 1000, 1),
vector_search_ms=round((t_search - t_embed) * 1000, 1),
total_ms=round((t_search - t0) * 1000, 1),
chunks=[s["id"] for s in search_results],
)
Args:
query: The user's query.
Returns:
The response from the LLM.
"""
t_start = time.perf_counter()
t0 = time.perf_counter()
context = await self.async_search(query)
t_search = time.perf_counter()
contents = [types.Content(role="user", parts=[types.Part(text=f"{query}\n\n{context}")])]
t1 = time.perf_counter()
async with _llm_semaphore:
generation = await self.llm.async_generate(
model=self.model,
prompt=contents,
system_prompt=self.system_prompt,
)
t_llm = time.perf_counter()
t_end = time.perf_counter()
logger.info(
"async_call.timing",
total_ms=round((t_end - t_start) * 1000, 1),
stages=[
{"stage": "search", "ms": round((t_search - t0) * 1000, 1)},
{"stage": "llm", "ms": round((t_llm - t1) * 1000, 1)},
],
llm_iterations=1,
prompt_tokens=generation.usage.prompt_tokens,
response_tokens=generation.usage.response_tokens,
cost_mxn=generation.usage.get_cost(self.model),
)
return generation.text
formatted_results = [
f"<document {i} name={result['id']}>\n"
f"{result['content']}\n"
f"</document {i}>"
for i, result in enumerate(search_results, start=1)
]
return "\n".join(formatted_results)