# -*- coding: utf-8 -*- # @Time : 2022/4/21 5:30 下午 # @Author : JianingWang # @File : global_pointer.py from typing import Optional import torch import numpy as np import torch.nn as nn from dataclasses import dataclass from torch.nn import BCEWithLogitsLoss from transformers import MegatronBertModel, MegatronBertPreTrainedModel from transformers.file_utils import ModelOutput from transformers.models.bert import BertPreTrainedModel, BertModel from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel from roformer import RoFormerPreTrainedModel, RoFormerModel, RoFormerModel class RawGlobalPointer(nn.Module): def __init__(self, encoder, ent_type_size, inner_dim, RoPE=True): # encodr: RoBerta-Large as encoder # inner_dim: 64 # ent_type_size: ent_cls_num super().__init__() self.encoder = encoder self.ent_type_size = ent_type_size self.inner_dim = inner_dim self.hidden_size = encoder.config.hidden_size self.dense = nn.Linear(self.hidden_size, self.ent_type_size * self.inner_dim * 2) self.RoPE = RoPE def sinusoidal_position_embedding(self, batch_size, seq_len, output_dim): position_ids = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(-1) indices = torch.arange(0, output_dim // 2, dtype=torch.float) indices = torch.pow(10000, -2 * indices / output_dim) embeddings = position_ids * indices embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) embeddings = embeddings.repeat((batch_size, *([1] * len(embeddings.shape)))) embeddings = torch.reshape(embeddings, (batch_size, seq_len, output_dim)) embeddings = embeddings.to(self.device) return embeddings def forward(self, input_ids, attention_mask, token_type_ids): self.device = input_ids.device context_outputs = self.encoder(input_ids, attention_mask, token_type_ids) # last_hidden_state:(batch_size, seq_len, hidden_size) last_hidden_state = context_outputs[0] batch_size = last_hidden_state.size()[0] seq_len = last_hidden_state.size()[1] outputs = self.dense(last_hidden_state) outputs = torch.split(outputs, self.inner_dim * 2, dim=-1) outputs = torch.stack(outputs, dim=-2) qw, kw = outputs[..., :self.inner_dim], outputs[..., self.inner_dim:] if self.RoPE: # pos_emb:(batch_size, seq_len, inner_dim) pos_emb = self.sinusoidal_position_embedding(batch_size, seq_len, self.inner_dim) cos_pos = pos_emb[..., None, 1::2].repeat_interleave(2, dim=-1) sin_pos = pos_emb[..., None, ::2].repeat_interleave(2, dim=-1) qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], -1) qw2 = qw2.reshape(qw.shape) qw = qw * cos_pos + qw2 * sin_pos kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], -1) kw2 = kw2.reshape(kw.shape) kw = kw * cos_pos + kw2 * sin_pos # logits:(batch_size, ent_type_size, seq_len, seq_len) logits = torch.einsum("bmhd,bnhd->bhmn", qw, kw) # padding mask pad_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, self.ent_type_size, seq_len, seq_len) logits = logits * pad_mask - (1 - pad_mask) * 1e12 # 排除下三角 mask = torch.tril(torch.ones_like(logits), -1) logits = logits - mask * 1e12 return logits / self.inner_dim ** 0.5 class SinusoidalPositionEmbedding(nn.Module): """定义Sin-Cos位置Embedding """ def __init__( self, output_dim, merge_mode="add", custom_position_ids=False): super(SinusoidalPositionEmbedding, self).__init__() self.output_dim = output_dim self.merge_mode = merge_mode self.custom_position_ids = custom_position_ids def forward(self, inputs): if self.custom_position_ids: seq_len = inputs.shape[1] inputs, position_ids = inputs position_ids = position_ids.type(torch.float) else: input_shape = inputs.shape batch_size, seq_len = input_shape[0], input_shape[1] position_ids = torch.arange(seq_len).type(torch.float)[None] indices = torch.arange(self.output_dim // 2).type(torch.float) indices = torch.pow(10000.0, -2 * indices / self.output_dim) embeddings = torch.einsum("bn,d->bnd", position_ids, indices) embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) embeddings = torch.reshape(embeddings, (-1, seq_len, self.output_dim)) if self.merge_mode == "add": return inputs + embeddings.to(inputs.device) elif self.merge_mode == "mul": return inputs * (embeddings + 1.0).to(inputs.device) elif self.merge_mode == "zero": return embeddings.to(inputs.device) def multilabel_categorical_crossentropy(y_pred, y_true): y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes y_pred_neg = y_pred - y_true * 1e12 # mask the pred outputs of pos classes y_pred_pos = y_pred - (1 - y_true) * 1e12 # mask the pred outputs of neg classes zeros = torch.zeros_like(y_pred[..., :1]) y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) neg_loss = torch.logsumexp(y_pred_neg, dim=-1) pos_loss = torch.logsumexp(y_pred_pos, dim=-1) # print(y_pred, y_true, pos_loss) return (neg_loss + pos_loss).mean() def multilabel_categorical_crossentropy2(y_pred, y_true): y_pred = (1 - 2 * y_true) * y_pred # -1 -> pos classes, 1 -> neg classes y_pred_neg = y_pred.clone() y_pred_pos = y_pred.clone() y_pred_neg[y_true>0] -= float("inf") y_pred_pos[y_true<1] -= float("inf") # y_pred_neg = y_pred - y_true * float("inf") # mask the pred outputs of pos classes # y_pred_pos = y_pred - (1 - y_true) * float("inf") # mask the pred outputs of neg classes zeros = torch.zeros_like(y_pred[..., :1]) y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) neg_loss = torch.logsumexp(y_pred_neg, dim=-1) pos_loss = torch.logsumexp(y_pred_pos, dim=-1) # print(y_pred, y_true, pos_loss) return (neg_loss + pos_loss).mean() @dataclass class GlobalPointerOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None topk_probs: torch.FloatTensor = None topk_indices: torch.IntTensor = None class BertForEffiGlobalPointer(BertPreTrainedModel): def __init__(self, config): # encodr: RoBerta-Large as encoder # inner_dim: 64 # ent_type_size: ent_cls_num super().__init__(config) self.bert = BertModel(config) self.ent_type_size = config.ent_type_size self.inner_dim = config.inner_dim self.hidden_size = config.hidden_size self.RoPE = config.RoPE self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2) self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2) def sequence_masking(self, x, mask, value="-inf", axis=None): if mask is None: return x else: if value == "-inf": value = -1e12 elif value == "inf": value = 1e12 assert axis > 0, "axis must be greater than 0" for _ in range(axis - 1): mask = torch.unsqueeze(mask, 1) for _ in range(x.ndim - mask.ndim): mask = torch.unsqueeze(mask, mask.ndim) return x * mask + value * (1 - mask) def add_mask_tril(self, logits, mask): if mask.dtype != logits.dtype: mask = mask.type(logits.dtype) logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2) logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1) # 排除下三角 mask = torch.tril(torch.ones_like(logits), diagonal=-1) logits = logits - mask * 1e12 return logits def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None): # with torch.no_grad(): context_outputs = self.bert(input_ids, attention_mask, token_type_ids) last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim] outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim] qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总 batch_size = input_ids.shape[0] if self.RoPE: pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs) cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90] sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1) qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) qw2 = torch.reshape(qw2, qw.shape) qw = qw * cos_pos + qw2 * sin_pos kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3) kw2 = torch.reshape(kw2, kw.shape) kw = kw * cos_pos + kw2 * sin_pos logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5 bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2 logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度 # logit_mask = self.add_mask_tril(logits, mask=attention_mask) loss = None mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵 # mask = torch.where(mask > 0, 0.0, 1) if labels is not None: y_pred = logits - (1-mask.unsqueeze(1))*1e12 y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1) y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1) loss = multilabel_categorical_crossentropy(y_pred, y_true) with torch.no_grad(): prob = torch.sigmoid(logits) * mask.unsqueeze(1) topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1) return GlobalPointerOutput( loss=loss, topk_probs=topk.values, topk_indices=topk.indices ) class RobertaForEffiGlobalPointer(RobertaPreTrainedModel): def __init__(self, config): # encodr: RoBerta-Large as encoder # inner_dim: 64 # ent_type_size: ent_cls_num super().__init__(config) self.roberta = RobertaModel(config) self.ent_type_size = config.ent_type_size self.inner_dim = config.inner_dim self.hidden_size = config.hidden_size self.RoPE = config.RoPE self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2) self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2) def sequence_masking(self, x, mask, value="-inf", axis=None): if mask is None: return x else: if value == "-inf": value = -1e12 elif value == "inf": value = 1e12 assert axis > 0, "axis must be greater than 0" for _ in range(axis - 1): mask = torch.unsqueeze(mask, 1) for _ in range(x.ndim - mask.ndim): mask = torch.unsqueeze(mask, mask.ndim) return x * mask + value * (1 - mask) def add_mask_tril(self, logits, mask): if mask.dtype != logits.dtype: mask = mask.type(logits.dtype) logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2) logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1) # 排除下三角 mask = torch.tril(torch.ones_like(logits), diagonal=-1) logits = logits - mask * 1e12 return logits def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None): # with torch.no_grad(): context_outputs = self.roberta(input_ids, attention_mask, token_type_ids) last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim] outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim] qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总 batch_size = input_ids.shape[0] if self.RoPE: pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs) cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90] sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1) qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) qw2 = torch.reshape(qw2, qw.shape) qw = qw * cos_pos + qw2 * sin_pos kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3) kw2 = torch.reshape(kw2, kw.shape) kw = kw * cos_pos + kw2 * sin_pos logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5 bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2 logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度 # logit_mask = self.add_mask_tril(logits, mask=attention_mask) loss = None mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵 # mask = torch.where(mask > 0, 0.0, 1) if labels is not None: y_pred = logits - (1-mask.unsqueeze(1))*1e12 y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1) y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1) loss = multilabel_categorical_crossentropy(y_pred, y_true) with torch.no_grad(): prob = torch.sigmoid(logits) * mask.unsqueeze(1) topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1) return GlobalPointerOutput( loss=loss, topk_probs=topk.values, topk_indices=topk.indices ) class RoformerForEffiGlobalPointer(RoFormerPreTrainedModel): def __init__(self, config): # encodr: RoBerta-Large as encoder # inner_dim: 64 # ent_type_size: ent_cls_num super().__init__(config) self.roformer = RoFormerModel(config) self.ent_type_size = config.ent_type_size self.inner_dim = config.inner_dim self.hidden_size = config.hidden_size self.RoPE = config.RoPE self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2) self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2) def sequence_masking(self, x, mask, value="-inf", axis=None): if mask is None: return x else: if value == "-inf": value = -1e12 elif value == "inf": value = 1e12 assert axis > 0, "axis must be greater than 0" for _ in range(axis - 1): mask = torch.unsqueeze(mask, 1) for _ in range(x.ndim - mask.ndim): mask = torch.unsqueeze(mask, mask.ndim) return x * mask + value * (1 - mask) def add_mask_tril(self, logits, mask): if mask.dtype != logits.dtype: mask = mask.type(logits.dtype) logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2) logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1) # 排除下三角 mask = torch.tril(torch.ones_like(logits), diagonal=-1) logits = logits - mask * 1e12 return logits def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None): # with torch.no_grad(): context_outputs = self.roformer(input_ids, attention_mask, token_type_ids) last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim] outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim] qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总 batch_size = input_ids.shape[0] if self.RoPE: pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs) cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90] sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1) qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) qw2 = torch.reshape(qw2, qw.shape) qw = qw * cos_pos + qw2 * sin_pos kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3) kw2 = torch.reshape(kw2, kw.shape) kw = kw * cos_pos + kw2 * sin_pos logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5 bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2 logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度 # logit_mask = self.add_mask_tril(logits, mask=attention_mask) loss = None mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵 # mask = torch.where(mask > 0, 0.0, 1) if labels is not None: y_pred = logits - (1-mask.unsqueeze(1))*1e12 y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1) y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1) loss = multilabel_categorical_crossentropy(y_pred, y_true) with torch.no_grad(): prob = torch.sigmoid(logits) * mask.unsqueeze(1) topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1) return GlobalPointerOutput( loss=loss, topk_probs=topk.values, topk_indices=topk.indices ) class MegatronForEffiGlobalPointer(MegatronBertPreTrainedModel): def __init__(self, config): # encodr: RoBerta-Large as encoder # inner_dim: 64 # ent_type_size: ent_cls_num super().__init__(config) self.bert = MegatronBertModel(config) self.ent_type_size = config.ent_type_size self.inner_dim = config.inner_dim self.hidden_size = config.hidden_size self.RoPE = config.RoPE self.dense_1 = nn.Linear(self.hidden_size, self.inner_dim * 2) self.dense_2 = nn.Linear(self.hidden_size, self.ent_type_size * 2) # 原版的dense2是(inner_dim * 2, ent_type_size * 2) def sequence_masking(self, x, mask, value="-inf", axis=None): if mask is None: return x else: if value == "-inf": value = -1e12 elif value == "inf": value = 1e12 assert axis > 0, "axis must be greater than 0" for _ in range(axis - 1): mask = torch.unsqueeze(mask, 1) for _ in range(x.ndim - mask.ndim): mask = torch.unsqueeze(mask, mask.ndim) return x * mask + value * (1 - mask) def add_mask_tril(self, logits, mask): if mask.dtype != logits.dtype: mask = mask.type(logits.dtype) logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 2) logits = self.sequence_masking(logits, mask, "-inf", logits.ndim - 1) # 排除下三角 mask = torch.tril(torch.ones_like(logits), diagonal=-1) logits = logits - mask * 1e12 return logits def forward(self, input_ids, attention_mask, token_type_ids, labels=None, short_labels=None): # with torch.no_grad(): context_outputs = self.bert(input_ids, attention_mask, token_type_ids) last_hidden_state = context_outputs.last_hidden_state # [bz, seq_len, hidden_dim] outputs = self.dense_1(last_hidden_state) # [bz, seq_len, 2*inner_dim] qw, kw = outputs[..., ::2], outputs[..., 1::2] # 从0,1开始间隔为2 最后一个纬度,从0开始,取奇数位置所有向量汇总 batch_size = input_ids.shape[0] if self.RoPE: pos = SinusoidalPositionEmbedding(self.inner_dim, "zero")(outputs) cos_pos = pos[..., 1::2].repeat_interleave(2, dim=-1) # e.g. [0.34, 0.90] -> [0.34, 0.34, 0.90, 0.90] sin_pos = pos[..., ::2].repeat_interleave(2, dim=-1) qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], 3) qw2 = torch.reshape(qw2, qw.shape) qw = qw * cos_pos + qw2 * sin_pos kw2 = torch.stack([-kw[..., 1::2], kw[..., ::2]], 3) kw2 = torch.reshape(kw2, kw.shape) kw = kw * cos_pos + kw2 * sin_pos logits = torch.einsum("bmd,bnd->bmn", qw, kw) / self.inner_dim ** 0.5 bias = torch.einsum("bnh->bhn", self.dense_2(last_hidden_state)) / 2 logits = logits[:, None] + bias[:, ::2, None] + bias[:, 1::2, :, None] # logits[:, None] 增加一个维度 # logit_mask = self.add_mask_tril(logits, mask=attention_mask) loss = None mask = torch.triu(attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)) # 上三角矩阵 # mask = torch.where(mask > 0, 0.0, 1) if labels is not None: y_pred = logits - (1-mask.unsqueeze(1))*1e12 y_true = labels.view(input_ids.shape[0] * self.ent_type_size, -1) y_pred = y_pred.view(input_ids.shape[0] * self.ent_type_size, -1) loss = multilabel_categorical_crossentropy(y_pred, y_true) with torch.no_grad(): prob = torch.sigmoid(logits) * mask.unsqueeze(1) topk = torch.topk(prob.view(batch_size, self.ent_type_size, -1), 50, dim=-1) return GlobalPointerOutput( loss=loss, topk_probs=topk.values, topk_indices=topk.indices )