soarhigh commited on
Commit
877eb8b
·
1 Parent(s): d816cca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -3
app.py CHANGED
@@ -2,11 +2,13 @@ import gradio as gr
2
  import torch
3
  from nextus_regressor_class import *
4
  import nltk
 
5
 
6
  model = NextUsRegressor()
7
  model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
8
  model.eval()
9
  mask = "[MASKED]"
 
10
  def shap(txt, tok_level):
11
  batch = [txt]
12
  if tok_level == "word":
@@ -26,9 +28,21 @@ def shap(txt, tok_level):
26
  y_offs = model(batch)
27
  shaps = (y_offs - y_pred).tolist() # convert to list and make tuple to be returned
28
  shapss = [s[0] for s in shaps]
29
- labels = ["+" if s < 0 else "-" for s in shapss]
 
 
 
 
 
 
 
 
 
 
 
 
30
  return list(zip(tokens, labels))
31
- #return txt
32
 
33
  demo = gr.Interface(shap,
34
  [
@@ -43,4 +57,4 @@ demo = gr.Interface(shap,
43
  color_map={"+": "red", "-": "green"}),
44
  theme=gr.themes.Base())
45
 
46
- demo.launch()
 
2
  import torch
3
  from nextus_regressor_class import *
4
  import nltk
5
+ from pprint import pprint
6
 
7
  model = NextUsRegressor()
8
  model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
9
  model.eval()
10
  mask = "[MASKED]"
11
+ threshold = 0.05
12
  def shap(txt, tok_level):
13
  batch = [txt]
14
  if tok_level == "word":
 
28
  y_offs = model(batch)
29
  shaps = (y_offs - y_pred).tolist() # convert to list and make tuple to be returned
30
  shapss = [s[0] for s in shaps]
31
+ labels = list()
32
+ for s in shapss:
33
+ if s <= -1.0*threshold:
34
+ labels.append("+")
35
+ elif s >= threshold:
36
+ labels.append("-")
37
+ else:
38
+ labels.append(None)
39
+ # labels = ["+" if s < -1.0*threshold "-" elif s > threshold else " " for s in shapss]
40
+ # print(len(tokens), len(labels))
41
+ # print(list(zip(tokens, labels)))
42
+ # pprint(list(zip(tokens, shapss)))
43
+ # return str(list(zip(tokens, labels)))
44
  return list(zip(tokens, labels))
45
+ # return txt
46
 
47
  demo = gr.Interface(shap,
48
  [
 
57
  color_map={"+": "red", "-": "green"}),
58
  theme=gr.themes.Base())
59
 
60
+ demo.launch()