"""Default Recurrent Neural Network Languge Model in `lm_train.py`.""" from typing import Any from typing import List from typing import Tuple import logging import torch import torch.nn as nn import torch.nn.functional as F from espnet.nets.lm_interface import LMInterface from espnet.nets.pytorch_backend.e2e_asr import to_device from espnet.nets.scorer_interface import BatchScorerInterface from espnet.utils.cli_utils import strtobool class DefaultRNNLM(BatchScorerInterface, LMInterface, nn.Module): """Default RNNLM for `LMInterface` Implementation. Note: PyTorch seems to have memory leak when one GPU compute this after data parallel. If parallel GPUs compute this, it seems to be fine. See also https://github.com/espnet/espnet/issues/1075 """ @staticmethod def add_arguments(parser): """Add arguments to command line argument parser.""" parser.add_argument( "--type", type=str, default="lstm", nargs="?", choices=["lstm", "gru"], help="Which type of RNN to use", ) parser.add_argument( "--layer", "-l", type=int, default=2, help="Number of hidden layers" ) parser.add_argument( "--unit", "-u", type=int, default=650, help="Number of hidden units" ) parser.add_argument( "--embed-unit", default=None, type=int, help="Number of hidden units in embedding layer, " "if it is not specified, it keeps the same number with hidden units.", ) parser.add_argument( "--dropout-rate", type=float, default=0.5, help="dropout probability" ) parser.add_argument( "--emb-dropout-rate", type=float, default=0.0, help="emb dropout probability", ) parser.add_argument( "--tie-weights", type=strtobool, default=False, help="Tie input and output embeddings", ) return parser def __init__(self, n_vocab, args): """Initialize class. Args: n_vocab (int): The size of the vocabulary args (argparse.Namespace): configurations. see py:method:`add_arguments` """ nn.Module.__init__(self) # NOTE: for a compatibility with less than 0.5.0 version models dropout_rate = getattr(args, "dropout_rate", 0.0) # NOTE: for a compatibility with less than 0.6.1 version models embed_unit = getattr(args, "embed_unit", None) # NOTE: for a compatibility with less than 0.9.7 version models emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0) # NOTE: for a compatibility with less than 0.9.7 version models tie_weights = getattr(args, "tie_weights", False) self.model = ClassifierWithState( RNNLM( n_vocab, args.layer, args.unit, embed_unit, args.type, dropout_rate, emb_dropout_rate, tie_weights, ) ) def state_dict(self): """Dump state dict.""" return self.model.state_dict() def load_state_dict(self, d): """Load state dict.""" self.model.load_state_dict(d) def forward(self, x, t): """Compute LM loss value from buffer sequences. Args: x (torch.Tensor): Input ids. (batch, len) t (torch.Tensor): Target ids. (batch, len) Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of loss to backward (scalar), negative log-likelihood of t: -log p(t) (scalar) and the number of elements in x (scalar) Notes: The last two return values are used in perplexity: p(t)^{-n} = exp(-log p(t) / n) """ loss = 0 logp = 0 count = torch.tensor(0).long() state = None batch_size, sequence_length = x.shape for i in range(sequence_length): # Compute the loss at this time step and accumulate it state, loss_batch = self.model(state, x[:, i], t[:, i]) non_zeros = torch.sum(x[:, i] != 0, dtype=loss_batch.dtype) loss += loss_batch.mean() * non_zeros logp += torch.sum(loss_batch * non_zeros) count += int(non_zeros) return loss / batch_size, loss, count.to(loss.device) def score(self, y, state, x): """Score new token. Args: y (torch.Tensor): 1D torch.int64 prefix tokens. state: Scorer state for prefix tokens x (torch.Tensor): 2D encoder feature that generates ys. Returns: tuple[torch.Tensor, Any]: Tuple of torch.float32 scores for next token (n_vocab) and next state for ys """ new_state, scores = self.model.predict(state, y[-1].unsqueeze(0)) return scores.squeeze(0), new_state def final_score(self, state): """Score eos. Args: state: Scorer state for prefix tokens Returns: float: final score """ return self.model.final(state) # batch beam search API (see BatchScorerInterface) def batch_score( self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor ) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch. Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ # merge states n_batch = len(ys) n_layers = self.model.predictor.n_layers if self.model.predictor.typ == "lstm": keys = ("c", "h") else: keys = ("h",) if states[0] is None: states = None else: # transpose state of [batch, key, layer] into [key, layer, batch] states = { k: [ torch.stack([states[b][k][i] for b in range(n_batch)]) for i in range(n_layers) ] for k in keys } states, logp = self.model.predict(states, ys[:, -1]) # transpose state of [key, layer, batch] into [batch, key, layer] return ( logp, [ {k: [states[k][i][b] for i in range(n_layers)] for k in keys} for b in range(n_batch) ], ) class ClassifierWithState(nn.Module): """A wrapper for pytorch RNNLM.""" def __init__( self, predictor, lossfun=nn.CrossEntropyLoss(reduction="none"), label_key=-1 ): """Initialize class. :param torch.nn.Module predictor : The RNNLM :param function lossfun : The loss function to use :param int/str label_key : """ if not (isinstance(label_key, (int, str))): raise TypeError("label_key must be int or str, but is %s" % type(label_key)) super(ClassifierWithState, self).__init__() self.lossfun = lossfun self.y = None self.loss = None self.label_key = label_key self.predictor = predictor def forward(self, state, *args, **kwargs): """Compute the loss value for an input and label pair. Notes: It also computes accuracy and stores it to the attribute. When ``label_key`` is ``int``, the corresponding element in ``args`` is treated as ground truth labels. And when it is ``str``, the element in ``kwargs`` is used. The all elements of ``args`` and ``kwargs`` except the groundtruth labels are features. It feeds features to the predictor and compare the result with ground truth labels. :param torch.Tensor state : the LM state :param list[torch.Tensor] args : Input minibatch :param dict[torch.Tensor] kwargs : Input minibatch :return loss value :rtype torch.Tensor """ if isinstance(self.label_key, int): if not (-len(args) <= self.label_key < len(args)): msg = "Label key %d is out of bounds" % self.label_key raise ValueError(msg) t = args[self.label_key] if self.label_key == -1: args = args[:-1] else: args = args[: self.label_key] + args[self.label_key + 1 :] elif isinstance(self.label_key, str): if self.label_key not in kwargs: msg = 'Label key "%s" is not found' % self.label_key raise ValueError(msg) t = kwargs[self.label_key] del kwargs[self.label_key] self.y = None self.loss = None state, self.y = self.predictor(state, *args, **kwargs) self.loss = self.lossfun(self.y, t) return state, self.loss def predict(self, state, x): """Predict log probabilities for given state and input x using the predictor. :param torch.Tensor state : The current state :param torch.Tensor x : The input :return a tuple (new state, log prob vector) :rtype (torch.Tensor, torch.Tensor) """ if hasattr(self.predictor, "normalized") and self.predictor.normalized: return self.predictor(state, x) else: state, z = self.predictor(state, x) return state, F.log_softmax(z, dim=1) def buff_predict(self, state, x, n): """Predict new tokens from buffered inputs.""" if self.predictor.__class__.__name__ == "RNNLM": return self.predict(state, x) new_state = [] new_log_y = [] for i in range(n): state_i = None if state is None else state[i] state_i, log_y = self.predict(state_i, x[i].unsqueeze(0)) new_state.append(state_i) new_log_y.append(log_y) return new_state, torch.cat(new_log_y) def final(self, state, index=None): """Predict final log probabilities for given state using the predictor. :param state: The state :return The final log probabilities :rtype torch.Tensor """ if hasattr(self.predictor, "final"): if index is not None: return self.predictor.final(state[index]) else: return self.predictor.final(state) else: return 0.0 # Definition of a recurrent net for language modeling class RNNLM(nn.Module): """A pytorch RNNLM.""" def __init__( self, n_vocab, n_layers, n_units, n_embed=None, typ="lstm", dropout_rate=0.5, emb_dropout_rate=0.0, tie_weights=False, ): """Initialize class. :param int n_vocab: The size of the vocabulary :param int n_layers: The number of layers to create :param int n_units: The number of units per layer :param str typ: The RNN type """ super(RNNLM, self).__init__() if n_embed is None: n_embed = n_units self.embed = nn.Embedding(n_vocab, n_embed) if emb_dropout_rate == 0.0: self.embed_drop = None else: self.embed_drop = nn.Dropout(emb_dropout_rate) if typ == "lstm": self.rnn = nn.ModuleList( [nn.LSTMCell(n_embed, n_units)] + [nn.LSTMCell(n_units, n_units) for _ in range(n_layers - 1)] ) else: self.rnn = nn.ModuleList( [nn.GRUCell(n_embed, n_units)] + [nn.GRUCell(n_units, n_units) for _ in range(n_layers - 1)] ) self.dropout = nn.ModuleList( [nn.Dropout(dropout_rate) for _ in range(n_layers + 1)] ) self.lo = nn.Linear(n_units, n_vocab) self.n_layers = n_layers self.n_units = n_units self.typ = typ logging.info("Tie weights set to {}".format(tie_weights)) logging.info("Dropout set to {}".format(dropout_rate)) logging.info("Emb Dropout set to {}".format(emb_dropout_rate)) if tie_weights: assert ( n_embed == n_units ), "Tie Weights: True need embedding and final dimensions to match" self.lo.weight = self.embed.weight # initialize parameters from uniform distribution for param in self.parameters(): param.data.uniform_(-0.1, 0.1) def zero_state(self, batchsize): """Initialize state.""" p = next(self.parameters()) return torch.zeros(batchsize, self.n_units).to(device=p.device, dtype=p.dtype) def forward(self, state, x): """Forward neural networks.""" if state is None: h = [to_device(x, self.zero_state(x.size(0))) for n in range(self.n_layers)] state = {"h": h} if self.typ == "lstm": c = [ to_device(x, self.zero_state(x.size(0))) for n in range(self.n_layers) ] state = {"c": c, "h": h} h = [None] * self.n_layers if self.embed_drop is not None: emb = self.embed_drop(self.embed(x)) else: emb = self.embed(x) if self.typ == "lstm": c = [None] * self.n_layers h[0], c[0] = self.rnn[0]( self.dropout[0](emb), (state["h"][0], state["c"][0]) ) for n in range(1, self.n_layers): h[n], c[n] = self.rnn[n]( self.dropout[n](h[n - 1]), (state["h"][n], state["c"][n]) ) state = {"c": c, "h": h} else: h[0] = self.rnn[0](self.dropout[0](emb), state["h"][0]) for n in range(1, self.n_layers): h[n] = self.rnn[n](self.dropout[n](h[n - 1]), state["h"][n]) state = {"h": h} y = self.lo(self.dropout[-1](h[-1])) return state, y