File size: 3,447 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# -*- coding: utf-8 -*-
r"""
BERT Encoder
==============
    Pretrained BERT encoder from Hugging Face.
"""
from argparse import Namespace
from typing import Dict

import torch
from transformers import AutoModel

from polos.models.encoders.encoder_base import Encoder
from polos.tokenizers_ import HFTextEncoder
from torchnlp.utils import lengths_to_mask


class BERTEncoder(Encoder):
    """BERT encoder.
    :param tokenizer: BERT text encoder.
    :param hparams: ArgumentParser.
    """

    def __init__(
        self,
        tokenizer: HFTextEncoder,
        hparams: Namespace,
    ) -> None:
        super().__init__(tokenizer)
        self.model = AutoModel.from_pretrained(hparams.pretrained_model)
        self.model.encoder.output_hidden_states = True
        self._output_units = self.model.config.hidden_size
        self._n_layers = self.model.config.num_hidden_layers + 1
        self._max_pos = self.model.config.max_position_embeddings

    @classmethod
    def from_pretrained(cls, hparams: Namespace) -> Encoder:
        """Function that loads a pretrained encoder from Hugging Face.
        :param hparams: Namespace.

        :return: Encoder model
        """
        tokenizer = HFTextEncoder(model=hparams.pretrained_model)
        model = BERTEncoder(tokenizer=tokenizer, hparams=hparams)
        return model

    def freeze_embeddings(self) -> None:
        """ Frezees the embedding layer of the network to save some memory while training. """
        for param in self.model.embeddings.parameters():
            param.requires_grad = False

    def layerwise_lr(self, lr: float, decay: float):
        """
        :return: List with grouped model parameters with layer-wise decaying learning rate
        """
        # Embedding Layer
        opt_parameters = [
            {
                "params": self.model.embeddings.parameters(),
                "lr": lr * decay ** (self.num_layers),
            }
        ]
        # All layers
        opt_parameters += [
            {
                "params": self.model.encoder.layer[l].parameters(),
                "lr": lr * decay ** l,
            }
            for l in range(self.num_layers - 2, 0, -1)
        ]
        return opt_parameters

    def forward(
        self, tokens: torch.Tensor, lengths: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Encodes a batch of sequences.

        :param tokens: Torch tensor with the input sequences [batch_size x seq_len].
        :param lengths: Torch tensor with the lenght of each sequence [seq_len].

        :return: Dictionary with `sentemb` (tensor with dims [batch_size x output_units]), `wordemb`
            (tensor with dims [batch_size x seq_len x output_units]), `mask` (input mask),
            `all_layers` (List with word_embeddings from all layers), `extra` (tuple with the
            last_hidden_state, the pooler_output representing  the entire sentence and the word
            embeddings for all BERT layers).
        """
        mask = lengths_to_mask(lengths, device=tokens.device)
        last_hidden_states, pooler_output, all_layers = self.model(
            tokens, mask, output_hidden_states=True, return_dict=False
        )
        return {
            "sentemb": pooler_output,
            "wordemb": last_hidden_states,
            "all_layers": all_layers,
            "mask": mask,
            "extra": (last_hidden_states, pooler_output, all_layers),
        }