textcnn-ko-dialect-classifier / modeling_textcnn.py
jinmang2's picture
Update modeling_textcnn.py
e025643
import torch
import torch.nn as nn
from typing import List
from dataclasses import dataclass
from transformers import PreTrainedModel
from transformers.file_utils import ModelOutput
from .configuration_textcnn import TextCNNConfig
@dataclass
class TextCNNModelOutput(ModelOutput):
last_hidden_states: torch.FloatTensor = None
ngram_feature_maps: List[torch.FloatTensor] = None
@dataclass
class TextCNNSequenceClassificerOutput(ModelOutput):
loss: torch.FloatTensor = None
logits: torch.FloatTensor = None
last_hidden_states: torch.FloatTensor = None
ngram_feature_maps: List[torch.FloatTensor] = None
class TextCNNPreTrainedModel(PreTrainedModel):
config_class = TextCNNConfig
base_model_prefix = "textcnn"
def _init_weights(self, module):
return NotImplementedError
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
dummy_inputs = {
"attention_mask": input_ids.ne(pad_token),
"input_ids": input_ids,
}
return dummy_inputs
class TextCNNModel(TextCNNPreTrainedModel):
""" A style classifier Text-CNN """
def __init__(self, config):
super().__init__(config)
self.embeder = nn.Embedding(config.vocab_size, config.embed_dim)
self.convs = nn.ModuleList([
nn.Conv2d(1, n, (f, config.embed_dim))
for (n, f) in zip(config.num_filters, config.filter_sizes)
])
def get_input_embeddings(self):
return self.embeder
def set_input_embeddings(self, value):
self.embeder = value
def forward(self, input_ids):
# input_ids.shape == (bsz, seq_len)
# x.shape == (bsz, 1, seq_len, emb_dim)
x = self.embeder(input_ids).unsqueeze(1) # add channel dim
outputs = []
for conv in self.convs:
# conv_output.shape == (bsz, n_filter[i], ngram_seq_len)
conv_output = torch.relu(conv(x)).squeeze(3)
# output.shape == (bsz, n_filter[i])
output = torch.max_pool1d(conv_output, conv_output.size(2)).squeeze(2)
outputs.append(output)
# outputs.shape == (bsz, feature_dim)
outputs = torch.cat(outputs, dim=1)
return TextCNNModelOutput(
last_hidden_states=outputs,
ngram_feature_maps=pools,
)
class TextCNNForSequenceClassification(TextCNNPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.feature_dim = sum(config.num_filters)
self.textcnn = TextCNNModel(config)
self.fc = nn.Sequential(
nn.Dropout(config.dropout),
nn.Linear(self.feature_dim, int(self.feature_dim / 2)),
nn.ReLU(),
nn.Linear(int(self.feature_dim / 2), config.num_labels)
)
def forward(self, input_ids, labels=None):
# input_ids.shape == (bsz, seq_len)
# labels.shape == (bsz,)
outputs = self.textcnn(input_ids)
# outputs.shape == (bsz, feature_dim)
logits = self.fc(outputs[0])
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
return TextCNNSequenceClassificerOutput(
loss=loss,
logits=logits,
last_hidden_states=outputs.last_hidden_states,
ngram_feature_maps=outputs.ngram_feature_maps,
)