diff --git a/README.md b/README.md index 552f3ba..0767644 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ Simplify LLM integration by exposing a single, consistent API that routes reques - **Azure OpenAI** (Azure-deployed models) - **Anthropic** (Claude) - **Google Generative AI** (Gemini) +- **Vertex AI** (Google Cloud-hosted Gemini models) Instead of managing multiple SDK integrations in your application, call one endpoint and let the gateway handle provider-specific implementations. @@ -24,7 +25,8 @@ latticelm (unified API) ├─→ OpenAI SDK ├─→ Azure OpenAI (OpenAI SDK + Azure auth) ├─→ Anthropic SDK -└─→ Google Gen AI SDK +├─→ Google Gen AI SDK +└─→ Vertex AI (Google Gen AI SDK + GCP auth) ``` ## Key Features @@ -45,11 +47,12 @@ latticelm (unified API) ## 🎉 Status: **WORKING!** -✅ **All four providers integrated with official Go SDKs:** +✅ **All providers integrated with official Go SDKs:** - OpenAI → `github.com/openai/openai-go/v3` - Azure OpenAI → `github.com/openai/openai-go/v3` (with Azure auth) - Anthropic → `github.com/anthropics/anthropic-sdk-go` - Google → `google.golang.org/genai` +- Vertex AI → `google.golang.org/genai` (with GCP auth) ✅ **Compiles successfully** (36MB binary) ✅ **Provider auto-selection** (gpt→Azure/OpenAI, claude→Anthropic, gemini→Google) diff --git a/config.example.yaml b/config.example.yaml index 9c6cc6e..2d25fa5 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -14,6 +14,12 @@ providers: type: "openai" api_key: "YOUR_OPENAI_API_KEY" endpoint: "https://api.openai.com" + # Vertex AI (Google Cloud) - optional + # Uses Application Default Credentials (ADC) or service account + # vertexai: + # type: "vertexai" + # project: "your-gcp-project-id" + # location: "us-central1" # or other GCP region # Azure OpenAI - optional # azureopenai: # type: "azureopenai" @@ -48,6 +54,8 @@ models: provider: "anthropic" - name: "gpt-4o-mini" provider: "openai" + # - name: "gemini-2.0-flash-exp" + # provider: "vertexai" # Use Vertex AI instead of Google AI API # - name: "gpt-4o" # provider: "azureopenai" # provider_model_id: "my-gpt4o-deployment" # optional: defaults to name diff --git a/internal/config/config.go b/internal/config/config.go index 9bb9c84..803e058 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,6 +48,8 @@ type ProviderEntry struct { APIKey string `yaml:"api_key"` Endpoint string `yaml:"endpoint"` APIVersion string `yaml:"api_version"` + Project string `yaml:"project"` // For Vertex AI + Location string `yaml:"location"` // For Vertex AI } // ModelEntry maps a model name to a provider entry. @@ -78,6 +80,12 @@ type AzureAnthropicConfig struct { Model string `yaml:"model"` } +// VertexAIConfig contains Vertex AI-specific settings used internally by the Google provider. +type VertexAIConfig struct { + Project string `yaml:"project"` + Location string `yaml:"location"` +} + // Load reads and parses a YAML configuration file, expanding ${VAR} env references. func Load(path string) (*Config, error) { data, err := os.ReadFile(path) diff --git a/internal/providers/google/google.go b/internal/providers/google/google.go index bb31131..5be93f1 100644 --- a/internal/providers/google/google.go +++ b/internal/providers/google/google.go @@ -19,7 +19,7 @@ type Provider struct { client *genai.Client } -// New constructs a Provider using the provided configuration. +// New constructs a Provider using the Google AI API with API key authentication. func New(cfg config.ProviderConfig) *Provider { var client *genai.Client if cfg.APIKey != "" { @@ -38,13 +38,36 @@ func New(cfg config.ProviderConfig) *Provider { } } +// NewVertexAI constructs a Provider targeting Vertex AI. +// Vertex AI uses the same genai SDK but with GCP project/location configuration +// and Application Default Credentials (ADC) or service account authentication. +func NewVertexAI(vertexCfg config.VertexAIConfig) *Provider { + var client *genai.Client + if vertexCfg.Project != "" && vertexCfg.Location != "" { + var err error + client, err = genai.NewClient(context.Background(), &genai.ClientConfig{ + Project: vertexCfg.Project, + Location: vertexCfg.Location, + Backend: genai.BackendVertexAI, + }) + if err != nil { + // Log error but don't fail construction - will fail on Generate + fmt.Printf("warning: failed to create vertex ai client: %v\n", err) + } + } + return &Provider{ + cfg: config.ProviderConfig{ + // Vertex AI doesn't use API key, but set empty for consistency + APIKey: "", + }, + client: client, + } +} + func (p *Provider) Name() string { return Name } // Generate routes the request to Gemini and returns a ProviderResult. func (p *Provider) Generate(ctx context.Context, messages []api.Message, req *api.ResponseRequest) (*api.ProviderResult, error) { - if p.cfg.APIKey == "" { - return nil, fmt.Errorf("google api key missing") - } if p.client == nil { return nil, fmt.Errorf("google client not initialized") } @@ -96,10 +119,6 @@ func (p *Provider) GenerateStream(ctx context.Context, messages []api.Message, r defer close(deltaChan) defer close(errChan) - if p.cfg.APIKey == "" { - errChan <- fmt.Errorf("google api key missing") - return - } if p.client == nil { errChan <- fmt.Errorf("google client not initialized") return diff --git a/internal/providers/providers.go b/internal/providers/providers.go index a4aa170..a22f8a4 100644 --- a/internal/providers/providers.go +++ b/internal/providers/providers.go @@ -60,7 +60,8 @@ func NewRegistry(entries map[string]config.ProviderEntry, models []config.ModelE } func buildProvider(entry config.ProviderEntry) (Provider, error) { - if entry.APIKey == "" { + // Vertex AI doesn't require APIKey, so check for it separately + if entry.Type != "vertexai" && entry.APIKey == "" { return nil, nil } @@ -97,6 +98,14 @@ func buildProvider(entry config.ProviderEntry) (Provider, error) { APIKey: entry.APIKey, Endpoint: entry.Endpoint, }), nil + case "vertexai": + if entry.Project == "" || entry.Location == "" { + return nil, fmt.Errorf("project and location are required for vertexai") + } + return googleprovider.NewVertexAI(config.VertexAIConfig{ + Project: entry.Project, + Location: entry.Location, + }), nil default: return nil, fmt.Errorf("unknown provider type %q", entry.Type) }