DeepLearning101's picture
Upload 3 files
5894ace
raw
history blame
37.4 kB
# -*- coding: utf-8 -*-
# @Time : 2022/3/15 21:26
# @Author : ruihan.wjn
# @File : pk-plm.py
"""
This code is implemented for the paper ""Knowledge Prompting in Pre-trained Langauge Models for Natural Langauge Understanding""
"""
from time import time
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from collections import OrderedDict
from transformers.models.bert import BertPreTrainedModel, BertModel
from transformers.models.roberta import RobertaModel, RobertaPreTrainedModel, RobertaTokenizer, RobertaForMaskedLM
from transformers.models.deberta import DebertaModel, DebertaPreTrainedModel, DebertaTokenizer, DebertaForMaskedLM
from transformers.models.bert.modeling_bert import BertOnlyMLMHead, BertPreTrainingHeads
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaLMHead
from transformers.models.deberta.modeling_deberta import DebertaModel, DebertaLMPredictionHead
"""
kg enhanced corpus structure example:
{
"token_ids": [20, 46098, 3277, 680, 10, 4066, 278, 9, 11129, 4063, 877, 579, 8, 8750, 14720, 8, 22498, 548,
19231, 46098, 3277, 6, 25, 157, 25, 130, 3753, 46098, 3277, 4, 3684, 19809, 10960, 9, 5, 30731, 2788, 914, 5,
1675, 8151, 35], "entity_pos": [[8, 11], [13, 15], [26, 27]],
"entity_qid": ["Q17582", "Q231978", "Q427013"],
"relation_pos": null,
"relation_pid": null
}
"""
from enum import Enum
class SiameseDistanceMetric(Enum):
"""
The metric for the contrastive loss
"""
EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)
class ContrastiveLoss(nn.Module):
"""
Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
:param model: SentenceTransformer model
:param distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
:param margin: Negative samples (label == 0) should have a distance of at least the margin value.
:param size_average: Average by the size of the mini-batch.
Example::
from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses
from sentence_transformers.readers import InputExample
model = SentenceTransformer("distilbert-base-nli-mean-tokens")
train_examples = [InputExample(texts=["This is a positive pair", "Where the distance will be minimized"], label=1),
InputExample(texts=["This is a negative pair", "Their distance will be increased"], label=0)]
train_dataset = SentencesDataset(train_examples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = losses.ContrastiveLoss(model=model)
"""
def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True):
super(ContrastiveLoss, self).__init__()
self.distance_metric = distance_metric
self.margin = margin
self.size_average = size_average
def forward(self, sent_embs1, sent_embs2, labels: torch.Tensor):
rep_anchor, rep_other = sent_embs1, sent_embs2
distances = self.distance_metric(rep_anchor, rep_other)
losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
return losses.mean() if self.size_average else losses.sum()
class NSPHead(nn.Module):
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class RoBertaKPPLMForProcessedWikiKGPLM(RobertaForMaskedLM):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
# self.roberta = RobertaModel(config)
try:
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
except:
classifier_dropout = (config.hidden_dropout_prob)
self.dropout = nn.Dropout(classifier_dropout)
# self.cls = BertOnlyMLMHead(config)
# self.lm_head = RobertaLMHead(config) # Masked Language Modeling head
self.detector = NSPHead(config) # Knowledge Noise Detection head
self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
# self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])
self.contrastive_loss_fn = ContrastiveLoss()
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
# entity_label=None,
entity_candidate=None,
# relation_label=None,
relation_candidate=None,
noise_detect_label=None,
task_id=None,
mask_id=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
# start_time = time()
mlm_labels = labels
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# print("attention_mask.shape=", attention_mask.shape)
# print("input_ids[0]=", input_ids[0])
# print("token_type_ids[0]=", token_type_ids[0])
# attention_mask = None
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output) # mlm head
# noise_detect_scores = self.detector(pooled_output) # knowledge noise detector use pool output
noise_detect_scores = self.detector(sequence_output[:, 0, :]) # knowledge noise detector use cls embedding
# ner
# sequence_output = self.dropout(sequence_output)
# ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)
# mlm
masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
total_loss = list()
if mlm_labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
total_loss.append(masked_lm_loss)
# if noise_detect_label is not None:
# noise_detect_scores = noise_detect_scores[task_id == 1]
# noise_detect_label = noise_detect_label[task_id == 1]
#
# if len(noise_detect_label) > 0:
# loss_fct = CrossEntropyLoss()
# noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1))
# total_loss.append(noise_detect_loss)
entity_candidate = entity_candidate[task_id == 2]
if len(entity_candidate) > 0:
batch_size = entity_candidate.shape[0]
candidate_num = entity_candidate.shape[1]
# print("negative_num=", negative_num)
# 获取被mask实体的embedding
batch_entity_query_embedding = list()
for ei, input_id in enumerate(input_ids[task_id == 2]):
batch_entity_query_embedding.append(
torch.mean(sequence_output[task_id == 2][ei][input_id == mask_id[task_id == 2][ei]], 0)) # [hidden_dim]
batch_entity_query_embedding = torch.stack(batch_entity_query_embedding) # [bz, dim]
# print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)
batch_entity_query_embedding = self.entity_mlp(batch_entity_query_embedding) # [bz, dim]
batch_entity_query_embedding = batch_entity_query_embedding.unsqueeze(1).repeat((1, candidate_num, 1)) # [bz, 11, dim]
batch_entity_query_embedding = batch_entity_query_embedding.view(-1, batch_entity_query_embedding.shape[-1]) # [bz * 11, dim]
# print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)
# 获得positive和negative的BERT表示
# entity_candidiate: [bz, 11, len]
entity_candidate = entity_candidate.view(-1, entity_candidate.shape[-1]) # [bz * 11, len]
entity_candidate_embedding = self.roberta.embeddings(input_ids=entity_candidate) # [bz * 11, len, dim]
entity_candidate_embedding = self.entity_mlp(torch.mean(entity_candidate_embedding, 1)) # [bz * 11, dim]
contrastive_entity_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
contrastive_entity_label = contrastive_entity_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
entity_loss = self.contrastive_loss_fn(
batch_entity_query_embedding, entity_candidate_embedding, contrastive_entity_label
)
total_loss.append(entity_loss)
relation_candidate = relation_candidate[task_id == 3]
if len(relation_candidate) > 0:
batch_size = relation_candidate.shape[0]
candidate_num = relation_candidate.shape[1]
# print("negative_num=", negative_num)
# 获取被mask relation的embedding
batch_relation_query_embedding = list()
for ei, input_id in enumerate(input_ids[task_id == 3]):
batch_relation_query_embedding.append(
torch.mean(sequence_output[task_id == 3][ei][input_id == mask_id[task_id == 3][ei]], 0)) # [hidden_dim]
batch_relation_query_embedding = torch.stack(batch_relation_query_embedding) # [bz, dim]
# print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)
batch_relation_query_embedding = self.relation_mlp(batch_relation_query_embedding) # [bz, dim]
batch_relation_query_embedding = batch_relation_query_embedding.unsqueeze(1).repeat(
(1, candidate_num, 1)) # [bz, 11, dim]
batch_relation_query_embedding = batch_relation_query_embedding.view(-1, batch_relation_query_embedding.shape[-1]) # [bz * 11, dim]
# print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)
# 获得positive和negative的BERT表示
# entity_candidiate: [bz, 11, len]
relation_candidate = relation_candidate.view(-1, relation_candidate.shape[-1]) # [bz * 11, len]
relation_candidate_embedding = self.roberta.embeddings(input_ids=relation_candidate) # [bz * 11, len, dim]
relation_candidate_embedding = self.relation_mlp(torch.mean(relation_candidate_embedding, 1)) # [bz * 11, dim]
contrastive_relation_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
contrastive_relation_label = contrastive_relation_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
relation_loss = self.contrastive_loss_fn(
batch_relation_query_embedding, relation_candidate_embedding, contrastive_relation_label
)
total_loss.append(relation_loss)
total_loss = torch.sum(torch.stack(total_loss), -1)
# end_time = time()
# print("neural_mode_time: {}".format(end_time - start_time))
# print("masked_lm_loss.unsqueeze(0)=", masked_lm_loss.unsqueeze(0))
# print("masked_lm_loss.unsqueeze(0).shape=", masked_lm_loss.unsqueeze(0).shape)
# print("logits=", prediction_scores.argmax(2))
# print("logits.shape=", prediction_scores.argmax(2).shape)
return OrderedDict([
("loss", total_loss),
("mlm_loss", masked_lm_loss.unsqueeze(0)),
# ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None),
# ("entity_loss", entity_loss.unsqueeze(0) if entity_loss is not None else None),
# ("relation_loss", relation_loss.unsqueeze(0) if relation_loss is not None else None),
("logits", prediction_scores.argmax(2)),
# ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None and len(noise_detect_scores) > 0 else None),
])
class DeBertaKPPLMForProcessedWikiKGPLM(DebertaForMaskedLM):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
# self.roberta = RobertaModel(config)
try:
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
except:
classifier_dropout = (config.hidden_dropout_prob)
self.dropout = nn.Dropout(classifier_dropout)
# self.cls = BertOnlyMLMHead(config)
# self.lm_head = RobertaLMHead(config) # Masked Language Modeling head
self.detector = NSPHead(config) # Knowledge Noise Detection head
self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
# self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])
self.contrastive_loss_fn = ContrastiveLoss()
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
# entity_label=None,
entity_candidate=None,
# relation_label=None,
relation_candidate=None,
noise_detect_label=None,
task_id=None,
mask_id=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
# start_time = time()
mlm_labels = labels
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# print("attention_mask.shape=", attention_mask.shape)
# print("input_ids[0]=", input_ids[0])
# print("token_type_ids[0]=", token_type_ids[0])
# attention_mask = None
outputs = self.deberta(
input_ids,
# attention_mask=attention_mask,
attention_mask=None,
token_type_ids=token_type_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output) # mlm head
# noise_detect_scores = self.detector(pooled_output) # knowledge noise detector use pool output
noise_detect_scores = self.detector(sequence_output[:, 0, :]) # knowledge noise detector use cls embedding
# ner
# sequence_output = self.dropout(sequence_output)
# ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)
# mlm
masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
total_loss = list()
if mlm_labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
total_loss.append(masked_lm_loss)
# if noise_detect_label is not None:
# noise_detect_scores = noise_detect_scores[task_id == 1]
# noise_detect_label = noise_detect_label[task_id == 1]
#
# if len(noise_detect_label) > 0:
# loss_fct = CrossEntropyLoss()
# noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1))
# total_loss.append(noise_detect_loss)
entity_candidate = entity_candidate[task_id == 2]
if len(entity_candidate) > 0:
batch_size = entity_candidate.shape[0]
candidate_num = entity_candidate.shape[1]
# print("negative_num=", negative_num)
# 获取被mask实体的embedding
batch_entity_query_embedding = list()
for ei, input_id in enumerate(input_ids[task_id == 2]):
batch_entity_query_embedding.append(
torch.mean(sequence_output[task_id == 2][ei][input_id == mask_id[task_id == 2][ei]], 0)) # [hidden_dim]
batch_entity_query_embedding = torch.stack(batch_entity_query_embedding) # [bz, dim]
# print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)
batch_entity_query_embedding = self.entity_mlp(batch_entity_query_embedding) # [bz, dim]
batch_entity_query_embedding = batch_entity_query_embedding.unsqueeze(1).repeat((1, candidate_num, 1)) # [bz, 11, dim]
batch_entity_query_embedding = batch_entity_query_embedding.view(-1, batch_entity_query_embedding.shape[-1]) # [bz * 11, dim]
# print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)
# 获得positive和negative的BERT表示
# entity_candidiate: [bz, 11, len]
entity_candidate = entity_candidate.view(-1, entity_candidate.shape[-1]) # [bz * 11, len]
entity_candidate_embedding = self.deberta.embeddings(input_ids=entity_candidate) # [bz * 11, len, dim]
entity_candidate_embedding = self.entity_mlp(torch.mean(entity_candidate_embedding, 1)) # [bz * 11, dim]
contrastive_entity_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
contrastive_entity_label = contrastive_entity_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
entity_loss = self.contrastive_loss_fn(
batch_entity_query_embedding, entity_candidate_embedding, contrastive_entity_label
)
total_loss.append(entity_loss)
relation_candidate = relation_candidate[task_id == 3]
if len(relation_candidate) > 0:
batch_size = relation_candidate.shape[0]
candidate_num = relation_candidate.shape[1]
# print("negative_num=", negative_num)
# 获取被mask relation的embedding
batch_relation_query_embedding = list()
for ei, input_id in enumerate(input_ids[task_id == 3]):
batch_relation_query_embedding.append(
torch.mean(sequence_output[task_id == 3][ei][input_id == mask_id[task_id == 3][ei]], 0)) # [hidden_dim]
batch_relation_query_embedding = torch.stack(batch_relation_query_embedding) # [bz, dim]
# print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)
batch_relation_query_embedding = self.relation_mlp(batch_relation_query_embedding) # [bz, dim]
batch_relation_query_embedding = batch_relation_query_embedding.unsqueeze(1).repeat(
(1, candidate_num, 1)) # [bz, 11, dim]
batch_relation_query_embedding = batch_relation_query_embedding.view(-1, batch_relation_query_embedding.shape[-1]) # [bz * 11, dim]
# print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)
# 获得positive和negative的BERT表示
# entity_candidiate: [bz, 11, len]
relation_candidate = relation_candidate.view(-1, relation_candidate.shape[-1]) # [bz * 11, len]
relation_candidate_embedding = self.deberta.embeddings(input_ids=relation_candidate) # [bz * 11, len, dim]
relation_candidate_embedding = self.relation_mlp(torch.mean(relation_candidate_embedding, 1)) # [bz * 11, dim]
contrastive_relation_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
contrastive_relation_label = contrastive_relation_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
relation_loss = self.contrastive_loss_fn(
batch_relation_query_embedding, relation_candidate_embedding, contrastive_relation_label
)
total_loss.append(relation_loss)
total_loss = torch.sum(torch.stack(total_loss), -1)
# end_time = time()
# print("neural_mode_time: {}".format(end_time - start_time))
# print("masked_lm_loss.unsqueeze(0)=", masked_lm_loss.unsqueeze(0))
# print("masked_lm_loss.unsqueeze(0).shape=", masked_lm_loss.unsqueeze(0).shape)
# print("logits=", prediction_scores.argmax(2))
# print("logits.shape=", prediction_scores.argmax(2).shape)
return OrderedDict([
("loss", total_loss),
("mlm_loss", masked_lm_loss.unsqueeze(0)),
# ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None),
# ("entity_loss", entity_loss.unsqueeze(0) if entity_loss is not None else None),
# ("relation_loss", relation_loss.unsqueeze(0) if relation_loss is not None else None),
("logits", prediction_scores.argmax(2)),
# ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None and len(noise_detect_scores) > 0 else None),
])
class RoBertaForWikiKGPLM(RobertaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.roberta = RobertaModel(config)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
# self.cls = BertOnlyMLMHead(config)
self.lm_head = RobertaLMHead(config) # Masked Language Modeling head
self.detector = NSPHead(config) # Knowledge Noise Detection head
self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
# self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])
self.contrastive_loss_fn = ContrastiveLoss()
self.post_init()
self.tokenizer = RobertaTokenizer.from_pretrained(config.name_or_path)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
mlm_labels=None,
entity_label=None,
entity_negative=None,
relation_label=None,
relation_negative=None,
noise_detect_label=None,
task_id=None,
mask_id=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
# start_time = time()
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# print("attention_mask.shape=", attention_mask.shape)
# print("input_ids[0]=", input_ids[0])
# print("token_type_ids[0]=", token_type_ids[0])
# attention_mask = None
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.lm_head(sequence_output) # mlm head
noise_detect_scores = self.detector(pooled_output) # knowledge noise detector
# ner
# sequence_output = self.dropout(sequence_output)
# ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)
# mlm
masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
if mlm_labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
if noise_detect_label is not None:
loss_fct = CrossEntropyLoss()
noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1))
total_loss = masked_lm_loss + noise_detect_loss
if entity_label is not None and entity_negative is not None:
batch_size = input_ids.shape[0]
negative_num = entity_negative.shape[1]
# print("negative_num=", negative_num)
# 获取被mask实体的embedding
batch_query_embedding = list()
for ei, input_id in enumerate(input_ids):
batch_query_embedding.append(torch.mean(sequence_output[ei][input_id == mask_id[ei]], 0)) # [hidden_dim]
batch_query_embedding = torch.stack(batch_query_embedding) # [bz, dim]
# print("batch_query_embedding.shape=", batch_query_embedding.shape)
batch_query_embedding = self.entity_mlp(batch_query_embedding) # [bz, dim]
batch_query_embedding = batch_query_embedding.unsqueeze(1).repeat((1, negative_num + 1, 1)) # [bz, 11, dim]
batch_query_embedding = batch_query_embedding.view(-1, batch_query_embedding.shape[-1]) # [bz * 11, dim]
# print("batch_query_embedding.shape=", batch_query_embedding.shape)
# 获得positive和negative的BERT表示
# entity_label: [bz, len], entity_negative: [bz, 10, len]
entity_negative = entity_negative.view(-1, entity_negative.shape[-1]) # [bz * 10, len]
entity_label_embedding = self.roberta.embeddings(input_ids=entity_label) # [bz, len, dim]
entity_label_embedding = self.entity_mlp(torch.mean(entity_label_embedding, 1)) # [bz, dim]
entity_label_embedding = entity_label_embedding.unsqueeze(1) # [bz, 1, dim]
entity_negative_embedding = self.roberta.embeddings(input_ids=entity_negative) # [bz * 10, len, dim]
entity_negative_embedding = self.entity_mlp(torch.mean(entity_negative_embedding, 1)) # [bz * 10, dim]
entity_negative_embedding = entity_negative_embedding \
.view(input_ids.shape[0], -1, entity_negative_embedding.shape[-1]) # [bz, 10, dim]
contrastive_label = torch.Tensor([0] * negative_num + [1]).float().cuda()
contrastive_label = contrastive_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
# print("entity_negative_embedding.shape=", entity_negative_embedding.shape)
# print("entity_label_embedding.shape=", entity_label_embedding.shape)
candidate_embedding = torch.cat([entity_negative_embedding, entity_label_embedding], 1) # [bz, 11, dim]
candidate_embedding = candidate_embedding.view(-1, candidate_embedding.shape[-1]) # [bz * 11, dim]
# print("candidate_embedding.shape=", candidate_embedding.shape)
entity_loss = self.contrastive_loss_fn(batch_query_embedding, candidate_embedding, contrastive_label)
total_loss = masked_lm_loss + entity_loss
# if ner_labels is not None:
# loss_fct = CrossEntropyLoss()
# # Only keep active parts of the loss
#
# active_loss = attention_mask.repeat(self.config.entity_type_num, 1, 1).view(-1) == 1
# active_logits = ner_logits.reshape(-1, self.config.num_ner_labels)
# active_labels = torch.where(
# active_loss, ner_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(ner_labels)
# )
# ner_loss = loss_fct(active_logits, active_labels)
#
# if masked_lm_loss:
# total_loss = masked_lm_loss + ner_loss * 4
# print("total_loss=", total_loss)
# print("mlm_loss=", masked_lm_loss)
# end_time = time()
# print("neural_mode_time: {}".format(end_time - start_time))
return OrderedDict([
("loss", total_loss),
("mlm_loss", masked_lm_loss.unsqueeze(0)),
("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None),
("entity_loss", entity_loss.unsqueeze(0) if entity_label is not None else None),
("logits", prediction_scores.argmax(2)),
("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None else None),
])
# MaskedLMOutput(
# loss=total_loss,
# logits=prediction_scores.argmax(2),
# ner_l
# hidden_states=outputs.hidden_states,
# attentions=outputs.attentions,
# )
class BertForWikiKGPLM(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.bert = BertModel(config)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
)
self.dropout = nn.Dropout(classifier_dropout)
# self.cls = BertOnlyMLMHead(config)
self.cls = BertPreTrainedModel(config)
self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
# self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])
self.contrastive_loss_fn = ContrastiveLoss()
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
mlm_labels=None,
entity_label=None,
entity_negative=None,
relation_label=None,
relation_negative=None,
noise_detect_label=None,
task_id=None,
mask_id=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
print("attention_mask.shape=", attention_mask.shape)
print("input_ids[0]=", input_ids[0])
print("token_type_ids[0]=", token_type_ids[0])
attention_mask = None
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
# ner
# sequence_output = self.dropout(sequence_output)
# ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)
# mlm
masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
if mlm_labels is not None:
loss_fct = CrossEntropyLoss() # -100 index = padding token
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
if noise_detect_label is not None:
loss_fct = CrossEntropyLoss()
noise_detect_loss = loss_fct(seq_relationship_score.view(-1, 2), noise_detect_label.view(-1))
total_loss = masked_lm_loss + noise_detect_loss
if entity_label is not None and entity_negative is not None:
negative_num = entity_negative.shape[1]
# 获取被mask实体的embedding
batch_query_embedding = list()
for ei, input_id in enumerate(input_ids):
batch_query_embedding.append(torch.mean(sequence_output[ei][input_id == mask_id[ei]], 0)) # [hidden_dim]
batch_query_embedding = torch.stack(batch_query_embedding) # [bz, dim]
batch_query_embedding = self.entity_mlp(batch_query_embedding) # [bz, dim]
batch_query_embedding = batch_query_embedding.repeat((1, negative_num + 1, 1)) # [bz, 11, dim]
# 获得positive和negative的BERT表示
# entity_label: [bz, len], entity_negative: [bz, 10, len]
entity_negative = entity_negative.view(-1, entity_negative.shape[-1]) # [bz * 10, len]
entity_label_embedding = self.bert.embeddings(input_id=entity_label) # [bz, len, dim]
entity_label_embedding = self.entity_mlp(torch.mean(entity_label_embedding, 1)) # [bz, dim]
entity_label_embedding = entity_label_embedding.unsqueeze(1) # [bz, 1, dim]
entity_negative_embedding = self.bert.embeddings(input_id=entity_negative) # [bz * 10, len, dim]
entity_negative_embedding = self.entity_mlp(torch.mean(entity_negative_embedding, 1)) # [bz * 10, dim]
entity_negative_embedding = entity_negative_embedding \
.view(input_ids.shape[0], -1, entity_negative_embedding.shape[-1]) # [bz, 10, dim]
contrastive_label = torch.Tensor([0] * negative_num + [1]).float().cuda()
candidate_embedding = torch.cat([entity_negative_embedding, entity_label_embedding], 1) # [bz, 11, dim]
entity_loss = self.contrastive_loss_fn(batch_query_embedding, candidate_embedding, contrastive_label)
total_loss = masked_lm_loss + entity_loss
# if ner_labels is not None:
# loss_fct = CrossEntropyLoss()
# # Only keep active parts of the loss
#
# active_loss = attention_mask.repeat(self.config.entity_type_num, 1, 1).view(-1) == 1
# active_logits = ner_logits.reshape(-1, self.config.num_ner_labels)
# active_labels = torch.where(
# active_loss, ner_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(ner_labels)
# )
# ner_loss = loss_fct(active_logits, active_labels)
#
# if masked_lm_loss:
# total_loss = masked_lm_loss + ner_loss * 4
return OrderedDict([
("loss", total_loss),
("mlm_loss", masked_lm_loss.unsqueeze(0)),
("noise_detect_loss", noise_detect_loss.unsqueeze(0)),
("entity_loss", entity_loss.unsqueeze(0)),
("logits", prediction_scores.argmax(2)),
("noise_detect_logits", seq_relationship_score.argmax(3)),
()
])
# MaskedLMOutput(
# loss=total_loss,
# logits=prediction_scores.argmax(2),
# ner_l
# hidden_states=outputs.hidden_states,
# attentions=outputs.attentions,
# )