import gradio as gr import os from dotenv import load_dotenv from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms import HuggingFacePipeline from langchain.chains import ConversationChain from langchain.memory import ConversationBufferMemory from langchain.llms import HuggingFaceHub # from doctr.models import ocr_predictor # from doctr.io import DocumentFile from pathlib import Path import chromadb # Later Packages from getpass import getpass import weasyprint import matplotlib.pyplot as plt from langchain.document_loaders import PyPDFDirectoryLoader load_dotenv() # model = ocr_predictor(pretrained = True) huggingfacehub_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") openai_key = os.getenv("OPEN_API_KEY") # default_persist_directory = './chroma_HF/' list_llm = ["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \ "google/gemma-7b-it","google/gemma-2b-it", \ "HuggingFaceH4/zephyr-7b-beta", \ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "tiiuae/falcon-7b-instruct", \ "google/flan-t5-xxl" ] list_llm_simple = [os.path.basename(llm) for llm in list_llm] #Extract text data from doctr reaponse def extract_value_from_response(response): value = '' for page in response.pages: for block in page.blocks: for line in block.lines: for word in line.words: value += " "+word.value return value # Craete PDf from URL def create_pdf_from_url(url): pdf = weasyprint.HTML(url).write_pdf() output_dir = "pdfDir" if not os.path.exists(output_dir): os.makedirs(output_dir) file_path = os.path.join(output_dir,'url_pdf.pdf') with open(file_path,'wb') as f: f.write(pdf) return file_path # Load PDF document and create doc splits def load_doc(list_file_path, chunk_size, chunk_overlap): # Processing for one document only loaders = [PyPDFLoader(x) for x in list_file_path] pages = [] for loader in loaders: pages.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter( chunk_size = chunk_size, chunk_overlap = chunk_overlap) doc_splits = text_splitter.split_documents(pages) # if len(doc_splits) == 0: # doc = DocumentFile.from_pdf(list_file_path[0]) # result = model(doc) # response = extract_value_from_response(result) # doc_splits = text_splitter.split_documents(response) return doc_splits # Create vector database def create_db(splits, collection_name): embedding = HuggingFaceEmbeddings() new_client = chromadb.EphemeralClient() vectordb = Chroma.from_documents( documents = splits, embedding = embedding, client = new_client, collection_name = collection_name, # persist_directory=default_persist_directory ) return vectordb # Load vector database def load_db(): embedding = HuggingFaceEmbeddings() vectordb = Chroma( embedding_function = embedding) return vectordb # Initialize langchain LLM chain def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): progress(0.1, desc="Initializing HF tokenizer...") # HuggingFaceHub uses HF inference endpoints progress(0.5, desc="Initializing HF Hub...") # Use of trust_remote_code as model_kwargs # Warning: langchain issue # URL: https://github.com/langchain-ai/langchain/issues/6080 if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1": llm = HuggingFaceHub( repo_id=llm_model, model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True} ) elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0": llm = HuggingFaceHub( repo_id=llm_model, model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k} ) else: llm = HuggingFaceHub( repo_id=llm_model, model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k} ) progress(0.75, desc="Defining buffer memory...") memory = ConversationBufferMemory( memory_key="chat_history", output_key='answer', return_messages=True ) progress(0.8, desc="Defining retrieval chain...") retriever = vector_db.as_retriever() qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever = retriever, chain_type = "stuff", memory = memory, # combine_docs_chain_kwargs={"prompt": your_prompt}) return_source_documents=True, #return_generated_question=False, verbose = False, ) progress(0.9, desc="Done!") return qa_chain # Initialize database def initialize_database(list_file_obj, chunk_size, chunk_overlap, vector_db, url, progress = gr.Progress()): if url != "": file_path = create_pdf_from_url(url) list_file_obj = [] list_file_obj.append(file_path) list_file_path = list_file_obj else: # Create list of documents (when valid) list_file_path = [x.name for x in list_file_obj if x is not None] # Create collection_name for vector database progress(0.1, desc="Creating collection name...") collection_name = Path(list_file_path[0]).stem # Fix potential issues from naming convention ## Remove spaces collection_name = collection_name.replace(" ", "-") ## Ensure it meets the minimum length (3 characters) if len(collection_name) < 3: collection_name += "-XX" # Append extra characters if too short ## Limit the length to 50 characters collection_name = collection_name[:50] ## Enforce that it starts with an alphanumeric character if not collection_name[0].isalnum(): collection_name = 'A' + collection_name[1:] ## Enforce that it ends with an alphanumeric character if not collection_name[-1].isalnum(): collection_name = collection_name[:-1] + 'Z' # Print the collection name for verification print('Collection name:', collection_name) print('Collection name: ', collection_name) progress(0.25, desc="Loading document...") # Load document and create splits doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap) # Create or load vector database progress(0.7, desc="Generating vector database...") # global vector_db vector_db = create_db(doc_splits, collection_name) return vector_db, collection_name, gr.update(value = ""), "Complete!" def re_initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db): llm_name = list_llm[llm_option] print("llm_name: ",llm_name) qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db) return qa_chain def format_chat_history(message, chat_history): formatted_chat_history = [] for user_message, bot_message in chat_history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") return formatted_chat_history def conversation(qa_chain, message, history, llm_option): formatted_chat_history = format_chat_history(message, history) # Generate response using QA chain response = qa_chain({"question": message, "chat_history": formatted_chat_history}) response_answer = response["answer"] if response_answer.find("Helpful Answer:") != -1: response_answer = response_answer.split("Helpful Answer:")[-1] new_history = history + [(message, response_answer)] return qa_chain, gr.update(value = ""), new_history def upload_file(file_obj): list_file_path = [] for idx, file in enumerate(file_obj): file_path = file_obj.name list_file_path.append(file_path) # print(file_path) return list_file_path def demo(): with gr.Blocks(theme = "base") as demo: vector_db = gr.State() qa_chain = gr.State() collection_name = gr.State() gr.Markdown( '''
PDF Document Chatbot
''') with gr.Row(): with gr.Row(): with gr.Column(): document = gr.Files(file_count="multiple", file_types=[".pdf"], interactive=True, label="Upload your PDF documents (single or multiple)") with gr.Row(): gr.Markdown( '''
OR
''') with gr.Row(): url = gr.Textbox(placeholder = "Enter your URL Here") with gr.Row(): db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database", visible = False) with gr.Accordion("Advanced options - Document text splitter", open=False, visible = False): with gr.Row(): slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True, visible = False) with gr.Row(): slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True, visible = False) llm_btn = gr.Radio(list_llm_simple, label = "LLM models", type = "index", info = "Choose your LLM model") db_progres = gr.Textbox(label="Vector database initialization", value="None") with gr.Row(): submit_file = gr.Button("Submit File") with gr.Row(): with gr.Column(): chatbot = gr.Chatbot() msg = gr.Textbox(placeholder = "Type Your Message") with gr.Accordion("Advanced options - LLM model", open = False): with gr.Row(): slider_temperature = gr.Slider(minimum = 0.0, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True) with gr.Row(): slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True) with gr.Row(): slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True) with gr.Row(): submit_btn = gr.Button("Submit") # clear_btn = gr.ClearButton([msg2, chatbot]) # Preprocessing events #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document]) submit_file.click(initialize_database, \ inputs=[document, slider_chunk_size, slider_chunk_overlap, vector_db, url], \ outputs = [vector_db, collection_name, url, db_progres]) llm_btn.change( re_initialize_LLM, \ inputs = [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \ outputs = [qa_chain] ) msg.submit(conversation, \ inputs=[qa_chain, msg, chatbot, llm_btn], \ outputs=[qa_chain, msg, chatbot], \ queue=False) submit_btn.click(conversation, \ inputs=[qa_chain, msg, chatbot, llm_btn], \ outputs=[qa_chain, msg, chatbot], \ queue=False) demo.queue().launch(share = True, debug = True) if __name__ == "__main__": demo()