|
|
|
|
|
|
|
"""Training/decoding definition for the speech recognition task.""" |
|
|
|
import copy |
|
import json |
|
import logging |
|
import math |
|
import os |
|
import sys |
|
|
|
from chainer import reporter as reporter_module |
|
from chainer import training |
|
from chainer.training import extensions |
|
from chainer.training.updater import StandardUpdater |
|
import numpy as np |
|
from tensorboardX import SummaryWriter |
|
import torch |
|
from torch.nn.parallel import data_parallel |
|
|
|
from espnet.asr.asr_utils import adadelta_eps_decay |
|
from espnet.asr.asr_utils import add_results_to_json |
|
from espnet.asr.asr_utils import CompareValueTrigger |
|
from espnet.asr.asr_utils import format_mulenc_args |
|
from espnet.asr.asr_utils import get_model_conf |
|
from espnet.asr.asr_utils import plot_spectrogram |
|
from espnet.asr.asr_utils import restore_snapshot |
|
from espnet.asr.asr_utils import snapshot_object |
|
from espnet.asr.asr_utils import torch_load |
|
from espnet.asr.asr_utils import torch_resume |
|
from espnet.asr.asr_utils import torch_snapshot |
|
from espnet.asr.pytorch_backend.asr_init import freeze_modules |
|
from espnet.asr.pytorch_backend.asr_init import load_trained_model |
|
from espnet.asr.pytorch_backend.asr_init import load_trained_modules |
|
import espnet.lm.pytorch_backend.extlm as extlm_pytorch |
|
from espnet.nets.asr_interface import ASRInterface |
|
from espnet.nets.beam_search_transducer import BeamSearchTransducer |
|
from espnet.nets.pytorch_backend.e2e_asr import pad_list |
|
import espnet.nets.pytorch_backend.lm.default as lm_pytorch |
|
from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E |
|
from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E |
|
from espnet.transform.spectrogram import IStft |
|
from espnet.transform.transformation import Transformation |
|
from espnet.utils.cli_writers import file_writer_helper |
|
from espnet.utils.dataset import ChainerDataLoader |
|
from espnet.utils.dataset import TransformDataset |
|
from espnet.utils.deterministic_utils import set_deterministic_pytorch |
|
from espnet.utils.dynamic_import import dynamic_import |
|
from espnet.utils.io_utils import LoadInputsAndTargets |
|
from espnet.utils.training.batchfy import make_batchset |
|
from espnet.utils.training.evaluator import BaseEvaluator |
|
from espnet.utils.training.iterators import ShufflingEnabler |
|
from espnet.utils.training.tensorboard_logger import TensorboardLogger |
|
from espnet.utils.training.train_utils import check_early_stop |
|
from espnet.utils.training.train_utils import set_early_stop |
|
|
|
import matplotlib |
|
|
|
matplotlib.use("Agg") |
|
|
|
if sys.version_info[0] == 2: |
|
from itertools import izip_longest as zip_longest |
|
else: |
|
from itertools import zip_longest as zip_longest |
|
|
|
|
|
def _recursive_to(xs, device): |
|
if torch.is_tensor(xs): |
|
return xs.to(device) |
|
if isinstance(xs, tuple): |
|
return tuple(_recursive_to(x, device) for x in xs) |
|
return xs |
|
|
|
|
|
class CustomEvaluator(BaseEvaluator): |
|
"""Custom Evaluator for Pytorch. |
|
|
|
Args: |
|
model (torch.nn.Module): The model to evaluate. |
|
iterator (chainer.dataset.Iterator) : The train iterator. |
|
|
|
target (link | dict[str, link]) :Link object or a dictionary of |
|
links to evaluate. If this is just a link object, the link is |
|
registered by the name ``'main'``. |
|
|
|
device (torch.device): The device used. |
|
ngpu (int): The number of GPUs. |
|
|
|
""" |
|
|
|
def __init__(self, model, iterator, target, device, ngpu=None): |
|
super(CustomEvaluator, self).__init__(iterator, target) |
|
self.model = model |
|
self.device = device |
|
if ngpu is not None: |
|
self.ngpu = ngpu |
|
elif device.type == "cpu": |
|
self.ngpu = 0 |
|
else: |
|
self.ngpu = 1 |
|
|
|
|
|
def evaluate(self): |
|
"""Main evaluate routine for CustomEvaluator.""" |
|
iterator = self._iterators["main"] |
|
|
|
if self.eval_hook: |
|
self.eval_hook(self) |
|
|
|
if hasattr(iterator, "reset"): |
|
iterator.reset() |
|
it = iterator |
|
else: |
|
it = copy.copy(iterator) |
|
|
|
summary = reporter_module.DictSummary() |
|
|
|
self.model.eval() |
|
with torch.no_grad(): |
|
for batch in it: |
|
x = _recursive_to(batch, self.device) |
|
observation = {} |
|
with reporter_module.report_scope(observation): |
|
|
|
|
|
|
|
if self.ngpu == 0: |
|
self.model(*x) |
|
else: |
|
|
|
data_parallel(self.model, x, range(self.ngpu)) |
|
|
|
summary.add(observation) |
|
self.model.train() |
|
|
|
return summary.compute_mean() |
|
|
|
|
|
class CustomUpdater(StandardUpdater): |
|
"""Custom Updater for Pytorch. |
|
|
|
Args: |
|
model (torch.nn.Module): The model to update. |
|
grad_clip_threshold (float): The gradient clipping value to use. |
|
train_iter (chainer.dataset.Iterator): The training iterator. |
|
optimizer (torch.optim.optimizer): The training optimizer. |
|
|
|
device (torch.device): The device to use. |
|
ngpu (int): The number of gpus to use. |
|
use_apex (bool): The flag to use Apex in backprop. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
model, |
|
grad_clip_threshold, |
|
train_iter, |
|
optimizer, |
|
device, |
|
ngpu, |
|
grad_noise=False, |
|
accum_grad=1, |
|
use_apex=False, |
|
): |
|
super(CustomUpdater, self).__init__(train_iter, optimizer) |
|
self.model = model |
|
self.grad_clip_threshold = grad_clip_threshold |
|
self.device = device |
|
self.ngpu = ngpu |
|
self.accum_grad = accum_grad |
|
self.forward_count = 0 |
|
self.grad_noise = grad_noise |
|
self.iteration = 0 |
|
self.use_apex = use_apex |
|
|
|
|
|
def update_core(self): |
|
"""Main update routine of the CustomUpdater.""" |
|
|
|
|
|
train_iter = self.get_iterator("main") |
|
optimizer = self.get_optimizer("main") |
|
epoch = train_iter.epoch |
|
|
|
|
|
batch = train_iter.next() |
|
|
|
|
|
x = _recursive_to(batch, self.device) |
|
is_new_epoch = train_iter.epoch != epoch |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.ngpu == 0: |
|
loss = self.model(*x).mean() / self.accum_grad |
|
else: |
|
|
|
loss = ( |
|
data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad |
|
) |
|
if self.use_apex: |
|
from apex import amp |
|
|
|
|
|
opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer |
|
with amp.scale_loss(loss, opt) as scaled_loss: |
|
scaled_loss.backward() |
|
else: |
|
loss.backward() |
|
|
|
if self.grad_noise: |
|
from espnet.asr.asr_utils import add_gradient_noise |
|
|
|
add_gradient_noise( |
|
self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55 |
|
) |
|
|
|
|
|
self.forward_count += 1 |
|
if not is_new_epoch and self.forward_count != self.accum_grad: |
|
return |
|
self.forward_count = 0 |
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_( |
|
self.model.parameters(), self.grad_clip_threshold |
|
) |
|
logging.info("grad norm={}".format(grad_norm)) |
|
if math.isnan(grad_norm): |
|
logging.warning("grad norm is nan. Do not update model.") |
|
else: |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
def update(self): |
|
self.update_core() |
|
|
|
|
|
if self.forward_count == 0: |
|
self.iteration += 1 |
|
|
|
|
|
class CustomConverter(object): |
|
"""Custom batch converter for Pytorch. |
|
|
|
Args: |
|
subsampling_factor (int): The subsampling factor. |
|
dtype (torch.dtype): Data type to convert. |
|
|
|
""" |
|
|
|
def __init__(self, subsampling_factor=1, dtype=torch.float32): |
|
"""Construct a CustomConverter object.""" |
|
self.subsampling_factor = subsampling_factor |
|
self.ignore_id = -1 |
|
self.dtype = dtype |
|
|
|
def __call__(self, batch, device=torch.device("cpu")): |
|
"""Transform a batch and send it to a device. |
|
|
|
Args: |
|
batch (list): The batch to transform. |
|
device (torch.device): The device to send to. |
|
|
|
Returns: |
|
tuple(torch.Tensor, torch.Tensor, torch.Tensor) |
|
|
|
""" |
|
|
|
assert len(batch) == 1 |
|
xs, ys = batch[0] |
|
|
|
|
|
if self.subsampling_factor > 1: |
|
xs = [x[:: self.subsampling_factor, :] for x in xs] |
|
|
|
|
|
ilens = np.array([x.shape[0] for x in xs]) |
|
|
|
|
|
|
|
if xs[0].dtype.kind == "c": |
|
xs_pad_real = pad_list( |
|
[torch.from_numpy(x.real).float() for x in xs], 0 |
|
).to(device, dtype=self.dtype) |
|
xs_pad_imag = pad_list( |
|
[torch.from_numpy(x.imag).float() for x in xs], 0 |
|
).to(device, dtype=self.dtype) |
|
|
|
|
|
|
|
|
|
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag} |
|
else: |
|
xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to( |
|
device, dtype=self.dtype |
|
) |
|
|
|
ilens = torch.from_numpy(ilens).to(device) |
|
|
|
ys_pad = pad_list( |
|
[ |
|
torch.from_numpy( |
|
np.array(y[0][:]) if isinstance(y, tuple) else y |
|
).long() |
|
for y in ys |
|
], |
|
self.ignore_id, |
|
).to(device) |
|
|
|
return xs_pad, ilens, ys_pad |
|
|
|
|
|
class CustomConverterMulEnc(object): |
|
"""Custom batch converter for Pytorch in multi-encoder case. |
|
|
|
Args: |
|
subsampling_factors (list): List of subsampling factors for each encoder. |
|
dtype (torch.dtype): Data type to convert. |
|
|
|
""" |
|
|
|
def __init__(self, subsamping_factors=[1, 1], dtype=torch.float32): |
|
"""Initialize the converter.""" |
|
self.subsamping_factors = subsamping_factors |
|
self.ignore_id = -1 |
|
self.dtype = dtype |
|
self.num_encs = len(subsamping_factors) |
|
|
|
def __call__(self, batch, device=torch.device("cpu")): |
|
"""Transform a batch and send it to a device. |
|
|
|
Args: |
|
batch (list): The batch to transform. |
|
device (torch.device): The device to send to. |
|
|
|
Returns: |
|
tuple( list(torch.Tensor), list(torch.Tensor), torch.Tensor) |
|
|
|
""" |
|
|
|
assert len(batch) == 1 |
|
xs_list = batch[0][: self.num_encs] |
|
ys = batch[0][-1] |
|
|
|
|
|
if np.sum(self.subsamping_factors) > self.num_encs: |
|
xs_list = [ |
|
[x[:: self.subsampling_factors[i], :] for x in xs_list[i]] |
|
for i in range(self.num_encs) |
|
] |
|
|
|
|
|
ilens_list = [ |
|
np.array([x.shape[0] for x in xs_list[i]]) for i in range(self.num_encs) |
|
] |
|
|
|
|
|
|
|
xs_list_pad = [ |
|
pad_list([torch.from_numpy(x).float() for x in xs_list[i]], 0).to( |
|
device, dtype=self.dtype |
|
) |
|
for i in range(self.num_encs) |
|
] |
|
|
|
ilens_list = [ |
|
torch.from_numpy(ilens_list[i]).to(device) for i in range(self.num_encs) |
|
] |
|
|
|
ys_pad = pad_list( |
|
[ |
|
torch.from_numpy(np.array(y[0]) if isinstance(y, tuple) else y).long() |
|
for y in ys |
|
], |
|
self.ignore_id, |
|
).to(device) |
|
|
|
return xs_list_pad, ilens_list, ys_pad |
|
|
|
|
|
def train(args): |
|
"""Train with the given args. |
|
|
|
Args: |
|
args (namespace): The program arguments. |
|
|
|
""" |
|
set_deterministic_pytorch(args) |
|
if args.num_encs > 1: |
|
args = format_mulenc_args(args) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
logging.warning("cuda is not available") |
|
|
|
|
|
with open(args.valid_json, "rb") as f: |
|
valid_json = json.load(f)["utts"] |
|
utts = list(valid_json.keys()) |
|
idim_list = [ |
|
int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs) |
|
] |
|
odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) |
|
for i in range(args.num_encs): |
|
logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i])) |
|
logging.info("#output dims: " + str(odim)) |
|
|
|
|
|
if "transducer" in args.model_module: |
|
if ( |
|
getattr(args, "etype", False) == "custom" |
|
or getattr(args, "dtype", False) == "custom" |
|
): |
|
mtl_mode = "custom_transducer" |
|
else: |
|
mtl_mode = "transducer" |
|
logging.info("Pure transducer mode") |
|
elif args.mtlalpha == 1.0: |
|
mtl_mode = "ctc" |
|
logging.info("Pure CTC mode") |
|
elif args.mtlalpha == 0.0: |
|
mtl_mode = "att" |
|
logging.info("Pure attention mode") |
|
else: |
|
mtl_mode = "mtl" |
|
logging.info("Multitask learning mode") |
|
|
|
if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1: |
|
model = load_trained_modules(idim_list[0], odim, args) |
|
else: |
|
model_class = dynamic_import(args.model_module) |
|
model = model_class( |
|
idim_list[0] if args.num_encs == 1 else idim_list, odim, args |
|
) |
|
assert isinstance(model, ASRInterface) |
|
total_subsampling_factor = model.get_total_subsampling_factor() |
|
|
|
logging.info( |
|
" Total parameter of the model = " |
|
+ str(sum(p.numel() for p in model.parameters())) |
|
) |
|
|
|
if args.rnnlm is not None: |
|
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) |
|
rnnlm = lm_pytorch.ClassifierWithState( |
|
lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit) |
|
) |
|
torch_load(args.rnnlm, rnnlm) |
|
model.rnnlm = rnnlm |
|
|
|
|
|
if not os.path.exists(args.outdir): |
|
os.makedirs(args.outdir) |
|
model_conf = args.outdir + "/model.json" |
|
with open(model_conf, "wb") as f: |
|
logging.info("writing a model config file to " + model_conf) |
|
f.write( |
|
json.dumps( |
|
(idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)), |
|
indent=4, |
|
ensure_ascii=False, |
|
sort_keys=True, |
|
).encode("utf_8") |
|
) |
|
for key in sorted(vars(args).keys()): |
|
logging.info("ARGS: " + key + ": " + str(vars(args)[key])) |
|
|
|
reporter = model.reporter |
|
|
|
|
|
if args.ngpu > 1: |
|
if args.batch_size != 0: |
|
logging.warning( |
|
"batch size is automatically increased (%d -> %d)" |
|
% (args.batch_size, args.batch_size * args.ngpu) |
|
) |
|
args.batch_size *= args.ngpu |
|
if args.num_encs > 1: |
|
|
|
raise NotImplementedError( |
|
"Data parallel is not supported for multi-encoder setup." |
|
) |
|
|
|
|
|
device = torch.device("cuda" if args.ngpu > 0 else "cpu") |
|
if args.train_dtype in ("float16", "float32", "float64"): |
|
dtype = getattr(torch, args.train_dtype) |
|
else: |
|
dtype = torch.float32 |
|
model = model.to(device=device, dtype=dtype) |
|
|
|
if args.freeze_mods: |
|
model, model_params = freeze_modules(model, args.freeze_mods) |
|
else: |
|
model_params = model.parameters() |
|
|
|
logging.warning( |
|
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( |
|
sum(p.numel() for p in model.parameters()), |
|
sum(p.numel() for p in model.parameters() if p.requires_grad), |
|
sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
* 100.0 |
|
/ sum(p.numel() for p in model.parameters()), |
|
) |
|
) |
|
|
|
|
|
if args.opt == "adadelta": |
|
optimizer = torch.optim.Adadelta( |
|
model_params, rho=0.95, eps=args.eps, weight_decay=args.weight_decay |
|
) |
|
elif args.opt == "adam": |
|
optimizer = torch.optim.Adam(model_params, weight_decay=args.weight_decay) |
|
elif args.opt == "noam": |
|
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt |
|
|
|
|
|
|
|
if hasattr(args, "enc_block_arch") or hasattr(args, "dec_block_arch"): |
|
adim = model.most_dom_dim |
|
else: |
|
adim = args.adim |
|
|
|
optimizer = get_std_opt( |
|
model_params, adim, args.transformer_warmup_steps, args.transformer_lr |
|
) |
|
else: |
|
raise NotImplementedError("unknown optimizer: " + args.opt) |
|
|
|
|
|
if args.train_dtype in ("O0", "O1", "O2", "O3"): |
|
try: |
|
from apex import amp |
|
except ImportError as e: |
|
logging.error( |
|
f"You need to install apex for --train-dtype {args.train_dtype}. " |
|
"See https://github.com/NVIDIA/apex#linux" |
|
) |
|
raise e |
|
if args.opt == "noam": |
|
model, optimizer.optimizer = amp.initialize( |
|
model, optimizer.optimizer, opt_level=args.train_dtype |
|
) |
|
else: |
|
model, optimizer = amp.initialize( |
|
model, optimizer, opt_level=args.train_dtype |
|
) |
|
use_apex = True |
|
|
|
from espnet.nets.pytorch_backend.ctc import CTC |
|
|
|
amp.register_float_function(CTC, "loss_fn") |
|
amp.init() |
|
logging.warning("register ctc as float function") |
|
else: |
|
use_apex = False |
|
|
|
|
|
setattr(optimizer, "target", reporter) |
|
setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) |
|
|
|
|
|
if args.num_encs == 1: |
|
converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype) |
|
else: |
|
converter = CustomConverterMulEnc( |
|
[i[0] for i in model.subsample_list], dtype=dtype |
|
) |
|
|
|
|
|
with open(args.train_json, "rb") as f: |
|
train_json = json.load(f)["utts"] |
|
with open(args.valid_json, "rb") as f: |
|
valid_json = json.load(f)["utts"] |
|
|
|
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 |
|
|
|
train = make_batchset( |
|
train_json, |
|
args.batch_size, |
|
args.maxlen_in, |
|
args.maxlen_out, |
|
args.minibatches, |
|
min_batch_size=args.ngpu if args.ngpu > 1 else 1, |
|
shortest_first=use_sortagrad, |
|
count=args.batch_count, |
|
batch_bins=args.batch_bins, |
|
batch_frames_in=args.batch_frames_in, |
|
batch_frames_out=args.batch_frames_out, |
|
batch_frames_inout=args.batch_frames_inout, |
|
iaxis=0, |
|
oaxis=0, |
|
) |
|
valid = make_batchset( |
|
valid_json, |
|
args.batch_size, |
|
args.maxlen_in, |
|
args.maxlen_out, |
|
args.minibatches, |
|
min_batch_size=args.ngpu if args.ngpu > 1 else 1, |
|
count=args.batch_count, |
|
batch_bins=args.batch_bins, |
|
batch_frames_in=args.batch_frames_in, |
|
batch_frames_out=args.batch_frames_out, |
|
batch_frames_inout=args.batch_frames_inout, |
|
iaxis=0, |
|
oaxis=0, |
|
) |
|
|
|
load_tr = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=True, |
|
preprocess_conf=args.preprocess_conf, |
|
preprocess_args={"train": True}, |
|
) |
|
load_cv = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=True, |
|
preprocess_conf=args.preprocess_conf, |
|
preprocess_args={"train": False}, |
|
) |
|
|
|
|
|
|
|
|
|
train_iter = ChainerDataLoader( |
|
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), |
|
batch_size=1, |
|
num_workers=args.n_iter_processes, |
|
shuffle=not use_sortagrad, |
|
collate_fn=lambda x: x[0], |
|
) |
|
valid_iter = ChainerDataLoader( |
|
dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), |
|
batch_size=1, |
|
shuffle=False, |
|
collate_fn=lambda x: x[0], |
|
num_workers=args.n_iter_processes, |
|
) |
|
|
|
|
|
updater = CustomUpdater( |
|
model, |
|
args.grad_clip, |
|
{"main": train_iter}, |
|
optimizer, |
|
device, |
|
args.ngpu, |
|
args.grad_noise, |
|
args.accum_grad, |
|
use_apex=use_apex, |
|
) |
|
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) |
|
|
|
if use_sortagrad: |
|
trainer.extend( |
|
ShufflingEnabler([train_iter]), |
|
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), |
|
) |
|
|
|
|
|
if args.resume: |
|
logging.info("resumed from %s" % args.resume) |
|
torch_resume(args.resume, trainer) |
|
|
|
|
|
if args.save_interval_iters > 0: |
|
trainer.extend( |
|
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), |
|
trigger=(args.save_interval_iters, "iteration"), |
|
) |
|
else: |
|
trainer.extend( |
|
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu) |
|
) |
|
|
|
|
|
is_attn_plot = ( |
|
"transformer" in args.model_module |
|
or "conformer" in args.model_module |
|
or mtl_mode in ["att", "mtl", "custom_transducer"] |
|
) |
|
|
|
if args.num_save_attention > 0 and is_attn_plot: |
|
data = sorted( |
|
list(valid_json.items())[: args.num_save_attention], |
|
key=lambda x: int(x[1]["input"][0]["shape"][1]), |
|
reverse=True, |
|
) |
|
if hasattr(model, "module"): |
|
att_vis_fn = model.module.calculate_all_attentions |
|
plot_class = model.module.attention_plot_class |
|
else: |
|
att_vis_fn = model.calculate_all_attentions |
|
plot_class = model.attention_plot_class |
|
att_reporter = plot_class( |
|
att_vis_fn, |
|
data, |
|
args.outdir + "/att_ws", |
|
converter=converter, |
|
transform=load_cv, |
|
device=device, |
|
subsampling_factor=total_subsampling_factor, |
|
) |
|
trainer.extend(att_reporter, trigger=(1, "epoch")) |
|
else: |
|
att_reporter = None |
|
|
|
|
|
if mtl_mode in ["ctc", "mtl"] and args.num_save_ctc > 0: |
|
|
|
data = sorted( |
|
list(valid_json.items())[: args.num_save_ctc], |
|
key=lambda x: int(x[1]["output"][0]["shape"][0]), |
|
reverse=True, |
|
) |
|
if hasattr(model, "module"): |
|
ctc_vis_fn = model.module.calculate_all_ctc_probs |
|
plot_class = model.module.ctc_plot_class |
|
else: |
|
ctc_vis_fn = model.calculate_all_ctc_probs |
|
plot_class = model.ctc_plot_class |
|
ctc_reporter = plot_class( |
|
ctc_vis_fn, |
|
data, |
|
args.outdir + "/ctc_prob", |
|
converter=converter, |
|
transform=load_cv, |
|
device=device, |
|
subsampling_factor=total_subsampling_factor, |
|
) |
|
trainer.extend(ctc_reporter, trigger=(1, "epoch")) |
|
else: |
|
ctc_reporter = None |
|
|
|
|
|
if args.num_encs > 1: |
|
report_keys_loss_ctc = [ |
|
"main/loss_ctc{}".format(i + 1) for i in range(model.num_encs) |
|
] + ["validation/main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)] |
|
report_keys_cer_ctc = [ |
|
"main/cer_ctc{}".format(i + 1) for i in range(model.num_encs) |
|
] + ["validation/main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)] |
|
|
|
if hasattr(model, "is_rnnt"): |
|
trainer.extend( |
|
extensions.PlotReport( |
|
[ |
|
"main/loss", |
|
"validation/main/loss", |
|
"main/loss_trans", |
|
"validation/main/loss_trans", |
|
"main/loss_ctc", |
|
"validation/main/loss_ctc", |
|
"main/loss_lm", |
|
"validation/main/loss_lm", |
|
"main/loss_aux_trans", |
|
"validation/main/loss_aux_trans", |
|
"main/loss_aux_symm_kl", |
|
"validation/main/loss_aux_symm_kl", |
|
], |
|
"epoch", |
|
file_name="loss.png", |
|
) |
|
) |
|
else: |
|
trainer.extend( |
|
extensions.PlotReport( |
|
[ |
|
"main/loss", |
|
"validation/main/loss", |
|
"main/loss_ctc", |
|
"validation/main/loss_ctc", |
|
"main/loss_att", |
|
"validation/main/loss_att", |
|
] |
|
+ ([] if args.num_encs == 1 else report_keys_loss_ctc), |
|
"epoch", |
|
file_name="loss.png", |
|
) |
|
) |
|
|
|
trainer.extend( |
|
extensions.PlotReport( |
|
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png" |
|
) |
|
) |
|
trainer.extend( |
|
extensions.PlotReport( |
|
["main/cer_ctc", "validation/main/cer_ctc"] |
|
+ ([] if args.num_encs == 1 else report_keys_loss_ctc), |
|
"epoch", |
|
file_name="cer.png", |
|
) |
|
) |
|
|
|
|
|
trainer.extend( |
|
snapshot_object(model, "model.loss.best"), |
|
trigger=training.triggers.MinValueTrigger("validation/main/loss"), |
|
) |
|
if mtl_mode not in ["ctc", "transducer", "custom_transducer"]: |
|
trainer.extend( |
|
snapshot_object(model, "model.acc.best"), |
|
trigger=training.triggers.MaxValueTrigger("validation/main/acc"), |
|
) |
|
|
|
|
|
if args.save_interval_iters > 0: |
|
trainer.extend( |
|
torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), |
|
trigger=(args.save_interval_iters, "iteration"), |
|
) |
|
|
|
|
|
trainer.extend(torch_snapshot(), trigger=(1, "epoch")) |
|
|
|
|
|
if args.opt == "adadelta": |
|
if args.criterion == "acc" and mtl_mode != "ctc": |
|
trainer.extend( |
|
restore_snapshot( |
|
model, args.outdir + "/model.acc.best", load_fn=torch_load |
|
), |
|
trigger=CompareValueTrigger( |
|
"validation/main/acc", |
|
lambda best_value, current_value: best_value > current_value, |
|
), |
|
) |
|
trainer.extend( |
|
adadelta_eps_decay(args.eps_decay), |
|
trigger=CompareValueTrigger( |
|
"validation/main/acc", |
|
lambda best_value, current_value: best_value > current_value, |
|
), |
|
) |
|
elif args.criterion == "loss": |
|
trainer.extend( |
|
restore_snapshot( |
|
model, args.outdir + "/model.loss.best", load_fn=torch_load |
|
), |
|
trigger=CompareValueTrigger( |
|
"validation/main/loss", |
|
lambda best_value, current_value: best_value < current_value, |
|
), |
|
) |
|
trainer.extend( |
|
adadelta_eps_decay(args.eps_decay), |
|
trigger=CompareValueTrigger( |
|
"validation/main/loss", |
|
lambda best_value, current_value: best_value < current_value, |
|
), |
|
) |
|
|
|
|
|
|
|
|
|
elif args.criterion == "loss_eps_decay_only": |
|
trainer.extend( |
|
adadelta_eps_decay(args.eps_decay), |
|
trigger=CompareValueTrigger( |
|
"validation/main/loss", |
|
lambda best_value, current_value: best_value < current_value, |
|
), |
|
) |
|
|
|
|
|
trainer.extend( |
|
extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) |
|
) |
|
|
|
if hasattr(model, "is_rnnt"): |
|
report_keys = [ |
|
"epoch", |
|
"iteration", |
|
"main/loss", |
|
"main/loss_trans", |
|
"main/loss_ctc", |
|
"main/loss_lm", |
|
"main/loss_aux_trans", |
|
"main/loss_aux_symm_kl", |
|
"validation/main/loss", |
|
"validation/main/loss_trans", |
|
"validation/main/loss_ctc", |
|
"validation/main/loss_lm", |
|
"validation/main/loss_aux_trans", |
|
"validation/main/loss_aux_symm_kl", |
|
"elapsed_time", |
|
] |
|
else: |
|
report_keys = [ |
|
"epoch", |
|
"iteration", |
|
"main/loss", |
|
"main/loss_ctc", |
|
"main/loss_att", |
|
"validation/main/loss", |
|
"validation/main/loss_ctc", |
|
"validation/main/loss_att", |
|
"main/acc", |
|
"validation/main/acc", |
|
"main/cer_ctc", |
|
"validation/main/cer_ctc", |
|
"elapsed_time", |
|
] + ([] if args.num_encs == 1 else report_keys_cer_ctc + report_keys_loss_ctc) |
|
|
|
if args.opt == "adadelta": |
|
trainer.extend( |
|
extensions.observe_value( |
|
"eps", |
|
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ |
|
"eps" |
|
], |
|
), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
report_keys.append("eps") |
|
if args.report_cer: |
|
report_keys.append("validation/main/cer") |
|
if args.report_wer: |
|
report_keys.append("validation/main/wer") |
|
trainer.extend( |
|
extensions.PrintReport(report_keys), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
|
|
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) |
|
set_early_stop(trainer, args) |
|
|
|
if args.tensorboard_dir is not None and args.tensorboard_dir != "": |
|
trainer.extend( |
|
TensorboardLogger( |
|
SummaryWriter(args.tensorboard_dir), |
|
att_reporter=att_reporter, |
|
ctc_reporter=ctc_reporter, |
|
), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
|
|
trainer.run() |
|
check_early_stop(trainer, args.epochs) |
|
|
|
|
|
def recog(args): |
|
"""Decode with the given args. |
|
|
|
Args: |
|
args (namespace): The program arguments. |
|
|
|
""" |
|
set_deterministic_pytorch(args) |
|
model, train_args = load_trained_model(args.model, training=False) |
|
assert isinstance(model, ASRInterface) |
|
model.recog_args = args |
|
|
|
if args.streaming_mode and "transformer" in train_args.model_module: |
|
raise NotImplementedError("streaming mode for transformer is not implemented") |
|
logging.info( |
|
" Total parameter of the model = " |
|
+ str(sum(p.numel() for p in model.parameters())) |
|
) |
|
|
|
|
|
if args.rnnlm: |
|
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) |
|
if getattr(rnnlm_args, "model_module", "default") != "default": |
|
raise ValueError( |
|
"use '--api v2' option to decode with non-default language model" |
|
) |
|
rnnlm = lm_pytorch.ClassifierWithState( |
|
lm_pytorch.RNNLM( |
|
len(train_args.char_list), |
|
rnnlm_args.layer, |
|
rnnlm_args.unit, |
|
getattr(rnnlm_args, "embed_unit", None), |
|
) |
|
) |
|
torch_load(args.rnnlm, rnnlm) |
|
rnnlm.eval() |
|
else: |
|
rnnlm = None |
|
|
|
if args.word_rnnlm: |
|
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) |
|
word_dict = rnnlm_args.char_list_dict |
|
char_dict = {x: i for i, x in enumerate(train_args.char_list)} |
|
word_rnnlm = lm_pytorch.ClassifierWithState( |
|
lm_pytorch.RNNLM( |
|
len(word_dict), |
|
rnnlm_args.layer, |
|
rnnlm_args.unit, |
|
getattr(rnnlm_args, "embed_unit", None), |
|
) |
|
) |
|
torch_load(args.word_rnnlm, word_rnnlm) |
|
word_rnnlm.eval() |
|
|
|
if rnnlm is not None: |
|
rnnlm = lm_pytorch.ClassifierWithState( |
|
extlm_pytorch.MultiLevelLM( |
|
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict |
|
) |
|
) |
|
else: |
|
rnnlm = lm_pytorch.ClassifierWithState( |
|
extlm_pytorch.LookAheadWordLM( |
|
word_rnnlm.predictor, word_dict, char_dict |
|
) |
|
) |
|
|
|
|
|
if args.ngpu == 1: |
|
gpu_id = list(range(args.ngpu)) |
|
logging.info("gpu id: " + str(gpu_id)) |
|
model.cuda() |
|
if rnnlm: |
|
rnnlm.cuda() |
|
|
|
|
|
with open(args.recog_json, "rb") as f: |
|
js = json.load(f)["utts"] |
|
new_js = {} |
|
|
|
load_inputs_and_targets = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=False, |
|
sort_in_input_length=False, |
|
preprocess_conf=train_args.preprocess_conf |
|
if args.preprocess_conf is None |
|
else args.preprocess_conf, |
|
preprocess_args={"train": False}, |
|
) |
|
|
|
|
|
if hasattr(model, "is_rnnt"): |
|
if hasattr(model, "dec"): |
|
trans_decoder = model.dec |
|
else: |
|
trans_decoder = model.decoder |
|
joint_network = model.joint_network |
|
|
|
beam_search_transducer = BeamSearchTransducer( |
|
decoder=trans_decoder, |
|
joint_network=joint_network, |
|
beam_size=args.beam_size, |
|
nbest=args.nbest, |
|
lm=rnnlm, |
|
lm_weight=args.lm_weight, |
|
search_type=args.search_type, |
|
max_sym_exp=args.max_sym_exp, |
|
u_max=args.u_max, |
|
nstep=args.nstep, |
|
prefix_alpha=args.prefix_alpha, |
|
score_norm=args.score_norm, |
|
) |
|
|
|
if args.batchsize == 0: |
|
with torch.no_grad(): |
|
for idx, name in enumerate(js.keys(), 1): |
|
logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) |
|
batch = [(name, js[name])] |
|
feat = load_inputs_and_targets(batch) |
|
feat = ( |
|
feat[0][0] |
|
if args.num_encs == 1 |
|
else [feat[idx][0] for idx in range(model.num_encs)] |
|
) |
|
if args.streaming_mode == "window" and args.num_encs == 1: |
|
logging.info( |
|
"Using streaming recognizer with window size %d frames", |
|
args.streaming_window, |
|
) |
|
se2e = WindowStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) |
|
for i in range(0, feat.shape[0], args.streaming_window): |
|
logging.info( |
|
"Feeding frames %d - %d", i, i + args.streaming_window |
|
) |
|
se2e.accept_input(feat[i : i + args.streaming_window]) |
|
logging.info("Running offline attention decoder") |
|
se2e.decode_with_attention_offline() |
|
logging.info("Offline attention decoder finished") |
|
nbest_hyps = se2e.retrieve_recognition() |
|
elif args.streaming_mode == "segment" and args.num_encs == 1: |
|
logging.info( |
|
"Using streaming recognizer with threshold value %d", |
|
args.streaming_min_blank_dur, |
|
) |
|
nbest_hyps = [] |
|
for n in range(args.nbest): |
|
nbest_hyps.append({"yseq": [], "score": 0.0}) |
|
se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) |
|
r = np.prod(model.subsample) |
|
for i in range(0, feat.shape[0], r): |
|
hyps = se2e.accept_input(feat[i : i + r]) |
|
if hyps is not None: |
|
text = "".join( |
|
[ |
|
train_args.char_list[int(x)] |
|
for x in hyps[0]["yseq"][1:-1] |
|
if int(x) != -1 |
|
] |
|
) |
|
text = text.replace( |
|
"\u2581", " " |
|
).strip() |
|
text = text.replace(model.space, " ") |
|
text = text.replace(model.blank, "") |
|
logging.info(text) |
|
for n in range(args.nbest): |
|
nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"]) |
|
nbest_hyps[n]["score"] += hyps[n]["score"] |
|
elif hasattr(model, "is_rnnt"): |
|
nbest_hyps = model.recognize(feat, beam_search_transducer) |
|
else: |
|
nbest_hyps = model.recognize( |
|
feat, args, train_args.char_list, rnnlm |
|
) |
|
new_js[name] = add_results_to_json( |
|
js[name], nbest_hyps, train_args.char_list |
|
) |
|
|
|
else: |
|
|
|
def grouper(n, iterable, fillvalue=None): |
|
kargs = [iter(iterable)] * n |
|
return zip_longest(*kargs, fillvalue=fillvalue) |
|
|
|
|
|
keys = list(js.keys()) |
|
if args.batchsize > 1: |
|
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] |
|
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) |
|
keys = [keys[i] for i in sorted_index] |
|
|
|
with torch.no_grad(): |
|
for names in grouper(args.batchsize, keys, None): |
|
names = [name for name in names if name] |
|
batch = [(name, js[name]) for name in names] |
|
feats = ( |
|
load_inputs_and_targets(batch)[0] |
|
if args.num_encs == 1 |
|
else load_inputs_and_targets(batch) |
|
) |
|
if args.streaming_mode == "window" and args.num_encs == 1: |
|
raise NotImplementedError |
|
elif args.streaming_mode == "segment" and args.num_encs == 1: |
|
if args.batchsize > 1: |
|
raise NotImplementedError |
|
feat = feats[0] |
|
nbest_hyps = [] |
|
for n in range(args.nbest): |
|
nbest_hyps.append({"yseq": [], "score": 0.0}) |
|
se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) |
|
r = np.prod(model.subsample) |
|
for i in range(0, feat.shape[0], r): |
|
hyps = se2e.accept_input(feat[i : i + r]) |
|
if hyps is not None: |
|
text = "".join( |
|
[ |
|
train_args.char_list[int(x)] |
|
for x in hyps[0]["yseq"][1:-1] |
|
if int(x) != -1 |
|
] |
|
) |
|
text = text.replace( |
|
"\u2581", " " |
|
).strip() |
|
text = text.replace(model.space, " ") |
|
text = text.replace(model.blank, "") |
|
logging.info(text) |
|
for n in range(args.nbest): |
|
nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"]) |
|
nbest_hyps[n]["score"] += hyps[n]["score"] |
|
nbest_hyps = [nbest_hyps] |
|
else: |
|
nbest_hyps = model.recognize_batch( |
|
feats, args, train_args.char_list, rnnlm=rnnlm |
|
) |
|
|
|
for i, nbest_hyp in enumerate(nbest_hyps): |
|
name = names[i] |
|
new_js[name] = add_results_to_json( |
|
js[name], nbest_hyp, train_args.char_list |
|
) |
|
|
|
with open(args.result_label, "wb") as f: |
|
f.write( |
|
json.dumps( |
|
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True |
|
).encode("utf_8") |
|
) |
|
|
|
|
|
def enhance(args): |
|
"""Dumping enhanced speech and mask. |
|
|
|
Args: |
|
args (namespace): The program arguments. |
|
""" |
|
set_deterministic_pytorch(args) |
|
|
|
idim, odim, train_args = get_model_conf(args.model, args.model_conf) |
|
|
|
|
|
assert args.num_encs == 1, "number of encoder should be 1 ({} is given)".format( |
|
args.num_encs |
|
) |
|
|
|
|
|
logging.info("reading model parameters from " + args.model) |
|
model_class = dynamic_import(train_args.model_module) |
|
model = model_class(idim, odim, train_args) |
|
assert isinstance(model, ASRInterface) |
|
torch_load(args.model, model) |
|
model.recog_args = args |
|
|
|
|
|
if args.ngpu == 1: |
|
gpu_id = list(range(args.ngpu)) |
|
logging.info("gpu id: " + str(gpu_id)) |
|
model.cuda() |
|
|
|
|
|
with open(args.recog_json, "rb") as f: |
|
js = json.load(f)["utts"] |
|
|
|
load_inputs_and_targets = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=False, |
|
sort_in_input_length=False, |
|
preprocess_conf=None, |
|
) |
|
if args.batchsize == 0: |
|
args.batchsize = 1 |
|
|
|
|
|
if args.enh_wspecifier is not None: |
|
enh_writer = file_writer_helper(args.enh_wspecifier, filetype=args.enh_filetype) |
|
else: |
|
enh_writer = None |
|
|
|
|
|
preprocess_conf = ( |
|
train_args.preprocess_conf |
|
if args.preprocess_conf is None |
|
else args.preprocess_conf |
|
) |
|
if preprocess_conf is not None: |
|
logging.info(f"Use preprocessing: {preprocess_conf}") |
|
transform = Transformation(preprocess_conf) |
|
else: |
|
transform = None |
|
|
|
|
|
istft = None |
|
frame_shift = args.istft_n_shift |
|
if args.apply_istft: |
|
if preprocess_conf is not None: |
|
|
|
with open(preprocess_conf) as f: |
|
|
|
|
|
|
|
|
|
|
|
|
|
conf = json.load(f) |
|
assert "process" in conf, conf |
|
|
|
for p in conf["process"]: |
|
if p["type"] == "stft": |
|
istft = IStft( |
|
win_length=p["win_length"], |
|
n_shift=p["n_shift"], |
|
window=p.get("window", "hann"), |
|
) |
|
logging.info( |
|
"stft is found in {}. " |
|
"Setting istft config from it\n{}".format( |
|
preprocess_conf, istft |
|
) |
|
) |
|
frame_shift = p["n_shift"] |
|
break |
|
if istft is None: |
|
|
|
istft = IStft( |
|
win_length=args.istft_win_length, |
|
n_shift=args.istft_n_shift, |
|
window=args.istft_window, |
|
) |
|
logging.info( |
|
"Setting istft config from the command line args\n{}".format(istft) |
|
) |
|
|
|
|
|
keys = list(js.keys()) |
|
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] |
|
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) |
|
keys = [keys[i] for i in sorted_index] |
|
|
|
def grouper(n, iterable, fillvalue=None): |
|
kargs = [iter(iterable)] * n |
|
return zip_longest(*kargs, fillvalue=fillvalue) |
|
|
|
num_images = 0 |
|
if not os.path.exists(args.image_dir): |
|
os.makedirs(args.image_dir) |
|
|
|
for names in grouper(args.batchsize, keys, None): |
|
batch = [(name, js[name]) for name in names] |
|
|
|
|
|
org_feats = load_inputs_and_targets(batch)[0] |
|
if transform is not None: |
|
|
|
feats = transform(org_feats, train=False) |
|
else: |
|
feats = org_feats |
|
|
|
with torch.no_grad(): |
|
enhanced, mask, ilens = model.enhance(feats) |
|
|
|
for idx, name in enumerate(names): |
|
|
|
|
|
enh = enhanced[idx][: ilens[idx]] |
|
mas = mask[idx][: ilens[idx]] |
|
feat = feats[idx] |
|
|
|
|
|
if args.image_dir is not None and num_images < args.num_images: |
|
import matplotlib.pyplot as plt |
|
|
|
num_images += 1 |
|
ref_ch = 0 |
|
|
|
plt.figure(figsize=(20, 10)) |
|
plt.subplot(4, 1, 1) |
|
plt.title("Mask [ref={}ch]".format(ref_ch)) |
|
plot_spectrogram( |
|
plt, |
|
mas[:, ref_ch].T, |
|
fs=args.fs, |
|
mode="linear", |
|
frame_shift=frame_shift, |
|
bottom=False, |
|
labelbottom=False, |
|
) |
|
|
|
plt.subplot(4, 1, 2) |
|
plt.title("Noisy speech [ref={}ch]".format(ref_ch)) |
|
plot_spectrogram( |
|
plt, |
|
feat[:, ref_ch].T, |
|
fs=args.fs, |
|
mode="db", |
|
frame_shift=frame_shift, |
|
bottom=False, |
|
labelbottom=False, |
|
) |
|
|
|
plt.subplot(4, 1, 3) |
|
plt.title("Masked speech [ref={}ch]".format(ref_ch)) |
|
plot_spectrogram( |
|
plt, |
|
(feat[:, ref_ch] * mas[:, ref_ch]).T, |
|
frame_shift=frame_shift, |
|
fs=args.fs, |
|
mode="db", |
|
bottom=False, |
|
labelbottom=False, |
|
) |
|
|
|
plt.subplot(4, 1, 4) |
|
plt.title("Enhanced speech") |
|
plot_spectrogram( |
|
plt, enh.T, fs=args.fs, mode="db", frame_shift=frame_shift |
|
) |
|
|
|
plt.savefig(os.path.join(args.image_dir, name + ".png")) |
|
plt.clf() |
|
|
|
|
|
if enh_writer is not None: |
|
if istft is not None: |
|
enh = istft(enh) |
|
else: |
|
enh = enh |
|
|
|
if args.keep_length: |
|
if len(org_feats[idx]) < len(enh): |
|
|
|
enh = enh[: len(org_feats[idx])] |
|
elif len(org_feats) > len(enh): |
|
padwidth = [(0, (len(org_feats[idx]) - len(enh)))] + [ |
|
(0, 0) |
|
] * (enh.ndim - 1) |
|
enh = np.pad(enh, padwidth, mode="constant") |
|
|
|
if args.enh_filetype in ("sound", "sound.hdf5"): |
|
enh_writer[name] = (args.fs, enh) |
|
else: |
|
|
|
|
|
enh_writer[name] = enh |
|
|
|
if num_images >= args.num_images and enh_writer is None: |
|
logging.info("Breaking the process.") |
|
break |
|
|
|
|
|
def ctc_align(args): |
|
"""CTC forced alignments with the given args. |
|
|
|
Args: |
|
args (namespace): The program arguments. |
|
""" |
|
|
|
def add_alignment_to_json(js, alignment, char_list): |
|
"""Add N-best results to json. |
|
|
|
Args: |
|
js (dict[str, Any]): Groundtruth utterance dict. |
|
alignment (list[int]): List of alignment. |
|
char_list (list[str]): List of characters. |
|
|
|
Returns: |
|
dict[str, Any]: N-best results added utterance dict. |
|
|
|
""" |
|
|
|
new_js = dict() |
|
new_js["ctc_alignment"] = [] |
|
|
|
alignment_tokens = [] |
|
for idx, a in enumerate(alignment): |
|
alignment_tokens.append(char_list[a]) |
|
alignment_tokens = " ".join(alignment_tokens) |
|
|
|
new_js["ctc_alignment"] = alignment_tokens |
|
|
|
return new_js |
|
|
|
set_deterministic_pytorch(args) |
|
model, train_args = load_trained_model(args.model) |
|
assert isinstance(model, ASRInterface) |
|
model.eval() |
|
|
|
load_inputs_and_targets = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=True, |
|
sort_in_input_length=False, |
|
preprocess_conf=train_args.preprocess_conf |
|
if args.preprocess_conf is None |
|
else args.preprocess_conf, |
|
preprocess_args={"train": False}, |
|
) |
|
|
|
if args.ngpu > 1: |
|
raise NotImplementedError("only single GPU decoding is supported") |
|
if args.ngpu == 1: |
|
device = "cuda" |
|
else: |
|
device = "cpu" |
|
dtype = getattr(torch, args.dtype) |
|
logging.info(f"Decoding device={device}, dtype={dtype}") |
|
model.to(device=device, dtype=dtype).eval() |
|
|
|
|
|
with open(args.align_json, "rb") as f: |
|
js = json.load(f)["utts"] |
|
new_js = {} |
|
if args.batchsize == 0: |
|
with torch.no_grad(): |
|
for idx, name in enumerate(js.keys(), 1): |
|
logging.info("(%d/%d) aligning " + name, idx, len(js.keys())) |
|
batch = [(name, js[name])] |
|
feat, label = load_inputs_and_targets(batch) |
|
feat = feat[0] |
|
label = label[0] |
|
enc = model.encode(torch.as_tensor(feat).to(device)).unsqueeze(0) |
|
alignment = model.ctc.forced_align(enc, label) |
|
new_js[name] = add_alignment_to_json( |
|
js[name], alignment, train_args.char_list |
|
) |
|
else: |
|
raise NotImplementedError("Align_batch is not implemented.") |
|
|
|
with open(args.result_label, "wb") as f: |
|
f.write( |
|
json.dumps( |
|
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True |
|
).encode("utf_8") |
|
) |
|
|