Spaces:
Sleeping
Sleeping
from sentence_transformers import SentenceTransformer | |
import streamlit as st | |
from sentence_transformers import CrossEncoder | |
from transformers import AutoTokenizer, AutoModel | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
import pickle | |
import faiss | |
from llama_index.core import VectorStoreIndex,StorageContext | |
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
from llama_index.vector_stores.faiss import FaissVectorStore | |
from llama_index.core import VectorStoreIndex | |
from llama_index.retrievers.bm25 import BM25Retriever | |
from llama_index.core.schema import NodeWithScore | |
from llama_index.core.retrievers import BaseRetriever | |
from llama_index.vector_stores.chroma import ChromaVectorStore | |
#%pip install llama-index-vector-stores-chroma | |
#pip install chromadb | |
import chromadb | |
from llama_index.vector_stores.chroma import ChromaVectorStore | |
class SentenceTransformerRerank(): | |
def __init__( | |
self, | |
top_n, | |
model, | |
device = "cpu", | |
): | |
self.model = CrossEncoder( | |
model, max_length=512, device=device | |
) | |
self.top_n=top_n | |
def predict(self,nodes,query = None, | |
) : | |
query_and_nodes = [ | |
(str(query),str(nodes[i].text)) | |
for i in range(len(nodes)) | |
] | |
def predict_score(pair): | |
return self.model.predict([pair])[0] | |
#scores = self.model.predict(query_and_nodes, num_workers=10) | |
scores = [] | |
with ThreadPoolExecutor() as executor: | |
# Submit tasks to the executor | |
future_to_index = {executor.submit(predict_score, pair): idx for idx, pair in enumerate(query_and_nodes)} | |
for future in as_completed(future_to_index): | |
idx = future_to_index[future] | |
try: | |
score = future.result() | |
scores.append((idx, score)) | |
except Exception as exc: | |
print(f'Generated an exception: {exc}') | |
# Assign scores back to nodes | |
for idx, score in scores: | |
nodes[idx].score = score | |
new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[ | |
: self.top_n | |
] | |
return new_nodes | |
def load_data(): | |
with open('nodes_clean.pkl', 'rb') as file: | |
embed_model, reranker=load_models() | |
#chroma_client = chromadb.EphemeralClient() | |
#chroma_collection = chroma_client.create_collection("quickstart") | |
#vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
nodes=pickle.load( file) | |
d = 768 | |
faiss_index = faiss.IndexFlatL2(d) | |
vector_store = FaissVectorStore(faiss_index=faiss_index ) | |
storage_context = StorageContext.from_defaults(vector_store=vector_store) | |
# use later nodes_clean | |
index = VectorStoreIndex(nodes,embed_model=embed_model,storage_context=storage_context) | |
retriever_dense = index.as_retriever(similarity_top_k=35,embedding=True) | |
retrieverBM25 = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=10) | |
hybrid_retriever = HybridRetriever(retriever_dense, retrieverBM25,reranker) | |
return hybrid_retriever | |
def load_models(): | |
EMBEDDING_MODEL = "BAAI/llm-embedder" | |
RANK_MODEL_NAME = "BAAI/bge-reranker-base" | |
embed_model = HuggingFaceEmbedding(EMBEDDING_MODEL, device='cpu') | |
reranker = SentenceTransformerRerank(top_n=25, model=RANK_MODEL_NAME, device='cpu') | |
return embed_model, reranker | |
class HybridRetriever(BaseRetriever): | |
def __init__(self, vector_retriever, bm25_retriever,reranker): | |
self.vector_retriever = vector_retriever | |
self.bm25_retriever = bm25_retriever | |
self.reranker = reranker | |
super().__init__() | |
def _retrieve(self, query, **kwargs): | |
with ThreadPoolExecutor() as executor: | |
bm25_future = executor.submit(self.bm25_retriever.retrieve, query, **kwargs) | |
vector_future = executor.submit(self.vector_retriever.retrieve, query, **kwargs) | |
bm25_nodes = bm25_future.result() | |
vector_nodes = vector_future.result() | |
# combine the two lists of nodes | |
dense_n=20 | |
bm25_n=2 | |
combined_nodes = vector_nodes[dense_n:] + bm25_nodes[bm25_n:] | |
all_nodes = [] | |
node_ids = set() | |
for n in bm25_nodes.copy()[:bm25_n] + vector_nodes[:dense_n]: | |
if n.node.node_id not in node_ids: | |
all_nodes.append(n) | |
node_ids.add(n.node.node_id) | |
#reRank only best of retrieved_nodes | |
reranked_nodes = self.reranker.predict( | |
all_nodes,query | |
) | |
return reranked_nodes+combined_nodes | |
import re | |
def clean_whitespace(text,k=5): | |
text = text.strip() | |
text=" ".join([i for i in text.split("\n")[:k] if len(i.strip())>25]+text.split("\n")[k:]) | |
text = re.sub(r"\.EU", "", text) | |
#text = re.sub(r"\n+", "\n", text) | |
text = re.sub(r"\s+", " ", text) | |
return text.lower() | |