First commit
This commit is contained in:
196
src/rag_eval/agent.py
Normal file
196
src/rag_eval/agent.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import structlog
|
||||
from embedder.vertex_ai import VertexAIEmbedder
|
||||
from google.genai import types
|
||||
from llm.vertex_ai import VertexAILLM
|
||||
from vector_search.vertex_ai import GoogleCloudVectorSearch
|
||||
|
||||
from rag_eval.config import settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
MAX_CONCURRENT_LLM_CALLS = 20
|
||||
_llm_semaphore = asyncio.Semaphore(MAX_CONCURRENT_LLM_CALLS)
|
||||
|
||||
class Agent:
|
||||
"""A class to handle the RAG workflow."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the Agent class."""
|
||||
self.settings = settings.agent
|
||||
self.index_settings = settings.index
|
||||
self.model = self.settings.language_model
|
||||
self.system_prompt = self.settings.instructions
|
||||
self.llm = VertexAILLM(
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
thinking=self.settings.thinking,
|
||||
)
|
||||
self.vector_search = GoogleCloudVectorSearch(
|
||||
project_id=settings.project_id,
|
||||
location=settings.location,
|
||||
bucket=settings.bucket,
|
||||
index_name=self.index_settings.name,
|
||||
)
|
||||
self.vector_search.load_index_endpoint(self.index_settings.endpoint)
|
||||
self.embedder = VertexAIEmbedder(
|
||||
project=settings.project_id,
|
||||
location=settings.location,
|
||||
model_name=self.settings.embedding_model, task="RETRIEVAL_QUERY"
|
||||
)
|
||||
self.min_sim = 0.60
|
||||
|
||||
def call(self, query: str | list[dict[str, str]]) -> str:
|
||||
"""Calls the LLM with the provided query and tools.
|
||||
|
||||
Args:
|
||||
query: The user's query.
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
"""
|
||||
if isinstance(query, str):
|
||||
search_query = query
|
||||
else:
|
||||
search_query = query[-1]["content"]
|
||||
|
||||
context = self.search(search_query)
|
||||
user_prompt = f"{search_query}\n\n{context}"
|
||||
|
||||
contents = []
|
||||
if isinstance(query, str):
|
||||
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
|
||||
else:
|
||||
for turn in query[:-1]:
|
||||
role = "model" if turn["role"] == "assistant" else "user"
|
||||
contents.append(
|
||||
types.Content(role=role, parts=[types.Part(text=turn["content"])])
|
||||
)
|
||||
contents.append(types.Content(role="user", parts=[types.Part(text=user_prompt)]))
|
||||
|
||||
generation = self.llm.generate(
|
||||
model=self.model,
|
||||
prompt=contents,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
|
||||
logger.info(f"total usage={generation.usage}")
|
||||
logger.info(f"costo ${generation.usage.get_cost(self.model)} MXN")
|
||||
|
||||
return generation.text
|
||||
|
||||
def search(self, query: str ) -> str:
|
||||
"""Searches the vector index for the given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
"""
|
||||
logger.debug(f"Search term: {query}")
|
||||
|
||||
query_embedding = self.embedder.generate_embedding(query)
|
||||
search_results = self.vector_search.run_query(
|
||||
deployed_index_id=self.index_settings.deployment,
|
||||
query=query_embedding,
|
||||
limit=5,
|
||||
)
|
||||
|
||||
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
|
||||
cutoff= max_sim * 0.9
|
||||
|
||||
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
|
||||
|
||||
logger.debug(f"{max_sim=}")
|
||||
logger.debug(f"{cutoff=}")
|
||||
logger.debug(f"chunks={[s['id'] for s in search_results]}")
|
||||
logger.debug(f"distancias={[s['distance'] for s in search_results]}")
|
||||
|
||||
return self._format_results(search_results)
|
||||
|
||||
async def async_search(self, query: str) -> str:
|
||||
"""Searches the vector index for the given query.
|
||||
|
||||
Args:
|
||||
query: The query to search for.
|
||||
|
||||
Returns:
|
||||
A formatted string containing the search results.
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
query_embedding = await self.embedder.async_generate_embedding(query)
|
||||
t_embed = time.perf_counter()
|
||||
|
||||
search_results = await self.vector_search.async_run_query(
|
||||
deployed_index_id=self.index_settings.deployment,
|
||||
query=query_embedding,
|
||||
limit=5,
|
||||
)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
max_sim = max(search_results, key=lambda x: x["distance"])["distance"]
|
||||
cutoff = max_sim * 0.9
|
||||
search_results = [s for s in search_results if s["distance"] > cutoff and s["distance"] > self.min_sim]
|
||||
|
||||
logger.info(
|
||||
"async_search.timing",
|
||||
embedding_ms=round((t_embed - t0) * 1000, 1),
|
||||
vector_search_ms=round((t_search - t_embed) * 1000, 1),
|
||||
total_ms=round((t_search - t0) * 1000, 1),
|
||||
chunks=[s["id"] for s in search_results],
|
||||
)
|
||||
|
||||
return self._format_results(search_results)
|
||||
|
||||
def _format_results(self, search_results):
|
||||
formatted_results = [
|
||||
f"<document {i} name={result['id']}>\n{result['content']}\n</document {i}>"
|
||||
for i, result in enumerate(search_results, start=1)
|
||||
]
|
||||
return "\n".join(formatted_results)
|
||||
|
||||
async def async_call(self, query: str) -> str:
|
||||
"""Calls the LLM with the provided query and tools.
|
||||
|
||||
Args:
|
||||
query: The user's query.
|
||||
|
||||
Returns:
|
||||
The response from the LLM.
|
||||
"""
|
||||
t_start = time.perf_counter()
|
||||
|
||||
t0 = time.perf_counter()
|
||||
context = await self.async_search(query)
|
||||
t_search = time.perf_counter()
|
||||
|
||||
contents = [types.Content(role="user", parts=[types.Part(text=f"{query}\n\n{context}")])]
|
||||
|
||||
t1 = time.perf_counter()
|
||||
async with _llm_semaphore:
|
||||
generation = await self.llm.async_generate(
|
||||
model=self.model,
|
||||
prompt=contents,
|
||||
system_prompt=self.system_prompt,
|
||||
)
|
||||
t_llm = time.perf_counter()
|
||||
|
||||
t_end = time.perf_counter()
|
||||
logger.info(
|
||||
"async_call.timing",
|
||||
total_ms=round((t_end - t_start) * 1000, 1),
|
||||
stages=[
|
||||
{"stage": "search", "ms": round((t_search - t0) * 1000, 1)},
|
||||
{"stage": "llm", "ms": round((t_llm - t1) * 1000, 1)},
|
||||
],
|
||||
llm_iterations=1,
|
||||
prompt_tokens=generation.usage.prompt_tokens,
|
||||
response_tokens=generation.usage.response_tokens,
|
||||
cost_mxn=generation.usage.get_cost(self.model),
|
||||
)
|
||||
|
||||
return generation.text
|
||||
Reference in New Issue
Block a user