Spaces:
Sleeping
Sleeping
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.chat_models import ChatOpenAI | |
from langchain_openai import AzureChatOpenAI,AzureOpenAIEmbeddings | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationChain | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.document_loaders import UnstructuredFileLoader | |
from typing import List, Dict, Tuple | |
import gradio as gr | |
import validators | |
import requests | |
import mimetypes | |
import tempfile | |
import os | |
from langchain.chains.question_answering import load_qa_chain | |
from langchain.llms import OpenAI | |
from langchain.prompts import PromptTemplate | |
from langchain.prompts.prompt import PromptTemplate | |
import pandas as pd | |
from langchain_experimental.agents.agent_toolkits import create_csv_agent | |
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent | |
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor | |
from langchain.agents.agent_types import AgentType | |
# from langchain.agents import create_csv_agent | |
from langchain import OpenAI, LLMChain | |
from openai import AzureOpenAI | |
class ChatDocumentQA: | |
def __init__(self) -> None: | |
pass | |
def _get_empty_state(self) -> Dict[str, None]: | |
"""Create an empty knowledge base.""" | |
return {"knowledge_base": None} | |
def _extract_text_from_pdfs(self, file_paths: List[str]) -> List[str]: | |
"""Extract text content from PDF files. | |
Args: | |
file_paths (List[str]): List of file paths. | |
Returns: | |
List[str]: Extracted text from the PDFs. | |
""" | |
docs = [] | |
loaders = [UnstructuredFileLoader(file_obj, strategy="fast") for file_obj in file_paths] | |
for loader in loaders: | |
docs.extend(loader.load()) | |
return docs | |
def _get_content_from_url(self, urls: str) -> List[str]: | |
"""Fetch content from given URLs. | |
Args: | |
urls (str): Comma-separated URLs. | |
Returns: | |
List[str]: List of text content fetched from the URLs. | |
""" | |
file_paths = [] | |
for url in urls.split(','): | |
if validators.url(url): | |
headers = {'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36',} | |
r = requests.get(url, headers=headers) | |
if r.status_code != 200: | |
raise ValueError("Check the url of your file; returned status code %s" % r.status_code) | |
content_type = r.headers.get("content-type") | |
file_extension = mimetypes.guess_extension(content_type) | |
temp_file = tempfile.NamedTemporaryFile(suffix=file_extension, delete=False) | |
temp_file.write(r.content) | |
file_paths.append(temp_file.name) | |
print("File_Paths:",file_paths) | |
docs = self._extract_text_from_pdfs(file_paths) | |
return docs | |
def _split_text_into_chunks(self, text: str) -> List[str]: | |
"""Split text into smaller chunks. | |
Args: | |
text (str): Input text to be split. | |
Returns: | |
List[str]: List of smaller text chunks. | |
""" | |
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=6000, chunk_overlap=0, length_function=len) | |
chunks = text_splitter.split_documents(text) | |
return chunks | |
def _create_vector_store_from_text_chunks(self, text_chunks: List[str]) -> FAISS: | |
"""Create a vector store from text chunks. | |
Args: | |
text_chunks (List[str]): List of text chunks. | |
Returns: | |
FAISS: Vector store created from the text chunks. | |
""" | |
embeddings = AzureOpenAIEmbeddings( | |
azure_deployment="text-embedding-3-large", | |
) | |
return FAISS.from_documents(documents=text_chunks, embedding=embeddings) | |
def _create_conversation_chain(self,vectorstore): | |
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language. | |
Chat History: {chat_history} | |
Follow Up Input: {question} | |
Standalone question:""" | |
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) | |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
# llm = ChatOpenAI(temperature=0) | |
llm=AzureChatOpenAI(azure_deployment = "GPT-4o") | |
return ConversationalRetrievalChain.from_llm(llm=llm, retriever=vectorstore.as_retriever(), | |
condense_question_prompt=CONDENSE_QUESTION_PROMPT, | |
memory=memory) | |
def _get_documents_knowledge_base(self, file_paths: List[str]) -> Tuple[str, Dict[str, FAISS]]: | |
"""Build knowledge base from uploaded files. | |
Args: | |
file_paths (List[str]): List of file paths. | |
Returns: | |
Tuple[str, Dict]: Tuple containing a status message and the knowledge base. | |
""" | |
file_path = file_paths[0].name | |
file_extension = os.path.splitext(file_path)[1] | |
if file_extension == '.csv': | |
# agent = self.create_agent(file_path) | |
# tools = self.get_agent_tools(agent) | |
# memory,tools,prompt = self.create_memory_for_csv_qa(tools) | |
# agent_chain = self.create_agent_chain_for_csv_qa(memory,tools,prompt) | |
agent_chain = create_csv_agent( | |
AzureChatOpenAI(azure_deployment = "GPT-4o"), | |
file_path, | |
verbose=True, | |
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
) | |
return "file uploaded", {"knowledge_base": agent_chain} | |
else: | |
pdf_docs = [file_path.name for file_path in file_paths] | |
raw_text = self._extract_text_from_pdfs(pdf_docs) | |
text_chunks = self._split_text_into_chunks(raw_text) | |
vectorstore = self._create_vector_store_from_text_chunks(text_chunks) | |
return "file uploaded", {"knowledge_base": vectorstore} | |
def _get_urls_knowledge_base(self, urls: str) -> Tuple[str, Dict[str, FAISS]]: | |
"""Build knowledge base from URLs. | |
Args: | |
urls (str): Comma-separated URLs. | |
Returns: | |
Tuple[str, Dict]: Tuple containing a status message and the knowledge base. | |
""" | |
webpage_text = self._get_content_from_url(urls) | |
text_chunks = self._split_text_into_chunks(webpage_text) | |
vectorstore = self._create_vector_store_from_text_chunks(text_chunks) | |
return "file uploaded", {"knowledge_base": vectorstore} | |
#************************ | |
# csv qa | |
#************************ | |
def create_agent(self,file_path): | |
agent_chain = create_csv_agent( | |
AzureChatOpenAI(azure_deployment = "GPT-4o"), | |
file_path, | |
verbose=True, | |
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
) | |
return agent_chain | |
def get_agent_tools(self,agent): | |
# search = agent | |
tools = [ | |
Tool( | |
name="dataframe qa", | |
func=agent.run, | |
description="useful for when you need to answer questions about table data and dataframe data", | |
) | |
] | |
return tools | |
def create_memory_for_csv_qa(self,tools): | |
prefix = """Have a conversation with a human, answering the following questions about table data and dataframe data as best you can. You have access to the following tools:""" | |
suffix = """Begin!" | |
{chat_history} | |
Question: {input} | |
{agent_scratchpad}""" | |
prompt = ZeroShotAgent.create_prompt( | |
tools, | |
prefix=prefix, | |
suffix=suffix, | |
input_variables=["input", "chat_history", "agent_scratchpad"], | |
) | |
memory = ConversationBufferMemory(memory_key="chat_history",return_messages=True) | |
return memory,tools,prompt | |
def create_agent_chain_for_csv_qa(self,memory,tools,prompt): | |
llm_chain = LLMChain(llm=AzureChatOpenAI(azure_deployment = "GPT-4o"), prompt=prompt) | |
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True) | |
agent_chain = AgentExecutor.from_agent_and_tools( | |
agent=agent, tools=tools, verbose=True, memory=memory | |
) | |
return agent_chain | |
def _get_response(self, message: str, chat_history: List[Tuple[str, str]], state: Dict[str, FAISS],file_paths) -> Tuple[str, List[Tuple[str, str]]]: | |
"""Get a response from the chatbot. | |
Args: | |
message (str): User's message/question. | |
chat_history (List[Tuple[str, str]]): List of chat history as tuples of (user_message, bot_response). | |
state (dict): State containing the knowledge base. | |
Returns: | |
Tuple[str, List[Tuple[str, str]]]: Tuple containing a status message and updated chat history. | |
""" | |
try: | |
if file_paths: | |
file_path = file_paths[0].name | |
file_extension = os.path.splitext(file_path)[1] | |
if file_extension == '.csv': | |
agent_chain = state["knowledge_base"] | |
response = agent_chain.run(input = message) | |
chat_history.append((message, response)) | |
return "", chat_history | |
else: | |
vectorstore = state["knowledge_base"] | |
chat = self._create_conversation_chain(vectorstore) | |
response = chat({"question": message,"chat_history": chat_history}) | |
chat_history.append((message, response["answer"])) | |
return "", chat_history | |
else: | |
vectorstore = state["knowledge_base"] | |
chat = self._create_conversation_chain(vectorstore) | |
response = chat({"question": message,"chat_history": chat_history}) | |
chat_history.append((message, response["answer"])) | |
return "", chat_history | |
except: | |
chat_history.append((message, "Please Upload Document or URL")) | |
return "", chat_history | |
def gradio_interface(self) -> None: | |
"""Create a Gradio interface for the chatbot.""" | |
with gr.Blocks(css="#textbox_id textarea {color: white}",theme='SherlockRamos/Feliz') as demo: | |
gr.HTML(""" | |
<style> | |
.footer { | |
display: none !important; | |
} | |
footer { | |
display: none !important; | |
} | |
#foot { | |
display: none !important; | |
} | |
.svelte-1fzp3xt { | |
display: none !important; | |
} | |
#root > div > div > div { | |
padding-bottom: 0 !important; | |
} | |
.custom-footer { | |
text-align: center; | |
padding: 10px; | |
font-size: 14px; | |
color: #333; | |
} | |
</style> | |
""") | |
gr.HTML("""<div><img src="data:image/jpeg;base64,/9j/4AAQSkZJRgABAQAAAQABAAD/2wCEAAkGBxAQDw4NDQ8PDg0OEA0NDQ4NDQ8NCQ0NFhEWFhURFRUYHSgsJBomGxMVIT0hJSs3Li4wFx8/RDM4QygtLisBCgoKDg0OGhAQFysdHiYrKy0tKy0rNy0tLS0tLSstKy0tLSs3LS0tLS0tLS0tLS0tLS0tLSsrLS0tKystLS0tLf/AABEIAMgAyAMBEQACEQEDEQH/xAAbAAEBAQEBAQEBAAAAAAAAAAAAAQcGBQMCBP/EAEcQAAIBAgEDDA8GBQUAAAAAAAABAgMEEQYSUQUHFiExQVRzgZKToRMVIjM0NVNhY4ORsbLB0SMycXLC0hRis/DxJEJSgqL/xAAaAQEBAAMBAQAAAAAAAAAAAAAAAQQFBgMC/8QAMhEBAAECAgYIBwADAQAAAAAAAAECAwQREiEyQVHwBRMUFTFhodEiIzM0gbHBcZHhQv/aAAwDAQACEQMRAD8AhqXHgAAAAAUKgRQIBQIAQAAAAAAAAAAAoVAigQCgQCkyVCoAAKBAAAABQqBAAAAoEApMlQqAACgQABQoEQAAAAUCAAAAABQIAAAUKBEAAAAFAgAABQIAAoVCIpRAKFEEQABQIAAAAKBAAFCoRFKIAAoVAgAAAAAAABQqBAAAJ5qDJAooVAgAAAAAAAAAAAAAAAAAVAQAAAAAAAAAAAAAAAAAAUKAQIiJORTTIvMM4SnRkxLkaS8vUFUCYAOXqJkqYlfOa4+cGkmIM15eoipiXJNIQfJgNT60RIapSKZgJ4J/l+ivRAgAAAAAFAgHSZGah0brs/Zs77PsTjmvDdz8ceaZFiimuGf0fhrdyas/L+umWQ9n6TpGZHZ6c2wp6NsRERl+/c2D2npekZOz0Pru2zw9Z912D2fpOeOz0J3dZ4JsHtPSdIx2ehe7bPD9+5sHtPSdIx2eg7ts8P37mwe09J0jHZ6Du2zw/fubBrT0nSMdRQndtnh6z7rsHtPSdIx2ehe7bPD9+5sHtNFTnjs9Cd3WeBsHs9FTnjs9B3dZ4JsHtPSdIx2ehe7bPD9+6SyGtN7skfwksfcOz0Hd1ndTk/hutb+GGNGvJPRUgpJ8scDzqwtM73hV0XGfwzl+P+uX1W1AuLbbqxTh5SDcqXmx0cp4V2Zpa29g7tqc5l5mB5Qx51qVQCBAAAAoVAgB1mt/qhRou67PVhTzuw5ufJRxwz8ff1mTh6qYhs+jLtNM1xVPD+uxWUFnwmjz0ZWnTxbSMVaiNr0ldkNnwmjz0TrKeK9qtcfSSOr9o2krmk22klnrHERcpXtNrj6S9Q9HuAefdarW9KTp1a1OnNJNxlNKSR8zVFLxrv0UTlMvlshs+E0eej562ni+e02+PpK7IbPhNHnodZTxO1WuPobIbPhNLnounTxO1WuPpJshs+E0ueidZSdqtcfSRZQWfCqPSJDrKV7Va4+kv7La5hUWdTnGcdMJKUeo+4qiXpTVFWuJffAr7l+KlNSTjJKUWmmmk4taNsnkkxmzPK/J7+Gmq1Lweo2mvJy23h+GGOBhXrWjrc/jsH1VWnGzPi53Exohr9JCnioUAAAAAABEhqhNe5SZL4eIIV9bLvsOMhvfzI+qS1rmW2o2rrwDMMtYZ2qDi9yUaMW1upPBPAwL0Z3Ms3PY6M8RllzlDodgNt5Wvzqf7T37PRm2HdtqZ1psCtvK1/bT/YXs9JHRtnPNyOU2pcLW4dGnKUlmRljPByxe7uJbW0Yl6iKWqxdjqrmUc6oeWeWbFMBmr62N5UozVSjNwksNtPaa0PSfdNc0vq1eqtzpUS1LJrVlXdBVNpVI9zVitzO0rzM2FqrSh0mGxEXqM458XsnoyX8Wq1jGvRq0ZYd3FpN/7Zb0vbgz5rp0oyed23FyjQljk4OLcZLCUW4yW+mtpr2o1blNyRD5pAoAAAAKFQIAVARkhKn1su+0+Mh8aLS+7e221G2dcAZnlj4yXqPfEwL220GN+554Q0sz5b+ADMtcDw2XF0fezAxO257pL6s/j+OcR4MBAIvoSUjVrdTrdXTjdTo49zVhJ4aZRbw6s4ybE/HzwbLoqrRrqo58ZaUZzfgGR5W2yp3txFbjl2RaHnwzn1tmuvR8znycxj6Mr087nkL6HkxPHJ+g+kAoEAAAAACgCK/dl32n+eHxlpLe225G2deAZllj4zXqPfEwb2257G/dc+TTDOl0MAGZa4HhsuLo+9mBidtz3SX1Z/H8c2eDAABDyetkhUzb63f8zjyShNfM9rG2ycDOWIj8fqprhsXTgGZa4UML1PTRg3yORg34+Nz/AElHzeeEOaRjNdTqCqqAgFCoEAAAABSblfSx75T4yHxn1StrabajauuAMzyx8ZL1H6TAvbbn8b9zzwhpZnuggAzLXA8NlxVP5mBidtz/AEj9Wed0ObPBrwCLfI+ad70smX/rbbjIe6R7WY+OPyyMHV8+nndU2E2LqgDNNcbwuHEU/wCpIwsT4tB0r9SOd9Ll/wC+sxmt3qFAKBAAFCgRAKTzVEI8B9bHvlPjIfGWktbTbUbZ14BmeWPjJeo/SYF7bc/jfueeENLM+XQQAZjrgeGviofMwMRtud6S+rPO6HOHgwVAi3yT4JTvetkjTzr63WiWdzYSfyPeztx+WVgYzvRzuqa4bB04BmWuFPG9w0UYRf8A6fzMHE7Tn+k5zuZc+EOaRjtd/wCgKoVAgAAAAKBAKRX0se+U+Mh8ZaS1tNtRtnXgGZ5Y+Mug/SYF7bc/jfueeENLM90EAGZa4HhsuKp/MwMTtuf6R+rPO6HNng14ATGW4p+LU6bW8tc67dTepU3t6JPCKXscjIw0Z1aTYdF2/mTPD2lpaM50ABkuV9wql7cS3oyUFt7mbDB9aZrr+uvnyczj6om9M8+DyMcOo8mJ4QBQAAAAAAAABSK+ll32nxkPjR9Urb222o2rrgDM8sfGXQfIwL225/G/c88IaWZ7oI8ADMtcDw2XFU/mYGI23P8ASP1Z53Q5s8GvVAIRbaik220kkm5Nt4JLDfJlnrKac9TUMjtSHa0O7X21XCVTTHDajHkXvZsrVGjS6PA4fqbeW/2zdAerNfy6oXcaNKpWn92EZSe9jgtwky+LlcUxmxqtUcpTnLblOUpSfnbbx9rNVM51OTrqzzl+dPITeBUUKgQAAAAAAA0kV9LLvtP88PjLSW9ttptnXAVmeWPjJeo/SYN/bc9jZ0cVq51Q0sznQqB4Oq+S1C5qdmqSqKTjGOEXHMwWO8152eVdqmpiXsHbvbT+HYFa+Ur86H7T57PQ8O7LKrIO1/51n/2hh8JOooO7LT1dTMnrW3edSprP3M+bcqi/DRyHpTbincyrWFt2tmHrH2yEHmM9y61eVR/wlGWMItOrJPalJP7vJ78NBiYi5nGUNJ0hitL5dE86nImK1QBQqBFAgFAgFJkqFQAqIr92jSqU28ElOm23tJJS222fVJb22udvLThVDpofU2WnTxdT19vidvLThVDpofUadPE6+3xcBlVc053/AGSE4Th9j3cZxdPdWO3iYV6r42jxdVE4iapnnKHfrVy04Vb9NT+pnTU3kX7fE7e2nCqHTQ+o04Ovt8Tt7acKodND6jTg663xO3lpwq36aH1Jp08V6+3xO3tpwqh00PqNOlOvt8X4qZQWi3bmjyTTfUOsp4vmcTajxq9Hn3eWlnD7s5VnopQlu/i8D4m/RG9419I2afCc/wDfs5XVrK+vXThSXYKTxW0260l+b2bhj14iZ1Q1t/pCu7T8vVz/AIc2v8GN4a5a3VKlfQAAoEAAUKBEAAAAFAgAAAAAUCAAAFCgRAAAABQIAAAUCAAKFQiKUQChRBEAAUCAAAACgQABQqERSiAAKFQIAAAAAAAAUKgQAACeagyQKKFQIAAAAAAAAAAAAAAAAAFQEAAAAAAAAAAAAAAAAAAFCgECKFQIAAAAAAAAUKgQAAUKAQIoVAgAAAAAFAgFCgECAAABQqBFAgAATzUKgBQIBQoBAgAAAUKgQAoVAgAAAAAFQEAoAAFQIoVAgBQqBAAAA//Z" alt="Intercontinental Exchange" style="float:left;width:80px;height:80px;"><h1 style="color:#000;margin-left:4in;padding-top:10px">Virtual Assistant Chatbot</h1></div>""") | |
state = gr.State(self._get_empty_state()) | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Column(scale=0.85): | |
msg = gr.Textbox(label="Question", elem_id="textbox_id") | |
with gr.Column(scale=0.15): | |
file_output = gr.Textbox(label="File Status") | |
with gr.Row(): | |
with gr.Column(scale=0.85): | |
clear = gr.ClearButton([msg, chatbot]) | |
with gr.Column(scale=0.15): | |
upload_button = gr.UploadButton( | |
"Browse File", | |
file_types=[".txt", ".pdf", ".doc", ".docx", ".csv"], | |
file_count="multiple", variant="primary" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_url = gr.Textbox(label="urls", elem_id="textbox_id") | |
input_url.submit(self._get_urls_knowledge_base, input_url, [file_output, state]) | |
upload_button.upload(self._get_documents_knowledge_base, upload_button, [file_output, state]) | |
msg.submit(self._get_response, [msg, chatbot, state,upload_button], [msg, chatbot]) | |
demo.launch(debug=True,allowed_paths=["/content/"]) | |
if __name__ == "__main__": | |
chatdocumentqa = ChatDocumentQA() | |
chatdocumentqa.gradio_interface() |