quora-similar-questions / build_quora_index.py
yalaa's picture
Update build_quora_index.py
ae1fd26 verified
import os
from tqdm import tqdm
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient, models
MAX_QUESTIONS = 1000
def compute_embedding(sentences, emb_model):
return emb_model.encode(sentences=sentences)
def get_questions(ds):
questions_text = set()
for i, item in enumerate(ds):
if i == MAX_QUESTIONS:
break
for q_text in item['questions']['text']:
questions_text.add(q_text)
unique_questions = list(questions_text)
return [{'question': q} for q in unique_questions]
def build_index():
qdrant = QdrantClient(
url=os.environ['QDRANT_URL'],
api_key=os.environ['QDRANT_API_KEY'],
)
encoder = SentenceTransformer(model_name_or_path='BAAI/bge-small-en-v1.5')
quora_ds = load_dataset(path='quora', split='train', streaming=True)
quora_questions = get_questions(ds=quora_ds)
qdrant.recreate_collection(
collection_name='questions',
vectors_config=models.VectorParams(
size=encoder.get_sentence_embedding_dimension(),
distance=models.Distance.COSINE
)
)
BATCH_SIZE = 100
question_batch = []
for idx, entry in enumerate(tqdm(quora_questions, desc='Uploading vector embeddings in batch size of {}'.format(BATCH_SIZE))):
if len(question_batch) < BATCH_SIZE:
question_batch.append({
'payload': entry,
'id': idx
})
else:
questions_list = [item['payload']['question'] for item in question_batch]
embedding_batch = compute_embedding(questions_list, encoder).tolist()
records = [
models.Record(
id=entry['id'],
payload=entry['payload'],
vector=embedding
) for entry, embedding in zip(question_batch, embedding_batch)
]
qdrant.upload_records(
collection_name='questions',
records=records
)
question_batch = []