Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn | |
from torch.nn import CrossEntropyLoss | |
from transformers import BertPreTrainedModel, BertModel, RobertaPreTrainedModel, RobertaModel | |
from transformers.modeling_outputs import QuestionAnsweringModelOutput | |
from model.prefix_encoder import PrefixEncoder | |
from model.deberta import DebertaPreTrainedModel, DebertaModel | |
class BertForQuestionAnswering(BertPreTrainedModel): | |
_keys_to_ignore_on_load_unexpected = [r"pooler"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.bert = BertModel(config, add_pooling_layer=False) | |
self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels) | |
for param in self.bert.parameters(): | |
param.requires_grad = False | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
start_positions=None, | |
end_positions=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
r""" | |
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for position (index) of the start of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the | |
sequence are not taken into account for computing the loss. | |
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for position (index) of the end of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the | |
sequence are not taken into account for computing the loss. | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.bert( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
logits = self.qa_outputs(sequence_output) | |
start_logits, end_logits = logits.split(1, dim=-1) | |
start_logits = start_logits.squeeze(-1).contiguous() | |
end_logits = end_logits.squeeze(-1).contiguous() | |
total_loss = None | |
if start_positions is not None and end_positions is not None: | |
# If we are on multi-GPU, split add a dimension | |
if len(start_positions.size()) > 1: | |
start_positions = start_positions.squeeze(-1) | |
if len(end_positions.size()) > 1: | |
end_positions = end_positions.squeeze(-1) | |
# sometimes the start/end positions are outside our model inputs, we ignore these terms | |
ignored_index = start_logits.size(1) | |
start_positions = start_positions.clamp(0, ignored_index) | |
end_positions = end_positions.clamp(0, ignored_index) | |
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | |
start_loss = loss_fct(start_logits, start_positions) | |
end_loss = loss_fct(end_logits, end_positions) | |
total_loss = (start_loss + end_loss) / 2 | |
if not return_dict: | |
output = (start_logits, end_logits) + outputs[2:] | |
return ((total_loss,) + output) if total_loss is not None else output | |
return QuestionAnsweringModelOutput( | |
loss=total_loss, | |
start_logits=start_logits, | |
end_logits=end_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class BertPrefixForQuestionAnswering(BertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.bert = BertModel(config, add_pooling_layer=False) | |
self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels) | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
self.prefix_encoder = PrefixEncoder(config) | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
for param in self.bert.parameters(): | |
param.requires_grad = False | |
self.init_weights() | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) | |
past_key_values = self.prefix_encoder(prefix_tokens) | |
bsz, seqlen, _ = past_key_values.shape | |
past_key_values = past_key_values.view( | |
bsz, | |
seqlen, | |
self.n_layer * 2, | |
self.n_head, | |
self.n_embd | |
) | |
past_key_values = self.dropout(past_key_values) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
return past_key_values | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
start_positions=None, | |
end_positions=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
r""" | |
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for position (index) of the start of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the | |
sequence are not taken into account for computing the loss. | |
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for position (index) of the end of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the | |
sequence are not taken into account for computing the loss. | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
past_key_values = self.get_prompt(batch_size=batch_size) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.bert( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
past_key_values=past_key_values, | |
) | |
sequence_output = outputs[0] | |
logits = self.qa_outputs(sequence_output) | |
start_logits, end_logits = logits.split(1, dim=-1) | |
start_logits = start_logits.squeeze(-1).contiguous() | |
end_logits = end_logits.squeeze(-1).contiguous() | |
total_loss = None | |
if start_positions is not None and end_positions is not None: | |
# If we are on multi-GPU, split add a dimension | |
if len(start_positions.size()) > 1: | |
start_positions = start_positions.squeeze(-1) | |
if len(end_positions.size()) > 1: | |
end_positions = end_positions.squeeze(-1) | |
# sometimes the start/end positions are outside our model inputs, we ignore these terms | |
ignored_index = start_logits.size(1) | |
start_positions = start_positions.clamp(0, ignored_index) | |
end_positions = end_positions.clamp(0, ignored_index) | |
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | |
start_loss = loss_fct(start_logits, start_positions) | |
end_loss = loss_fct(end_logits, end_positions) | |
total_loss = (start_loss + end_loss) / 2 | |
if not return_dict: | |
output = (start_logits, end_logits) + outputs[2:] | |
return ((total_loss,) + output) if total_loss is not None else output | |
return QuestionAnsweringModelOutput( | |
loss=total_loss, | |
start_logits=start_logits, | |
end_logits=end_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class RobertaPrefixModelForQuestionAnswering(RobertaPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.roberta = RobertaModel(config, add_pooling_layer=False) | |
self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels) | |
self.init_weights() | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
self.prefix_encoder = PrefixEncoder(config) | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
for param in self.roberta.parameters(): | |
param.requires_grad = False | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) | |
past_key_values = self.prefix_encoder(prefix_tokens) | |
bsz, seqlen, _ = past_key_values.shape | |
past_key_values = past_key_values.view( | |
bsz, | |
seqlen, | |
self.n_layer * 2, | |
self.n_head, | |
self.n_embd | |
) | |
past_key_values = self.dropout(past_key_values) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
return past_key_values | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
start_positions=None, | |
end_positions=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
r""" | |
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for position (index) of the start of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the | |
sequence are not taken into account for computing the loss. | |
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for position (index) of the end of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the | |
sequence are not taken into account for computing the loss. | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
past_key_values = self.get_prompt(batch_size=batch_size) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.roberta( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
past_key_values=past_key_values, | |
) | |
sequence_output = outputs[0] | |
logits = self.qa_outputs(sequence_output) | |
start_logits, end_logits = logits.split(1, dim=-1) | |
start_logits = start_logits.squeeze(-1).contiguous() | |
end_logits = end_logits.squeeze(-1).contiguous() | |
total_loss = None | |
if start_positions is not None and end_positions is not None: | |
# If we are on multi-GPU, split add a dimension | |
if len(start_positions.size()) > 1: | |
start_positions = start_positions.squeeze(-1) | |
if len(end_positions.size()) > 1: | |
end_positions = end_positions.squeeze(-1) | |
# sometimes the start/end positions are outside our model inputs, we ignore these terms | |
ignored_index = start_logits.size(1) | |
start_positions = start_positions.clamp(0, ignored_index) | |
end_positions = end_positions.clamp(0, ignored_index) | |
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | |
start_loss = loss_fct(start_logits, start_positions) | |
end_loss = loss_fct(end_logits, end_positions) | |
total_loss = (start_loss + end_loss) / 2 | |
if not return_dict: | |
output = (start_logits, end_logits) + outputs[2:] | |
return ((total_loss,) + output) if total_loss is not None else output | |
return QuestionAnsweringModelOutput( | |
loss=total_loss, | |
start_logits=start_logits, | |
end_logits=end_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class DebertaPrefixModelForQuestionAnswering(DebertaPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.deberta = DebertaModel(config) | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels) | |
self.init_weights() | |
for param in self.deberta.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
# Use a two layered MLP to encode the prefix | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = PrefixEncoder(config) | |
deberta_param = 0 | |
for name, param in self.deberta.named_parameters(): | |
deberta_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - deberta_param | |
print('total param is {}'.format(total_param)) # 9860105 | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) | |
past_key_values = self.prefix_encoder(prefix_tokens) | |
# bsz, seqlen, _ = past_key_values.shape | |
past_key_values = past_key_values.view( | |
batch_size, | |
self.pre_seq_len, | |
self.n_layer * 2, | |
self.n_head, | |
self.n_embd | |
) | |
past_key_values = self.dropout(past_key_values) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
return past_key_values | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
# head_mask=None, | |
inputs_embeds=None, | |
start_positions=None, | |
end_positions=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
): | |
r""" | |
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for position (index) of the start of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the | |
sequence are not taken into account for computing the loss. | |
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for position (index) of the end of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the | |
sequence are not taken into account for computing the loss. | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size = input_ids.shape[0] | |
past_key_values = self.get_prompt(batch_size=batch_size) | |
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.deberta( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
past_key_values=past_key_values, | |
) | |
sequence_output = outputs[0] | |
logits = self.qa_outputs(sequence_output) | |
start_logits, end_logits = logits.split(1, dim=-1) | |
start_logits = start_logits.squeeze(-1).contiguous() | |
end_logits = end_logits.squeeze(-1).contiguous() | |
total_loss = None | |
if start_positions is not None and end_positions is not None: | |
# If we are on multi-GPU, split add a dimension | |
if len(start_positions.size()) > 1: | |
start_positions = start_positions.squeeze(-1) | |
if len(end_positions.size()) > 1: | |
end_positions = end_positions.squeeze(-1) | |
# sometimes the start/end positions are outside our model inputs, we ignore these terms | |
ignored_index = start_logits.size(1) | |
start_positions = start_positions.clamp(0, ignored_index) | |
end_positions = end_positions.clamp(0, ignored_index) | |
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) | |
start_loss = loss_fct(start_logits, start_positions) | |
end_loss = loss_fct(end_logits, end_positions) | |
total_loss = (start_loss + end_loss) / 2 | |
if not return_dict: | |
output = (start_logits, end_logits) + outputs[2:] | |
return ((total_loss,) + output) if total_loss is not None else output | |
return QuestionAnsweringModelOutput( | |
loss=total_loss, | |
start_logits=start_logits, | |
end_logits=end_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) |