EDITH / app.py
LivisLiquoro's picture
Upload 6 files
c163b71 verified
raw
history blame
8.79 kB
!pip install -q gradio
import warnings
warnings.filterwarnings('ignore')
# Import necessary libraries
import gradio as gr
import torch
from transformers import (
BertTokenizerFast,
BertForQuestionAnswering,
AutoTokenizer,
BartForQuestionAnswering,
DistilBertTokenizerFast,
DistilBertForQuestionAnswering
)
import gc
# Create a context store
context_store = []
selected_model = None # To track the selected model
# Define models and tokenizers
def load_bert_model_and_tokenizer():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_save_path = "LivisLiquoro/BERT_Model_Squad1.1"
model = BertForQuestionAnswering.from_pretrained(model_save_path)
tokenizer = BertTokenizerFast.from_pretrained(model_save_path)
model.eval().to(device)
gc.collect()
torch.cuda.empty_cache()
return tokenizer, model, device
def load_bart_model_and_tokenizer():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BartForQuestionAnswering.from_pretrained("valhalla/bart-large-finetuned-squadv1")
tokenizer = AutoTokenizer.from_pretrained("valhalla/bart-large-finetuned-squadv1")
model.eval().to(device)
gc.collect()
torch.cuda.empty_cache()
return tokenizer, model, device
def load_distilbert_model_and_tokenizer():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_save_path = "LivisLiquoro/DistilBert_model_squad1.1"
model = DistilBertForQuestionAnswering.from_pretrained(model_save_path)
tokenizer = DistilBertTokenizerFast.from_pretrained(model_save_path)
model.eval().to(device)
gc.collect()
torch.cuda.empty_cache()
return tokenizer, model, device
def clean_answer(tokens):
"""
Clean the tokens by removing special tokens like [SEP], [CLS], and fixing token fragments.
"""
cleaned_tokens = []
for token in tokens:
if token in ['[SEP]', '[CLS]']:
continue # Skip special tokens
token = token.replace('##', '') # Remove '##' prefix
if token: # Only add non-empty tokens
cleaned_tokens.append(token)
return tokenizer.convert_tokens_to_string(cleaned_tokens).strip() or None
def generate_answer(context, question):
max_attempts = 50 # Set maximum attempts for generating answers
attempts = 0
best_answer = None
# Adjusting the context chunking method
max_length = 512
chunks = [context[i:i + max_length] for i in range(0, len(context), max_length)]
while attempts < max_attempts:
attempts += 1
for chunk in chunks:
inputs = tokenizer(chunk, question, return_tensors='pt', truncation=True, max_length=max_length).to(device)
with torch.no_grad():
outputs = model(**inputs)
answer_start = torch.argmax(outputs.start_logits)
answer_end = torch.argmax(outputs.end_logits) + 1
if answer_start < answer_end:
answer = clean_answer(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
# Validate answer and ensure it's direct
if answer and answer.lower() != "no valid answer found":
best_answer = answer.capitalize()
break # Exit the chunk loop if a valid answer is found
if best_answer: # If an answer is found, no need to keep trying
break
if best_answer: # If a valid answer was found, exit the attempts loop
break
if best_answer:
return best_answer
else:
return "❌ No valid answer found."
# Define the Gradio interface with light theme and organized layout
def chatbot_interface():
with gr.Blocks() as demo:
# Custom CSS for light theme and layout
gr.Markdown("""
<style>
body { background-color: #f9f9f9; }
.chatbot-container { background-color: #ffffff; border-radius: 10px; padding: 20px; color: #333; font-family: Arial, sans-serif; }
.gr-button { background-color: #4CAF50; color: white; border: none; border-radius: 5px; padding: 10px 20px; font-size: 14px; cursor: pointer; }
.gr-button:hover { background-color: #45a049; }
.gr-textbox { background-color: #ffffff; color: #333; border-radius: 5px; border: 1px solid #ddd; padding: 10px; }
.gr-chatbot { background-color: #e6e6e6; border-radius: 10px; padding: 15px; color: #333; }
.footer { text-align: right; font-size: 12px; color: #777; font-style: italic; }
.note { text-align: right; font-size: 10px; color: #777; font-style: italic; position: absolute; bottom: 10px; right: 10px; }
</style>
""")
# Header
gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>EDITH: Multi-Model Question Answering Platform</h1>")
gr.Markdown("<p style='text-align: center; color: #777;'>Switch between BERT, BART, and DistilBERT models and ask questions based on the context.</p>")
context_state = gr.State()
model_choice_state = gr.State(value="BERT") # Default model is BERT
with gr.Row():
with gr.Column(scale=11): # Left panel for chatbot and question input (45%)
chatbot = gr.Chatbot(label="Chatbot")
question_input = gr.Textbox(label="Ask a Question", placeholder="Enter your question here...", lines=1)
submit_btn = gr.Button("Submit Question")
with gr.Column(scale=9): # Right panel for setting context and instructions (55%)
context_input = gr.Textbox(label="Set Context", placeholder="Enter the context here...", lines=4)
set_context_btn = gr.Button("Set Context")
clear_context_btn = gr.Button("Clear Context")
# Model selection buttons
model_selection = gr.Radio(choices=["BERT", "BART", "DistilBERT"], label="Select Model", value="BERT")
status_message = gr.Markdown("")
gr.Markdown("<strong>Instructions:</strong><br>1. Set a context.<br>2. Select the model (BERT, BART, or DistilBERT).<br>3. Ask questions based on the context.<br><br><strong>Note:</strong> <span class='note'>The BART model is pre-trained from Hugging Face. Credits to Hugging Face and the person who fine-tuned this model ('valhalla/bart-large-finetuned-squadv1')</span>")
footer = gr.Markdown("<div class='footer'>Prepared by: Team EDITH</div>")
def set_context(context):
if not context.strip():
return gr.update(), "Please enter a valid context.", None
return gr.update(visible=False), "Context has been set. You can now ask questions.", context
def clear_context():
return gr.update(visible=True), "Context has been cleared. Please set a new context.", None
def handle_question(question, history, context, model_choice):
global tokenizer, model, device
if not context:
return history, "Please set the context before asking questions."
if not question.strip():
return history, "Please enter a valid question."
# Load the selected model and tokenizer
if model_choice == "BERT":
tokenizer, model, device = load_bert_model_and_tokenizer()
model_name = "BERT"
elif model_choice == "BART":
tokenizer, model, device = load_bart_model_and_tokenizer()
model_name = "BART"
elif model_choice == "DistilBERT":
tokenizer, model, device = load_distilbert_model_and_tokenizer()
model_name = "DistilBERT"
answer = generate_answer(context, question)
history = history + [[f"πŸ‘€: {question}", f"πŸ€– ({model_name}): {answer}"]] # Show the selected model with the answer
return history, ""
set_context_btn.click(set_context, inputs=context_input, outputs=[context_input, status_message, context_state])
clear_context_btn.click(clear_context, inputs=None, outputs=[context_input, status_message, context_state])
submit_btn.click(handle_question, inputs=[question_input, chatbot, context_state, model_selection], outputs=[chatbot, question_input])
# Enable "Enter" key to trigger the "Submit" button
question_input.submit(handle_question, inputs=[question_input, chatbot, context_state, model_selection], outputs=[chatbot, question_input])
return demo
# Run the Gradio interface
interface = chatbot_interface()
interface.launch(share=True)