dudududukim commited on
Commit
749e845
·
verified ·
1 Parent(s): 8ac702d
configuration_bert_concat.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from transformers import BertConfig
2
+
3
+ class BertConcatConfig(BertConfig):
4
+ def __init__(self, bert_model_name='klue/bert-base', num_labels=2, **kwargs):
5
+ super().__init__(**kwargs)
6
+ self.bert_model_name = bert_model_name
7
+ self.num_labels = num_labels
modeling_bert_concat.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModel
2
+ from transformers.modeling_outputs import SequenceClassifierOutput
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ class BertConcatClassifier(PreTrainedModel):
7
+ config_class = BertConcatConfig
8
+
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ self.bert = AutoModel.from_pretrained(config.bert_model_name, output_hidden_states=True)
12
+ self.num_labels = config.num_labels
13
+
14
+ # Classification layers
15
+ self.conv = nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) # 3x768 -> 1x768
16
+ self.relu = nn.ReLU()
17
+ self.classifier = nn.Linear(768, self.num_labels)
18
+
19
+ def forward(self, input_ids, attention_mask=None, labels=None):
20
+ outputs = self.bert(input_ids, attention_mask=attention_mask)
21
+ hidden_states = outputs.hidden_states
22
+
23
+ # Concatenate the vectors as per custom model design
24
+ last_cls_vector = hidden_states[-1][:, 0, :]
25
+ fourth_last_cls_vector = hidden_states[-4][:, 0, :]
26
+ mean_pooled_vector = hidden_states[-1].mean(dim=1)
27
+
28
+ concatenated_vector = torch.cat(
29
+ (last_cls_vector.unsqueeze(1),
30
+ fourth_last_cls_vector.unsqueeze(1),
31
+ mean_pooled_vector.unsqueeze(1)),
32
+ dim=1
33
+ )
34
+
35
+ # Apply convolution and linear layers
36
+ conv_output = self.conv(concatenated_vector).squeeze(2)
37
+ relu_output = self.relu(conv_output)
38
+ logits = self.classifier(relu_output)
39
+ logits = logits.squeeze(1)
40
+
41
+ loss = None
42
+ if labels is not None:
43
+ loss_fct = nn.CrossEntropyLoss()
44
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
45
+
46
+ return SequenceClassifierOutput(
47
+ loss=loss,
48
+ logits=logits,
49
+ hidden_states=outputs.hidden_states,
50
+ attentions=outputs.attentions
51
+ )