TestApp / components /vector_db_operations.py
menikev's picture
Upload 9 files
d2ed505 verified
raw
history blame contribute delete
No virus
4.51 kB
import pandas as pd
import os
import chromadb
from chromadb.utils import embedding_functions
import math
def create_domain_identification_database(vdb_path: str,collection_name:str , df: pd.DataFrame) -> None:
"""This function processes the dataframe into the required format, and then creates the following collections in a ChromaDB instance
1. domain_identification_collection - Contains input text embeddings, and the metadata the other columns
Args:
collection_name (str) : name of database collection
vdb_path (str): Relative path of the location of the ChromaDB instance.
df (pd.DataFrame): task scheduling dataset.
"""
#identify the saving location of the ChromaDB
chroma_client = chromadb.PersistentClient(path=vdb_path)
#extract the embedding from hugging face
embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE")
#creating the collection
domain_identification_collection = chroma_client.create_collection(
name=collection_name,
embedding_function=embedding_function,
)
# the main text "query" that will be embedded
domain_identification_documents = [row.query for row in df.itertuples()]
# the metadata
domain_identification_metadata = [
{"domain": row.domain , "label": row.label}
for row in df.itertuples()
]
#index
domain_ids = ["domain_id " + str(row.Index) for row in df.itertuples()]
length = len(df)
num_iteration = length / 166
num_iteration = math.ceil(num_iteration)
start = 0
# start adding the the vectors
for i in range(num_iteration):
if i == num_iteration - 1 :
domain_identification_collection.add(documents=domain_identification_documents[start:], metadatas=domain_identification_metadata[start:], ids=domain_ids[start:])
else:
end = start + 166
domain_identification_collection.add(documents=domain_identification_documents[start:end], metadatas=domain_identification_metadata[start:end], ids=domain_ids[start:end])
start = end
return None
def delete_collection_from_vector_db(vdb_path: str, collection_name: str) -> None:
"""Deletes a particular collection from the persistent ChromaDB instance.
Args:
vdb_path (str): Path of the persistent ChromaDB instance.
collection_name (str): Name of the collection to be deleted.
"""
chroma_client = chromadb.PersistentClient(path=vdb_path)
chroma_client.delete_collection(collection_name)
return None
def list_collections_from_vector_db(vdb_path: str) -> None:
"""Lists all the available collections from the persistent ChromaDB instance.
Args:
vdb_path (str): Path of the persistent ChromaDB instance.
"""
chroma_client = chromadb.PersistentClient(path=vdb_path)
print(chroma_client.list_collections())
def get_collection_from_vector_db(
vdb_path: str, collection_name: str
) -> chromadb.Collection:
"""Fetches a particular ChromaDB collection object from the persistent ChromaDB instance.
Args:
vdb_path (str): Path of the persistent ChromaDB instance.
collection_name (str): Name of the collection which needs to be retrieved.
"""
chroma_client = chromadb.PersistentClient(path=vdb_path)
huggingface_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="sentence-transformers/LaBSE")
collection = chroma_client.get_collection(
name=collection_name, embedding_function=huggingface_ef
)
return collection
def retrieval( input_text : str,
num_results : int,
collection: chromadb.Collection ):
"""fetches the domain name from the collection based on the semantic similarity
args:
input_text : the received text which can be news , posts , or tweets
num_results : number of fetched examples from the collection
collection : the extracted collection from the database that we will fetch examples from
"""
fetched_domain = collection.query(
query_texts = [input_text],
n_results = num_results,
)
#extracting domain name and label from the featched domains
domain = fetched_domain["metadatas"][0][0]["domain"]
label = fetched_domain["metadatas"][0][0]["label"]
distance = fetched_domain["distances"][0][0]
return domain , label , distance