# -*- coding: utf-8 -*- # @Time : 2022/3/15 21:26 # @Author : ruihan.wjn # @File : pk-plm.py """ This code is implemented for the paper ""Knowledge Prompting in Pre-trained Langauge Models for Natural Langauge Understanding"" """ from time import time import torch from torch import nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from collections import OrderedDict from transformers.models.bert import BertPreTrainedModel, BertModel from transformers.models.roberta import RobertaModel, RobertaPreTrainedModel, RobertaTokenizer, RobertaForMaskedLM from transformers.models.deberta import DebertaModel, DebertaPreTrainedModel, DebertaTokenizer, DebertaForMaskedLM from transformers.models.bert.modeling_bert import BertOnlyMLMHead, BertPreTrainingHeads from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaLMHead from transformers.models.deberta.modeling_deberta import DebertaModel, DebertaLMPredictionHead """ kg enhanced corpus structure example: { "token_ids": [20, 46098, 3277, 680, 10, 4066, 278, 9, 11129, 4063, 877, 579, 8, 8750, 14720, 8, 22498, 548, 19231, 46098, 3277, 6, 25, 157, 25, 130, 3753, 46098, 3277, 4, 3684, 19809, 10960, 9, 5, 30731, 2788, 914, 5, 1675, 8151, 35], "entity_pos": [[8, 11], [13, 15], [26, 27]], "entity_qid": ["Q17582", "Q231978", "Q427013"], "relation_pos": null, "relation_pid": null } """ from enum import Enum class SiameseDistanceMetric(Enum): """ The metric for the contrastive loss """ EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2) MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1) COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y) class ContrastiveLoss(nn.Module): """ Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased. Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf :param model: SentenceTransformer model :param distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used :param margin: Negative samples (label == 0) should have a distance of at least the margin value. :param size_average: Average by the size of the mini-batch. Example:: from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses from sentence_transformers.readers import InputExample model = SentenceTransformer("distilbert-base-nli-mean-tokens") train_examples = [InputExample(texts=["This is a positive pair", "Where the distance will be minimized"], label=1), InputExample(texts=["This is a negative pair", "Their distance will be increased"], label=0)] train_dataset = SentencesDataset(train_examples, model) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size) train_loss = losses.ContrastiveLoss(model=model) """ def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True): super(ContrastiveLoss, self).__init__() self.distance_metric = distance_metric self.margin = margin self.size_average = size_average def forward(self, sent_embs1, sent_embs2, labels: torch.Tensor): rep_anchor, rep_other = sent_embs1, sent_embs2 distances = self.distance_metric(rep_anchor, rep_other) losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2)) return losses.mean() if self.size_average else losses.sum() class NSPHead(nn.Module): def __init__(self, config): super().__init__() self.seq_relationship = nn.Linear(config.hidden_size, 2) def forward(self, pooled_output): seq_relationship_score = self.seq_relationship(pooled_output) return seq_relationship_score class RoBertaKPPLMForProcessedWikiKGPLM(RobertaForMaskedLM): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config # self.roberta = RobertaModel(config) try: classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) except: classifier_dropout = (config.hidden_dropout_prob) self.dropout = nn.Dropout(classifier_dropout) # self.cls = BertOnlyMLMHead(config) # self.lm_head = RobertaLMHead(config) # Masked Language Modeling head self.detector = NSPHead(config) # Knowledge Noise Detection head self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size) self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size) # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)]) self.contrastive_loss_fn = ContrastiveLoss() self.post_init() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, # entity_label=None, entity_candidate=None, # relation_label=None, relation_candidate=None, noise_detect_label=None, task_id=None, mask_id=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): # start_time = time() mlm_labels = labels return_dict = return_dict if return_dict is not None else self.config.use_return_dict # print("attention_mask.shape=", attention_mask.shape) # print("input_ids[0]=", input_ids[0]) # print("token_type_ids[0]=", token_type_ids[0]) # attention_mask = None outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] prediction_scores = self.lm_head(sequence_output) # mlm head # noise_detect_scores = self.detector(pooled_output) # knowledge noise detector use pool output noise_detect_scores = self.detector(sequence_output[:, 0, :]) # knowledge noise detector use cls embedding # ner # sequence_output = self.dropout(sequence_output) # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0) # mlm masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None total_loss = list() if mlm_labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1)) total_loss.append(masked_lm_loss) # if noise_detect_label is not None: # noise_detect_scores = noise_detect_scores[task_id == 1] # noise_detect_label = noise_detect_label[task_id == 1] # # if len(noise_detect_label) > 0: # loss_fct = CrossEntropyLoss() # noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1)) # total_loss.append(noise_detect_loss) entity_candidate = entity_candidate[task_id == 2] if len(entity_candidate) > 0: batch_size = entity_candidate.shape[0] candidate_num = entity_candidate.shape[1] # print("negative_num=", negative_num) # 获取被mask实体的embedding batch_entity_query_embedding = list() for ei, input_id in enumerate(input_ids[task_id == 2]): batch_entity_query_embedding.append( torch.mean(sequence_output[task_id == 2][ei][input_id == mask_id[task_id == 2][ei]], 0)) # [hidden_dim] batch_entity_query_embedding = torch.stack(batch_entity_query_embedding) # [bz, dim] # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape) batch_entity_query_embedding = self.entity_mlp(batch_entity_query_embedding) # [bz, dim] batch_entity_query_embedding = batch_entity_query_embedding.unsqueeze(1).repeat((1, candidate_num, 1)) # [bz, 11, dim] batch_entity_query_embedding = batch_entity_query_embedding.view(-1, batch_entity_query_embedding.shape[-1]) # [bz * 11, dim] # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape) # 获得positive和negative的BERT表示 # entity_candidiate: [bz, 11, len] entity_candidate = entity_candidate.view(-1, entity_candidate.shape[-1]) # [bz * 11, len] entity_candidate_embedding = self.roberta.embeddings(input_ids=entity_candidate) # [bz * 11, len, dim] entity_candidate_embedding = self.entity_mlp(torch.mean(entity_candidate_embedding, 1)) # [bz * 11, dim] contrastive_entity_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda() contrastive_entity_label = contrastive_entity_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11] entity_loss = self.contrastive_loss_fn( batch_entity_query_embedding, entity_candidate_embedding, contrastive_entity_label ) total_loss.append(entity_loss) relation_candidate = relation_candidate[task_id == 3] if len(relation_candidate) > 0: batch_size = relation_candidate.shape[0] candidate_num = relation_candidate.shape[1] # print("negative_num=", negative_num) # 获取被mask relation的embedding batch_relation_query_embedding = list() for ei, input_id in enumerate(input_ids[task_id == 3]): batch_relation_query_embedding.append( torch.mean(sequence_output[task_id == 3][ei][input_id == mask_id[task_id == 3][ei]], 0)) # [hidden_dim] batch_relation_query_embedding = torch.stack(batch_relation_query_embedding) # [bz, dim] # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape) batch_relation_query_embedding = self.relation_mlp(batch_relation_query_embedding) # [bz, dim] batch_relation_query_embedding = batch_relation_query_embedding.unsqueeze(1).repeat( (1, candidate_num, 1)) # [bz, 11, dim] batch_relation_query_embedding = batch_relation_query_embedding.view(-1, batch_relation_query_embedding.shape[-1]) # [bz * 11, dim] # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape) # 获得positive和negative的BERT表示 # entity_candidiate: [bz, 11, len] relation_candidate = relation_candidate.view(-1, relation_candidate.shape[-1]) # [bz * 11, len] relation_candidate_embedding = self.roberta.embeddings(input_ids=relation_candidate) # [bz * 11, len, dim] relation_candidate_embedding = self.relation_mlp(torch.mean(relation_candidate_embedding, 1)) # [bz * 11, dim] contrastive_relation_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda() contrastive_relation_label = contrastive_relation_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11] relation_loss = self.contrastive_loss_fn( batch_relation_query_embedding, relation_candidate_embedding, contrastive_relation_label ) total_loss.append(relation_loss) total_loss = torch.sum(torch.stack(total_loss), -1) # end_time = time() # print("neural_mode_time: {}".format(end_time - start_time)) # print("masked_lm_loss.unsqueeze(0)=", masked_lm_loss.unsqueeze(0)) # print("masked_lm_loss.unsqueeze(0).shape=", masked_lm_loss.unsqueeze(0).shape) # print("logits=", prediction_scores.argmax(2)) # print("logits.shape=", prediction_scores.argmax(2).shape) return OrderedDict([ ("loss", total_loss), ("mlm_loss", masked_lm_loss.unsqueeze(0)), # ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None), # ("entity_loss", entity_loss.unsqueeze(0) if entity_loss is not None else None), # ("relation_loss", relation_loss.unsqueeze(0) if relation_loss is not None else None), ("logits", prediction_scores.argmax(2)), # ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None and len(noise_detect_scores) > 0 else None), ]) class DeBertaKPPLMForProcessedWikiKGPLM(DebertaForMaskedLM): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config # self.roberta = RobertaModel(config) try: classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) except: classifier_dropout = (config.hidden_dropout_prob) self.dropout = nn.Dropout(classifier_dropout) # self.cls = BertOnlyMLMHead(config) # self.lm_head = RobertaLMHead(config) # Masked Language Modeling head self.detector = NSPHead(config) # Knowledge Noise Detection head self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size) self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size) # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)]) self.contrastive_loss_fn = ContrastiveLoss() self.post_init() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, # entity_label=None, entity_candidate=None, # relation_label=None, relation_candidate=None, noise_detect_label=None, task_id=None, mask_id=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): # start_time = time() mlm_labels = labels return_dict = return_dict if return_dict is not None else self.config.use_return_dict # print("attention_mask.shape=", attention_mask.shape) # print("input_ids[0]=", input_ids[0]) # print("token_type_ids[0]=", token_type_ids[0]) # attention_mask = None outputs = self.deberta( input_ids, # attention_mask=attention_mask, attention_mask=None, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] prediction_scores = self.cls(sequence_output) # mlm head # noise_detect_scores = self.detector(pooled_output) # knowledge noise detector use pool output noise_detect_scores = self.detector(sequence_output[:, 0, :]) # knowledge noise detector use cls embedding # ner # sequence_output = self.dropout(sequence_output) # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0) # mlm masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None total_loss = list() if mlm_labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1)) total_loss.append(masked_lm_loss) # if noise_detect_label is not None: # noise_detect_scores = noise_detect_scores[task_id == 1] # noise_detect_label = noise_detect_label[task_id == 1] # # if len(noise_detect_label) > 0: # loss_fct = CrossEntropyLoss() # noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1)) # total_loss.append(noise_detect_loss) entity_candidate = entity_candidate[task_id == 2] if len(entity_candidate) > 0: batch_size = entity_candidate.shape[0] candidate_num = entity_candidate.shape[1] # print("negative_num=", negative_num) # 获取被mask实体的embedding batch_entity_query_embedding = list() for ei, input_id in enumerate(input_ids[task_id == 2]): batch_entity_query_embedding.append( torch.mean(sequence_output[task_id == 2][ei][input_id == mask_id[task_id == 2][ei]], 0)) # [hidden_dim] batch_entity_query_embedding = torch.stack(batch_entity_query_embedding) # [bz, dim] # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape) batch_entity_query_embedding = self.entity_mlp(batch_entity_query_embedding) # [bz, dim] batch_entity_query_embedding = batch_entity_query_embedding.unsqueeze(1).repeat((1, candidate_num, 1)) # [bz, 11, dim] batch_entity_query_embedding = batch_entity_query_embedding.view(-1, batch_entity_query_embedding.shape[-1]) # [bz * 11, dim] # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape) # 获得positive和negative的BERT表示 # entity_candidiate: [bz, 11, len] entity_candidate = entity_candidate.view(-1, entity_candidate.shape[-1]) # [bz * 11, len] entity_candidate_embedding = self.deberta.embeddings(input_ids=entity_candidate) # [bz * 11, len, dim] entity_candidate_embedding = self.entity_mlp(torch.mean(entity_candidate_embedding, 1)) # [bz * 11, dim] contrastive_entity_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda() contrastive_entity_label = contrastive_entity_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11] entity_loss = self.contrastive_loss_fn( batch_entity_query_embedding, entity_candidate_embedding, contrastive_entity_label ) total_loss.append(entity_loss) relation_candidate = relation_candidate[task_id == 3] if len(relation_candidate) > 0: batch_size = relation_candidate.shape[0] candidate_num = relation_candidate.shape[1] # print("negative_num=", negative_num) # 获取被mask relation的embedding batch_relation_query_embedding = list() for ei, input_id in enumerate(input_ids[task_id == 3]): batch_relation_query_embedding.append( torch.mean(sequence_output[task_id == 3][ei][input_id == mask_id[task_id == 3][ei]], 0)) # [hidden_dim] batch_relation_query_embedding = torch.stack(batch_relation_query_embedding) # [bz, dim] # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape) batch_relation_query_embedding = self.relation_mlp(batch_relation_query_embedding) # [bz, dim] batch_relation_query_embedding = batch_relation_query_embedding.unsqueeze(1).repeat( (1, candidate_num, 1)) # [bz, 11, dim] batch_relation_query_embedding = batch_relation_query_embedding.view(-1, batch_relation_query_embedding.shape[-1]) # [bz * 11, dim] # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape) # 获得positive和negative的BERT表示 # entity_candidiate: [bz, 11, len] relation_candidate = relation_candidate.view(-1, relation_candidate.shape[-1]) # [bz * 11, len] relation_candidate_embedding = self.deberta.embeddings(input_ids=relation_candidate) # [bz * 11, len, dim] relation_candidate_embedding = self.relation_mlp(torch.mean(relation_candidate_embedding, 1)) # [bz * 11, dim] contrastive_relation_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda() contrastive_relation_label = contrastive_relation_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11] relation_loss = self.contrastive_loss_fn( batch_relation_query_embedding, relation_candidate_embedding, contrastive_relation_label ) total_loss.append(relation_loss) total_loss = torch.sum(torch.stack(total_loss), -1) # end_time = time() # print("neural_mode_time: {}".format(end_time - start_time)) # print("masked_lm_loss.unsqueeze(0)=", masked_lm_loss.unsqueeze(0)) # print("masked_lm_loss.unsqueeze(0).shape=", masked_lm_loss.unsqueeze(0).shape) # print("logits=", prediction_scores.argmax(2)) # print("logits.shape=", prediction_scores.argmax(2).shape) return OrderedDict([ ("loss", total_loss), ("mlm_loss", masked_lm_loss.unsqueeze(0)), # ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None), # ("entity_loss", entity_loss.unsqueeze(0) if entity_loss is not None else None), # ("relation_loss", relation_loss.unsqueeze(0) if relation_loss is not None else None), ("logits", prediction_scores.argmax(2)), # ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None and len(noise_detect_scores) > 0 else None), ]) class RoBertaForWikiKGPLM(RobertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.roberta = RobertaModel(config) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) # self.cls = BertOnlyMLMHead(config) self.lm_head = RobertaLMHead(config) # Masked Language Modeling head self.detector = NSPHead(config) # Knowledge Noise Detection head self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size) self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size) # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)]) self.contrastive_loss_fn = ContrastiveLoss() self.post_init() self.tokenizer = RobertaTokenizer.from_pretrained(config.name_or_path) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, mlm_labels=None, entity_label=None, entity_negative=None, relation_label=None, relation_negative=None, noise_detect_label=None, task_id=None, mask_id=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): # start_time = time() return_dict = return_dict if return_dict is not None else self.config.use_return_dict # print("attention_mask.shape=", attention_mask.shape) # print("input_ids[0]=", input_ids[0]) # print("token_type_ids[0]=", token_type_ids[0]) # attention_mask = None outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output, pooled_output = outputs[:2] prediction_scores = self.lm_head(sequence_output) # mlm head noise_detect_scores = self.detector(pooled_output) # knowledge noise detector # ner # sequence_output = self.dropout(sequence_output) # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0) # mlm masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None if mlm_labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1)) if noise_detect_label is not None: loss_fct = CrossEntropyLoss() noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1)) total_loss = masked_lm_loss + noise_detect_loss if entity_label is not None and entity_negative is not None: batch_size = input_ids.shape[0] negative_num = entity_negative.shape[1] # print("negative_num=", negative_num) # 获取被mask实体的embedding batch_query_embedding = list() for ei, input_id in enumerate(input_ids): batch_query_embedding.append(torch.mean(sequence_output[ei][input_id == mask_id[ei]], 0)) # [hidden_dim] batch_query_embedding = torch.stack(batch_query_embedding) # [bz, dim] # print("batch_query_embedding.shape=", batch_query_embedding.shape) batch_query_embedding = self.entity_mlp(batch_query_embedding) # [bz, dim] batch_query_embedding = batch_query_embedding.unsqueeze(1).repeat((1, negative_num + 1, 1)) # [bz, 11, dim] batch_query_embedding = batch_query_embedding.view(-1, batch_query_embedding.shape[-1]) # [bz * 11, dim] # print("batch_query_embedding.shape=", batch_query_embedding.shape) # 获得positive和negative的BERT表示 # entity_label: [bz, len], entity_negative: [bz, 10, len] entity_negative = entity_negative.view(-1, entity_negative.shape[-1]) # [bz * 10, len] entity_label_embedding = self.roberta.embeddings(input_ids=entity_label) # [bz, len, dim] entity_label_embedding = self.entity_mlp(torch.mean(entity_label_embedding, 1)) # [bz, dim] entity_label_embedding = entity_label_embedding.unsqueeze(1) # [bz, 1, dim] entity_negative_embedding = self.roberta.embeddings(input_ids=entity_negative) # [bz * 10, len, dim] entity_negative_embedding = self.entity_mlp(torch.mean(entity_negative_embedding, 1)) # [bz * 10, dim] entity_negative_embedding = entity_negative_embedding \ .view(input_ids.shape[0], -1, entity_negative_embedding.shape[-1]) # [bz, 10, dim] contrastive_label = torch.Tensor([0] * negative_num + [1]).float().cuda() contrastive_label = contrastive_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11] # print("entity_negative_embedding.shape=", entity_negative_embedding.shape) # print("entity_label_embedding.shape=", entity_label_embedding.shape) candidate_embedding = torch.cat([entity_negative_embedding, entity_label_embedding], 1) # [bz, 11, dim] candidate_embedding = candidate_embedding.view(-1, candidate_embedding.shape[-1]) # [bz * 11, dim] # print("candidate_embedding.shape=", candidate_embedding.shape) entity_loss = self.contrastive_loss_fn(batch_query_embedding, candidate_embedding, contrastive_label) total_loss = masked_lm_loss + entity_loss # if ner_labels is not None: # loss_fct = CrossEntropyLoss() # # Only keep active parts of the loss # # active_loss = attention_mask.repeat(self.config.entity_type_num, 1, 1).view(-1) == 1 # active_logits = ner_logits.reshape(-1, self.config.num_ner_labels) # active_labels = torch.where( # active_loss, ner_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(ner_labels) # ) # ner_loss = loss_fct(active_logits, active_labels) # # if masked_lm_loss: # total_loss = masked_lm_loss + ner_loss * 4 # print("total_loss=", total_loss) # print("mlm_loss=", masked_lm_loss) # end_time = time() # print("neural_mode_time: {}".format(end_time - start_time)) return OrderedDict([ ("loss", total_loss), ("mlm_loss", masked_lm_loss.unsqueeze(0)), ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None), ("entity_loss", entity_loss.unsqueeze(0) if entity_label is not None else None), ("logits", prediction_scores.argmax(2)), ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None else None), ]) # MaskedLMOutput( # loss=total_loss, # logits=prediction_scores.argmax(2), # ner_l # hidden_states=outputs.hidden_states, # attentions=outputs.attentions, # ) class BertForWikiKGPLM(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.bert = BertModel(config) classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) # self.cls = BertOnlyMLMHead(config) self.cls = BertPreTrainedModel(config) self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size) self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size) # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)]) self.contrastive_loss_fn = ContrastiveLoss() self.post_init() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, mlm_labels=None, entity_label=None, entity_negative=None, relation_label=None, relation_negative=None, noise_detect_label=None, task_id=None, mask_id=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict print("attention_mask.shape=", attention_mask.shape) print("input_ids[0]=", input_ids[0]) print("token_type_ids[0]=", token_type_ids[0]) attention_mask = None outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output, pooled_output = outputs[:2] prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) # ner # sequence_output = self.dropout(sequence_output) # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0) # mlm masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None if mlm_labels is not None: loss_fct = CrossEntropyLoss() # -100 index = padding token masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1)) if noise_detect_label is not None: loss_fct = CrossEntropyLoss() noise_detect_loss = loss_fct(seq_relationship_score.view(-1, 2), noise_detect_label.view(-1)) total_loss = masked_lm_loss + noise_detect_loss if entity_label is not None and entity_negative is not None: negative_num = entity_negative.shape[1] # 获取被mask实体的embedding batch_query_embedding = list() for ei, input_id in enumerate(input_ids): batch_query_embedding.append(torch.mean(sequence_output[ei][input_id == mask_id[ei]], 0)) # [hidden_dim] batch_query_embedding = torch.stack(batch_query_embedding) # [bz, dim] batch_query_embedding = self.entity_mlp(batch_query_embedding) # [bz, dim] batch_query_embedding = batch_query_embedding.repeat((1, negative_num + 1, 1)) # [bz, 11, dim] # 获得positive和negative的BERT表示 # entity_label: [bz, len], entity_negative: [bz, 10, len] entity_negative = entity_negative.view(-1, entity_negative.shape[-1]) # [bz * 10, len] entity_label_embedding = self.bert.embeddings(input_id=entity_label) # [bz, len, dim] entity_label_embedding = self.entity_mlp(torch.mean(entity_label_embedding, 1)) # [bz, dim] entity_label_embedding = entity_label_embedding.unsqueeze(1) # [bz, 1, dim] entity_negative_embedding = self.bert.embeddings(input_id=entity_negative) # [bz * 10, len, dim] entity_negative_embedding = self.entity_mlp(torch.mean(entity_negative_embedding, 1)) # [bz * 10, dim] entity_negative_embedding = entity_negative_embedding \ .view(input_ids.shape[0], -1, entity_negative_embedding.shape[-1]) # [bz, 10, dim] contrastive_label = torch.Tensor([0] * negative_num + [1]).float().cuda() candidate_embedding = torch.cat([entity_negative_embedding, entity_label_embedding], 1) # [bz, 11, dim] entity_loss = self.contrastive_loss_fn(batch_query_embedding, candidate_embedding, contrastive_label) total_loss = masked_lm_loss + entity_loss # if ner_labels is not None: # loss_fct = CrossEntropyLoss() # # Only keep active parts of the loss # # active_loss = attention_mask.repeat(self.config.entity_type_num, 1, 1).view(-1) == 1 # active_logits = ner_logits.reshape(-1, self.config.num_ner_labels) # active_labels = torch.where( # active_loss, ner_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(ner_labels) # ) # ner_loss = loss_fct(active_logits, active_labels) # # if masked_lm_loss: # total_loss = masked_lm_loss + ner_loss * 4 return OrderedDict([ ("loss", total_loss), ("mlm_loss", masked_lm_loss.unsqueeze(0)), ("noise_detect_loss", noise_detect_loss.unsqueeze(0)), ("entity_loss", entity_loss.unsqueeze(0)), ("logits", prediction_scores.argmax(2)), ("noise_detect_logits", seq_relationship_score.argmax(3)), () ]) # MaskedLMOutput( # loss=total_loss, # logits=prediction_scores.argmax(2), # ner_l # hidden_states=outputs.hidden_states, # attentions=outputs.attentions, # )