forked from innovacion/Mayacontigo
ic
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import google.oauth2.service_account as sa
|
||||
import vertexai
|
||||
import vertexai.generative_models as gm
|
||||
from PIL.Image import Image
|
||||
|
||||
|
||||
class Gemini:
|
||||
def __init__(
|
||||
self, model: str | None = None, *, account_info: dict[str, str] | None = None
|
||||
) -> None:
|
||||
if account_info is None:
|
||||
account_info = json.loads(os.environ["GCP_SERVICE_ACCOUNT"])
|
||||
assert account_info is not None
|
||||
|
||||
credentials = sa.Credentials.from_service_account_info(account_info)
|
||||
vertexai.init(project=account_info["project_id"], credentials=credentials)
|
||||
|
||||
self.model = gm.GenerativeModel(model) if model else None
|
||||
|
||||
def set_model(self, model: str):
|
||||
self.model = gm.GenerativeModel(model)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, model: str, path: Path):
|
||||
account_info = json.loads(path.read_text())
|
||||
return cls(model, account_info=account_info)
|
||||
|
||||
def generate(self, contents, response_schema=None):
|
||||
if self.model is None:
|
||||
raise ValueError(
|
||||
"No model set; Please choose a model before calling 'generate'"
|
||||
)
|
||||
|
||||
generation_config = None
|
||||
if response_schema:
|
||||
generation_config = gm.GenerationConfig(
|
||||
response_mime_type="application/json", response_schema=response_schema
|
||||
)
|
||||
return self.model.generate_content(
|
||||
contents, generation_config=generation_config
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_part_from_pdf_bytes(pdf_bytes: bytes):
|
||||
part = gm.Part.from_data(
|
||||
data=pdf_bytes,
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
return part
|
||||
|
||||
@staticmethod
|
||||
def create_part_from_PIL_image(pil_image: Image, format="jpeg"):
|
||||
with io.BytesIO() as img_buffer:
|
||||
pil_image.save(img_buffer, format=format.upper())
|
||||
img_bytes = img_buffer.getvalue()
|
||||
|
||||
part = gm.Part.from_data(
|
||||
data=img_bytes,
|
||||
mime_type="image/" + format,
|
||||
)
|
||||
|
||||
return part
|
||||
|
||||
@classmethod
|
||||
def from_vault(
|
||||
cls,
|
||||
vault: str,
|
||||
*,
|
||||
model: str | None = None,
|
||||
url: str | None = None,
|
||||
token: str | None = None,
|
||||
mount_point: str = "secret",
|
||||
):
|
||||
from hvac import Client
|
||||
|
||||
client = Client(url=url or "https://vault.ia-innovacion.work", token=token)
|
||||
|
||||
if not client.is_authenticated():
|
||||
raise Exception("Vault authentication failed")
|
||||
|
||||
secret_map = client.secrets.kv.v2.read_secret_version(
|
||||
path=vault, mount_point=mount_point
|
||||
)["data"]["data"]
|
||||
|
||||
return cls(
|
||||
account_info=secret_map["gcp_service_account"],
|
||||
model=model,
|
||||
)
|
||||
Reference in New Issue
Block a user