import gradio as gr import os from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import Chroma from langchain.chains import ConversationalRetrievalChain from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.llms import HuggingFacePipeline from langchain.chains import ConversationChain from langchain.memory import ConversationBufferMemory from langchain_community.llms import HuggingFaceEndpoint from pathlib import Path import chromadb from unidecode import unidecode from transformers import AutoTokenizer import transformers import torch import tqdm import accelerate import re import torch from sacrebleu import corpus_bleu from rouge_score import rouge_scorer from bert_score import score from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline import nltk from nltk.util import ngrams api_key = os.getenv('API_KEY') # default_persist_directory = './chroma_HF/' list_llm = ["mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", \ "google/gemma-7b-it","google/gemma-2b-it", \ "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1", \ "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", \ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct", "tiiuae/falcon-7b-instruct", \ "google/flan-t5-xxl" ] list_llm_simple = [os.path.basename(llm) for llm in list_llm] # Load PDF document and create doc splits def load_doc(list_file_path, chunk_size, chunk_overlap): # Processing for one document only # loader = PyPDFLoader(file_path) # pages = loader.load() loaders = [PyPDFLoader(x) for x in list_file_path] pages = [] for loader in loaders: pages.extend(loader.load()) # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50) text_splitter = RecursiveCharacterTextSplitter( chunk_size = chunk_size, chunk_overlap = chunk_overlap) doc_splits = text_splitter.split_documents(pages) return doc_splits # Create vector database def create_db(splits, collection_name): embedding = HuggingFaceEmbeddings() new_client = chromadb.EphemeralClient() vectordb = Chroma.from_documents( documents=splits, embedding=embedding, client=new_client, collection_name=collection_name, # persist_directory=default_persist_directory ) return vectordb # Load vector database def load_db(): embedding = HuggingFaceEmbeddings() vectordb = Chroma( # persist_directory=default_persist_directory, embedding_function=embedding) return vectordb # Initialize langchain LLM chain def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): progress(0.1, desc="Initializing HF tokenizer...") # HuggingFaceHub uses HF inference endpoints progress(0.5, desc="Initializing HF Hub...") # Use of trust_remote_code as model_kwargs # Warning: langchain issue # URL: https://github.com/langchain-ai/langchain/issues/6080 if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1": llm = HuggingFaceEndpoint( repo_id=llm_model, # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True} temperature = temperature, max_new_tokens = max_tokens, top_k = top_k, load_in_8bit = True, huggingfacehub_api_token = 'api_key', ) elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1","mosaicml/mpt-7b-instruct"]: raise gr.Error("LLM model is too large to be loaded automatically on free inference endpoint") llm = HuggingFaceEndpoint( repo_id=llm_model, temperature = temperature, max_new_tokens = max_tokens, top_k = top_k, huggingfacehub_api_token = 'api_key', ) elif llm_model == "microsoft/phi-2": # raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...") llm = HuggingFaceEndpoint( repo_id=llm_model, # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"} temperature = temperature, max_new_tokens = max_tokens, top_k = top_k, trust_remote_code = True, torch_dtype = "auto", huggingfacehub_api_token = 'api_key', ) elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0": llm = HuggingFaceEndpoint( repo_id=llm_model, # model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k} temperature = temperature, max_new_tokens = 250, top_k = top_k, huggingfacehub_api_token = 'api_key', ) elif llm_model == "meta-llama/Llama-2-7b-chat-hf": raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...") llm = HuggingFaceEndpoint( repo_id=llm_model, # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k} temperature = temperature, max_new_tokens = max_tokens, top_k = top_k, huggingfacehub_api_token = 'api_key', ) else: llm = HuggingFaceEndpoint( repo_id=llm_model, # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"} # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k} temperature = temperature, max_new_tokens = max_tokens, top_k = top_k, huggingfacehub_api_token = 'api_key', ) progress(0.75, desc="Defining buffer memory...") memory = ConversationBufferMemory( memory_key="chat_history", output_key='answer', return_messages=True ) # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3}) retriever=vector_db.as_retriever() progress(0.8, desc="Defining retrieval chain...") qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, # combine_docs_chain_kwargs={"prompt": your_prompt}) return_source_documents=True, #return_generated_question=False, verbose=False, ) progress(0.9, desc="Done!") return qa_chain # Generate collection name for vector database # - Use filepath as input, ensuring unicode text def create_collection_name(filepath): # Extract filename without extension collection_name = Path(filepath).stem # Fix potential issues from naming convention ## Remove space collection_name = collection_name.replace(" ","-") ## ASCII transliterations of Unicode text collection_name = unidecode(collection_name) ## Remove special characters #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0] collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name) ## Limit length to 50 characters collection_name = collection_name[:50] ## Minimum length of 3 characters if len(collection_name) < 3: collection_name = collection_name + 'xyz' ## Enforce start and end as alphanumeric character if not collection_name[0].isalnum(): collection_name = 'A' + collection_name[1:] if not collection_name[-1].isalnum(): collection_name = collection_name[:-1] + 'Z' print('Filepath: ', filepath) print('Collection name: ', collection_name) return collection_name # Initialize database def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()): # Create list of documents (when valid) list_file_path = [x.name for x in list_file_obj if x is not None] # Create collection_name for vector database progress(0.1, desc="Creating collection name...") collection_name = create_collection_name(list_file_path[0]) progress(0.25, desc="Loading document...") # Load document and create splits doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap) # Create or load vector database progress(0.5, desc="Generating vector database...") # global vector_db vector_db = create_db(doc_splits, collection_name) progress(0.9, desc="Done!") return vector_db, collection_name, "Complete!" def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()): # print("llm_option",llm_option) llm_name = list_llm[llm_option] print("llm_name: ",llm_name) qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress) return qa_chain, "Complete!" def format_chat_history(message, chat_history): formatted_chat_history = [] for user_message, bot_message in chat_history: formatted_chat_history.append(f"User: {user_message}") formatted_chat_history.append(f"Assistant: {bot_message}") return formatted_chat_history #---------------------------------------------------------------------------------- def load_gpt2_model(): model = GPT2LMHeadModel.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained('gpt2') return model, tokenizer gpt2_model, gpt2_tokenizer = load_gpt2_model() bias_pipeline = pipeline("zero-shot-classification", model="Hate-speech-CNERG/dehatebert-mono-english") def evaluate_bleu_rouge(candidates, references): bleu_score = corpus_bleu(candidates, [references]).score scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) rouge_scores = [scorer.score(ref, cand) for ref, cand in zip(references, candidates)] rouge1 = sum([score['rouge1'].fmeasure for score in rouge_scores]) / len(rouge_scores) return bleu_score, rouge1 def evaluate_bert_score(candidates, references): P, R, F1 = score(candidates, references, lang="en", model_type='bert-base-multilingual-cased') return P.mean().item(), R.mean().item(), F1.mean().item() def evaluate_perplexity(text, model, tokenizer): encodings = tokenizer(text, return_tensors='pt') max_length = model.config.n_positions stride = 512 lls = [] for i in range(0, encodings.input_ids.size(1), stride): begin_loc = max(i + stride - max_length, 0) end_loc = min(i + stride, encodings.input_ids.size(1)) trg_len = end_loc - i input_ids = encodings.input_ids[:, begin_loc:end_loc] target_ids = input_ids.clone() target_ids[:, :-trg_len] = -100 with torch.no_grad(): outputs = model(input_ids, labels=target_ids) log_likelihood = outputs[0] * trg_len lls.append(log_likelihood) ppl = torch.exp(torch.stack(lls).sum() / end_loc) return ppl.item() def evaluate_diversity(texts): all_tokens = [tok for text in texts for tok in text.split()] unique_bigrams = set(ngrams(all_tokens, 2)) diversity_score = len(unique_bigrams) / len(all_tokens) if all_tokens else 0 return diversity_score def evaluate_racial_bias(text, pipeline): results = pipeline([text], candidate_labels=["hate speech", "not hate speech"]) bias_score = results[0]['scores'][results[0]['labels'].index('hate speech')] return bias_score def evaluate_all(question, response, reference, gpt2_model, gpt2_tokenizer, bias_pipeline): candidates = [response] references = [reference] bleu, rouge1 = evaluate_bleu_rouge(candidates, references) bert_p, bert_r, bert_f1 = evaluate_bert_score(candidates, references) perplexity = evaluate_perplexity(response, gpt2_model, gpt2_tokenizer) diversity = evaluate_diversity(candidates) racial_bias = evaluate_racial_bias(response, bias_pipeline) return { "BLEU": bleu, "ROUGE-1": rouge1, "BERT P": bert_p, "BERT R": bert_r, "BERT F1": bert_f1, "Perplexity": perplexity, "Diversity": diversity, "Racial Bias": racial_bias } #--------------------------------------------------------------------------------- def display_metrics(metrics): result = "" for k, v in metrics.items(): if k == 'BLEU': result += f"BLEU measures the overlap between the generated output and reference text based on n-grams. Higher scores indicate better match. Score obtained: {v}\n\n" elif k == "ROUGE-1": result += f"ROUGE-1 measures the overlap of unigrams between the generated output and reference text. Higher scores indicate better match. Score obtained: {v}\n\n" elif k == 'BERT P': result += "BERTScore evaluates the semantic similarity between the generated output and reference text using BERT embeddings.\n\n" result += f"**BERT Precision**: {metrics['BERT P']}\n" result += f"**BERT Recall**: {metrics['BERT R']}\n" result += f"**BERT F1 Score**: {metrics['BERT F1']}\n\n" elif k == 'Perplexity': result += f"Perplexity measures how well a language model predicts the text. Lower values indicate better fluency and coherence. Score obtained: {v}\n\n" elif k == 'Diversity': result += f"Diversity measures the uniqueness of bigrams in the generated output. Higher values indicate more diverse and varied output. Score obtained: {v}\n\n" elif k == 'Racial Bias': result += f"Racial Bias score indicates the presence of biased language in the generated output. Higher scores indicate more bias. Score obtained: {v}\n\n" return result #--------------------------------------------------------------------------------------------------------------------------------------------------- def conversation(qa_chain, message, history, gpt2_model, gpt2_tokenizer, bias_pipeline): formatted_chat_history = format_chat_history(message, history) question_by_user = message response = qa_chain({"question": message, "chat_history": formatted_chat_history}) response_answer = response["answer"] answer_of_question = response["answer"] if response_answer.find("Helpful Answer:") != -1: response_answer = response_answer.split("Helpful Answer:")[-1] response_sources = response["source_documents"] context = " ".join([d.page_content for d in response_sources]) response_source1 = response_sources[0].page_content.strip() response_source2 = response_sources[1].page_content.strip() response_source3 = response_sources[2].page_content.strip() response_source1_page = response_sources[0].metadata["page"] + 1 response_source2_page = response_sources[1].metadata["page"] + 1 response_source3_page = response_sources[2].metadata["page"] + 1 new_history = history + [(message, response_answer)] # Evaluate the metrics metrics = evaluate_all(question_by_user, answer_of_question, context) evaluation_metrics = display_metrics(metrics) return (qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page, evaluation_metrics) # def interact(qa_chain, message, history): # return conversation(qa_chain, message, history, evaluator) def upload_file(file_obj): list_file_path = [] for idx, file in enumerate(file_obj): file_path = file_obj.name list_file_path.append(file_path) # print(file_path) # initialize_database(file_path, progress) return list_file_path #################################### def demo(): with gr.Blocks(theme="base") as demo: vector_db = gr.State() qa_chain = gr.State() collection_name = gr.State() history = gr.State([]) # Initialize history as an empty list gr.Markdown( """