import torch from transformers import DistilBertModel, PreTrainedModel from .configuration_sentence_label import DistillBERTSentenceLabelConfig class DistillBERTSentenceLabel(PreTrainedModel): config_class = DistillBERTSentenceLabelConfig def __init__(self, config): super().__init__(config) self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased") self.pre_classifier = torch.nn.Linear(768, 768) self.dropout = torch.nn.Dropout(0.3) self.classifier = torch.nn.Linear(768, 1) # https://glassboxmedicine.com/2019/05/26/classification-sigmoid-vs-softmax/ # self.softmax = torch.nn.Softmax(dim=1) # self.sigmoid = torch.nn.Sigmoid() # apply sigmoid on vector of 1*4 def forward(self, input_ids=None, attention_mask=None): output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask) hidden_state = output_1[0] pooler = hidden_state[:, 0] pooler = self.pre_classifier(pooler) pooler = torch.nn.ReLU()(pooler) pooler = self.dropout(pooler) output = self.classifier(pooler) # output = self.sigmoid(output) # output = self.softmax(output) return {"logits": output}