from langchain_community.document_loaders import PyMuPDFLoader, PyPDFDirectoryLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_chroma import Chroma from langchain_community.embeddings.sentence_transformer import ( SentenceTransformerEmbeddings, ) from langchain.retrievers import EnsembleRetriever from langchain_community.retrievers import BM25Retriever from langchain.prompts import PromptTemplate from langchain_mistralai.chat_models import ChatMistralAI from langchain_openai import ChatOpenAI import streamlit as st import os import pickle import openai import nest_asyncio # Apply nest_asyncio nest_asyncio.apply() # LLM Model Credential # Mistral # key = os.getenv('MISTRAL_API_KEY') # os.environ["MISTRAL_API_KEY"] = key # OpenAI key = os.getenv('OPENAI_API_KEY') openai.api_key = key os.environ["OPENAI_API_KEY"] = key # Load vector db and chunked documents with open('chunked_data/chunked_docs.pkl', 'rb') as file: docs = pickle.load(file) # load from disk # create the open-source embedding function embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2") index = Chroma(persist_directory="./chroma_db", embedding_function=embedding_function) ### Retrievars # initialize the bm25 retriever and faiss retriever bm25_retriever = BM25Retriever.from_documents(docs) bm25_retriever.k = 4 # Vector Retrievar vectordb = index.as_retriever(search_kwargs={"k": 4}) # initialize the ensemble retriever ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, vectordb], weights=[0.5, 0.5]) ### Response Generation prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just politely refuse, don't try to make up an answer. {context} Question: {question} Answer:""" PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) # Streamlit UI st.title("Emerging Social Studies QnA") model = "gpt-3.5-turbo-1106" llm = ChatOpenAI(model = model) # llm = ChatMistralAI(api_key=key) if "messages" not in st.session_state: st.session_state.messages = [] # Display chat messages from history on app rerun for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("Enter your query:"): # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # Getting context st.success("Fetching info...") context_list = ensemble_retriever.get_relevant_documents(prompt) # Response resp = llm.predict(text=PROMPT.format_prompt( context=context_list, question=prompt ).text) st.session_state.messages.append({"role": "assistant", "content": resp}) with st.chat_message("assistant"): st.markdown(resp)