Spaces:
Sleeping
Sleeping

Medivocate : An AI-powered platform exploring African history, culture, and traditional medicine, fostering understanding and appreciation of the continent's rich heritage.
15aea1e
verified
import os | |
from typing import List, Union | |
from langchain.retrievers import EnsembleRetriever, MultiQueryRetriever | |
from langchain_chroma import Chroma | |
from langchain_community.retrievers import BM25Retriever | |
from langchain_core.documents import Document | |
from tqdm import tqdm | |
from transformers import AutoTokenizer | |
from ..utilities.llm_models import get_llm_model_embedding | |
from .document_loader import DocumentLoader | |
from .vector_store import get_collection_name | |
from .prompts import DEFAULT_QUERY_PROMPT | |
class VectorStoreManager: | |
""" | |
Manages vector store initialization, updates, and retrieval. | |
""" | |
def __init__(self, persist_directory: str, batch_size: int = 64): | |
""" | |
Initializes the VectorStoreManager with the given parameters. | |
Args: | |
persist_directory (str): Directory to persist the vector store. | |
batch_size (int): Number of documents to process in each batch. | |
""" | |
self.persist_directory = persist_directory | |
self.batch_size = batch_size | |
self.embeddings = get_llm_model_embedding() | |
self.collection_name = get_collection_name() | |
self.vector_stores: dict[str, Union[Chroma, BM25Retriever]] = { | |
"chroma": None, | |
"bm25": None, | |
} | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
os.getenv("HF_MODEL", "meta-llama/Llama-3.2-1B") | |
) | |
self.vs_initialized = False | |
self.vector_store = None | |
def _batch_process_documents(self, documents: List[Document]): | |
""" | |
Processes documents in batches for vector store initialization. | |
Args: | |
documents (List[Document]): List of documents to process. | |
""" | |
for i in tqdm( | |
range(0, len(documents), self.batch_size), desc="Processing documents" | |
): | |
batch = documents[i : i + self.batch_size] | |
if not self.vs_initialized: | |
self.vector_stores["chroma"] = Chroma.from_documents( | |
collection_name=self.collection_name, | |
documents=batch, | |
embedding=self.embeddings, | |
persist_directory=self.persist_directory, | |
) | |
self.vs_initialized = True | |
else: | |
self.vector_stores["chroma"].add_documents(batch) | |
self.vector_stores["bm25"] = BM25Retriever.from_documents( | |
documents, tokenizer=self.tokenizer | |
) | |
def initialize_vector_store(self, documents: List[Document] = None): | |
""" | |
Initializes or loads the vector store. | |
Args: | |
documents (List[Document], optional): List of documents to initialize the vector store. Defaults to None. | |
""" | |
if documents: | |
self._batch_process_documents(documents) | |
else: | |
self.vector_stores["chroma"] = Chroma( | |
collection_name=self.collection_name, | |
persist_directory=self.persist_directory, | |
embedding_function=self.embeddings, | |
) | |
all_documents = self.vector_stores["chroma"].get( | |
include=["documents", "metadatas"] | |
) | |
documents = [ | |
Document(page_content=content, id=doc_id, metadata=metadata) | |
for content, doc_id, metadata in zip( | |
all_documents["documents"], | |
all_documents["ids"], | |
all_documents["metadatas"], | |
) | |
] | |
self.vector_stores["bm25"] = BM25Retriever.from_documents(documents) | |
self.vs_initialized = True | |
def create_retriever( | |
self, llm, n_documents: int, bm25_portion: float = 0.8 | |
) -> EnsembleRetriever: | |
""" | |
Creates an ensemble retriever combining Chroma and BM25. | |
Args: | |
llm: Language model to use for retrieval. | |
n_documents (int): Number of documents to retrieve. | |
bm25_portion (float): Proportion of BM25 retriever in the ensemble. | |
Returns: | |
EnsembleRetriever: The created ensemble retriever. | |
""" | |
self.vector_stores["bm25"].k = n_documents | |
self.vector_store = MultiQueryRetriever.from_llm( | |
retriever=EnsembleRetriever( | |
retrievers=[ | |
self.vector_stores["bm25"], | |
self.vector_stores["chroma"].as_retriever( | |
search_kwargs={"k": n_documents} | |
), | |
], | |
weights=[bm25_portion, 1 - bm25_portion], | |
), | |
llm=llm, | |
include_original=True, | |
prompt=DEFAULT_QUERY_PROMPT | |
) | |
return self.vector_store | |
def load_and_process_documents(self, doc_dir) -> List[Document]: | |
""" | |
Loads and processes documents from the specified directory. | |
Returns: | |
List[Document]: List of loaded and processed documents. | |
""" | |
loader = DocumentLoader(doc_dir) | |
return loader.load_documents() | |