Spaces:
Runtime error
Runtime error
from qdrant_client import QdrantClient | |
from qdrant_client.http.models import ScoredPoint | |
from embedding import Embedding | |
from model.document import Document | |
from model.record import Record | |
from model.user import User | |
from qdrant_client.http import models | |
import uuid | |
import tqdm | |
class Index: | |
type: str | |
def load_or_update_document(self, user: User, document: Document, progress: tqdm.tqdm = None): | |
pass | |
def remove_document(self, user: User, document: Document): | |
pass | |
def query_index(self, user: User, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: | |
pass | |
def query_document(self, user: User, document: Document, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: | |
pass | |
def contains(self, user: User, document: Document) -> bool: | |
pass | |
class QDrantVectorStore(Index): | |
_client: QdrantClient | |
_embedding: Embedding | |
collection_name: str | |
batch_size: int = 10 | |
type: str = 'qdrant' | |
def __init__( | |
self, | |
client: QdrantClient, | |
embedding: Embedding, | |
collection_name: str): | |
self._embedding = embedding | |
self.collection_name = collection_name | |
self._client = client | |
def _response_to_records(self, response: list[ScoredPoint]) -> list[Record]: | |
for point in response: | |
meta_data = point.payload['meta_data'] | |
yield Record( | |
embedding=point.vector, | |
meta_data= meta_data, | |
content=point.payload['content'], | |
document_id=point.payload['document_id'], | |
timestamp=point.payload['timestamp'], | |
) | |
def create_collection(self): | |
self._client.recreate_collection( | |
collection_name=self.collection_name, | |
vectors_config=models.VectorParams( | |
size=self._embedding.vector_size, | |
distance=models.Distance.COSINE), | |
) | |
def if_collection_exists(self) -> bool: | |
try: | |
self._client.get_collection(self.collection_name) | |
return True | |
except Exception: | |
return False | |
def create_collection_if_not_exists(self): | |
if not self.if_collection_exists(): | |
self.create_collection() | |
def load_or_update_document(self, user: User, document: Document, progress: tqdm.tqdm = None): | |
self.create_collection_if_not_exists() | |
if self.contains(user, document): | |
self.remove_document(user, document) | |
group_id = user.user_name | |
# upsert records in batch | |
records = document.load_records() | |
records = list(records) | |
batch_range = range(0, len(records), self.batch_size) | |
if progress is not None: | |
batch_range = progress(batch_range) | |
for i in batch_range: | |
batch = records[i:i+self.batch_size] | |
uuids = [str(uuid.uuid4()) for _ in batch] | |
payloads = [{ | |
'content': record.content, | |
'meta_data': record.meta_data, | |
'document_id': record.document_id, | |
'group_id': group_id, | |
'timestamp': record.timestamp, | |
} for record in batch] | |
vectors = [record.embedding for record in batch] | |
self._client.upsert( | |
collection_name=self.collection_name, | |
points=models.Batch( | |
payloads=payloads, | |
ids=uuids, | |
vectors=vectors, | |
), | |
) | |
def remove_document(self, user: User, document: Document): | |
if not self.if_collection_exists(): | |
return | |
document_id = document.name | |
self._client.delete( | |
collection_name=self.collection_name, | |
points_selector=models.FilterSelector( | |
filter=models.Filter( | |
must=[ | |
models.FieldCondition( | |
key="document_id", | |
match=models.MatchValue(value=document_id) | |
), | |
models.FieldCondition( | |
key="group_id", | |
match=models.MatchValue( | |
value=user.user_name, | |
), | |
) | |
] | |
) | |
) | |
) | |
def contains(self, user: User, document: Document) -> bool: | |
document_id = document.name | |
group_id = user.user_name | |
count = self._client.count( | |
collection_name=self.collection_name, | |
count_filter=models.Filter( | |
must=[ | |
models.FieldCondition( | |
key="document_id", | |
match=models.MatchValue(value=document_id) | |
), | |
models.FieldCondition( | |
key="group_id", | |
match=models.MatchValue( | |
value=group_id, | |
), | |
) | |
] | |
), | |
exact=True, | |
) | |
return count.count > 0 | |
def query_index(self, user: User, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: | |
if not self.if_collection_exists(): | |
return [] | |
response = self._client.search( | |
collection_name=self.collection_name, | |
query_vector=self._embedding.generate_embedding(query), | |
limit=top_k, | |
query_filter= models.Filter( | |
must=[ | |
models.FieldCondition( | |
key="group_id", | |
match=models.MatchValue( | |
value=user.user_name, | |
), | |
) | |
] | |
), | |
score_threshold=threshold, | |
) | |
return self._response_to_records(response) | |
def query_document(self, user: User, document: Document, query: str, top_k: int = 10, threshold: float = 0.5) -> list[Record]: | |
if not self.if_collection_exists(): | |
return [] | |
response = self._client.search( | |
collection_name=self.collection_name, | |
query_vector=self._embedding.generate_embedding(query), | |
limit=top_k, | |
query_filter= models.Filter( | |
must=[ | |
models.FieldCondition( | |
key="document_id", | |
match=models.MatchValue(value=document.name) | |
), | |
models.FieldCondition( | |
key="group_id", | |
match=models.MatchValue(value=user.user_name), | |
) | |
] | |
), | |
score_threshold=threshold, | |
) | |
return self._response_to_records(response) | |