Polos-Demo / polos /models /encoders /encoder_base.py
yuwd's picture
init
03f6091
# -*- coding: utf-8 -*-
r"""
Encoder Model base
====================
Module defining the common interface between all pretrained encoder models.
"""
import warnings
from argparse import Namespace
from typing import Dict, List
import torch
import torch.nn as nn
from polos.tokenizers_ import TextEncoderBase
class Encoder(nn.Module):
"""Base class for an encoder model.
:param output_units: Number of output features that will be passed to the Estimator.
"""
def __init__(
self,
tokenizer: TextEncoderBase,
) -> None:
super().__init__()
self.tokenizer = tokenizer
@property
def output_units(self):
""" Max number of tokens the encoder handles. """
return self._output_units
@property
def max_positions(self):
""" Max number of tokens the encoder handles. """
return self._max_pos
@property
def num_layers(self):
""" Number of model layers available. """
return self._n_layers
@property
def lm_head(self):
""" Language modeling head. """
raise NotImplementedError
@classmethod
def from_pretrained(cls, hparams: Namespace):
"""Function that loads a pretrained encoder and the respective tokenizer.
:return: Encoder model
"""
raise NotImplementedError
def check_lengths(self, tokens: torch.Tensor, lengths: torch.Tensor):
""" Checks if lengths are not exceeded and warns user if so."""
if lengths.max() > self.max_positions:
warnings.warn(
"Encoder max length exceeded ({} > {}).".format(
lengths.max(), self.max_positions
),
category=RuntimeWarning,
)
lengths[lengths > self.max_positions] = self.max_positions
tokens = tokens[:, : self.max_positions]
return tokens, lengths
def prepare_sample(self, sample: List[str]) -> (torch.Tensor, torch.Tensor):
""" Receives a list of strings and applies model specific tokenization and vectorization."""
tokens, lengths = self.tokenizer.batch_encode(sample)
tokens, lengths = self.check_lengths(tokens, lengths)
return {"tokens": tokens, "lengths": lengths}
def freeze(self) -> None:
""" Frezees the entire encoder network. """
for param in self.parameters():
param.requires_grad = False
def unfreeze(self) -> None:
""" Unfrezees the entire encoder network. """
for param in self.parameters():
param.requires_grad = True
def freeze_embeddings(self) -> None:
""" Frezees the embedding layer of the network to save some memory while training. """
raise NotImplementedError
def layerwise_lr(self, lr: float, decay: float):
"""
:return: List with grouped model parameters with layer-wise decaying learning rate
"""
raise NotImplementedError
def forward(
self, tokens: torch.Tensor, lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Encodes a batch of sequences.
:param tokens: Torch tensor with the input sequences [batch_size x seq_len].
:param lengths: Torch tensor with the lenght of each sequence [seq_len].
:return: Dictionary with `sentemb` (tensor with dims [batch_size x output_units]), `wordemb`
(tensor with dims [batch_size x seq_len x output_units]), `mask` (input mask),
`all_layers` (List with word_embeddings from all layers, `extra` (model specific outputs).
"""
raise NotImplementedError