Add CI
This commit is contained in:
111
scripts/agent.py
Normal file
111
scripts/agent.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# ruff: noqa: INP001
|
||||
"""ADK agent that connects to the knowledge-search MCP server."""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.runners import Runner
|
||||
from google.adk.sessions import InMemorySessionService
|
||||
from google.adk.tools.mcp_tool import McpToolset
|
||||
from google.adk.tools.mcp_tool.mcp_session_manager import (
|
||||
SseConnectionParams,
|
||||
StdioConnectionParams,
|
||||
)
|
||||
from google.genai import types
|
||||
from mcp import StdioServerParameters
|
||||
|
||||
# ADK needs these env vars for Vertex AI; reuse the ones from .env
|
||||
os.environ.setdefault("GOOGLE_GENAI_USE_VERTEXAI", "True")
|
||||
if project := os.environ.get("PROJECT_ID"):
|
||||
os.environ.setdefault("GOOGLE_CLOUD_PROJECT", project)
|
||||
if location := os.environ.get("LOCATION"):
|
||||
os.environ.setdefault("GOOGLE_CLOUD_LOCATION", location)
|
||||
|
||||
SERVER_SCRIPT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src", "knowledge_search_mcp", "main.py")
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Knowledge Search Agent")
|
||||
parser.add_argument(
|
||||
"--remote",
|
||||
metavar="URL",
|
||||
help="Connect to an already-running MCP server at this SSE URL "
|
||||
"(e.g. http://localhost:8080/sse). Without this flag the agent "
|
||||
"spawns the server as a subprocess.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def async_main() -> None:
|
||||
args = _parse_args()
|
||||
|
||||
if args.remote:
|
||||
connection_params = SseConnectionParams(url=args.remote)
|
||||
else:
|
||||
connection_params = StdioConnectionParams(
|
||||
server_params=StdioServerParameters(
|
||||
command="uv",
|
||||
args=["run", "python", SERVER_SCRIPT],
|
||||
),
|
||||
)
|
||||
|
||||
toolset = McpToolset(connection_params=connection_params)
|
||||
|
||||
agent = LlmAgent(
|
||||
model="gemini-2.0-flash",
|
||||
name="knowledge_agent",
|
||||
instruction=(
|
||||
"You are a helpful assistant with access to a knowledge base. "
|
||||
"Use the knowledge_search tool to find relevant information "
|
||||
"when the user asks questions. Summarize the results clearly."
|
||||
),
|
||||
tools=[toolset],
|
||||
)
|
||||
|
||||
session_service = InMemorySessionService()
|
||||
session = await session_service.create_session(
|
||||
state={},
|
||||
app_name="knowledge_agent",
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
runner = Runner(
|
||||
app_name="knowledge_agent",
|
||||
agent=agent,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
print("Knowledge Search Agent ready. Type your query (Ctrl+C to exit):")
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
query = input("\n> ").strip()
|
||||
except EOFError:
|
||||
break
|
||||
if not query:
|
||||
continue
|
||||
|
||||
content = types.Content(
|
||||
role="user",
|
||||
parts=[types.Part(text=query)],
|
||||
)
|
||||
|
||||
async for event in runner.run_async(
|
||||
session_id=session.id,
|
||||
user_id=session.user_id,
|
||||
new_message=content,
|
||||
):
|
||||
if event.is_final_response() and event.content and event.content.parts:
|
||||
for part in event.content.parts:
|
||||
if part.text:
|
||||
print(part.text)
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutting down...")
|
||||
finally:
|
||||
await toolset.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(async_main())
|
||||
Reference in New Issue
Block a user