Spaces:
Runtime error
Runtime error
File size: 6,191 Bytes
38f8c33 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# Run this script independently (`python create_collection.py`) to create a Qdrant collection. A Qdrant collection is a set of vectors among which you can search.
# All the legal documents over which search needs to be enabled need to be converted to their embedding representation and inserted into a Qdrant collection for search feature to work.
import os
import cohere
from datasets import load_dataset
from qdrant_client import QdrantClient
from qdrant_client import models
from qdrant_client.http import models as rest
from constants import (
ENGLISH_EMBEDDING_MODEL,
MULTILINGUAL_EMBEDDING_MODEL,
USE_MULTILINGUAL_EMBEDDING,
CREATE_QDRANT_COLLECTION_NAME,
)
# load environment variables
QDRANT_HOST = os.environ.get("QDRANT_HOST")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
COHERE_API_KEY = os.environ.get("COHERE_API_KEY")
def get_embedding_size():
"""
Get the dimensions of the embeddings returned by the model being used to create embeddings for documents.
Returns:
embedding_size (`int`):
The dimensions of the embeddings returned by the embeddings model.
"""
if USE_MULTILINGUAL_EMBEDDING:
embedding_size = 768
else:
embedding_size = 4096
return embedding_size
def create_qdrant_collection(vector_size):
"""
(Re)-create a Qdrant Collection with the desired `collection name` , `vector_size` and `distance_measure`.
This collection will be used to keep all the vectors representing all the legal documents.
Args:
vector_size (`int`):
The dimensions of the embeddings that will be added to the collection.
"""
if USE_MULTILINGUAL_EMBEDDING:
# multilingual embedding model trained using dot product calculation
distance_measure = rest.Distance.DOT
else:
distance_measure = rest.Distance.COSINE
print("CREATE_QDRANT_COLLECTION_NAME:", CREATE_QDRANT_COLLECTION_NAME)
qdrant_client.recreate_collection(
collection_name=CREATE_QDRANT_COLLECTION_NAME,
vectors_config=models.VectorParams(size=vector_size, distance=distance_measure),
)
def embed_legal_docs(legal_docs):
"""
Create embeddings and ids which will used to represent the legal documents upon which search needs to be enabled.
Args:
legal_docs (`List`):
A list of documents for which embeddings need to be created.
Returns:
doc_embeddings (`List`):
A list of embeddings corresponding to each document.
doc_ids (`List`):
A list of unique ids which will be used as identifiers for the points (documents) in a qdrant collection.
"""
if USE_MULTILINGUAL_EMBEDDING:
model_name = MULTILINGUAL_EMBEDDING_MODEL
else:
model_name = ENGLISH_EMBEDDING_MODEL
legal_docs_embeds = cohere_client.embed(
texts=legal_docs,
model=model_name,
)
doc_embeddings = [
list(map(float, vector)) for vector in legal_docs_embeds.embeddings
]
doc_ids = [id for id, _ in enumerate(legal_docs_embeds)]
return doc_embeddings, doc_ids
def upsert_data_in_collection(vectors, ids, payload):
"""
Create embeddings and ids which will used to represent the legal documents upon which search needs to be enabled.
Args:
vectors (`List`):
A list of embeddings corresponding to each document which needs to be added to the collection.
ids (`List`):
A list of unique ids which will be used as identifiers for the points (documents) in a qdrant collection.
payload (`List`):
A list of additional information or metadata corresponding to each document being added to the collection.
"""
try:
update_result = qdrant_client.upsert(
collection_name=CREATE_QDRANT_COLLECTION_NAME,
points=rest.Batch(
ids=ids,
vectors=vectors,
payloads=payload,
),
)
return update_result
except:
return None
def fetch_legal_documents_and_payload():
"""
Get the legal documents and additional information (payload) related to them which will be used as part of the search module.
Returns:
legal_docs (`List['str]`):
The documents that will be used as part of the search module.
payload (`List[Dict]`):
Additional information related to the documents that are being used as part of the search module.
"""
legal_dataset = load_dataset("joelito/covid19_emergency_event", split="train")
legal_docs = legal_dataset["text"]
# prepare payload (additional information or metadata for documents being inserted)
payload = list(legal_dataset)
return payload, legal_docs
if __name__ == "__main__":
# create qdrant and cohere client
cohere_client = cohere.Client(COHERE_API_KEY)
qdrant_client = QdrantClient(
host=QDRANT_HOST,
prefer_grpc=True,
api_key=QDRANT_API_KEY,
)
# fetch the size of the embeddings depending on which model is being used to create embeddings for documents
vector_size = get_embedding_size()
# create a collection in Qdrant
create_qdrant_collection(vector_size)
# load the set of documents which will be inserted into the Qdrant collection
payload, legal_docs = fetch_legal_documents_and_payload()
# create embedddings for documents and IDs for documents before insertion into Qdrant collection
doc_embeddings, doc_ids = embed_legal_docs(legal_docs)
# insert/update documents in the previously created qdrant collection
update_result = upsert_data_in_collection(doc_embeddings, doc_ids, payload)
collection_info = qdrant_client.get_collection(
collection_name=CREATE_QDRANT_COLLECTION_NAME
)
if update_result is not None:
if collection_info.vectors_count == len(legal_docs):
print("All documents have been successfully added to Qdrant Collection!")
else:
print("Failed to add documents to Qdrant collection")
|