File size: 3,655 Bytes
03f6091
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# -*- coding: utf-8 -*-
r"""
Encoder Model base
====================
    Module defining the common interface between all pretrained encoder models.
"""
import warnings
from argparse import Namespace
from typing import Dict, List

import torch
import torch.nn as nn

from polos.tokenizers_ import TextEncoderBase


class Encoder(nn.Module):
    """Base class for an encoder model.

    :param output_units: Number of output features that will be passed to the Estimator.
    """

    def __init__(
        self,
        tokenizer: TextEncoderBase,
    ) -> None:
        super().__init__()
        self.tokenizer = tokenizer

    @property
    def output_units(self):
        """ Max number of tokens the encoder handles. """
        return self._output_units

    @property
    def max_positions(self):
        """ Max number of tokens the encoder handles. """
        return self._max_pos

    @property
    def num_layers(self):
        """ Number of model layers available. """
        return self._n_layers

    @property
    def lm_head(self):
        """ Language modeling head. """
        raise NotImplementedError

    @classmethod
    def from_pretrained(cls, hparams: Namespace):
        """Function that loads a pretrained encoder and the respective tokenizer.

        :return: Encoder model
        """
        raise NotImplementedError

    def check_lengths(self, tokens: torch.Tensor, lengths: torch.Tensor):
        """ Checks if lengths are not exceeded and warns user if so."""
        if lengths.max() > self.max_positions:
            warnings.warn(
                "Encoder max length exceeded ({} > {}).".format(
                    lengths.max(), self.max_positions
                ),
                category=RuntimeWarning,
            )
            lengths[lengths > self.max_positions] = self.max_positions
            tokens = tokens[:, : self.max_positions]

        return tokens, lengths

    def prepare_sample(self, sample: List[str]) -> (torch.Tensor, torch.Tensor):
        """ Receives a list of strings and applies model specific tokenization and vectorization."""
        tokens, lengths = self.tokenizer.batch_encode(sample)
        tokens, lengths = self.check_lengths(tokens, lengths)
        return {"tokens": tokens, "lengths": lengths}

    def freeze(self) -> None:
        """ Frezees the entire encoder network. """
        for param in self.parameters():
            param.requires_grad = False

    def unfreeze(self) -> None:
        """ Unfrezees the entire encoder network. """
        for param in self.parameters():
            param.requires_grad = True

    def freeze_embeddings(self) -> None:
        """ Frezees the embedding layer of the network to save some memory while training. """
        raise NotImplementedError

    def layerwise_lr(self, lr: float, decay: float):
        """
        :return: List with grouped model parameters with layer-wise decaying learning rate
        """
        raise NotImplementedError

    def forward(
        self, tokens: torch.Tensor, lengths: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Encodes a batch of sequences.

        :param tokens: Torch tensor with the input sequences [batch_size x seq_len].
        :param lengths: Torch tensor with the lenght of each sequence [seq_len].

        :return: Dictionary with `sentemb` (tensor with dims [batch_size x output_units]), `wordemb`
            (tensor with dims [batch_size x seq_len x output_units]), `mask` (input mask),
            `all_layers` (List with word_embeddings from all layers, `extra` (model specific outputs).
        """
        raise NotImplementedError