from sentence_transformers import SentenceTransformer import streamlit as st from sentence_transformers import CrossEncoder from transformers import AutoTokenizer, AutoModel from concurrent.futures import ThreadPoolExecutor, as_completed import pickle import faiss from llama_index.core import VectorStoreIndex,StorageContext from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.vector_stores.faiss import FaissVectorStore from llama_index.core import VectorStoreIndex from llama_index.retrievers.bm25 import BM25Retriever from llama_index.core.schema import NodeWithScore from llama_index.core.retrievers import BaseRetriever from llama_index.vector_stores.chroma import ChromaVectorStore #%pip install llama-index-vector-stores-chroma #pip install chromadb import chromadb from llama_index.vector_stores.chroma import ChromaVectorStore @st.cache_resource(show_spinner=False) class SentenceTransformerRerank(): def __init__( self, top_n, model, device = "cpu", ): self.model = CrossEncoder( model, max_length=512, device=device ) self.top_n=top_n def predict(self,nodes,query = None, ) : query_and_nodes = [ (str(query),str(nodes[i].text)) for i in range(len(nodes)) ] def predict_score(pair): return self.model.predict([pair])[0] #scores = self.model.predict(query_and_nodes, num_workers=10) scores = [] with ThreadPoolExecutor() as executor: # Submit tasks to the executor future_to_index = {executor.submit(predict_score, pair): idx for idx, pair in enumerate(query_and_nodes)} for future in as_completed(future_to_index): idx = future_to_index[future] try: score = future.result() scores.append((idx, score)) except Exception as exc: print(f'Generated an exception: {exc}') # Assign scores back to nodes for idx, score in scores: nodes[idx].score = score new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[ : self.top_n ] return new_nodes @st.cache_resource(show_spinner=False) def load_data(): with open('nodes_clean.pkl', 'rb') as file: embed_model, reranker=load_models() #chroma_client = chromadb.EphemeralClient() #chroma_collection = chroma_client.create_collection("quickstart") #vector_store = ChromaVectorStore(chroma_collection=chroma_collection) nodes=pickle.load( file) d = 768 faiss_index = faiss.IndexFlatL2(d) vector_store = FaissVectorStore(faiss_index=faiss_index ) storage_context = StorageContext.from_defaults(vector_store=vector_store) # use later nodes_clean index = VectorStoreIndex(nodes,embed_model=embed_model,storage_context=storage_context) retriever_dense = index.as_retriever(similarity_top_k=35,embedding=True) retrieverBM25 = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=10) hybrid_retriever = HybridRetriever(retriever_dense, retrieverBM25,reranker) return hybrid_retriever @st.cache_resource(show_spinner=False) def load_models(): EMBEDDING_MODEL = "BAAI/llm-embedder" RANK_MODEL_NAME = "BAAI/bge-reranker-base" embed_model = HuggingFaceEmbedding(EMBEDDING_MODEL, device='cpu') reranker = SentenceTransformerRerank(top_n=25, model=RANK_MODEL_NAME, device='cpu') return embed_model, reranker class HybridRetriever(BaseRetriever): def __init__(self, vector_retriever, bm25_retriever,reranker): self.vector_retriever = vector_retriever self.bm25_retriever = bm25_retriever self.reranker = reranker super().__init__() def _retrieve(self, query, **kwargs): with ThreadPoolExecutor() as executor: bm25_future = executor.submit(self.bm25_retriever.retrieve, query, **kwargs) vector_future = executor.submit(self.vector_retriever.retrieve, query, **kwargs) bm25_nodes = bm25_future.result() vector_nodes = vector_future.result() # combine the two lists of nodes dense_n=20 bm25_n=2 combined_nodes = vector_nodes[dense_n:] + bm25_nodes[bm25_n:] all_nodes = [] node_ids = set() for n in bm25_nodes.copy()[:bm25_n] + vector_nodes[:dense_n]: if n.node.node_id not in node_ids: all_nodes.append(n) node_ids.add(n.node.node_id) #reRank only best of retrieved_nodes reranked_nodes = self.reranker.predict( all_nodes,query ) return reranked_nodes+combined_nodes import re def clean_whitespace(text,k=5): text = text.strip() text=" ".join([i for i in text.split("\n")[:k] if len(i.strip())>25]+text.split("\n")[k:]) text = re.sub(r"\.EU", "", text) #text = re.sub(r"\n+", "\n", text) text = re.sub(r"\s+", " ", text) return text.lower()