slant / app.py
soarhigh's picture
module name fix
f65da74
raw
history blame
3.18 kB
import io
import gradio as gr
import torch
from nextus_regressor_class import *
import nltk
from pprint import pprint
import pandas as pd
model = NextUsRegressor()
model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
model.eval()
mask = "[MASKED]"
threshold = 0.05
def shap(txt, tok_level):
batch = [txt]
if tok_level == "word":
tokens = nltk.word_tokenize(txt)
#print("word")
elif tok_level == "sentence":
#print("sentence")
tokens = nltk.sent_tokenize(txt)
else:
pass
#print("this token granularity not supported")
#tokens = nltk
for i, _ in enumerate(tokens):
batch.append(" ".join([s for j, s in enumerate(tokens) if j!=i]))
with torch.no_grad():
y_pred = model(txt)
y_offs = model(batch)
shaps = (y_offs - y_pred).tolist() # convert to list and make tuple to be returned
shapss = [s[0] for s in shaps]
labels = list()
for s in shapss:
if s <= -1.0*threshold:
labels.append("+")
elif s >= threshold:
labels.append("-")
else:
labels.append(None)
# labels = ["+" if s < -1.0*threshold "-" elif s > threshold else " " for s in shapss]
# print(len(tokens), len(labels))
# print(list(zip(tokens, labels)))
pprint(list(zip(tokens, shapss)))
# return str(list(zip(tokens, labels)))
largest_shap = torch.max(y_offs - y_pred).item()
largest_shap_span = tokens[torch.argmax(y_offs - y_pred).item()]
explanation = "๊ฐ€์žฅ ํฐ ์˜ํ–ฅ์„ ๋ฏธ์นœ ํ…์ŠคํŠธ๋Š”\n'"+ largest_shap_span+ "'\n์ด๋ฉฐ, ํ•ด๋‹น ํ…์ŠคํŠธ๊ฐ€ ์—†์„ ๊ฒฝ์šฐ Slant ์Šค์ฝ”์–ด\n" + str(round(y_pred.item(), 4))+ "\n์—์„œ\n"+ str(round(largest_shap,4))+ "\n๋งŒํผ ๋ฒ—์–ด๋‚ฉ๋‹ˆ๋‹ค."
return list(zip(tokens, labels)), explanation
# return txt
def parse_file_input(f):
# print(f, type(f))
all_articles = list()
# with open(f, "r") as fh:
if ".csv" in f.name:
all_articles += pd.read_csv(f.name).iloc[:, 0].to_list()
elif ".xls" in f.name:
all_articles += pd.read_excel(f.name).iloc[:, 0].to_list()
else:
pass
# print(len(all_articles))
# print(all_articles)
scores = model(all_articles)
return scores
demo = gr.Interface(parse_file_input,
[
gr.File(file_count="single", file_types=[".csv", ".xls", ".xlsx"], type="file", label="๊ธฐ์‚ฌ ํŒŒ์ผ(csv/excel)์„ ์—…๋กœ๋“œํ•˜์„ธ์š”")
#gr.Textbox(label="๊ธฐ์‚ฌ", lines=30, placeholder="๊ธฐ์‚ฌ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”."),
# gr.Radio(choices=["sentence", "word"], label="ํ•ด์„ค ํ‘œ์‹œ ๋‹จ์œ„", value="sentence", info="๋ฌธ์žฅ ๋‹จ์œ„์˜ ํ•ด์„ค์€ sentence๋ฅผ, ๋‹จ์–ด ๋‹จ์œ„์˜ ํ•ด์„ค์€ word๋ฅผ ์„ ํƒํ•˜์„ธ์š”.")
],
gr.Textbox(label="Slant Scores"),
#gr.HighlightedText(
# label="Diff",
# combine_adjacent=True,
# show_legend=True,
# color_map={"+": "red", "-": "green"}),
theme=gr.themes.Base())
demo.launch()