from torch import nn from transformers import PreTrainedModel, PretrainedConfig from transformers import BertModel, BertConfig from transformers import AutoModelForTokenClassification, AutoConfig from torchcrf import CRF class BERT_CRF_Config(PretrainedConfig): model_type = "BERT_CRF" def __init__(self, **kwarg): super().__init__(**kwarg) self.model_name = "BERT_CRF" class BERT_CRF(PreTrainedModel): config_class = BERT_CRF_Config def __init__(self, config): super().__init__(config) bert_config = BertConfig.from_pretrained(config.bert_name) bert_config.output_attentions = True bert_config.output_hidden_states = True self.bert = BertModel.from_pretrained(config.bert_name, config=bert_config) self.dropout = nn.Dropout(p=0.5) self.linear = nn.Linear( self.bert.config.hidden_size, config.num_labels) self.crf = CRF(config.num_labels, batch_first=True) def forward(self, input_ids, token_type_ids, attention_mask, labels, labels_mask): last_hidden_layer = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[ 'last_hidden_state'] last_hidden_layer = self.dropout(last_hidden_layer) logits = self.linear(last_hidden_layer) batch_size = logits.shape[0] output_tags = [] if labels is not None: loss = 0 for seq_logits, seq_labels, seq_mask in zip(logits, labels, labels_mask): # Index logits and labels using prediction mask to pass only the # first subtoken of each word to CRF. seq_logits = seq_logits[seq_mask].unsqueeze(0) seq_labels = seq_labels[seq_mask].unsqueeze(0) if seq_logits.numel() != 0: loss -= self.crf(seq_logits, seq_labels, reduction='token_mean') return loss / batch_size else: for seq_logits, seq_mask in zip(logits, labels_mask): seq_logits = seq_logits[seq_mask].unsqueeze(0) if seq_logits.numel() != 0: tags = self.crf.decode(seq_logits) else: tags = [[]] # Unpack "batch" results output_tags.append(tags[0]) return output_tags class ModelRegisterStep(): def __call__(self, args): AutoConfig.register("BERT_CRF", BERT_CRF_Config) AutoModelForTokenClassification.register(BERT_CRF_Config, BERT_CRF) return { **args, }