TheComputerMan's picture
Upload utils.py
9944ebb
raw
history blame
9.85 kB
"""
Taken from ESPNet, modified by Florian Lux
"""
import os
from abc import ABC
import torch
def cumsum_durations(durations):
out = [0]
for duration in durations:
out.append(duration + out[-1])
centers = list()
for index, _ in enumerate(out):
if index + 1 < len(out):
centers.append((out[index] + out[index + 1]) / 2)
return out, centers
def delete_old_checkpoints(checkpoint_dir, keep=5):
checkpoint_list = list()
for el in os.listdir(checkpoint_dir):
if el.endswith(".pt") and el != "best.pt":
checkpoint_list.append(int(el.split(".")[0].split("_")[1]))
if len(checkpoint_list) <= keep:
return
else:
checkpoint_list.sort(reverse=False)
checkpoints_to_delete = [os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(step)) for step in checkpoint_list[:-keep]]
for old_checkpoint in checkpoints_to_delete:
os.remove(os.path.join(old_checkpoint))
def get_most_recent_checkpoint(checkpoint_dir, verbose=True):
checkpoint_list = list()
for el in os.listdir(checkpoint_dir):
if el.endswith(".pt") and el != "best.pt":
checkpoint_list.append(int(el.split(".")[0].split("_")[1]))
if len(checkpoint_list) == 0:
print("No previous checkpoints found, cannot reload.")
return None
checkpoint_list.sort(reverse=True)
if verbose:
print("Reloading checkpoint_{}.pt".format(checkpoint_list[0]))
return os.path.join(checkpoint_dir, "checkpoint_{}.pt".format(checkpoint_list[0]))
def make_pad_mask(lengths, xs=None, length_dim=-1, device=None):
"""
Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
"""
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
if device is not None:
seq_range = torch.arange(0, maxlen, dtype=torch.int64, device=device)
else:
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(slice(None) if i in (0, length_dim) else None for i in range(xs.dim()))
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
def make_non_pad_mask(lengths, xs=None, length_dim=-1, device=None):
"""
Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
"""
return ~make_pad_mask(lengths, xs, length_dim, device=device)
def initialize(model, init):
"""
Initialize weights of a neural network module.
Parameters are initialized using the given method or distribution.
Args:
model: Target.
init: Method of initialization.
"""
# weight init
for p in model.parameters():
if p.dim() > 1:
if init == "xavier_uniform":
torch.nn.init.xavier_uniform_(p.data)
elif init == "xavier_normal":
torch.nn.init.xavier_normal_(p.data)
elif init == "kaiming_uniform":
torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
elif init == "kaiming_normal":
torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
else:
raise ValueError("Unknown initialization: " + init)
# bias init
for p in model.parameters():
if p.dim() == 1:
p.data.zero_()
# reset some modules with default init
for m in model.modules():
if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)):
m.reset_parameters()
def pad_list(xs, pad_value):
"""
Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
"""
n_batch = len(xs)
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, : xs[i].size(0)] = xs[i]
return pad
def subsequent_mask(size, device="cpu", dtype=torch.bool):
"""
Create mask for subsequent steps (size, size).
:param int size: size of mask
:param str device: "cpu" or "cuda" or torch.Tensor.device
:param torch.dtype dtype: result dtype
:rtype
"""
ret = torch.ones(size, size, device=device, dtype=dtype)
return torch.tril(ret, out=ret)
class ScorerInterface:
"""
Scorer interface for beam search.
The scorer performs scoring of the all tokens in vocabulary.
Examples:
* Search heuristics
* :class:`espnet.nets.scorers.length_bonus.LengthBonus`
* Decoder networks of the sequence-to-sequence models
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
* Neural language models
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
"""
def init_state(self, x):
"""
Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return None
def select_state(self, state, i, new_id=None):
"""
Select state with relative ids in the main beam search.
Args:
state: Decoder state for prefix tokens
i (int): Index to select a state in the main beam search
new_id (int): New label index to select a state if necessary
Returns:
state: pruned state
"""
return None if state is None else state[i]
def score(self, y, state, x):
"""
Score new token (required).
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
scores for next token that has a shape of `(n_vocab)`
and next state for ys
"""
raise NotImplementedError
def final_score(self, state):
"""
Score eos (optional).
Args:
state: Scorer state for prefix tokens
Returns:
float: final score
"""
return 0.0
class BatchScorerInterface(ScorerInterface, ABC):
def batch_init_state(self, x):
"""
Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return self.init_state(x)
def batch_score(self, ys, states, xs):
"""
Score new token batch (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
scores = list()
outstates = list()
for i, (y, state, x) in enumerate(zip(ys, states, xs)):
score, outstate = self.score(y, state, x)
outstates.append(outstate)
scores.append(score)
scores = torch.cat(scores, 0).view(ys.shape[0], -1)
return scores, outstates
def to_device(m, x):
"""Send tensor into the device of the module.
Args:
m (torch.nn.Module): Torch module.
x (Tensor): Torch tensor.
Returns:
Tensor: Torch tensor located in the same place as torch module.
"""
if isinstance(m, torch.nn.Module):
device = next(m.parameters()).device
elif isinstance(m, torch.Tensor):
device = m.device
else:
raise TypeError(
"Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
)
return x.to(device)