import torch from torch._C import NoopLogger import torch.nn import torch.nn.functional as F from torch import Tensor from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss from transformers import BertModel, BertPreTrainedModel from transformers import RobertaModel, RobertaPreTrainedModel from transformers.modeling_outputs import MultipleChoiceModelOutput, BaseModelOutput, Seq2SeqLMOutput from model.prefix_encoder import PrefixEncoder from model.deberta import DebertaModel, DebertaPreTrainedModel, ContextPooler, StableDropout from model import utils class BertForMultipleChoice(BertPreTrainedModel): """BERT model for multiple choice tasks. This module is composed of the BERT model with a linear layer on top of the pooled output. Params: `config`: a BertConfig class instance with the configuration to build a new model. `num_choices`: the number of classes for the classifier. Default = 2. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`) `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] with indices selected in [0, ..., num_choices]. Outputs: if `labels` is not `None`: Outputs the CrossEntropy classification loss of the output with the labels. if `labels` is `None`: Outputs the classification logits of shape [batch_size, num_labels]. Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) num_choices = 2 model = BertForMultipleChoice(config, num_choices) logits = model(input_ids, token_type_ids, input_mask) ``` """ def __init__(self, config): super().__init__(config) self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, 1) self.init_weights() self.embedding = utils.get_embeddings(self, config) self.embeddings_gradient = utils.GradientStorage(self.embedding) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, use_base_grad=False ): utils.use_grad(self.bert, use_base_grad) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, num_choices = input_ids.shape[:2] input_ids = input_ids.reshape(-1, input_ids.size(-1)) token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None inputs_embeds = ( inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) if inputs_embeds is not None else None ) outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.reshape(-1, num_choices) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return MultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class BertPrefixForMultipleChoice(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, 1) for param in self.bert.parameters(): param.requires_grad = False self.pre_seq_len = config.pre_seq_len self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = PrefixEncoder(config) bert_param = 0 for name, param in self.bert.named_parameters(): bert_param += param.numel() all_param = 0 for name, param in self.named_parameters(): all_param += param.numel() total_param = all_param - bert_param print('total param is {}'.format(total_param)) # 9860105 self.embedding = utils.get_embeddings(self, config) self.embeddings_gradient = utils.GradientStorage(self.embedding) def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) past_key_values = self.prefix_encoder(prefix_tokens) past_key_values = past_key_values.view( batch_size, self.pre_seq_len, self.n_layer * 2, self.n_head, self.n_embd ) past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) return past_key_values def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, use_base_grad=False ): utils.use_grad(self.bert, use_base_grad) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds[:2] input_ids = input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) if token_type_ids is not None else None attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) if attention_mask is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None inputs_embeds = ( inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) if inputs_embeds is not None else None ) past_key_values = self.get_prompt(batch_size=batch_size * num_choices) prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.bert.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.reshape(-1, num_choices) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return MultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class RobertaPrefixForMultipleChoice(RobertaPreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) self.roberta = RobertaModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, 1) self.init_weights() for param in self.roberta.parameters(): param.requires_grad = False self.pre_seq_len = config.pre_seq_len self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = PrefixEncoder(config) bert_param = 0 for name, param in self.roberta.named_parameters(): bert_param += param.numel() all_param = 0 for name, param in self.named_parameters(): all_param += param.numel() total_param = all_param - bert_param print('total param is {}'.format(total_param)) self.embedding = utils.get_embeddings(self, config) self.embeddings_gradient = utils.GradientStorage(self.embedding) def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) past_key_values = self.prefix_encoder(prefix_tokens) past_key_values = past_key_values.view( batch_size, self.pre_seq_len, self.n_layer * 2, self.n_head, self.n_embd ) past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) return past_key_values def forward( self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, use_base_grad=False ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See :obj:`input_ids` above) """ utils.use_grad(self.roberta, use_base_grad) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2] flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_inputs_embeds = ( inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) if inputs_embeds is not None else None ) past_key_values = self.get_prompt(batch_size=batch_size * num_choices) prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.roberta.device) flat_attention_mask = torch.cat((prefix_attention_mask, flat_attention_mask), dim=1) outputs = self.roberta( flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids, attention_mask=flat_attention_mask, head_mask=head_mask, inputs_embeds=flat_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.view(-1, num_choices) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return MultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class DebertaPrefixForMultipleChoice(DebertaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.deberta = DebertaModel(config) self.pooler = ContextPooler(config) output_dim = self.pooler.output_dim self.classifier = torch.nn.Linear(output_dim, 1) self.dropout = StableDropout(config.hidden_dropout_prob) self.init_weights() for param in self.deberta.parameters(): param.requires_grad = False self.pre_seq_len = config.pre_seq_len self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = PrefixEncoder(config) deberta_param = 0 for name, param in self.deberta.named_parameters(): deberta_param += param.numel() all_param = 0 for name, param in self.named_parameters(): all_param += param.numel() total_param = all_param - deberta_param print('total param is {}'.format(total_param)) self.embedding = utils.get_embeddings(self, config) self.embeddings_gradient = utils.GradientStorage(self.embedding) def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) past_key_values = self.prefix_encoder(prefix_tokens) past_key_values = past_key_values.view( batch_size, self.pre_seq_len, self.n_layer * 2, self.n_head, self.n_embd ) past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) return past_key_values def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, use_base_grad=False ): utils.use_grad(self.deberta, use_base_grad) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2] flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None flat_inputs_embeds = ( inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) if inputs_embeds is not None else None ) past_key_values = self.get_prompt(batch_size=batch_size * num_choices) prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.deberta.device) flat_attention_mask = torch.cat((prefix_attention_mask, flat_attention_mask), dim=1) outputs = self.deberta( flat_input_ids, attention_mask=flat_attention_mask, token_type_ids=flat_token_type_ids, position_ids=flat_position_ids, inputs_embeds=flat_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, ) encoder_layer = outputs[0] pooled_output = self.pooler(encoder_layer) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.view(-1, num_choices) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return MultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class BertPromptForMultipleChoice(BertPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.config = config self.bert = BertModel(config) self.embeddings = self.bert.embeddings self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, 1) for param in self.bert.parameters(): param.requires_grad = False self.pre_seq_len = config.pre_seq_len self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) bert_param = 0 for name, param in self.bert.named_parameters(): bert_param += param.numel() all_param = 0 for name, param in self.named_parameters(): all_param += param.numel() total_param = all_param - bert_param print('total param is {}'.format(total_param)) # 9860105 self.embedding = utils.get_embeddings(self, config) self.embeddings_gradient = utils.GradientStorage(self.embedding) def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) prompts = self.prefix_encoder(prefix_tokens) return prompts def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, use_base_grad=False ): utils.use_grad(self.bert, use_base_grad) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds[:2] input_ids = input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) if token_type_ids is not None else None attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) if attention_mask is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None inputs_embeds = ( inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) if inputs_embeds is not None else None ) raw_embedding = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, ) prompts = self.get_prompt(batch_size=batch_size * num_choices) inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.bert.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) outputs = self.bert( attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.reshape(-1, num_choices) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return MultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class RobertaPromptForMultipleChoice(RobertaPreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) self.roberta = RobertaModel(config) self.embeddings = self.roberta.embeddings self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, 1) self.init_weights() for param in self.roberta.parameters(): param.requires_grad = False self.pre_seq_len = config.pre_seq_len self.n_layer = config.num_hidden_layers self.n_head = config.num_attention_heads self.n_embd = config.hidden_size // config.num_attention_heads self.prefix_tokens = torch.arange(self.pre_seq_len).long() self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) bert_param = 0 for name, param in self.roberta.named_parameters(): bert_param += param.numel() all_param = 0 for name, param in self.named_parameters(): all_param += param.numel() total_param = all_param - bert_param print('total param is {}'.format(total_param)) self.embedding = utils.get_embeddings(self, config) self.embeddings_gradient = utils.GradientStorage(self.embedding) def get_prompt(self, batch_size): prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) prompts = self.prefix_encoder(prefix_tokens) return prompts def forward( self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, use_base_grad=False ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See :obj:`input_ids` above) """ utils.use_grad(self.roberta, use_base_grad) return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2] input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None inputs_embeds = ( inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) if inputs_embeds is not None else None ) raw_embedding = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, ) prompts = self.get_prompt(batch_size=batch_size * num_choices) inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.roberta.device) attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) outputs = self.roberta( attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.view(-1, num_choices) loss = None if labels is not None: loss_fct = CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) if not return_dict: output = (reshaped_logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return MultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )