Elrmnd commited on
Commit
486c8a9
1 Parent(s): 0ed36a0

Create App.py

Browse files
Files changed (1) hide show
  1. App.py +59 -0
App.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import BertTokenizer, BertForMaskedLM
4
+
5
+ # Load the fine-tuned BERT model
6
+ model_name = "/content/fine_tuned_bert_model"
7
+ tokenizer = BertTokenizer.from_pretrained(model_name)
8
+ model = BertForMaskedLM.from_pretrained(model_name)
9
+ model.to("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Function to answer questions using the fine-tuned model
12
+ def answer_question(context, question):
13
+ # Preprocess the context and question
14
+ context_tokens = tokenizer(context, truncation=True, padding="max_length", max_length=128, return_tensors="pt")
15
+ question_tokens = tokenizer(question, truncation=True, padding="max_length", max_length=16, return_tensors="pt")
16
+
17
+ # Move tensors to device
18
+ context_tokens = context_tokens.to(model.device)
19
+ question_tokens = question_tokens.to(model.device)
20
+
21
+ with torch.no_grad():
22
+ # Generate masked LM predictions for each token in the question
23
+ outputs = model(**question_tokens)
24
+ predictions = torch.argmax(outputs.logits, dim=-1)
25
+
26
+ # Replace masked tokens in the question with predicted tokens
27
+ answer_tokens = []
28
+ for i in range(len(question_tokens["input_ids"][0])):
29
+ if question_tokens["input_ids"][0][i] == tokenizer.mask_token_id:
30
+ answer_tokens.append(predictions[0][i].item())
31
+ else:
32
+ answer_tokens.append(question_tokens["input_ids"][0][i].item())
33
+
34
+ # Decode tokens and remove special tokens
35
+ answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
36
+
37
+ # Return the answer
38
+ return answer
39
+
40
+ # Define example questions
41
+ examples = [
42
+ ["Where did the Enron scandal occur?", "The Enron scandal occurred in [MASK]."],
43
+ ["What was the outcome of the Enron scandal?", "The outcome of the Enron scandal was [MASK]."],
44
+ ["When did Enron file for bankruptcy?", "Enron filed for bankruptcy in [MASK]."],
45
+ ["How did Enron's stock price change during the scandal?", "During the Enron scandal, Enron's stock price [MASK]."]
46
+ ]
47
+
48
+ # Gradio interface with examples
49
+ iface = gr.Interface(
50
+ fn=answer_question,
51
+ inputs=["text", "text"],
52
+ outputs="text",
53
+ title="Enron Email Analysis",
54
+ description="Ask questions about the Enron email dataset using a fine-tuned BERT model.",
55
+ examples=examples
56
+ )
57
+
58
+ # Launch the Gradio interface
59
+ iface.launch(share=True)