not-lain commited on
Commit
83dab7c
·
1 Parent(s): 0526e55

Create modeling_tunbert.py

Browse files
Files changed (1) hide show
  1. modeling_tunbert.py +31 -0
modeling_tunbert.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForSequenceClassification, PreTrainedModel,AutoConfig, BertModel
3
+ from transformers.modeling_outputs import SequenceClassifierOutput
4
+
5
+ class classifier(nn.Module):
6
+ def __init__(self,config):
7
+ super().__init__()
8
+
9
+ self.layer0 = nn.Linear(in_features=config.hidden_size, out_features=config.hidden_size, bias=True)
10
+ self.layer1 = nn.Linear(in_features=config.hidden_size, out_features=config.type_vocab_size, bias=True)
11
+ def forward(self,tensor):
12
+ out1 = self.layer0(tensor)
13
+ return self.layer1(out1)
14
+
15
+
16
+ class TunBERT(PreTrainedModel):
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+ self.BertModel = BertModel(config)
20
+ self.dropout = nn.Dropout(p=0.1, inplace=False)
21
+ self.classifier = classifier(config)
22
+
23
+ def forward(self,input_ids=None,token_type_ids=None,attention_mask=None,labels=None) :
24
+ outputs = self.BertModel(input_ids,token_type_ids,attention_mask)
25
+ sequence_output = self.dropout(outputs.last_hidden_state)
26
+ logits = self.classifier(sequence_output)
27
+ loss =None
28
+ if labels is not None :
29
+ loss_func = nn.CrossentropyLoss()
30
+ loss = loss_func(logits.view(-1,self.config.type_vocab_size),labels.view(-1))
31
+ return SequenceClassifierOutput(loss = loss, logits= logits, hidden_states=outputs.last_hidden_state,attentions=outputs.attentions)