File size: 8,790 Bytes
c163b71
 
 
 
 
7504338
c163b71
 
 
 
 
 
 
 
7504338
c163b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7504338
c163b71
7504338
c163b71
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
!pip install -q gradio
import warnings
warnings.filterwarnings('ignore')

# Import necessary libraries
import gradio as gr
import torch
from transformers import (
    BertTokenizerFast,
    BertForQuestionAnswering,
    AutoTokenizer,
    BartForQuestionAnswering,
    DistilBertTokenizerFast,
    DistilBertForQuestionAnswering
)
import gc

# Create a context store
context_store = []
selected_model = None  # To track the selected model

# Define models and tokenizers
def load_bert_model_and_tokenizer():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_save_path = "LivisLiquoro/BERT_Model_Squad1.1"
    model = BertForQuestionAnswering.from_pretrained(model_save_path)
    tokenizer = BertTokenizerFast.from_pretrained(model_save_path)
    model.eval().to(device)
    gc.collect()
    torch.cuda.empty_cache()
    return tokenizer, model, device

def load_bart_model_and_tokenizer():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BartForQuestionAnswering.from_pretrained("valhalla/bart-large-finetuned-squadv1")
    tokenizer = AutoTokenizer.from_pretrained("valhalla/bart-large-finetuned-squadv1")
    model.eval().to(device)
    gc.collect()
    torch.cuda.empty_cache()
    return tokenizer, model, device

def load_distilbert_model_and_tokenizer():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_save_path = "LivisLiquoro/DistilBert_model_squad1.1"
    model = DistilBertForQuestionAnswering.from_pretrained(model_save_path)
    tokenizer = DistilBertTokenizerFast.from_pretrained(model_save_path)
    model.eval().to(device)
    gc.collect()
    torch.cuda.empty_cache()
    return tokenizer, model, device

def clean_answer(tokens):
    """
    Clean the tokens by removing special tokens like [SEP], [CLS], and fixing token fragments.
    """
    cleaned_tokens = []
    for token in tokens:
        if token in ['[SEP]', '[CLS]']:
            continue  # Skip special tokens
        token = token.replace('##', '')  # Remove '##' prefix
        if token:  # Only add non-empty tokens
            cleaned_tokens.append(token)

    return tokenizer.convert_tokens_to_string(cleaned_tokens).strip() or None

