djagatiya commited on
Commit
bb9a235
·
1 Parent(s): d4b4f1f

attention viz added.

Browse files
Files changed (1) hide show
  1. app.py +55 -10
app.py CHANGED
@@ -33,29 +33,50 @@ class IMBDModel(DistilBertPreTrainedModel):
33
 
34
  def forward(self, x):
35
 
36
- x = self.distilbert(**x).last_hidden_state
37
- pooled_output = x[:, 0]
 
38
 
39
  x = self.fc(pooled_output)
40
-
41
- return x
42
 
43
 
44
  infer_path = "./model/fold0_epoch01_loss0.1403_val_loss0.1994_roc_auc0.9779/"
45
 
46
- tokenizer = AutoTokenizer.from_pretrained(infer_path)
47
 
48
- pretrained_model = IMBDModel.from_pretrained(infer_path, local_files_only=True)
49
  pretrained_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def prediction(text):
52
 
53
- tokens = tokenizer(text, padding='max_length', truncation=True, max_length=512)
54
  tokens = {k:torch.tensor([v]) for k, v in tokens.items()}
 
55
 
56
  with torch.no_grad():
57
 
58
- scores = pretrained_model(tokens)
59
  scores = torch.sigmoid(scores).numpy()
60
 
61
  scores = scores[0][0]
@@ -66,14 +87,38 @@ def prediction(text):
66
  label = "Neutral"
67
  else:
68
  label = "Negative"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
- return f"{label} feedback", f"{scores:.2f}"
 
71
 
 
72
 
73
  demo = gr.Interface(
74
  fn=prediction,
75
  inputs=gr.Textbox(lines=5, placeholder="Text to analyze..."),
76
- outputs=["text", "text"]
 
77
  )
78
 
79
  demo.launch(server_name="0.0.0.0")
 
33
 
34
  def forward(self, x):
35
 
36
+ output = self.distilbert(**x)
37
+
38
+ pooled_output = output.last_hidden_state[:, 0]
39
 
40
  x = self.fc(pooled_output)
41
+
42
+ return x, output.attentions
43
 
44
 
45
  infer_path = "./model/fold0_epoch01_loss0.1403_val_loss0.1994_roc_auc0.9779/"
46
 
47
+ pretrained_tokenizer = AutoTokenizer.from_pretrained(infer_path)
48
 
49
+ pretrained_model = IMBDModel.from_pretrained(infer_path, local_files_only=True, output_attentions=True)
50
  pretrained_model.eval()
51
+
52
+ print("Model loaded.")
53
+
54
+ def get_attentions(attentions):
55
+
56
+ # last layer attentions
57
+ layer_layer_att = attentions[-1] # [batch, heads, seq_len, seq_len]
58
+ cls_att = layer_layer_att[:,:,0,:] # attentions of [CLS] token
59
+ cls_att_mean = cls_att.mean(dim=1) # mean over heads
60
+
61
+ cls_att_mean = cls_att_mean[0]
62
+
63
+ # min-max scaled because we are using for opicity (0 - 1)
64
+ cls_att_mean = (cls_att_mean - cls_att_mean.min()) / (cls_att_mean.max() - cls_att_mean.min())
65
+
66
+ return cls_att_mean
67
+
68
+ def wrap_text(word, score):
69
+ return f"<span style='background-color:rgba(0, 0, 255, {score:.2f});padding:2px;'>{word}</span>"
70
 
71
  def prediction(text):
72
 
73
+ tokens = pretrained_tokenizer(text, truncation=True, max_length=512)
74
  tokens = {k:torch.tensor([v]) for k, v in tokens.items()}
75
+ word_tokens= pretrained_tokenizer.convert_ids_to_tokens(tokens['input_ids'][0])
76
 
77
  with torch.no_grad():
78
 
79
+ scores, attentions = pretrained_model(tokens)
80
  scores = torch.sigmoid(scores).numpy()
81
 
82
  scores = scores[0][0]
 
87
  label = "Neutral"
88
  else:
89
  label = "Negative"
90
+
91
+ att_op = get_attentions(attentions)
92
+
93
+ html = "".join([wrap_text(w,s) for w,s in zip(word_tokens, att_op)])
94
+ html = f"<p style='word-wrap: break-word;'>{html}</p>"
95
+
96
+ return f"{label} feedback", f"{scores:.2f}", html
97
+
98
+ examples = [
99
+
100
+ """
101
+ Infinity war is one of the best MCU protects. It has a great story, great acting, and awesome looking. If you aren't a Marvel fan or haven't watched most of the previous MCU movies this however, won't be something for you. Let's start with Thanos, definitely one of the best villains, he has a motive, is well played, you can even say that Infinity war tells his story and not the story of a hero. But also most of the other cast members were great in their role and again, if you love Marvel, watch this movie.
102
+ """
103
+ ,
104
+
105
+ """
106
+ This is truly bottom of the barrel stuff. Nobody asked for this show but it was shoved on us hapless souls anyway.
107
+
108
+ Walters came across as obnoxious, vacuous and full of herself. There was no effort made at all towards character development in the first episode which looks like a poorly crafted music video from the 90s.
109
+
110
+ The first episode obviously serves as a placeholder for more shlock. This show is trying to be sex and the city with a pg13 rating and female Shrek.
111
 
112
+ Marvel's idea of setting up strong female characters with exposition dumps instead of focusing on having them go through a journey to realize their true potential.
113
+ """
114
 
115
+ ]
116
 
117
  demo = gr.Interface(
118
  fn=prediction,
119
  inputs=gr.Textbox(lines=5, placeholder="Text to analyze..."),
120
+ outputs=["text", "text", "html"],
121
+ examples=examples
122
  )
123
 
124
  demo.launch(server_name="0.0.0.0")