reverse-RAG / ask_app.py
elia-waefler's picture
init files, idea
fcac63a
raw
history blame
9.1 kB
"""
complete, functional RAG App
stores vectors in session state, or locally.
add function to display retrieved documents
"""
# import time
from datetime import datetime
# import openai
# import tiktoken
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from html_templates import css, bot_template, user_template
from langchain.llms import HuggingFaceHub
import os
import numpy as np
import faiss_utils
from langchain_community.vectorstores import FAISS
from langchain.embeddings import OpenAIEmbeddings
def merge_faiss_indices(index1, index2):
"""
Merge two FAISS indices into a new index, assuming both are of the same type and dimensionality.
Args:
index1 (faiss.Index): The first FAISS index.
index2 (faiss.Index): The second FAISS index.
Returns:
faiss.Index: A new FAISS index containing all vectors from index1 and index2.
"""
# Check if both indices are the same type
if type(index1) != type(index2):
raise ValueError("Indices are of different types")
# Check dimensionality
if index1.d != index2.d:
raise ValueError("Indices have different dimensionality")
# Determine type of indices
if isinstance(index1, FAISS.IndexFlatL2):
# Handle simple flat indices
d = index1.d
# Extract vectors from both indices
xb1 = FAISS.rev_swig_ptr(index1.xb.data(), index1.ntotal * d)
xb2 = FAISS.rev_swig_ptr(index2.xb.data(), index2.ntotal * d)
# Combine vectors
xb_combined = np.vstack((xb1, xb2))
# Create a new index and add combined vectors
new_index = FAISS.IndexFlatL2(d)
new_index.add(xb_combined)
return new_index
elif isinstance(index1, FAISS.IndexIVFFlat):
# Handle quantized indices (IndexIVFFlat)
d = index1.d
nlist = index1.nlist
quantizer = FAISS.IndexFlatL2(d) # Re-create the appropriate quantizer
# Create a new index with the same configuration
new_index = FAISS.IndexIVFFlat(quantizer, d, nlist, FAISS.METRIC_L2)
# If the indices are already trained, you can directly add the vectors
# Otherwise, you may need to train new_index using a representative subset of vectors
vecs1 = FAISS.rev_swig_ptr(index1.xb.data(), index1.ntotal * d)
vecs2 = FAISS.rev_swig_ptr(index2.xb.data(), index2.ntotal * d)
new_index.add(vecs1)
new_index.add(vecs2)
return new_index
else:
raise TypeError("Index type not supported for merging in this function")
def get_pdf_text(pdf_docs):
text = ""
for pdf in pdf_docs:
pdf_reader = PdfReader(pdf)
for page in pdf_reader.pages:
text += page.extract_text()
return text
def get_text_chunks(text):
text_splitter = CharacterTextSplitter(
separator="\n",
chunk_size=1000,
chunk_overlap=200,
length_function=len
)
chunks = text_splitter.split_text(text)
return chunks
def get_faiss_vectorstore(text_chunks):
if sst.openai:
my_embeddings = OpenAIEmbeddings()
else:
my_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-xl")
vectorstore = FAISS.from_texts(texts=text_chunks, embedding=my_embeddings)
return vectorstore
def get_conversation_chain(vectorstore):
if sst.openai:
llm = ChatOpenAI()
else:
llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature": 0.5, "max_length": 512})
memory = ConversationBufferMemory(
memory_key='chat_history', return_messages=True)
conversation_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=vectorstore.as_retriever(),
memory=memory
)
return conversation_chain
def handle_userinput(user_question):
response = sst.conversation({'question': user_question})
sst.chat_history = response['chat_history']
for i, message in enumerate(sst.chat_history):
# Display user message
if i % 2 == 0:
st.write(user_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
else:
print(message)
# Display AI response
st.write(bot_template.replace("{{MSG}}", message.content), unsafe_allow_html=True)
# Display source document information if available in the message
if hasattr(message, 'source') and message.source:
st.write(f"Source Document: {message.source}", unsafe_allow_html=True)
if True:
BASE_URL = "https://api.vectara.io/v1"
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
OPENAI_ORG_ID = os.environ["OPENAI_ORG_ID"]
PINECONE_API_KEY = os.environ["PINECONE_API_KEY_LCBIM"]
HUGGINGFACEHUB_API_TOKEN = os.environ["HUGGINGFACEHUB_API_TOKEN"]
VECTARA_API_KEY = os.environ["VECTARA_API_KEY"]
VECTARA_CUSTOMER_ID = os.environ["VECTARA_CUSTOMER_ID"]
headers = {"Authorization": f"Bearer {VECTARA_API_KEY}", "Content-Type": "application/json"}
def main():
st.set_page_config(page_title="Anna Seiler Haus KI-Assistent", page_icon=":hospital:")
st.write(css, unsafe_allow_html=True)
if "conversation" not in sst:
sst.conversation = None
if "chat_history" not in sst:
sst.chat_history = None
if "page" not in sst:
sst.page = "home"
if "openai" not in sst:
sst.openai = True
if "login" not in sst:
sst.login = False
if 'submitted_user_query' not in sst:
sst.submitted_user_query = ''
if 'submitted_user_safe' not in sst:
sst.submitted_user_safe = ''
if 'submitted_user_load' not in sst:
sst.submitted_user_load = ''
def submit_user_query():
sst.submitted_user_query = sst.widget_user_query
sst.widget_user_query = ''
def submit_user_safe():
sst.submitted_user_safe = sst.widget_user_safe
sst.widget_user_safe = ''
if "vectorstore" in sst:
# faiss_name = str(datetime.now().strftime("%Y%m%d%H%M%S")) + "faiss_index"
faiss_utils.save_local(sst.vectorstore, path=sst.submitted_user_safe)
st.sidebar.success("saved")
else:
st.sidebar.warning("No embeddings to save. Please process documents first.")
def submit_user_load():
sst.submitted_user_load = sst.widget_user_load
sst.widget_user_load = ''
if os.path.exists(sst.submitted_user_load):
new_db = faiss_utils.load_vectorstore(f"{sst.submitted_user_load}/faiss_index.index")
if "vectorstore" in sst:
if new_db is not None: # Check if this is working
sst.vectorstore.merge_from(new_db)
sst.conversation = get_conversation_chain(sst.vectorstore)
st.sidebar.success("faiss loaded")
else:
if new_db is not None: # Check if this is working
sst.vectorstore = new_db
sst.conversation = get_conversation_chain(new_db)
st.sidebar.success("faiss loaded")
else:
st.sidebar.warning("Couldn't load/find embeddings")
st.header("Anna Seiler Haus KI-Assistent ASH :hospital:")
if st.text_input("ASK_ASH_PASSWORD: ", type="password") == ASK_ASH_PASSWORD:
#user_question = st.text_input("Ask a question about your documents:", key="user_query", on_change=handle_query)
st.text_input('Ask a question about your documents:', key='widget_user_query', on_change=submit_user_query)
#sst.openai = st.toggle(label="use openai?")
if sst.submitted_user_query:
if "vectorstore" in sst:
handle_userinput(sst.submitted_user_query)
else:
st.warning("no vectorstore loaded.")
with st.sidebar:
st.subheader("Your documents")
pdf_docs = st.file_uploader("Upload your PDFs here and click on 'Process'", accept_multiple_files=True)
if st.button("Process"):
with st.spinner("Processing"):
vec = get_faiss_vectorstore(get_text_chunks(get_pdf_text(pdf_docs)))
sst.vectorstore = vec
sst.conversation = get_conversation_chain(vec)
st.success("embedding complete")
st.text_input('Safe Embeddings to: (copy path of folder)', key='widget_user_safe',
on_change=submit_user_safe)
st.text_input('Load Embeddings from: (copy path of folder)', key='widget_user_load',
on_change=submit_user_load)
if __name__ == '__main__':
sst = st.session_state
ASK_ASH_PASSWORD = os.environ["ASK_ASH_PASSWORD"]
main()