hyomin's picture
Update app.py
a246e1c
import pandas as pd
import numpy as np
# from konlpy.tag import Okt
from string import whitespace, punctuation
import re
import unicodedata
from sentence_transformers import SentenceTransformer, util
import gradio as gr
import pytorch_lightning as pl
import torch
from transformers import PreTrainedTokenizerFast, BartForConditionalGeneration
from transformers import BartForConditionalGeneration, PreTrainedTokenizerFast
from transformers.optimization import get_cosine_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset
# classification
def CleanEnd(text):
email = re.compile(
r'[-_0-9a-z]+@[-_0-9a-z]+(?:\.[0-9a-z]+)+', flags=re.IGNORECASE)
url = re.compile(
r'(?:https?:\/\/)?[-_0-9a-z]+(?:\.[-_0-9a-z]+)+', flags=re.IGNORECASE)
etc = re.compile(
r'\.([^\.]*(?:๊ธฐ์ž|ํŠนํŒŒ์›|๊ต์ˆ˜|์ž‘๊ฐ€|๋Œ€ํ‘œ|๋…ผ์„ค|๊ณ ๋ฌธ|์ฃผํ•„|๋ถ€๋ฌธ์žฅ|ํŒ€์žฅ|์žฅ๊ด€|์›์žฅ|์—ฐ๊ตฌ์›|์ด์‚ฌ์žฅ|์œ„์›|์‹ค์žฅ|์ฐจ์žฅ|๋ถ€์žฅ|์—์„ธ์ด|ํ™”๋ฐฑ|์‚ฌ์„ค|์†Œ์žฅ|๋‹จ์žฅ|๊ณผ์žฅ|๊ธฐํš์ž|ํ๋ ˆ์ดํ„ฐ|์ €์ž‘๊ถŒ|ํ‰๋ก ๊ฐ€|ยฉ|ยฉ|โ“’|\@|\/|=|โ–ถ|๋ฌด๋‹จ|์ „์žฌ|์žฌ๋ฐฐํฌ|๊ธˆ์ง€|\[|\]|\(\))[^\.]*)$')
bracket = re.compile(r'^((?:\[.+\])|(?:ใ€.+ใ€‘)|(?:<.+>)|(?:โ—†.+โ—†)\s)')
result = email.sub('', text)
result = url.sub('', result)
result = etc.sub('.', result)
result = bracket.sub('', result).strip()
return result
def TextFilter(text):
punct = ''.join([chr for chr in punctuation if chr != '%'])
filtering = re.compile(f'[{whitespace}{punct}]+')
onlyText = re.compile(r'[^\% ใ„ฑ-ใ…ฃ๊ฐ€-ํžฃ]+')
result = filtering.sub(' ', text)
result = onlyText.sub(' ', result).strip()
result = filtering.sub(' ', result)
return result
def is_clickbait(title, content, threshold=0.815):
model = SentenceTransformer(
'./model/onlineContrastive')
pattern_whitespace = re.compile(f'[{whitespace}]+')
title = unicodedata.normalize('NFC', re.sub(
pattern_whitespace, ' ', title)).strip()
title = CleanEnd(title)
title = TextFilter(title)
content = unicodedata.normalize('NFC', re.sub(
pattern_whitespace, ' ', content)).strip()
content = CleanEnd(content)
content = TextFilter(content)
# Noun Extraction
# okt = Okt()
# title = ' '.join(okt.nouns(title))
# content = ' '.join(okt.nouns(content))
# Compute embedding
embeddings1 = model.encode(title, convert_to_tensor=True)
embeddings2 = model.encode(content, convert_to_tensor=True)
# Compute cosine-similarities
cosine_score = util.cos_sim(embeddings1, embeddings2)
similarity = cosine_score.numpy()[0][0]
if similarity < threshold:
return 0, similarity # clickbait
else:
return 1, similarity # non-clickbait
# Generation
df_train = pd.DataFrame()
df_train['input_text'] = ['1', '2']
df_train['target_text'] = ['1', '2']
def CleanEnd_g(text):
email = re.compile(
r'[-_0-9a-z]+@[-_0-9a-z]+(?:\.[0-9a-z]+)+', flags=re.IGNORECASE)
# url = re.compile(r'(?:https?:\/\/)?[-_0-9a-z]+(?:\.[-_0-9a-z]+)+', flags=re.IGNORECASE)
# etc = re.compile(r'\.([^\.]*(?:๊ธฐ์ž|ํŠนํŒŒ์›|๊ต์ˆ˜|์ž‘๊ฐ€|๋Œ€ํ‘œ|๋…ผ์„ค|๊ณ ๋ฌธ|์ฃผํ•„|๋ถ€๋ฌธ์žฅ|ํŒ€์žฅ|์žฅ๊ด€|์›์žฅ|์—ฐ๊ตฌ์›|์ด์‚ฌ์žฅ|์œ„์›|์‹ค์žฅ|์ฐจ์žฅ|๋ถ€์žฅ|์—์„ธ์ด|ํ™”๋ฐฑ|์‚ฌ์„ค|์†Œ์žฅ|๋‹จ์žฅ|๊ณผ์žฅ|๊ธฐํš์ž|ํ๋ ˆ์ดํ„ฐ|์ €์ž‘๊ถŒ|ํ‰๋ก ๊ฐ€|ยฉ|ยฉ|โ“’|\@|\/|=|โ–ถ|๋ฌด๋‹จ|์ „์žฌ|์žฌ๋ฐฐํฌ|๊ธˆ์ง€|\[|\]|\(\))[^\.]*)$')
# bracket = re.compile(r'^((?:\[.+\])|(?:ใ€.+ใ€‘)|(?:<.+>)|(?:โ—†.+โ—†)\s)')
result = email.sub('', text)
# result = url.sub('', result)
# result = etc.sub('.', result)
# result = bracket.sub('', result).strip()
return result
class DatasetFromDataframe(Dataset):
def __init__(self, df, dataset_args):
self.data = df
self.max_length = dataset_args['max_length']
self.tokenizer = dataset_args['tokenizer']
self.start_token = '<s>'
self.end_token = '</s>'
def __len__(self):
return len(self.data)
def create_tokens(self, text):
tokens = self.tokenizer.encode(
self.start_token + text + self.end_token)
tokenLength = len(tokens)
remain = self.max_length - tokenLength
if remain >= 0:
tokens = tokens + [self.tokenizer.pad_token_id] * remain
attention_mask = [1] * tokenLength + [0] * remain
else:
tokens = tokens[: self.max_length - 1] + \
self.tokenizer.encode(self.end_token)
attention_mask = [1] * self.max_length
return tokens, attention_mask
def __getitem__(self, index):
record = self.data.iloc[index]
question, answer = record['input_text'], record['target_text']
input_id, input_mask = self.create_tokens(question)
output_id, output_mask = self.create_tokens(answer)
label = output_id[1:(self.max_length + 1)]
label = label + (self.max_length - len(label)) * [-100]
return {
'input_ids': torch.LongTensor(input_id),
'attention_mask': torch.LongTensor(input_mask),
'decoder_input_ids': torch.LongTensor(output_id),
'decoder_attention_mask': torch.LongTensor(output_mask),
"labels": torch.LongTensor(label)
}
class OneSourceDataModule(pl.LightningDataModule):
def __init__(
self,
**kwargs
):
super().__init__()
self.data = kwargs.get('data')
self.dataset_args = kwargs.get("dataset_args")
self.batch_size = kwargs.get("batch_size") or 32
self.train_size = kwargs.get("train_size") or 0.9
def setup(self, stage=""):
# trainset, testset = train_test_split(df_train, train_size=self.train_size, shuffle=True)
self.trainset = DatasetFromDataframe(df_train, self.dataset_args)
self.testset = DatasetFromDataframe(df_train, self.dataset_args)
def train_dataloader(self):
train = DataLoader(
self.trainset,
batch_size=self.batch_size
)
return train
def val_dataloader(self):
val = DataLoader(
self.testset,
batch_size=self.batch_size
)
return val
def test_dataloader(self):
test = DataLoader(
self.testset,
batch_size=self.batch_size
)
return test
class KoBARTConditionalGeneration(pl.LightningModule):
def __init__(self, hparams, **kwargs):
super(KoBARTConditionalGeneration, self).__init__()
self.hparams.update(hparams)
self.model = kwargs['model']
self.tokenizer = kwargs['tokenizer']
self.model.train()
def configure_optimizers(self):
param_optimizer = list(self.model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [{
'params': [
p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
],
'weight_decay': 0.01
}, {
'params': [
p for n, p in param_optimizer if any(nd in n for nd in no_decay)
],
'weight_decay': 0.0
}]
optimizer = torch.optim.AdamW(
optimizer_grouped_parameters,
lr=self.hparams.lr
)
# num_workers = gpus * num_nodes
data_len = len(self.train_dataloader().dataset)
print(f'ํ•™์Šต ๋ฐ์ดํ„ฐ ์–‘: {data_len}')
num_train_steps = int(
data_len / self.hparams.batch_size * self.hparams.max_epochs)
print(f'Step ์ˆ˜: {num_train_steps}')
num_warmup_steps = int(num_train_steps * self.hparams.warmup_ratio)
print(f'Warmup Step ์ˆ˜: {num_warmup_steps}')
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_train_steps
)
lr_scheduler = {
'scheduler': scheduler,
'monitor': 'loss',
'interval': 'step',
'frequency': 1
}
return [optimizer], [lr_scheduler]
def forward(self, inputs):
return self.model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
decoder_input_ids=inputs['decoder_input_ids'],
decoder_attention_mask=inputs['decoder_attention_mask'],
labels=inputs['labels'],
return_dict=True
)
def training_step(self, batch, batch_idx):
loss = self(batch).loss
return loss
def validation_step(self, batch, batch_idx):
loss = self(batch).loss
def test(self, text):
tokens = self.tokenizer.encode("<s>" + text + "</s>")
tokenLength = len(tokens)
remain = self.hparams.max_length - tokenLength
if remain >= 0:
tokens = tokens + [self.tokenizer.pad_token_id] * remain
attention_mask = [1] * tokenLength + [0] * remain
else:
tokens = tokens[: self.hparams.max_length - 1] + \
self.tokenizer.encode("</s>")
attention_mask = [1] * self.hparams.max_length
tokens = torch.LongTensor([tokens])
attention_mask = torch.LongTensor([attention_mask])
self.model = self.model
result = self.model.generate(
tokens,
max_length=self.hparams.max_length,
attention_mask=attention_mask,
num_beams=10
)[0]
a = self.tokenizer.decode(result)
return a
def generation(szContent):
tokenizer = PreTrainedTokenizerFast.from_pretrained(
"gogamza/kobart-summarization")
model1 = BartForConditionalGeneration.from_pretrained(
"gogamza/kobart-summarization")
if len(szContent) > 500:
input_ids = tokenizer.encode(szContent[:500], return_tensors="pt")
else:
input_ids = tokenizer.encode(szContent, return_tensors="pt")
summary = model1.generate(
input_ids=input_ids,
bos_token_id=model1.config.bos_token_id,
eos_token_id=model1.config.eos_token_id,
length_penalty=.3, # bigger than 1= longer, smaller than 1=shorter summary
max_length=35,
min_length=25,
num_beams=5)
szSummary = tokenizer.decode(summary[0], skip_special_tokens=True)
print(szSummary)
KoBARTModel = BartForConditionalGeneration.from_pretrained(
'./model/final2.h5')
BATCH_SIZE = 32
MAX_LENGTH = 128
EPOCHS = 0
model2 = KoBARTConditionalGeneration({
"lr": 5e-6,
"warmup_ratio": 0.1,
"batch_size": BATCH_SIZE,
"max_length": MAX_LENGTH,
"max_epochs": EPOCHS
},
tokenizer=tokenizer,
model=KoBARTModel
)
dm = OneSourceDataModule(
data=df_train,
batch_size=BATCH_SIZE,
train_size=0.9,
dataset_args={
"tokenizer": tokenizer,
"max_length": MAX_LENGTH,
}
)
trainer = pl.Trainer(
max_epochs=EPOCHS,
gpus=0
)
trainer.fit(model2, dm)
szTitle = model2.test(szSummary)
df = pd.DataFrame()
df['newTitle'] = [szTitle]
df['content'] = [szContent]
# White space, punctuation removal
pattern_whitespace = re.compile(f'[{whitespace}]+')
df['newTitle'] = df.newTitle.fillna('').replace(pattern_whitespace, ' ').map(
lambda x: unicodedata.normalize('NFC', x)).str.strip()
df['newTitle'] = df.newTitle.map(CleanEnd_g)
df['newTitle'] = df.newTitle.map(TextFilter)
return df.newTitle[0]
def new_headline(title, content):
label = is_clickbait(title, content)
if label[0] == 0:
return generation(content)
elif label[0] == 1:
return '๋‚š์‹œ์„ฑ ๊ธฐ์‚ฌ๊ฐ€ ์•„๋‹™๋‹ˆ๋‹ค.'
# gradio
with gr.Blocks() as demo1:
gr.Markdown(
"""
<h1 align="center">
clickbait news classifier and new headline generator
</h1>
""")
gr.Markdown(
"""
๋‰ด์Šค ๊ธฐ์‚ฌ ์ œ๋ชฉ๊ณผ ๋ณธ๋ฌธ์„ ์ž…๋ ฅํ•˜๋ฉด ๋‚š์‹œ์„ฑ ๊ธฐ์‚ฌ์ธ์ง€ ๋ถ„๋ฅ˜ํ•˜๊ณ ,
๋‚š์‹œ์„ฑ ๊ธฐ์‚ฌ์ด๋ฉด ์ƒˆ๋กœ์šด ์ œ๋ชฉ์„ ์ƒ์„ฑํ•ด์ฃผ๋Š” ํ”„๋กœ๊ทธ๋žจ์ž…๋‹ˆ๋‹ค.
""")
with gr.Row():
with gr.Column():
inputs = [gr.Textbox(placeholder="๋‰ด์Šค๊ธฐ์‚ฌ ์ œ๋ชฉ์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”", label='headline'),
gr.Textbox(
lines=10, placeholder="๋‰ด์Šค๊ธฐ์‚ฌ ๋ณธ๋ฌธ์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”", label='content')]
with gr.Row():
btn = gr.Button("๊ฒฐ๊ณผ ์ถœ๋ ฅ")
with gr.Column():
output = gr.Text(label='Result')
btn.click(fn=new_headline, inputs=inputs, outputs=output)
if __name__ == "__main__":
demo1.launch()