arubenruben's picture
Upload BERT_CRF
2968f5e
from torch import nn
from transformers import PreTrainedModel, PretrainedConfig
from transformers import BertModel, BertConfig
from transformers import AutoModelForTokenClassification, AutoConfig
from torchcrf import CRF
class BERT_CRF_Config(PretrainedConfig):
model_type = "BERT_CRF"
def __init__(self, **kwarg):
super().__init__(**kwarg)
self.model_name = "BERT_CRF"
class BERT_CRF(PreTrainedModel):
config_class = BERT_CRF_Config
def __init__(self, config):
super().__init__(config)
bert_config = BertConfig.from_pretrained(config.bert_name)
bert_config.output_attentions = True
bert_config.output_hidden_states = True
self.bert = BertModel.from_pretrained(config.bert_name, config=bert_config)
self.dropout = nn.Dropout(p=0.5)
self.linear = nn.Linear(
self.bert.config.hidden_size, config.num_labels)
self.crf = CRF(config.num_labels, batch_first=True)
def forward(self, input_ids, token_type_ids, attention_mask, labels, labels_mask):
last_hidden_layer = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[
'last_hidden_state']
last_hidden_layer = self.dropout(last_hidden_layer)
logits = self.linear(last_hidden_layer)
batch_size = logits.shape[0]
output_tags = []
if labels is not None:
loss = 0
for seq_logits, seq_labels, seq_mask in zip(logits, labels, labels_mask):
# Index logits and labels using prediction mask to pass only the
# first subtoken of each word to CRF.
seq_logits = seq_logits[seq_mask].unsqueeze(0)
seq_labels = seq_labels[seq_mask].unsqueeze(0)
if seq_logits.numel() != 0:
loss -= self.crf(seq_logits, seq_labels,
reduction='token_mean')
return loss / batch_size
else:
for seq_logits, seq_mask in zip(logits, labels_mask):
seq_logits = seq_logits[seq_mask].unsqueeze(0)
if seq_logits.numel() != 0:
tags = self.crf.decode(seq_logits)
else:
tags = [[]]
# Unpack "batch" results
output_tags.append(tags[0])
return output_tags
class ModelRegisterStep():
def __call__(self, args):
AutoConfig.register("BERT_CRF", BERT_CRF_Config)
AutoModelForTokenClassification.register(BERT_CRF_Config, BERT_CRF)
return {
**args,
}