def generate_answer(context, question):
    max_attempts = 50  # Set maximum attempts for generating answers
    attempts = 0
    best_answer = None
    
    # Adjusting the context chunking method
    max_length = 512
    chunks = [context[i:i + max_length] for i in range(0, len(context), max_length)]

    while attempts < max_attempts:
        attempts += 1
        for chunk in chunks:
            inputs = tokenizer(chunk, question, return_tensors='pt', truncation=True, max_length=max_length).to(device)

            with torch.no_grad():
                outputs = model(**inputs)
                answer_start = torch.argmax(outputs.start_logits)
                answer_end = torch.argmax(outputs.end_logits) + 1

                if answer_start < answer_end:
                    answer = clean_answer(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

                    # Validate answer and ensure it's direct
                    if answer and answer.lower() != "no valid answer found":
                        best_answer = answer.capitalize()
                        break  # Exit the chunk loop if a valid answer is found
            if best_answer:  # If an answer is found, no need to keep trying
                break
        
        if best_answer:  # If a valid answer was found, exit the attempts loop
            break
    
    if best_answer:
        return best_answer
    else:
        return "❌ No valid answer found."

# Define the Gradio interface with light theme and organized layout
def chatbot_interface():
    with gr.Blocks() as demo:
        # Custom CSS for light theme and layout
        gr.Markdown("""
            <style>
                body { background-color: #f9f9f9; }
                .chatbot-container { background-color: #ffffff; border-radius: 10px; padding: 20px; color: #333; font-family: Arial, sans-serif; }
                .gr-button { background-color: #4CAF50; color: white; border: none; border-radius: 5px; padding: 10px 20px; font-size: 14px; cursor: pointer; }
                .gr-button:hover { background-color: #45a049; }
                .gr-textbox { background-color: #ffffff; color: #333; border-radius: 5px; border: 1px solid #ddd; padding: 10px; }
                .gr-chatbot { background-color: #e6e6e6; border-radius: 10px; padding: 15px; color: #333; }
                .footer { text-align: right; font-size: 12px; color: #777; font-style: italic; }
                .note { text-align: right; font-size: 10px; color: #777; font-style: italic; position: absolute; bottom: 10px; right: 10px; }
            </style>
        """)

        # Header
        gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>EDITH: Multi-Model Question Answering Platform</h1>")
        gr.Markdown("<p style='text-align: center; color: #777;'>Switch between BERT, BART, and DistilBERT models and ask questions based on the context.</p>")

        context_state = gr.State()
        model_choice_state = gr.State(value="BERT")  # Default model is BERT

        with gr.Row():
            with gr.Column(scale=11):  # Left panel for chatbot and question input (45%)
                chatbot = gr.Chatbot(label="Chatbot")
                question_input = gr.Textbox(label="Ask a Question", placeholder="Enter your question here...", lines=1)
                submit_btn = gr.Button("Submit Question")

            with gr.Column(scale=9):  # Right panel for setting context and instructions (55%)
                context_input = gr.Textbox(label="Set Context", placeholder="Enter the context here...", lines=4)
                set_context_btn = gr.Button("Set Context")
                clear_context_btn = gr.Button("Clear Context")

                # Model selection buttons
                model_selection = gr.Radio(choices=["BERT", "BART", "DistilBERT"], label="Select Model", value="BERT")
                status_message = gr.Markdown("")

                gr.Markdown("<strong>Instructions:</strong><br>1. Set a context.<br>2. Select the model (BERT, BART, or DistilBERT).<br>3. Ask questions based on the context.<br><br><strong>Note:</strong> <span class='note'>The BART model is pre-trained from Hugging Face. Credits to Hugging Face and the person who fine-tuned this model ('valhalla/bart-large-finetuned-squadv1')</span>")

        footer = gr.Markdown("<div class='footer'>Prepared by: Team EDITH</div>")

        def set_context(context):
            if not context.strip():
                return gr.update(), "Please enter a valid context.", None
            return gr.update(visible=False), "Context has been set. You can now ask questions.", context

        def clear_context():
            return gr.update(visible=True), "Context has been cleared. Please set a new context.", None

        def handle_question(question, history, context, model_choice):
            global tokenizer, model, device

            if not context:
                return history, "Please set the context before asking questions."
            if not question.strip():
                return history, "Please enter a valid question."

            # Load the selected model and tokenizer
            if model_choice == "BERT":
                tokenizer, model, device = load_bert_model_and_tokenizer()
                model_name = "BERT"
            elif model_choice == "BART":
                tokenizer, model, device = load_bart_model_and_tokenizer()
                model_name = "BART"
            elif model_choice == "DistilBERT":
                tokenizer, model, device = load_distilbert_model_and_tokenizer()
                model_name = "DistilBERT"

            answer = generate_answer(context, question)
            history = history + [[f"👤: {question}", f"🤖 ({model_name}): {answer}"]]  # Show the selected model with the answer
            return history, ""

        set_context_btn.click(set_context, inputs=context_input, outputs=[context_input, status_message, context_state])
        clear_context_btn.click(clear_context, inputs=None, outputs=[context_input, status_message, context_state])
        submit_btn.click(handle_question, inputs=[question_input, chatbot, context_state, model_selection], outputs=[chatbot, question_input])

        # Enable "Enter" key to trigger the "Submit" button
        question_input.submit(handle_question, inputs=[question_input, chatbot, context_state, model_selection], outputs=[chatbot, question_input])

    return demo

# Run the Gradio interface
interface = chatbot_interface()
interface.launch(share=True)