|
from transformers import PreTrainedModel |
|
from .configuration import * |
|
|
|
import torch.nn as nn |
|
import torch |
|
from transformers import AutoModel |
|
|
|
class OffensivenessEstimationModel(PreTrainedModel): |
|
config_class = OffensivenessEstimationConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.text_encoder = PretrainedLanguageModel(config) |
|
self.decoder = nn.Sequential( |
|
nn.Dropout(p=config.dropout_rate), |
|
nn.Linear(768, config.output_class_num) |
|
) |
|
|
|
def forward(self, ids, mask): |
|
h = self.text_encoder(ids, mask) |
|
output = self.decoder(h) |
|
return output |
|
|
|
class PretrainedLanguageModel(PreTrainedModel): |
|
config_class = OffensivenessEstimationConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.language_model = AutoModel.from_pretrained(config.language_model) |
|
self.reinit_n_layers = config.reinit_n_layers |
|
if self.reinit_n_layers > 0: |
|
self._do_reinit() |
|
|
|
def _do_reinit(self): |
|
|
|
for layer in self.language_model.encoder.layer[-1*self.reinit_n_layers:]: |
|
for module in layer.modules(): |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=self.language_model.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.language_model.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
def forward(self, ids, mask): |
|
output = self.language_model(ids, attention_mask=mask) |
|
return output[0][:,0,:] |
|
|