# Copyright 2019 Kyoto University (Hirofumi Inaguma) # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) """Training/decoding definition for the speech translation task.""" import json import logging import os import sys from chainer import training from chainer.training import extensions import numpy as np from tensorboardX import SummaryWriter import torch from espnet.asr.asr_utils import adadelta_eps_decay from espnet.asr.asr_utils import adam_lr_decay from espnet.asr.asr_utils import add_results_to_json from espnet.asr.asr_utils import CompareValueTrigger from espnet.asr.asr_utils import restore_snapshot from espnet.asr.asr_utils import snapshot_object from espnet.asr.asr_utils import torch_load from espnet.asr.asr_utils import torch_resume from espnet.asr.asr_utils import torch_snapshot from espnet.asr.pytorch_backend.asr_init import load_trained_model from espnet.asr.pytorch_backend.asr_init import load_trained_modules from espnet.nets.pytorch_backend.e2e_asr import pad_list from espnet.nets.st_interface import STInterface from espnet.utils.dataset import ChainerDataLoader from espnet.utils.dataset import TransformDataset from espnet.utils.deterministic_utils import set_deterministic_pytorch from espnet.utils.dynamic_import import dynamic_import from espnet.utils.io_utils import LoadInputsAndTargets from espnet.utils.training.batchfy import make_batchset from espnet.utils.training.iterators import ShufflingEnabler from espnet.utils.training.tensorboard_logger import TensorboardLogger from espnet.utils.training.train_utils import check_early_stop from espnet.utils.training.train_utils import set_early_stop from espnet.asr.pytorch_backend.asr import CustomConverter as ASRCustomConverter from espnet.asr.pytorch_backend.asr import CustomEvaluator from espnet.asr.pytorch_backend.asr import CustomUpdater import matplotlib matplotlib.use("Agg") if sys.version_info[0] == 2: from itertools import izip_longest as zip_longest else: from itertools import zip_longest as zip_longest class CustomConverter(ASRCustomConverter): """Custom batch converter for Pytorch. Args: subsampling_factor (int): The subsampling factor. dtype (torch.dtype): Data type to convert. use_source_text (bool): use source transcription. """ def __init__( self, subsampling_factor=1, dtype=torch.float32, use_source_text=False ): """Construct a CustomConverter object.""" super().__init__(subsampling_factor=subsampling_factor, dtype=dtype) self.use_source_text = use_source_text def __call__(self, batch, device=torch.device("cpu")): """Transform a batch and send it to a device. Args: batch (list): The batch to transform. device (torch.device): The device to send to. Returns: tuple(torch.Tensor, torch.Tensor, torch.Tensor) """ # batch should be located in list assert len(batch) == 1 xs, ys, ys_src = batch[0] # get batch of lengths of input sequences ilens = np.array([x.shape[0] for x in xs]) ilens = torch.from_numpy(ilens).to(device) xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to( device, dtype=self.dtype ) ys_pad = pad_list( [torch.from_numpy(np.array(y, dtype=np.int64)) for y in ys], self.ignore_id, ).to(device) if self.use_source_text: ys_pad_src = pad_list( [torch.from_numpy(np.array(y, dtype=np.int64)) for y in ys_src], self.ignore_id, ).to(device) else: ys_pad_src = None return xs_pad, ilens, ys_pad, ys_pad_src def train(args): """Train with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) # check cuda availability if not torch.cuda.is_available(): logging.warning("cuda is not available") # get input and output dimension info with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] utts = list(valid_json.keys()) idim = int(valid_json[utts[0]]["input"][0]["shape"][-1]) odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) logging.info("#input dims : " + str(idim)) logging.info("#output dims: " + str(odim)) # Initialize with pre-trained ASR encoder and MT decoder if args.enc_init is not None or args.dec_init is not None: model = load_trained_modules(idim, odim, args, interface=STInterface) else: model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) assert isinstance(model, STInterface) total_subsampling_factor = model.get_total_subsampling_factor() # write model config if not os.path.exists(args.outdir): os.makedirs(args.outdir) model_conf = args.outdir + "/model.json" with open(model_conf, "wb") as f: logging.info("writing a model config file to " + model_conf) f.write( json.dumps( (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") ) for key in sorted(vars(args).keys()): logging.info("ARGS: " + key + ": " + str(vars(args)[key])) reporter = model.reporter # check the use of multi-gpu if args.ngpu > 1: if args.batch_size != 0: logging.warning( "batch size is automatically increased (%d -> %d)" % (args.batch_size, args.batch_size * args.ngpu) ) args.batch_size *= args.ngpu # set torch device device = torch.device("cuda" if args.ngpu > 0 else "cpu") if args.train_dtype in ("float16", "float32", "float64"): dtype = getattr(torch, args.train_dtype) else: dtype = torch.float32 model = model.to(device=device, dtype=dtype) logging.warning( "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), sum(p.numel() for p in model.parameters() if p.requires_grad) * 100.0 / sum(p.numel() for p in model.parameters()), ) ) # Setup an optimizer if args.opt == "adadelta": optimizer = torch.optim.Adadelta( model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay ) elif args.opt == "adam": optimizer = torch.optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) elif args.opt == "noam": from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt optimizer = get_std_opt( model.parameters(), args.adim, args.transformer_warmup_steps, args.transformer_lr, ) else: raise NotImplementedError("unknown optimizer: " + args.opt) # setup apex.amp if args.train_dtype in ("O0", "O1", "O2", "O3"): try: from apex import amp except ImportError as e: logging.error( f"You need to install apex for --train-dtype {args.train_dtype}. " "See https://github.com/NVIDIA/apex#linux" ) raise e if args.opt == "noam": model, optimizer.optimizer = amp.initialize( model, optimizer.optimizer, opt_level=args.train_dtype ) else: model, optimizer = amp.initialize( model, optimizer, opt_level=args.train_dtype ) use_apex = True else: use_apex = False # FIXME: TOO DIRTY HACK setattr(optimizer, "target", reporter) setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) # Setup a converter converter = CustomConverter( subsampling_factor=model.subsample[0], dtype=dtype, use_source_text=args.asr_weight > 0 or args.mt_weight > 0, ) # read json data with open(args.train_json, "rb") as f: train_json = json.load(f)["utts"] with open(args.valid_json, "rb") as f: valid_json = json.load(f)["utts"] use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 # make minibatch list (variable length) train = make_batchset( train_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, shortest_first=use_sortagrad, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) valid = make_batchset( valid_json, args.batch_size, args.maxlen_in, args.maxlen_out, args.minibatches, min_batch_size=args.ngpu if args.ngpu > 1 else 1, count=args.batch_count, batch_bins=args.batch_bins, batch_frames_in=args.batch_frames_in, batch_frames_out=args.batch_frames_out, batch_frames_inout=args.batch_frames_inout, iaxis=0, oaxis=0, ) load_tr = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={"train": True}, # Switch the mode of preprocessing ) load_cv = LoadInputsAndTargets( mode="asr", load_output=True, preprocess_conf=args.preprocess_conf, preprocess_args={"train": False}, # Switch the mode of preprocessing ) # hack to make batchsize argument as 1 # actual bathsize is included in a list # default collate function converts numpy array to pytorch tensor # we used an empty collate function instead which returns list train_iter = ChainerDataLoader( dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), batch_size=1, num_workers=args.n_iter_processes, shuffle=not use_sortagrad, collate_fn=lambda x: x[0], ) valid_iter = ChainerDataLoader( dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), batch_size=1, shuffle=False, collate_fn=lambda x: x[0], num_workers=args.n_iter_processes, ) # Set up a trainer updater = CustomUpdater( model, args.grad_clip, {"main": train_iter}, optimizer, device, args.ngpu, args.grad_noise, args.accum_grad, use_apex=use_apex, ) trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) if use_sortagrad: trainer.extend( ShufflingEnabler([train_iter]), trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), ) # Resume from a snapshot if args.resume: logging.info("resumed from %s" % args.resume) torch_resume(args.resume, trainer) # Evaluate the model with the test dataset for each epoch if args.save_interval_iters > 0: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend( CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu) ) # Save attention weight at each epoch if args.num_save_attention > 0: data = sorted( list(valid_json.items())[: args.num_save_attention], key=lambda x: int(x[1]["input"][0]["shape"][1]), reverse=True, ) if hasattr(model, "module"): att_vis_fn = model.module.calculate_all_attentions plot_class = model.module.attention_plot_class else: att_vis_fn = model.calculate_all_attentions plot_class = model.attention_plot_class att_reporter = plot_class( att_vis_fn, data, args.outdir + "/att_ws", converter=converter, transform=load_cv, device=device, subsampling_factor=total_subsampling_factor, ) trainer.extend(att_reporter, trigger=(1, "epoch")) else: att_reporter = None # Save CTC prob at each epoch if (args.asr_weight > 0 and args.mtlalpha > 0) and args.num_save_ctc > 0: # NOTE: sort it by output lengths data = sorted( list(valid_json.items())[: args.num_save_ctc], key=lambda x: int(x[1]["output"][0]["shape"][0]), reverse=True, ) if hasattr(model, "module"): ctc_vis_fn = model.module.calculate_all_ctc_probs plot_class = model.module.ctc_plot_class else: ctc_vis_fn = model.calculate_all_ctc_probs plot_class = model.ctc_plot_class ctc_reporter = plot_class( ctc_vis_fn, data, args.outdir + "/ctc_prob", converter=converter, transform=load_cv, device=device, subsampling_factor=total_subsampling_factor, ) trainer.extend(ctc_reporter, trigger=(1, "epoch")) else: ctc_reporter = None # Make a plot for training and validation values trainer.extend( extensions.PlotReport( [ "main/loss", "validation/main/loss", "main/loss_asr", "validation/main/loss_asr", "main/loss_mt", "validation/main/loss_mt", "main/loss_st", "validation/main/loss_st", ], "epoch", file_name="loss.png", ) ) trainer.extend( extensions.PlotReport( [ "main/acc", "validation/main/acc", "main/acc_asr", "validation/main/acc_asr", "main/acc_mt", "validation/main/acc_mt", ], "epoch", file_name="acc.png", ) ) trainer.extend( extensions.PlotReport( ["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png" ) ) # Save best models trainer.extend( snapshot_object(model, "model.loss.best"), trigger=training.triggers.MinValueTrigger("validation/main/loss"), ) trainer.extend( snapshot_object(model, "model.acc.best"), trigger=training.triggers.MaxValueTrigger("validation/main/acc"), ) # save snapshot which contains model and optimizer states if args.save_interval_iters > 0: trainer.extend( torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), trigger=(args.save_interval_iters, "iteration"), ) else: trainer.extend(torch_snapshot(), trigger=(1, "epoch")) # epsilon decay in the optimizer if args.opt == "adadelta": if args.criterion == "acc": trainer.extend( restore_snapshot( model, args.outdir + "/model.acc.best", load_fn=torch_load ), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot( model, args.outdir + "/model.loss.best", load_fn=torch_load ), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adadelta_eps_decay(args.eps_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) elif args.opt == "adam": if args.criterion == "acc": trainer.extend( restore_snapshot( model, args.outdir + "/model.acc.best", load_fn=torch_load ), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) trainer.extend( adam_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/acc", lambda best_value, current_value: best_value > current_value, ), ) elif args.criterion == "loss": trainer.extend( restore_snapshot( model, args.outdir + "/model.loss.best", load_fn=torch_load ), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) trainer.extend( adam_lr_decay(args.lr_decay), trigger=CompareValueTrigger( "validation/main/loss", lambda best_value, current_value: best_value < current_value, ), ) # Write a log of evaluation statistics for each epoch trainer.extend( extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) ) report_keys = [ "epoch", "iteration", "main/loss", "main/loss_st", "main/loss_asr", "validation/main/loss", "validation/main/loss_st", "validation/main/loss_asr", "main/acc", "validation/main/acc", ] if args.asr_weight > 0: report_keys.append("main/acc_asr") report_keys.append("validation/main/acc_asr") report_keys += ["elapsed_time"] if args.opt == "adadelta": trainer.extend( extensions.observe_value( "eps", lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ "eps" ], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("eps") elif args.opt in ["adam", "noam"]: trainer.extend( extensions.observe_value( "lr", lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ "lr" ], ), trigger=(args.report_interval_iters, "iteration"), ) report_keys.append("lr") if args.asr_weight > 0: if args.mtlalpha > 0: report_keys.append("main/cer_ctc") report_keys.append("validation/main/cer_ctc") if args.mtlalpha < 1: if args.report_cer: report_keys.append("validation/main/cer") if args.report_wer: report_keys.append("validation/main/wer") if args.report_bleu: report_keys.append("main/bleu") report_keys.append("validation/main/bleu") trainer.extend( extensions.PrintReport(report_keys), trigger=(args.report_interval_iters, "iteration"), ) trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) set_early_stop(trainer, args) if args.tensorboard_dir is not None and args.tensorboard_dir != "": trainer.extend( TensorboardLogger( SummaryWriter(args.tensorboard_dir), att_reporter=att_reporter, ctc_reporter=ctc_reporter, ), trigger=(args.report_interval_iters, "iteration"), ) # Run the training trainer.run() check_early_stop(trainer, args.epochs) def trans(args): """Decode with the given args. Args: args (namespace): The program arguments. """ set_deterministic_pytorch(args) model, train_args = load_trained_model(args.model) assert isinstance(model, STInterface) model.trans_args = args # gpu if args.ngpu == 1: gpu_id = list(range(args.ngpu)) logging.info("gpu id: " + str(gpu_id)) model.cuda() # read json data with open(args.trans_json, "rb") as f: js = json.load(f)["utts"] new_js = {} load_inputs_and_targets = LoadInputsAndTargets( mode="asr", load_output=False, sort_in_input_length=False, preprocess_conf=train_args.preprocess_conf if args.preprocess_conf is None else args.preprocess_conf, preprocess_args={"train": False}, ) if args.batchsize == 0: with torch.no_grad(): for idx, name in enumerate(js.keys(), 1): logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) batch = [(name, js[name])] feat = load_inputs_and_targets(batch)[0][0] nbest_hyps = model.translate( feat, args, train_args.char_list, ) new_js[name] = add_results_to_json( js[name], nbest_hyps, train_args.char_list ) else: def grouper(n, iterable, fillvalue=None): kargs = [iter(iterable)] * n return zip_longest(*kargs, fillvalue=fillvalue) # sort data if batchsize > 1 keys = list(js.keys()) if args.batchsize > 1: feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) keys = [keys[i] for i in sorted_index] with torch.no_grad(): for names in grouper(args.batchsize, keys, None): names = [name for name in names if name] batch = [(name, js[name]) for name in names] feats = load_inputs_and_targets(batch)[0] nbest_hyps = model.translate_batch( feats, args, train_args.char_list, ) for i, nbest_hyp in enumerate(nbest_hyps): name = names[i] new_js[name] = add_results_to_json( js[name], nbest_hyp, train_args.char_list ) with open(args.result_label, "wb") as f: f.write( json.dumps( {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True ).encode("utf_8") )