# -*- coding: utf-8 -*- # @Time : 2022/4/21 5:30 下午 # @Author : JianingWang # @File : fusion_siamese.py from typing import Optional import torch import numpy as np import torch.nn as nn from dataclasses import dataclass from torch.nn import BCEWithLogitsLoss from transformers import MegatronBertModel, MegatronBertPreTrainedModel from transformers.file_utils import ModelOutput from transformers.models.bert import BertPreTrainedModel, BertModel from transformers.activations import ACT2FN from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss from transformers.modeling_outputs import SequenceClassifierOutput from loss.focal_loss import FocalLoss # from roformer import RoFormerPreTrainedModel, RoFormerModel class BertPooler(nn.Module): def __init__(self, hidden_size, hidden_act): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) # self.activation = nn.Tanh() self.activation = ACT2FN[hidden_act] # self.dropout = nn.Dropout(hidden_dropout_prob) def forward(self, features): x = features[:, 0, :] # take token (equiv. to [CLS]) # x = self.dropout(x) x = self.dense(x) x = self.activation(x) return x class BertForFusionSiamese(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config) self.hidden_size = config.hidden_size self.hidden_act = config.hidden_act self.bert_poor = BertPooler(self.hidden_size, self.hidden_act) self.dense_1 = nn.Linear(self.hidden_size, self.hidden_size) self.dense_2 = nn.Linear(self.hidden_size, self.hidden_size) if hasattr(config, "cls_dropout_rate"): cls_dropout_rate = config.cls_dropout_rate else: cls_dropout_rate = config.hidden_dropout_prob self.dropout = nn.Dropout(cls_dropout_rate) self.classifier = nn.Linear(3 * self.hidden_size, config.num_labels) self.init_weights() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, pseudo_label=None, segment_spans=None, pseuso_proba=None ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict logits, outputs = None, None inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, "head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict} inputs = {k: v for k, v in inputs.items() if v is not None} outputs = self.bert(**inputs) if "sequence_output" in outputs: sequence_output = outputs.sequence_output # [bz, seq_len, dim] else: sequence_output = outputs[0] # [bz, seq_len, dim] cls_output = self.bert_poor(sequence_output) # [bz, dim] if segment_spans is not None: # 如果输入的是两个segment,则分别进行平均池化 seg1_embeddings, seg2_embeddings = list(), list() for ei, sentence_embeddings in enumerate(sequence_output): # sentence_embedding: [seq_len, dim] seg1_start, seg1_end, seg2_start, seg2_end = segment_spans[ei] # print("sentence_embeddings[seg1_start, seg1_end].shape=", sentence_embeddings[seg1_start, seg1_end].shape) # print("torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape=", torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape) seg1_embeddings.append(torch.mean(sentence_embeddings[seg1_start: seg1_end], 0)) # [dim] seg2_embeddings.append(torch.mean(sentence_embeddings[seg2_start: seg2_end], 0)) # [dim] seg1_embeddings, seg2_embeddings = torch.stack(seg1_embeddings), torch.stack(seg2_embeddings) # [bz, dim] # print("seg1_embeddings.shape=", seg1_embeddings.shape) seg1_embeddings = self.bert_poor.activation(self.dense_1(seg1_embeddings)) seg2_embeddings = self.bert_poor.activation(self.dense_1(seg2_embeddings)) cls_output = torch.cat([cls_output, seg1_embeddings, seg2_embeddings], dim=-1) # [bz, 3*dim] # cls_output = cls_output + seg1_embeddings + seg2_embeddings # [bz, dim] pooler_output = self.dropout(cls_output) # pooler_output = self.LayerNorm(pooler_output) logits = self.classifier(pooler_output) loss = None if labels is not None: # loss_fct = FocalLoss() loss_fct = CrossEntropyLoss() # 伪标签 if pseudo_label is not None: train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1] train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1] train_loss = loss_fct(train_logits.view(-1, self.num_labels), train_labels.view(-1)) if train_labels.nelement() else 0 pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels), pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0 loss = 0.9 * train_loss + 0.1 * pseudo_loss else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class BertForWSC(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config) self.hidden_size = config.hidden_size self.hidden_act = config.hidden_act self.bert_poor = BertPooler(self.hidden_size, self.hidden_act) self.dense_1 = nn.Linear(self.hidden_size, self.hidden_size) self.dense_2 = nn.Linear(self.hidden_size, self.hidden_size) if hasattr(config, "cls_dropout_rate"): cls_dropout_rate = config.cls_dropout_rate else: cls_dropout_rate = config.hidden_dropout_prob self.dropout = nn.Dropout(cls_dropout_rate) self.classifier = nn.Linear(2 * self.hidden_size, config.num_labels) self.init_weights() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, pseudo_label=None, span=None, pseuso_proba=None ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict logits, outputs = None, None inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, "position_ids": position_ids, "head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions, "output_hidden_states": output_hidden_states, "return_dict": return_dict} inputs = {k: v for k, v in inputs.items() if v is not None} outputs = self.bert(**inputs) if "sequence_output" in outputs: sequence_output = outputs.sequence_output # [bz, seq_len, dim] else: sequence_output = outputs[0] # [bz, seq_len, dim] # cls_output = self.bert_poor(sequence_output) # [bz, dim] # 如果输入的是两个span,则分别进行平均池化 seg1_embeddings, seg2_embeddings = list(), list() # print("span=", span) for ei, sentence_embeddings in enumerate(sequence_output): # sentence_embedding: [seq_len, dim] seg1_start, seg1_end, seg2_start, seg2_end = span[ei] # print("sentence_embeddings[seg1_start, seg1_end].shape=", sentence_embeddings[seg1_start, seg1_end].shape) # print("torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape=", torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape) seg1_embeddings.append(torch.mean(sentence_embeddings[seg1_start+1: seg1_end], 0)) # [dim] seg2_embeddings.append(torch.mean(sentence_embeddings[seg2_start+1: seg2_end], 0)) # [dim] seg1_embeddings, seg2_embeddings = torch.stack(seg1_embeddings), torch.stack(seg2_embeddings) # [bz, dim] # print("seg1_embeddings.shape=", seg1_embeddings.shape) # seg1_embeddings = self.bert_poor.activation(self.dense_1(seg1_embeddings)) # seg2_embeddings = self.bert_poor.activation(self.dense_1(seg2_embeddings)) cls_output = torch.cat([seg1_embeddings, seg2_embeddings], dim=-1) # [bz, 3*dim] # cls_output = cls_output + seg1_embeddings + seg2_embeddings # [bz, dim] pooler_output = self.dropout(cls_output) # pooler_output = self.LayerNorm(pooler_output) logits = self.classifier(pooler_output) loss = None if labels is not None: # loss_fct = FocalLoss() loss_fct = CrossEntropyLoss() # 伪标签 if pseudo_label is not None: train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1] train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1] train_loss = loss_fct(train_logits.view(-1, self.num_labels), train_labels.view(-1)) if train_labels.nelement() else 0 pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels), pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0 loss = 0.9 * train_loss + 0.1 * pseudo_loss else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )