Spaces:
Runtime error
Runtime error
File size: 5,039 Bytes
b7e13eb 621bc57 b7e13eb 621bc57 b7e13eb 0d49d2e b7e13eb 621bc57 b7e13eb 0ed257c b7e13eb f823183 b7e13eb f823183 b7e13eb 3454a86 b7e13eb 3454a86 b7e13eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from langchain.agents import AgentType, Tool, initialize_agent
from langchain.callbacks import StreamlitCallbackHandler
from langchain.chains import RetrievalQA
from langchain.chains.conversation.memory import ConversationBufferMemory
from utils.ask_human import CustomAskHumanTool
from utils.model_params import get_model_params
from utils.prompts import create_agent_prompt, create_qa_prompt
from PyPDF2 import PdfReader
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import HuggingFaceHubEmbeddings
from langchain import HuggingFaceHub
import torch
import streamlit as st
from langchain.utilities import SerpAPIWrapper
import os
hf_token = os.environ['HF_TOKEN']
serp_token = os.environ['SERP_TOKEN']
repo_id = "sentence-transformers/all-mpnet-base-v2"
HUGGINGFACEHUB_API_TOKEN= hf_token
hf = HuggingFaceHubEmbeddings(
repo_id=repo_id,
task="feature-extraction",
huggingfacehub_api_token= HUGGINGFACEHUB_API_TOKEN,
)
llm = HuggingFaceHub(
repo_id='mistralai/Mistral-7B-Instruct-v0.2',
huggingfacehub_api_token = HUGGINGFACEHUB_API_TOKEN,
)
from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain import PromptTemplate
### PAGE ELEMENTS
# st.set_page_config(
# page_title="RAG Agent Demo",
# page_icon="π¦",
# layout="centered",
# initial_sidebar_state="collapsed",
# )
# st.markdown("### Leveraging the User to Improve Agents in RAG Use Cases")
def main():
st.set_page_config(page_title="Ask your PDF powered by Search Agents")
st.header("Ask your PDF powered by Search Agents π¬")
# upload file
pdf = st.file_uploader("Upload your PDF and chat with Agent", type="pdf")
# extract the text
if pdf is not None:
pdf_reader = PdfReader(pdf)
text = ""
for page in pdf_reader.pages:
text += page.extract_text()
# Split documents and create text snippets
text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
texts = text_splitter.split_text(text)
embeddings = hf
knowledge_base = FAISS.from_texts(texts, embeddings)
retriever = knowledge_base.as_retriever(search_kwargs={"k":3})
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=False,
chain_type_kwargs={
"prompt": create_qa_prompt(),
},
)
conversational_memory = ConversationBufferMemory(
memory_key="chat_history", k=3, return_messages=True
)
# tool for db search
db_search_tool = Tool(
name="dbRetrievalTool",
func=qa_chain,
description="""Use this tool to answer document related questions. The input to this tool should be the question.""",
)
search = SerpAPIWrapper(serpapi_api_key=serp_token)
google_searchtool= Tool(
name="Current Search",
func=search.run,
description="use this tool to answer real time or current search related questions.",
)
# tool for asking human
human_ask_tool = CustomAskHumanTool()
# agent prompt
prefix, format_instructions, suffix = create_agent_prompt()
mode = "Agent with AskHuman tool"
# initialize agent
agent = initialize_agent(
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
tools=[human_ask_tool,db_search_tool,google_searchtool]
if mode == "Agent with AskHuman tool"
llm=llm,
verbose=True,
max_iterations=5,
early_stopping_method="generate",
memory=conversational_memory,
agent_kwargs={
"prefix": prefix,
"format_instructions": format_instructions,
"suffix": suffix,
},
handle_parsing_errors=True,
)
# question form
with st.form(key="form"):
user_input = st.text_input("Ask your question")
submit_clicked = st.form_submit_button("Submit Question")
# output container
output_container = st.empty()
if submit_clicked:
output_container = output_container.container()
output_container.chat_message("user").write(user_input)
answer_container = output_container.chat_message("assistant", avatar="π¦")
st_callback = StreamlitCallbackHandler(answer_container)
answer = agent.run(user_input, callbacks=[st_callback])
answer_container = output_container.container()
answer_container.chat_message("assistant").write(answer)
if __name__ == '__main__':
main() |