|
|
|
|
|
from qdrant_client import QdrantClient |
|
from qdrant_client.http import models |
|
import os |
|
import logging |
|
import uuid |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
qdrant_client = QdrantClient( |
|
url=os.getenv('QDRANT_URL'), |
|
api_key=os.getenv('QDRANT_API_KEY') |
|
) |
|
|
|
def create_collection_if_not_exists(collection_name, vector_size): |
|
try: |
|
|
|
collections = qdrant_client.get_collections().collections |
|
if not any(collection.name == collection_name for collection in collections): |
|
|
|
qdrant_client.create_collection( |
|
collection_name=collection_name, |
|
vectors_config=models.VectorParams(size=vector_size, distance=models.Distance.COSINE) |
|
) |
|
logging.info(f"Created new collection: {collection_name}") |
|
else: |
|
logging.info(f"Collection {collection_name} already exists") |
|
except Exception as e: |
|
logging.error(f"Error creating collection: {str(e)}") |
|
raise |
|
|
|
def store_embeddings(chunks, embeddings, user_id, data_source_id, file_id, organization_id, s3_bucket_key, total_tokens): |
|
try: |
|
collection_name = "embed" |
|
vector_size = len(embeddings[0]) |
|
|
|
|
|
create_collection_if_not_exists(collection_name, vector_size) |
|
|
|
|
|
points = [] |
|
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): |
|
chunk_id = str(uuid.uuid4()) |
|
points.append( |
|
models.PointStruct( |
|
id=chunk_id, |
|
vector=embedding.tolist(), |
|
payload={ |
|
"user_id": user_id, |
|
"data_source_id": data_source_id, |
|
"file_id": file_id, |
|
"organization_id": organization_id, |
|
"chunk_index": i, |
|
"chunk_text": chunk, |
|
"s3_bucket_key": s3_bucket_key, |
|
"total_tokens": total_tokens |
|
|
|
} |
|
) |
|
) |
|
|
|
|
|
qdrant_client.upsert( |
|
collection_name=collection_name, |
|
points=points |
|
) |
|
logging.info(f"Successfully stored {len(points)} embeddings") |
|
except Exception as e: |
|
logging.error(f"Error storing embeddings in Qdrant: {str(e)}") |
|
raise |