import gradio as gr
import os
from pathlib import Path
import re
from unidecode import unidecode
import chromadb
from langchain_community.vectorstores import FAISS, ScaNN, Milvus
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import ConversationChain
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFaceEndpoint
from huggingface_hub import InferenceClient
import torch
api_token = os.getenv("HF_TOKEN")
client = InferenceClient(
"mistralai/Mistral-7B-Instruct-v0.3"
)
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.3"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
def format_prompt(message, history):
prompt = ""
for user_prompt, bot_response in history:
prompt += f"[INST] {user_prompt} [/INST]"
prompt += f" {bot_response} "
prompt += f"[INST] {message} [/INST]"
return prompt
def generate(prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
temperature = float(temperature)
if temperature < 1e-2:
temperature = 1e-2
top_p = float(top_p)
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
seed=42,
)
formatted_prompt = format_prompt(prompt, history)
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
output += response.token.text
yield output
return output
def load_doc(list_file_path, chunk_size, chunk_overlap):
loaders = [PyPDFLoader(x) for x in list_file_path]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
doc_splits = text_splitter.split_documents(pages)
return doc_splits
def create_db(splits, collection_name, db_type):
embedding = HuggingFaceEmbeddings()
if db_type == "ChromaDB":
new_client = chromadb.EphemeralClient()
vectordb = Chroma.from_documents(
documents=splits,
embedding=embedding,
client=new_client,
collection_name=collection_name,
)
elif db_type == "FAISS":
vectordb = FAISS.from_documents(
documents=splits,
embedding=embedding
)
elif db_type == "ScaNN":
vectordb = ScaNN.from_documents(
documents=splits,
embedding=embedding
)
elif db_type == "Milvus":
vectordb = Milvus.from_documents(
documents=splits,
embedding=embedding,
collection_name=collection_name,
)
else:
raise ValueError(f"Unsupported vector database type: {db_type}")
return vectordb
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, initial_prompt, progress=gr.Progress()):
progress(0.1, desc="Initializing HF tokenizer...")
progress(0.5, desc="Initializing HF Hub...")
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
progress(0.75, desc="Defining buffer memory...")
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
progress(0.8, desc="Defining retrieval chain...")
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
qa_chain({"question": initial_prompt}) # Initialize with the initial prompt
progress(0.9, desc="Done!")
return qa_chain
def initialize_llm_no_doc(llm_model, temperature, max_tokens, top_k, initial_prompt, progress=gr.Progress()):
progress(0.1, desc="Initializing HF tokenizer...")
progress(0.5, desc="Initializing HF Hub...")
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
progress(0.75, desc="Defining buffer memory...")
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
conversation_chain = ConversationChain(llm=llm, memory=memory, verbose=False)
conversation_chain({"question": initial_prompt})
progress(0.9, desc="Done!")
return conversation_chain
def format_chat_history(message, chat_history):
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
return formatted_chat_history
def conversation(qa_chain, message, history):
formatted_chat_history = format_chat_history(message, history)
response = qa_chain({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
if "Helpful Answer:" in response_answer:
response_answer = response_answer.split("Helpful Answer:")[-1]
response_sources = response["source_documents"]
response_source1 = response_sources[0].page_content.strip()
response_source2 = response_sources[1].page_content.strip()
response_source3 = response_sources[2].page_content.strip()
response_source1_page = response_sources[0].metadata["page"] + 1
response_source2_page = response_sources[1].metadata["page"] + 1
response_source3_page = response_sources[2].metadata["page"] + 1
new_history = history + [(message, response_answer)]
return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
def conversation_no_doc(llm, message, history):
formatted_chat_history = format_chat_history(message, history)
response = llm({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
new_history = history + [(message, response_answer)]
return llm, gr.update(value=""), new_history
def upload_file(file_obj):
list_file_path = []
for file in file_obj:
list_file_path.append(file.name)
return list_file_path
def demo():
with gr.Blocks(theme="base") as demo:
vector_db = gr.State()
qa_chain = gr.State()
collection_name = gr.State()
initial_prompt = gr.State("")
llm_no_doc = gr.State()
gr.Markdown(
"""