121 lines
3.4 KiB
Python
121 lines
3.4 KiB
Python
# Copyright 2026 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Configuration helper for ADK agent with vector search."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from dataclasses import dataclass
|
|
from functools import cached_property
|
|
|
|
import vertexai
|
|
from pydantic_settings import (
|
|
BaseSettings,
|
|
PydanticBaseSettingsSource,
|
|
SettingsConfigDict,
|
|
YamlConfigSettingsSource,
|
|
)
|
|
from vertexai.language_models import TextEmbeddingModel
|
|
|
|
CONFIG_FILE_PATH = os.getenv("CONFIG_YAML", "config.yaml")
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingResult:
|
|
"""Result from embedding a query."""
|
|
|
|
embeddings: list[list[float]]
|
|
|
|
|
|
class VertexAIEmbedder:
|
|
"""Embedder using Vertex AI TextEmbeddingModel."""
|
|
|
|
def __init__(self, model_name: str, project_id: str, location: str) -> None:
|
|
"""Initialize the embedder.
|
|
|
|
Args:
|
|
model_name: Name of the embedding model (e.g., 'text-embedding-004')
|
|
project_id: GCP project ID
|
|
location: GCP location
|
|
|
|
"""
|
|
vertexai.init(project=project_id, location=location)
|
|
self.model = TextEmbeddingModel.from_pretrained(model_name)
|
|
|
|
async def embed_query(self, query: str) -> EmbeddingResult:
|
|
"""Embed a single query string.
|
|
|
|
Args:
|
|
query: Text to embed
|
|
|
|
Returns:
|
|
EmbeddingResult with embeddings list
|
|
|
|
"""
|
|
embeddings = self.model.get_embeddings([query])
|
|
return EmbeddingResult(embeddings=[list(embeddings[0].values)])
|
|
|
|
|
|
class AgentSettings(BaseSettings):
|
|
"""Settings for ADK agent with vector search."""
|
|
|
|
# Google Cloud settings
|
|
project_id: str
|
|
location: str
|
|
bucket: str
|
|
|
|
# Agent configuration
|
|
agent_name: str
|
|
agent_instructions: str
|
|
agent_language_model: str
|
|
agent_embedding_model: str
|
|
|
|
# Vector index configuration
|
|
index_name: str
|
|
index_deployed_id: str
|
|
index_endpoint: str
|
|
|
|
model_config = SettingsConfigDict(
|
|
yaml_file=CONFIG_FILE_PATH,
|
|
extra="ignore", # Ignore extra fields from config.yaml
|
|
)
|
|
|
|
@classmethod
|
|
def settings_customise_sources(
|
|
cls,
|
|
settings_cls: type[BaseSettings],
|
|
init_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
|
env_settings: PydanticBaseSettingsSource,
|
|
dotenv_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
|
file_secret_settings: PydanticBaseSettingsSource, # noqa: ARG003
|
|
) -> tuple[PydanticBaseSettingsSource, ...]:
|
|
"""Use env vars and YAML as settings sources."""
|
|
return (
|
|
env_settings,
|
|
YamlConfigSettingsSource(settings_cls),
|
|
)
|
|
|
|
@cached_property
|
|
def embedder(self) -> VertexAIEmbedder:
|
|
"""Return an embedder configured for the agent's embedding model."""
|
|
return VertexAIEmbedder(
|
|
model_name=self.agent_embedding_model,
|
|
project_id=self.project_id,
|
|
location=self.location,
|
|
)
|
|
|
|
|
|
settings = AgentSettings.model_validate({})
|