import os import streamlit as st from langchain.text_splitter import RecursiveCharacterTextSplitter import re import pathlib from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from langchain.llms import HuggingFacePipeline from langchain.llms import LlamaCpp from langchain import PromptTemplate, LLMChain from langchain.callbacks.manager import CallbackManager from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.embeddings import HuggingFaceEmbeddings from langchain.chains import RetrievalQA from langchain.vectorstores import FAISS from PyPDF2 import PdfReader import os import time from langchain.chains.question_answering import load_qa_chain from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT from langchain.document_loaders import TextLoader from langchain.document_loaders import PyPDFLoader from langchain.document_loaders import Docx2txtLoader from langchain.document_loaders.image import UnstructuredImageLoader from langchain.document_loaders import UnstructuredHTMLLoader from langchain.document_loaders import UnstructuredPowerPointLoader from langchain.document_loaders import TextLoader from langchain.memory import ConversationBufferWindowMemory from langchain.memory import ConversationBufferMemory from langchain.chains import ConversationalRetrievalChain from langchain.memory.chat_message_histories.streamlit import StreamlitChatMessageHistory # sidebar contents with st.sidebar: st.title('DOC-QA DEMO ') st.markdown(''' ## About Detail this application: - LLM model: llama2-7b-chat-4bit - Hardware resource : Huggingface space 8 vCPU 32 GB ''') class UploadDoc: def __init__(self, path_data): self.path_data = path_data def prepare_filetype(self): extension_lists = { ".docx": [], ".pdf": [], ".html": [], ".png": [], ".pptx": [], ".txt": [], } path_list = [] for path, subdirs, files in os.walk(self.path_data): for name in files: path_list.append(os.path.join(path, name)) #print(os.path.join(path, name)) # Loop through the path_list and categorize files for filename in path_list: file_extension = pathlib.Path(filename).suffix #print("File Extension:", file_extension) if file_extension in extension_lists: extension_lists[file_extension].append(filename) return extension_lists def upload_docx(self, extension_lists): #word data_docxs = [] for doc in extension_lists[".docx"]: loader = Docx2txtLoader(doc) data = loader.load() data_docxs.extend(data) return data_docxs def upload_pdf(self, extension_lists): #pdf data_pdf = [] for doc in extension_lists[".pdf"]: loader = PyPDFLoader(doc) data = loader.load_and_split() data_pdf.extend(data) return data_pdf def upload_html(self, extension_lists): #html data_html = [] for doc in extension_lists[".html"]: loader = UnstructuredHTMLLoader(doc) data = loader.load() data_html.extend(data) return data_html def upload_png_ocr(self, extension_lists): #png ocr data_png = [] for doc in extension_lists[".png"]: loader = UnstructuredImageLoader(doc) data = loader.load() data_png.extend(data) return data_png def upload_pptx(self, extension_lists): #power point data_pptx = [] for doc in extension_lists[".pptx"]: loader = UnstructuredPowerPointLoader(doc) data = loader.load() data_pptx.extend(data) return data_pptx def upload_txt(self, extension_lists): #txt data_txt = [] for doc in extension_lists[".txt"]: loader = TextLoader(doc) data = loader.load() data_txt.extend(data) return data_txt def count_files(self, extension_lists): file_extension_counts = {} # Count the quantity of each item for ext, file_list in extension_lists.items(): file_extension_counts[ext] = len(file_list) return print(f"number of file:{file_extension_counts}") # Print the counts # for ext, count in file_extension_counts.items(): # return print(f"{ext}: {count} file") def create_document(self, dataframe=True): documents = [] extension_lists = self.prepare_filetype() self.count_files(extension_lists) upload_functions = { ".docx": self.upload_docx, ".pdf": self.upload_pdf, ".html": self.upload_html, ".png": self.upload_png_ocr, ".pptx": self.upload_pptx, ".txt": self.upload_txt, } for extension, upload_function in upload_functions.items(): if len(extension_lists[extension]) > 0: if extension == ".xlsx" or extension == ".csv": data = upload_function(extension_lists, dataframe) else: data = upload_function(extension_lists) documents.extend(data) return documents def split_docs(documents,chunk_size=500): text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=100) sp_docs = text_splitter.split_documents(documents) return sp_docs @st.cache_resource def load_llama2_llamaCpp(): core_model_name = "llama-2-7b-chat.Q4_0.gguf" #n_gpu_layers = 32 n_batch = 32 callback_manager = CallbackManager([StreamingStdOutCallbackHandler()]) llm = LlamaCpp( model_path=core_model_name, #n_gpu_layers=n_gpu_layers, n_batch=n_batch, callback_manager=callback_manager, verbose=True,n_ctx = 1024, temperature = 0.1, max_tokens = 256 ) return llm def set_custom_prompt(): custom_prompt_template = """ Use the following pieces of information from context to answer the user's question. If you don't know the answer, don't try to make up an answer. Context : {context} Question : {question} Only returns the helpful answer below and nothing else. Helpful answer: """ prompt = PromptTemplate(template=custom_prompt_template, input_variables=['context', 'question', ]) return prompt @st.cache_resource def load_embeddings(): embeddings = HuggingFaceEmbeddings(model_name = "thenlper/gte-base", model_kwargs = {'device': 'cpu'}) return embeddings def main(): data = [] msgs = StreamlitChatMessageHistory(key="langchain_messages") print(msgs) if "messages" not in st.session_state: st.session_state.messages = [] # DB_FAISS_UPLOAD_PATH = "vectorstores/db_faiss" st.header("DOCUMENT QUESTION ANSWERING IS2") directory = "data" data_dir = UploadDoc(directory).create_document() data.extend(data_dir) #create vector from upload #if len(data) > 0 : sp_docs = split_docs(documents = data) st.write(f"This document have {len(sp_docs)} chunks") embeddings = load_embeddings() # with st.spinner('Wait for create vector'): db = FAISS.from_documents(sp_docs, embeddings) # db.save_local(DB_FAISS_UPLOAD_PATH) # st.write(f"Your model is already store in {DB_FAISS_UPLOAD_PATH}") llm = load_llama2_llamaCpp() qa_prompt = set_custom_prompt() #memory = ConversationBufferWindowMemory(k = 0, return_messages=True, input_key= 'question', output_key='answer', memory_key="chat_history") #memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) #doc_chain = load_qa_chain(llm, chain_type="stuff", prompt = qa_prompt) #question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT) #embeddings = load_embeddings() # uploaded_file = st.file_uploader('Choose your .pdf file', type="pdf") # print(uploaded_file) # if uploaded_file is not None: # pdf_reader = PdfReader(uploaded_file) # text = "" # for page in pdf_reader.pages: # text += page.extract_text() # print(text) # db = FAISS.from_texts(text, embeddings) memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, input_key="query", output_key="result") qa_chain = RetrievalQA.from_chain_type( llm = llm, chain_type = "stuff", retriever = db.as_retriever(search_kwargs = {'k':3}), return_source_documents = True, memory = memory, chain_type_kwargs = {"prompt":qa_prompt}) # qa_chain = ConversationalRetrievalChain( # retriever =db.as_retriever(search_kwargs={'k':2}), # question_generator=question_generator, # #condense_question_prompt=CONDENSE_QUESTION_PROMPT, # combine_docs_chain=doc_chain, # return_source_documents=True, # memory = memory, # #get_chat_history=lambda h :h # ) for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Accept user input if query := st.chat_input("What is up?"): # Display user message in chat message container with st.chat_message("user"): st.markdown(query) # Add user message to chat history st.session_state.messages.append({"role": "user", "content": query}) start = time.time() response = qa_chain({'query': query}) # url_list = set([i.metadata['source'] for i in response['source_documents']]) #print(f"condensed quesion : {question_generator.run({'chat_history': response['chat_history'], 'question' : query})}") with st.chat_message("assistant"): st.markdown(response['result']) end = time.time() st.write("Respone time:",int(end-start),"sec") print(response) # Add assistant response to chat history st.session_state.messages.append({"role": "assistant", "content": response['result']}) # with st.expander("See the related documents"): # for count, url in enumerate(url_list): # #url_reg = regex_source(url) # st.write(str(count+1)+":", url) clear_button = st.button("Start new convo") if clear_button : st.session_state.messages = [] qa_chain.memory.chat_memory.clear() if __name__ == '__main__': main()