jinmang2 commited on
Commit
a576d9e
·
1 Parent(s): 556d5b1

Create modeling_textcnn.py

Browse files
Files changed (1) hide show
  1. modeling_textcnn.py +106 -0
modeling_textcnn.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import List
5
+ from dataclasses import dataclass
6
+
7
+ from transformers import PreTrainedModel
8
+ from transformers.file_utils import ModelOutput
9
+
10
+
11
+ @dataclass
12
+ class TextCNNModelOutput(ModelOutput):
13
+ last_hidden_states: torch.FloatTensor = None
14
+ ngram_feature_maps: List[torch.FloatTensor] = None
15
+
16
+
17
+ @dataclass
18
+ class TextCNNSequenceClassificerOutput(ModelOutput):
19
+ loss: torch.FloatTensor = None
20
+ logits: torch.FloatTensor = None
21
+ last_hidden_states: torch.FloatTensor = None
22
+ ngram_feature_maps: List[torch.FloatTensor] = None
23
+
24
+
25
+ class TextCNNPreTrainedModel(PreTrainedModel):
26
+ config_class = TextCNNConfig
27
+ base_model_prefix = "textcnn"
28
+
29
+ def _init_weights(self, module):
30
+ return NotImplementedError
31
+
32
+ @property
33
+ def dummy_inputs(self):
34
+ pad_token = self.config.pad_token_id
35
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
36
+ dummy_inputs = {
37
+ "attention_mask": input_ids.ne(pad_token),
38
+ "input_ids": input_ids,
39
+ }
40
+ dummy_inputs
41
+
42
+
43
+ class TextCNNModel(TextCNNPreTrainedModel):
44
+ """ A Style classifier Text-CNN """
45
+
46
+ def __init__(self, config):
47
+ super().__init__(config)
48
+ self.embeder = nn.Embedding(config.vocab_size, config.embed_dim)
49
+ self.convs = nn.ModuleList([
50
+ nn.Conv2d(1, n, (f, config.embed_dim))
51
+ for (n, f) in zip(config.num_filters, config.filter_sizes)
52
+ ])
53
+
54
+ def get_input_embeddings(self):
55
+ return self.embeder
56
+
57
+ def set_input_embeddings(self, value):
58
+ self.embeder = value
59
+
60
+ def forward(self, input_ids):
61
+ # input_ids.shape == (bsz, seq_len)
62
+ x = self.embeder(input_ids).unsqueeze(1) # add channel dim
63
+ # x.shape == (bsz, 1, seq_len, emb_dim)
64
+ convs = [torch.relu(conv(x)).squeeze(3) for conv in self.convs]
65
+ # convs[i].shape == (bsz, n_filter[i], ngram_seq_len)
66
+ pools = [torch.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs]
67
+ # pools[i].shape == (bsz, n_filter[i])
68
+ outputs = torch.cat(pools, 1)
69
+ # outputs.shape == (bsz, feature_dim)
70
+
71
+ return TextCNNModelOutput(
72
+ last_hidden_states=outputs,
73
+ ngram_feature_maps=pools,
74
+ )
75
+
76
+
77
+ class TextCNNForSequenceClassification(TextCNNPreTrainedModel):
78
+ def __init__(self, config):
79
+ super().__init__(config)
80
+ self.feature_dim = sum(config.num_filters)
81
+ self.textcnn = TextCNNModel(config)
82
+ self.fc = nn.Sequential(
83
+ nn.Dropout(config.dropout),
84
+ nn.Linear(self.feature_dim, int(self.feature_dim / 2)),
85
+ nn.ReLU(),
86
+ nn.Linear(int(self.feature_dim / 2), config.num_labels)
87
+ )
88
+
89
+ def forward(self, input_ids, labels=None):
90
+ # input_ids.shape == (bsz, seq_len)
91
+ # labels.shape == (bsz,)
92
+ outputs = self.textcnn(input_ids)
93
+ # outputs.shape == (bsz, feature_dim)
94
+ logits = self.fc(outputs[0])
95
+
96
+ loss = None
97
+ if labels is not None:
98
+ loss_fct = nn.CrossEntropyLoss()
99
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
100
+
101
+ return TextCNNSequenceClassificerOutput(
102
+ loss=loss,
103
+ logits=logits,
104
+ last_hidden_states=outputs.last_hidden_states,
105
+ ngram_feature_maps=outputs.ngram_feature_maps,
106
+ )