|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import pickle |
|
import json |
|
import numpy as np |
|
import pandas as pd |
|
from rouge import Rouge |
|
import string |
|
import copy |
|
import re |
|
from transformers import AutoModel, AutoTokenizer |
|
from underthesea import sent_tokenize, word_tokenize |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
phobert = AutoModel.from_pretrained("vinai/phobert-base-v2").to(device) |
|
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2") |
|
|
|
def getRouge2(ref, pred, kind): |
|
try: |
|
return round(Rouge().get_scores(pred.lower(), ref.lower())[0]['rouge-2'][kind], 4) |
|
except ValueError: |
|
return 0.0 |
|
|
|
class MLP(nn.Module): |
|
def __init__(self, dims: list, layers=2, act=nn.LeakyReLU(), dropout_p=0.3, keep_last_layer=False): |
|
super(MLP, self).__init__() |
|
assert len(dims) == layers + 1 |
|
self.layers = layers |
|
self.act = act |
|
self.dropout = nn.Dropout(dropout_p) |
|
self.keep_last = keep_last_layer |
|
|
|
self.mlp_layers = nn.ModuleList([]) |
|
for i in range(self.layers): |
|
self.mlp_layers.append(nn.Linear(dims[i], dims[i + 1])) |
|
|
|
def forward(self, x): |
|
for i in range(len(self.mlp_layers) - 1): |
|
x = self.dropout(self.act(self.mlp_layers[i](x))) |
|
if self.keep_last: |
|
x = self.mlp_layers[-1](x) |
|
else: |
|
x = self.act(self.mlp_layers[-1](x)) |
|
return x |
|
|
|
class GraphAttentionLayer(nn.Module): |
|
def __init__(self, in_features: int, out_features: int, n_heads: int, |
|
is_concat: bool = True, |
|
dropout: float = 0.6, |
|
leaky_relu_negative_slope: float = 0.2): |
|
super().__init__() |
|
|
|
self.is_concat = is_concat |
|
self.n_heads = n_heads |
|
|
|
|
|
if is_concat: |
|
assert out_features % n_heads == 0 |
|
self.n_hidden = out_features // n_heads |
|
else: |
|
self.n_hidden = out_features |
|
|
|
self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False) |
|
self.attn = nn.Linear(self.n_hidden * 2, 1, bias=False) |
|
self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) |
|
self.softmax = nn.Softmax(dim=1) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, h: torch.Tensor, adj_mat: torch.Tensor, docnum, secnum): |
|
n_nodes = h.shape[0] |
|
g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden) |
|
g_repeat = g.repeat(n_nodes, 1, 1) |
|
g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0) |
|
g_concat = torch.cat([g_repeat_interleave, g_repeat], dim=-1) |
|
g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden) |
|
e = self.activation(self.attn(g_concat)) |
|
|
|
e = e.squeeze(-1) |
|
|
|
|
|
|
|
assert adj_mat.shape[0] == 1 or adj_mat.shape[0] == n_nodes |
|
assert adj_mat.shape[1] == 1 or adj_mat.shape[1] == n_nodes |
|
assert adj_mat.shape[2] == 1 or adj_mat.shape[2] == self.n_heads |
|
|
|
|
|
e = e.masked_fill(adj_mat == 0, float(-1e9)) |
|
a = self.softmax(e) |
|
a = self.dropout(a) |
|
attn_res = torch.einsum('ijh,jhf->ihf', a, g) |
|
|
|
|
|
if self.is_concat: |
|
return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden) |
|
|
|
else: |
|
return attn_res.mean(dim=1) |
|
|
|
|
|
class GAT(nn.Module): |
|
def __init__(self, in_features: int, n_hidden: int, n_classes: int, n_heads: int, dropout: float): |
|
super().__init__() |
|
self.layer1 = GraphAttentionLayer(in_features, n_hidden, n_heads, is_concat=True, dropout=dropout) |
|
self.activation = nn.ELU() |
|
self.output = GraphAttentionLayer(n_hidden, n_classes, 1, is_concat=False, dropout=dropout) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x: torch.Tensor, adj_mat: torch.Tensor, docnum, secnum): |
|
x = x.squeeze(0) |
|
adj_mat = adj_mat.squeeze(0) |
|
adj_x = adj_mat.clone().sum(dim=1, keepdim=True).repeat(1, x.shape[1]).bool() |
|
adj_mat = adj_mat.unsqueeze(-1).bool() |
|
x = self.dropout(x) |
|
x = self.layer1(x, adj_mat, docnum, secnum) |
|
x = self.activation(x) |
|
x = self.dropout(x) |
|
x = self.output(x, adj_mat, docnum, secnum).masked_fill(adj_x == 0, float(0)) |
|
return x.unsqueeze(0) |
|
|
|
|
|
class StepWiseGraphConvLayer(nn.Module): |
|
def __init__(self, in_dim, hid_dim, dropout_p=0.3, act=nn.LeakyReLU(), nheads=6, iter=1, final="att"): |
|
super().__init__() |
|
self.act = act |
|
self.dropout = nn.Dropout(dropout_p) |
|
self.iter = iter |
|
self.in_dim = in_dim |
|
self.gat = nn.ModuleList([GAT(in_features=in_dim, n_hidden=hid_dim, n_classes=in_dim, |
|
dropout=dropout_p, n_heads=nheads) for _ in range(iter)]) |
|
self.gat2 = nn.ModuleList([GAT(in_features=in_dim, n_hidden=hid_dim, n_classes=in_dim, |
|
dropout=dropout_p, n_heads=nheads) for _ in range(iter)]) |
|
self.gat3 = nn.ModuleList([GAT(in_features=in_dim, n_hidden=hid_dim, n_classes=in_dim, |
|
dropout=dropout_p, n_heads=nheads) for _ in range(iter)]) |
|
|
|
self.out_ffn = MLP([in_dim * 3, hid_dim, hid_dim, in_dim], layers=3, dropout_p=dropout_p) |
|
|
|
def forward(self, feature, adj, docnum, secnum): |
|
sen_adj = adj.clone() |
|
sen_adj[:, -docnum-secnum-1:, :] = sen_adj[:, :, -docnum-secnum-1:] = 0 |
|
sec_adj = adj.clone() |
|
sec_adj[:, :-docnum-secnum-1, :] = sec_adj[:, -docnum-1:, :] = sec_adj[:, :, -docnum-1:] = 0 |
|
doc_adj = adj.clone() |
|
doc_adj[:, :-docnum-1, :] = 0 |
|
|
|
feature_sen = feature.clone() |
|
feature_resi = feature |
|
|
|
feature_sen_re = feature_sen.clone() |
|
for i in range(0, self.iter): |
|
feature_sen = self.gat[i](feature_sen, sen_adj, docnum, secnum) |
|
feature_sen = F.layer_norm(feature_sen + feature_sen_re, [self.in_dim]) |
|
|
|
feature_sec = feature_sen.clone() |
|
feature_sec_re = feature_sec.clone() |
|
for i in range(0, self.iter): |
|
feature_sec = self.gat2[i](feature_sec, sec_adj, docnum, secnum) |
|
feature_sec = F.layer_norm(feature_sec + feature_sec_re, [self.in_dim]) |
|
|
|
feature_doc = feature_sec.clone() |
|
feature_doc_re = feature_doc.clone() |
|
for i in range(0, self.iter): |
|
feature_doc = self.gat3[i](feature_doc, doc_adj, docnum, secnum) |
|
feature_doc = F.layer_norm(feature_doc + feature_doc_re, [self.in_dim]) |
|
|
|
feature_sec[:, :-docnum-secnum-1, :] = adj[:, :-docnum-secnum-1, -docnum-secnum-1:-docnum-1] @ feature_sec[:, -docnum-secnum-1:-docnum-1, :] |
|
feature_doc[:, -docnum-secnum-1:-docnum-1, :] = adj[:, -docnum-secnum-1:-docnum-1, -docnum-1:] @ feature_doc[:, -docnum-1:, :] |
|
feature_doc[:, :-docnum-secnum-1, :] = adj[:, :-docnum-secnum-1, -docnum-secnum-1:-docnum-1] @ feature_doc[:, -docnum-secnum-1:-docnum-1, :] |
|
feature = torch.concat([feature_doc, feature_sec, feature_sen], dim=-1) |
|
feature = F.layer_norm(self.out_ffn(feature) + feature_resi, [self.in_dim]) |
|
return feature |
|
|
|
class Contrast_Encoder(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, heads, act=nn.LeakyReLU(0.1), dropout_p=0.3): |
|
super(Contrast_Encoder, self).__init__() |
|
self.graph_encoder = StepWiseGraphConvLayer(in_dim=input_dim, hid_dim=hidden_dim, |
|
dropout_p=dropout_p, act=act, nheads=heads, iter=1) |
|
self.common_proj_mlp = MLP([input_dim, hidden_dim, input_dim], layers=2, dropout_p=dropout_p, act=act, keep_last_layer=False) |
|
|
|
def forward(self, p_gfeature, doc_lens, p_adj, docnum, secnum): |
|
posVec = torch.cat([PositionVec[:l] for l in doc_lens] + [torch.zeros(secnum+docnum+1, 768).float().to(device)], dim=0) |
|
p_gfeature = p_gfeature + posVec.unsqueeze(0) |
|
pg = self.graph_encoder(p_gfeature, p_adj, docnum, secnum) |
|
pg = self.common_proj_mlp(pg) |
|
return pg |
|
|
|
|
|
class End2End_Encoder(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, heads, act=nn.LeakyReLU(0.1), dropout_p=0.3): |
|
super(End2End_Encoder, self).__init__() |
|
self.graph_encoder = StepWiseGraphConvLayer(in_dim=input_dim, hid_dim=hidden_dim, |
|
dropout_p=dropout_p, act=act, nheads=heads, iter=1) |
|
self.dropout = nn.Dropout(dropout_p) |
|
self.out_proj_layer_mlp = MLP([input_dim, hidden_dim, input_dim], layers=2, dropout_p=dropout_p, act=act, keep_last_layer=False) |
|
self.linear = MLP([input_dim, 1], layers=1, dropout_p=dropout_p, act=act, keep_last_layer=True) |
|
|
|
def forward(self, x, doc_lens, adj, docnum, secnum): |
|
x = self.graph_encoder(x, adj, docnum, secnum) |
|
x = x[:, :-docnum-secnum-1, :] |
|
x = self.out_proj_layer_mlp(x) |
|
x = self.linear(x) |
|
return x |
|
|
|
def _similarity(h1: torch.Tensor, h2: torch.Tensor): |
|
h1 = F.normalize(h1) |
|
h2 = F.normalize(h2) |
|
return h1 @ h2.t() |
|
|
|
class InfoNCE(nn.Module): |
|
def __init__(self, tau): |
|
super(InfoNCE, self).__init__() |
|
self.tau = tau |
|
|
|
def forward(self, anchor, sample, pos_mask, neg_mask, *args, **kwargs): |
|
sim = _similarity(anchor, sample) / self.tau |
|
if len(anchor) > 1: |
|
sim, _ = torch.max(sim, dim=0) |
|
exp_sim = torch.exp(sim) |
|
loss = torch.log((exp_sim * pos_mask).sum(dim=1)) - torch.log((exp_sim * (pos_mask + neg_mask)).sum(dim=1)) |
|
return -loss.mean() |
|
|
|
class Cluster: |
|
def __init__(self, sent_texts, sent_vecs, doc_lens, doc_sec_mask, sec_sen_mask): |
|
assert len(sent_vecs) == len(sent_texts) |
|
self.docnum = len(doc_sec_mask) |
|
self.secnum = len(sec_sen_mask) |
|
self.feature = torch.cat((torch.stack(sent_vecs, dim=0), torch.zeros((self.secnum+self.docnum+1, sent_vecs[0].shape[0]))), dim=0).to(device) |
|
self.adj = torch.from_numpy(self.mask_to_adj(doc_sec_mask, sec_sen_mask)).float().to(device) |
|
self.sent_text = np.array(sent_texts) |
|
self.doc_lens = doc_lens |
|
self.init_node_vec() |
|
self.feature = self.feature.float() |
|
|
|
def init_node_vec(self): |
|
docnum, secnum = self.docnum, self.secnum |
|
for i in range(-secnum-docnum-1, -docnum-1): |
|
mask = self.adj[i].clone() |
|
mask[-secnum-docnum-1:] = 0 |
|
self.feature[i] = torch.mean(self.feature[mask.bool()], dim=0) |
|
for i in range(-docnum-1, -1): |
|
mask = self.adj[i].clone() |
|
mask[-docnum-1:] = 0 |
|
self.feature[i] = torch.mean(self.feature[mask.bool()], dim=0) |
|
self.feature[-1] = torch.mean(self.feature[-docnum-1:-1], dim=0) |
|
|
|
def mask_to_adj(self, doc_sec_mask, sec_sen_mask): |
|
sen_num = sec_sen_mask.shape[1] |
|
sec_num = sec_sen_mask.shape[0] |
|
doc_num = doc_sec_mask.shape[0] |
|
adj = np.zeros((sen_num+sec_num+doc_num+1, sen_num+sec_num+doc_num+1)) |
|
|
|
adj[-sec_num-doc_num-1:-doc_num-1, 0:-sec_num-doc_num-1] = sec_sen_mask |
|
adj[0:-sec_num-doc_num-1, -sec_num-doc_num-1:-doc_num-1] = sec_sen_mask.T |
|
for i in range(0, doc_num): |
|
doc_mask = doc_sec_mask[i] |
|
doc_mask = doc_mask.reshape((1, len(doc_mask))) |
|
adj[sen_num:-doc_num-1, sen_num:-doc_num-1] += doc_mask * doc_mask.T |
|
|
|
adj[-doc_num-1:-1, -sec_num-doc_num-1:-doc_num-1] = doc_sec_mask |
|
adj[-sec_num-doc_num-1:-doc_num-1, -doc_num-1:-1] = doc_sec_mask.T |
|
adj[-doc_num-1:, -doc_num-1:] = 1 |
|
|
|
|
|
for i in range(0, sec_num): |
|
sec_mask = sec_sen_mask[i] |
|
sec_mask = sec_mask.reshape((1, len(sec_mask))) |
|
adj[:sen_num, :sen_num] += sec_mask * sec_mask.T |
|
return adj |
|
|
|
def meanTokenVecs(text): |
|
sent = text.lower() |
|
input_ids = torch.tensor([tokenizer.encode(sent)]) |
|
tokenized_text = tokenizer.tokenize(sent) |
|
with torch.no_grad(): |
|
features = phobert(input_ids) |
|
wordVecs, buffer, buffer_str = {}, [], '' |
|
for token in zip(tokenized_text, features.last_hidden_state[0, 1:-1,:]): |
|
if token[0][-2:] == '@@': |
|
buffer.append(token[1]) |
|
buffer_str += token[0][:-2] |
|
continue |
|
if buffer: |
|
buffer.append(token[1]) |
|
buffer_str += token[0] |
|
wordVecs[buffer_str] = torch.mean(torch.stack(buffer), dim=0) |
|
buffer, buffer_str = [], '' |
|
else: |
|
wordVecs[token[0]] = token[1] |
|
|
|
return torch.mean(torch.stack([vec for w, vec in wordVecs.items() if w not in string.punctuation]), dim=0) |
|
|
|
def getPositionEncoding(pos, d=768, n=10000): |
|
P = np.zeros(d) |
|
for i in np.arange(int(d/2)): |
|
denominator = np.power(n, 2*i/d) |
|
P[2*i] = np.sin(pos/denominator) |
|
P[2*i+1] = np.cos(pos/denominator) |
|
return P |
|
|
|
PositionVec = torch.stack([torch.from_numpy(getPositionEncoding(i, d=768)) for i in range(200)], dim=0).float().to(device) |
|
stop_w = ['...'] |
|
with open('./vietnamese-stopwords-dash.txt', 'r', encoding='utf-8') as f: |
|
for w in f.readlines(): |
|
stop_w.append(w.strip()) |
|
stop_w.extend([c for c in '!"#$%&\'()*+,./:;<=>?@[\\]^`{|}~…“”’‘']) |
|
|
|
with open('./LDA_models.pkl', mode='rb') as fp: |
|
cate_models = pickle.load(fp) |
|
|
|
def removeRedundant(text): |
|
text = text.lower() |
|
words = [w for w in text.split(' ') if w not in stop_w] |
|
return ' '.join(words) |
|
|
|
def divideSection(doc_text, category): |
|
assert category in cate_models, 'invalid category' |
|
sent_para, para_sec, sent_sec = {}, {}, {} |
|
|
|
paras = [para for para in doc_text.split('\n') if para != ''] |
|
all_sents = [] |
|
|
|
sentcnt = 0 |
|
for i, para in enumerate(paras): |
|
sents = [word_tokenize(sent, format="text") for sent in sent_tokenize(para) if sent != ''] |
|
all_sents.extend(sents) |
|
for ii, sent in enumerate(sents): |
|
sent_para[sentcnt + ii] = i |
|
sent = removeRedundant(sent) |
|
sentcnt += len(sents) |
|
|
|
|
|
paras = [removeRedundant(para) for para in paras] |
|
tf, lda_model = cate_models[category] |
|
X = tf.transform(paras) |
|
lda_top = lda_model.transform(X) |
|
for i, para_top in enumerate(lda_top): |
|
para_sec[i] = para_top.argmax() |
|
|
|
|
|
for k, v in sent_para.items(): |
|
sent_sec[k] = para_sec[v] |
|
return sent_sec, all_sents |
|
|
|
def loadClusterData(docs_org, category): |
|
seclist, docs = {}, [] |
|
for d, doc in enumerate(docs_org): |
|
seclist[d], sentTexts = divideSection(doc, category) |
|
docs.append(sentTexts) |
|
|
|
secnum = 0 |
|
for k, val_dict in seclist.items(): |
|
vals = set(val_dict.values()) |
|
for ki, vi in val_dict.items(): |
|
for i, v in enumerate(vals): |
|
if vi == v: |
|
val_dict[ki] = i + secnum |
|
break |
|
seclist[k] = val_dict |
|
secnum += len(vals) |
|
|
|
sents, sentVecs, secIDs, doc_lens = [], [], [], [] |
|
sentnum = sum([len(doc.values()) for doc in seclist.values()]) |
|
doc_sec_mask = np.zeros((len(docs), secnum)) |
|
sec_sen_mask = np.zeros((secnum, sentnum)) |
|
cursec, cursent = 0, 0 |
|
|
|
for d, doc in enumerate(docs): |
|
doc_lens.append(len(doc)) |
|
doc_endsec = max(seclist[d].values()) |
|
doc_sec_mask[d][cursec:doc_endsec + 1] = 1 |
|
cursec = doc_endsec + 1 |
|
for s, sent in enumerate(doc): |
|
sents.append(sent) |
|
sentVecs.append(meanTokenVecs(sent)) |
|
sec_sen_mask[seclist[d][s], cursent] = 1 |
|
cursent += 1 |
|
|
|
return Cluster(sents, sentVecs, doc_lens, doc_sec_mask, sec_sen_mask) |
|
|
|
def val_e2e(data, model, max_word_num=200, c_model=None): |
|
model.eval() |
|
c_model.eval() |
|
|
|
feature = data.feature.unsqueeze(0) |
|
doc_lens = data.doc_lens |
|
adj = data.adj.unsqueeze(0) |
|
docnum = data.docnum |
|
secnum = data.secnum |
|
|
|
with torch.no_grad(): |
|
feature = c_model(feature, doc_lens, adj, docnum, secnum) |
|
x = model(feature, doc_lens, adj, docnum, secnum) |
|
scores = torch.sigmoid(x.squeeze(-1)) |
|
|
|
return get_summary(scores[0], data.sent_text, max_word_num) |
|
|
|
def get_summary(scores, sents, max_word_num=200): |
|
ranked_score_idxs = torch.argsort(scores, dim=0, descending=True) |
|
wordCnt = 0 |
|
summSentIDList = [] |
|
for i in ranked_score_idxs: |
|
if wordCnt >= max_word_num: break |
|
s = sents[i] |
|
|
|
replicated = False |
|
for chosedID in summSentIDList: |
|
if getRouge2(sents[chosedID], s, 'p') >= 0.45: |
|
replicated = True |
|
break |
|
if replicated: continue |
|
|
|
wordCnt += len(s.split(' ')) |
|
summSentIDList.append(i) |
|
summSentIDList = sorted(summSentIDList) |
|
return ' '.join([s for i, s in enumerate(sents) if i in summSentIDList]) |
|
|
|
c_model = Contrast_Encoder(768, 1024, 4, dropout_p=0.3).to(device) |
|
model = End2End_Encoder(768, 1024, 4, dropout_p=0.3).to(device) |
|
model.load_state_dict(torch.load('./e_25_0.3071.mdl', map_location=device), strict=False) |
|
c_model.load_state_dict(torch.load('./c_25_0.3071.mdl', map_location=device), strict=False) |
|
|
|
def infer(docs, category): |
|
|
|
docs = [text.strip() for text in docs] |
|
data_tree = loadClusterData(docs, category) |
|
summ = val_e2e(data_tree, model, c_model=c_model, max_word_num=200) |
|
summ = re.sub(r'\s+([.,;:"?()/!?])', r'\1', summ.replace('_', ' ')) |
|
return summ, docs |