Update app.py
Browse files
app.py
CHANGED
@@ -2,11 +2,13 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from nextus_regressor_class import *
|
4 |
import nltk
|
|
|
5 |
|
6 |
model = NextUsRegressor()
|
7 |
model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
|
8 |
model.eval()
|
9 |
mask = "[MASKED]"
|
|
|
10 |
def shap(txt, tok_level):
|
11 |
batch = [txt]
|
12 |
if tok_level == "word":
|
@@ -26,9 +28,21 @@ def shap(txt, tok_level):
|
|
26 |
y_offs = model(batch)
|
27 |
shaps = (y_offs - y_pred).tolist() # convert to list and make tuple to be returned
|
28 |
shapss = [s[0] for s in shaps]
|
29 |
-
labels =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
return list(zip(tokens, labels))
|
31 |
-
#return txt
|
32 |
|
33 |
demo = gr.Interface(shap,
|
34 |
[
|
@@ -43,4 +57,4 @@ demo = gr.Interface(shap,
|
|
43 |
color_map={"+": "red", "-": "green"}),
|
44 |
theme=gr.themes.Base())
|
45 |
|
46 |
-
demo.launch()
|
|
|
2 |
import torch
|
3 |
from nextus_regressor_class import *
|
4 |
import nltk
|
5 |
+
from pprint import pprint
|
6 |
|
7 |
model = NextUsRegressor()
|
8 |
model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
|
9 |
model.eval()
|
10 |
mask = "[MASKED]"
|
11 |
+
threshold = 0.05
|
12 |
def shap(txt, tok_level):
|
13 |
batch = [txt]
|
14 |
if tok_level == "word":
|
|
|
28 |
y_offs = model(batch)
|
29 |
shaps = (y_offs - y_pred).tolist() # convert to list and make tuple to be returned
|
30 |
shapss = [s[0] for s in shaps]
|
31 |
+
labels = list()
|
32 |
+
for s in shapss:
|
33 |
+
if s <= -1.0*threshold:
|
34 |
+
labels.append("+")
|
35 |
+
elif s >= threshold:
|
36 |
+
labels.append("-")
|
37 |
+
else:
|
38 |
+
labels.append(None)
|
39 |
+
# labels = ["+" if s < -1.0*threshold "-" elif s > threshold else " " for s in shapss]
|
40 |
+
# print(len(tokens), len(labels))
|
41 |
+
# print(list(zip(tokens, labels)))
|
42 |
+
# pprint(list(zip(tokens, shapss)))
|
43 |
+
# return str(list(zip(tokens, labels)))
|
44 |
return list(zip(tokens, labels))
|
45 |
+
# return txt
|
46 |
|
47 |
demo = gr.Interface(shap,
|
48 |
[
|
|
|
57 |
color_map={"+": "red", "-": "green"}),
|
58 |
theme=gr.themes.Base())
|
59 |
|
60 |
+
demo.launch()
|