from transformers import PretrainedConfig, PreTrainedModel, BertModel, BertConfig from torch import nn class SimBertModel(PreTrainedModel): """ SimBert Model """ config_class = BertConfig def __init__( self, config: PretrainedConfig ) -> None: super().__init__(config) self.bert = BertModel(config=config, add_pooling_layer=True) self.fc = nn.Linear(config.hidden_size, 2) # self.loss_fct = nn.CrossEntropyLoss() self.loss_fct = nn.MSELoss() self.softmax = nn.Softmax(dim=1) def forward( self, input_ids, token_type_ids, attention_mask, labels=None ): outputs = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ) pooled_output = outputs.pooler_output logits = self.fc(pooled_output) logits = self.softmax(logits)[:,1] if labels is not None: loss = self.loss_fct(logits.view(-1), labels.view(-1)) return loss, logits return None, logits