ko-review / app.py
sigmadream's picture
Update app.py
fbdc62b
import gradio as gr
import fasttext
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import numpy as np
import pandas as pd
import torch
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}
title = "์˜ํ™” ๋ฆฌ๋ทฐ ์ ์ˆ˜ ํŒ๋ณ„๊ธฐ"
description = "์˜ํ™”ํ‰์„ ์ž…๋ ฅํ•˜์—ฌ ๊ธ์ •์ ์ธ์ง€ ๋ถ€์ •์ ์ธ์ง€๋ฅผ ๋ถ„๋ฅ˜ํ•˜๋Š” ํ”„๋กœ๊ทธ๋žจ์ž…๋‹ˆ๋‹ค. \
ํ•œ๊ตญ์–ด ๋ฒ„์ „๊ณผ ์˜์–ด ๋ฒ„์ „ ์ค‘์—์„œ ์„ ํƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. \
ํ•œ๊ตญ์–ด์ธ์ง€ ์˜์–ด์ธ์ง€ ํŒ๋‹จํ•˜๊ณ  ์˜ˆ์ธกํ•ด์ฃผ๋Š” ""Default""๋ผ๋Š” ๋ฒ„์ „๋„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค."
class LanguageIdentification:
def __init__(self):
pretrained_lang_model = "./lid.176.ftz"
self.model = fasttext.load_model(pretrained_lang_model)
def predict_lang(self, text):
predictions = self.model.predict(text, k=200) # returns top 200 matching languages
return predictions
LANGUAGE = LanguageIdentification()
def tokenized_data(tokenizer, inputs):
return tokenizer.batch_encode_plus(
[inputs],
return_tensors="pt",
padding="max_length",
max_length=64,
truncation=True)
examples = []
df = pd.read_csv('examples.csv', sep='\t', index_col='Unnamed: 0')
np.random.seed(100)
idx = np.random.choice(50, size=5, replace=False)
eng_examples = [ ['Eng', df.iloc[i, 0]] for i in idx ]
kor_examples = [ ['Kor', df.iloc[i, 1]] for i in idx ]
examples = eng_examples + kor_examples
eng_model_name = "roberta-base"
eng_step = 1900
eng_tokenizer = AutoTokenizer.from_pretrained(eng_model_name)
eng_file_name = "{}-{}.pt".format(eng_model_name, eng_step)
eng_state_dict = torch.load(eng_file_name)
eng_model = AutoModelForSequenceClassification.from_pretrained(
eng_model_name, num_labels=2, id2label=id2label, label2id=label2id,
state_dict=eng_state_dict
)
kor_model_name = "klue/roberta-small"
kor_step = 2400
kor_tokenizer = AutoTokenizer.from_pretrained(kor_model_name)
kor_file_name = "{}-{}.pt".format(kor_model_name.replace('/', '_'), kor_step)
kor_state_dict = torch.load(kor_file_name)
kor_model = AutoModelForSequenceClassification.from_pretrained(
kor_model_name, num_labels=2, id2label=id2label, label2id=label2id,
state_dict=kor_state_dict
)
def builder(Lang, Text):
percent_kor, percent_eng = 0, 0
text_list = Text.split(' ')
# [ output_1 ]
if Lang == '์–ธ์–ด๊ฐ์ง€ ๊ธฐ๋Šฅ ์‚ฌ์šฉ':
pred = LANGUAGE.predict_lang(Text)
if '__label__en' in pred[0]:
Lang = 'Eng'
idx = pred[0].index('__label__en')
p_eng = pred[1][idx]
if '__label__ko' in pred[0]:
Lang = 'Kor'
idx = pred[0].index('__label__ko')
p_kor = pred[1][idx]
# Normalize Percentage
percent_kor = p_kor / (p_kor+p_eng)
percent_eng = p_eng / (p_kor+p_eng)
if Lang == 'Eng':
model = eng_model
tokenizer = eng_tokenizer
if percent_eng==0: percent_eng=1
if Lang == 'Kor':
model = kor_model
tokenizer = kor_tokenizer
if percent_kor==0: percent_kor=1
# [ output_2 ]
inputs = tokenized_data(tokenizer, Text)
model.eval()
with torch.no_grad():
logits = model(input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask']).logits
m = torch.nn.Softmax(dim=1)
output = m(logits)
# print(logits, output)
# [ output_3 ]
output_analysis = []
for word in text_list:
tokenized_word = tokenized_data(tokenizer, word)
with torch.no_grad():
logit = model(input_ids=tokenized_word['input_ids'],
attention_mask=tokenized_word['attention_mask']).logits
word_output = m(logit)
if word_output[0][1] > 0.99:
output_analysis.append( (word, '+++') )
elif word_output[0][1] > 0.9:
output_analysis.append( (word, '++') )
elif word_output[0][1] > 0.8:
output_analysis.append( (word, '+') )
elif word_output[0][1] < 0.01:
output_analysis.append( (word, '---') )
elif word_output[0][1] < 0.1:
output_analysis.append( (word, '--') )
elif word_output[0][1] < 0.2:
output_analysis.append( (word, '-') )
else:
output_analysis.append( (word, None) )
return [ {'Kor': percent_kor, 'Eng': percent_eng},
{id2label[1]: output[0][1].item(), id2label[0]: output[0][0].item()},
output_analysis ]
# prediction = torch.argmax(logits, axis=1)
return id2label[prediction.item()]
# demo3 = gr.Interface.load("models/mdj1412/movie_review_score_discriminator_eng", inputs="text", outputs="text",
# title=title, theme="peach",
# allow_flagging="auto",
# description=description, examples=examples)
# demo = gr.Interface(builder, inputs=[gr.inputs.Dropdown(['Default', 'Eng', 'Kor']), gr.Textbox(placeholder="๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜์‹œ์˜ค.")],
# outputs=[ gr.Label(num_top_classes=3, label='Lang'),
# gr.Label(num_top_classes=2, label='Result'),
# gr.HighlightedText(label="Analysis", combine_adjacent=False)
# .style(color_map={"+++": "#CF0000", "++": "#FF3232", "+": "#FFD4D4", "---": "#0004FE", "--": "#4C47FF", "-": "#BEBDFF"}) ],
# # outputs='label',
# title=title, description=description, examples=examples)
with gr.Blocks() as demo1:
gr.Markdown(
"""
<h1 align="center">
์˜ํ™” ๋ฆฌ๋ทฐ ์ ์ˆ˜ ํŒ๋ณ„๊ธฐ
</h1>
""")
gr.Markdown(
"""
์˜ํ™” ๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด, ๋ฆฌ๋ทฐ๊ฐ€ ๊ธ์ •์ธ์ง€ ๋ถ€์ •์ธ์ง€ ํŒ๋ณ„ํ•ด์ฃผ๋Š” ๋ชจ๋ธ์ด๋‹ค. \
์˜์–ด์™€ ํ•œ๊ธ€์„ ์ง€์›ํ•˜๋ฉฐ, ์–ธ์–ด๋ฅผ ์ง์ ‘ ์„ ํƒํ• ์ˆ˜๋„, ํ˜น์€ ๋ชจ๋ธ์ด ์–ธ์–ด๊ฐ์ง€๋ฅผ ์ง์ ‘ ํ•˜๋„๋ก ํ•  ์ˆ˜ ์žˆ๋‹ค.
๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜๋ฉด, (1) ๊ฐ์ง€๋œ ์–ธ์–ด, (2) ๊ธ์ • ๋ฆฌ๋ทฐ์ผ ํ™•๋ฅ ๊ณผ ๋ถ€์ • ๋ฆฌ๋ทฐ์ผ ํ™•๋ฅ , (3) ์ž…๋ ฅ๋œ ๋ฆฌ๋ทฐ์˜ ์–ด๋Š ๋‹จ์–ด๊ฐ€ ๊ธ์ •/๋ถ€์ • ๊ฒฐ์ •์— ์˜ํ–ฅ์„ ์ฃผ์—ˆ๋Š”์ง€ \
(๊ธ์ •์ผ ๊ฒฝ์šฐ ๋นจ๊ฐ•์ƒ‰, ๋ถ€์ •์ผ ๊ฒฝ์šฐ ํŒŒ๋ž€์ƒ‰)๋ฅผ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.
""")
with gr.Accordion(label="๋ชจ๋ธ์— ๋Œ€ํ•œ ์„ค๋ช… ( ์—ฌ๊ธฐ๋ฅผ ํด๋ฆญ ํ•˜์‹œ์˜ค. )", open=False):
gr.Markdown(
"""
์˜์–ด ๋ชจ๋ธ์€ bert-base-uncased ๊ธฐ๋ฐ˜์œผ๋กœ, ์˜์–ด ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„์„ ๋ฐ์ดํ„ฐ์…‹์ธ SST-2๋กœ ํ•™์Šต ๋ฐ ํ‰๊ฐ€๋˜์—ˆ๋‹ค.
ํ•œ๊ธ€ ๋ชจ๋ธ์€ klue/roberta-base ๊ธฐ๋ฐ˜์ด๋‹ค. ๊ธฐ์กด ํ•œ๊ธ€ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„์„ ๋ฐ์ดํ„ฐ์…‹์ด ์กด์žฌํ•˜์ง€ ์•Š์•„, ๋„ค์ด๋ฒ„ ์˜ํ™”์˜ ๋ฆฌ๋ทฐ๋ฅผ ํฌ๋กค๋งํ•ด์„œ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ถ„์„ ๋ฐ์ดํ„ฐ์…‹์„ ์ œ์ž‘ํ•˜๊ณ , ์ด๋ฅผ ์ด์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ํ•™์Šต ๋ฐ ํ‰๊ฐ€ํ•˜์˜€๋‹ค.
์˜์–ด ๋ชจ๋ธ์€ SST-2์—์„œ 92.8%, ํ•œ๊ธ€ ๋ชจ๋ธ์€ ๋„ค์ด๋ฒ„ ์˜ํ™” ๋ฆฌ๋ทฐ ๋ฐ์ดํ„ฐ์…‹์—์„œ 94%์˜ ์ •ํ™•๋„๋ฅผ ๊ฐ€์ง„๋‹ค (test set ๊ธฐ์ค€).
์–ธ์–ด๊ฐ์ง€๋Š” fasttext์˜ language detector๋ฅผ ์‚ฌ์šฉํ•˜์˜€๋‹ค. ๋ฆฌ๋ทฐ์˜ ๋‹จ์–ด๋ณ„ ์˜ํ–ฅ๋ ฅ์€, ๋‹จ์–ด ๊ฐ๊ฐ์„ ๋ชจ๋ธ์— ๋„ฃ์—ˆ์„ ๋•Œ ๊ฒฐ๊ณผ๊ฐ€ ๊ธ์ •์œผ๋กœ ๋‚˜์˜ค๋Š”์ง€ ๋ถ€์ •์œผ๋กœ ๋‚˜์˜ค๋Š”์ง€๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์ธก์ •ํ•˜์˜€๋‹ค.
""")
with gr.Row():
with gr.Column():
inputs_1 = gr.Dropdown(choices=['์–ธ์–ด๊ฐ์ง€ ๊ธฐ๋Šฅ ์‚ฌ์šฉ', 'Eng', 'Kor'], value='์–ธ์–ด๊ฐ์ง€ ๊ธฐ๋Šฅ ์‚ฌ์šฉ', label='Lang')
inputs_2 = gr.Textbox(placeholder="๋ฆฌ๋ทฐ๋ฅผ ์ž…๋ ฅํ•˜์‹œ์˜ค.", label='Text')
with gr.Row():
# btn2 = gr.Button("ํด๋ฆฌ์–ด")
btn = gr.Button("์ œ์ถœํ•˜๊ธฐ")
with gr.Column():
output_1 = gr.Label(num_top_classes=3, label='Lang')
output_2 = gr.Label(num_top_classes=2, label='Result')
output_3 = gr.HighlightedText(label="Analysis", combine_adjacent=False) \
.style(color_map={"+++": "#CF0000", "++": "#FF3232", "+": "#FFD4D4", "---": "#0004FE", "--": "#4C47FF", "-": "#BEBDFF"})
# btn2.click(fn=fn2, inputs=[None, None], output=[output_1, output_2, output_3])
btn.click(fn=builder, inputs=[inputs_1, inputs_2], outputs=[output_1, output_2, output_3])
gr.Examples(examples, inputs=[inputs_1, inputs_2])
if __name__ == "__main__":
# print(examples)
# demo.launch()
demo1.launch()