File size: 2,019 Bytes
c505fde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from transformers import PreTrainedModel
from .configuration import *

import torch.nn as nn
import torch
from transformers import AutoModel

class OffensivenessEstimationModel(PreTrainedModel):
    config_class = OffensivenessEstimationConfig

    def __init__(self, config):
        super().__init__(config)
        self.text_encoder = PretrainedLanguageModel(config)
        self.decoder = nn.Sequential(
            nn.Dropout(p=config.dropout_rate),
            nn.Linear(768, config.output_class_num)
        )

    def forward(self, ids, mask):
        h = self.text_encoder(ids, mask)
        output = self.decoder(h)
        return output

class PretrainedLanguageModel(PreTrainedModel):
    config_class = OffensivenessEstimationConfig

    def __init__(self, config):
        super().__init__(config)
        self.language_model = AutoModel.from_pretrained(config.language_model)
        self.reinit_n_layers = config.reinit_n_layers
        if self.reinit_n_layers > 0:
            self._do_reinit()

    def _do_reinit(self):
        # Re-init last n layers.
        for layer in self.language_model.encoder.layer[-1*self.reinit_n_layers:]:
            for module in layer.modules():
                if isinstance(module, nn.Linear):
                    module.weight.data.normal_(mean=0.0, std=self.language_model.config.initializer_range)
                    if module.bias is not None:
                        module.bias.data.zero_()
                elif isinstance(module, nn.Embedding):
                    module.weight.data.normal_(mean=0.0, std=self.language_model.config.initializer_range)
                    if module.padding_idx is not None:
                        module.weight.data[module.padding_idx].zero_()
                elif isinstance(module, nn.LayerNorm):
                    module.bias.data.zero_()
                    module.weight.data.fill_(1.0)

    def forward(self, ids, mask):
        output = self.language_model(ids, attention_mask=mask)
        return output[0][:,0,:]