Aabbhishekk's picture
Update app.py
2f49d42
raw
history blame
4.97 kB
from langchain.agents import AgentType, Tool, initialize_agent
from langchain.callbacks import StreamlitCallbackHandler
from langchain.chains import RetrievalQA
from langchain.chains.conversation.memory import ConversationBufferMemory
from utils.ask_human import CustomAskHumanTool
from utils.model_params import get_model_params
from utils.prompts import create_agent_prompt, create_qa_prompt
from PyPDF2 import PdfReader
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain import HuggingFaceHub
import torch
import streamlit as st
from langchain.utilities import SerpAPIWrapper
import os
hf_token = os.environ['HF_TOKEN']
serp_token = os.environ['SERP_TOKEN']
repo_id = "sentence-transformers/all-mpnet-base-v2"
HUGGINGFACEHUB_API_TOKEN= hf_token
hf = HuggingFaceHubEmbeddings(
repo_id=repo_id,
task="feature-extraction",
huggingfacehub_api_token= HUGGINGFACEHUB_API_TOKEN,
)
llm = HuggingFaceHub(
repo_id='mistralai/Mistral-7B-Instruct-v0.2',
huggingfacehub_api_token = HUGGINGFACEHUB_API_TOKEN,
)
from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain import PromptTemplate
### PAGE ELEMENTS
# st.set_page_config(
# page_title="RAG Agent Demo",
# page_icon="🦜",
# layout="centered",
# initial_sidebar_state="collapsed",
# )
# st.markdown("### Leveraging the User to Improve Agents in RAG Use Cases")
def main():
st.set_page_config(page_title="Ask your PDF powered by Search Agents")
st.header("Ask your PDF powered by Search Agents 💬")
# upload file
pdf = st.file_uploader("Upload your PDF and chat with Agent", type="pdf")
# extract the text
if pdf is not None:
pdf_reader = PdfReader(pdf)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
# Split documents and create text snippets
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
texts = text_splitter.split_text(text)
embeddings = hf
knowledge_base = FAISS.from_texts(texts, embeddings)
retriever = knowledge_base.as_retriever(search_kwargs={"k":3})
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=False,
chain_type_kwargs={
"prompt": create_qa_prompt(),
},
)
conversational_memory = ConversationBufferMemory(
memory_key="chat_history", k=3, return_messages=True
)
# tool for db search
db_search_tool = Tool(
name="dbRetrievalTool",
func=qa_chain,
description="""Use this tool to answer document related questions. The input to this tool should be the question.""",
)
search = SerpAPIWrapper(serpapi_api_key=serp_token)
google_searchtool= Tool(
name="Current Search",
func=search.run,
description="use this tool to answer real time or current search related questions.",
)
# tool for asking human
human_ask_tool = CustomAskHumanTool()
# agent prompt
prefix, format_instructions, suffix = create_agent_prompt()
mode = "Agent with AskHuman tool"
# initialize agent
agent = initialize_agent(
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
tools=[db_search_tool,google_searchtool],
llm=llm,
verbose=True,
max_iterations=5,
early_stopping_method="generate",
memory=conversational_memory,
agent_kwargs={
"prefix": prefix,
"format_instructions": format_instructions,
"suffix": suffix,
},
handle_parsing_errors=True,
)
# question form
with st.form(key="form"):
user_input = st.text_input("Ask your question")
submit_clicked = st.form_submit_button("Submit Question")
# output container
output_container = st.empty()
if submit_clicked:
output_container = output_container.container()
output_container.chat_message("user").write(user_input)
answer_container = output_container.chat_message("assistant", avatar="🦜")
st_callback = StreamlitCallbackHandler(answer_container)
answer = agent.run(user_input, callbacks=[st_callback])
answer_container = output_container.container()
answer_container.chat_message("assistant").write(answer)
if __name__ == '__main__':
main()