slant / app.py
soarhigh's picture
Update app.py
8de1bb4
raw
history blame
1.7 kB
import gradio as gr
import torch
from nextus_regressor_class import *
import nltk
model = NextUsRegressor()
model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
model.eval()
mask = "[MASKED]"
def shap(txt, tok_level):
batch = [txt]
if tok_level == "word":
tokens = nltk.word_tokenize(txt)
#print("word")
elif tok_level == "sentence":
#print("sentence")
tokens = nltk.sent_tokenize(txt)
else:
pass
#print("this token granularity not supported")
#tokens = nltk
for i, _ in enumerate(tokens):
batch.append(" ".join([s for j, s in enumerate(tokens) if j!=i]))
with torch.no_grad():
y_pred = model(txt)
y_offs = model(batch)
shaps = (y_offs - y_pred).tolist() # convert to list and make tuple to be returned
shapss = [s[0] for s in shaps]
labels = ["+" if s < 0 else "-" for s in shapss]
return list(zip(tokens, labels))
demo = gr.Interface(shap,
[
gr.Textbox(label="๊ธฐ์‚ฌ", lines=30, placeholder="๊ธฐ์‚ฌ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”."),
gr.Radio(choices=["sentence", "word"], label="ํ•ด์„ค ํ‘œ์‹œ ๋‹จ์œ„", value="sentence", info="๋ฌธ์žฅ ๋‹จ์œ„์˜ ํ•ด์„ค์€ sentence๋ฅผ, ๋‹จ์–ด ๋‹จ์œ„์˜ ํ•ด์„ค์€ word๋ฅผ ์„ ํƒํ•˜์„ธ์š”.")
],
#gr.Textbox(label="Test Output"),
gr.HighlightedText(
label="Diff",
combine_adjacent=True,
show_legend=True,
color_map={"+": "red", "-": "green"}),
theme=gr.themes.Base())
demo.launch()