Spaces:
Runtime error
Runtime error
from langchain.retrievers import BM25Retriever, EnsembleRetriever | |
from langchain.vectorstores import FAISS, Chroma, Qdrant | |
from qdrant_client import QdrantClient | |
from langchain_pinecone import PineconeVectorStore | |
import os | |
from dotenv import load_dotenv | |
import pickle | |
load_dotenv() | |
class CreateBM25Retriever: | |
def __init__(self, docs): | |
self.bm25_retriever = BM25Retriever.from_documents(docs) | |
with open('bm25retriever.pkl', 'wb') as outp: | |
pickle.dump(self.bm25_retriever, outp, pickle.HIGHEST_PROTOCOL) | |
class Retriever: | |
def __init__(self, db,per_dir,embeddings, strategy, k, collection_name="mydocuments"): | |
self.db = db | |
self.strategy = strategy | |
self.per_dir = per_dir | |
if self.db == 'faiss': | |
self.db_ = FAISS.load_local(self.per_dir, embeddings, allow_dangerous_deserialization=True) | |
elif self.db == 'chroma': | |
self.db_ = Chroma(persist_directory=self.per_dir, embedding_function=embeddings) | |
elif self.db == 'qdrant': | |
self.db_ = Qdrant(client=QdrantClient(path=self.per_dir), collection_name=collection_name, embeddings=embeddings) | |
elif self.db == 'pinecone': | |
self.db_ = PineconeVectorStore(pinecone_api_key=os.getenv("PINECONE_API_KEY"),index_name=collection_name, embedding=embeddings) | |
self.retriever = self.db_.as_retriever(search_kwargs={"k": k}) | |
if strategy == 'ensemble': | |
with open('bm25retriever.pkl', 'rb') as inp: | |
self.bm25_retriever = pickle.load(inp) | |
self.bm25_retriever.k = k | |
self.retriever = EnsembleRetriever(retrievers=[self.bm25_retriever, self.retriever], | |
weights=[0.4, 0.6]) | |
def get_docs(self, query): | |
return self.retriever.get_relevant_documents(query) | |
def get_context(self, query): | |
docs = self.get_docs(query) | |
context = "" | |
context_list = [] | |
# src = [] | |
for txt in docs: | |
context += '\n\n'+txt.page_content + "\n" + "Source: "+txt.metadata['source'] | |
context_list.append(txt.page_content) | |
# src.append(txt.metadata['source']) | |
# src = max(set(src), key=src.count) | |
return context, context_list | |