|
import streamlit as st |
|
|
|
import numpy as np |
|
|
|
import torch |
|
from torch.autograd import Variable |
|
|
|
import argparse |
|
import os |
|
import re |
|
|
|
from data_preprocessing import remove_xem_them, remove_emojis, remove_stopwords, format_punctuation, remove_punctuation, clean_text, normalize_format, word_segment, format_price, format_price_v2 |
|
|
|
class inferSSCL(): |
|
def __init__(self, args='None'): |
|
self.args = args |
|
self.base_models = {} |
|
self.batch_data = {} |
|
self.test_data = [] |
|
self.output = [] |
|
|
|
def load_vocab_pretrain(self, file_pretrain_vocab, file_pretrain_vec, pad_tokens=True): |
|
vocab2id = {'<pad>': 0} |
|
id2vocab = {0: '<pad>'} |
|
|
|
cnt = len(id2vocab) |
|
with open(file_pretrain_vocab, 'r', encoding='utf-8') as fp: |
|
for line in fp: |
|
arr = re.split(' ', line[:-1]) |
|
vocab2id[arr[1]] = cnt |
|
id2vocab[cnt] = arr[1] |
|
cnt += 1 |
|
|
|
pretrain_vec = np.load(file_pretrain_vec) |
|
pad_vec = np.zeros([1, pretrain_vec.shape[1]]) |
|
pretrain_vec = np.vstack((pad_vec, pretrain_vec)) |
|
return vocab2id, id2vocab, pretrain_vec |
|
|
|
def load_vocabulary(self): |
|
cluster_dir = './' |
|
file_wordvec = 'vectors.npy' |
|
file_vocab = 'vocab.txt' |
|
file_kmeans_centroid = 'aspect_centroid.txt' |
|
file_aspect_mapping = 'aspect_mapping.txt' |
|
|
|
vocab2id, id2vocab, pretrain_vec = self.load_vocab_pretrain(os.path.join(cluster_dir, file_vocab), os.path.join(cluster_dir, file_wordvec)) |
|
vocab_size = len(vocab2id) |
|
|
|
self.batch_data['vocab2id'] = vocab2id |
|
self.batch_data['id2vocab'] = id2vocab |
|
self.batch_data['pretrain_emb'] = pretrain_vec |
|
self.batch_data['vocab_size'] = vocab_size |
|
|
|
aspect_vec = np.loadtxt(os.path.join(cluster_dir, file_kmeans_centroid), dtype=float) |
|
|
|
tmp = [] |
|
fp = open(os.path.join(cluster_dir, file_aspect_mapping), 'r') |
|
for line in fp: |
|
line = re.sub(r'[0-9]+', '', line) |
|
line = line.replace(' ', '').replace('\n', '') |
|
if line == "none": |
|
tmp.append([0.] * 256) |
|
else : |
|
tmp.append([1.] * 256) |
|
fp.close() |
|
|
|
aspect_vec = aspect_vec * tmp |
|
aspect_vec = torch.FloatTensor(aspect_vec).to(device) |
|
self.batch_data['aspect_centroid'] = aspect_vec |
|
self.batch_data['n_aspects'] = aspect_vec.shape[0] |
|
|
|
def load_models(self): |
|
self.base_models['embedding'] = torch.nn.Embedding(self.batch_data['vocab_size'], emb_size).to(device) |
|
emb_para = torch.FloatTensor(self.batch_data['pretrain_emb']).to(device) |
|
self.base_models['embedding'].weight = torch.nn.Parameter(emb_para) |
|
|
|
self.base_models['asp_weight'] = torch.nn.Linear(emb_size, self.batch_data['n_aspects']).to(device) |
|
self.base_models['asp_weight'].load_state_dict(torch.load('./asp_weight.model', map_location=torch.device('cpu'))) |
|
|
|
self.base_models['attn_kernel'] = torch.nn.Linear(emb_size, emb_size).to(device) |
|
self.base_models['attn_kernel'].load_state_dict(torch.load('./attn_kernel.model', map_location=torch.device('cpu')), strict=False) |
|
|
|
|
|
def build_pipe(self): |
|
|
|
attn_pos, lbl_pos = self.encoder( |
|
self.batch_data['pos_sen_var'], |
|
self.batch_data['pos_pad_mask'] |
|
) |
|
|
|
outw = np.around(attn_pos.data.cpu().numpy().tolist(), 4) |
|
outw = outw.tolist() |
|
outw = outw[:len(self.batch_data['comment'].split())] |
|
|
|
asp_weight = self.base_models['asp_weight'](lbl_pos) |
|
|
|
asp_weight = torch.softmax(asp_weight, dim=1) |
|
|
|
return asp_weight |
|
|
|
def encoder(self, input_, mask_): |
|
|
|
with torch.no_grad(): |
|
emb_ = self.base_models['embedding'](input_) |
|
|
|
print(emb_.shape) |
|
|
|
emb_ = emb_ * mask_.unsqueeze(2) |
|
|
|
emb_avg = torch.sum(emb_, dim=1) |
|
norm = torch.sum(mask_, dim=1, keepdim=True) + 1e-20 |
|
|
|
|
|
enc_ = emb_avg.div(norm.expand_as(emb_avg)) |
|
|
|
|
|
emb_trn = self.base_models['attn_kernel'](emb_) |
|
|
|
|
|
attn_ = enc_.unsqueeze(1) @ emb_trn.transpose(1, 2) |
|
attn_ = attn_.squeeze(1) |
|
|
|
|
|
attn_ = self.args.smooth_factor * torch.tanh(attn_) |
|
|
|
attn_ = attn_.masked_fill(mask_ == 0, -1e20) |
|
|
|
|
|
attn_ = torch.softmax(attn_, dim=1) |
|
|
|
|
|
lbl_ = attn_.unsqueeze(1) @ emb_ |
|
lbl_ = lbl_.squeeze(1) |
|
|
|
return attn_, lbl_ |
|
|
|
def build_batch(self, review): |
|
vocab2id = self.batch_data['vocab2id'] |
|
|
|
sen_text = [] |
|
cmt = [] |
|
|
|
sen_text_len = emb_size |
|
|
|
senid = [vocab2id[wd] for wd in review.split() if wd in vocab2id] |
|
sen_text.append(senid) |
|
|
|
cmt.append(review) |
|
|
|
|
|
|
|
sen_text_len = min(len(senid), sen_text_len) |
|
sen_text = [itm[:sen_text_len] + [vocab2id['<pad>'] for _ in range(sen_text_len - len(itm))] for itm in sen_text] |
|
|
|
sen_text_var = Variable(torch.LongTensor(sen_text)).to(device) |
|
sen_pad_mask = Variable(torch.LongTensor(sen_text)).to(device) |
|
sen_pad_mask[sen_pad_mask != vocab2id['<pad>']] = -1 |
|
sen_pad_mask[sen_pad_mask == vocab2id['<pad>']] = 0 |
|
sen_pad_mask = -sen_pad_mask |
|
|
|
self.batch_data['comment'] = cmt |
|
|
|
self.batch_data['pos_sen_var'] = sen_text_var |
|
self.batch_data['pos_pad_mask'] = sen_pad_mask |
|
|
|
def calculate_atten_weight(self): |
|
|
|
attn_pos, lbl_pos = self.encoder( |
|
self.batch_data['pos_sen_var'], |
|
self.batch_data['pos_pad_mask'] |
|
) |
|
|
|
|
|
asp_weight = self.base_models['asp_weight'](lbl_pos) |
|
|
|
asp_weight = torch.softmax(asp_weight, dim=1) |
|
|
|
|
|
return asp_weight |
|
|
|
def get_test_data(self): |
|
asp_weight = self.calculate_atten_weight() |
|
asp_weight = asp_weight.data.cpu().numpy().tolist() |
|
|
|
output = {} |
|
output['comment'] = self.batch_data['comment'] |
|
output['aspect_weight'] = asp_weight[0] |
|
self.test_data.append(output) |
|
|
|
def select_top(self, data): |
|
|
|
d = np.abs(data - np.median(data)) |
|
mdev = np.median(d) |
|
s = d/mdev if mdev else 0 |
|
|
|
return s |
|
|
|
def get_predict(self, top_pred, aspect_label, threshold=3): |
|
pred = {'none':0, 'do_an': 0, 'gia_ca':0, 'khong_gian': 0, 'phuc_vu': 0} |
|
try: |
|
for i in range(len(top_pred)): |
|
if top_pred[i] > threshold: |
|
pred[aspect_label[i]] = 1 |
|
except: |
|
print('Error') |
|
return pred |
|
|
|
def get_evaluate_result(self, input_): |
|
|
|
aspect_label = [] |
|
fp = open('./aspect_mapping.txt', 'r', encoding='utf8') |
|
for line in fp: |
|
aspect_label.append(line.split()[1]) |
|
fp.close() |
|
|
|
top_score = self.select_top(input_['aspect_weight']) |
|
print(top_score) |
|
curr_pred = self.get_predict(top_score, aspect_label) |
|
|
|
aspect_key = [] |
|
for key, value in curr_pred.items(): |
|
if int(value) == 1: |
|
aspect_key.append(key) |
|
|
|
return self.get_aspect(aspect_key) |
|
|
|
def get_aspect(self, pred, ignore='none'): |
|
if len(pred) > 1: |
|
self.output.append(pred[1:]) |
|
else: |
|
self.output.append(['None']) |
|
|
|
def infer(self, text=''): |
|
self.args.task = 'sscl-infer' |
|
|
|
text = remove_xem_them(text) |
|
text = remove_emojis(text) |
|
text = format_punctuation(text) |
|
text = remove_punctuation(text) |
|
text = clean_text(text) |
|
text = normalize_format(text) |
|
text = word_segment(text) |
|
text = remove_stopwords(text) |
|
text = format_price(text) |
|
input_ = format_price_v2(text) |
|
print(input_) |
|
|
|
self.load_vocabulary() |
|
self.load_models() |
|
|
|
self.build_batch(input_) |
|
self.get_test_data() |
|
|
|
val_result = self.test_data |
|
|
|
self.get_evaluate_result(val_result[0]) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--task', default='infer') |
|
parser.add_argument('--smooth_factor', type=float, default=0.9) |
|
device = 'cpu' |
|
emb_size = 256 |
|
|
|
args = parser.parse_args(args=[]) |
|
model = inferSSCL(args) |
|
|
|
cmt = st.text_area('Nhập nhận xét của bạn vào đây:') |
|
if cmt == '': |
|
st.title('Nội dung bình luận của bạn!') |
|
else: |
|
model.infer(cmt) |
|
|
|
outputs = model.output[0] |
|
if outputs: |
|
for output in outputs: |
|
if output == 'do_an': |
|
st.title(':blue[Đồ ăn]') |
|
elif output == 'gia_ca': |
|
st.title(':blue[Giá cả]') |
|
elif output == 'khong_gian': |
|
st.title(':blue[Không gian]') |
|
elif output == 'phuc_vu': |
|
st.title(':blue[Phục vụ]') |
|
else: |
|
st.title('None') |
|
st.divider() |