Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
r""" | |
BERT Encoder | |
============== | |
Pretrained BERT encoder from Hugging Face. | |
""" | |
from argparse import Namespace | |
from typing import Dict | |
import torch | |
from transformers import AutoModel | |
from polos.models.encoders.encoder_base import Encoder | |
from polos.tokenizers_ import HFTextEncoder | |
from torchnlp.utils import lengths_to_mask | |
class BERTEncoder(Encoder): | |
"""BERT encoder. | |
:param tokenizer: BERT text encoder. | |
:param hparams: ArgumentParser. | |
""" | |
def __init__( | |
self, | |
tokenizer: HFTextEncoder, | |
hparams: Namespace, | |
) -> None: | |
super().__init__(tokenizer) | |
self.model = AutoModel.from_pretrained(hparams.pretrained_model) | |
self.model.encoder.output_hidden_states = True | |
self._output_units = self.model.config.hidden_size | |
self._n_layers = self.model.config.num_hidden_layers + 1 | |
self._max_pos = self.model.config.max_position_embeddings | |
def from_pretrained(cls, hparams: Namespace) -> Encoder: | |
"""Function that loads a pretrained encoder from Hugging Face. | |
:param hparams: Namespace. | |
:return: Encoder model | |
""" | |
tokenizer = HFTextEncoder(model=hparams.pretrained_model) | |
model = BERTEncoder(tokenizer=tokenizer, hparams=hparams) | |
return model | |
def freeze_embeddings(self) -> None: | |
""" Frezees the embedding layer of the network to save some memory while training. """ | |
for param in self.model.embeddings.parameters(): | |
param.requires_grad = False | |
def layerwise_lr(self, lr: float, decay: float): | |
""" | |
:return: List with grouped model parameters with layer-wise decaying learning rate | |
""" | |
# Embedding Layer | |
opt_parameters = [ | |
{ | |
"params": self.model.embeddings.parameters(), | |
"lr": lr * decay ** (self.num_layers), | |
} | |
] | |
# All layers | |
opt_parameters += [ | |
{ | |
"params": self.model.encoder.layer[l].parameters(), | |
"lr": lr * decay ** l, | |
} | |
for l in range(self.num_layers - 2, 0, -1) | |
] | |
return opt_parameters | |
def forward( | |
self, tokens: torch.Tensor, lengths: torch.Tensor | |
) -> Dict[str, torch.Tensor]: | |
""" | |
Encodes a batch of sequences. | |
:param tokens: Torch tensor with the input sequences [batch_size x seq_len]. | |
:param lengths: Torch tensor with the lenght of each sequence [seq_len]. | |
:return: Dictionary with `sentemb` (tensor with dims [batch_size x output_units]), `wordemb` | |
(tensor with dims [batch_size x seq_len x output_units]), `mask` (input mask), | |
`all_layers` (List with word_embeddings from all layers), `extra` (tuple with the | |
last_hidden_state, the pooler_output representing the entire sentence and the word | |
embeddings for all BERT layers). | |
""" | |
mask = lengths_to_mask(lengths, device=tokens.device) | |
last_hidden_states, pooler_output, all_layers = self.model( | |
tokens, mask, output_hidden_states=True, return_dict=False | |
) | |
return { | |
"sentemb": pooler_output, | |
"wordemb": last_hidden_states, | |
"all_layers": all_layers, | |
"mask": mask, | |
"extra": (last_hidden_states, pooler_output, all_layers), | |
} | |