|
|
|
|
|
|
|
"""Training/decoding definition for the speech recognition task.""" |
|
|
|
import json |
|
import logging |
|
import os |
|
import six |
|
|
|
|
|
import chainer |
|
|
|
from chainer import training |
|
|
|
from chainer.datasets import TransformDataset |
|
from chainer.training import extensions |
|
|
|
|
|
from espnet.asr.asr_utils import adadelta_eps_decay |
|
from espnet.asr.asr_utils import add_results_to_json |
|
from espnet.asr.asr_utils import chainer_load |
|
from espnet.asr.asr_utils import CompareValueTrigger |
|
from espnet.asr.asr_utils import get_model_conf |
|
from espnet.asr.asr_utils import restore_snapshot |
|
from espnet.nets.asr_interface import ASRInterface |
|
from espnet.utils.deterministic_utils import set_deterministic_chainer |
|
from espnet.utils.dynamic_import import dynamic_import |
|
from espnet.utils.io_utils import LoadInputsAndTargets |
|
from espnet.utils.training.batchfy import make_batchset |
|
from espnet.utils.training.evaluator import BaseEvaluator |
|
from espnet.utils.training.iterators import ShufflingEnabler |
|
from espnet.utils.training.iterators import ToggleableShufflingMultiprocessIterator |
|
from espnet.utils.training.iterators import ToggleableShufflingSerialIterator |
|
from espnet.utils.training.train_utils import check_early_stop |
|
from espnet.utils.training.train_utils import set_early_stop |
|
|
|
|
|
import espnet.lm.chainer_backend.extlm as extlm_chainer |
|
import espnet.lm.chainer_backend.lm as lm_chainer |
|
|
|
|
|
import matplotlib |
|
|
|
from espnet.utils.training.tensorboard_logger import TensorboardLogger |
|
from tensorboardX import SummaryWriter |
|
|
|
matplotlib.use("Agg") |
|
|
|
|
|
def train(args): |
|
"""Train with the given args. |
|
|
|
Args: |
|
args (namespace): The program arguments. |
|
|
|
""" |
|
|
|
logging.info("chainer version = " + chainer.__version__) |
|
|
|
set_deterministic_chainer(args) |
|
|
|
|
|
if not chainer.cuda.available: |
|
logging.warning("cuda is not available") |
|
if not chainer.cuda.cudnn_enabled: |
|
logging.warning("cudnn is not available") |
|
|
|
|
|
with open(args.valid_json, "rb") as f: |
|
valid_json = json.load(f)["utts"] |
|
utts = list(valid_json.keys()) |
|
idim = int(valid_json[utts[0]]["input"][0]["shape"][1]) |
|
odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) |
|
logging.info("#input dims : " + str(idim)) |
|
logging.info("#output dims: " + str(odim)) |
|
|
|
|
|
if args.mtlalpha == 1.0: |
|
mtl_mode = "ctc" |
|
logging.info("Pure CTC mode") |
|
elif args.mtlalpha == 0.0: |
|
mtl_mode = "att" |
|
logging.info("Pure attention mode") |
|
else: |
|
mtl_mode = "mtl" |
|
logging.info("Multitask learning mode") |
|
|
|
|
|
logging.info("import model module: " + args.model_module) |
|
model_class = dynamic_import(args.model_module) |
|
model = model_class(idim, odim, args, flag_return=False) |
|
assert isinstance(model, ASRInterface) |
|
total_subsampling_factor = model.get_total_subsampling_factor() |
|
|
|
|
|
if not os.path.exists(args.outdir): |
|
os.makedirs(args.outdir) |
|
model_conf = args.outdir + "/model.json" |
|
with open(model_conf, "wb") as f: |
|
logging.info("writing a model config file to " + model_conf) |
|
f.write( |
|
json.dumps( |
|
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True |
|
).encode("utf_8") |
|
) |
|
for key in sorted(vars(args).keys()): |
|
logging.info("ARGS: " + key + ": " + str(vars(args)[key])) |
|
|
|
|
|
ngpu = args.ngpu |
|
if ngpu == 1: |
|
gpu_id = 0 |
|
|
|
chainer.cuda.get_device_from_id(gpu_id).use() |
|
model.to_gpu() |
|
logging.info("single gpu calculation.") |
|
elif ngpu > 1: |
|
gpu_id = 0 |
|
devices = {"main": gpu_id} |
|
for gid in six.moves.xrange(1, ngpu): |
|
devices["sub_%d" % gid] = gid |
|
logging.info("multi gpu calculation (#gpus = %d)." % ngpu) |
|
logging.warning( |
|
"batch size is automatically increased (%d -> %d)" |
|
% (args.batch_size, args.batch_size * args.ngpu) |
|
) |
|
else: |
|
gpu_id = -1 |
|
logging.info("cpu calculation") |
|
|
|
|
|
if args.opt == "adadelta": |
|
optimizer = chainer.optimizers.AdaDelta(eps=args.eps) |
|
elif args.opt == "adam": |
|
optimizer = chainer.optimizers.Adam() |
|
elif args.opt == "noam": |
|
optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9) |
|
else: |
|
raise NotImplementedError("args.opt={}".format(args.opt)) |
|
|
|
optimizer.setup(model) |
|
optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip)) |
|
|
|
|
|
converter = model.custom_converter(subsampling_factor=model.subsample[0]) |
|
|
|
|
|
with open(args.train_json, "rb") as f: |
|
train_json = json.load(f)["utts"] |
|
with open(args.valid_json, "rb") as f: |
|
valid_json = json.load(f)["utts"] |
|
|
|
|
|
load_tr = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=True, |
|
preprocess_conf=args.preprocess_conf, |
|
preprocess_args={"train": True}, |
|
) |
|
load_cv = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=True, |
|
preprocess_conf=args.preprocess_conf, |
|
preprocess_args={"train": False}, |
|
) |
|
|
|
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 |
|
accum_grad = args.accum_grad |
|
if ngpu <= 1: |
|
|
|
train = make_batchset( |
|
train_json, |
|
args.batch_size, |
|
args.maxlen_in, |
|
args.maxlen_out, |
|
args.minibatches, |
|
min_batch_size=args.ngpu if args.ngpu > 1 else 1, |
|
shortest_first=use_sortagrad, |
|
count=args.batch_count, |
|
batch_bins=args.batch_bins, |
|
batch_frames_in=args.batch_frames_in, |
|
batch_frames_out=args.batch_frames_out, |
|
batch_frames_inout=args.batch_frames_inout, |
|
iaxis=0, |
|
oaxis=0, |
|
) |
|
|
|
|
|
if args.n_iter_processes > 0: |
|
train_iters = [ |
|
ToggleableShufflingMultiprocessIterator( |
|
TransformDataset(train, load_tr), |
|
batch_size=1, |
|
n_processes=args.n_iter_processes, |
|
n_prefetch=8, |
|
maxtasksperchild=20, |
|
shuffle=not use_sortagrad, |
|
) |
|
] |
|
else: |
|
train_iters = [ |
|
ToggleableShufflingSerialIterator( |
|
TransformDataset(train, load_tr), |
|
batch_size=1, |
|
shuffle=not use_sortagrad, |
|
) |
|
] |
|
|
|
|
|
updater = model.custom_updater( |
|
train_iters[0], |
|
optimizer, |
|
converter=converter, |
|
device=gpu_id, |
|
accum_grad=accum_grad, |
|
) |
|
else: |
|
if args.batch_count not in ("auto", "seq") and args.batch_size == 0: |
|
raise NotImplementedError( |
|
"--batch-count 'bin' and 'frame' are not implemented " |
|
"in chainer multi gpu" |
|
) |
|
|
|
train_subsets = [] |
|
for gid in six.moves.xrange(ngpu): |
|
|
|
train_json_subset = { |
|
k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid |
|
} |
|
|
|
train_subsets += [ |
|
make_batchset( |
|
train_json_subset, |
|
args.batch_size, |
|
args.maxlen_in, |
|
args.maxlen_out, |
|
args.minibatches, |
|
) |
|
] |
|
|
|
|
|
maxlen = max([len(train_subset) for train_subset in train_subsets]) |
|
for train_subset in train_subsets: |
|
if maxlen != len(train_subset): |
|
for i in six.moves.xrange(maxlen - len(train_subset)): |
|
train_subset += [train_subset[i]] |
|
|
|
|
|
|
|
if args.n_iter_processes > 0: |
|
train_iters = [ |
|
ToggleableShufflingMultiprocessIterator( |
|
TransformDataset(train_subsets[gid], load_tr), |
|
batch_size=1, |
|
n_processes=args.n_iter_processes, |
|
n_prefetch=8, |
|
maxtasksperchild=20, |
|
shuffle=not use_sortagrad, |
|
) |
|
for gid in six.moves.xrange(ngpu) |
|
] |
|
else: |
|
train_iters = [ |
|
ToggleableShufflingSerialIterator( |
|
TransformDataset(train_subsets[gid], load_tr), |
|
batch_size=1, |
|
shuffle=not use_sortagrad, |
|
) |
|
for gid in six.moves.xrange(ngpu) |
|
] |
|
|
|
|
|
updater = model.custom_parallel_updater( |
|
train_iters, optimizer, converter=converter, devices=devices |
|
) |
|
|
|
|
|
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) |
|
|
|
if use_sortagrad: |
|
trainer.extend( |
|
ShufflingEnabler(train_iters), |
|
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), |
|
) |
|
if args.opt == "noam": |
|
from espnet.nets.chainer_backend.transformer.training import VaswaniRule |
|
|
|
trainer.extend( |
|
VaswaniRule( |
|
"alpha", |
|
d=args.adim, |
|
warmup_steps=args.transformer_warmup_steps, |
|
scale=args.transformer_lr, |
|
), |
|
trigger=(1, "iteration"), |
|
) |
|
|
|
if args.resume: |
|
chainer.serializers.load_npz(args.resume, trainer) |
|
|
|
|
|
valid = make_batchset( |
|
valid_json, |
|
args.batch_size, |
|
args.maxlen_in, |
|
args.maxlen_out, |
|
args.minibatches, |
|
min_batch_size=args.ngpu if args.ngpu > 1 else 1, |
|
count=args.batch_count, |
|
batch_bins=args.batch_bins, |
|
batch_frames_in=args.batch_frames_in, |
|
batch_frames_out=args.batch_frames_out, |
|
batch_frames_inout=args.batch_frames_inout, |
|
iaxis=0, |
|
oaxis=0, |
|
) |
|
|
|
if args.n_iter_processes > 0: |
|
valid_iter = chainer.iterators.MultiprocessIterator( |
|
TransformDataset(valid, load_cv), |
|
batch_size=1, |
|
repeat=False, |
|
shuffle=False, |
|
n_processes=args.n_iter_processes, |
|
n_prefetch=8, |
|
maxtasksperchild=20, |
|
) |
|
else: |
|
valid_iter = chainer.iterators.SerialIterator( |
|
TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False |
|
) |
|
|
|
|
|
trainer.extend(BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id)) |
|
|
|
|
|
if args.num_save_attention > 0 and args.mtlalpha != 1.0: |
|
data = sorted( |
|
list(valid_json.items())[: args.num_save_attention], |
|
key=lambda x: int(x[1]["input"][0]["shape"][1]), |
|
reverse=True, |
|
) |
|
if hasattr(model, "module"): |
|
att_vis_fn = model.module.calculate_all_attentions |
|
plot_class = model.module.attention_plot_class |
|
else: |
|
att_vis_fn = model.calculate_all_attentions |
|
plot_class = model.attention_plot_class |
|
logging.info("Using custom PlotAttentionReport") |
|
att_reporter = plot_class( |
|
att_vis_fn, |
|
data, |
|
args.outdir + "/att_ws", |
|
converter=converter, |
|
transform=load_cv, |
|
device=gpu_id, |
|
subsampling_factor=total_subsampling_factor, |
|
) |
|
trainer.extend(att_reporter, trigger=(1, "epoch")) |
|
else: |
|
att_reporter = None |
|
|
|
|
|
trainer.extend( |
|
extensions.snapshot(filename="snapshot.ep.{.updater.epoch}"), |
|
trigger=(1, "epoch"), |
|
) |
|
|
|
|
|
trainer.extend( |
|
extensions.PlotReport( |
|
[ |
|
"main/loss", |
|
"validation/main/loss", |
|
"main/loss_ctc", |
|
"validation/main/loss_ctc", |
|
"main/loss_att", |
|
"validation/main/loss_att", |
|
], |
|
"epoch", |
|
file_name="loss.png", |
|
) |
|
) |
|
trainer.extend( |
|
extensions.PlotReport( |
|
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png" |
|
) |
|
) |
|
|
|
|
|
trainer.extend( |
|
extensions.snapshot_object(model, "model.loss.best"), |
|
trigger=training.triggers.MinValueTrigger("validation/main/loss"), |
|
) |
|
if mtl_mode != "ctc": |
|
trainer.extend( |
|
extensions.snapshot_object(model, "model.acc.best"), |
|
trigger=training.triggers.MaxValueTrigger("validation/main/acc"), |
|
) |
|
|
|
|
|
if args.opt == "adadelta": |
|
if args.criterion == "acc" and mtl_mode != "ctc": |
|
trainer.extend( |
|
restore_snapshot(model, args.outdir + "/model.acc.best"), |
|
trigger=CompareValueTrigger( |
|
"validation/main/acc", |
|
lambda best_value, current_value: best_value > current_value, |
|
), |
|
) |
|
trainer.extend( |
|
adadelta_eps_decay(args.eps_decay), |
|
trigger=CompareValueTrigger( |
|
"validation/main/acc", |
|
lambda best_value, current_value: best_value > current_value, |
|
), |
|
) |
|
elif args.criterion == "loss": |
|
trainer.extend( |
|
restore_snapshot(model, args.outdir + "/model.loss.best"), |
|
trigger=CompareValueTrigger( |
|
"validation/main/loss", |
|
lambda best_value, current_value: best_value < current_value, |
|
), |
|
) |
|
trainer.extend( |
|
adadelta_eps_decay(args.eps_decay), |
|
trigger=CompareValueTrigger( |
|
"validation/main/loss", |
|
lambda best_value, current_value: best_value < current_value, |
|
), |
|
) |
|
|
|
|
|
trainer.extend( |
|
extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) |
|
) |
|
report_keys = [ |
|
"epoch", |
|
"iteration", |
|
"main/loss", |
|
"main/loss_ctc", |
|
"main/loss_att", |
|
"validation/main/loss", |
|
"validation/main/loss_ctc", |
|
"validation/main/loss_att", |
|
"main/acc", |
|
"validation/main/acc", |
|
"elapsed_time", |
|
] |
|
if args.opt == "adadelta": |
|
trainer.extend( |
|
extensions.observe_value( |
|
"eps", lambda trainer: trainer.updater.get_optimizer("main").eps |
|
), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
report_keys.append("eps") |
|
trainer.extend( |
|
extensions.PrintReport(report_keys), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
|
|
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) |
|
|
|
set_early_stop(trainer, args) |
|
if args.tensorboard_dir is not None and args.tensorboard_dir != "": |
|
writer = SummaryWriter(args.tensorboard_dir) |
|
trainer.extend( |
|
TensorboardLogger(writer, att_reporter), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
|
|
|
|
trainer.run() |
|
check_early_stop(trainer, args.epochs) |
|
|
|
|
|
def recog(args): |
|
"""Decode with the given args. |
|
|
|
Args: |
|
args (namespace): The program arguments. |
|
|
|
""" |
|
|
|
logging.info("chainer version = " + chainer.__version__) |
|
|
|
set_deterministic_chainer(args) |
|
|
|
|
|
idim, odim, train_args = get_model_conf(args.model, args.model_conf) |
|
|
|
for key in sorted(vars(args).keys()): |
|
logging.info("ARGS: " + key + ": " + str(vars(args)[key])) |
|
|
|
|
|
logging.info("reading model parameters from " + args.model) |
|
|
|
if hasattr(train_args, "model_module"): |
|
model_module = train_args.model_module |
|
else: |
|
model_module = "espnet.nets.chainer_backend.e2e_asr:E2E" |
|
model_class = dynamic_import(model_module) |
|
model = model_class(idim, odim, train_args) |
|
assert isinstance(model, ASRInterface) |
|
chainer_load(args.model, model) |
|
|
|
|
|
if args.rnnlm: |
|
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) |
|
rnnlm = lm_chainer.ClassifierWithState( |
|
lm_chainer.RNNLM( |
|
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit |
|
) |
|
) |
|
chainer_load(args.rnnlm, rnnlm) |
|
else: |
|
rnnlm = None |
|
|
|
if args.word_rnnlm: |
|
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) |
|
word_dict = rnnlm_args.char_list_dict |
|
char_dict = {x: i for i, x in enumerate(train_args.char_list)} |
|
word_rnnlm = lm_chainer.ClassifierWithState( |
|
lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit) |
|
) |
|
chainer_load(args.word_rnnlm, word_rnnlm) |
|
|
|
if rnnlm is not None: |
|
rnnlm = lm_chainer.ClassifierWithState( |
|
extlm_chainer.MultiLevelLM( |
|
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict |
|
) |
|
) |
|
else: |
|
rnnlm = lm_chainer.ClassifierWithState( |
|
extlm_chainer.LookAheadWordLM( |
|
word_rnnlm.predictor, word_dict, char_dict |
|
) |
|
) |
|
|
|
|
|
with open(args.recog_json, "rb") as f: |
|
js = json.load(f)["utts"] |
|
|
|
load_inputs_and_targets = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=False, |
|
sort_in_input_length=False, |
|
preprocess_conf=train_args.preprocess_conf |
|
if args.preprocess_conf is None |
|
else args.preprocess_conf, |
|
preprocess_args={"train": False}, |
|
) |
|
|
|
|
|
new_js = {} |
|
with chainer.no_backprop_mode(): |
|
for idx, name in enumerate(js.keys(), 1): |
|
logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) |
|
batch = [(name, js[name])] |
|
feat = load_inputs_and_targets(batch)[0][0] |
|
nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) |
|
new_js[name] = add_results_to_json( |
|
js[name], nbest_hyps, train_args.char_list |
|
) |
|
|
|
with open(args.result_label, "wb") as f: |
|
f.write( |
|
json.dumps( |
|
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True |
|
).encode("utf_8") |
|
) |
|
|