File size: 3,693 Bytes
a576d9e
 
 
 
 
 
 
 
 
8d3fe9a
8ad43d6
a576d9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4972105
a576d9e
 
 
4972105
a576d9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e025643
 
 
 
 
 
 
 
a576d9e
e025643
a576d9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import torch
import torch.nn as nn

from typing import List
from dataclasses import dataclass

from transformers import PreTrainedModel
from transformers.file_utils import ModelOutput

from .configuration_textcnn import TextCNNConfig


@dataclass
class TextCNNModelOutput(ModelOutput):
    last_hidden_states: torch.FloatTensor = None
    ngram_feature_maps: List[torch.FloatTensor] = None
        
        
@dataclass
class TextCNNSequenceClassificerOutput(ModelOutput):
    loss: torch.FloatTensor = None
    logits: torch.FloatTensor = None
    last_hidden_states: torch.FloatTensor = None
    ngram_feature_maps: List[torch.FloatTensor] = None
    

class TextCNNPreTrainedModel(PreTrainedModel):
    config_class = TextCNNConfig
    base_model_prefix = "textcnn"
    
    def _init_weights(self, module):
        return NotImplementedError
    
    @property
    def dummy_inputs(self):
        pad_token = self.config.pad_token_id
        input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
        dummy_inputs = {
            "attention_mask": input_ids.ne(pad_token),
            "input_ids": input_ids,
        }
        return dummy_inputs
        
        
class TextCNNModel(TextCNNPreTrainedModel):
    """ A style classifier Text-CNN """
    
    def __init__(self, config):
        super().__init__(config)
        self.embeder = nn.Embedding(config.vocab_size, config.embed_dim)
        self.convs = nn.ModuleList([
            nn.Conv2d(1, n, (f, config.embed_dim))
            for (n, f) in zip(config.num_filters, config.filter_sizes)
        ])
        
    def get_input_embeddings(self):
        return self.embeder
    
    def set_input_embeddings(self, value):
        self.embeder = value
        
    def forward(self, input_ids):
        # input_ids.shape == (bsz, seq_len)
        # x.shape == (bsz, 1, seq_len, emb_dim)
        x = self.embeder(input_ids).unsqueeze(1) # add channel dim
        outputs = []
        for conv in self.convs:
            # conv_output.shape == (bsz, n_filter[i], ngram_seq_len)
            conv_output = torch.relu(conv(x)).squeeze(3)
            # output.shape == (bsz, n_filter[i])
            output = torch.max_pool1d(conv_output, conv_output.size(2)).squeeze(2)
            outputs.append(output)
        # outputs.shape == (bsz, feature_dim)
        outputs = torch.cat(outputs, dim=1)
        
        return TextCNNModelOutput(
            last_hidden_states=outputs,
            ngram_feature_maps=pools,
        )
        
        
class TextCNNForSequenceClassification(TextCNNPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.feature_dim = sum(config.num_filters)
        self.textcnn = TextCNNModel(config)
        self.fc = nn.Sequential(
            nn.Dropout(config.dropout),
            nn.Linear(self.feature_dim, int(self.feature_dim / 2)),
            nn.ReLU(),
            nn.Linear(int(self.feature_dim / 2), config.num_labels)
        )
        
    def forward(self, input_ids, labels=None):
        # input_ids.shape == (bsz, seq_len)
        # labels.shape == (bsz,)
        outputs = self.textcnn(input_ids)
        # outputs.shape == (bsz, feature_dim)
        logits = self.fc(outputs[0])
        
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
        
        return TextCNNSequenceClassificerOutput(
            loss=loss,
            logits=logits,
            last_hidden_states=outputs.last_hidden_states,
            ngram_feature_maps=outputs.ngram_feature_maps,
        )