Switch to agent arch
This commit is contained in:
@@ -1,196 +1,84 @@
|
||||
import asyncio
|
||||
"""Pydantic AI agent with RAG tool for vector search."""
|
||||
|
||||
import time
|
||||
|
||||
import structlog
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from google.genai import types
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
from pydantic import BaseModel
|
||||
from pydantic_ai import Agent, Embedder, RunContext
|
||||
from pydantic_ai.models.google import GoogleModel
|
||||
|
||||
from rag_eval.config import settings
|
||||
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
MAX_CONCURRENT_LLM_CALLS = 20
|
||||
_llm_semaphore = asyncio.Semaphore(MAX_CONCURRENT_LLM_CALLS)
|
||||
|
||||
class Agent:
|
||||
"""A class to handle the RAG workflow."""
|
||||
class Deps(BaseModel):
|
||||
"""Dependencies injected into the agent at runtime."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the Agent class."""
|
||||
self.settings = settings.agent
|
||||
self.index_settings = settings.index
|
||||
self.model = self.settings.language_model
|
||||
self.system_prompt = self.settings.instructions
|
||||
self.llm = VertexAILLM(
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
thinking=self.settings.thinking,
|
||||
)
|
||||
self.vector_search = GoogleCloudVectorSearch(
|
||||
project_id=settings.project_id,
|
||||
location=settings.location,
|
||||
bucket=settings.bucket,
|
||||
index_name=self.index_settings.name,
|
||||
)
|
||||
self.vector_search.load_index_endpoint(self.index_settings.endpoint)
|
||||
self.embedder = VertexAIEmbedder(
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
model_name=self.settings.embedding_model, task="RETRIEVAL_QUERY"
|
||||
)
|
||||
self.min_sim = 0.60
|
||||
vector_search: GoogleCloudVectorSearch
|
||||
embedder: Embedder
|
||||
|
||||
def call(self, query: str | list[dict[str, str]]) -> str:
|
||||
"""Calls the LLM with the provided query and tools.
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
Args:
|
||||
query: The user's query.
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
search_query = query
|
||||
else:
|
||||
search_query = query[-1]["content"]
|
||||
model = GoogleModel(
|
||||
settings.agent_language_model,
|
||||
provider=settings.provider,
|
||||
)
|
||||
agent = Agent(
|
||||
model,
|
||||
deps_type=Deps,
|
||||
system_prompt=settings.agent_instructions,
|
||||
)
|
||||
|
||||
context = self.search(search_query)
|
||||
user_prompt = f"{search_query}\n\n{context}"
|
||||
|
||||
contents = []
|
||||
if isinstance(query, str):
|
||||
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
|
||||
else:
|
||||
for turn in query[:-1]:
|
||||
role = "model" if turn["role"] == "assistant" else "user"
|
||||
contents.append(
|
||||
types.Content(role=role, parts=[types.Part(text=turn["content"])])
|
||||
)
|
||||
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
|
||||
@agent.tool
|
||||
async def conocimiento(ctx: RunContext[Deps], query: str) -> str:
|
||||
"""Search the vector index for the given query.
|
||||
|
||||
generation = self.llm.generate(
|
||||
model=self.model,
|
||||
prompt=contents,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
Args:
|
||||
ctx: The run context containing dependencies.
|
||||
query: The query to search for.
|
||||
|
||||
logger.info(f"total usage={generation.usage}")
|
||||
logger.info(f"costo ${generation.usage.get_cost(self.model)} MXN")
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
|
||||
return generation.text
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
min_sim = 0.6
|
||||
|
||||
def search(self, query: str ) -> str:
|
||||
"""Searches the vector index for the given query.
|
||||
query_embedding = await ctx.deps.embedder.embed_query(query)
|
||||
t_embed = time.perf_counter()
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
search_results = await ctx.deps.vector_search.async_run_query(
|
||||
deployed_index_id=settings.index_name,
|
||||
query=list(query_embedding.embeddings[0]),
|
||||
limit=5,
|
||||
)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
"""
|
||||
logger.debug(f"Search term: {query}")
|
||||
|
||||
query_embedding = self.embedder.generate_embedding(query)
|
||||
search_results = self.vector_search.run_query(
|
||||
deployed_index_id=self.index_settings.deployment,
|
||||
query=query_embedding,
|
||||
limit=5,
|
||||
)
|
||||
|
||||
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
|
||||
cutoff= max_sim * 0.9
|
||||
|
||||
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
|
||||
|
||||
logger.debug(f"{max_sim=}")
|
||||
logger.debug(f"{cutoff=}")
|
||||
logger.debug(f"chunks={[s['id'] for s in search_results]}")
|
||||
logger.debug(f"distancias={[s['distance'] for s in search_results]}")
|
||||
|
||||
return self._format_results(search_results)
|
||||
|
||||
async def async_search(self, query: str) -> str:
|
||||
"""Searches the vector index for the given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
query_embedding = await self.embedder.async_generate_embedding(query)
|
||||
t_embed = time.perf_counter()
|
||||
|
||||
search_results = await self.vector_search.async_run_query(
|
||||
deployed_index_id=self.index_settings.deployment,
|
||||
query=query_embedding,
|
||||
limit=5,
|
||||
)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
|
||||
if search_results:
|
||||
max_sim = max(r["distance"] for r in search_results)
|
||||
cutoff = max_sim * 0.9
|
||||
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
|
||||
|
||||
logger.info(
|
||||
"async_search.timing",
|
||||
embedding_ms=round((t_embed - t0) * 1000, 1),
|
||||
vector_search_ms=round((t_search - t_embed) * 1000, 1),
|
||||
total_ms=round((t_search - t0) * 1000, 1),
|
||||
chunks=[s["id"] for s in search_results],
|
||||
)
|
||||
|
||||
return self._format_results(search_results)
|
||||
|
||||
def _format_results(self, search_results):
|
||||
formatted_results = [
|
||||
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
|
||||
for i, result in enumerate(search_results, start=1)
|
||||
search_results = [
|
||||
s
|
||||
for s in search_results
|
||||
if s["distance"] > cutoff and s["distance"] > min_sim
|
||||
]
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
async def async_call(self, query: str) -> str:
|
||||
"""Calls the LLM with the provided query and tools.
|
||||
logger.info(
|
||||
"conocimiento.timing",
|
||||
embedding_ms=round((t_embed - t0) * 1000, 1),
|
||||
vector_search_ms=round((t_search - t_embed) * 1000, 1),
|
||||
total_ms=round((t_search - t0) * 1000, 1),
|
||||
chunks=[s["id"] for s in search_results],
|
||||
)
|
||||
|
||||
Args:
|
||||
query: The user's query.
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
"""
|
||||
t_start = time.perf_counter()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
context = await self.async_search(query)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
contents = [types.Content(role="user", parts=[types.Part(text=f"{query}\n\n{context}")])]
|
||||
|
||||
t1 = time.perf_counter()
|
||||
async with _llm_semaphore:
|
||||
generation = await self.llm.async_generate(
|
||||
model=self.model,
|
||||
prompt=contents,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
t_llm = time.perf_counter()
|
||||
|
||||
t_end = time.perf_counter()
|
||||
logger.info(
|
||||
"async_call.timing",
|
||||
total_ms=round((t_end - t_start) * 1000, 1),
|
||||
stages=[
|
||||
{"stage": "search", "ms": round((t_search - t0) * 1000, 1)},
|
||||
{"stage": "llm", "ms": round((t_llm - t1) * 1000, 1)},
|
||||
],
|
||||
llm_iterations=1,
|
||||
prompt_tokens=generation.usage.prompt_tokens,
|
||||
response_tokens=generation.usage.response_tokens,
|
||||
cost_mxn=generation.usage.get_cost(self.model),
|
||||
)
|
||||
|
||||
return generation.text
|
||||
formatted_results = [
|
||||
f"<document {i} name={result['id']}>\n"
|
||||
f"{result['content']}\n"
|
||||
f"</document {i}>"
|
||||
for i, result in enumerate(search_results, start=1)
|
||||
]
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
Reference in New Issue
Block a user