Space_forSer / utils_st.py
Arthur-75's picture
Update utils_st.py
40455cc verified
from sentence_transformers import SentenceTransformer
import streamlit as st
from sentence_transformers import CrossEncoder
from transformers import AutoTokenizer, AutoModel
from concurrent.futures import ThreadPoolExecutor, as_completed
import pickle
import faiss
from llama_index.core import VectorStoreIndex,StorageContext
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.core import VectorStoreIndex
from llama_index.retrievers.bm25 import BM25Retriever
from llama_index.core.schema import NodeWithScore
from llama_index.core.retrievers import BaseRetriever
from llama_index.vector_stores.chroma import ChromaVectorStore
#%pip install llama-index-vector-stores-chroma
#pip install chromadb
import chromadb
from llama_index.vector_stores.chroma import ChromaVectorStore
@st.cache_resource(show_spinner=False)
class SentenceTransformerRerank():
def __init__(
self,
top_n,
model,
device = "cpu",
):
self.model = CrossEncoder(
model, max_length=512, device=device
)
self.top_n=top_n
def predict(self,nodes,query = None,
) :
query_and_nodes = [
(str(query),str(nodes[i].text))
for i in range(len(nodes))
]
def predict_score(pair):
return self.model.predict([pair])[0]
#scores = self.model.predict(query_and_nodes, num_workers=10)
scores = []
with ThreadPoolExecutor() as executor:
# Submit tasks to the executor
future_to_index = {executor.submit(predict_score, pair): idx for idx, pair in enumerate(query_and_nodes)}
for future in as_completed(future_to_index):
idx = future_to_index[future]
try:
score = future.result()
scores.append((idx, score))
except Exception as exc:
print(f'Generated an exception: {exc}')
# Assign scores back to nodes
for idx, score in scores:
nodes[idx].score = score
new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[
: self.top_n
]
return new_nodes
@st.cache_resource(show_spinner=False)
def load_data():
with open('nodes_clean.pkl', 'rb') as file:
embed_model, reranker=load_models()
#chroma_client = chromadb.EphemeralClient()
#chroma_collection = chroma_client.create_collection("quickstart")
#vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
nodes=pickle.load( file)
d = 768
faiss_index = faiss.IndexFlatL2(d)
vector_store = FaissVectorStore(faiss_index=faiss_index )
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# use later nodes_clean
index = VectorStoreIndex(nodes,embed_model=embed_model,storage_context=storage_context)
retriever_dense = index.as_retriever(similarity_top_k=35,embedding=True)
retrieverBM25 = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=10)
hybrid_retriever = HybridRetriever(retriever_dense, retrieverBM25,reranker)
return hybrid_retriever
@st.cache_resource(show_spinner=False)
def load_models():
EMBEDDING_MODEL = "BAAI/llm-embedder"
RANK_MODEL_NAME = "BAAI/bge-reranker-base"
embed_model = HuggingFaceEmbedding(EMBEDDING_MODEL, device='cpu')
reranker = SentenceTransformerRerank(top_n=25, model=RANK_MODEL_NAME, device='cpu')
return embed_model, reranker
class HybridRetriever(BaseRetriever):
def __init__(self, vector_retriever, bm25_retriever,reranker):
self.vector_retriever = vector_retriever
self.bm25_retriever = bm25_retriever
self.reranker = reranker
super().__init__()
def _retrieve(self, query, **kwargs):
with ThreadPoolExecutor() as executor:
bm25_future = executor.submit(self.bm25_retriever.retrieve, query, **kwargs)
vector_future = executor.submit(self.vector_retriever.retrieve, query, **kwargs)
bm25_nodes = bm25_future.result()
vector_nodes = vector_future.result()
# combine the two lists of nodes
dense_n=20
bm25_n=2
combined_nodes = vector_nodes[dense_n:] + bm25_nodes[bm25_n:]
all_nodes = []
node_ids = set()
for n in bm25_nodes.copy()[:bm25_n] + vector_nodes[:dense_n]:
if n.node.node_id not in node_ids:
all_nodes.append(n)
node_ids.add(n.node.node_id)
#reRank only best of retrieved_nodes
reranked_nodes = self.reranker.predict(
all_nodes,query
)
return reranked_nodes+combined_nodes
import re
def clean_whitespace(text,k=5):
text = text.strip()
text=" ".join([i for i in text.split("\n")[:k] if len(i.strip())>25]+text.split("\n")[k:])
text = re.sub(r"\.EU", "", text)
#text = re.sub(r"\n+", "\n", text)
text = re.sub(r"\s+", " ", text)
return text.lower()