Spaces:
Sleeping
Sleeping
import torch | |
from torch._C import NoopLogger | |
import torch.nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss | |
from transformers import BertModel, BertPreTrainedModel | |
from transformers import RobertaModel, RobertaPreTrainedModel | |
from transformers.modeling_outputs import MultipleChoiceModelOutput, BaseModelOutput, Seq2SeqLMOutput | |
from model.prefix_encoder import PrefixEncoder | |
from model.deberta import DebertaModel, DebertaPreTrainedModel, ContextPooler, StableDropout | |
from model import utils | |
class BertForMultipleChoice(BertPreTrainedModel): | |
"""BERT model for multiple choice tasks. | |
This module is composed of the BERT model with a linear layer on top of | |
the pooled output. | |
Params: | |
`config`: a BertConfig class instance with the configuration to build a new model. | |
`num_choices`: the number of classes for the classifier. Default = 2. | |
Inputs: | |
`input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] | |
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts | |
`extract_features.py`, `run_classifier.py` and `run_squad.py`) | |
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] | |
with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` | |
and type 1 corresponds to a `sentence B` token (see BERT paper for more details). | |
`attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices | |
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max | |
input sequence length in the current batch. It's the mask that we typically use for attention when | |
a batch has varying length sentences. | |
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size] | |
with indices selected in [0, ..., num_choices]. | |
Outputs: | |
if `labels` is not `None`: | |
Outputs the CrossEntropy classification loss of the output with the labels. | |
if `labels` is `None`: | |
Outputs the classification logits of shape [batch_size, num_labels]. | |
Example usage: | |
```python | |
# Already been converted into WordPiece token ids | |
input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) | |
input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) | |
token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) | |
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | |
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |
num_choices = 2 | |
model = BertForMultipleChoice(config, num_choices) | |
logits = model(input_ids, token_type_ids, input_mask) | |
``` | |
""" | |
def __init__(self, config): | |
super().__init__(config) | |
self.bert = BertModel(config) | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
self.classifier = torch.nn.Linear(config.hidden_size, 1) | |
self.init_weights() | |
self.embedding = utils.get_embeddings(self, config) | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False | |
): | |
utils.use_grad(self.bert, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, num_choices = input_ids.shape[:2] | |
input_ids = input_ids.reshape(-1, input_ids.size(-1)) | |
token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) | |
attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) | |
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None | |
inputs_embeds = ( | |
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) | |
if inputs_embeds is not None | |
else None | |
) | |
outputs = self.bert( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
pooled_output = outputs[1] | |
pooled_output = self.dropout(pooled_output) | |
logits = self.classifier(pooled_output) | |
reshaped_logits = logits.reshape(-1, num_choices) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(reshaped_logits, labels) | |
if not return_dict: | |
output = (reshaped_logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return MultipleChoiceModelOutput( | |
loss=loss, | |
logits=reshaped_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class BertPrefixForMultipleChoice(BertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.config = config | |
self.bert = BertModel(config) | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
self.classifier = torch.nn.Linear(config.hidden_size, 1) | |
for param in self.bert.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = PrefixEncoder(config) | |
bert_param = 0 | |
for name, param in self.bert.named_parameters(): | |
bert_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - bert_param | |
print('total param is {}'.format(total_param)) # 9860105 | |
self.embedding = utils.get_embeddings(self, config) | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) | |
past_key_values = self.prefix_encoder(prefix_tokens) | |
past_key_values = past_key_values.view( | |
batch_size, | |
self.pre_seq_len, | |
self.n_layer * 2, | |
self.n_head, | |
self.n_embd | |
) | |
past_key_values = self.dropout(past_key_values) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
return past_key_values | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False | |
): | |
utils.use_grad(self.bert, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds[:2] | |
input_ids = input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None | |
token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) if token_type_ids is not None else None | |
attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) if attention_mask is not None else None | |
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None | |
inputs_embeds = ( | |
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) | |
if inputs_embeds is not None | |
else None | |
) | |
past_key_values = self.get_prompt(batch_size=batch_size * num_choices) | |
prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.bert.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.bert( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
past_key_values=past_key_values, | |
) | |
pooled_output = outputs[1] | |
pooled_output = self.dropout(pooled_output) | |
logits = self.classifier(pooled_output) | |
reshaped_logits = logits.reshape(-1, num_choices) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(reshaped_logits, labels) | |
if not return_dict: | |
output = (reshaped_logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return MultipleChoiceModelOutput( | |
loss=loss, | |
logits=reshaped_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class RobertaPrefixForMultipleChoice(RobertaPreTrainedModel): | |
_keys_to_ignore_on_load_missing = [r"position_ids"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.roberta = RobertaModel(config) | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
self.classifier = torch.nn.Linear(config.hidden_size, 1) | |
self.init_weights() | |
for param in self.roberta.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = PrefixEncoder(config) | |
bert_param = 0 | |
for name, param in self.roberta.named_parameters(): | |
bert_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - bert_param | |
print('total param is {}'.format(total_param)) | |
self.embedding = utils.get_embeddings(self, config) | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) | |
past_key_values = self.prefix_encoder(prefix_tokens) | |
past_key_values = past_key_values.view( | |
batch_size, | |
self.pre_seq_len, | |
self.n_layer * 2, | |
self.n_head, | |
self.n_embd | |
) | |
past_key_values = self.dropout(past_key_values) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
return past_key_values | |
def forward( | |
self, | |
input_ids=None, | |
token_type_ids=None, | |
attention_mask=None, | |
labels=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., | |
num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See | |
:obj:`input_ids` above) | |
""" | |
utils.use_grad(self.roberta, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2] | |
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None | |
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None | |
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None | |
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None | |
flat_inputs_embeds = ( | |
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) | |
if inputs_embeds is not None | |
else None | |
) | |
past_key_values = self.get_prompt(batch_size=batch_size * num_choices) | |
prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.roberta.device) | |
flat_attention_mask = torch.cat((prefix_attention_mask, flat_attention_mask), dim=1) | |
outputs = self.roberta( | |
flat_input_ids, | |
position_ids=flat_position_ids, | |
token_type_ids=flat_token_type_ids, | |
attention_mask=flat_attention_mask, | |
head_mask=head_mask, | |
inputs_embeds=flat_inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
past_key_values=past_key_values, | |
) | |
pooled_output = outputs[1] | |
pooled_output = self.dropout(pooled_output) | |
logits = self.classifier(pooled_output) | |
reshaped_logits = logits.view(-1, num_choices) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(reshaped_logits, labels) | |
if not return_dict: | |
output = (reshaped_logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return MultipleChoiceModelOutput( | |
loss=loss, | |
logits=reshaped_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class DebertaPrefixForMultipleChoice(DebertaPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.config = config | |
self.deberta = DebertaModel(config) | |
self.pooler = ContextPooler(config) | |
output_dim = self.pooler.output_dim | |
self.classifier = torch.nn.Linear(output_dim, 1) | |
self.dropout = StableDropout(config.hidden_dropout_prob) | |
self.init_weights() | |
for param in self.deberta.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = PrefixEncoder(config) | |
deberta_param = 0 | |
for name, param in self.deberta.named_parameters(): | |
deberta_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - deberta_param | |
print('total param is {}'.format(total_param)) | |
self.embedding = utils.get_embeddings(self, config) | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) | |
past_key_values = self.prefix_encoder(prefix_tokens) | |
past_key_values = past_key_values.view( | |
batch_size, | |
self.pre_seq_len, | |
self.n_layer * 2, | |
self.n_head, | |
self.n_embd | |
) | |
past_key_values = self.dropout(past_key_values) | |
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) | |
return past_key_values | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False | |
): | |
utils.use_grad(self.deberta, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2] | |
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None | |
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None | |
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None | |
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None | |
flat_inputs_embeds = ( | |
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) | |
if inputs_embeds is not None | |
else None | |
) | |
past_key_values = self.get_prompt(batch_size=batch_size * num_choices) | |
prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.deberta.device) | |
flat_attention_mask = torch.cat((prefix_attention_mask, flat_attention_mask), dim=1) | |
outputs = self.deberta( | |
flat_input_ids, | |
attention_mask=flat_attention_mask, | |
token_type_ids=flat_token_type_ids, | |
position_ids=flat_position_ids, | |
inputs_embeds=flat_inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
past_key_values=past_key_values, | |
) | |
encoder_layer = outputs[0] | |
pooled_output = self.pooler(encoder_layer) | |
pooled_output = self.dropout(pooled_output) | |
logits = self.classifier(pooled_output) | |
reshaped_logits = logits.view(-1, num_choices) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(reshaped_logits, labels) | |
if not return_dict: | |
output = (reshaped_logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return MultipleChoiceModelOutput( | |
loss=loss, | |
logits=reshaped_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class BertPromptForMultipleChoice(BertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.config = config | |
self.bert = BertModel(config) | |
self.embeddings = self.bert.embeddings | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
self.classifier = torch.nn.Linear(config.hidden_size, 1) | |
for param in self.bert.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) | |
bert_param = 0 | |
for name, param in self.bert.named_parameters(): | |
bert_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - bert_param | |
print('total param is {}'.format(total_param)) # 9860105 | |
self.embedding = utils.get_embeddings(self, config) | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) | |
prompts = self.prefix_encoder(prefix_tokens) | |
return prompts | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False | |
): | |
utils.use_grad(self.bert, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds[:2] | |
input_ids = input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None | |
token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) if token_type_ids is not None else None | |
attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) if attention_mask is not None else None | |
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None | |
inputs_embeds = ( | |
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) | |
if inputs_embeds is not None | |
else None | |
) | |
raw_embedding = self.embeddings( | |
input_ids=input_ids, | |
position_ids=position_ids, | |
token_type_ids=token_type_ids, | |
) | |
prompts = self.get_prompt(batch_size=batch_size * num_choices) | |
inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) | |
prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.bert.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.bert( | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
pooled_output = outputs[1] | |
pooled_output = self.dropout(pooled_output) | |
logits = self.classifier(pooled_output) | |
reshaped_logits = logits.reshape(-1, num_choices) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(reshaped_logits, labels) | |
if not return_dict: | |
output = (reshaped_logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return MultipleChoiceModelOutput( | |
loss=loss, | |
logits=reshaped_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
class RobertaPromptForMultipleChoice(RobertaPreTrainedModel): | |
_keys_to_ignore_on_load_missing = [r"position_ids"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.roberta = RobertaModel(config) | |
self.embeddings = self.roberta.embeddings | |
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) | |
self.classifier = torch.nn.Linear(config.hidden_size, 1) | |
self.init_weights() | |
for param in self.roberta.parameters(): | |
param.requires_grad = False | |
self.pre_seq_len = config.pre_seq_len | |
self.n_layer = config.num_hidden_layers | |
self.n_head = config.num_attention_heads | |
self.n_embd = config.hidden_size // config.num_attention_heads | |
self.prefix_tokens = torch.arange(self.pre_seq_len).long() | |
self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) | |
bert_param = 0 | |
for name, param in self.roberta.named_parameters(): | |
bert_param += param.numel() | |
all_param = 0 | |
for name, param in self.named_parameters(): | |
all_param += param.numel() | |
total_param = all_param - bert_param | |
print('total param is {}'.format(total_param)) | |
self.embedding = utils.get_embeddings(self, config) | |
self.embeddings_gradient = utils.GradientStorage(self.embedding) | |
def get_prompt(self, batch_size): | |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) | |
prompts = self.prefix_encoder(prefix_tokens) | |
return prompts | |
def forward( | |
self, | |
input_ids=None, | |
token_type_ids=None, | |
attention_mask=None, | |
labels=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
use_base_grad=False | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., | |
num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See | |
:obj:`input_ids` above) | |
""" | |
utils.use_grad(self.roberta, use_base_grad) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2] | |
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None | |
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None | |
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None | |
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None | |
inputs_embeds = ( | |
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) | |
if inputs_embeds is not None | |
else None | |
) | |
raw_embedding = self.embeddings( | |
input_ids=input_ids, | |
position_ids=position_ids, | |
token_type_ids=token_type_ids, | |
) | |
prompts = self.get_prompt(batch_size=batch_size * num_choices) | |
inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) | |
prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.roberta.device) | |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) | |
outputs = self.roberta( | |
attention_mask=attention_mask, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
pooled_output = outputs[1] | |
pooled_output = self.dropout(pooled_output) | |
logits = self.classifier(pooled_output) | |
reshaped_logits = logits.view(-1, num_choices) | |
loss = None | |
if labels is not None: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(reshaped_logits, labels) | |
if not return_dict: | |
output = (reshaped_logits,) + outputs[2:] | |
return ((loss,) + output) if loss is not None else output | |
return MultipleChoiceModelOutput( | |
loss=loss, | |
logits=reshaped_logits, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) |