datawithsuman's picture
Update app.py
5657710 verified
from langchain_community.document_loaders import PyMuPDFLoader, PyPDFDirectoryLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings,
)
from langchain.retrievers import EnsembleRetriever
from langchain_community.retrievers import BM25Retriever
from langchain.prompts import PromptTemplate
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_openai import ChatOpenAI
import streamlit as st
import os
import pickle
import openai
import nest_asyncio
# Apply nest_asyncio
nest_asyncio.apply()
# LLM Model Credential
# Mistral
# key = os.getenv('MISTRAL_API_KEY')
# os.environ["MISTRAL_API_KEY"] = key
# OpenAI
key = os.getenv('OPENAI_API_KEY')
openai.api_key = key
os.environ["OPENAI_API_KEY"] = key
# Load vector db and chunked documents
with open('chunked_data/chunked_docs.pkl', 'rb') as file:
docs = pickle.load(file)
# load from disk
# create the open-source embedding function
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
index = Chroma(persist_directory="./chroma_db", embedding_function=embedding_function)
### Retrievars
# initialize the bm25 retriever and faiss retriever
bm25_retriever = BM25Retriever.from_documents(docs)
bm25_retriever.k = 4
# Vector Retrievar
vectordb = index.as_retriever(search_kwargs={"k": 4})
# initialize the ensemble retriever
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, vectordb], weights=[0.5, 0.5])
### Response Generation
prompt_template = """Use the following pieces of context to answer the question at the end.
If you don't know the answer, just politely refuse, don't try to make up an answer.
{context}
Question: {question}
Answer:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
# Streamlit UI
st.title("Emerging Social Studies QnA")
model = "gpt-3.5-turbo-1106"
llm = ChatOpenAI(model = model)
# llm = ChatMistralAI(api_key=key)
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Enter your query:"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Getting context
st.success("Fetching info...")
context_list = ensemble_retriever.get_relevant_documents(prompt)
# Response
resp = llm.predict(text=PROMPT.format_prompt(
context=context_list,
question=prompt
).text)
st.session_state.messages.append({"role": "assistant", "content": resp})
with st.chat_message("assistant"):
st.markdown(resp)