tobiasc's picture
Initial commit
ad16788
"""Utility functions for transducer models."""
import os
import numpy as np
import torch
from espnet.nets.pytorch_backend.nets_utils import pad_list
def prepare_loss_inputs(ys_pad, hlens, blank_id=0, ignore_id=-1):
"""Prepare tensors for transducer loss computation.
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
hlens (torch.Tensor): batch of hidden sequence lengthts (B)
or batch of masks (B, 1, Tmax)
blank_id (int): index of blank label
ignore_id (int): index of initial padding
Returns:
ys_in_pad (torch.Tensor): batch of padded target sequences + blank (B, Lmax + 1)
target (torch.Tensor): batch of padded target sequences (B, Lmax)
pred_len (torch.Tensor): batch of hidden sequence lengths (B)
target_len (torch.Tensor): batch of output sequence lengths (B)
"""
device = ys_pad.device
ys = [y[y != ignore_id] for y in ys_pad]
blank = ys[0].new([blank_id])
ys_in_pad = pad_list([torch.cat([blank, y], dim=0) for y in ys], blank_id)
ys_out_pad = pad_list([torch.cat([y, blank], dim=0) for y in ys], ignore_id)
target = pad_list(ys, blank_id).type(torch.int32).to(device)
target_len = torch.IntTensor([y.size(0) for y in ys]).to(device)
if torch.is_tensor(hlens):
if hlens.dim() > 1:
hs = [h[h != 0] for h in hlens]
hlens = list(map(int, [h.size(0) for h in hs]))
else:
hlens = list(map(int, hlens))
pred_len = torch.IntTensor(hlens).to(device)
return ys_in_pad, ys_out_pad, target, pred_len, target_len
def valid_aux_task_layer_list(aux_layer_ids, enc_num_layers):
"""Check whether input list of auxiliary layer ids is valid.
Return the valid list sorted with duplicated removed.
Args:
aux_layer_ids (list): Auxiliary layers ids
enc_num_layers (int): Number of encoder layers
Returns:
valid (list): Validated list of layers for auxiliary task
"""
if (
not isinstance(aux_layer_ids, list)
or not aux_layer_ids
or not all(isinstance(layer, int) for layer in aux_layer_ids)
):
raise ValueError("--aux-task-layer-list argument takes a list of layer ids.")
sorted_list = sorted(aux_layer_ids, key=int, reverse=False)
valid = list(filter(lambda x: 0 <= x < enc_num_layers, sorted_list))
if sorted_list != valid:
raise ValueError(
"Provided list of layer ids for auxiliary task is incorrect. "
"IDs should be between [0, %d]" % (enc_num_layers - 1)
)
return valid
def is_prefix(x, pref):
"""Check prefix.
Args:
x (list): token id sequence
pref (list): token id sequence
Returns:
(boolean): whether pref is a prefix of x.
"""
if len(pref) >= len(x):
return False
for i in range(len(pref)):
if pref[i] != x[i]:
return False
return True
def substract(x, subset):
"""Remove elements of subset if corresponding token id sequence exist in x.
Args:
x (list): set of hypotheses
subset (list): subset of hypotheses
Returns:
final (list): new set
"""
final = []
for x_ in x:
if any(x_.yseq == sub.yseq for sub in subset):
continue
final.append(x_)
return final
def select_lm_state(lm_states, idx, lm_layers, is_wordlm):
"""Get LM state from batch for given id.
Args:
lm_states (list or dict): batch of LM states
idx (int): index to extract state from batch state
lm_layers (int): number of LM layers
is_wordlm (bool): whether provided LM is a word-LM
Returns:
idx_state (dict): LM state for given id
"""
if is_wordlm:
idx_state = lm_states[idx]
else:
idx_state = {}
idx_state["c"] = [lm_states["c"][layer][idx] for layer in range(lm_layers)]
idx_state["h"] = [lm_states["h"][layer][idx] for layer in range(lm_layers)]
return idx_state
def create_lm_batch_state(lm_states_list, lm_layers, is_wordlm):
"""Create batch of LM states.
Args:
lm_states (list or dict): list of individual LM states
lm_layers (int): number of LM layers
is_wordlm (bool): whether provided LM is a word-LM
Returns:
batch_states (list): batch of LM states
"""
if is_wordlm:
batch_states = lm_states_list
else:
batch_states = {}
batch_states["c"] = [
torch.stack([state["c"][layer] for state in lm_states_list])
for layer in range(lm_layers)
]
batch_states["h"] = [
torch.stack([state["h"][layer] for state in lm_states_list])
for layer in range(lm_layers)
]
return batch_states
def init_lm_state(lm_model):
"""Initialize LM state.
Args:
lm_model (torch.nn.Module): LM module
Returns:
lm_state (dict): initial LM state
"""
lm_layers = len(lm_model.rnn)
lm_units_typ = lm_model.typ
lm_units = lm_model.n_units
p = next(lm_model.parameters())
h = [
torch.zeros(lm_units).to(device=p.device, dtype=p.dtype)
for _ in range(lm_layers)
]
lm_state = {"h": h}
if lm_units_typ == "lstm":
lm_state["c"] = [
torch.zeros(lm_units).to(device=p.device, dtype=p.dtype)
for _ in range(lm_layers)
]
return lm_state
def recombine_hyps(hyps):
"""Recombine hypotheses with equivalent output sequence.
Args:
hyps (list): list of hypotheses
Returns:
final (list): list of recombined hypotheses
"""
final = []
for hyp in hyps:
seq_final = [f.yseq for f in final if f.yseq]
if hyp.yseq in seq_final:
seq_pos = seq_final.index(hyp.yseq)
final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score)
else:
final.append(hyp)
return hyps
def pad_sequence(seqlist, pad_token):
"""Left pad list of token id sequences.
Args:
seqlist (list): list of token id sequences
pad_token (int): padding token id
Returns:
final (list): list of padded token id sequences
"""
maxlen = max(len(x) for x in seqlist)
final = [([pad_token] * (maxlen - len(x))) + x for x in seqlist]
return final
def check_state(state, max_len, pad_token):
"""Check state and left pad or trim if necessary.
Args:
state (list): list of of L decoder states (in_len, dec_dim)
max_len (int): maximum length authorized
pad_token (int): padding token id
Returns:
final (list): list of L padded decoder states (1, max_len, dec_dim)
"""
if state is None or max_len < 1 or state[0].size(1) == max_len:
return state
curr_len = state[0].size(1)
if curr_len > max_len:
trim_val = int(state[0].size(1) - max_len)
for i, s in enumerate(state):
state[i] = s[:, trim_val:, :]
else:
layers = len(state)
ddim = state[0].size(2)
final_dims = (1, max_len, ddim)
final = [state[0].data.new(*final_dims).fill_(pad_token) for _ in range(layers)]
for i, s in enumerate(state):
final[i][:, (max_len - s.size(1)) : max_len, :] = s
return final
return state
def check_batch_state(state, max_len, pad_token):
"""Check batch of states and left pad or trim if necessary.
Args:
state (list): list of of L decoder states (B, ?, dec_dim)
max_len (int): maximum length authorized
pad_token (int): padding token id
Returns:
final (list): list of L decoder states (B, pred_len, dec_dim)
"""
final_dims = (len(state), max_len, state[0].size(1))
final = state[0].data.new(*final_dims).fill_(pad_token)
for i, s in enumerate(state):
curr_len = s.size(0)
if curr_len < max_len:
final[i, (max_len - curr_len) : max_len, :] = s
else:
final[i, :, :] = s[(curr_len - max_len) :, :]
return final
def custom_torch_load(model_path, model, training=True):
"""Load transducer model modules and parameters with training-only ones removed.
Args:
model_path (str): Model path
model (torch.nn.Module): The model with pretrained modules
"""
if "snapshot" in os.path.basename(model_path):
model_state_dict = torch.load(
model_path, map_location=lambda storage, loc: storage
)["model"]
else:
model_state_dict = torch.load(
model_path, map_location=lambda storage, loc: storage
)
if not training:
model_state_dict = {
k: v for k, v in model_state_dict.items() if not k.startswith("aux")
}
model.load_state_dict(model_state_dict)
del model_state_dict