|
import gradio as gr |
|
import torch |
|
from nextus_regressor_class import * |
|
import nltk |
|
|
|
model = NextUsRegressor() |
|
model.load_state_dict(torch.load("./nextus_regressor1012.pt")) |
|
model.eval() |
|
mask = "[MASKED]" |
|
def shap(txt, tok_level): |
|
batch = [txt] |
|
if tok_level == "word": |
|
tokens = nltk.word_tokenize(txt) |
|
|
|
elif tok_level == "sentence": |
|
|
|
tokens = nltk.sent_tokenize(txt) |
|
else: |
|
pass |
|
|
|
|
|
for i, _ in enumerate(tokens): |
|
batch.append(" ".join([s for j, s in enumerate(tokens) if j!=i])) |
|
with torch.no_grad(): |
|
y_pred = model(txt) |
|
y_offs = model(batch) |
|
shaps = (y_offs - y_pred).tolist() |
|
shapss = [s[0] for s in shaps] |
|
labels = ["+" if s < 0 else "-" for s in shapss] |
|
return list(zip(tokens, labels)) |
|
|
|
demo = gr.Interface(shap, |
|
[ |
|
gr.Textbox(label="๊ธฐ์ฌ", lines=30, placeholder="๊ธฐ์ฌ๋ฅผ ์
๋ ฅํ์ธ์."), |
|
gr.Radio(choices=["sentence", "word"], label="ํด์ค ํ์ ๋จ์", value="sentence", info="๋ฌธ์ฅ ๋จ์์ ํด์ค์ sentence๋ฅผ, ๋จ์ด ๋จ์์ ํด์ค์ word๋ฅผ ์ ํํ์ธ์.") |
|
], |
|
|
|
gr.HighlightedText( |
|
label="Diff", |
|
combine_adjacent=True, |
|
show_legend=True, |
|
color_map={"+": "red", "-": "green"}), |
|
|
|
theme=gr.themes.Base()) |
|
|
|
demo.launch() |