ko-review / app.py
sigmadream's picture
Update app.py
ee9cdb8
raw
history blame
7.6 kB
import gradio as gr
import fasttext
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import numpy as np
import pandas as pd
import torch
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
title = "ํ•œ๊ตญ์–ด/์˜์–ด ๊ฐ์ • ๋ถ„์„ ์˜ˆ์ œ(๋„ค์ด๋ฒ„ ์˜ํ™” ๋ฆฌ๋ทฐ๋ฅผ ํ™œ์šฉ)"
description = "์˜ํ™”ํ‰์„ ์ž…๋ ฅํ•˜์—ฌ ๊ธ์ •์ ์ธ์ง€ ๋ถ€์ •์ ์ธ์ง€๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. \
ํ•œ๊ตญ์–ด์ธ์ง€ ์˜์–ด์ธ์ง€ ํŒ๋‹จํ•˜๊ณ  ์˜ˆ์ธกํ•ด์ฃผ๋Š” ""Default""๋ผ๋Š” ๋ฒ„์ „๋„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค." \
ํ•œ๊ตญ์–ด ๋ฒ„์ „๊ณผ ์˜์–ด ๋ฒ„์ „ ์ค‘์—์„œ ์„ ํƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
class LanguageIdentification:
def __init__(self):
pretrained_lang_model = "./lid.176.ftz"
self.model = fasttext.load_model(pretrained_lang_model)
def predict_lang(self, text):
predictions = self.model.predict(text, k=200) # returns top 200 matching languages
return predictions
LANGUAGE = LanguageIdentification()
def tokenized_data(tokenizer, inputs):
return tokenizer.batch_encode_plus(
[inputs],
return_tensors="pt",
padding="max_length",
max_length=64,
truncation=True)
examples = []
df = pd.read_csv('examples.csv', sep='\t', index_col='Unnamed: 0')
np.random.seed(100)
idx = np.random.choice(50, size=5, replace=False)
eng_examples = [ ['Eng', df.iloc[i, 0]] for i in idx ]
kor_examples = [ ['Kor', df.iloc[i, 1]] for i in idx ]
examples = eng_examples + kor_examples
eng_model_name = "roberta-base"
eng_step = 1900
eng_tokenizer = AutoTokenizer.from_pretrained(eng_model_name)
eng_file_name = "{}-{}.pt".format(eng_model_name, eng_step)
eng_state_dict = torch.load(eng_file_name)
eng_model = AutoModelForSequenceClassification.from_pretrained(
eng_model_name, num_labels=2, id2label=id2label, label2id=label2id,
state_dict=eng_state_dict
)
kor_model_name = "klue/roberta-small"
kor_step = 2400
kor_tokenizer = AutoTokenizer.from_pretrained(kor_model_name)
kor_file_name = "{}-{}.pt".format(kor_model_name.replace('/', '_'), kor_step)
kor_state_dict = torch.load(kor_file_name)
kor_model = AutoModelForSequenceClassification.from_pretrained(
kor_model_name, num_labels=2, id2label=id2label, label2id=label2id,
state_dict=kor_state_dict
)
def builder(Lang, Text):
percent_kor, percent_eng = 0, 0
text_list = Text.split(' ')
# [ output_1 ]
if Lang == '์–ธ์–ด๊ฐ์ง€ ๊ธฐ๋Šฅ ์‚ฌ์šฉ':
pred = LANGUAGE.predict_lang(Text)
if '__label__en' in pred[0]:
Lang = 'Eng'
idx = pred[0].index('__label__en')
p_eng = pred[1][idx]
if '__label__ko' in pred[0]:
Lang = 'Kor'
idx = pred[0].index('__label__ko')
p_kor = pred[1][idx]
# Normalize Percentage
percent_kor = p_kor / (p_kor+p_eng)
percent_eng = p_eng / (p_kor+p_eng)
if Lang == 'Eng':
model = eng_model
tokenizer = eng_tokenizer
if percent_eng==0: percent_eng=1
if Lang == 'Kor':
model = kor_model
tokenizer = kor_tokenizer
if percent_kor==0: percent_kor=1
# [ output_2 ]
inputs = tokenized_data(tokenizer, Text)
model.eval()
with torch.no_grad():
logits = model(input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask']).logits
m = torch.nn.Softmax(dim=1)
output = m(logits)
# print(logits, output)
# [ output_3 ]
output_analysis = []
for word in text_list:
tokenized_word = tokenized_data(tokenizer, word)
with torch.no_grad():
logit = model(input_ids=tokenized_word['input_ids'],
attention_mask=tokenized_word['attention_mask']).logits
word_output = m(logit)
if word_output[0][1] > 0.99:
output_analysis.append( (word, '+++') )
elif word_output[0][1] > 0.9:
output_analysis.append( (word, '++') )
elif word_output[0][1] > 0.8:
output_analysis.append( (word, '+') )
elif word_output[0][1] < 0.01:
output_analysis.append( (word, '---') )
elif word_output[0][1] < 0.1:
output_analysis.append( (word, '--') )
elif word_output[0][1] < 0.2:
output_analysis.append( (word, '-') )
else:
output_analysis.append( (word, None) )
return [ {'Kor': percent_kor, 'Eng': percent_eng},
{id2label[1]: output[0][1].item(), id2label[0]: output[0][0].item()},
output_analysis ]
# prediction = torch.argmax(logits, axis=1)
return id2label[prediction.item()]
with gr.Blocks() as demo1:
gr.Markdown(
"""
<h1 align="center">
ํ•œ๊ตญ์–ด/์˜์–ด ๊ฐ์ • ๋ถ„์„ ์˜ˆ์ œ(๋„ค์ด๋ฒ„ ์˜ํ™” ๋ฆฌ๋ทฐ๋ฅผ ํ™œ์šฉ)
</h1>
""")
gr.Markdown(
"""
์˜ํ™” ๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด, ๊ธ์ •์ ์ธ ๊ฐ์ •์ธ์ง€ ๋ถ€์ •์ ์ธ ๊ฐ์ •์ธ์ง€ ํŒ๋ณ„ํ•˜๋Š” ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. \
์˜์–ด์™€ ํ•œ๊ธ€์„ ์ง€์›ํ•˜๋ฉฐ, ์–ธ์–ด๋ฅผ ์ง์ ‘ ์„ ํƒํ• ์ˆ˜๋„, ํ˜น์€ ๋ชจ๋ธ์ด ์–ธ์–ด๊ฐ์ง€๋ฅผ ์ง์ ‘ ํ•˜๋„๋ก ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด, (1) ๊ฐ์ง€๋œ ์–ธ์–ด, (2) ๊ธ์ • ๋ฆฌ๋ทฐ์ผ ํ™•๋ฅ ๊ณผ ๋ถ€์ • ๋ฆฌ๋ทฐ์ผ ํ™•๋ฅ , (3) ์ž…๋ ฅ๋œ ๋ฆฌ๋ทฐ์˜ ์–ด๋Š ๋‹จ์–ด๊ฐ€ ๊ธ์ •/๋ถ€์ • ๊ฒฐ์ •์— ์˜ํ–ฅ์„ ์ฃผ์—ˆ๋Š”์ง€ \
(๊ธ์ •์ผ ๊ฒฝ์šฐ ๋นจ๊ฐ•์ƒ‰, ๋ถ€์ •์ผ ๊ฒฝ์šฐ ํŒŒ๋ž€์ƒ‰)๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
""")
with gr.Accordion(label="๋ชจ๋ธ์— ๋Œ€ํ•œ ์„ค๋ช… ( ์—ฌ๊ธฐ๋ฅผ ํด๋ฆญ ํ•˜์‹œ์˜ค. )", open=False):
gr.Markdown(
"""
์˜์–ด ๋ชจ๋ธ์€ bert-base-uncased ๊ธฐ๋ฐ˜์œผ๋กœ, ์˜์–ด ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„์„ ๋ฐ์ดํ„ฐ์…‹์ธ SST-2๋กœ ํ•™์Šต ๋ฐ ํ‰๊ฐ€๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
ํ•œ๊ธ€ ๋ชจ๋ธ์€ klue/roberta-base ๊ธฐ๋ฐ˜์ด๋‹ค. ๊ธฐ์กด ํ•œ๊ธ€ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„์„ ๋ฐ์ดํ„ฐ์…‹์ด ์กด์žฌํ•˜์ง€ ์•Š์•„, ๋„ค์ด๋ฒ„ ์˜ํ™”์˜ ๋ฆฌ๋ทฐ๋ฅผ ํฌ๋กค๋งํ•ด์„œ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„์„ ๋ฐ์ดํ„ฐ์…‹์„ ์ œ์ž‘ํ•˜๊ณ , ์ด๋ฅผ ์ด์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต ๋ฐ ํ‰๊ฐ€ํ•˜์˜€์Šต๋‹ˆ๋‹ค.
์˜์–ด ๋ชจ๋ธ์€ SST-2์—์„œ 92.8%, ํ•œ๊ธ€ ๋ชจ๋ธ์€ ๋„ค์ด๋ฒ„ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ฐ์ดํ„ฐ์…‹์—์„œ 94%์˜ ์ •ํ™•๋„๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค(test set ๊ธฐ์ค€).
์–ธ์–ด๊ฐ์ง€๋Š” fasttext์˜ language detector๋ฅผ ์‚ฌ์šฉํ•˜์˜€๋‹ค. ๋ฆฌ๋ทฐ์˜ ๋‹จ์–ด๋ณ„ ์˜ํ–ฅ๋ ฅ์€, ๋‹จ์–ด ๊ฐ๊ฐ์„ ๋ชจ๋ธ์— ๋„ฃ์—ˆ์„ ๋•Œ ๊ฒฐ๊ณผ๊ฐ€ ๊ธ์ •์œผ๋กœ ๋‚˜์˜ค๋Š”์ง€ ๋ถ€์ •์œผ๋กœ ๋‚˜์˜ค๋Š”์ง€๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์ธก์ •ํ•˜์˜€์Šต๋‹ˆ๋‹ค.
""")
with gr.Row():
with gr.Column():
inputs_1 = gr.Dropdown(choices=['์–ธ์–ด๊ฐ์ง€ ๊ธฐ๋Šฅ ์‚ฌ์šฉ', 'Eng', 'Kor'], value='์–ธ์–ด๊ฐ์ง€ ๊ธฐ๋Šฅ ์‚ฌ์šฉ', label='Lang')
inputs_2 = gr.Textbox(placeholder="๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜์‹œ์˜ค.", label='Text')
with gr.Row():
# btn2 = gr.Button("ํด๋ฆฌ์–ด")
btn = gr.Button("์ œ์ถœํ•˜๊ธฐ")
with gr.Column():
output_1 = gr.Label(num_top_classes=3, label='Lang')
output_2 = gr.Label(num_top_classes=2, label='Result')
output_3 = gr.HighlightedText(label="Analysis", combine_adjacent=False) \
.style(color_map={"+++": "#CF0000", "++": "#FF3232", "+": "#FFD4D4", "---": "#0004FE", "--": "#4C47FF", "-": "#BEBDFF"})
# btn2.click(fn=fn2, inputs=[None, None], output=[output_1, output_2, output_3])
btn.click(fn=builder, inputs=[inputs_1, inputs_2], outputs=[output_1, output_2, output_3])
gr.Examples(examples, inputs=[inputs_1, inputs_2])
if __name__ == "__main__":
# print(examples)
# demo.launch()
demo1.launch()