import gradio as gr import torch from transformers import BertTokenizer, BertForMaskedLM # Load the fine-tuned BERT model model_name = "fine_tuned_bert_model" tokenizer = BertTokenizer.from_pretrained(model_name) model = BertForMaskedLM.from_pretrained(model_name) model.to("cuda" if torch.cuda.is_available() else "cpu") # Function to answer questions using the fine-tuned model def answer_question(context, question): # Preprocess the context and question context_tokens = tokenizer(context, truncation=True, padding="max_length", max_length=128, return_tensors="pt") question_tokens = tokenizer(question, truncation=True, padding="max_length", max_length=16, return_tensors="pt") # Move tensors to device context_tokens = context_tokens.to(model.device) question_tokens = question_tokens.to(model.device) with torch.no_grad(): # Generate masked LM predictions for each token in the question outputs = model(**question_tokens) predictions = torch.argmax(outputs.logits, dim=-1) # Replace masked tokens in the question with predicted tokens answer_tokens = [] for i in range(len(question_tokens["input_ids"][0])): if question_tokens["input_ids"][0][i] == tokenizer.mask_token_id: answer_tokens.append(predictions[0][i].item()) else: answer_tokens.append(question_tokens["input_ids"][0][i].item()) # Decode tokens and remove special tokens answer = tokenizer.decode(answer_tokens, skip_special_tokens=True) # Return the answer return answer # Define example questions examples = [ ["Where did the Enron scandal occur?", "The Enron scandal occurred in [MASK]."], ["What was the outcome of the Enron scandal?", "The outcome of the Enron scandal was [MASK]."], ["When did Enron file for bankruptcy?", "Enron filed for bankruptcy in [MASK]."], ["How did Enron's stock price change during the scandal?", "During the Enron scandal, Enron's stock price [MASK]."] ] # Gradio interface with examples iface = gr.Interface( fn=answer_question, inputs=["text", "text"], outputs="text", title="Enron Email Analysis", description="Ask questions about the Enron email dataset using a fine-tuned BERT model.", examples=examples ) # Launch the Gradio interface iface.launch(share=True)