Spaces:
Runtime error
Runtime error
from torch import nn | |
from torch.nn import CrossEntropyLoss, MSELoss | |
from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel | |
from gp import GPClassificationHead | |
class BertForUQSequenceClassification(BertPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.bert = BertModel(config) | |
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |
self.classifier = GPClassificationHead( | |
hidden_size=config.hidden_size, | |
num_classes=config.num_labels, | |
num_inducing=512, | |
) | |
self.return_gp_cov = False | |
self.init_weights() | |
def forward( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
token_type_ids=None, | |
position_ids=None, | |
head_mask=None, | |
inputs_embeds=None, | |
labels=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
): | |
r""" | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): | |
Labels for computing the sequence classification/regression loss. | |
Indices should be in :obj:`[0, ..., config.num_labels - 1]`. | |
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), | |
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
Returns: | |
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: | |
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): | |
Classification (or regression if config.num_labels==1) loss. | |
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`): | |
Classification (or regression if config.num_labels==1) scores (before SoftMax). | |
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |
of shape :obj:`(batch_size, sequence_length, hidden_size)`. | |
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | |
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): | |
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape | |
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`. | |
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention | |
heads. | |
""" | |
outputs = self.bert( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
) | |
pooled_output = outputs[1] | |
pooled_output = self.dropout(pooled_output) | |
if self.return_gp_cov: | |
logits, gp_cov = self.classifier( | |
pooled_output, | |
return_gp_cov=True, | |
update_cov=False, | |
) | |
else: | |
logits = self.classifier(pooled_output) | |
outputs = (logits,) + outputs[ | |
2: | |
] # add hidden states and attention if they are here | |
if labels is not None: | |
if self.num_labels == 1: | |
# We are doing regression | |
loss_fct = MSELoss() | |
loss = loss_fct(logits.view(-1), labels.view(-1)) | |
else: | |
loss_fct = CrossEntropyLoss() | |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
outputs = (loss,) + outputs | |
if self.return_gp_cov: | |
return outputs, gp_cov | |
else: | |
return outputs # (loss), logits, (hidden_states), (attentions) | |