|
"""Utility functions for transducer models.""" |
|
|
|
import os |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from espnet.nets.pytorch_backend.nets_utils import pad_list |
|
|
|
|
|
def prepare_loss_inputs(ys_pad, hlens, blank_id=0, ignore_id=-1): |
|
"""Prepare tensors for transducer loss computation. |
|
|
|
Args: |
|
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) |
|
hlens (torch.Tensor): batch of hidden sequence lengthts (B) |
|
or batch of masks (B, 1, Tmax) |
|
blank_id (int): index of blank label |
|
ignore_id (int): index of initial padding |
|
|
|
Returns: |
|
ys_in_pad (torch.Tensor): batch of padded target sequences + blank (B, Lmax + 1) |
|
target (torch.Tensor): batch of padded target sequences (B, Lmax) |
|
pred_len (torch.Tensor): batch of hidden sequence lengths (B) |
|
target_len (torch.Tensor): batch of output sequence lengths (B) |
|
|
|
""" |
|
device = ys_pad.device |
|
|
|
ys = [y[y != ignore_id] for y in ys_pad] |
|
blank = ys[0].new([blank_id]) |
|
|
|
ys_in_pad = pad_list([torch.cat([blank, y], dim=0) for y in ys], blank_id) |
|
ys_out_pad = pad_list([torch.cat([y, blank], dim=0) for y in ys], ignore_id) |
|
|
|
target = pad_list(ys, blank_id).type(torch.int32).to(device) |
|
target_len = torch.IntTensor([y.size(0) for y in ys]).to(device) |
|
|
|
if torch.is_tensor(hlens): |
|
if hlens.dim() > 1: |
|
hs = [h[h != 0] for h in hlens] |
|
hlens = list(map(int, [h.size(0) for h in hs])) |
|
else: |
|
hlens = list(map(int, hlens)) |
|
|
|
pred_len = torch.IntTensor(hlens).to(device) |
|
|
|
return ys_in_pad, ys_out_pad, target, pred_len, target_len |
|
|
|
|
|
def valid_aux_task_layer_list(aux_layer_ids, enc_num_layers): |
|
"""Check whether input list of auxiliary layer ids is valid. |
|
|
|
Return the valid list sorted with duplicated removed. |
|
|
|
Args: |
|
aux_layer_ids (list): Auxiliary layers ids |
|
enc_num_layers (int): Number of encoder layers |
|
|
|
Returns: |
|
valid (list): Validated list of layers for auxiliary task |
|
|
|
""" |
|
if ( |
|
not isinstance(aux_layer_ids, list) |
|
or not aux_layer_ids |
|
or not all(isinstance(layer, int) for layer in aux_layer_ids) |
|
): |
|
raise ValueError("--aux-task-layer-list argument takes a list of layer ids.") |
|
|
|
sorted_list = sorted(aux_layer_ids, key=int, reverse=False) |
|
valid = list(filter(lambda x: 0 <= x < enc_num_layers, sorted_list)) |
|
|
|
if sorted_list != valid: |
|
raise ValueError( |
|
"Provided list of layer ids for auxiliary task is incorrect. " |
|
"IDs should be between [0, %d]" % (enc_num_layers - 1) |
|
) |
|
|
|
return valid |
|
|
|
|
|
def is_prefix(x, pref): |
|
"""Check prefix. |
|
|
|
Args: |
|
x (list): token id sequence |
|
pref (list): token id sequence |
|
|
|
Returns: |
|
(boolean): whether pref is a prefix of x. |
|
|
|
""" |
|
if len(pref) >= len(x): |
|
return False |
|
|
|
for i in range(len(pref)): |
|
if pref[i] != x[i]: |
|
return False |
|
|
|
return True |
|
|
|
|
|
def substract(x, subset): |
|
"""Remove elements of subset if corresponding token id sequence exist in x. |
|
|
|
Args: |
|
x (list): set of hypotheses |
|
subset (list): subset of hypotheses |
|
|
|
Returns: |
|
final (list): new set |
|
|
|
""" |
|
final = [] |
|
|
|
for x_ in x: |
|
if any(x_.yseq == sub.yseq for sub in subset): |
|
continue |
|
final.append(x_) |
|
|
|
return final |
|
|
|
|
|
def select_lm_state(lm_states, idx, lm_layers, is_wordlm): |
|
"""Get LM state from batch for given id. |
|
|
|
Args: |
|
lm_states (list or dict): batch of LM states |
|
idx (int): index to extract state from batch state |
|
lm_layers (int): number of LM layers |
|
is_wordlm (bool): whether provided LM is a word-LM |
|
|
|
Returns: |
|
idx_state (dict): LM state for given id |
|
|
|
""" |
|
if is_wordlm: |
|
idx_state = lm_states[idx] |
|
else: |
|
idx_state = {} |
|
|
|
idx_state["c"] = [lm_states["c"][layer][idx] for layer in range(lm_layers)] |
|
idx_state["h"] = [lm_states["h"][layer][idx] for layer in range(lm_layers)] |
|
|
|
return idx_state |
|
|
|
|
|
def create_lm_batch_state(lm_states_list, lm_layers, is_wordlm): |
|
"""Create batch of LM states. |
|
|
|
Args: |
|
lm_states (list or dict): list of individual LM states |
|
lm_layers (int): number of LM layers |
|
is_wordlm (bool): whether provided LM is a word-LM |
|
|
|
Returns: |
|
batch_states (list): batch of LM states |
|
|
|
""" |
|
if is_wordlm: |
|
batch_states = lm_states_list |
|
else: |
|
batch_states = {} |
|
|
|
batch_states["c"] = [ |
|
torch.stack([state["c"][layer] for state in lm_states_list]) |
|
for layer in range(lm_layers) |
|
] |
|
batch_states["h"] = [ |
|
torch.stack([state["h"][layer] for state in lm_states_list]) |
|
for layer in range(lm_layers) |
|
] |
|
|
|
return batch_states |
|
|
|
|
|
def init_lm_state(lm_model): |
|
"""Initialize LM state. |
|
|
|
Args: |
|
lm_model (torch.nn.Module): LM module |
|
|
|
Returns: |
|
lm_state (dict): initial LM state |
|
|
|
""" |
|
lm_layers = len(lm_model.rnn) |
|
lm_units_typ = lm_model.typ |
|
lm_units = lm_model.n_units |
|
|
|
p = next(lm_model.parameters()) |
|
|
|
h = [ |
|
torch.zeros(lm_units).to(device=p.device, dtype=p.dtype) |
|
for _ in range(lm_layers) |
|
] |
|
|
|
lm_state = {"h": h} |
|
|
|
if lm_units_typ == "lstm": |
|
lm_state["c"] = [ |
|
torch.zeros(lm_units).to(device=p.device, dtype=p.dtype) |
|
for _ in range(lm_layers) |
|
] |
|
|
|
return lm_state |
|
|
|
|
|
def recombine_hyps(hyps): |
|
"""Recombine hypotheses with equivalent output sequence. |
|
|
|
Args: |
|
hyps (list): list of hypotheses |
|
|
|
Returns: |
|
final (list): list of recombined hypotheses |
|
|
|
""" |
|
final = [] |
|
|
|
for hyp in hyps: |
|
seq_final = [f.yseq for f in final if f.yseq] |
|
|
|
if hyp.yseq in seq_final: |
|
seq_pos = seq_final.index(hyp.yseq) |
|
|
|
final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score) |
|
else: |
|
final.append(hyp) |
|
|
|
return hyps |
|
|
|
|
|
def pad_sequence(seqlist, pad_token): |
|
"""Left pad list of token id sequences. |
|
|
|
Args: |
|
seqlist (list): list of token id sequences |
|
pad_token (int): padding token id |
|
|
|
Returns: |
|
final (list): list of padded token id sequences |
|
|
|
""" |
|
maxlen = max(len(x) for x in seqlist) |
|
|
|
final = [([pad_token] * (maxlen - len(x))) + x for x in seqlist] |
|
|
|
return final |
|
|
|
|
|
def check_state(state, max_len, pad_token): |
|
"""Check state and left pad or trim if necessary. |
|
|
|
Args: |
|
state (list): list of of L decoder states (in_len, dec_dim) |
|
max_len (int): maximum length authorized |
|
pad_token (int): padding token id |
|
|
|
Returns: |
|
final (list): list of L padded decoder states (1, max_len, dec_dim) |
|
|
|
""" |
|
if state is None or max_len < 1 or state[0].size(1) == max_len: |
|
return state |
|
|
|
curr_len = state[0].size(1) |
|
|
|
if curr_len > max_len: |
|
trim_val = int(state[0].size(1) - max_len) |
|
|
|
for i, s in enumerate(state): |
|
state[i] = s[:, trim_val:, :] |
|
else: |
|
layers = len(state) |
|
ddim = state[0].size(2) |
|
|
|
final_dims = (1, max_len, ddim) |
|
final = [state[0].data.new(*final_dims).fill_(pad_token) for _ in range(layers)] |
|
|
|
for i, s in enumerate(state): |
|
final[i][:, (max_len - s.size(1)) : max_len, :] = s |
|
|
|
return final |
|
|
|
return state |
|
|
|
|
|
def check_batch_state(state, max_len, pad_token): |
|
"""Check batch of states and left pad or trim if necessary. |
|
|
|
Args: |
|
state (list): list of of L decoder states (B, ?, dec_dim) |
|
max_len (int): maximum length authorized |
|
pad_token (int): padding token id |
|
|
|
Returns: |
|
final (list): list of L decoder states (B, pred_len, dec_dim) |
|
|
|
""" |
|
final_dims = (len(state), max_len, state[0].size(1)) |
|
final = state[0].data.new(*final_dims).fill_(pad_token) |
|
|
|
for i, s in enumerate(state): |
|
curr_len = s.size(0) |
|
|
|
if curr_len < max_len: |
|
final[i, (max_len - curr_len) : max_len, :] = s |
|
else: |
|
final[i, :, :] = s[(curr_len - max_len) :, :] |
|
|
|
return final |
|
|
|
|
|
def custom_torch_load(model_path, model, training=True): |
|
"""Load transducer model modules and parameters with training-only ones removed. |
|
|
|
Args: |
|
model_path (str): Model path |
|
model (torch.nn.Module): The model with pretrained modules |
|
|
|
""" |
|
if "snapshot" in os.path.basename(model_path): |
|
model_state_dict = torch.load( |
|
model_path, map_location=lambda storage, loc: storage |
|
)["model"] |
|
else: |
|
model_state_dict = torch.load( |
|
model_path, map_location=lambda storage, loc: storage |
|
) |
|
|
|
if not training: |
|
model_state_dict = { |
|
k: v for k, v in model_state_dict.items() if not k.startswith("aux") |
|
} |
|
|
|
model.load_state_dict(model_state_dict) |
|
|
|
del model_state_dict |
|
|