Update app.py
Browse files
app.py
CHANGED
@@ -24,21 +24,22 @@ def shap(txt, tok_level):
|
|
24 |
with torch.no_grad():
|
25 |
y_pred = model(txt)
|
26 |
y_offs = model(batch)
|
27 |
-
shaps = (y_offs - y_pred).tolist()
|
28 |
-
|
29 |
-
|
|
|
30 |
|
31 |
demo = gr.Interface(shap,
|
32 |
[
|
33 |
gr.Textbox(label="기사", lines=30, placeholder="기사를 입력하세요."),
|
34 |
gr.Radio(choices=["sentence", "word"], label="해설 표시 단위", value="sentence", info="문장 단위의 해설은 sentence를, 단어 단위의 해설은 word를 선택하세요.")
|
35 |
],
|
36 |
-
gr.Textbox(label="Test Output"),
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
|
43 |
theme=gr.themes.Base())
|
44 |
|
|
|
24 |
with torch.no_grad():
|
25 |
y_pred = model(txt)
|
26 |
y_offs = model(batch)
|
27 |
+
shaps = (y_offs - y_pred).tolist() # convert to list and make tuple to be returned
|
28 |
+
shapss = [s[0] for s in shaps]
|
29 |
+
labels = ["+" if s < 0 else "-" for s in shapss]
|
30 |
+
return list(zip(tokens, labels))
|
31 |
|
32 |
demo = gr.Interface(shap,
|
33 |
[
|
34 |
gr.Textbox(label="기사", lines=30, placeholder="기사를 입력하세요."),
|
35 |
gr.Radio(choices=["sentence", "word"], label="해설 표시 단위", value="sentence", info="문장 단위의 해설은 sentence를, 단어 단위의 해설은 word를 선택하세요.")
|
36 |
],
|
37 |
+
#gr.Textbox(label="Test Output"),
|
38 |
+
gr.HighlightedText(
|
39 |
+
label="Diff",
|
40 |
+
combine_adjacent=True,
|
41 |
+
show_legend=True,
|
42 |
+
color_map={"+": "red", "-": "green"}),
|
43 |
|
44 |
theme=gr.themes.Base())
|
45 |
|