Add Vertex AI support
This commit is contained in:
@@ -48,6 +48,8 @@ type ProviderEntry struct {
|
||||
APIKey string `yaml:"api_key"`
|
||||
Endpoint string `yaml:"endpoint"`
|
||||
APIVersion string `yaml:"api_version"`
|
||||
Project string `yaml:"project"` // For Vertex AI
|
||||
Location string `yaml:"location"` // For Vertex AI
|
||||
}
|
||||
|
||||
// ModelEntry maps a model name to a provider entry.
|
||||
@@ -78,6 +80,12 @@ type AzureAnthropicConfig struct {
|
||||
Model string `yaml:"model"`
|
||||
}
|
||||
|
||||
// VertexAIConfig contains Vertex AI-specific settings used internally by the Google provider.
|
||||
type VertexAIConfig struct {
|
||||
Project string `yaml:"project"`
|
||||
Location string `yaml:"location"`
|
||||
}
|
||||
|
||||
// Load reads and parses a YAML configuration file, expanding ${VAR} env references.
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
|
||||
@@ -19,7 +19,7 @@ type Provider struct {
|
||||
client *genai.Client
|
||||
}
|
||||
|
||||
// New constructs a Provider using the provided configuration.
|
||||
// New constructs a Provider using the Google AI API with API key authentication.
|
||||
func New(cfg config.ProviderConfig) *Provider {
|
||||
var client *genai.Client
|
||||
if cfg.APIKey != "" {
|
||||
@@ -38,13 +38,36 @@ func New(cfg config.ProviderConfig) *Provider {
|
||||
}
|
||||
}
|
||||
|
||||
// NewVertexAI constructs a Provider targeting Vertex AI.
|
||||
// Vertex AI uses the same genai SDK but with GCP project/location configuration
|
||||
// and Application Default Credentials (ADC) or service account authentication.
|
||||
func NewVertexAI(vertexCfg config.VertexAIConfig) *Provider {
|
||||
var client *genai.Client
|
||||
if vertexCfg.Project != "" && vertexCfg.Location != "" {
|
||||
var err error
|
||||
client, err = genai.NewClient(context.Background(), &genai.ClientConfig{
|
||||
Project: vertexCfg.Project,
|
||||
Location: vertexCfg.Location,
|
||||
Backend: genai.BackendVertexAI,
|
||||
})
|
||||
if err != nil {
|
||||
// Log error but don't fail construction - will fail on Generate
|
||||
fmt.Printf("warning: failed to create vertex ai client: %v\n", err)
|
||||
}
|
||||
}
|
||||
return &Provider{
|
||||
cfg: config.ProviderConfig{
|
||||
// Vertex AI doesn't use API key, but set empty for consistency
|
||||
APIKey: "",
|
||||
},
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) Name() string { return Name }
|
||||
|
||||
// Generate routes the request to Gemini and returns a ProviderResult.
|
||||
func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) {
|
||||
if p.cfg.APIKey == "" {
|
||||
return nil, fmt.Errorf("google api key missing")
|
||||
}
|
||||
if p.client == nil {
|
||||
return nil, fmt.Errorf("google client not initialized")
|
||||
}
|
||||
@@ -96,10 +119,6 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r
|
||||
defer close(deltaChan)
|
||||
defer close(errChan)
|
||||
|
||||
if p.cfg.APIKey == "" {
|
||||
errChan <- fmt.Errorf("google api key missing")
|
||||
return
|
||||
}
|
||||
if p.client == nil {
|
||||
errChan <- fmt.Errorf("google client not initialized")
|
||||
return
|
||||
|
||||
@@ -60,7 +60,8 @@ func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelE
|
||||
}
|
||||
|
||||
func buildProvider(entry config.ProviderEntry) (Provider, error) {
|
||||
if entry.APIKey == "" {
|
||||
// Vertex AI doesn't require APIKey, so check for it separately
|
||||
if entry.Type != "vertexai" && entry.APIKey == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -97,6 +98,14 @@ func buildProvider(entry config.ProviderEntry) (Provider, error) {
|
||||
APIKey: entry.APIKey,
|
||||
Endpoint: entry.Endpoint,
|
||||
}), nil
|
||||
case "vertexai":
|
||||
if entry.Project == "" || entry.Location == "" {
|
||||
return nil, fmt.Errorf("project and location are required for vertexai")
|
||||
}
|
||||
return googleprovider.NewVertexAI(config.VertexAIConfig{
|
||||
Project: entry.Project,
|
||||
Location: entry.Location,
|
||||
}), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown provider type %q", entry.Type)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user