|
|
|
|
|
|
|
"""Training/decoding definition for the speech translation task.""" |
|
|
|
import json |
|
import logging |
|
import os |
|
import sys |
|
|
|
from chainer import training |
|
from chainer.training import extensions |
|
import numpy as np |
|
from tensorboardX import SummaryWriter |
|
import torch |
|
|
|
from espnet.asr.asr_utils import adadelta_eps_decay |
|
from espnet.asr.asr_utils import adam_lr_decay |
|
from espnet.asr.asr_utils import add_results_to_json |
|
from espnet.asr.asr_utils import CompareValueTrigger |
|
from espnet.asr.asr_utils import restore_snapshot |
|
from espnet.asr.asr_utils import snapshot_object |
|
from espnet.asr.asr_utils import torch_load |
|
from espnet.asr.asr_utils import torch_resume |
|
from espnet.asr.asr_utils import torch_snapshot |
|
from espnet.asr.pytorch_backend.asr_init import load_trained_model |
|
from espnet.asr.pytorch_backend.asr_init import load_trained_modules |
|
|
|
from espnet.nets.pytorch_backend.e2e_asr import pad_list |
|
from espnet.nets.st_interface import STInterface |
|
from espnet.utils.dataset import ChainerDataLoader |
|
from espnet.utils.dataset import TransformDataset |
|
from espnet.utils.deterministic_utils import set_deterministic_pytorch |
|
from espnet.utils.dynamic_import import dynamic_import |
|
from espnet.utils.io_utils import LoadInputsAndTargets |
|
from espnet.utils.training.batchfy import make_batchset |
|
from espnet.utils.training.iterators import ShufflingEnabler |
|
from espnet.utils.training.tensorboard_logger import TensorboardLogger |
|
from espnet.utils.training.train_utils import check_early_stop |
|
from espnet.utils.training.train_utils import set_early_stop |
|
|
|
from espnet.asr.pytorch_backend.asr import CustomConverter as ASRCustomConverter |
|
from espnet.asr.pytorch_backend.asr import CustomEvaluator |
|
from espnet.asr.pytorch_backend.asr import CustomUpdater |
|
|
|
import matplotlib |
|
|
|
matplotlib.use("Agg") |
|
|
|
if sys.version_info[0] == 2: |
|
from itertools import izip_longest as zip_longest |
|
else: |
|
from itertools import zip_longest as zip_longest |
|
|
|
|
|
class CustomConverter(ASRCustomConverter): |
|
"""Custom batch converter for Pytorch. |
|
|
|
Args: |
|
subsampling_factor (int): The subsampling factor. |
|
dtype (torch.dtype): Data type to convert. |
|
use_source_text (bool): use source transcription. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, subsampling_factor=1, dtype=torch.float32, use_source_text=False |
|
): |
|
"""Construct a CustomConverter object.""" |
|
super().__init__(subsampling_factor=subsampling_factor, dtype=dtype) |
|
self.use_source_text = use_source_text |
|
|
|
def __call__(self, batch, device=torch.device("cpu")): |
|
"""Transform a batch and send it to a device. |
|
|
|
Args: |
|
batch (list): The batch to transform. |
|
device (torch.device): The device to send to. |
|
|
|
Returns: |
|
tuple(torch.Tensor, torch.Tensor, torch.Tensor) |
|
|
|
""" |
|
|
|
assert len(batch) == 1 |
|
xs, ys, ys_src = batch[0] |
|
|
|
|
|
ilens = np.array([x.shape[0] for x in xs]) |
|
ilens = torch.from_numpy(ilens).to(device) |
|
|
|
xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to( |
|
device, dtype=self.dtype |
|
) |
|
|
|
ys_pad = pad_list( |
|
[torch.from_numpy(np.array(y, dtype=np.int64)) for y in ys], |
|
self.ignore_id, |
|
).to(device) |
|
|
|
if self.use_source_text: |
|
ys_pad_src = pad_list( |
|
[torch.from_numpy(np.array(y, dtype=np.int64)) for y in ys_src], |
|
self.ignore_id, |
|
).to(device) |
|
else: |
|
ys_pad_src = None |
|
|
|
return xs_pad, ilens, ys_pad, ys_pad_src |
|
|
|
|
|
def train(args): |
|
"""Train with the given args. |
|
|
|
Args: |
|
args (namespace): The program arguments. |
|
|
|
""" |
|
set_deterministic_pytorch(args) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
logging.warning("cuda is not available") |
|
|
|
|
|
with open(args.valid_json, "rb") as f: |
|
valid_json = json.load(f)["utts"] |
|
utts = list(valid_json.keys()) |
|
idim = int(valid_json[utts[0]]["input"][0]["shape"][-1]) |
|
odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) |
|
logging.info("#input dims : " + str(idim)) |
|
logging.info("#output dims: " + str(odim)) |
|
|
|
|
|
if args.enc_init is not None or args.dec_init is not None: |
|
model = load_trained_modules(idim, odim, args, interface=STInterface) |
|
else: |
|
model_class = dynamic_import(args.model_module) |
|
model = model_class(idim, odim, args) |
|
assert isinstance(model, STInterface) |
|
total_subsampling_factor = model.get_total_subsampling_factor() |
|
|
|
|
|
if not os.path.exists(args.outdir): |
|
os.makedirs(args.outdir) |
|
model_conf = args.outdir + "/model.json" |
|
with open(model_conf, "wb") as f: |
|
logging.info("writing a model config file to " + model_conf) |
|
f.write( |
|
json.dumps( |
|
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True |
|
).encode("utf_8") |
|
) |
|
for key in sorted(vars(args).keys()): |
|
logging.info("ARGS: " + key + ": " + str(vars(args)[key])) |
|
|
|
reporter = model.reporter |
|
|
|
|
|
if args.ngpu > 1: |
|
if args.batch_size != 0: |
|
logging.warning( |
|
"batch size is automatically increased (%d -> %d)" |
|
% (args.batch_size, args.batch_size * args.ngpu) |
|
) |
|
args.batch_size *= args.ngpu |
|
|
|
|
|
device = torch.device("cuda" if args.ngpu > 0 else "cpu") |
|
if args.train_dtype in ("float16", "float32", "float64"): |
|
dtype = getattr(torch, args.train_dtype) |
|
else: |
|
dtype = torch.float32 |
|
model = model.to(device=device, dtype=dtype) |
|
|
|
logging.warning( |
|
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( |
|
sum(p.numel() for p in model.parameters()), |
|
sum(p.numel() for p in model.parameters() if p.requires_grad), |
|
sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
* 100.0 |
|
/ sum(p.numel() for p in model.parameters()), |
|
) |
|
) |
|
|
|
|
|
if args.opt == "adadelta": |
|
optimizer = torch.optim.Adadelta( |
|
model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay |
|
) |
|
elif args.opt == "adam": |
|
optimizer = torch.optim.Adam( |
|
model.parameters(), lr=args.lr, weight_decay=args.weight_decay |
|
) |
|
elif args.opt == "noam": |
|
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt |
|
|
|
optimizer = get_std_opt( |
|
model.parameters(), |
|
args.adim, |
|
args.transformer_warmup_steps, |
|
args.transformer_lr, |
|
) |
|
else: |
|
raise NotImplementedError("unknown optimizer: " + args.opt) |
|
|
|
|
|
if args.train_dtype in ("O0", "O1", "O2", "O3"): |
|
try: |
|
from apex import amp |
|
except ImportError as e: |
|
logging.error( |
|
f"You need to install apex for --train-dtype {args.train_dtype}. " |
|
"See https://github.com/NVIDIA/apex#linux" |
|
) |
|
raise e |
|
if args.opt == "noam": |
|
model, optimizer.optimizer = amp.initialize( |
|
model, optimizer.optimizer, opt_level=args.train_dtype |
|
) |
|
else: |
|
model, optimizer = amp.initialize( |
|
model, optimizer, opt_level=args.train_dtype |
|
) |
|
use_apex = True |
|
else: |
|
use_apex = False |
|
|
|
|
|
setattr(optimizer, "target", reporter) |
|
setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) |
|
|
|
|
|
converter = CustomConverter( |
|
subsampling_factor=model.subsample[0], |
|
dtype=dtype, |
|
use_source_text=args.asr_weight > 0 or args.mt_weight > 0, |
|
) |
|
|
|
|
|
with open(args.train_json, "rb") as f: |
|
train_json = json.load(f)["utts"] |
|
with open(args.valid_json, "rb") as f: |
|
valid_json = json.load(f)["utts"] |
|
|
|
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 |
|
|
|
train = make_batchset( |
|
train_json, |
|
args.batch_size, |
|
args.maxlen_in, |
|
args.maxlen_out, |
|
args.minibatches, |
|
min_batch_size=args.ngpu if args.ngpu > 1 else 1, |
|
shortest_first=use_sortagrad, |
|
count=args.batch_count, |
|
batch_bins=args.batch_bins, |
|
batch_frames_in=args.batch_frames_in, |
|
batch_frames_out=args.batch_frames_out, |
|
batch_frames_inout=args.batch_frames_inout, |
|
iaxis=0, |
|
oaxis=0, |
|
) |
|
valid = make_batchset( |
|
valid_json, |
|
args.batch_size, |
|
args.maxlen_in, |
|
args.maxlen_out, |
|
args.minibatches, |
|
min_batch_size=args.ngpu if args.ngpu > 1 else 1, |
|
count=args.batch_count, |
|
batch_bins=args.batch_bins, |
|
batch_frames_in=args.batch_frames_in, |
|
batch_frames_out=args.batch_frames_out, |
|
batch_frames_inout=args.batch_frames_inout, |
|
iaxis=0, |
|
oaxis=0, |
|
) |
|
|
|
load_tr = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=True, |
|
preprocess_conf=args.preprocess_conf, |
|
preprocess_args={"train": True}, |
|
) |
|
load_cv = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=True, |
|
preprocess_conf=args.preprocess_conf, |
|
preprocess_args={"train": False}, |
|
) |
|
|
|
|
|
|
|
|
|
train_iter = ChainerDataLoader( |
|
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), |
|
batch_size=1, |
|
num_workers=args.n_iter_processes, |
|
shuffle=not use_sortagrad, |
|
collate_fn=lambda x: x[0], |
|
) |
|
valid_iter = ChainerDataLoader( |
|
dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), |
|
batch_size=1, |
|
shuffle=False, |
|
collate_fn=lambda x: x[0], |
|
num_workers=args.n_iter_processes, |
|
) |
|
|
|
|
|
updater = CustomUpdater( |
|
model, |
|
args.grad_clip, |
|
{"main": train_iter}, |
|
optimizer, |
|
device, |
|
args.ngpu, |
|
args.grad_noise, |
|
args.accum_grad, |
|
use_apex=use_apex, |
|
) |
|
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) |
|
|
|
if use_sortagrad: |
|
trainer.extend( |
|
ShufflingEnabler([train_iter]), |
|
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), |
|
) |
|
|
|
|
|
if args.resume: |
|
logging.info("resumed from %s" % args.resume) |
|
torch_resume(args.resume, trainer) |
|
|
|
|
|
if args.save_interval_iters > 0: |
|
trainer.extend( |
|
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), |
|
trigger=(args.save_interval_iters, "iteration"), |
|
) |
|
else: |
|
trainer.extend( |
|
CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu) |
|
) |
|
|
|
|
|
if args.num_save_attention > 0: |
|
data = sorted( |
|
list(valid_json.items())[: args.num_save_attention], |
|
key=lambda x: int(x[1]["input"][0]["shape"][1]), |
|
reverse=True, |
|
) |
|
if hasattr(model, "module"): |
|
att_vis_fn = model.module.calculate_all_attentions |
|
plot_class = model.module.attention_plot_class |
|
else: |
|
att_vis_fn = model.calculate_all_attentions |
|
plot_class = model.attention_plot_class |
|
att_reporter = plot_class( |
|
att_vis_fn, |
|
data, |
|
args.outdir + "/att_ws", |
|
converter=converter, |
|
transform=load_cv, |
|
device=device, |
|
subsampling_factor=total_subsampling_factor, |
|
) |
|
trainer.extend(att_reporter, trigger=(1, "epoch")) |
|
else: |
|
att_reporter = None |
|
|
|
|
|
if (args.asr_weight > 0 and args.mtlalpha > 0) and args.num_save_ctc > 0: |
|
|
|
data = sorted( |
|
list(valid_json.items())[: args.num_save_ctc], |
|
key=lambda x: int(x[1]["output"][0]["shape"][0]), |
|
reverse=True, |
|
) |
|
if hasattr(model, "module"): |
|
ctc_vis_fn = model.module.calculate_all_ctc_probs |
|
plot_class = model.module.ctc_plot_class |
|
else: |
|
ctc_vis_fn = model.calculate_all_ctc_probs |
|
plot_class = model.ctc_plot_class |
|
ctc_reporter = plot_class( |
|
ctc_vis_fn, |
|
data, |
|
args.outdir + "/ctc_prob", |
|
converter=converter, |
|
transform=load_cv, |
|
device=device, |
|
subsampling_factor=total_subsampling_factor, |
|
) |
|
trainer.extend(ctc_reporter, trigger=(1, "epoch")) |
|
else: |
|
ctc_reporter = None |
|
|
|
|
|
trainer.extend( |
|
extensions.PlotReport( |
|
[ |
|
"main/loss", |
|
"validation/main/loss", |
|
"main/loss_asr", |
|
"validation/main/loss_asr", |
|
"main/loss_mt", |
|
"validation/main/loss_mt", |
|
"main/loss_st", |
|
"validation/main/loss_st", |
|
], |
|
"epoch", |
|
file_name="loss.png", |
|
) |
|
) |
|
trainer.extend( |
|
extensions.PlotReport( |
|
[ |
|
"main/acc", |
|
"validation/main/acc", |
|
"main/acc_asr", |
|
"validation/main/acc_asr", |
|
"main/acc_mt", |
|
"validation/main/acc_mt", |
|
], |
|
"epoch", |
|
file_name="acc.png", |
|
) |
|
) |
|
trainer.extend( |
|
extensions.PlotReport( |
|
["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png" |
|
) |
|
) |
|
|
|
|
|
trainer.extend( |
|
snapshot_object(model, "model.loss.best"), |
|
trigger=training.triggers.MinValueTrigger("validation/main/loss"), |
|
) |
|
trainer.extend( |
|
snapshot_object(model, "model.acc.best"), |
|
trigger=training.triggers.MaxValueTrigger("validation/main/acc"), |
|
) |
|
|
|
|
|
if args.save_interval_iters > 0: |
|
trainer.extend( |
|
torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), |
|
trigger=(args.save_interval_iters, "iteration"), |
|
) |
|
else: |
|
trainer.extend(torch_snapshot(), trigger=(1, "epoch")) |
|
|
|
|
|
if args.opt == "adadelta": |
|
if args.criterion == "acc": |
|
trainer.extend( |
|
restore_snapshot( |
|
model, args.outdir + "/model.acc.best", load_fn=torch_load |
|
), |
|
trigger=CompareValueTrigger( |
|
"validation/main/acc", |
|
lambda best_value, current_value: best_value > current_value, |
|
), |
|
) |
|
trainer.extend( |
|
adadelta_eps_decay(args.eps_decay), |
|
trigger=CompareValueTrigger( |
|
"validation/main/acc", |
|
lambda best_value, current_value: best_value > current_value, |
|
), |
|
) |
|
elif args.criterion == "loss": |
|
trainer.extend( |
|
restore_snapshot( |
|
model, args.outdir + "/model.loss.best", load_fn=torch_load |
|
), |
|
trigger=CompareValueTrigger( |
|
"validation/main/loss", |
|
lambda best_value, current_value: best_value < current_value, |
|
), |
|
) |
|
trainer.extend( |
|
adadelta_eps_decay(args.eps_decay), |
|
trigger=CompareValueTrigger( |
|
"validation/main/loss", |
|
lambda best_value, current_value: best_value < current_value, |
|
), |
|
) |
|
elif args.opt == "adam": |
|
if args.criterion == "acc": |
|
trainer.extend( |
|
restore_snapshot( |
|
model, args.outdir + "/model.acc.best", load_fn=torch_load |
|
), |
|
trigger=CompareValueTrigger( |
|
"validation/main/acc", |
|
lambda best_value, current_value: best_value > current_value, |
|
), |
|
) |
|
trainer.extend( |
|
adam_lr_decay(args.lr_decay), |
|
trigger=CompareValueTrigger( |
|
"validation/main/acc", |
|
lambda best_value, current_value: best_value > current_value, |
|
), |
|
) |
|
elif args.criterion == "loss": |
|
trainer.extend( |
|
restore_snapshot( |
|
model, args.outdir + "/model.loss.best", load_fn=torch_load |
|
), |
|
trigger=CompareValueTrigger( |
|
"validation/main/loss", |
|
lambda best_value, current_value: best_value < current_value, |
|
), |
|
) |
|
trainer.extend( |
|
adam_lr_decay(args.lr_decay), |
|
trigger=CompareValueTrigger( |
|
"validation/main/loss", |
|
lambda best_value, current_value: best_value < current_value, |
|
), |
|
) |
|
|
|
|
|
trainer.extend( |
|
extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) |
|
) |
|
report_keys = [ |
|
"epoch", |
|
"iteration", |
|
"main/loss", |
|
"main/loss_st", |
|
"main/loss_asr", |
|
"validation/main/loss", |
|
"validation/main/loss_st", |
|
"validation/main/loss_asr", |
|
"main/acc", |
|
"validation/main/acc", |
|
] |
|
if args.asr_weight > 0: |
|
report_keys.append("main/acc_asr") |
|
report_keys.append("validation/main/acc_asr") |
|
report_keys += ["elapsed_time"] |
|
if args.opt == "adadelta": |
|
trainer.extend( |
|
extensions.observe_value( |
|
"eps", |
|
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ |
|
"eps" |
|
], |
|
), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
report_keys.append("eps") |
|
elif args.opt in ["adam", "noam"]: |
|
trainer.extend( |
|
extensions.observe_value( |
|
"lr", |
|
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ |
|
"lr" |
|
], |
|
), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
report_keys.append("lr") |
|
if args.asr_weight > 0: |
|
if args.mtlalpha > 0: |
|
report_keys.append("main/cer_ctc") |
|
report_keys.append("validation/main/cer_ctc") |
|
if args.mtlalpha < 1: |
|
if args.report_cer: |
|
report_keys.append("validation/main/cer") |
|
if args.report_wer: |
|
report_keys.append("validation/main/wer") |
|
if args.report_bleu: |
|
report_keys.append("main/bleu") |
|
report_keys.append("validation/main/bleu") |
|
trainer.extend( |
|
extensions.PrintReport(report_keys), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
|
|
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) |
|
set_early_stop(trainer, args) |
|
|
|
if args.tensorboard_dir is not None and args.tensorboard_dir != "": |
|
trainer.extend( |
|
TensorboardLogger( |
|
SummaryWriter(args.tensorboard_dir), |
|
att_reporter=att_reporter, |
|
ctc_reporter=ctc_reporter, |
|
), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
|
|
trainer.run() |
|
check_early_stop(trainer, args.epochs) |
|
|
|
|
|
def trans(args): |
|
"""Decode with the given args. |
|
|
|
Args: |
|
args (namespace): The program arguments. |
|
|
|
""" |
|
set_deterministic_pytorch(args) |
|
model, train_args = load_trained_model(args.model) |
|
assert isinstance(model, STInterface) |
|
model.trans_args = args |
|
|
|
|
|
if args.ngpu == 1: |
|
gpu_id = list(range(args.ngpu)) |
|
logging.info("gpu id: " + str(gpu_id)) |
|
model.cuda() |
|
|
|
|
|
with open(args.trans_json, "rb") as f: |
|
js = json.load(f)["utts"] |
|
new_js = {} |
|
|
|
load_inputs_and_targets = LoadInputsAndTargets( |
|
mode="asr", |
|
load_output=False, |
|
sort_in_input_length=False, |
|
preprocess_conf=train_args.preprocess_conf |
|
if args.preprocess_conf is None |
|
else args.preprocess_conf, |
|
preprocess_args={"train": False}, |
|
) |
|
|
|
if args.batchsize == 0: |
|
with torch.no_grad(): |
|
for idx, name in enumerate(js.keys(), 1): |
|
logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) |
|
batch = [(name, js[name])] |
|
feat = load_inputs_and_targets(batch)[0][0] |
|
nbest_hyps = model.translate( |
|
feat, |
|
args, |
|
train_args.char_list, |
|
) |
|
new_js[name] = add_results_to_json( |
|
js[name], nbest_hyps, train_args.char_list |
|
) |
|
|
|
else: |
|
|
|
def grouper(n, iterable, fillvalue=None): |
|
kargs = [iter(iterable)] * n |
|
return zip_longest(*kargs, fillvalue=fillvalue) |
|
|
|
|
|
keys = list(js.keys()) |
|
if args.batchsize > 1: |
|
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] |
|
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) |
|
keys = [keys[i] for i in sorted_index] |
|
|
|
with torch.no_grad(): |
|
for names in grouper(args.batchsize, keys, None): |
|
names = [name for name in names if name] |
|
batch = [(name, js[name]) for name in names] |
|
feats = load_inputs_and_targets(batch)[0] |
|
nbest_hyps = model.translate_batch( |
|
feats, |
|
args, |
|
train_args.char_list, |
|
) |
|
|
|
for i, nbest_hyp in enumerate(nbest_hyps): |
|
name = names[i] |
|
new_js[name] = add_results_to_json( |
|
js[name], nbest_hyp, train_args.char_list |
|
) |
|
|
|
with open(args.result_label, "wb") as f: |
|
f.write( |
|
json.dumps( |
|
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True |
|
).encode("utf_8") |
|
) |
|
|