|
from torch import nn |
|
from transformers import PreTrainedModel, PretrainedConfig |
|
from transformers import BertModel, BertConfig |
|
from transformers import AutoModelForTokenClassification, AutoConfig |
|
from torchcrf import CRF |
|
|
|
class BERT_CRF_Config(PretrainedConfig): |
|
model_type = "BERT_CRF" |
|
|
|
def __init__(self, **kwarg): |
|
super().__init__(**kwarg) |
|
self.model_name = "BERT_CRF" |
|
|
|
|
|
class BERT_CRF(PreTrainedModel): |
|
config_class = BERT_CRF_Config |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
bert_config = BertConfig.from_pretrained(config.bert_name) |
|
|
|
bert_config.output_attentions = True |
|
bert_config.output_hidden_states = True |
|
|
|
self.bert = BertModel.from_pretrained(config.bert_name, config=bert_config) |
|
|
|
self.dropout = nn.Dropout(p=0.5) |
|
|
|
self.linear = nn.Linear( |
|
self.bert.config.hidden_size, config.num_labels) |
|
|
|
self.crf = CRF(config.num_labels, batch_first=True) |
|
|
|
def forward(self, input_ids, token_type_ids, attention_mask, labels, labels_mask): |
|
|
|
last_hidden_layer = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[ |
|
'last_hidden_state'] |
|
|
|
last_hidden_layer = self.dropout(last_hidden_layer) |
|
|
|
logits = self.linear(last_hidden_layer) |
|
|
|
batch_size = logits.shape[0] |
|
|
|
output_tags = [] |
|
|
|
if labels is not None: |
|
loss = 0 |
|
|
|
for seq_logits, seq_labels, seq_mask in zip(logits, labels, labels_mask): |
|
|
|
|
|
seq_logits = seq_logits[seq_mask].unsqueeze(0) |
|
seq_labels = seq_labels[seq_mask].unsqueeze(0) |
|
|
|
if seq_logits.numel() != 0: |
|
loss -= self.crf(seq_logits, seq_labels, |
|
reduction='token_mean') |
|
|
|
return loss / batch_size |
|
else: |
|
for seq_logits, seq_mask in zip(logits, labels_mask): |
|
seq_logits = seq_logits[seq_mask].unsqueeze(0) |
|
|
|
if seq_logits.numel() != 0: |
|
tags = self.crf.decode(seq_logits) |
|
else: |
|
tags = [[]] |
|
|
|
|
|
output_tags.append(tags[0]) |
|
|
|
return output_tags |
|
|
|
|
|
class ModelRegisterStep(): |
|
def __call__(self, args): |
|
|
|
AutoConfig.register("BERT_CRF", BERT_CRF_Config) |
|
AutoModelForTokenClassification.register(BERT_CRF_Config, BERT_CRF) |
|
|
|
return { |
|
**args, |
|
} |
|
|