soarhigh commited on
Commit
60aa3c2
·
1 Parent(s): dabdfcc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -3
app.py CHANGED
@@ -1,11 +1,36 @@
1
  import gradio as gr
2
  import torch
3
  from nextus_regressor_class import *
 
4
 
5
  model = NextUsRegressor()
6
  model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
7
  model.eval()
8
 
9
- gr.Interface(fn=model,
10
- inputs=gr.Textbox(label="기사", lines=30, placeholder="기사를 입력하세요"),
11
- outputs=gr.Number(label="Slant Score")).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from nextus_regressor_class import *
4
+ import nltk
5
 
6
  model = NextUsRegressor()
7
  model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
8
  model.eval()
9
 
10
+ def shap(txt, tok_level):
11
+ batch = [txt]
12
+ if tok_level == "word":
13
+ print("word")
14
+ elif tok_level == "sentence":
15
+ print("sentence")
16
+ else:
17
+ print("this token granularity not supported")
18
+ #tokens = nltk
19
+ with torch.no_grad():
20
+ y_pred = model(txt)
21
+ return y_pred
22
+
23
+ demo = gr.Interface(fn=shap,
24
+ inputs=[
25
+ gr.Textbox(label="기사", lines=30, placeholder="기사를 입력하세요."),
26
+ gr.Radio(["sentence", "word"], value="sentence", info="문장 단위의 해설은 sentence를 단어 단위의 해설은 word를 선택하세요.")
27
+ ],
28
+ outputs=gr.HighlightedText(
29
+ label="Diff",
30
+ combine_adjacent=True,
31
+ show_legend=True,
32
+ color_map={"+": "red", "-": "green"}),
33
+ theme=gr.themes.Base())
34
+
35
+ if __name__ == "__main__":
36
+ demo.launch()