chatbot / app.py
hail75's picture
add rag
7f07a51
raw
history blame
5.22 kB
import os
import streamlit as st
from langchain_openai import OpenAIEmbeddings
from langchain_openai.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers import OpenAIWhisperParser
from langchain_community.document_loaders.blob_loaders.youtube_audio import (
YoutubeAudioLoader,
)
from langchain_community.vectorstores import Chroma
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
openai_api_key = os.getenv("OPENAI_API_KEY")
st.set_page_config(page_title="Chat with your data", page_icon="πŸ€–")
st.title("Chat with your data")
st.header("Add your data for RAG")
data_type = st.radio("Choose the type of data to add:", ("Text", "PDF", "YouTube URL"))
if "vectordb" not in st.session_state:
st.session_state.vectordb = None
def add_text_to_chroma(text):
embeddings = OpenAIEmbeddings()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
texts = text_splitter.split_text(text)
vectordb = Chroma.from_texts(
texts=texts,
embedding=embeddings,
)
return vectordb
def add_pdf_to_chroma(uploaded_pdf):
loader = PyPDFLoader(uploaded_pdf)
pages = loader.load()
embeddings = OpenAIEmbeddings()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
docs = text_splitter.split_documents(pages)
vectordb = Chroma.from_documents(
documents=docs,
embedding=embeddings,
)
return vectordb
def add_youtube_to_chroma(youtube_url):
save_dir = "docs/youtube"
loader = GenericLoader(
YoutubeAudioLoader([youtube_url], save_dir), OpenAIWhisperParser()
)
pages = loader.load()
embeddings = OpenAIEmbeddings()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=150)
docs = text_splitter.split_documents(pages)
vectordb = Chroma.from_documents(
documents=docs, embedding=embeddings, persist_directory="chroma"
)
return vectordb
if data_type == "Text":
user_text = st.text_area("Enter text data")
if st.button("Add"):
st.session_state.vectordb = add_text_to_chroma(user_text)
elif data_type == "PDF":
uploaded_pdf = st.file_uploader("Upload PDF", type="pdf")
if st.button("Add"):
st.session_state.vectordb = add_pdf_to_chroma(uploaded_pdf)
else:
youtube_url = st.text_input("Enter YouTube URL")
if st.button("Add"):
st.session_state.vectordb = add_youtube_to_chroma(youtube_url)
llm = ChatOpenAI(
api_key=openai_api_key, temperature=0.2, model="gpt-3.5-turbo"
)
def get_context_retreiver_chain(vectordb):
retriever = vectordb.as_retriever()
prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
(
"user",
"Given the above conversation, generate a search query to look up in order to get information relevant to the conversation",
),
]
)
retriever_chain = create_history_aware_retriever(llm, retriever, prompt)
return retriever_chain
def get_conversational_rag_chain(retriever_chain):
prompt = ChatPromptTemplate.from_messages([
("system", "Answer the user's questions based on the below context:\n\n{context}"),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
])
stuff_domain_chain = create_stuff_documents_chain(llm, prompt)
return create_retrieval_chain(retriever_chain, stuff_domain_chain)
def get_response(user_input):
if st.session_state.vectordb is None:
return "Please add data first"
retrieveal_chain = get_context_retreiver_chain(st.session_state.vectordb)
converasational_rag_chain = get_conversational_rag_chain(retrieveal_chain)
response = converasational_rag_chain.invoke({
"chat_history": st.session_state.chat_history,
"input": user_input
})
return response
user_query = st.chat_input("Your message")
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
for message in st.session_state.chat_history:
if isinstance(message, HumanMessage):
with st.chat_message("Human"):
st.markdown(message.content)
else:
with st.chat_message("AI"):
st.markdown(message.content)
if user_query and user_query != "":
with st.chat_message("Human"):
st.markdown(user_query)
with st.chat_message("AI"):
ai_response = get_response(user_query)
st.markdown(ai_response)
st.session_state.chat_history.append(HumanMessage(user_query))
st.session_state.chat_history.append(AIMessage(ai_response))