import streamlit as st import numpy as np import torch from torch.autograd import Variable import argparse import os import re from data_preprocessing import remove_xem_them, remove_emojis, remove_stopwords, format_punctuation, remove_punctuation, clean_text, normalize_format, word_segment, format_price, format_price_v2 class inferSSCL(): def __init__(self, args='None'): self.args = args self.base_models = {} self.batch_data = {} self.test_data = [] def load_vocab_pretrain(self, file_pretrain_vocab, file_pretrain_vec, pad_tokens=True): vocab2id = {'': 0} id2vocab = {0: ''} cnt = len(id2vocab) with open(file_pretrain_vocab, 'r', encoding='utf-8') as fp: for line in fp: arr = re.split(' ', line[:-1]) vocab2id[arr[1]] = cnt id2vocab[cnt] = arr[1] cnt += 1 # word embedding pretrain_vec = np.load(file_pretrain_vec) pad_vec = np.zeros([1, pretrain_vec.shape[1]]) pretrain_vec = np.vstack((pad_vec, pretrain_vec)) return vocab2id, id2vocab, pretrain_vec def load_vocabulary(self): cluster_dir = './' file_wordvec = 'vectors.npy' file_vocab = 'vocab.txt' file_kmeans_centroid = 'aspect_centroid.txt' file_aspect_mapping = 'aspect_mapping.txt' vocab2id, id2vocab, pretrain_vec = self.load_vocab_pretrain(os.path.join(cluster_dir, file_vocab), os.path.join(cluster_dir, file_wordvec)) vocab_size = len(vocab2id) self.batch_data['vocab2id'] = vocab2id self.batch_data['id2vocab'] = id2vocab self.batch_data['pretrain_emb'] = pretrain_vec self.batch_data['vocab_size'] = vocab_size aspect_vec = np.loadtxt(os.path.join(cluster_dir, file_kmeans_centroid), dtype=float) tmp = [] fp = open(os.path.join(cluster_dir, file_aspect_mapping), 'r') for line in fp: line = re.sub(r'[0-9]+', '', line) line = line.replace(' ', '').replace('\n', '') if line == "none": tmp.append([0.] * 256) else : tmp.append([1.] * 256) fp.close() aspect_vec = aspect_vec * tmp aspect_vec = torch.FloatTensor(aspect_vec).to(device) self.batch_data['aspect_centroid'] = aspect_vec self.batch_data['n_aspects'] = aspect_vec.shape[0] def load_models(self): self.base_models['embedding'] = torch.nn.Embedding(self.batch_data['vocab_size'], emb_size).to(device) emb_para = torch.FloatTensor(self.batch_data['pretrain_emb']).to(device) self.base_models['embedding'].weight = torch.nn.Parameter(emb_para) self.base_models['asp_weight'] = torch.nn.Linear(emb_size, self.batch_data['n_aspects']).to(device) self.base_models['asp_weight'].load_state_dict(torch.load('./asp_weight.model')) self.base_models['attn_kernel'] = torch.nn.Linear(emb_size, emb_size).to(device) self.base_models['attn_kernel'].load_state_dict(torch.load('./attn_kernel.model'), strict=False) def build_pipe(self): attn_pos, lbl_pos = self.encoder( self.batch_data['pos_sen_var'], self.batch_data['pos_pad_mask'] ) outw = np.around(attn_pos.data.cpu().numpy().tolist(), 4) outw = outw.tolist() outw = outw[:len(self.batch_data['comment'].split())] asp_weight = self.base_models['asp_weight'](lbl_pos) # Attention weight asp_weight = torch.softmax(asp_weight, dim=1) return asp_weight def encoder(self, input_, mask_): with torch.no_grad(): emb_ = self.base_models['embedding'](input_) print(emb_.shape) emb_ = emb_ * mask_.unsqueeze(2) emb_avg = torch.sum(emb_, dim=1) norm = torch.sum(mask_, dim=1, keepdim=True) + 1e-20 # query vector enc_ = emb_avg.div(norm.expand_as(emb_avg)) #We Ex + be emb_trn = self.base_models['attn_kernel'](emb_) #query vetor * (We Ex + be) attn_ = enc_.unsqueeze(1) @ emb_trn.transpose(1, 2) attn_ = attn_.squeeze(1) #alignment score attn_ = self.args.smooth_factor * torch.tanh(attn_) attn_ = attn_.masked_fill(mask_ == 0, -1e20) # attention weight attn_ = torch.softmax(attn_, dim=1) #sxE lbl_ = attn_.unsqueeze(1) @ emb_ lbl_ = lbl_.squeeze(1) return attn_, lbl_ def build_batch(self, review): vocab2id = self.batch_data['vocab2id'] sen_text = [] cmt = [] # sen_text_len = 0 sen_text_len = emb_size senid = [vocab2id[wd] for wd in review.split() if wd in vocab2id] sen_text.append(senid) cmt.append(review) # if len(senid) > sen_text_len: # sen_text_len = len(senid) sen_text_len = min(len(senid), sen_text_len) sen_text = [itm[:sen_text_len] + [vocab2id[''] for _ in range(sen_text_len - len(itm))] for itm in sen_text] sen_text_var = Variable(torch.LongTensor(sen_text)).to(device) sen_pad_mask = Variable(torch.LongTensor(sen_text)).to(device) sen_pad_mask[sen_pad_mask != vocab2id['']] = -1 sen_pad_mask[sen_pad_mask == vocab2id['']] = 0 sen_pad_mask = -sen_pad_mask self.batch_data['comment'] = cmt self.batch_data['pos_sen_var'] = sen_text_var self.batch_data['pos_pad_mask'] = sen_pad_mask def calculate_atten_weight(self): attn_pos, lbl_pos = self.encoder( self.batch_data['pos_sen_var'], self.batch_data['pos_pad_mask'] ) asp_weight = self.base_models['asp_weight'](lbl_pos) #print('asp_weight:', asp_weight) asp_weight = torch.softmax(asp_weight, dim=1) #print('soft_max:', asp_weight) return asp_weight def get_test_data(self): asp_weight = self.calculate_atten_weight() asp_weight = asp_weight.data.cpu().numpy().tolist() output = {} output['comment'] = self.batch_data['comment'] output['aspect_weight'] = asp_weight[0] self.test_data.append(output) def select_top(self, data): #print(data) d = np.abs(data - np.median(data)) mdev = np.median(d) s = d/mdev if mdev else 0 return s def get_predict(self, top_pred, aspect_label, threshold=1): pred = {'none':0, 'do_an': 0, 'gia_ca':0, 'khong_gian': 0, 'phuc_vu': 0} try: for i in range(len(top_pred)): if top_pred[i] > threshold: pred[aspect_label[i]] = 1 except: print('Error') return pred def get_evaluate_result(self, input_): aspect_label = [] fp = open('./aspect_mapping.txt', 'r', encoding='utf8') for line in fp: aspect_label.append(line.split()[1]) fp.close() top_score = self.select_top(input_['aspect_weight']) print(top_score) curr_pred = self.get_predict(top_score, aspect_label) aspect_key = [] for key, value in curr_pred.items(): if int(value) == 1: aspect_key.append(key) return self.get_aspect(aspect_key) def get_aspect(self, pred, ignore='none'): if len(pred) > 1: return(pred[1:]) else: return(['None']) def infer(self, text=''): self.args.task = 'sscl-infer' text = remove_xem_them(text) text = remove_emojis(text) text = format_punctuation(text) text = remove_punctuation(text) text = clean_text(text) text = normalize_format(text) text = word_segment(text) text = remove_stopwords(text) text = format_price(text) input_ = format_price_v2(text) print(input_) self.load_vocabulary() self.load_models() self.build_batch(input_) self.get_test_data() val_result = self.test_data self.get_evaluate_result(val_result[0]) parser = argparse.ArgumentParser() parser.add_argument('--task', default='infer') parser.add_argument('--smooth_factor', type=float, default=0.9) device = 'cpu' emb_size = 256 args = parser.parse_args(args=[]) model = inferSSCL(args) cmt = st.text_area('Enter some text: ') output = model.infer(cmt) if output: st.title(output)