forked from innovacion/Mayacontigo
48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
from langfuse.openai import AzureOpenAI
|
|
from openai.types.embedding import Embedding
|
|
|
|
from .base import BaseAda
|
|
|
|
|
|
class Ada(BaseAda):
|
|
def __init__(
|
|
self, model: str | None = None, *, endpoint: str, key: str, version: str
|
|
) -> None:
|
|
super().__init__(model, endpoint=endpoint, key=key, version=version)
|
|
self.client = AzureOpenAI(
|
|
azure_endpoint=endpoint, api_key=key, api_version=version
|
|
)
|
|
|
|
def embed(
|
|
self, input: str | list[str], *, model: str | None = None
|
|
) -> list[float] | list[list[float]]:
|
|
if isinstance(input, str):
|
|
return self.embed_query(input, model)
|
|
else:
|
|
return self.batch_embed(input, model)
|
|
|
|
def batch_embed(
|
|
self, texts: list[str], model: str | None = None
|
|
) -> list[list[float]]:
|
|
if model is None:
|
|
if self.model is None:
|
|
raise ValueError("No embedding model set")
|
|
model = self.model
|
|
|
|
batches = [texts[i : i + 2048] for i in range(0, len(texts), 2048)]
|
|
results = [
|
|
(self.client.embeddings.create(input=batch, model=model)).data
|
|
for batch in batches
|
|
]
|
|
flattened_results: list[Embedding] = sum(results, [])
|
|
return [result.embedding for result in flattened_results]
|
|
|
|
def embed_query(self, text: str, model: str | None = None) -> list[float]:
|
|
if model is None:
|
|
if self.model is None:
|
|
raise ValueError("No embedding model set")
|
|
model = self.model
|
|
|
|
response = self.client.embeddings.create(input=text, model=model)
|
|
return response.data[0].embedding
|