yuwd's picture
init
03f6091
# -*- coding: utf-8 -*-
r"""
BERT Encoder
==============
Pretrained BERT encoder from Hugging Face.
"""
from argparse import Namespace
from typing import Dict
import torch
from transformers import AutoModel
from polos.models.encoders.encoder_base import Encoder
from polos.tokenizers_ import HFTextEncoder
from torchnlp.utils import lengths_to_mask
class BERTEncoder(Encoder):
"""BERT encoder.
:param tokenizer: BERT text encoder.
:param hparams: ArgumentParser.
"""
def __init__(
self,
tokenizer: HFTextEncoder,
hparams: Namespace,
) -> None:
super().__init__(tokenizer)
self.model = AutoModel.from_pretrained(hparams.pretrained_model)
self.model.encoder.output_hidden_states = True
self._output_units = self.model.config.hidden_size
self._n_layers = self.model.config.num_hidden_layers + 1
self._max_pos = self.model.config.max_position_embeddings
@classmethod
def from_pretrained(cls, hparams: Namespace) -> Encoder:
"""Function that loads a pretrained encoder from Hugging Face.
:param hparams: Namespace.
:return: Encoder model
"""
tokenizer = HFTextEncoder(model=hparams.pretrained_model)
model = BERTEncoder(tokenizer=tokenizer, hparams=hparams)
return model
def freeze_embeddings(self) -> None:
""" Frezees the embedding layer of the network to save some memory while training. """
for param in self.model.embeddings.parameters():
param.requires_grad = False
def layerwise_lr(self, lr: float, decay: float):
"""
:return: List with grouped model parameters with layer-wise decaying learning rate
"""
# Embedding Layer
opt_parameters = [
{
"params": self.model.embeddings.parameters(),
"lr": lr * decay ** (self.num_layers),
}
]
# All layers
opt_parameters += [
{
"params": self.model.encoder.layer[l].parameters(),
"lr": lr * decay ** l,
}
for l in range(self.num_layers - 2, 0, -1)
]
return opt_parameters
def forward(
self, tokens: torch.Tensor, lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Encodes a batch of sequences.
:param tokens: Torch tensor with the input sequences [batch_size x seq_len].
:param lengths: Torch tensor with the lenght of each sequence [seq_len].
:return: Dictionary with `sentemb` (tensor with dims [batch_size x output_units]), `wordemb`
(tensor with dims [batch_size x seq_len x output_units]), `mask` (input mask),
`all_layers` (List with word_embeddings from all layers), `extra` (tuple with the
last_hidden_state, the pooler_output representing the entire sentence and the word
embeddings for all BERT layers).
"""
mask = lengths_to_mask(lengths, device=tokens.device)
last_hidden_states, pooler_output, all_layers = self.model(
tokens, mask, output_hidden_states=True, return_dict=False
)
return {
"sentemb": pooler_output,
"wordemb": last_hidden_states,
"all_layers": all_layers,
"mask": mask,
"extra": (last_hidden_states, pooler_output, all_layers),
}