forked from innovacion/Mayacontigo
100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
from collections.abc import Sequence
|
|
from typing import Any
|
|
|
|
from langfuse.decorators import langfuse_context, observe
|
|
from qdrant_client import QdrantClient, models
|
|
|
|
from .base import BaseQdrant
|
|
|
|
|
|
class Qdrant(BaseQdrant):
|
|
def __init__(
|
|
self, *, url: str, api_key: str | None, collection: str | None = None
|
|
) -> None:
|
|
super().__init__(url=url, api_key=api_key, collection=collection)
|
|
self.client = QdrantClient(url=url, api_key=api_key)
|
|
|
|
def list_collections(self) -> Sequence[str]:
|
|
return [
|
|
collection.name for collection in self.client.get_collections().collections
|
|
]
|
|
|
|
@observe(capture_input=False)
|
|
def semantic_search(
|
|
self,
|
|
embedding: Sequence[float] | models.NamedVector,
|
|
collection: str | None = None,
|
|
limit: int = 10,
|
|
conditions: Any | None = None,
|
|
threshold: float | None = None,
|
|
**kwargs,
|
|
) -> Sequence[dict[str, Any]]:
|
|
if collection is None:
|
|
if self.collection is None:
|
|
raise ValueError(
|
|
"No collection set; Please set a collection before calling 'semantic_search'"
|
|
)
|
|
collection = self.collection
|
|
|
|
langfuse_context.update_current_observation(
|
|
input={
|
|
"collection": collection,
|
|
"limit": limit,
|
|
"embedding": embedding,
|
|
"conditions": conditions,
|
|
}
|
|
)
|
|
|
|
points = self.client.search(
|
|
collection_name=collection,
|
|
query_vector=embedding,
|
|
query_filter=conditions,
|
|
limit=limit,
|
|
with_payload=True,
|
|
with_vectors=False,
|
|
score_threshold=threshold,
|
|
**kwargs,
|
|
)
|
|
|
|
return [point.payload for point in points if point.payload is not None]
|
|
|
|
def create_collection_if_not_exists(
|
|
self,
|
|
*,
|
|
collection: str | None = None,
|
|
vector_config: dict[str, models.VectorParams],
|
|
):
|
|
if collection is None:
|
|
if self.collection is None:
|
|
raise ValueError(
|
|
"No collection is set; Please set a collection before calling 'create_collection_if_not_exists'"
|
|
)
|
|
collection = self.collection
|
|
|
|
result = self.client.get_collections()
|
|
collection_names = [collection.name for collection in result.collections]
|
|
|
|
if collection not in collection_names:
|
|
return self.client.create_collection(
|
|
collection_name=collection,
|
|
vectors_config=vector_config,
|
|
)
|
|
|
|
return False
|
|
|
|
def upload_to_collection(
|
|
self,
|
|
*,
|
|
points: list[models.PointStruct],
|
|
collection: str | None = None,
|
|
):
|
|
if collection is None:
|
|
if self.collection is None:
|
|
raise ValueError(
|
|
"No collection is set; Please set a collection before calling 'create_collection_if_not_exists'"
|
|
)
|
|
collection = self.collection
|
|
|
|
for point in points:
|
|
self.client.upsert(collection_name=collection, points=[point], wait=True)
|