93 lines
2.4 KiB
Python
93 lines
2.4 KiB
Python
"""Pydantic AI agent with RAG tool for vector search."""
|
|
|
|
import time
|
|
|
|
import structlog
|
|
from pydantic import BaseModel
|
|
from pydantic_ai import Agent, Embedder, RunContext
|
|
from pydantic_ai.models.google import GoogleModel
|
|
|
|
from rag_eval.config import settings
|
|
from rag_eval.vector_search.vertex_ai import GoogleCloudVectorSearch
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
class Deps(BaseModel):
|
|
"""Dependencies injected into the agent at runtime."""
|
|
|
|
vector_search: GoogleCloudVectorSearch
|
|
embedder: Embedder
|
|
|
|
model_config = {"arbitrary_types_allowed": True}
|
|
|
|
|
|
model = GoogleModel(
|
|
settings.agent_language_model,
|
|
provider=settings.provider,
|
|
)
|
|
agent = Agent(
|
|
model,
|
|
deps_type=Deps,
|
|
system_prompt=settings.agent_instructions,
|
|
)
|
|
|
|
|
|
@agent.tool
|
|
async def conocimiento(ctx: RunContext[Deps], query: str) -> str:
|
|
"""Search the vector index for the given query.
|
|
|
|
Args:
|
|
ctx: The run context containing dependencies.
|
|
query: The query to search for.
|
|
|
|
Returns:
|
|
A formatted string containing the search results.
|
|
|
|
"""
|
|
t0 = time.perf_counter()
|
|
min_sim = 0.6
|
|
|
|
query_embedding = await ctx.deps.embedder.embed_query(query)
|
|
t_embed = time.perf_counter()
|
|
|
|
search_results = await ctx.deps.vector_search.async_run_query(
|
|
deployed_index_id=settings.index_deployed_id,
|
|
query=list(query_embedding.embeddings[0]),
|
|
limit=5,
|
|
)
|
|
t_search = time.perf_counter()
|
|
|
|
if search_results:
|
|
max_sim = max(r["distance"] for r in search_results)
|
|
cutoff = max_sim * 0.9
|
|
search_results = [
|
|
s
|
|
for s in search_results
|
|
if s["distance"] > cutoff and s["distance"] > min_sim
|
|
]
|
|
|
|
logger.info(
|
|
"conocimiento.timing",
|
|
embedding_ms=round((t_embed - t0) * 1000, 1),
|
|
vector_search_ms=round((t_search - t_embed) * 1000, 1),
|
|
total_ms=round((t_search - t0) * 1000, 1),
|
|
chunks=[s["id"] for s in search_results],
|
|
)
|
|
|
|
formatted_results = [
|
|
f"<document {i} name={result['id']}>\n"
|
|
f"{result['content']}\n"
|
|
f"</document {i}>"
|
|
for i, result in enumerate(search_results, start=1)
|
|
]
|
|
return "\n".join(formatted_results)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
deps = Deps(
|
|
vector_search=settings.vector_search,
|
|
embedder=settings.embedder,
|
|
)
|
|
agent.to_cli_sync(deps=deps)
|