Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,36 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from nextus_regressor_class import *
|
|
|
4 |
|
5 |
model = NextUsRegressor()
|
6 |
model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
|
7 |
model.eval()
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
from nextus_regressor_class import *
|
4 |
+
import nltk
|
5 |
|
6 |
model = NextUsRegressor()
|
7 |
model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
|
8 |
model.eval()
|
9 |
|
10 |
+
def shap(txt, tok_level):
|
11 |
+
batch = [txt]
|
12 |
+
if tok_level == "word":
|
13 |
+
print("word")
|
14 |
+
elif tok_level == "sentence":
|
15 |
+
print("sentence")
|
16 |
+
else:
|
17 |
+
print("this token granularity not supported")
|
18 |
+
#tokens = nltk
|
19 |
+
with torch.no_grad():
|
20 |
+
y_pred = model(txt)
|
21 |
+
return y_pred
|
22 |
+
|
23 |
+
demo = gr.Interface(fn=shap,
|
24 |
+
inputs=[
|
25 |
+
gr.Textbox(label="기사", lines=30, placeholder="기사를 입력하세요."),
|
26 |
+
gr.Radio(["sentence", "word"], value="sentence", info="문장 단위의 해설은 sentence를 단어 단위의 해설은 word를 선택하세요.")
|
27 |
+
],
|
28 |
+
outputs=gr.HighlightedText(
|
29 |
+
label="Diff",
|
30 |
+
combine_adjacent=True,
|
31 |
+
show_legend=True,
|
32 |
+
color_map={"+": "red", "-": "green"}),
|
33 |
+
theme=gr.themes.Base())
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
demo.launch()
|