Spaces:
Build error
Build error
attention viz added.
Browse files
app.py
CHANGED
@@ -33,29 +33,50 @@ class IMBDModel(DistilBertPreTrainedModel):
|
|
33 |
|
34 |
def forward(self, x):
|
35 |
|
36 |
-
|
37 |
-
|
|
|
38 |
|
39 |
x = self.fc(pooled_output)
|
40 |
-
|
41 |
-
return x
|
42 |
|
43 |
|
44 |
infer_path = "./model/fold0_epoch01_loss0.1403_val_loss0.1994_roc_auc0.9779/"
|
45 |
|
46 |
-
|
47 |
|
48 |
-
pretrained_model = IMBDModel.from_pretrained(infer_path, local_files_only=True)
|
49 |
pretrained_model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
def prediction(text):
|
52 |
|
53 |
-
tokens =
|
54 |
tokens = {k:torch.tensor([v]) for k, v in tokens.items()}
|
|
|
55 |
|
56 |
with torch.no_grad():
|
57 |
|
58 |
-
scores = pretrained_model(tokens)
|
59 |
scores = torch.sigmoid(scores).numpy()
|
60 |
|
61 |
scores = scores[0][0]
|
@@ -66,14 +87,38 @@ def prediction(text):
|
|
66 |
label = "Neutral"
|
67 |
else:
|
68 |
label = "Negative"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
|
|
|
71 |
|
|
|
72 |
|
73 |
demo = gr.Interface(
|
74 |
fn=prediction,
|
75 |
inputs=gr.Textbox(lines=5, placeholder="Text to analyze..."),
|
76 |
-
outputs=["text", "text"]
|
|
|
77 |
)
|
78 |
|
79 |
demo.launch(server_name="0.0.0.0")
|
|
|
33 |
|
34 |
def forward(self, x):
|
35 |
|
36 |
+
output = self.distilbert(**x)
|
37 |
+
|
38 |
+
pooled_output = output.last_hidden_state[:, 0]
|
39 |
|
40 |
x = self.fc(pooled_output)
|
41 |
+
|
42 |
+
return x, output.attentions
|
43 |
|
44 |
|
45 |
infer_path = "./model/fold0_epoch01_loss0.1403_val_loss0.1994_roc_auc0.9779/"
|
46 |
|
47 |
+
pretrained_tokenizer = AutoTokenizer.from_pretrained(infer_path)
|
48 |
|
49 |
+
pretrained_model = IMBDModel.from_pretrained(infer_path, local_files_only=True, output_attentions=True)
|
50 |
pretrained_model.eval()
|
51 |
+
|
52 |
+
print("Model loaded.")
|
53 |
+
|
54 |
+
def get_attentions(attentions):
|
55 |
+
|
56 |
+
# last layer attentions
|
57 |
+
layer_layer_att = attentions[-1] # [batch, heads, seq_len, seq_len]
|
58 |
+
cls_att = layer_layer_att[:,:,0,:] # attentions of [CLS] token
|
59 |
+
cls_att_mean = cls_att.mean(dim=1) # mean over heads
|
60 |
+
|
61 |
+
cls_att_mean = cls_att_mean[0]
|
62 |
+
|
63 |
+
# min-max scaled because we are using for opicity (0 - 1)
|
64 |
+
cls_att_mean = (cls_att_mean - cls_att_mean.min()) / (cls_att_mean.max() - cls_att_mean.min())
|
65 |
+
|
66 |
+
return cls_att_mean
|
67 |
+
|
68 |
+
def wrap_text(word, score):
|
69 |
+
return f"<span style='background-color:rgba(0, 0, 255, {score:.2f});padding:2px;'>{word}</span>"
|
70 |
|
71 |
def prediction(text):
|
72 |
|
73 |
+
tokens = pretrained_tokenizer(text, truncation=True, max_length=512)
|
74 |
tokens = {k:torch.tensor([v]) for k, v in tokens.items()}
|
75 |
+
word_tokens= pretrained_tokenizer.convert_ids_to_tokens(tokens['input_ids'][0])
|
76 |
|
77 |
with torch.no_grad():
|
78 |
|
79 |
+
scores, attentions = pretrained_model(tokens)
|
80 |
scores = torch.sigmoid(scores).numpy()
|
81 |
|
82 |
scores = scores[0][0]
|
|
|
87 |
label = "Neutral"
|
88 |
else:
|
89 |
label = "Negative"
|
90 |
+
|
91 |
+
att_op = get_attentions(attentions)
|
92 |
+
|
93 |
+
html = "".join([wrap_text(w,s) for w,s in zip(word_tokens, att_op)])
|
94 |
+
html = f"<p style='word-wrap: break-word;'>{html}</p>"
|
95 |
+
|
96 |
+
return f"{label} feedback", f"{scores:.2f}", html
|
97 |
+
|
98 |
+
examples = [
|
99 |
+
|
100 |
+
"""
|
101 |
+
Infinity war is one of the best MCU protects. It has a great story, great acting, and awesome looking. If you aren't a Marvel fan or haven't watched most of the previous MCU movies this however, won't be something for you. Let's start with Thanos, definitely one of the best villains, he has a motive, is well played, you can even say that Infinity war tells his story and not the story of a hero. But also most of the other cast members were great in their role and again, if you love Marvel, watch this movie.
|
102 |
+
"""
|
103 |
+
,
|
104 |
+
|
105 |
+
"""
|
106 |
+
This is truly bottom of the barrel stuff. Nobody asked for this show but it was shoved on us hapless souls anyway.
|
107 |
+
|
108 |
+
Walters came across as obnoxious, vacuous and full of herself. There was no effort made at all towards character development in the first episode which looks like a poorly crafted music video from the 90s.
|
109 |
+
|
110 |
+
The first episode obviously serves as a placeholder for more shlock. This show is trying to be sex and the city with a pg13 rating and female Shrek.
|
111 |
|
112 |
+
Marvel's idea of setting up strong female characters with exposition dumps instead of focusing on having them go through a journey to realize their true potential.
|
113 |
+
"""
|
114 |
|
115 |
+
]
|
116 |
|
117 |
demo = gr.Interface(
|
118 |
fn=prediction,
|
119 |
inputs=gr.Textbox(lines=5, placeholder="Text to analyze..."),
|
120 |
+
outputs=["text", "text", "html"],
|
121 |
+
examples=examples
|
122 |
)
|
123 |
|
124 |
demo.launch(server_name="0.0.0.0")
|