File size: 1,698 Bytes
fc44eab
 
 
60aa3c2
fc44eab
 
6d3ec42
fc44eab
e1e2421
60aa3c2
 
 
e1e2421
 
60aa3c2
e1e2421
 
60aa3c2
e1e2421
 
60aa3c2
e1e2421
 
60aa3c2
 
e1e2421
8de1bb4
 
 
 
60aa3c2
b9a8e24
a760798
 
aa6ef25
a760798
8de1bb4
 
 
 
 
 
f74875b
 
60aa3c2
9966196
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
import torch
from nextus_regressor_class import *
import nltk

model = NextUsRegressor()
model.load_state_dict(torch.load("./nextus_regressor1012.pt"))
model.eval()
mask = "[MASKED]"
def shap(txt, tok_level):
    batch = [txt]
    if tok_level == "word":
        tokens = nltk.word_tokenize(txt)
        #print("word")
    elif tok_level == "sentence":
        #print("sentence")
        tokens = nltk.sent_tokenize(txt)
    else:
        pass
        #print("this token granularity not supported")
    #tokens = nltk
    for i, _ in enumerate(tokens):
        batch.append(" ".join([s for j, s in enumerate(tokens) if j!=i]))
    with torch.no_grad():
        y_pred = model(txt)
        y_offs = model(batch)
        shaps = (y_offs - y_pred).tolist() # convert to list and make tuple to be returned
    shapss = [s[0] for s in shaps]
    labels = ["+" if s < 0 else "-" for s in shapss]
    return list(zip(tokens, labels))

demo = gr.Interface(shap,
                    [
                         gr.Textbox(label="๊ธฐ์‚ฌ", lines=30, placeholder="๊ธฐ์‚ฌ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”."),
                         gr.Radio(choices=["sentence", "word"], label="ํ•ด์„ค ํ‘œ์‹œ ๋‹จ์œ„", value="sentence", info="๋ฌธ์žฅ ๋‹จ์œ„์˜ ํ•ด์„ค์€ sentence๋ฅผ, ๋‹จ์–ด ๋‹จ์œ„์˜ ํ•ด์„ค์€ word๋ฅผ ์„ ํƒํ•˜์„ธ์š”.")
                    ],
                    #gr.Textbox(label="Test Output"),
                    gr.HighlightedText(
                        label="Diff",
                        combine_adjacent=True,
                        show_legend=True,
                        color_map={"+": "red", "-": "green"}),
                   
                    theme=gr.themes.Base())

demo.launch()