diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..6769e21d99a63338394e47bc4c7d0aba1e88d5a5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..841cebf5f194f39a122794c35ac6a460dbf408ca --- /dev/null +++ b/app.py @@ -0,0 +1,180 @@ +from espnet2.bin.tts_inference import Text2Speech +import torch +from parallel_wavegan.utils import download_pretrained_model, load_model +from phonemizer import phonemize +from phonemizer.separator import Separator +import gradio as gr + +s = Separator(word=None, phone=" ") +config_path = "config.yaml" +model_path = "model.pth" + +vocoder_tag = "ljspeech_parallel_wavegan.v3" + +vocoder = load_model(download_pretrained_model(vocoder_tag)).to("cpu").eval() +vocoder.remove_weight_norm() + +global_styles = { + "Style 1": torch.load("style1.pt"), + "Style 2": torch.load("style2.pt"), + "Style 3": torch.load("style3.pt"), + "Style 4": torch.load("style4.pt"), + "Style 5": torch.load("style5.pt"), + "Style 6": torch.load("style6.pt"), +} + + +def inference(text, global_style, alpha, prev_fg_inds, input_fg_inds): + with torch.no_grad(): + text2speech = Text2Speech( + config_path, + model_path, + device="cpu", + # Only for Tacotron 2 + threshold=0.5, + minlenratio=0.0, + maxlenratio=10.0, + use_att_constraint=False, + backward_window=1, + forward_window=3, + # Only for FastSpeech & FastSpeech2 + speed_control_alpha=alpha, + ) + text2speech.spc2wav = None # Disable griffin-lim + + style_emb = torch.flatten(global_styles[global_style]) + + phoneme_string = phonemize( + text, language="mb-us1", backend="espeak-mbrola", separator=s + ) + phonemes = phoneme_string.split(" ") + + max_edit_index = -1 + for i in range(len(input_fg_inds) - 1, -1, -1): + if input_fg_inds[i] != "": + max_edit_index = i + break + + if max_edit_index == -1: + _, c, _, _, _, _, _, output_fg_inds = text2speech( + phoneme_string, ref_embs=style_emb + ) + + else: + input_fg_inds_int_list = [] + for i in range(max_edit_index + 1): + if input_fg_inds[i] != "": + input_fg_inds_int_list.append(int(input_fg_inds[i])) + else: + input_fg_inds_int_list.append(prev_fg_inds[i][1]) + input_fg_inds = input_fg_inds_int_list + + prev_fg_inds_list = [[[row[1], row[2], row[3]] for row in prev_fg_inds]] + prev_fg_inds = torch.tensor(prev_fg_inds_list, dtype=torch.int64) + + fg_inds = torch.tensor(input_fg_inds_int_list).unsqueeze(0) + _, c, _, _, _, _, _, part_output_fg_inds = text2speech( + phoneme_string, ref_embs=style_emb, fg_inds=fg_inds + ) + + prev_fg_inds[0, max_edit_index + 1 :, :] = part_output_fg_inds[0] + output_fg_inds = prev_fg_inds + + output_fg_inds_list = output_fg_inds.tolist()[0] + padded_phonemes = ["", *phonemes] + dataframe_values = [ + [phoneme, *fgs] + for phoneme, fgs in zip(padded_phonemes, output_fg_inds_list) + ] + selected_inds = [ + [input_fg_inds[i]] if i < len(input_fg_inds) else [""] + for i in range(len(padded_phonemes)) + ] + wav = vocoder.inference(c) + + return [ + (22050, wav.view(-1).cpu().numpy()), + dataframe_values, + selected_inds, + ] + + +demo = gr.Blocks() + +with demo: + gr.Markdown( + """ + + # ConEx Demo + + This demo shows the capabilities of ConEx, a model for **Con**trollable **Ex**pressive speech synthesis. + ConEx allows you to generate speech in a certain speaking style, and gives you the ability to edit the prosody* of the generated speech at a fine level. + We proposed ConEx in our paper titled ["Interactive Multi-Level Prosody Control for Expressive Speech Synthesis"](https://jessa.github.io/assets/pdf/cornille2022icassp.pdf), published in proceedings of the IEEE International Conference on Acoustics, Speech, and Signal Processing (ICASSP) 2022. + + To convert text to speech: input some text, choose the desired speaking style, set the duration factor (higher = slower speech), and press "Generate speech". + + **prosody refers to speech characteristics such as intonation, stress, rhythm* + """ + ) + + with gr.Row(): + text_input = gr.Textbox( + label="Input text", + lines=4, + placeholder="E.g. I didn't say he stole the money", + ) + + with gr.Column(): + global_style_dropdown = gr.Dropdown( + ["Style 1", "Style 2", "Style 3", "Style 4", "Style 5", "Style 6"], + value="Style 1", + label="Global speaking style", + ) + alpha_slider = gr.Slider( + 0.1, 2, value=1, step=0.1, label="Alpha (duration factor)" + ) + + audio = gr.Audio() + with gr.Row(): + button = gr.Button("Generate Speech") + + gr.Markdown( + """ + + ### Fine-grained prosody editor + Once you've generated some speech, the following table will show the id of the prosody embedding used for each phoneme. + A prosody embedding determines the prosody of the phoneme. + The table not only shows the prosody embeddings that are used by default (the top predictions), but also two more likely prosody embeddings. + + In order to change the prosody of a phoneme, write a new prosody embedding id in the "Chosen prosody embeddings" column and press "Generate speech" again. + You can use any number from 0-31, but the 2nd and 3rd predictions are more likely to give a fitting prosody. + Based on your edit, new prosody embeddings will be generated for the phonemes after the edit. + Thus, you can iteratively change the prosody by starting from the beginning of the utterance and working your through the utterance, making edits as you see fit. + The prosody embeddings before your edit will remain the same as before, and will be copied to the "Chosen prosody embeddings" column. + """ + ) + + with gr.Row(): + phoneme_preds_df = gr.Dataframe( + headers=["Phoneme", "🥇 Top pred", "🥈 2nd pred", "🥉 3rd pred"], + type="array", + col_count=(4, "static"), + ) + phoneme_edits_df = gr.Dataframe( + headers=["Chosen prosody embeddings"], type="array", col_count=(1, "static") + ) + + button.click( + inference, + inputs=[ + text_input, + global_style_dropdown, + alpha_slider, + phoneme_preds_df, + phoneme_edits_df, + ], + outputs=[audio, phoneme_preds_df, phoneme_edits_df], + ) + + +demo.launch() diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bef55a0734efd38898557463e390d9676c7fc40f --- /dev/null +++ b/config.yaml @@ -0,0 +1,266 @@ +config: conf/ar_prior_train.yaml +print_config: false +log_level: INFO +dry_run: false +iterator_type: sequence +output_dir: exp/tts_finetune_ar_prior +ngpu: 1 +seed: 0 +num_workers: 1 +num_att_plot: 3 +dist_backend: nccl +dist_init_method: env:// +dist_world_size: null +dist_rank: null +local_rank: 0 +dist_master_addr: null +dist_master_port: null +dist_launcher: null +multiprocessing_distributed: false +unused_parameters: false +sharded_ddp: false +cudnn_enabled: true +cudnn_benchmark: false +cudnn_deterministic: true +collect_stats: false +write_collected_feats: false +max_epoch: 500 +patience: null +val_scheduler_criterion: +- valid +- loss +early_stopping_criterion: +- valid +- loss +- min +best_model_criterion: +- - valid + - loss + - min +- - train + - loss + - min +keep_nbest_models: 5 +grad_clip: 1.0 +grad_clip_type: 2.0 +grad_noise: false +accum_grad: 8 +no_forward_run: false +resume: true +train_dtype: float32 +use_amp: false +log_interval: null +use_tensorboard: true +use_wandb: false +wandb_project: null +wandb_id: null +detect_anomaly: false +pretrain_path: null +init_param: +- /data/leuven/339/vsc33942/espnet-mirror/egs2/acapela_blizzard/tts1/exp/tts_train_raw_phn_none/valid.loss.best.pth:::tts.prosody_encoder.ar_prior +freeze_param: +- encoder.,prosody_encoder.ref_encoder.,prosody_encoder.fg_encoder.,prosody_encoder.global_encoder.,prosody_encoder.global_projection.,prosody_encoder.vq_layer.,prosody_encoder.qfg_projection,duration_predictor.,length_regulator,decoder.,feat_out,postnet +num_iters_per_epoch: 50 +batch_size: 20 +valid_batch_size: null +batch_bins: 3000000 +valid_batch_bins: null +train_shape_file: +- exp/tts_stats_raw_phn_none/train/text_shape.phn +- exp/tts_stats_raw_phn_none/train/speech_shape +valid_shape_file: +- exp/tts_stats_raw_phn_none/valid/text_shape.phn +- exp/tts_stats_raw_phn_none/valid/speech_shape +batch_type: numel +valid_batch_type: null +fold_length: +- 150 +- 204800 +sort_in_batch: descending +sort_batch: descending +multiple_iterator: false +chunk_length: 500 +chunk_shift_ratio: 0.5 +num_cache_chunks: 1024 +train_data_path_and_name_and_type: +- - dump/raw/tr_no_dev/text + - text + - text +- - data/durations/tr_no_dev/durations + - durations + - text_int +- - dump/raw/tr_no_dev/wav.scp + - speech + - sound +valid_data_path_and_name_and_type: +- - dump/raw/dev/text + - text + - text +- - data/durations/dev/durations + - durations + - text_int +- - dump/raw/dev/wav.scp + - speech + - sound +allow_variable_data_keys: false +max_cache_size: 0.0 +max_cache_fd: 32 +valid_max_cache_size: null +optim: adam +optim_conf: + lr: 1.0 +scheduler: noamlr +scheduler_conf: + model_size: 384 + warmup_steps: 4000 +token_list: +- +- +- n +- '@' +- t +- _ +- s +- I +- r +- d +- l +- m +- i +- '{' +- z +- D +- w +- r= +- f +- v +- E1 +- b +- t_h +- h +- V +- u +- k +- I1 +- '{1' +- k_h +- N +- EI1 +- V1 +- O1 +- AI +- H +- S +- p_h +- '@U1' +- i1 +- g +- AI1 +- j +- O +- p +- u1 +- r=1 +- tS +- Or +- '4' +- A +- Or1 +- E +- dZ +- T +- aU1 +- U +- Er1 +- '@U' +- U1 +- Ar1 +- Er +- aU +- EI +- ir1 +- l= +- OI1 +- Ar +- Ur1 +- n= +- A1 +- Z +- '?' +- ir +- Ur +- OI +- +odim: null +model_conf: {} +use_preprocessor: true +token_type: phn +bpemodel: null +non_linguistic_symbols: null +cleaner: null +g2p: null +feats_extract: fbank +feats_extract_conf: + fs: 22050 + fmin: 80 + fmax: 7600 + n_mels: 80 + hop_length: 256 + n_fft: 1024 + win_length: null +normalize: global_mvn +normalize_conf: + stats_file: feats_stats.npz +tts: fastespeech +tts_conf: + adim: 128 + aheads: 2 + elayers: 4 + eunits: 1536 + dlayers: 4 + dunits: 1536 + positionwise_layer_type: conv1d + positionwise_conv_kernel_size: 3 + duration_predictor_layers: 2 + duration_predictor_chans: 128 + duration_predictor_kernel_size: 3 + duration_predictor_dropout_rate: 0.2 + postnet_layers: 5 + postnet_filts: 5 + postnet_chans: 256 + use_masking: true + use_scaled_pos_enc: true + encoder_normalize_before: true + decoder_normalize_before: true + reduction_factor: 1 + init_type: xavier_uniform + init_enc_alpha: 1.0 + init_dec_alpha: 1.0 + transformer_enc_dropout_rate: 0.2 + transformer_enc_positional_dropout_rate: 0.2 + transformer_enc_attn_dropout_rate: 0.2 + transformer_dec_dropout_rate: 0.2 + transformer_dec_positional_dropout_rate: 0.2 + transformer_dec_attn_dropout_rate: 0.2 + ref_enc_conv_layers: 2 + ref_enc_conv_kernel_size: 3 + ref_enc_conv_stride: 2 + ref_enc_gru_layers: 1 + ref_enc_gru_units: 32 + ref_emb_integration_type: add + prosody_num_embs: 32 + prosody_hidden_dim: 3 + prosody_emb_integration_type: add +pitch_extract: null +pitch_extract_conf: {} +pitch_normalize: null +pitch_normalize_conf: {} +energy_extract: null +energy_extract_conf: {} +energy_normalize: null +energy_normalize_conf: {} +required: +- output_dir +- token_list +version: 0.9.9 +distributed: false diff --git a/espnet/__init__.py b/espnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e230942d4039bc0310b7c44d9306091d8bd75abb --- /dev/null +++ b/espnet/__init__.py @@ -0,0 +1,8 @@ +"""Initialize espnet package.""" + +import os + +dirname = os.path.dirname(__file__) +version_file = os.path.join(dirname, "version.txt") +with open(version_file, "r") as f: + __version__ = f.read().strip() diff --git a/espnet/asr/__init__.py b/espnet/asr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/asr/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/asr/asr_mix_utils.py b/espnet/asr/asr_mix_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7a84f54cbed7c3fe701d8b85f12522d2695b42 --- /dev/null +++ b/espnet/asr/asr_mix_utils.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 + +""" +This script is used to provide utility functions designed for multi-speaker ASR. + +Copyright 2017 Johns Hopkins University (Shinji Watanabe) + Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +Most functions can be directly used as in asr_utils.py: + CompareValueTrigger, restore_snapshot, adadelta_eps_decay, chainer_load, + torch_snapshot, torch_save, torch_resume, AttributeDict, get_model_conf. + +""" + +import copy +import logging +import os + +from chainer.training import extension + +import matplotlib + +from espnet.asr.asr_utils import parse_hypothesis + + +matplotlib.use("Agg") + + +# * -------------------- chainer extension related -------------------- * +class PlotAttentionReport(extension.Extension): + """Plot attention reporter. + + Args: + att_vis_fn (espnet.nets.*_backend.e2e_asr.calculate_all_attentions): + Function of attention visualization. + data (list[tuple(str, dict[str, dict[str, Any]])]): List json utt key items. + outdir (str): Directory to save figures. + converter (espnet.asr.*_backend.asr.CustomConverter): + CustomConverter object. Function to convert data. + device (torch.device): The destination device to send tensor. + reverse (bool): If True, input and output length are reversed. + + """ + + def __init__(self, att_vis_fn, data, outdir, converter, device, reverse=False): + """Initialize PlotAttentionReport.""" + self.att_vis_fn = att_vis_fn + self.data = copy.deepcopy(data) + self.outdir = outdir + self.converter = converter + self.device = device + self.reverse = reverse + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + + def __call__(self, trainer): + """Plot and save imaged matrix of att_ws.""" + att_ws_sd = self.get_attention_weights() + for ns, att_ws in enumerate(att_ws_sd): + for idx, att_w in enumerate(att_ws): + filename = "%s/%s.ep.{.updater.epoch}.output%d.png" % ( + self.outdir, + self.data[idx][0], + ns + 1, + ) + att_w = self.get_attention_weight(idx, att_w, ns) + self._plot_and_save_attention(att_w, filename.format(trainer)) + + def log_attentions(self, logger, step): + """Add image files of attention matrix to tensorboard.""" + att_ws_sd = self.get_attention_weights() + for ns, att_ws in enumerate(att_ws_sd): + for idx, att_w in enumerate(att_ws): + att_w = self.get_attention_weight(idx, att_w, ns) + plot = self.draw_attention_plot(att_w) + logger.add_figure("%s" % (self.data[idx][0]), plot.gcf(), step) + plot.clf() + + def get_attention_weights(self): + """Return attention weights. + + Returns: + arr_ws_sd (numpy.ndarray): attention weights. It's shape would be + differ from bachend.dtype=float + * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax). 2) + other case => (B, Lmax, Tmax). + * chainer-> attention weights (B, Lmax, Tmax). + + """ + batch = self.converter([self.converter.transform(self.data)], self.device) + att_ws_sd = self.att_vis_fn(*batch) + return att_ws_sd + + def get_attention_weight(self, idx, att_w, spkr_idx): + """Transform attention weight in regard to self.reverse.""" + if self.reverse: + dec_len = int(self.data[idx][1]["input"][0]["shape"][0]) + enc_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0]) + else: + dec_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0]) + enc_len = int(self.data[idx][1]["input"][0]["shape"][0]) + if len(att_w.shape) == 3: + att_w = att_w[:, :dec_len, :enc_len] + else: + att_w = att_w[:dec_len, :enc_len] + return att_w + + def draw_attention_plot(self, att_w): + """Visualize attention weights matrix. + + Args: + att_w(Tensor): Attention weight matrix. + + Returns: + matplotlib.pyplot: pyplot object with attention matrix image. + + """ + import matplotlib.pyplot as plt + + if len(att_w.shape) == 3: + for h, aw in enumerate(att_w, 1): + plt.subplot(1, len(att_w), h) + plt.imshow(aw, aspect="auto") + plt.xlabel("Encoder Index") + plt.ylabel("Decoder Index") + else: + plt.imshow(att_w, aspect="auto") + plt.xlabel("Encoder Index") + plt.ylabel("Decoder Index") + plt.tight_layout() + return plt + + def _plot_and_save_attention(self, att_w, filename): + plt = self.draw_attention_plot(att_w) + plt.savefig(filename) + plt.close() + + +def add_results_to_json(js, nbest_hyps_sd, char_list): + """Add N-best results to json. + + Args: + js (dict[str, Any]): Groundtruth utterance dict. + nbest_hyps_sd (list[dict[str, Any]]): + List of hypothesis for multi_speakers (# Utts x # Spkrs). + char_list (list[str]): List of characters. + + Returns: + dict[str, Any]: N-best results added utterance dict. + + """ + # copy old json info + new_js = dict() + new_js["utt2spk"] = js["utt2spk"] + num_spkrs = len(nbest_hyps_sd) + new_js["output"] = [] + + for ns in range(num_spkrs): + tmp_js = [] + nbest_hyps = nbest_hyps_sd[ns] + + for n, hyp in enumerate(nbest_hyps, 1): + # parse hypothesis + rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list) + + # copy ground-truth + out_dic = dict(js["output"][ns].items()) + + # update name + out_dic["name"] += "[%d]" % n + + # add recognition results + out_dic["rec_text"] = rec_text + out_dic["rec_token"] = rec_token + out_dic["rec_tokenid"] = rec_tokenid + out_dic["score"] = score + + # add to list of N-best result dicts + tmp_js.append(out_dic) + + # show 1-best result + if n == 1: + logging.info("groundtruth: %s" % out_dic["text"]) + logging.info("prediction : %s" % out_dic["rec_text"]) + + new_js["output"].append(tmp_js) + return new_js diff --git a/espnet/asr/asr_utils.py b/espnet/asr/asr_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..99f6c8c84f23d551a0d5cded109c60db874a9521 --- /dev/null +++ b/espnet/asr/asr_utils.py @@ -0,0 +1,1024 @@ +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import argparse +import copy +import json +import logging +import os +import shutil +import tempfile + +import numpy as np +import torch + + +# * -------------------- training iterator related -------------------- * + + +class CompareValueTrigger(object): + """Trigger invoked when key value getting bigger or lower than before. + + Args: + key (str) : Key of value. + compare_fn ((float, float) -> bool) : Function to compare the values. + trigger (tuple(int, str)) : Trigger that decide the comparison interval. + + """ + + def __init__(self, key, compare_fn, trigger=(1, "epoch")): + from chainer import training + + self._key = key + self._best_value = None + self._interval_trigger = training.util.get_trigger(trigger) + self._init_summary() + self._compare_fn = compare_fn + + def __call__(self, trainer): + """Get value related to the key and compare with current value.""" + observation = trainer.observation + summary = self._summary + key = self._key + if key in observation: + summary.add({key: observation[key]}) + + if not self._interval_trigger(trainer): + return False + + stats = summary.compute_mean() + value = float(stats[key]) # copy to CPU + self._init_summary() + + if self._best_value is None: + # initialize best value + self._best_value = value + return False + elif self._compare_fn(self._best_value, value): + return True + else: + self._best_value = value + return False + + def _init_summary(self): + import chainer + + self._summary = chainer.reporter.DictSummary() + + +try: + from chainer.training import extension +except ImportError: + PlotAttentionReport = None +else: + + class PlotAttentionReport(extension.Extension): + """Plot attention reporter. + + Args: + att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions): + Function of attention visualization. + data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. + outdir (str): Directory to save figures. + converter (espnet.asr.*_backend.asr.CustomConverter): + Function to convert data. + device (int | torch.device): Device. + reverse (bool): If True, input and output length are reversed. + ikey (str): Key to access input + (for ASR/ST ikey="input", for MT ikey="output".) + iaxis (int): Dimension to access input + (for ASR/ST iaxis=0, for MT iaxis=1.) + okey (str): Key to access output + (for ASR/ST okey="input", MT okay="output".) + oaxis (int): Dimension to access output + (for ASR/ST oaxis=0, for MT oaxis=0.) + subsampling_factor (int): subsampling factor in encoder + + """ + + def __init__( + self, + att_vis_fn, + data, + outdir, + converter, + transform, + device, + reverse=False, + ikey="input", + iaxis=0, + okey="output", + oaxis=0, + subsampling_factor=1, + ): + self.att_vis_fn = att_vis_fn + self.data = copy.deepcopy(data) + self.data_dict = {k: v for k, v in copy.deepcopy(data)} + # key is utterance ID + self.outdir = outdir + self.converter = converter + self.transform = transform + self.device = device + self.reverse = reverse + self.ikey = ikey + self.iaxis = iaxis + self.okey = okey + self.oaxis = oaxis + self.factor = subsampling_factor + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + + def __call__(self, trainer): + """Plot and save image file of att_ws matrix.""" + att_ws, uttid_list = self.get_attention_weights() + if isinstance(att_ws, list): # multi-encoder case + num_encs = len(att_ws) - 1 + # atts + for i in range(num_encs): + for idx, att_w in enumerate(att_ws[i]): + filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % ( + self.outdir, + uttid_list[idx], + i + 1, + ) + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % ( + self.outdir, + uttid_list[idx], + i + 1, + ) + np.save(np_filename.format(trainer), att_w) + self._plot_and_save_attention(att_w, filename.format(trainer)) + # han + for idx, att_w in enumerate(att_ws[num_encs]): + filename = "%s/%s.ep.{.updater.epoch}.han.png" % ( + self.outdir, + uttid_list[idx], + ) + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % ( + self.outdir, + uttid_list[idx], + ) + np.save(np_filename.format(trainer), att_w) + self._plot_and_save_attention( + att_w, filename.format(trainer), han_mode=True + ) + else: + for idx, att_w in enumerate(att_ws): + filename = "%s/%s.ep.{.updater.epoch}.png" % ( + self.outdir, + uttid_list[idx], + ) + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( + self.outdir, + uttid_list[idx], + ) + np.save(np_filename.format(trainer), att_w) + self._plot_and_save_attention(att_w, filename.format(trainer)) + + def log_attentions(self, logger, step): + """Add image files of att_ws matrix to the tensorboard.""" + att_ws, uttid_list = self.get_attention_weights() + if isinstance(att_ws, list): # multi-encoder case + num_encs = len(att_ws) - 1 + # atts + for i in range(num_encs): + for idx, att_w in enumerate(att_ws[i]): + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + plot = self.draw_attention_plot(att_w) + logger.add_figure( + "%s_att%d" % (uttid_list[idx], i + 1), + plot.gcf(), + step, + ) + # han + for idx, att_w in enumerate(att_ws[num_encs]): + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + plot = self.draw_han_plot(att_w) + logger.add_figure( + "%s_han" % (uttid_list[idx]), + plot.gcf(), + step, + ) + else: + for idx, att_w in enumerate(att_ws): + att_w = self.trim_attention_weight(uttid_list[idx], att_w) + plot = self.draw_attention_plot(att_w) + logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step) + + def get_attention_weights(self): + """Return attention weights. + + Returns: + numpy.ndarray: attention weights. float. Its shape would be + differ from backend. + * pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2) + other case => (B, Lmax, Tmax). + * chainer-> (B, Lmax, Tmax) + + """ + return_batch, uttid_list = self.transform(self.data, return_uttid=True) + batch = self.converter([return_batch], self.device) + if isinstance(batch, tuple): + att_ws = self.att_vis_fn(*batch) + else: + att_ws = self.att_vis_fn(**batch) + return att_ws, uttid_list + + def trim_attention_weight(self, uttid, att_w): + """Transform attention matrix with regard to self.reverse.""" + if self.reverse: + enc_key, enc_axis = self.okey, self.oaxis + dec_key, dec_axis = self.ikey, self.iaxis + else: + enc_key, enc_axis = self.ikey, self.iaxis + dec_key, dec_axis = self.okey, self.oaxis + dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0]) + enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0]) + if self.factor > 1: + enc_len //= self.factor + if len(att_w.shape) == 3: + att_w = att_w[:, :dec_len, :enc_len] + else: + att_w = att_w[:dec_len, :enc_len] + return att_w + + def draw_attention_plot(self, att_w): + """Plot the att_w matrix. + + Returns: + matplotlib.pyplot: pyplot object with attention matrix image. + + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + plt.clf() + att_w = att_w.astype(np.float32) + if len(att_w.shape) == 3: + for h, aw in enumerate(att_w, 1): + plt.subplot(1, len(att_w), h) + plt.imshow(aw, aspect="auto") + plt.xlabel("Encoder Index") + plt.ylabel("Decoder Index") + else: + plt.imshow(att_w, aspect="auto") + plt.xlabel("Encoder Index") + plt.ylabel("Decoder Index") + plt.tight_layout() + return plt + + def draw_han_plot(self, att_w): + """Plot the att_w matrix for hierarchical attention. + + Returns: + matplotlib.pyplot: pyplot object with attention matrix image. + + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + plt.clf() + if len(att_w.shape) == 3: + for h, aw in enumerate(att_w, 1): + legends = [] + plt.subplot(1, len(att_w), h) + for i in range(aw.shape[1]): + plt.plot(aw[:, i]) + legends.append("Att{}".format(i)) + plt.ylim([0, 1.0]) + plt.xlim([0, aw.shape[0]]) + plt.grid(True) + plt.ylabel("Attention Weight") + plt.xlabel("Decoder Index") + plt.legend(legends) + else: + legends = [] + for i in range(att_w.shape[1]): + plt.plot(att_w[:, i]) + legends.append("Att{}".format(i)) + plt.ylim([0, 1.0]) + plt.xlim([0, att_w.shape[0]]) + plt.grid(True) + plt.ylabel("Attention Weight") + plt.xlabel("Decoder Index") + plt.legend(legends) + plt.tight_layout() + return plt + + def _plot_and_save_attention(self, att_w, filename, han_mode=False): + if han_mode: + plt = self.draw_han_plot(att_w) + else: + plt = self.draw_attention_plot(att_w) + plt.savefig(filename) + plt.close() + + +try: + from chainer.training import extension +except ImportError: + PlotCTCReport = None +else: + + class PlotCTCReport(extension.Extension): + """Plot CTC reporter. + + Args: + ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs): + Function of CTC visualization. + data (list[tuple(str, dict[str, list[Any]])]): List json utt key items. + outdir (str): Directory to save figures. + converter (espnet.asr.*_backend.asr.CustomConverter): + Function to convert data. + device (int | torch.device): Device. + reverse (bool): If True, input and output length are reversed. + ikey (str): Key to access input + (for ASR/ST ikey="input", for MT ikey="output".) + iaxis (int): Dimension to access input + (for ASR/ST iaxis=0, for MT iaxis=1.) + okey (str): Key to access output + (for ASR/ST okey="input", MT okay="output".) + oaxis (int): Dimension to access output + (for ASR/ST oaxis=0, for MT oaxis=0.) + subsampling_factor (int): subsampling factor in encoder + + """ + + def __init__( + self, + ctc_vis_fn, + data, + outdir, + converter, + transform, + device, + reverse=False, + ikey="input", + iaxis=0, + okey="output", + oaxis=0, + subsampling_factor=1, + ): + self.ctc_vis_fn = ctc_vis_fn + self.data = copy.deepcopy(data) + self.data_dict = {k: v for k, v in copy.deepcopy(data)} + # key is utterance ID + self.outdir = outdir + self.converter = converter + self.transform = transform + self.device = device + self.reverse = reverse + self.ikey = ikey + self.iaxis = iaxis + self.okey = okey + self.oaxis = oaxis + self.factor = subsampling_factor + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + + def __call__(self, trainer): + """Plot and save image file of ctc prob.""" + ctc_probs, uttid_list = self.get_ctc_probs() + if isinstance(ctc_probs, list): # multi-encoder case + num_encs = len(ctc_probs) - 1 + for i in range(num_encs): + for idx, ctc_prob in enumerate(ctc_probs[i]): + filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % ( + self.outdir, + uttid_list[idx], + i + 1, + ) + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % ( + self.outdir, + uttid_list[idx], + i + 1, + ) + np.save(np_filename.format(trainer), ctc_prob) + self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) + else: + for idx, ctc_prob in enumerate(ctc_probs): + filename = "%s/%s.ep.{.updater.epoch}.png" % ( + self.outdir, + uttid_list[idx], + ) + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + np_filename = "%s/%s.ep.{.updater.epoch}.npy" % ( + self.outdir, + uttid_list[idx], + ) + np.save(np_filename.format(trainer), ctc_prob) + self._plot_and_save_ctc(ctc_prob, filename.format(trainer)) + + def log_ctc_probs(self, logger, step): + """Add image files of ctc probs to the tensorboard.""" + ctc_probs, uttid_list = self.get_ctc_probs() + if isinstance(ctc_probs, list): # multi-encoder case + num_encs = len(ctc_probs) - 1 + for i in range(num_encs): + for idx, ctc_prob in enumerate(ctc_probs[i]): + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + plot = self.draw_ctc_plot(ctc_prob) + logger.add_figure( + "%s_ctc%d" % (uttid_list[idx], i + 1), + plot.gcf(), + step, + ) + else: + for idx, ctc_prob in enumerate(ctc_probs): + ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob) + plot = self.draw_ctc_plot(ctc_prob) + logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step) + + def get_ctc_probs(self): + """Return CTC probs. + + Returns: + numpy.ndarray: CTC probs. float. Its shape would be + differ from backend. (B, Tmax, vocab). + + """ + return_batch, uttid_list = self.transform(self.data, return_uttid=True) + batch = self.converter([return_batch], self.device) + if isinstance(batch, tuple): + probs = self.ctc_vis_fn(*batch) + else: + probs = self.ctc_vis_fn(**batch) + return probs, uttid_list + + def trim_ctc_prob(self, uttid, prob): + """Trim CTC posteriors accoding to input lengths.""" + enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0]) + if self.factor > 1: + enc_len //= self.factor + prob = prob[:enc_len] + return prob + + def draw_ctc_plot(self, ctc_prob): + """Plot the ctc_prob matrix. + + Returns: + matplotlib.pyplot: pyplot object with CTC prob matrix image. + + """ + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + ctc_prob = ctc_prob.astype(np.float32) + + plt.clf() + topk_ids = np.argsort(ctc_prob, axis=1) + n_frames, vocab = ctc_prob.shape + times_probs = np.arange(n_frames) + + plt.figure(figsize=(20, 8)) + + # NOTE: index 0 is reserved for blank + for idx in set(topk_ids.reshape(-1).tolist()): + if idx == 0: + plt.plot( + times_probs, ctc_prob[:, 0], ":", label="", color="grey" + ) + else: + plt.plot(times_probs, ctc_prob[:, idx]) + plt.xlabel(u"Input [frame]", fontsize=12) + plt.ylabel("Posteriors", fontsize=12) + plt.xticks(list(range(0, int(n_frames) + 1, 10))) + plt.yticks(list(range(0, 2, 1))) + plt.tight_layout() + return plt + + def _plot_and_save_ctc(self, ctc_prob, filename): + plt = self.draw_ctc_plot(ctc_prob) + plt.savefig(filename) + plt.close() + + +def restore_snapshot(model, snapshot, load_fn=None): + """Extension to restore snapshot. + + Returns: + An extension function. + + """ + import chainer + from chainer import training + + if load_fn is None: + load_fn = chainer.serializers.load_npz + + @training.make_extension(trigger=(1, "epoch")) + def restore_snapshot(trainer): + _restore_snapshot(model, snapshot, load_fn) + + return restore_snapshot + + +def _restore_snapshot(model, snapshot, load_fn=None): + if load_fn is None: + import chainer + + load_fn = chainer.serializers.load_npz + + load_fn(snapshot, model) + logging.info("restored from " + str(snapshot)) + + +def adadelta_eps_decay(eps_decay): + """Extension to perform adadelta eps decay. + + Args: + eps_decay (float): Decay rate of eps. + + Returns: + An extension function. + + """ + from chainer import training + + @training.make_extension(trigger=(1, "epoch")) + def adadelta_eps_decay(trainer): + _adadelta_eps_decay(trainer, eps_decay) + + return adadelta_eps_decay + + +def _adadelta_eps_decay(trainer, eps_decay): + optimizer = trainer.updater.get_optimizer("main") + # for chainer + if hasattr(optimizer, "eps"): + current_eps = optimizer.eps + setattr(optimizer, "eps", current_eps * eps_decay) + logging.info("adadelta eps decayed to " + str(optimizer.eps)) + # pytorch + else: + for p in optimizer.param_groups: + p["eps"] *= eps_decay + logging.info("adadelta eps decayed to " + str(p["eps"])) + + +def adam_lr_decay(eps_decay): + """Extension to perform adam lr decay. + + Args: + eps_decay (float): Decay rate of lr. + + Returns: + An extension function. + + """ + from chainer import training + + @training.make_extension(trigger=(1, "epoch")) + def adam_lr_decay(trainer): + _adam_lr_decay(trainer, eps_decay) + + return adam_lr_decay + + +def _adam_lr_decay(trainer, eps_decay): + optimizer = trainer.updater.get_optimizer("main") + # for chainer + if hasattr(optimizer, "lr"): + current_lr = optimizer.lr + setattr(optimizer, "lr", current_lr * eps_decay) + logging.info("adam lr decayed to " + str(optimizer.lr)) + # pytorch + else: + for p in optimizer.param_groups: + p["lr"] *= eps_decay + logging.info("adam lr decayed to " + str(p["lr"])) + + +def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"): + """Extension to take snapshot of the trainer for pytorch. + + Returns: + An extension function. + + """ + from chainer.training import extension + + @extension.make_extension(trigger=(1, "epoch"), priority=-100) + def torch_snapshot(trainer): + _torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun) + + return torch_snapshot + + +def _torch_snapshot_object(trainer, target, filename, savefun): + from chainer.serializers import DictionarySerializer + + # make snapshot_dict dictionary + s = DictionarySerializer() + s.save(trainer) + if hasattr(trainer.updater.model, "model"): + # (for TTS) + if hasattr(trainer.updater.model.model, "module"): + model_state_dict = trainer.updater.model.model.module.state_dict() + else: + model_state_dict = trainer.updater.model.model.state_dict() + else: + # (for ASR) + if hasattr(trainer.updater.model, "module"): + model_state_dict = trainer.updater.model.module.state_dict() + else: + model_state_dict = trainer.updater.model.state_dict() + snapshot_dict = { + "trainer": s.target, + "model": model_state_dict, + "optimizer": trainer.updater.get_optimizer("main").state_dict(), + } + + # save snapshot dictionary + fn = filename.format(trainer) + prefix = "tmp" + fn + tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out) + tmppath = os.path.join(tmpdir, fn) + try: + savefun(snapshot_dict, tmppath) + shutil.move(tmppath, os.path.join(trainer.out, fn)) + finally: + shutil.rmtree(tmpdir) + + +def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55): + """Adds noise from a standard normal distribution to the gradients. + + The standard deviation (`sigma`) is controlled by the three hyper-parameters below. + `sigma` goes to zero (no noise) with more iterations. + + Args: + model (torch.nn.model): Model. + iteration (int): Number of iterations. + duration (int) {100, 1000}: + Number of durations to control the interval of the `sigma` change. + eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`. + scale_factor (float) {0.55}: The scale of `sigma`. + """ + interval = (iteration // duration) + 1 + sigma = eta / interval ** scale_factor + for param in model.parameters(): + if param.grad is not None: + _shape = param.grad.size() + noise = sigma * torch.randn(_shape).to(param.device) + param.grad += noise + + +# * -------------------- general -------------------- * +def get_model_conf(model_path, conf_path=None): + """Get model config information by reading a model config file (model.json). + + Args: + model_path (str): Model path. + conf_path (str): Optional model config path. + + Returns: + list[int, int, dict[str, Any]]: Config information loaded from json file. + + """ + if conf_path is None: + model_conf = os.path.dirname(model_path) + "/model.json" + else: + model_conf = conf_path + with open(model_conf, "rb") as f: + logging.info("reading a config file from " + model_conf) + confs = json.load(f) + if isinstance(confs, dict): + # for lm + args = confs + return argparse.Namespace(**args) + else: + # for asr, tts, mt + idim, odim, args = confs + return idim, odim, argparse.Namespace(**args) + + +def chainer_load(path, model): + """Load chainer model parameters. + + Args: + path (str): Model path or snapshot file path to be loaded. + model (chainer.Chain): Chainer model. + + """ + import chainer + + if "snapshot" in os.path.basename(path): + chainer.serializers.load_npz(path, model, path="updater/model:main/") + else: + chainer.serializers.load_npz(path, model) + + +def torch_save(path, model): + """Save torch model states. + + Args: + path (str): Model path to be saved. + model (torch.nn.Module): Torch model. + + """ + if hasattr(model, "module"): + torch.save(model.module.state_dict(), path) + else: + torch.save(model.state_dict(), path) + + +def snapshot_object(target, filename): + """Returns a trainer extension to take snapshots of a given object. + + Args: + target (model): Object to serialize. + filename (str): Name of the file into which the object is serialized.It can + be a format string, where the trainer object is passed to + the :meth: `str.format` method. For example, + ``'snapshot_{.updater.iteration}'`` is converted to + ``'snapshot_10000'`` at the 10,000th iteration. + + Returns: + An extension function. + + """ + from chainer.training import extension + + @extension.make_extension(trigger=(1, "epoch"), priority=-100) + def snapshot_object(trainer): + torch_save(os.path.join(trainer.out, filename.format(trainer)), target) + + return snapshot_object + + +def torch_load(path, model): + """Load torch model states. + + Args: + path (str): Model path or snapshot file path to be loaded. + model (torch.nn.Module): Torch model. + + """ + if "snapshot" in os.path.basename(path): + model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[ + "model" + ] + else: + model_state_dict = torch.load(path, map_location=lambda storage, loc: storage) + + if hasattr(model, "module"): + model.module.load_state_dict(model_state_dict) + else: + model.load_state_dict(model_state_dict) + + del model_state_dict + + +def torch_resume(snapshot_path, trainer): + """Resume from snapshot for pytorch. + + Args: + snapshot_path (str): Snapshot file path. + trainer (chainer.training.Trainer): Chainer's trainer instance. + + """ + from chainer.serializers import NpzDeserializer + + # load snapshot + snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage) + + # restore trainer states + d = NpzDeserializer(snapshot_dict["trainer"]) + d.load(trainer) + + # restore model states + if hasattr(trainer.updater.model, "model"): + # (for TTS model) + if hasattr(trainer.updater.model.model, "module"): + trainer.updater.model.model.module.load_state_dict(snapshot_dict["model"]) + else: + trainer.updater.model.model.load_state_dict(snapshot_dict["model"]) + else: + # (for ASR model) + if hasattr(trainer.updater.model, "module"): + trainer.updater.model.module.load_state_dict(snapshot_dict["model"]) + else: + trainer.updater.model.load_state_dict(snapshot_dict["model"]) + + # retore optimizer states + trainer.updater.get_optimizer("main").load_state_dict(snapshot_dict["optimizer"]) + + # delete opened snapshot + del snapshot_dict + + +# * ------------------ recognition related ------------------ * +def parse_hypothesis(hyp, char_list): + """Parse hypothesis. + + Args: + hyp (list[dict[str, Any]]): Recognition hypothesis. + char_list (list[str]): List of characters. + + Returns: + tuple(str, str, str, float) + + """ + # remove sos and get results + tokenid_as_list = list(map(int, hyp["yseq"][1:])) + token_as_list = [char_list[idx] for idx in tokenid_as_list] + score = float(hyp["score"]) + + # convert to string + tokenid = " ".join([str(idx) for idx in tokenid_as_list]) + token = " ".join(token_as_list) + text = "".join(token_as_list).replace("", " ") + + return text, token, tokenid, score + + +def add_results_to_json(js, nbest_hyps, char_list): + """Add N-best results to json. + + Args: + js (dict[str, Any]): Groundtruth utterance dict. + nbest_hyps_sd (list[dict[str, Any]]): + List of hypothesis for multi_speakers: nutts x nspkrs. + char_list (list[str]): List of characters. + + Returns: + dict[str, Any]: N-best results added utterance dict. + + """ + # copy old json info + new_js = dict() + new_js["utt2spk"] = js["utt2spk"] + new_js["output"] = [] + + for n, hyp in enumerate(nbest_hyps, 1): + # parse hypothesis + rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list) + + # copy ground-truth + if len(js["output"]) > 0: + out_dic = dict(js["output"][0].items()) + else: + # for no reference case (e.g., speech translation) + out_dic = {"name": ""} + + # update name + out_dic["name"] += "[%d]" % n + + # add recognition results + out_dic["rec_text"] = rec_text + out_dic["rec_token"] = rec_token + out_dic["rec_tokenid"] = rec_tokenid + out_dic["score"] = score + + # add to list of N-best result dicts + new_js["output"].append(out_dic) + + # show 1-best result + if n == 1: + if "text" in out_dic.keys(): + logging.info("groundtruth: %s" % out_dic["text"]) + logging.info("prediction : %s" % out_dic["rec_text"]) + + return new_js + + +def plot_spectrogram( + plt, + spec, + mode="db", + fs=None, + frame_shift=None, + bottom=True, + left=True, + right=True, + top=False, + labelbottom=True, + labelleft=True, + labelright=True, + labeltop=False, + cmap="inferno", +): + """Plot spectrogram using matplotlib. + + Args: + plt (matplotlib.pyplot): pyplot object. + spec (numpy.ndarray): Input stft (Freq, Time) + mode (str): db or linear. + fs (int): Sample frequency. To convert y-axis to kHz unit. + frame_shift (int): The frame shift of stft. To convert x-axis to second unit. + bottom (bool):Whether to draw the respective ticks. + left (bool): + right (bool): + top (bool): + labelbottom (bool):Whether to draw the respective tick labels. + labelleft (bool): + labelright (bool): + labeltop (bool): + cmap (str): Colormap defined in matplotlib. + + """ + spec = np.abs(spec) + if mode == "db": + x = 20 * np.log10(spec + np.finfo(spec.dtype).eps) + elif mode == "linear": + x = spec + else: + raise ValueError(mode) + + if fs is not None: + ytop = fs / 2000 + ylabel = "kHz" + else: + ytop = x.shape[0] + ylabel = "bin" + + if frame_shift is not None and fs is not None: + xtop = x.shape[1] * frame_shift / fs + xlabel = "s" + else: + xtop = x.shape[1] + xlabel = "frame" + + extent = (0, xtop, 0, ytop) + plt.imshow(x[::-1], cmap=cmap, extent=extent) + + if labelbottom: + plt.xlabel("time [{}]".format(xlabel)) + if labelleft: + plt.ylabel("freq [{}]".format(ylabel)) + plt.colorbar().set_label("{}".format(mode)) + + plt.tick_params( + bottom=bottom, + left=left, + right=right, + top=top, + labelbottom=labelbottom, + labelleft=labelleft, + labelright=labelright, + labeltop=labeltop, + ) + plt.axis("auto") + + +# * ------------------ recognition related ------------------ * +def format_mulenc_args(args): + """Format args for multi-encoder setup. + + It deals with following situations: (when args.num_encs=2): + 1. args.elayers = None -> args.elayers = [4, 4]; + 2. args.elayers = 4 -> args.elayers = [4, 4]; + 3. args.elayers = [4, 4, 4] -> args.elayers = [4, 4]. + + """ + # default values when None is assigned. + default_dict = { + "etype": "blstmp", + "elayers": 4, + "eunits": 300, + "subsample": "1", + "dropout_rate": 0.0, + "atype": "dot", + "adim": 320, + "awin": 5, + "aheads": 4, + "aconv_chans": -1, + "aconv_filts": 100, + } + for k in default_dict.keys(): + if isinstance(vars(args)[k], list): + if len(vars(args)[k]) != args.num_encs: + logging.warning( + "Length mismatch {}: Convert {} to {}.".format( + k, vars(args)[k], vars(args)[k][: args.num_encs] + ) + ) + vars(args)[k] = vars(args)[k][: args.num_encs] + else: + if not vars(args)[k]: + # assign default value if it is None + vars(args)[k] = default_dict[k] + logging.warning( + "{} is not specified, use default value {}.".format( + k, default_dict[k] + ) + ) + # duplicate + logging.warning( + "Type mismatch {}: Convert {} to {}.".format( + k, vars(args)[k], [vars(args)[k] for _ in range(args.num_encs)] + ) + ) + vars(args)[k] = [vars(args)[k] for _ in range(args.num_encs)] + return args diff --git a/espnet/asr/chainer_backend/__init__.py b/espnet/asr/chainer_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/asr/chainer_backend/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/asr/chainer_backend/asr.py b/espnet/asr/chainer_backend/asr.py new file mode 100644 index 0000000000000000000000000000000000000000..54b16fc1066d9655ce87dd1166a33f41a107a6e7 --- /dev/null +++ b/espnet/asr/chainer_backend/asr.py @@ -0,0 +1,575 @@ +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Training/decoding definition for the speech recognition task.""" + +import json +import logging +import os +import six + +# chainer related +import chainer + +from chainer import training + +from chainer.datasets import TransformDataset +from chainer.training import extensions + +# espnet related +from espnet.asr.asr_utils import adadelta_eps_decay +from espnet.asr.asr_utils import add_results_to_json +from espnet.asr.asr_utils import chainer_load +from espnet.asr.asr_utils import CompareValueTrigger +from espnet.asr.asr_utils import get_model_conf +from espnet.asr.asr_utils import restore_snapshot +from espnet.nets.asr_interface import ASRInterface +from espnet.utils.deterministic_utils import set_deterministic_chainer +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.io_utils import LoadInputsAndTargets +from espnet.utils.training.batchfy import make_batchset +from espnet.utils.training.evaluator import BaseEvaluator +from espnet.utils.training.iterators import ShufflingEnabler +from espnet.utils.training.iterators import ToggleableShufflingMultiprocessIterator +from espnet.utils.training.iterators import ToggleableShufflingSerialIterator +from espnet.utils.training.train_utils import check_early_stop +from espnet.utils.training.train_utils import set_early_stop + +# rnnlm +import espnet.lm.chainer_backend.extlm as extlm_chainer +import espnet.lm.chainer_backend.lm as lm_chainer + +# numpy related +import matplotlib + +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from tensorboardX import SummaryWriter + +matplotlib.use("Agg") + + +def train(args): + """Train with the given args. + + Args: + args (namespace): The program arguments. + + """ + # display chainer version + logging.info("chainer version = " + chainer.__version__) + + set_deterministic_chainer(args) + + # check cuda and cudnn availability + if not chainer.cuda.available: + logging.warning("cuda is not available") + if not chainer.cuda.cudnn_enabled: + logging.warning("cudnn is not available") + + # get input and output dimension info + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + utts = list(valid_json.keys()) + idim = int(valid_json[utts[0]]["input"][0]["shape"][1]) + odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) + logging.info("#input dims : " + str(idim)) + logging.info("#output dims: " + str(odim)) + + # specify attention, CTC, hybrid mode + if args.mtlalpha == 1.0: + mtl_mode = "ctc" + logging.info("Pure CTC mode") + elif args.mtlalpha == 0.0: + mtl_mode = "att" + logging.info("Pure attention mode") + else: + mtl_mode = "mtl" + logging.info("Multitask learning mode") + + # specify model architecture + logging.info("import model module: " + args.model_module) + model_class = dynamic_import(args.model_module) + model = model_class(idim, odim, args, flag_return=False) + assert isinstance(model, ASRInterface) + total_subsampling_factor = model.get_total_subsampling_factor() + + # write model config + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + model_conf = args.outdir + "/model.json" + with open(model_conf, "wb") as f: + logging.info("writing a model config file to " + model_conf) + f.write( + json.dumps( + (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) + for key in sorted(vars(args).keys()): + logging.info("ARGS: " + key + ": " + str(vars(args)[key])) + + # Set gpu + ngpu = args.ngpu + if ngpu == 1: + gpu_id = 0 + # Make a specified GPU current + chainer.cuda.get_device_from_id(gpu_id).use() + model.to_gpu() # Copy the model to the GPU + logging.info("single gpu calculation.") + elif ngpu > 1: + gpu_id = 0 + devices = {"main": gpu_id} + for gid in six.moves.xrange(1, ngpu): + devices["sub_%d" % gid] = gid + logging.info("multi gpu calculation (#gpus = %d)." % ngpu) + logging.warning( + "batch size is automatically increased (%d -> %d)" + % (args.batch_size, args.batch_size * args.ngpu) + ) + else: + gpu_id = -1 + logging.info("cpu calculation") + + # Setup an optimizer + if args.opt == "adadelta": + optimizer = chainer.optimizers.AdaDelta(eps=args.eps) + elif args.opt == "adam": + optimizer = chainer.optimizers.Adam() + elif args.opt == "noam": + optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9) + else: + raise NotImplementedError("args.opt={}".format(args.opt)) + + optimizer.setup(model) + optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip)) + + # Setup a converter + converter = model.custom_converter(subsampling_factor=model.subsample[0]) + + # read json data + with open(args.train_json, "rb") as f: + train_json = json.load(f)["utts"] + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + + # set up training iterator and updater + load_tr = LoadInputsAndTargets( + mode="asr", + load_output=True, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": True}, # Switch the mode of preprocessing + ) + load_cv = LoadInputsAndTargets( + mode="asr", + load_output=True, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": False}, # Switch the mode of preprocessing + ) + + use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 + accum_grad = args.accum_grad + if ngpu <= 1: + # make minibatch list (variable length) + train = make_batchset( + train_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + shortest_first=use_sortagrad, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + iaxis=0, + oaxis=0, + ) + # hack to make batchsize argument as 1 + # actual batchsize is included in a list + if args.n_iter_processes > 0: + train_iters = [ + ToggleableShufflingMultiprocessIterator( + TransformDataset(train, load_tr), + batch_size=1, + n_processes=args.n_iter_processes, + n_prefetch=8, + maxtasksperchild=20, + shuffle=not use_sortagrad, + ) + ] + else: + train_iters = [ + ToggleableShufflingSerialIterator( + TransformDataset(train, load_tr), + batch_size=1, + shuffle=not use_sortagrad, + ) + ] + + # set up updater + updater = model.custom_updater( + train_iters[0], + optimizer, + converter=converter, + device=gpu_id, + accum_grad=accum_grad, + ) + else: + if args.batch_count not in ("auto", "seq") and args.batch_size == 0: + raise NotImplementedError( + "--batch-count 'bin' and 'frame' are not implemented " + "in chainer multi gpu" + ) + # set up minibatches + train_subsets = [] + for gid in six.moves.xrange(ngpu): + # make subset + train_json_subset = { + k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid + } + # make minibatch list (variable length) + train_subsets += [ + make_batchset( + train_json_subset, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + ) + ] + + # each subset must have same length for MultiprocessParallelUpdater + maxlen = max([len(train_subset) for train_subset in train_subsets]) + for train_subset in train_subsets: + if maxlen != len(train_subset): + for i in six.moves.xrange(maxlen - len(train_subset)): + train_subset += [train_subset[i]] + + # hack to make batchsize argument as 1 + # actual batchsize is included in a list + if args.n_iter_processes > 0: + train_iters = [ + ToggleableShufflingMultiprocessIterator( + TransformDataset(train_subsets[gid], load_tr), + batch_size=1, + n_processes=args.n_iter_processes, + n_prefetch=8, + maxtasksperchild=20, + shuffle=not use_sortagrad, + ) + for gid in six.moves.xrange(ngpu) + ] + else: + train_iters = [ + ToggleableShufflingSerialIterator( + TransformDataset(train_subsets[gid], load_tr), + batch_size=1, + shuffle=not use_sortagrad, + ) + for gid in six.moves.xrange(ngpu) + ] + + # set up updater + updater = model.custom_parallel_updater( + train_iters, optimizer, converter=converter, devices=devices + ) + + # Set up a trainer + trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) + + if use_sortagrad: + trainer.extend( + ShufflingEnabler(train_iters), + trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), + ) + if args.opt == "noam": + from espnet.nets.chainer_backend.transformer.training import VaswaniRule + + trainer.extend( + VaswaniRule( + "alpha", + d=args.adim, + warmup_steps=args.transformer_warmup_steps, + scale=args.transformer_lr, + ), + trigger=(1, "iteration"), + ) + # Resume from a snapshot + if args.resume: + chainer.serializers.load_npz(args.resume, trainer) + + # set up validation iterator + valid = make_batchset( + valid_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + iaxis=0, + oaxis=0, + ) + + if args.n_iter_processes > 0: + valid_iter = chainer.iterators.MultiprocessIterator( + TransformDataset(valid, load_cv), + batch_size=1, + repeat=False, + shuffle=False, + n_processes=args.n_iter_processes, + n_prefetch=8, + maxtasksperchild=20, + ) + else: + valid_iter = chainer.iterators.SerialIterator( + TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False + ) + + # Evaluate the model with the test dataset for each epoch + trainer.extend(BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id)) + + # Save attention weight each epoch + if args.num_save_attention > 0 and args.mtlalpha != 1.0: + data = sorted( + list(valid_json.items())[: args.num_save_attention], + key=lambda x: int(x[1]["input"][0]["shape"][1]), + reverse=True, + ) + if hasattr(model, "module"): + att_vis_fn = model.module.calculate_all_attentions + plot_class = model.module.attention_plot_class + else: + att_vis_fn = model.calculate_all_attentions + plot_class = model.attention_plot_class + logging.info("Using custom PlotAttentionReport") + att_reporter = plot_class( + att_vis_fn, + data, + args.outdir + "/att_ws", + converter=converter, + transform=load_cv, + device=gpu_id, + subsampling_factor=total_subsampling_factor, + ) + trainer.extend(att_reporter, trigger=(1, "epoch")) + else: + att_reporter = None + + # Take a snapshot for each specified epoch + trainer.extend( + extensions.snapshot(filename="snapshot.ep.{.updater.epoch}"), + trigger=(1, "epoch"), + ) + + # Make a plot for training and validation values + trainer.extend( + extensions.PlotReport( + [ + "main/loss", + "validation/main/loss", + "main/loss_ctc", + "validation/main/loss_ctc", + "main/loss_att", + "validation/main/loss_att", + ], + "epoch", + file_name="loss.png", + ) + ) + trainer.extend( + extensions.PlotReport( + ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png" + ) + ) + + # Save best models + trainer.extend( + extensions.snapshot_object(model, "model.loss.best"), + trigger=training.triggers.MinValueTrigger("validation/main/loss"), + ) + if mtl_mode != "ctc": + trainer.extend( + extensions.snapshot_object(model, "model.acc.best"), + trigger=training.triggers.MaxValueTrigger("validation/main/acc"), + ) + + # epsilon decay in the optimizer + if args.opt == "adadelta": + if args.criterion == "acc" and mtl_mode != "ctc": + trainer.extend( + restore_snapshot(model, args.outdir + "/model.acc.best"), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + elif args.criterion == "loss": + trainer.extend( + restore_snapshot(model, args.outdir + "/model.loss.best"), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + + # Write a log of evaluation statistics for each epoch + trainer.extend( + extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) + ) + report_keys = [ + "epoch", + "iteration", + "main/loss", + "main/loss_ctc", + "main/loss_att", + "validation/main/loss", + "validation/main/loss_ctc", + "validation/main/loss_att", + "main/acc", + "validation/main/acc", + "elapsed_time", + ] + if args.opt == "adadelta": + trainer.extend( + extensions.observe_value( + "eps", lambda trainer: trainer.updater.get_optimizer("main").eps + ), + trigger=(args.report_interval_iters, "iteration"), + ) + report_keys.append("eps") + trainer.extend( + extensions.PrintReport(report_keys), + trigger=(args.report_interval_iters, "iteration"), + ) + + trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) + + set_early_stop(trainer, args) + if args.tensorboard_dir is not None and args.tensorboard_dir != "": + writer = SummaryWriter(args.tensorboard_dir) + trainer.extend( + TensorboardLogger(writer, att_reporter), + trigger=(args.report_interval_iters, "iteration"), + ) + + # Run the training + trainer.run() + check_early_stop(trainer, args.epochs) + + +def recog(args): + """Decode with the given args. + + Args: + args (namespace): The program arguments. + + """ + # display chainer version + logging.info("chainer version = " + chainer.__version__) + + set_deterministic_chainer(args) + + # read training config + idim, odim, train_args = get_model_conf(args.model, args.model_conf) + + for key in sorted(vars(args).keys()): + logging.info("ARGS: " + key + ": " + str(vars(args)[key])) + + # specify model architecture + logging.info("reading model parameters from " + args.model) + # To be compatible with v.0.3.0 models + if hasattr(train_args, "model_module"): + model_module = train_args.model_module + else: + model_module = "espnet.nets.chainer_backend.e2e_asr:E2E" + model_class = dynamic_import(model_module) + model = model_class(idim, odim, train_args) + assert isinstance(model, ASRInterface) + chainer_load(args.model, model) + + # read rnnlm + if args.rnnlm: + rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) + rnnlm = lm_chainer.ClassifierWithState( + lm_chainer.RNNLM( + len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit + ) + ) + chainer_load(args.rnnlm, rnnlm) + else: + rnnlm = None + + if args.word_rnnlm: + rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) + word_dict = rnnlm_args.char_list_dict + char_dict = {x: i for i, x in enumerate(train_args.char_list)} + word_rnnlm = lm_chainer.ClassifierWithState( + lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit) + ) + chainer_load(args.word_rnnlm, word_rnnlm) + + if rnnlm is not None: + rnnlm = lm_chainer.ClassifierWithState( + extlm_chainer.MultiLevelLM( + word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict + ) + ) + else: + rnnlm = lm_chainer.ClassifierWithState( + extlm_chainer.LookAheadWordLM( + word_rnnlm.predictor, word_dict, char_dict + ) + ) + + # read json data + with open(args.recog_json, "rb") as f: + js = json.load(f)["utts"] + + load_inputs_and_targets = LoadInputsAndTargets( + mode="asr", + load_output=False, + sort_in_input_length=False, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf, + preprocess_args={"train": False}, # Switch the mode of preprocessing + ) + + # decode each utterance + new_js = {} + with chainer.no_backprop_mode(): + for idx, name in enumerate(js.keys(), 1): + logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) + batch = [(name, js[name])] + feat = load_inputs_and_targets(batch)[0][0] + nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) + new_js[name] = add_results_to_json( + js[name], nbest_hyps, train_args.char_list + ) + + with open(args.result_label, "wb") as f: + f.write( + json.dumps( + {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) diff --git a/espnet/asr/pytorch_backend/__init__.py b/espnet/asr/pytorch_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/asr/pytorch_backend/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/asr/pytorch_backend/asr.py b/espnet/asr/pytorch_backend/asr.py new file mode 100644 index 0000000000000000000000000000000000000000..32c7c5b7d180717557ca2e68814e92fa7685fcda --- /dev/null +++ b/espnet/asr/pytorch_backend/asr.py @@ -0,0 +1,1500 @@ +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Training/decoding definition for the speech recognition task.""" + +import copy +import json +import logging +import math +import os +import sys + +from chainer import reporter as reporter_module +from chainer import training +from chainer.training import extensions +from chainer.training.updater import StandardUpdater +import numpy as np +from tensorboardX import SummaryWriter +import torch +from torch.nn.parallel import data_parallel + +from espnet.asr.asr_utils import adadelta_eps_decay +from espnet.asr.asr_utils import add_results_to_json +from espnet.asr.asr_utils import CompareValueTrigger +from espnet.asr.asr_utils import format_mulenc_args +from espnet.asr.asr_utils import get_model_conf +from espnet.asr.asr_utils import plot_spectrogram +from espnet.asr.asr_utils import restore_snapshot +from espnet.asr.asr_utils import snapshot_object +from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import torch_resume +from espnet.asr.asr_utils import torch_snapshot +from espnet.asr.pytorch_backend.asr_init import freeze_modules +from espnet.asr.pytorch_backend.asr_init import load_trained_model +from espnet.asr.pytorch_backend.asr_init import load_trained_modules +import espnet.lm.pytorch_backend.extlm as extlm_pytorch +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.beam_search_transducer import BeamSearchTransducer +from espnet.nets.pytorch_backend.e2e_asr import pad_list +import espnet.nets.pytorch_backend.lm.default as lm_pytorch +from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E +from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E +from espnet.transform.spectrogram import IStft +from espnet.transform.transformation import Transformation +from espnet.utils.cli_writers import file_writer_helper +from espnet.utils.dataset import ChainerDataLoader +from espnet.utils.dataset import TransformDataset +from espnet.utils.deterministic_utils import set_deterministic_pytorch +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.io_utils import LoadInputsAndTargets +from espnet.utils.training.batchfy import make_batchset +from espnet.utils.training.evaluator import BaseEvaluator +from espnet.utils.training.iterators import ShufflingEnabler +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from espnet.utils.training.train_utils import check_early_stop +from espnet.utils.training.train_utils import set_early_stop + +import matplotlib + +matplotlib.use("Agg") + +if sys.version_info[0] == 2: + from itertools import izip_longest as zip_longest +else: + from itertools import zip_longest as zip_longest + + +def _recursive_to(xs, device): + if torch.is_tensor(xs): + return xs.to(device) + if isinstance(xs, tuple): + return tuple(_recursive_to(x, device) for x in xs) + return xs + + +class CustomEvaluator(BaseEvaluator): + """Custom Evaluator for Pytorch. + + Args: + model (torch.nn.Module): The model to evaluate. + iterator (chainer.dataset.Iterator) : The train iterator. + + target (link | dict[str, link]) :Link object or a dictionary of + links to evaluate. If this is just a link object, the link is + registered by the name ``'main'``. + + device (torch.device): The device used. + ngpu (int): The number of GPUs. + + """ + + def __init__(self, model, iterator, target, device, ngpu=None): + super(CustomEvaluator, self).__init__(iterator, target) + self.model = model + self.device = device + if ngpu is not None: + self.ngpu = ngpu + elif device.type == "cpu": + self.ngpu = 0 + else: + self.ngpu = 1 + + # The core part of the update routine can be customized by overriding + def evaluate(self): + """Main evaluate routine for CustomEvaluator.""" + iterator = self._iterators["main"] + + if self.eval_hook: + self.eval_hook(self) + + if hasattr(iterator, "reset"): + iterator.reset() + it = iterator + else: + it = copy.copy(iterator) + + summary = reporter_module.DictSummary() + + self.model.eval() + with torch.no_grad(): + for batch in it: + x = _recursive_to(batch, self.device) + observation = {} + with reporter_module.report_scope(observation): + # read scp files + # x: original json with loaded features + # will be converted to chainer variable later + if self.ngpu == 0: + self.model(*x) + else: + # apex does not support torch.nn.DataParallel + data_parallel(self.model, x, range(self.ngpu)) + + summary.add(observation) + self.model.train() + + return summary.compute_mean() + + +class CustomUpdater(StandardUpdater): + """Custom Updater for Pytorch. + + Args: + model (torch.nn.Module): The model to update. + grad_clip_threshold (float): The gradient clipping value to use. + train_iter (chainer.dataset.Iterator): The training iterator. + optimizer (torch.optim.optimizer): The training optimizer. + + device (torch.device): The device to use. + ngpu (int): The number of gpus to use. + use_apex (bool): The flag to use Apex in backprop. + + """ + + def __init__( + self, + model, + grad_clip_threshold, + train_iter, + optimizer, + device, + ngpu, + grad_noise=False, + accum_grad=1, + use_apex=False, + ): + super(CustomUpdater, self).__init__(train_iter, optimizer) + self.model = model + self.grad_clip_threshold = grad_clip_threshold + self.device = device + self.ngpu = ngpu + self.accum_grad = accum_grad + self.forward_count = 0 + self.grad_noise = grad_noise + self.iteration = 0 + self.use_apex = use_apex + + # The core part of the update routine can be customized by overriding. + def update_core(self): + """Main update routine of the CustomUpdater.""" + # When we pass one iterator and optimizer to StandardUpdater.__init__, + # they are automatically named 'main'. + train_iter = self.get_iterator("main") + optimizer = self.get_optimizer("main") + epoch = train_iter.epoch + + # Get the next batch (a list of json files) + batch = train_iter.next() + # self.iteration += 1 # Increase may result in early report, + # which is done in other place automatically. + x = _recursive_to(batch, self.device) + is_new_epoch = train_iter.epoch != epoch + # When the last minibatch in the current epoch is given, + # gradient accumulation is turned off in order to evaluate the model + # on the validation set in every epoch. + # see details in https://github.com/espnet/espnet/pull/1388 + + # Compute the loss at this time step and accumulate it + if self.ngpu == 0: + loss = self.model(*x).mean() / self.accum_grad + else: + # apex does not support torch.nn.DataParallel + loss = ( + data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad + ) + if self.use_apex: + from apex import amp + + # NOTE: for a compatibility with noam optimizer + opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer + with amp.scale_loss(loss, opt) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + # gradient noise injection + if self.grad_noise: + from espnet.asr.asr_utils import add_gradient_noise + + add_gradient_noise( + self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55 + ) + + # update parameters + self.forward_count += 1 + if not is_new_epoch and self.forward_count != self.accum_grad: + return + self.forward_count = 0 + # compute the gradient norm to check if it is normal or not + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.grad_clip_threshold + ) + logging.info("grad norm={}".format(grad_norm)) + if math.isnan(grad_norm): + logging.warning("grad norm is nan. Do not update model.") + else: + optimizer.step() + optimizer.zero_grad() + + def update(self): + self.update_core() + # #iterations with accum_grad > 1 + # Ref.: https://github.com/espnet/espnet/issues/777 + if self.forward_count == 0: + self.iteration += 1 + + +class CustomConverter(object): + """Custom batch converter for Pytorch. + + Args: + subsampling_factor (int): The subsampling factor. + dtype (torch.dtype): Data type to convert. + + """ + + def __init__(self, subsampling_factor=1, dtype=torch.float32): + """Construct a CustomConverter object.""" + self.subsampling_factor = subsampling_factor + self.ignore_id = -1 + self.dtype = dtype + + def __call__(self, batch, device=torch.device("cpu")): + """Transform a batch and send it to a device. + + Args: + batch (list): The batch to transform. + device (torch.device): The device to send to. + + Returns: + tuple(torch.Tensor, torch.Tensor, torch.Tensor) + + """ + # batch should be located in list + assert len(batch) == 1 + xs, ys = batch[0] + + # perform subsampling + if self.subsampling_factor > 1: + xs = [x[:: self.subsampling_factor, :] for x in xs] + + # get batch of lengths of input sequences + ilens = np.array([x.shape[0] for x in xs]) + + # perform padding and convert to tensor + # currently only support real number + if xs[0].dtype.kind == "c": + xs_pad_real = pad_list( + [torch.from_numpy(x.real).float() for x in xs], 0 + ).to(device, dtype=self.dtype) + xs_pad_imag = pad_list( + [torch.from_numpy(x.imag).float() for x in xs], 0 + ).to(device, dtype=self.dtype) + # Note(kamo): + # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E. + # Don't create ComplexTensor and give it E2E here + # because torch.nn.DataParellel can't handle it. + xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag} + else: + xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to( + device, dtype=self.dtype + ) + + ilens = torch.from_numpy(ilens).to(device) + # NOTE: this is for multi-output (e.g., speech translation) + ys_pad = pad_list( + [ + torch.from_numpy( + np.array(y[0][:]) if isinstance(y, tuple) else y + ).long() + for y in ys + ], + self.ignore_id, + ).to(device) + + return xs_pad, ilens, ys_pad + + +class CustomConverterMulEnc(object): + """Custom batch converter for Pytorch in multi-encoder case. + + Args: + subsampling_factors (list): List of subsampling factors for each encoder. + dtype (torch.dtype): Data type to convert. + + """ + + def __init__(self, subsamping_factors=[1, 1], dtype=torch.float32): + """Initialize the converter.""" + self.subsamping_factors = subsamping_factors + self.ignore_id = -1 + self.dtype = dtype + self.num_encs = len(subsamping_factors) + + def __call__(self, batch, device=torch.device("cpu")): + """Transform a batch and send it to a device. + + Args: + batch (list): The batch to transform. + device (torch.device): The device to send to. + + Returns: + tuple( list(torch.Tensor), list(torch.Tensor), torch.Tensor) + + """ + # batch should be located in list + assert len(batch) == 1 + xs_list = batch[0][: self.num_encs] + ys = batch[0][-1] + + # perform subsampling + if np.sum(self.subsamping_factors) > self.num_encs: + xs_list = [ + [x[:: self.subsampling_factors[i], :] for x in xs_list[i]] + for i in range(self.num_encs) + ] + + # get batch of lengths of input sequences + ilens_list = [ + np.array([x.shape[0] for x in xs_list[i]]) for i in range(self.num_encs) + ] + + # perform padding and convert to tensor + # currently only support real number + xs_list_pad = [ + pad_list([torch.from_numpy(x).float() for x in xs_list[i]], 0).to( + device, dtype=self.dtype + ) + for i in range(self.num_encs) + ] + + ilens_list = [ + torch.from_numpy(ilens_list[i]).to(device) for i in range(self.num_encs) + ] + # NOTE: this is for multi-task learning (e.g., speech translation) + ys_pad = pad_list( + [ + torch.from_numpy(np.array(y[0]) if isinstance(y, tuple) else y).long() + for y in ys + ], + self.ignore_id, + ).to(device) + + return xs_list_pad, ilens_list, ys_pad + + +def train(args): + """Train with the given args. + + Args: + args (namespace): The program arguments. + + """ + set_deterministic_pytorch(args) + if args.num_encs > 1: + args = format_mulenc_args(args) + + # check cuda availability + if not torch.cuda.is_available(): + logging.warning("cuda is not available") + + # get input and output dimension info + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + utts = list(valid_json.keys()) + idim_list = [ + int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs) + ] + odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) + for i in range(args.num_encs): + logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i])) + logging.info("#output dims: " + str(odim)) + + # specify attention, CTC, hybrid mode + if "transducer" in args.model_module: + if ( + getattr(args, "etype", False) == "custom" + or getattr(args, "dtype", False) == "custom" + ): + mtl_mode = "custom_transducer" + else: + mtl_mode = "transducer" + logging.info("Pure transducer mode") + elif args.mtlalpha == 1.0: + mtl_mode = "ctc" + logging.info("Pure CTC mode") + elif args.mtlalpha == 0.0: + mtl_mode = "att" + logging.info("Pure attention mode") + else: + mtl_mode = "mtl" + logging.info("Multitask learning mode") + + if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1: + model = load_trained_modules(idim_list[0], odim, args) + else: + model_class = dynamic_import(args.model_module) + model = model_class( + idim_list[0] if args.num_encs == 1 else idim_list, odim, args + ) + assert isinstance(model, ASRInterface) + total_subsampling_factor = model.get_total_subsampling_factor() + + logging.info( + " Total parameter of the model = " + + str(sum(p.numel() for p in model.parameters())) + ) + + if args.rnnlm is not None: + rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) + rnnlm = lm_pytorch.ClassifierWithState( + lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit) + ) + torch_load(args.rnnlm, rnnlm) + model.rnnlm = rnnlm + + # write model config + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + model_conf = args.outdir + "/model.json" + with open(model_conf, "wb") as f: + logging.info("writing a model config file to " + model_conf) + f.write( + json.dumps( + (idim_list[0] if args.num_encs == 1 else idim_list, odim, vars(args)), + indent=4, + ensure_ascii=False, + sort_keys=True, + ).encode("utf_8") + ) + for key in sorted(vars(args).keys()): + logging.info("ARGS: " + key + ": " + str(vars(args)[key])) + + reporter = model.reporter + + # check the use of multi-gpu + if args.ngpu > 1: + if args.batch_size != 0: + logging.warning( + "batch size is automatically increased (%d -> %d)" + % (args.batch_size, args.batch_size * args.ngpu) + ) + args.batch_size *= args.ngpu + if args.num_encs > 1: + # TODO(ruizhili): implement data parallel for multi-encoder setup. + raise NotImplementedError( + "Data parallel is not supported for multi-encoder setup." + ) + + # set torch device + device = torch.device("cuda" if args.ngpu > 0 else "cpu") + if args.train_dtype in ("float16", "float32", "float64"): + dtype = getattr(torch, args.train_dtype) + else: + dtype = torch.float32 + model = model.to(device=device, dtype=dtype) + + if args.freeze_mods: + model, model_params = freeze_modules(model, args.freeze_mods) + else: + model_params = model.parameters() + + logging.warning( + "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(p.numel() for p in model.parameters() if p.requires_grad) + * 100.0 + / sum(p.numel() for p in model.parameters()), + ) + ) + + # Setup an optimizer + if args.opt == "adadelta": + optimizer = torch.optim.Adadelta( + model_params, rho=0.95, eps=args.eps, weight_decay=args.weight_decay + ) + elif args.opt == "adam": + optimizer = torch.optim.Adam(model_params, weight_decay=args.weight_decay) + elif args.opt == "noam": + from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + + # For transformer-transducer, adim declaration is within the block definition. + # Thus, we need retrieve the most dominant value (d_hidden) for Noam scheduler. + if hasattr(args, "enc_block_arch") or hasattr(args, "dec_block_arch"): + adim = model.most_dom_dim + else: + adim = args.adim + + optimizer = get_std_opt( + model_params, adim, args.transformer_warmup_steps, args.transformer_lr + ) + else: + raise NotImplementedError("unknown optimizer: " + args.opt) + + # setup apex.amp + if args.train_dtype in ("O0", "O1", "O2", "O3"): + try: + from apex import amp + except ImportError as e: + logging.error( + f"You need to install apex for --train-dtype {args.train_dtype}. " + "See https://github.com/NVIDIA/apex#linux" + ) + raise e + if args.opt == "noam": + model, optimizer.optimizer = amp.initialize( + model, optimizer.optimizer, opt_level=args.train_dtype + ) + else: + model, optimizer = amp.initialize( + model, optimizer, opt_level=args.train_dtype + ) + use_apex = True + + from espnet.nets.pytorch_backend.ctc import CTC + + amp.register_float_function(CTC, "loss_fn") + amp.init() + logging.warning("register ctc as float function") + else: + use_apex = False + + # FIXME: TOO DIRTY HACK + setattr(optimizer, "target", reporter) + setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) + + # Setup a converter + if args.num_encs == 1: + converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype) + else: + converter = CustomConverterMulEnc( + [i[0] for i in model.subsample_list], dtype=dtype + ) + + # read json data + with open(args.train_json, "rb") as f: + train_json = json.load(f)["utts"] + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + + use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 + # make minibatch list (variable length) + train = make_batchset( + train_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + shortest_first=use_sortagrad, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + iaxis=0, + oaxis=0, + ) + valid = make_batchset( + valid_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + iaxis=0, + oaxis=0, + ) + + load_tr = LoadInputsAndTargets( + mode="asr", + load_output=True, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": True}, # Switch the mode of preprocessing + ) + load_cv = LoadInputsAndTargets( + mode="asr", + load_output=True, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": False}, # Switch the mode of preprocessing + ) + # hack to make batchsize argument as 1 + # actual bathsize is included in a list + # default collate function converts numpy array to pytorch tensor + # we used an empty collate function instead which returns list + train_iter = ChainerDataLoader( + dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), + batch_size=1, + num_workers=args.n_iter_processes, + shuffle=not use_sortagrad, + collate_fn=lambda x: x[0], + ) + valid_iter = ChainerDataLoader( + dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), + batch_size=1, + shuffle=False, + collate_fn=lambda x: x[0], + num_workers=args.n_iter_processes, + ) + + # Set up a trainer + updater = CustomUpdater( + model, + args.grad_clip, + {"main": train_iter}, + optimizer, + device, + args.ngpu, + args.grad_noise, + args.accum_grad, + use_apex=use_apex, + ) + trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) + + if use_sortagrad: + trainer.extend( + ShufflingEnabler([train_iter]), + trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), + ) + + # Resume from a snapshot + if args.resume: + logging.info("resumed from %s" % args.resume) + torch_resume(args.resume, trainer) + + # Evaluate the model with the test dataset for each epoch + if args.save_interval_iters > 0: + trainer.extend( + CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), + trigger=(args.save_interval_iters, "iteration"), + ) + else: + trainer.extend( + CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu) + ) + + # Save attention weight each epoch + is_attn_plot = ( + "transformer" in args.model_module + or "conformer" in args.model_module + or mtl_mode in ["att", "mtl", "custom_transducer"] + ) + + if args.num_save_attention > 0 and is_attn_plot: + data = sorted( + list(valid_json.items())[: args.num_save_attention], + key=lambda x: int(x[1]["input"][0]["shape"][1]), + reverse=True, + ) + if hasattr(model, "module"): + att_vis_fn = model.module.calculate_all_attentions + plot_class = model.module.attention_plot_class + else: + att_vis_fn = model.calculate_all_attentions + plot_class = model.attention_plot_class + att_reporter = plot_class( + att_vis_fn, + data, + args.outdir + "/att_ws", + converter=converter, + transform=load_cv, + device=device, + subsampling_factor=total_subsampling_factor, + ) + trainer.extend(att_reporter, trigger=(1, "epoch")) + else: + att_reporter = None + + # Save CTC prob at each epoch + if mtl_mode in ["ctc", "mtl"] and args.num_save_ctc > 0: + # NOTE: sort it by output lengths + data = sorted( + list(valid_json.items())[: args.num_save_ctc], + key=lambda x: int(x[1]["output"][0]["shape"][0]), + reverse=True, + ) + if hasattr(model, "module"): + ctc_vis_fn = model.module.calculate_all_ctc_probs + plot_class = model.module.ctc_plot_class + else: + ctc_vis_fn = model.calculate_all_ctc_probs + plot_class = model.ctc_plot_class + ctc_reporter = plot_class( + ctc_vis_fn, + data, + args.outdir + "/ctc_prob", + converter=converter, + transform=load_cv, + device=device, + subsampling_factor=total_subsampling_factor, + ) + trainer.extend(ctc_reporter, trigger=(1, "epoch")) + else: + ctc_reporter = None + + # Make a plot for training and validation values + if args.num_encs > 1: + report_keys_loss_ctc = [ + "main/loss_ctc{}".format(i + 1) for i in range(model.num_encs) + ] + ["validation/main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)] + report_keys_cer_ctc = [ + "main/cer_ctc{}".format(i + 1) for i in range(model.num_encs) + ] + ["validation/main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)] + + if hasattr(model, "is_rnnt"): + trainer.extend( + extensions.PlotReport( + [ + "main/loss", + "validation/main/loss", + "main/loss_trans", + "validation/main/loss_trans", + "main/loss_ctc", + "validation/main/loss_ctc", + "main/loss_lm", + "validation/main/loss_lm", + "main/loss_aux_trans", + "validation/main/loss_aux_trans", + "main/loss_aux_symm_kl", + "validation/main/loss_aux_symm_kl", + ], + "epoch", + file_name="loss.png", + ) + ) + else: + trainer.extend( + extensions.PlotReport( + [ + "main/loss", + "validation/main/loss", + "main/loss_ctc", + "validation/main/loss_ctc", + "main/loss_att", + "validation/main/loss_att", + ] + + ([] if args.num_encs == 1 else report_keys_loss_ctc), + "epoch", + file_name="loss.png", + ) + ) + + trainer.extend( + extensions.PlotReport( + ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png" + ) + ) + trainer.extend( + extensions.PlotReport( + ["main/cer_ctc", "validation/main/cer_ctc"] + + ([] if args.num_encs == 1 else report_keys_loss_ctc), + "epoch", + file_name="cer.png", + ) + ) + + # Save best models + trainer.extend( + snapshot_object(model, "model.loss.best"), + trigger=training.triggers.MinValueTrigger("validation/main/loss"), + ) + if mtl_mode not in ["ctc", "transducer", "custom_transducer"]: + trainer.extend( + snapshot_object(model, "model.acc.best"), + trigger=training.triggers.MaxValueTrigger("validation/main/acc"), + ) + + # save snapshot which contains model and optimizer states + if args.save_interval_iters > 0: + trainer.extend( + torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), + trigger=(args.save_interval_iters, "iteration"), + ) + + # save snapshot at every epoch - for model averaging + trainer.extend(torch_snapshot(), trigger=(1, "epoch")) + + # epsilon decay in the optimizer + if args.opt == "adadelta": + if args.criterion == "acc" and mtl_mode != "ctc": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.acc.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + elif args.criterion == "loss": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.loss.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + # NOTE: In some cases, it may take more than one epoch for the model's loss + # to escape from a local minimum. + # Thus, restore_snapshot extension is not used here. + # see details in https://github.com/espnet/espnet/pull/2171 + elif args.criterion == "loss_eps_decay_only": + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + + # Write a log of evaluation statistics for each epoch + trainer.extend( + extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) + ) + + if hasattr(model, "is_rnnt"): + report_keys = [ + "epoch", + "iteration", + "main/loss", + "main/loss_trans", + "main/loss_ctc", + "main/loss_lm", + "main/loss_aux_trans", + "main/loss_aux_symm_kl", + "validation/main/loss", + "validation/main/loss_trans", + "validation/main/loss_ctc", + "validation/main/loss_lm", + "validation/main/loss_aux_trans", + "validation/main/loss_aux_symm_kl", + "elapsed_time", + ] + else: + report_keys = [ + "epoch", + "iteration", + "main/loss", + "main/loss_ctc", + "main/loss_att", + "validation/main/loss", + "validation/main/loss_ctc", + "validation/main/loss_att", + "main/acc", + "validation/main/acc", + "main/cer_ctc", + "validation/main/cer_ctc", + "elapsed_time", + ] + ([] if args.num_encs == 1 else report_keys_cer_ctc + report_keys_loss_ctc) + + if args.opt == "adadelta": + trainer.extend( + extensions.observe_value( + "eps", + lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ + "eps" + ], + ), + trigger=(args.report_interval_iters, "iteration"), + ) + report_keys.append("eps") + if args.report_cer: + report_keys.append("validation/main/cer") + if args.report_wer: + report_keys.append("validation/main/wer") + trainer.extend( + extensions.PrintReport(report_keys), + trigger=(args.report_interval_iters, "iteration"), + ) + + trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) + set_early_stop(trainer, args) + + if args.tensorboard_dir is not None and args.tensorboard_dir != "": + trainer.extend( + TensorboardLogger( + SummaryWriter(args.tensorboard_dir), + att_reporter=att_reporter, + ctc_reporter=ctc_reporter, + ), + trigger=(args.report_interval_iters, "iteration"), + ) + # Run the training + trainer.run() + check_early_stop(trainer, args.epochs) + + +def recog(args): + """Decode with the given args. + + Args: + args (namespace): The program arguments. + + """ + set_deterministic_pytorch(args) + model, train_args = load_trained_model(args.model, training=False) + assert isinstance(model, ASRInterface) + model.recog_args = args + + if args.streaming_mode and "transformer" in train_args.model_module: + raise NotImplementedError("streaming mode for transformer is not implemented") + logging.info( + " Total parameter of the model = " + + str(sum(p.numel() for p in model.parameters())) + ) + + # read rnnlm + if args.rnnlm: + rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) + if getattr(rnnlm_args, "model_module", "default") != "default": + raise ValueError( + "use '--api v2' option to decode with non-default language model" + ) + rnnlm = lm_pytorch.ClassifierWithState( + lm_pytorch.RNNLM( + len(train_args.char_list), + rnnlm_args.layer, + rnnlm_args.unit, + getattr(rnnlm_args, "embed_unit", None), # for backward compatibility + ) + ) + torch_load(args.rnnlm, rnnlm) + rnnlm.eval() + else: + rnnlm = None + + if args.word_rnnlm: + rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) + word_dict = rnnlm_args.char_list_dict + char_dict = {x: i for i, x in enumerate(train_args.char_list)} + word_rnnlm = lm_pytorch.ClassifierWithState( + lm_pytorch.RNNLM( + len(word_dict), + rnnlm_args.layer, + rnnlm_args.unit, + getattr(rnnlm_args, "embed_unit", None), # for backward compatibility + ) + ) + torch_load(args.word_rnnlm, word_rnnlm) + word_rnnlm.eval() + + if rnnlm is not None: + rnnlm = lm_pytorch.ClassifierWithState( + extlm_pytorch.MultiLevelLM( + word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict + ) + ) + else: + rnnlm = lm_pytorch.ClassifierWithState( + extlm_pytorch.LookAheadWordLM( + word_rnnlm.predictor, word_dict, char_dict + ) + ) + + # gpu + if args.ngpu == 1: + gpu_id = list(range(args.ngpu)) + logging.info("gpu id: " + str(gpu_id)) + model.cuda() + if rnnlm: + rnnlm.cuda() + + # read json data + with open(args.recog_json, "rb") as f: + js = json.load(f)["utts"] + new_js = {} + + load_inputs_and_targets = LoadInputsAndTargets( + mode="asr", + load_output=False, + sort_in_input_length=False, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf, + preprocess_args={"train": False}, + ) + + # load transducer beam search + if hasattr(model, "is_rnnt"): + if hasattr(model, "dec"): + trans_decoder = model.dec + else: + trans_decoder = model.decoder + joint_network = model.joint_network + + beam_search_transducer = BeamSearchTransducer( + decoder=trans_decoder, + joint_network=joint_network, + beam_size=args.beam_size, + nbest=args.nbest, + lm=rnnlm, + lm_weight=args.lm_weight, + search_type=args.search_type, + max_sym_exp=args.max_sym_exp, + u_max=args.u_max, + nstep=args.nstep, + prefix_alpha=args.prefix_alpha, + score_norm=args.score_norm, + ) + + if args.batchsize == 0: + with torch.no_grad(): + for idx, name in enumerate(js.keys(), 1): + logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) + batch = [(name, js[name])] + feat = load_inputs_and_targets(batch) + feat = ( + feat[0][0] + if args.num_encs == 1 + else [feat[idx][0] for idx in range(model.num_encs)] + ) + if args.streaming_mode == "window" and args.num_encs == 1: + logging.info( + "Using streaming recognizer with window size %d frames", + args.streaming_window, + ) + se2e = WindowStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) + for i in range(0, feat.shape[0], args.streaming_window): + logging.info( + "Feeding frames %d - %d", i, i + args.streaming_window + ) + se2e.accept_input(feat[i : i + args.streaming_window]) + logging.info("Running offline attention decoder") + se2e.decode_with_attention_offline() + logging.info("Offline attention decoder finished") + nbest_hyps = se2e.retrieve_recognition() + elif args.streaming_mode == "segment" and args.num_encs == 1: + logging.info( + "Using streaming recognizer with threshold value %d", + args.streaming_min_blank_dur, + ) + nbest_hyps = [] + for n in range(args.nbest): + nbest_hyps.append({"yseq": [], "score": 0.0}) + se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) + r = np.prod(model.subsample) + for i in range(0, feat.shape[0], r): + hyps = se2e.accept_input(feat[i : i + r]) + if hyps is not None: + text = "".join( + [ + train_args.char_list[int(x)] + for x in hyps[0]["yseq"][1:-1] + if int(x) != -1 + ] + ) + text = text.replace( + "\u2581", " " + ).strip() # for SentencePiece + text = text.replace(model.space, " ") + text = text.replace(model.blank, "") + logging.info(text) + for n in range(args.nbest): + nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"]) + nbest_hyps[n]["score"] += hyps[n]["score"] + elif hasattr(model, "is_rnnt"): + nbest_hyps = model.recognize(feat, beam_search_transducer) + else: + nbest_hyps = model.recognize( + feat, args, train_args.char_list, rnnlm + ) + new_js[name] = add_results_to_json( + js[name], nbest_hyps, train_args.char_list + ) + + else: + + def grouper(n, iterable, fillvalue=None): + kargs = [iter(iterable)] * n + return zip_longest(*kargs, fillvalue=fillvalue) + + # sort data if batchsize > 1 + keys = list(js.keys()) + if args.batchsize > 1: + feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] + sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) + keys = [keys[i] for i in sorted_index] + + with torch.no_grad(): + for names in grouper(args.batchsize, keys, None): + names = [name for name in names if name] + batch = [(name, js[name]) for name in names] + feats = ( + load_inputs_and_targets(batch)[0] + if args.num_encs == 1 + else load_inputs_and_targets(batch) + ) + if args.streaming_mode == "window" and args.num_encs == 1: + raise NotImplementedError + elif args.streaming_mode == "segment" and args.num_encs == 1: + if args.batchsize > 1: + raise NotImplementedError + feat = feats[0] + nbest_hyps = [] + for n in range(args.nbest): + nbest_hyps.append({"yseq": [], "score": 0.0}) + se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm) + r = np.prod(model.subsample) + for i in range(0, feat.shape[0], r): + hyps = se2e.accept_input(feat[i : i + r]) + if hyps is not None: + text = "".join( + [ + train_args.char_list[int(x)] + for x in hyps[0]["yseq"][1:-1] + if int(x) != -1 + ] + ) + text = text.replace( + "\u2581", " " + ).strip() # for SentencePiece + text = text.replace(model.space, " ") + text = text.replace(model.blank, "") + logging.info(text) + for n in range(args.nbest): + nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"]) + nbest_hyps[n]["score"] += hyps[n]["score"] + nbest_hyps = [nbest_hyps] + else: + nbest_hyps = model.recognize_batch( + feats, args, train_args.char_list, rnnlm=rnnlm + ) + + for i, nbest_hyp in enumerate(nbest_hyps): + name = names[i] + new_js[name] = add_results_to_json( + js[name], nbest_hyp, train_args.char_list + ) + + with open(args.result_label, "wb") as f: + f.write( + json.dumps( + {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) + + +def enhance(args): + """Dumping enhanced speech and mask. + + Args: + args (namespace): The program arguments. + """ + set_deterministic_pytorch(args) + # read training config + idim, odim, train_args = get_model_conf(args.model, args.model_conf) + + # TODO(ruizhili): implement enhance for multi-encoder model + assert args.num_encs == 1, "number of encoder should be 1 ({} is given)".format( + args.num_encs + ) + + # load trained model parameters + logging.info("reading model parameters from " + args.model) + model_class = dynamic_import(train_args.model_module) + model = model_class(idim, odim, train_args) + assert isinstance(model, ASRInterface) + torch_load(args.model, model) + model.recog_args = args + + # gpu + if args.ngpu == 1: + gpu_id = list(range(args.ngpu)) + logging.info("gpu id: " + str(gpu_id)) + model.cuda() + + # read json data + with open(args.recog_json, "rb") as f: + js = json.load(f)["utts"] + + load_inputs_and_targets = LoadInputsAndTargets( + mode="asr", + load_output=False, + sort_in_input_length=False, + preprocess_conf=None, # Apply pre_process in outer func + ) + if args.batchsize == 0: + args.batchsize = 1 + + # Creates writers for outputs from the network + if args.enh_wspecifier is not None: + enh_writer = file_writer_helper(args.enh_wspecifier, filetype=args.enh_filetype) + else: + enh_writer = None + + # Creates a Transformation instance + preprocess_conf = ( + train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf + ) + if preprocess_conf is not None: + logging.info(f"Use preprocessing: {preprocess_conf}") + transform = Transformation(preprocess_conf) + else: + transform = None + + # Creates a IStft instance + istft = None + frame_shift = args.istft_n_shift # Used for plot the spectrogram + if args.apply_istft: + if preprocess_conf is not None: + # Read the conffile and find stft setting + with open(preprocess_conf) as f: + # Json format: e.g. + # {"process": [{"type": "stft", + # "win_length": 400, + # "n_fft": 512, "n_shift": 160, + # "window": "han"}, + # {"type": "foo", ...}, ...]} + conf = json.load(f) + assert "process" in conf, conf + # Find stft setting + for p in conf["process"]: + if p["type"] == "stft": + istft = IStft( + win_length=p["win_length"], + n_shift=p["n_shift"], + window=p.get("window", "hann"), + ) + logging.info( + "stft is found in {}. " + "Setting istft config from it\n{}".format( + preprocess_conf, istft + ) + ) + frame_shift = p["n_shift"] + break + if istft is None: + # Set from command line arguments + istft = IStft( + win_length=args.istft_win_length, + n_shift=args.istft_n_shift, + window=args.istft_window, + ) + logging.info( + "Setting istft config from the command line args\n{}".format(istft) + ) + + # sort data + keys = list(js.keys()) + feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] + sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) + keys = [keys[i] for i in sorted_index] + + def grouper(n, iterable, fillvalue=None): + kargs = [iter(iterable)] * n + return zip_longest(*kargs, fillvalue=fillvalue) + + num_images = 0 + if not os.path.exists(args.image_dir): + os.makedirs(args.image_dir) + + for names in grouper(args.batchsize, keys, None): + batch = [(name, js[name]) for name in names] + + # May be in time region: (Batch, [Time, Channel]) + org_feats = load_inputs_and_targets(batch)[0] + if transform is not None: + # May be in time-freq region: : (Batch, [Time, Channel, Freq]) + feats = transform(org_feats, train=False) + else: + feats = org_feats + + with torch.no_grad(): + enhanced, mask, ilens = model.enhance(feats) + + for idx, name in enumerate(names): + # Assuming mask, feats : [Batch, Time, Channel. Freq] + # enhanced : [Batch, Time, Freq] + enh = enhanced[idx][: ilens[idx]] + mas = mask[idx][: ilens[idx]] + feat = feats[idx] + + # Plot spectrogram + if args.image_dir is not None and num_images < args.num_images: + import matplotlib.pyplot as plt + + num_images += 1 + ref_ch = 0 + + plt.figure(figsize=(20, 10)) + plt.subplot(4, 1, 1) + plt.title("Mask [ref={}ch]".format(ref_ch)) + plot_spectrogram( + plt, + mas[:, ref_ch].T, + fs=args.fs, + mode="linear", + frame_shift=frame_shift, + bottom=False, + labelbottom=False, + ) + + plt.subplot(4, 1, 2) + plt.title("Noisy speech [ref={}ch]".format(ref_ch)) + plot_spectrogram( + plt, + feat[:, ref_ch].T, + fs=args.fs, + mode="db", + frame_shift=frame_shift, + bottom=False, + labelbottom=False, + ) + + plt.subplot(4, 1, 3) + plt.title("Masked speech [ref={}ch]".format(ref_ch)) + plot_spectrogram( + plt, + (feat[:, ref_ch] * mas[:, ref_ch]).T, + frame_shift=frame_shift, + fs=args.fs, + mode="db", + bottom=False, + labelbottom=False, + ) + + plt.subplot(4, 1, 4) + plt.title("Enhanced speech") + plot_spectrogram( + plt, enh.T, fs=args.fs, mode="db", frame_shift=frame_shift + ) + + plt.savefig(os.path.join(args.image_dir, name + ".png")) + plt.clf() + + # Write enhanced wave files + if enh_writer is not None: + if istft is not None: + enh = istft(enh) + else: + enh = enh + + if args.keep_length: + if len(org_feats[idx]) < len(enh): + # Truncate the frames added by stft padding + enh = enh[: len(org_feats[idx])] + elif len(org_feats) > len(enh): + padwidth = [(0, (len(org_feats[idx]) - len(enh)))] + [ + (0, 0) + ] * (enh.ndim - 1) + enh = np.pad(enh, padwidth, mode="constant") + + if args.enh_filetype in ("sound", "sound.hdf5"): + enh_writer[name] = (args.fs, enh) + else: + # Hint: To dump stft_signal, mask or etc, + # enh_filetype='hdf5' might be convenient. + enh_writer[name] = enh + + if num_images >= args.num_images and enh_writer is None: + logging.info("Breaking the process.") + break + + +def ctc_align(args): + """CTC forced alignments with the given args. + + Args: + args (namespace): The program arguments. + """ + + def add_alignment_to_json(js, alignment, char_list): + """Add N-best results to json. + + Args: + js (dict[str, Any]): Groundtruth utterance dict. + alignment (list[int]): List of alignment. + char_list (list[str]): List of characters. + + Returns: + dict[str, Any]: N-best results added utterance dict. + + """ + # copy old json info + new_js = dict() + new_js["ctc_alignment"] = [] + + alignment_tokens = [] + for idx, a in enumerate(alignment): + alignment_tokens.append(char_list[a]) + alignment_tokens = " ".join(alignment_tokens) + + new_js["ctc_alignment"] = alignment_tokens + + return new_js + + set_deterministic_pytorch(args) + model, train_args = load_trained_model(args.model) + assert isinstance(model, ASRInterface) + model.eval() + + load_inputs_and_targets = LoadInputsAndTargets( + mode="asr", + load_output=True, + sort_in_input_length=False, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf, + preprocess_args={"train": False}, + ) + + if args.ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + if args.ngpu == 1: + device = "cuda" + else: + device = "cpu" + dtype = getattr(torch, args.dtype) + logging.info(f"Decoding device={device}, dtype={dtype}") + model.to(device=device, dtype=dtype).eval() + + # read json data + with open(args.align_json, "rb") as f: + js = json.load(f)["utts"] + new_js = {} + if args.batchsize == 0: + with torch.no_grad(): + for idx, name in enumerate(js.keys(), 1): + logging.info("(%d/%d) aligning " + name, idx, len(js.keys())) + batch = [(name, js[name])] + feat, label = load_inputs_and_targets(batch) + feat = feat[0] + label = label[0] + enc = model.encode(torch.as_tensor(feat).to(device)).unsqueeze(0) + alignment = model.ctc.forced_align(enc, label) + new_js[name] = add_alignment_to_json( + js[name], alignment, train_args.char_list + ) + else: + raise NotImplementedError("Align_batch is not implemented.") + + with open(args.result_label, "wb") as f: + f.write( + json.dumps( + {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) diff --git a/espnet/asr/pytorch_backend/asr_init.py b/espnet/asr/pytorch_backend/asr_init.py new file mode 100644 index 0000000000000000000000000000000000000000..5831abde090c17967d80cc0047d3c719ba1a0b51 --- /dev/null +++ b/espnet/asr/pytorch_backend/asr_init.py @@ -0,0 +1,282 @@ +"""Finetuning methods.""" + +import logging +import os +import torch + +from collections import OrderedDict + +from espnet.asr.asr_utils import get_model_conf +from espnet.asr.asr_utils import torch_load +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.mt_interface import MTInterface +from espnet.nets.pytorch_backend.transducer.utils import custom_torch_load +from espnet.nets.tts_interface import TTSInterface +from espnet.utils.dynamic_import import dynamic_import + + +def freeze_modules(model, modules): + """Freeze model parameters according to modules list. + + Args: + model (torch.nn.Module): main model to update + modules (list): specified module list for freezing + + Return: + model (torch.nn.Module): updated model + model_params (filter): filtered model parameters + + """ + for mod, param in model.named_parameters(): + if any(mod.startswith(m) for m in modules): + logging.info(f"freezing {mod}, it will not be updated.") + param.requires_grad = False + + model_params = filter(lambda x: x.requires_grad, model.parameters()) + + return model, model_params + + +def transfer_verification(model_state_dict, partial_state_dict, modules): + """Verify tuples (key, shape) for input model modules match specified modules. + + Args: + model_state_dict (OrderedDict): the initial model state_dict + partial_state_dict (OrderedDict): the trained model state_dict + modules (list): specified module list for transfer + + Return: + (boolean): allow transfer + + """ + modules_model = [] + partial_modules = [] + + for key_p, value_p in partial_state_dict.items(): + if any(key_p.startswith(m) for m in modules): + partial_modules += [(key_p, value_p.shape)] + + for key_m, value_m in model_state_dict.items(): + if any(key_m.startswith(m) for m in modules): + modules_model += [(key_m, value_m.shape)] + + len_match = len(modules_model) == len(partial_modules) + + module_match = sorted(modules_model, key=lambda x: (x[0], x[1])) == sorted( + partial_modules, key=lambda x: (x[0], x[1]) + ) + + return len_match and module_match + + +def get_partial_state_dict(model_state_dict, modules): + """Create state_dict with specified modules matching input model modules. + + Note that get_partial_lm_state_dict is used if a LM specified. + + Args: + model_state_dict (OrderedDict): trained model state_dict + modules (list): specified module list for transfer + + Return: + new_state_dict (OrderedDict): the updated state_dict + + """ + new_state_dict = OrderedDict() + + for key, value in model_state_dict.items(): + if any(key.startswith(m) for m in modules): + new_state_dict[key] = value + + return new_state_dict + + +def get_lm_state_dict(lm_state_dict): + """Create compatible ASR decoder state dict from LM state dict. + + Args: + lm_state_dict (OrderedDict): pre-trained LM state_dict + + Return: + new_state_dict (OrderedDict): LM state_dict with updated keys + + """ + new_state_dict = OrderedDict() + + for key, value in list(lm_state_dict.items()): + if key == "predictor.embed.weight": + new_state_dict["dec.embed.weight"] = value + elif key.startswith("predictor.rnn."): + _split = key.split(".") + + new_key = "dec.decoder." + _split[2] + "." + _split[3] + "_l0" + new_state_dict[new_key] = value + + return new_state_dict + + +def filter_modules(model_state_dict, modules): + """Filter non-matched modules in module_state_dict. + + Args: + model_state_dict (OrderedDict): trained model state_dict + modules (list): specified module list for transfer + + Return: + new_mods (list): the update module list + + """ + new_mods = [] + incorrect_mods = [] + + mods_model = list(model_state_dict.keys()) + for mod in modules: + if any(key.startswith(mod) for key in mods_model): + new_mods += [mod] + else: + incorrect_mods += [mod] + + if incorrect_mods: + logging.warning( + "module(s) %s don't match or (partially match) " + "available modules in model.", + incorrect_mods, + ) + logging.warning("for information, the existing modules in model are:") + logging.warning("%s", mods_model) + + return new_mods + + +def load_trained_model(model_path, training=True): + """Load the trained model for recognition. + + Args: + model_path (str): Path to model.***.best + + """ + idim, odim, train_args = get_model_conf( + model_path, os.path.join(os.path.dirname(model_path), "model.json") + ) + + logging.warning("reading model parameters from " + model_path) + + if hasattr(train_args, "model_module"): + model_module = train_args.model_module + else: + model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E" + # CTC Loss is not needed, default to builtin to prevent import errors + if hasattr(train_args, "ctc_type"): + train_args.ctc_type = "builtin" + + model_class = dynamic_import(model_module) + + if "transducer" in model_module: + model = model_class(idim, odim, train_args, training=training) + custom_torch_load(model_path, model, training=training) + else: + model = model_class(idim, odim, train_args) + torch_load(model_path, model) + + return model, train_args + + +def get_trained_model_state_dict(model_path): + """Extract the trained model state dict for pre-initialization. + + Args: + model_path (str): Path to model.***.best + + Return: + model.state_dict() (OrderedDict): the loaded model state_dict + (bool): Boolean defining whether the model is an LM + + """ + conf_path = os.path.join(os.path.dirname(model_path), "model.json") + if "rnnlm" in model_path: + logging.warning("reading model parameters from %s", model_path) + + return get_lm_state_dict(torch.load(model_path)) + + idim, odim, args = get_model_conf(model_path, conf_path) + + logging.warning("reading model parameters from " + model_path) + + if hasattr(args, "model_module"): + model_module = args.model_module + else: + model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E" + + model_class = dynamic_import(model_module) + model = model_class(idim, odim, args) + torch_load(model_path, model) + assert ( + isinstance(model, MTInterface) + or isinstance(model, ASRInterface) + or isinstance(model, TTSInterface) + ) + + return model.state_dict() + + +def load_trained_modules(idim, odim, args, interface=ASRInterface): + """Load model encoder or/and decoder modules with ESPNET pre-trained model(s). + + Args: + idim (int): initial input dimension. + odim (int): initial output dimension. + args (Namespace): The initial model arguments. + interface (Interface): ASRInterface or STInterface or TTSInterface. + + Return: + model (torch.nn.Module): The model with pretrained modules. + + """ + + def print_new_keys(state_dict, modules, model_path): + logging.warning("loading %s from model: %s", modules, model_path) + + for k in state_dict.keys(): + logging.warning("override %s" % k) + + enc_model_path = args.enc_init + dec_model_path = args.dec_init + enc_modules = args.enc_init_mods + dec_modules = args.dec_init_mods + + model_class = dynamic_import(args.model_module) + main_model = model_class(idim, odim, args) + assert isinstance(main_model, interface) + + main_state_dict = main_model.state_dict() + + logging.warning("model(s) found for pre-initialization") + for model_path, modules in [ + (enc_model_path, enc_modules), + (dec_model_path, dec_modules), + ]: + if model_path is not None: + if os.path.isfile(model_path): + model_state_dict = get_trained_model_state_dict(model_path) + + modules = filter_modules(model_state_dict, modules) + + partial_state_dict = get_partial_state_dict(model_state_dict, modules) + + if partial_state_dict: + if transfer_verification( + main_state_dict, partial_state_dict, modules + ): + print_new_keys(partial_state_dict, modules, model_path) + main_state_dict.update(partial_state_dict) + else: + logging.warning( + f"modules {modules} in model {model_path} " + f"don't match your training config", + ) + else: + logging.warning("model was not found : %s", model_path) + + main_model.load_state_dict(main_state_dict) + + return main_model diff --git a/espnet/asr/pytorch_backend/asr_mix.py b/espnet/asr/pytorch_backend/asr_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9ce6d9110042c4226f310c86535e0c63eea04d --- /dev/null +++ b/espnet/asr/pytorch_backend/asr_mix.py @@ -0,0 +1,654 @@ +#!/usr/bin/env python3 + +""" +This script is used for multi-speaker speech recognition. + +Copyright 2017 Johns Hopkins University (Shinji Watanabe) + Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +""" +import json +import logging +import os + +# chainer related +from chainer import training +from chainer.training import extensions +from itertools import zip_longest as zip_longest +import numpy as np +from tensorboardX import SummaryWriter +import torch + +from espnet.asr.asr_mix_utils import add_results_to_json +from espnet.asr.asr_utils import adadelta_eps_decay + +from espnet.asr.asr_utils import CompareValueTrigger +from espnet.asr.asr_utils import get_model_conf +from espnet.asr.asr_utils import restore_snapshot +from espnet.asr.asr_utils import snapshot_object +from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import torch_resume +from espnet.asr.asr_utils import torch_snapshot +from espnet.asr.pytorch_backend.asr import CustomEvaluator +from espnet.asr.pytorch_backend.asr import CustomUpdater +from espnet.asr.pytorch_backend.asr import load_trained_model +import espnet.lm.pytorch_backend.extlm as extlm_pytorch +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.pytorch_backend.e2e_asr_mix import pad_list +import espnet.nets.pytorch_backend.lm.default as lm_pytorch +from espnet.utils.dataset import ChainerDataLoader +from espnet.utils.dataset import TransformDataset +from espnet.utils.deterministic_utils import set_deterministic_pytorch +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.io_utils import LoadInputsAndTargets +from espnet.utils.training.batchfy import make_batchset +from espnet.utils.training.iterators import ShufflingEnabler +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from espnet.utils.training.train_utils import check_early_stop +from espnet.utils.training.train_utils import set_early_stop + +import matplotlib + +matplotlib.use("Agg") + + +class CustomConverter(object): + """Custom batch converter for Pytorch. + + Args: + subsampling_factor (int): The subsampling factor. + dtype (torch.dtype): Data type to convert. + + """ + + def __init__(self, subsampling_factor=1, dtype=torch.float32, num_spkrs=2): + """Initialize the converter.""" + self.subsampling_factor = subsampling_factor + self.ignore_id = -1 + self.dtype = dtype + self.num_spkrs = num_spkrs + + def __call__(self, batch, device=torch.device("cpu")): + """Transform a batch and send it to a device. + + Args: + batch (list(tuple(str, dict[str, dict[str, Any]]))): The batch to transform. + device (torch.device): The device to send to. + + Returns: + tuple(torch.Tensor, torch.Tensor, torch.Tensor): Transformed batch. + + """ + # batch should be located in list + assert len(batch) == 1 + xs, ys = batch[0][0], batch[0][-self.num_spkrs :] + + # perform subsampling + if self.subsampling_factor > 1: + xs = [x[:: self.subsampling_factor, :] for x in xs] + + # get batch of lengths of input sequences + ilens = np.array([x.shape[0] for x in xs]) + + # perform padding and convert to tensor + # currently only support real number + if xs[0].dtype.kind == "c": + xs_pad_real = pad_list( + [torch.from_numpy(x.real).float() for x in xs], 0 + ).to(device, dtype=self.dtype) + xs_pad_imag = pad_list( + [torch.from_numpy(x.imag).float() for x in xs], 0 + ).to(device, dtype=self.dtype) + # Note(kamo): + # {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E. + # Don't create ComplexTensor and give it to E2E here + # because torch.nn.DataParallel can't handle it. + xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag} + else: + xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to( + device, dtype=self.dtype + ) + + ilens = torch.from_numpy(ilens).to(device) + if not isinstance(ys[0], np.ndarray): + ys_pad = [] + for i in range(len(ys)): # speakers + ys_pad += [torch.from_numpy(y).long() for y in ys[i]] + ys_pad = pad_list(ys_pad, self.ignore_id) + ys_pad = ( + ys_pad.view(self.num_spkrs, -1, ys_pad.size(1)) + .transpose(0, 1) + .to(device) + ) # (B, num_spkrs, Tmax) + else: + ys_pad = pad_list( + [torch.from_numpy(y).long() for y in ys], self.ignore_id + ).to(device) + + return xs_pad, ilens, ys_pad + + +def train(args): + """Train with the given args. + + Args: + args (namespace): The program arguments. + + """ + set_deterministic_pytorch(args) + + # check cuda availability + if not torch.cuda.is_available(): + logging.warning("cuda is not available") + + # get input and output dimension info + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + utts = list(valid_json.keys()) + idim = int(valid_json[utts[0]]["input"][0]["shape"][-1]) + odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) + logging.info("#input dims : " + str(idim)) + logging.info("#output dims: " + str(odim)) + + # specify attention, CTC, hybrid mode + if args.mtlalpha == 1.0: + mtl_mode = "ctc" + logging.info("Pure CTC mode") + elif args.mtlalpha == 0.0: + mtl_mode = "att" + logging.info("Pure attention mode") + else: + mtl_mode = "mtl" + logging.info("Multitask learning mode") + + # specify model architecture + model_class = dynamic_import(args.model_module) + model = model_class(idim, odim, args) + assert isinstance(model, ASRInterface) + subsampling_factor = model.subsample[0] + + if args.rnnlm is not None: + rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) + rnnlm = lm_pytorch.ClassifierWithState( + lm_pytorch.RNNLM( + len(args.char_list), + rnnlm_args.layer, + rnnlm_args.unit, + getattr(rnnlm_args, "embed_unit", None), # for backward compatibility + ) + ) + torch.load(args.rnnlm, rnnlm) + model.rnnlm = rnnlm + + # write model config + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + model_conf = args.outdir + "/model.json" + with open(model_conf, "wb") as f: + logging.info("writing a model config file to " + model_conf) + f.write( + json.dumps( + (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) + for key in sorted(vars(args).keys()): + logging.info("ARGS: " + key + ": " + str(vars(args)[key])) + + reporter = model.reporter + + # check the use of multi-gpu + if args.ngpu > 1: + if args.batch_size != 0: + logging.warning( + "batch size is automatically increased (%d -> %d)" + % (args.batch_size, args.batch_size * args.ngpu) + ) + args.batch_size *= args.ngpu + + # set torch device + device = torch.device("cuda" if args.ngpu > 0 else "cpu") + if args.train_dtype in ("float16", "float32", "float64"): + dtype = getattr(torch, args.train_dtype) + else: + dtype = torch.float32 + model = model.to(device=device, dtype=dtype) + + logging.warning( + "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(p.numel() for p in model.parameters() if p.requires_grad) + * 100.0 + / sum(p.numel() for p in model.parameters()), + ) + ) + + # Setup an optimizer + if args.opt == "adadelta": + optimizer = torch.optim.Adadelta( + model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay + ) + elif args.opt == "adam": + optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay) + elif args.opt == "noam": + from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + + optimizer = get_std_opt( + model.parameters(), + args.adim, + args.transformer_warmup_steps, + args.transformer_lr, + ) + else: + raise NotImplementedError("unknown optimizer: " + args.opt) + + # setup apex.amp + if args.train_dtype in ("O0", "O1", "O2", "O3"): + try: + from apex import amp + except ImportError as e: + logging.error( + f"You need to install apex for --train-dtype {args.train_dtype}. " + "See https://github.com/NVIDIA/apex#linux" + ) + raise e + if args.opt == "noam": + model, optimizer.optimizer = amp.initialize( + model, optimizer.optimizer, opt_level=args.train_dtype + ) + else: + model, optimizer = amp.initialize( + model, optimizer, opt_level=args.train_dtype + ) + use_apex = True + else: + use_apex = False + + # FIXME: TOO DIRTY HACK + setattr(optimizer, "target", reporter) + setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) + + # Setup a converter + converter = CustomConverter( + subsampling_factor=subsampling_factor, dtype=dtype, num_spkrs=args.num_spkrs + ) + + # read json data + with open(args.train_json, "rb") as f: + train_json = json.load(f)["utts"] + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + + use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 + # make minibatch list (variable length) + train = make_batchset( + train_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + shortest_first=use_sortagrad, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + iaxis=0, + oaxis=-1, + ) + valid = make_batchset( + valid_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + iaxis=0, + oaxis=-1, + ) + + load_tr = LoadInputsAndTargets( + mode="asr", + load_output=True, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": True}, # Switch the mode of preprocessing + ) + load_cv = LoadInputsAndTargets( + mode="asr", + load_output=True, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": False}, # Switch the mode of preprocessing + ) + # hack to make batchsize argument as 1 + # actual bathsize is included in a list + # default collate function converts numpy array to pytorch tensor + # we used an empty collate function instead which returns list + train_iter = { + "main": ChainerDataLoader( + dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), + batch_size=1, + num_workers=args.n_iter_processes, + shuffle=True, + collate_fn=lambda x: x[0], + ) + } + valid_iter = { + "main": ChainerDataLoader( + dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), + batch_size=1, + shuffle=False, + collate_fn=lambda x: x[0], + num_workers=args.n_iter_processes, + ) + } + + # Set up a trainer + updater = CustomUpdater( + model, + args.grad_clip, + train_iter, + optimizer, + device, + args.ngpu, + args.grad_noise, + args.accum_grad, + use_apex=use_apex, + ) + trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) + + if use_sortagrad: + trainer.extend( + ShufflingEnabler([train_iter]), + trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), + ) + + # Resume from a snapshot + if args.resume: + logging.info("resumed from %s" % args.resume) + torch_resume(args.resume, trainer) + + # Evaluate the model with the test dataset for each epoch + trainer.extend(CustomEvaluator(model, valid_iter, reporter, device, args.ngpu)) + + # Save attention weight each epoch + if args.num_save_attention > 0 and args.mtlalpha != 1.0: + data = sorted( + list(valid_json.items())[: args.num_save_attention], + key=lambda x: int(x[1]["input"][0]["shape"][1]), + reverse=True, + ) + if hasattr(model, "module"): + att_vis_fn = model.module.calculate_all_attentions + plot_class = model.module.attention_plot_class + else: + att_vis_fn = model.calculate_all_attentions + plot_class = model.attention_plot_class + att_reporter = plot_class( + att_vis_fn, + data, + args.outdir + "/att_ws", + converter=converter, + transform=load_cv, + device=device, + ) + trainer.extend(att_reporter, trigger=(1, "epoch")) + else: + att_reporter = None + + # Make a plot for training and validation values + trainer.extend( + extensions.PlotReport( + [ + "main/loss", + "validation/main/loss", + "main/loss_ctc", + "validation/main/loss_ctc", + "main/loss_att", + "validation/main/loss_att", + ], + "epoch", + file_name="loss.png", + ) + ) + trainer.extend( + extensions.PlotReport( + ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png" + ) + ) + trainer.extend( + extensions.PlotReport( + ["main/cer_ctc", "validation/main/cer_ctc"], "epoch", file_name="cer.png" + ) + ) + + # Save best models + trainer.extend( + snapshot_object(model, "model.loss.best"), + trigger=training.triggers.MinValueTrigger("validation/main/loss"), + ) + if mtl_mode != "ctc": + trainer.extend( + snapshot_object(model, "model.acc.best"), + trigger=training.triggers.MaxValueTrigger("validation/main/acc"), + ) + + # save snapshot which contains model and optimizer states + trainer.extend(torch_snapshot(), trigger=(1, "epoch")) + + # epsilon decay in the optimizer + if args.opt == "adadelta": + if args.criterion == "acc" and mtl_mode != "ctc": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.acc.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + elif args.criterion == "loss": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.loss.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + + # Write a log of evaluation statistics for each epoch + trainer.extend( + extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) + ) + report_keys = [ + "epoch", + "iteration", + "main/loss", + "main/loss_ctc", + "main/loss_att", + "validation/main/loss", + "validation/main/loss_ctc", + "validation/main/loss_att", + "main/acc", + "validation/main/acc", + "main/cer_ctc", + "validation/main/cer_ctc", + "elapsed_time", + ] + if args.opt == "adadelta": + trainer.extend( + extensions.observe_value( + "eps", + lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ + "eps" + ], + ), + trigger=(args.report_interval_iters, "iteration"), + ) + report_keys.append("eps") + if args.report_cer: + report_keys.append("validation/main/cer") + if args.report_wer: + report_keys.append("validation/main/wer") + trainer.extend( + extensions.PrintReport(report_keys), + trigger=(args.report_interval_iters, "iteration"), + ) + + trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) + set_early_stop(trainer, args) + + if args.tensorboard_dir is not None and args.tensorboard_dir != "": + trainer.extend( + TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), + trigger=(args.report_interval_iters, "iteration"), + ) + # Run the training + trainer.run() + check_early_stop(trainer, args.epochs) + + +def recog(args): + """Decode with the given args. + + Args: + args (namespace): The program arguments. + + """ + set_deterministic_pytorch(args) + model, train_args = load_trained_model(args.model) + assert isinstance(model, ASRInterface) + model.recog_args = args + + # read rnnlm + if args.rnnlm: + rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) + if getattr(rnnlm_args, "model_module", "default") != "default": + raise ValueError( + "use '--api v2' option to decode with non-default language model" + ) + rnnlm = lm_pytorch.ClassifierWithState( + lm_pytorch.RNNLM( + len(train_args.char_list), + rnnlm_args.layer, + rnnlm_args.unit, + getattr(rnnlm_args, "embed_unit", None), # for backward compatibility + ) + ) + torch_load(args.rnnlm, rnnlm) + rnnlm.eval() + else: + rnnlm = None + + if args.word_rnnlm: + rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf) + word_dict = rnnlm_args.char_list_dict + char_dict = {x: i for i, x in enumerate(train_args.char_list)} + word_rnnlm = lm_pytorch.ClassifierWithState( + lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit) + ) + torch_load(args.word_rnnlm, word_rnnlm) + word_rnnlm.eval() + + if rnnlm is not None: + rnnlm = lm_pytorch.ClassifierWithState( + extlm_pytorch.MultiLevelLM( + word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict + ) + ) + else: + rnnlm = lm_pytorch.ClassifierWithState( + extlm_pytorch.LookAheadWordLM( + word_rnnlm.predictor, word_dict, char_dict + ) + ) + + # gpu + if args.ngpu == 1: + gpu_id = list(range(args.ngpu)) + logging.info("gpu id: " + str(gpu_id)) + model.cuda() + if rnnlm: + rnnlm.cuda() + + # read json data + with open(args.recog_json, "rb") as f: + js = json.load(f)["utts"] + new_js = {} + + load_inputs_and_targets = LoadInputsAndTargets( + mode="asr", + load_output=False, + sort_in_input_length=False, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf, + preprocess_args={"train": False}, + ) + + if args.batchsize == 0: + with torch.no_grad(): + for idx, name in enumerate(js.keys(), 1): + logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) + batch = [(name, js[name])] + feat = load_inputs_and_targets(batch)[0][0] + nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm) + new_js[name] = add_results_to_json( + js[name], nbest_hyps, train_args.char_list + ) + + else: + + def grouper(n, iterable, fillvalue=None): + kargs = [iter(iterable)] * n + return zip_longest(*kargs, fillvalue=fillvalue) + + # sort data if batchsize > 1 + keys = list(js.keys()) + if args.batchsize > 1: + feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] + sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) + keys = [keys[i] for i in sorted_index] + + with torch.no_grad(): + for names in grouper(args.batchsize, keys, None): + names = [name for name in names if name] + batch = [(name, js[name]) for name in names] + feats = load_inputs_and_targets(batch)[0] + nbest_hyps = model.recognize_batch( + feats, args, train_args.char_list, rnnlm=rnnlm + ) + + for i, name in enumerate(names): + nbest_hyp = [hyp[i] for hyp in nbest_hyps] + new_js[name] = add_results_to_json( + js[name], nbest_hyp, train_args.char_list + ) + + with open(args.result_label, "wb") as f: + f.write( + json.dumps( + {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) diff --git a/espnet/asr/pytorch_backend/recog.py b/espnet/asr/pytorch_backend/recog.py new file mode 100644 index 0000000000000000000000000000000000000000..f4299dcf9241f24ca35c0714355d21e37afeaf56 --- /dev/null +++ b/espnet/asr/pytorch_backend/recog.py @@ -0,0 +1,152 @@ +"""V2 backend for `asr_recog.py` using py:class:`espnet.nets.beam_search.BeamSearch`.""" + +import json +import logging + +import torch + +from espnet.asr.asr_utils import add_results_to_json +from espnet.asr.asr_utils import get_model_conf +from espnet.asr.asr_utils import torch_load +from espnet.asr.pytorch_backend.asr import load_trained_model +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.batch_beam_search import BatchBeamSearch +from espnet.nets.beam_search import BeamSearch +from espnet.nets.lm_interface import dynamic_import_lm +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.nets.scorers.length_bonus import LengthBonus +from espnet.utils.deterministic_utils import set_deterministic_pytorch +from espnet.utils.io_utils import LoadInputsAndTargets + + +def recog_v2(args): + """Decode with custom models that implements ScorerInterface. + + Notes: + The previous backend espnet.asr.pytorch_backend.asr.recog + only supports E2E and RNNLM + + Args: + args (namespace): The program arguments. + See py:func:`espnet.bin.asr_recog.get_parser` for details + + """ + logging.warning("experimental API for custom LMs is selected by --api v2") + if args.batchsize > 1: + raise NotImplementedError("multi-utt batch decoding is not implemented") + if args.streaming_mode is not None: + raise NotImplementedError("streaming mode is not implemented") + if args.word_rnnlm: + raise NotImplementedError("word LM is not implemented") + + set_deterministic_pytorch(args) + model, train_args = load_trained_model(args.model) + assert isinstance(model, ASRInterface) + model.eval() + + load_inputs_and_targets = LoadInputsAndTargets( + mode="asr", + load_output=False, + sort_in_input_length=False, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf, + preprocess_args={"train": False}, + ) + + if args.rnnlm: + lm_args = get_model_conf(args.rnnlm, args.rnnlm_conf) + # NOTE: for a compatibility with less than 0.5.0 version models + lm_model_module = getattr(lm_args, "model_module", "default") + lm_class = dynamic_import_lm(lm_model_module, lm_args.backend) + lm = lm_class(len(train_args.char_list), lm_args) + torch_load(args.rnnlm, lm) + lm.eval() + else: + lm = None + + if args.ngram_model: + from espnet.nets.scorers.ngram import NgramFullScorer + from espnet.nets.scorers.ngram import NgramPartScorer + + if args.ngram_scorer == "full": + ngram = NgramFullScorer(args.ngram_model, train_args.char_list) + else: + ngram = NgramPartScorer(args.ngram_model, train_args.char_list) + else: + ngram = None + + scorers = model.scorers() + scorers["lm"] = lm + scorers["ngram"] = ngram + scorers["length_bonus"] = LengthBonus(len(train_args.char_list)) + weights = dict( + decoder=1.0 - args.ctc_weight, + ctc=args.ctc_weight, + lm=args.lm_weight, + ngram=args.ngram_weight, + length_bonus=args.penalty, + ) + beam_search = BeamSearch( + beam_size=args.beam_size, + vocab_size=len(train_args.char_list), + weights=weights, + scorers=scorers, + sos=model.sos, + eos=model.eos, + token_list=train_args.char_list, + pre_beam_score_key=None if args.ctc_weight == 1.0 else "full", + ) + # TODO(karita): make all scorers batchfied + if args.batchsize == 1: + non_batch = [ + k + for k, v in beam_search.full_scorers.items() + if not isinstance(v, BatchScorerInterface) + ] + if len(non_batch) == 0: + beam_search.__class__ = BatchBeamSearch + logging.info("BatchBeamSearch implementation is selected.") + else: + logging.warning( + f"As non-batch scorers {non_batch} are found, " + f"fall back to non-batch implementation." + ) + + if args.ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + if args.ngpu == 1: + device = "cuda" + else: + device = "cpu" + dtype = getattr(torch, args.dtype) + logging.info(f"Decoding device={device}, dtype={dtype}") + model.to(device=device, dtype=dtype).eval() + beam_search.to(device=device, dtype=dtype).eval() + + # read json data + with open(args.recog_json, "rb") as f: + js = json.load(f)["utts"] + new_js = {} + with torch.no_grad(): + for idx, name in enumerate(js.keys(), 1): + logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) + batch = [(name, js[name])] + feat = load_inputs_and_targets(batch)[0][0] + enc = model.encode(torch.as_tensor(feat).to(device=device, dtype=dtype)) + nbest_hyps = beam_search( + x=enc, maxlenratio=args.maxlenratio, minlenratio=args.minlenratio + ) + nbest_hyps = [ + h.asdict() for h in nbest_hyps[: min(len(nbest_hyps), args.nbest)] + ] + new_js[name] = add_results_to_json( + js[name], nbest_hyps, train_args.char_list + ) + + with open(args.result_label, "wb") as f: + f.write( + json.dumps( + {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) diff --git a/espnet/bin/__init__.py b/espnet/bin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/bin/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/bin/asr_align.py b/espnet/bin/asr_align.py new file mode 100644 index 0000000000000000000000000000000000000000..b26a275cdcdf487a06f8b108fdf4e85d44f98a56 --- /dev/null +++ b/espnet/bin/asr_align.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2020 Johns Hopkins University (Xuankai Chang) +# 2020, Technische Universität München; Dominik Winkelbauer, Ludwig Kürzinger +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" +This program performs CTC segmentation to align utterances within audio files. + +Inputs: + `--data-json`: + A json containing list of utterances and audio files + `--model`: + An already trained ASR model + +Output: + `--output`: + A plain `segments` file with utterance positions in the audio files. + +Selected parameters: + `--min-window-size`: + Minimum window size considered for a single utterance. The current default value + should be OK in most cases. Larger values might give better results; too large + values cause IndexErrors. + `--subsampling-factor`: + If the encoder sub-samples its input, the number of frames at the CTC layer is + reduced by this factor. + `--frame-duration`: + This is the non-overlapping duration of a single frame in milliseconds (the + inverse of frames per millisecond). + `--set-blank`: + In the rare case that the blank token has not the index 0 in the character + dictionary, this parameter sets the index of the blank token. + `--gratis-blank`: + Sets the transition cost for blank tokens to zero. Useful if there are longer + unrelated segments between segments. + `--replace-spaces-with-blanks`: + Spaces are replaced with blanks. Helps to model pauses between words. May + increase length of ground truth. May lead to misaligned segments when combined + with the option `--gratis-blank`. +""" + +import configargparse +import logging +import os +import sys + +# imports for inference +from espnet.asr.pytorch_backend.asr_init import load_trained_model +from espnet.nets.asr_interface import ASRInterface +from espnet.utils.io_utils import LoadInputsAndTargets +import json +import torch + +# imports for CTC segmentation +from ctc_segmentation import ctc_segmentation +from ctc_segmentation import CtcSegmentationParameters +from ctc_segmentation import determine_utterance_segments +from ctc_segmentation import prepare_text + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(): + """Get default arguments.""" + parser = configargparse.ArgumentParser( + description="Align text to audio using CTC segmentation." + "using a pre-trained speech recognition model.", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="Decoding config file path.") + parser.add_argument( + "--ngpu", type=int, default=0, help="Number of GPUs (max. 1 is supported)" + ) + parser.add_argument( + "--dtype", + choices=("float16", "float32", "float64"), + default="float32", + help="Float precision (only available in --api v2)", + ) + parser.add_argument( + "--backend", + type=str, + default="pytorch", + choices=["pytorch"], + help="Backend library", + ) + parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") + parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option") + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + # task related + parser.add_argument( + "--data-json", type=str, help="Json of recognition data for audio and text" + ) + parser.add_argument("--utt-text", type=str, help="Text separated into utterances") + # model (parameter) related + parser.add_argument( + "--model", type=str, required=True, help="Model file parameters to read" + ) + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file" + ) + parser.add_argument( + "--num-encs", default=1, type=int, help="Number of encoders in the model." + ) + # ctc-segmentation related + parser.add_argument( + "--subsampling-factor", + type=int, + default=None, + help="Subsampling factor." + " If the encoder sub-samples its input, the number of frames at the CTC layer" + " is reduced by this factor. For example, a BLSTMP with subsampling 1_2_2_1_1" + " has a subsampling factor of 4.", + ) + parser.add_argument( + "--frame-duration", + type=int, + default=None, + help="Non-overlapping duration of a single frame in milliseconds.", + ) + parser.add_argument( + "--min-window-size", + type=int, + default=None, + help="Minimum window size considered for utterance.", + ) + parser.add_argument( + "--max-window-size", + type=int, + default=None, + help="Maximum window size considered for utterance.", + ) + parser.add_argument( + "--use-dict-blank", + type=int, + default=None, + help="DEPRECATED.", + ) + parser.add_argument( + "--set-blank", + type=int, + default=None, + help="Index of model dictionary for blank token (default: 0).", + ) + parser.add_argument( + "--gratis-blank", + type=int, + default=None, + help="Set the transition cost of the blank token to zero. Audio sections" + " labeled with blank tokens can then be skipped without penalty. Useful" + " if there are unrelated audio segments between utterances.", + ) + parser.add_argument( + "--replace-spaces-with-blanks", + type=int, + default=None, + help="Fill blanks in between words to better model pauses between words." + " Segments can be misaligned if this option is combined with --gratis-blank." + " May increase length of ground truth.", + ) + parser.add_argument( + "--scoring-length", + type=int, + default=None, + help="Changes partitioning length L for calculation of the confidence score.", + ) + parser.add_argument( + "--output", + type=configargparse.FileType("w"), + required=True, + help="Output segments file", + ) + return parser + + +def main(args): + """Run the main decoding function.""" + parser = get_parser() + args, extra = parser.parse_known_args(args) + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose == 2: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + if args.ngpu == 0 and args.dtype == "float16": + raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.") + # check CUDA_VISIBLE_DEVICES + device = "cpu" + if args.ngpu == 1: + device = "cuda" + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu > 1: + logging.error("Decoding only supports ngpu=1.") + sys.exit(1) + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + # recog + logging.info("backend = " + args.backend) + if args.backend == "pytorch": + ctc_align(args, device) + else: + raise ValueError("Only pytorch is supported.") + sys.exit(0) + + +def ctc_align(args, device): + """ESPnet-specific interface for CTC segmentation. + + Parses configuration, infers the CTC posterior probabilities, + and then aligns start and end of utterances using CTC segmentation. + Results are written to the output file given in the args. + + :param args: given configuration + :param device: for inference; one of ['cuda', 'cpu'] + :return: 0 on success + """ + model, train_args = load_trained_model(args.model) + assert isinstance(model, ASRInterface) + load_inputs_and_targets = LoadInputsAndTargets( + mode="asr", + load_output=True, + sort_in_input_length=False, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf, + preprocess_args={"train": False}, + ) + logging.info(f"Decoding device={device}") + # Warn for nets with high memory consumption on long audio files + if hasattr(model, "enc"): + encoder_module = model.enc.__class__.__module__ + elif hasattr(model, "encoder"): + encoder_module = model.encoder.__class__.__module__ + else: + encoder_module = "Unknown" + logging.info(f"Encoder module: {encoder_module}") + logging.info(f"CTC module: {model.ctc.__class__.__module__}") + if "rnn" not in encoder_module: + logging.warning("No BLSTM model detected; memory consumption may be high.") + model.to(device=device).eval() + # read audio and text json data + with open(args.data_json, "rb") as f: + js = json.load(f)["utts"] + with open(args.utt_text, "r", encoding="utf-8") as f: + lines = f.readlines() + i = 0 + text = {} + segment_names = {} + for name in js.keys(): + text_per_audio = [] + segment_names_per_audio = [] + while i < len(lines) and lines[i].startswith(name): + text_per_audio.append(lines[i][lines[i].find(" ") + 1 :]) + segment_names_per_audio.append(lines[i][: lines[i].find(" ")]) + i += 1 + text[name] = text_per_audio + segment_names[name] = segment_names_per_audio + # apply configuration + config = CtcSegmentationParameters() + if args.subsampling_factor is not None: + config.subsampling_factor = args.subsampling_factor + if args.frame_duration is not None: + config.frame_duration_ms = args.frame_duration + if args.min_window_size is not None: + config.min_window_size = args.min_window_size + if args.max_window_size is not None: + config.max_window_size = args.max_window_size + config.char_list = train_args.char_list + if args.use_dict_blank is not None: + logging.warning( + "The option --use-dict-blank is deprecated. If needed," + " use --set-blank instead." + ) + if args.set_blank is not None: + config.blank = args.set_blank + if args.replace_spaces_with_blanks is not None: + if args.replace_spaces_with_blanks: + config.replace_spaces_with_blanks = True + else: + config.replace_spaces_with_blanks = False + if args.gratis_blank: + config.blank_transition_cost_zero = True + if config.blank_transition_cost_zero and args.replace_spaces_with_blanks: + logging.error( + "Blanks are inserted between words, and also the transition cost of blank" + " is zero. This configuration may lead to misalignments!" + ) + if args.scoring_length is not None: + config.score_min_mean_over_L = args.scoring_length + logging.info( + f"Frame timings: {config.frame_duration_ms}ms * {config.subsampling_factor}" + ) + # Iterate over audio files to decode and align + for idx, name in enumerate(js.keys(), 1): + logging.info("(%d/%d) Aligning " + name, idx, len(js.keys())) + batch = [(name, js[name])] + feat, label = load_inputs_and_targets(batch) + feat = feat[0] + with torch.no_grad(): + # Encode input frames + enc_output = model.encode(torch.as_tensor(feat).to(device)).unsqueeze(0) + # Apply ctc layer to obtain log character probabilities + lpz = model.ctc.log_softmax(enc_output)[0].cpu().numpy() + # Prepare the text for aligning + ground_truth_mat, utt_begin_indices = prepare_text(config, text[name]) + # Align using CTC segmentation + timings, char_probs, state_list = ctc_segmentation( + config, lpz, ground_truth_mat + ) + logging.debug(f"state_list = {state_list}") + # Obtain list of utterances with time intervals and confidence score + segments = determine_utterance_segments( + config, utt_begin_indices, char_probs, timings, text[name] + ) + # Write to "segments" file + for i, boundary in enumerate(segments): + utt_segment = ( + f"{segment_names[name][i]} {name} {boundary[0]:.2f}" + f" {boundary[1]:.2f} {boundary[2]:.9f}\n" + ) + args.output.write(utt_segment) + return 0 + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/asr_enhance.py b/espnet/bin/asr_enhance.py new file mode 100644 index 0000000000000000000000000000000000000000..98f0d693caa47fd752c0a9cd8e577dacf47f74b9 --- /dev/null +++ b/espnet/bin/asr_enhance.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +import configargparse +from distutils.util import strtobool +import logging +import os +import random +import sys + +import numpy as np + +from espnet.asr.pytorch_backend.asr import enhance + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(): + parser = configargparse.ArgumentParser( + description="Enhance noisy speech for speech recognition", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="config file path") + parser.add( + "--config2", + is_config_file=True, + help="second config file path that overwrites the settings in `--config`.", + ) + parser.add( + "--config3", + is_config_file=True, + help="third config file path that overwrites the settings " + "in `--config` and `--config2`.", + ) + + parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs") + parser.add_argument( + "--backend", + default="chainer", + type=str, + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") + parser.add_argument("--seed", default=1, type=int, help="Random seed") + parser.add_argument("--verbose", "-V", default=1, type=int, help="Verbose option") + parser.add_argument( + "--batchsize", + default=1, + type=int, + help="Batch size for beam search (0: means no batch processing)", + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + # task related + parser.add_argument( + "--recog-json", type=str, help="Filename of recognition data (json)" + ) + # model (parameter) related + parser.add_argument( + "--model", type=str, required=True, help="Model file parameters to read" + ) + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file" + ) + + # Outputs configuration + parser.add_argument( + "--enh-wspecifier", + type=str, + default=None, + help="Specify the output way for enhanced speech." + "e.g. ark,scp:outdir,wav.scp", + ) + parser.add_argument( + "--enh-filetype", + type=str, + default="sound", + choices=["mat", "hdf5", "sound.hdf5", "sound"], + help="Specify the file format for enhanced speech. " + '"mat" is the matrix format in kaldi', + ) + parser.add_argument("--fs", type=int, default=16000, help="The sample frequency") + parser.add_argument( + "--keep-length", + type=strtobool, + default=True, + help="Adjust the output length to match " "with the input for enhanced speech", + ) + parser.add_argument( + "--image-dir", type=str, default=None, help="The directory saving the images." + ) + parser.add_argument( + "--num-images", + type=int, + default=20, + help="The number of images files to be saved. " + "If negative, all samples are to be saved.", + ) + + # IStft + parser.add_argument( + "--apply-istft", + type=strtobool, + default=True, + help="Apply istft to the output from the network", + ) + parser.add_argument( + "--istft-win-length", + type=int, + default=512, + help="The window length for istft. " + "This option is ignored " + "if stft is found in the preprocess-conf", + ) + parser.add_argument( + "--istft-n-shift", + type=str, + default=256, + help="The window type for istft. " + "This option is ignored " + "if stft is found in the preprocess-conf", + ) + parser.add_argument( + "--istft-window", + type=str, + default="hann", + help="The window type for istft. " + "This option is ignored " + "if stft is found in the preprocess-conf", + ) + return parser + + +def main(args): + parser = get_parser() + args = parser.parse_args(args) + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose == 2: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # TODO(kamo): support of multiple GPUs + if args.ngpu > 1: + logging.error("The program only supports ngpu=1.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # seed setting + random.seed(args.seed) + np.random.seed(args.seed) + logging.info("set random seed = %d" % args.seed) + + # recog + logging.info("backend = " + args.backend) + if args.backend == "pytorch": + enhance(args) + else: + raise ValueError("Only pytorch is supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/asr_recog.py b/espnet/bin/asr_recog.py new file mode 100644 index 0000000000000000000000000000000000000000..dc7c64a76f187ff3b132076fc102e9bac67e311f --- /dev/null +++ b/espnet/bin/asr_recog.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""End-to-end speech recognition model decoding script.""" + +import configargparse +import logging +import os +import random +import sys + +import numpy as np + +from espnet.utils.cli_utils import strtobool + +# NOTE: you need this func to generate our sphinx doc + + +def get_parser(): + """Get default arguments.""" + parser = configargparse.ArgumentParser( + description="Transcribe text from speech using " + "a speech recognition model on one CPU or GPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="Config file path") + parser.add( + "--config2", + is_config_file=True, + help="Second config file path that overwrites the settings in `--config`", + ) + parser.add( + "--config3", + is_config_file=True, + help="Third config file path that overwrites the settings " + "in `--config` and `--config2`", + ) + + parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") + parser.add_argument( + "--dtype", + choices=("float16", "float32", "float64"), + default="float32", + help="Float precision (only available in --api v2)", + ) + parser.add_argument( + "--backend", + type=str, + default="chainer", + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option") + parser.add_argument( + "--batchsize", + type=int, + default=1, + help="Batch size for beam search (0: means no batch processing)", + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + parser.add_argument( + "--api", + default="v1", + choices=["v1", "v2"], + help="Beam search APIs " + "v1: Default API. It only supports the ASRInterface.recognize method " + "and DefaultRNNLM. " + "v2: Experimental API. It supports any models that implements ScorerInterface.", + ) + # task related + parser.add_argument( + "--recog-json", type=str, help="Filename of recognition data (json)" + ) + parser.add_argument( + "--result-label", + type=str, + required=True, + help="Filename of result label data (json)", + ) + # model (parameter) related + parser.add_argument( + "--model", type=str, required=True, help="Model file parameters to read" + ) + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file" + ) + parser.add_argument( + "--num-spkrs", + type=int, + default=1, + choices=[1, 2], + help="Number of speakers in the speech", + ) + parser.add_argument( + "--num-encs", default=1, type=int, help="Number of encoders in the model." + ) + # search related + parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=1, help="Beam size") + parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths""", + ) + parser.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + parser.add_argument( + "--ctc-weight", type=float, default=0.0, help="CTC weight in joint decoding" + ) + parser.add_argument( + "--weights-ctc-dec", + type=float, + action="append", + help="ctc weight assigned to each encoder during decoding." + "[in multi-encoder mode only]", + ) + parser.add_argument( + "--ctc-window-margin", + type=int, + default=0, + help="""Use CTC window with margin parameter to accelerate + CTC/attention decoding especially on GPU. Smaller magin + makes decoding faster, but may increase search errors. + If margin=0 (default), this function is disabled""", + ) + # transducer related + parser.add_argument( + "--search-type", + type=str, + default="default", + choices=["default", "nsc", "tsd", "alsd"], + help="""Type of beam search implementation to use during inference. + Can be either: default beam search, n-step constrained beam search ("nsc"), + time-synchronous decoding ("tsd") or alignment-length synchronous decoding + ("alsd"). + Additional associated parameters: "nstep" + "prefix-alpha" (for nsc), + "max-sym-exp" (for tsd) and "u-max" (for alsd)""", + ) + parser.add_argument( + "--nstep", + type=int, + default=1, + help="Number of expansion steps allowed in NSC beam search.", + ) + parser.add_argument( + "--prefix-alpha", + type=int, + default=2, + help="Length prefix difference allowed in NSC beam search.", + ) + parser.add_argument( + "--max-sym-exp", + type=int, + default=2, + help="Number of symbol expansions allowed in TSD decoding.", + ) + parser.add_argument( + "--u-max", + type=int, + default=400, + help="Length prefix difference allowed in ALSD beam search.", + ) + parser.add_argument( + "--score-norm", + type=strtobool, + nargs="?", + default=True, + help="Normalize transducer scores by length", + ) + # rnnlm related + parser.add_argument( + "--rnnlm", type=str, default=None, help="RNNLM model file to read" + ) + parser.add_argument( + "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read" + ) + parser.add_argument( + "--word-rnnlm", type=str, default=None, help="Word RNNLM model file to read" + ) + parser.add_argument( + "--word-rnnlm-conf", + type=str, + default=None, + help="Word RNNLM model config file to read", + ) + parser.add_argument("--word-dict", type=str, default=None, help="Word list to read") + parser.add_argument("--lm-weight", type=float, default=0.1, help="RNNLM weight") + # ngram related + parser.add_argument( + "--ngram-model", type=str, default=None, help="ngram model file to read" + ) + parser.add_argument("--ngram-weight", type=float, default=0.1, help="ngram weight") + parser.add_argument( + "--ngram-scorer", + type=str, + default="part", + choices=("full", "part"), + help="""if the ngram is set as a part scorer, similar with CTC scorer, + ngram scorer only scores topK hypethesis. + if the ngram is set as full scorer, ngram scorer scores all hypthesis + the decoding speed of part scorer is musch faster than full one""", + ) + # streaming related + parser.add_argument( + "--streaming-mode", + type=str, + default=None, + choices=["window", "segment"], + help="""Use streaming recognizer for inference. + `--batchsize` must be set to 0 to enable this mode""", + ) + parser.add_argument("--streaming-window", type=int, default=10, help="Window size") + parser.add_argument( + "--streaming-min-blank-dur", + type=int, + default=10, + help="Minimum blank duration threshold", + ) + parser.add_argument( + "--streaming-onset-margin", type=int, default=1, help="Onset margin" + ) + parser.add_argument( + "--streaming-offset-margin", type=int, default=1, help="Offset margin" + ) + # non-autoregressive related + # Mask CTC related. See https://arxiv.org/abs/2005.08700 for the detail. + parser.add_argument( + "--maskctc-n-iterations", + type=int, + default=10, + help="Number of decoding iterations." + "For Mask CTC, set 0 to predict 1 mask/iter.", + ) + parser.add_argument( + "--maskctc-probability-threshold", + type=float, + default=0.999, + help="Threshold probability for CTC output", + ) + + return parser + + +def main(args): + """Run the main decoding function.""" + parser = get_parser() + args = parser.parse_args(args) + + if args.ngpu == 0 and args.dtype == "float16": + raise ValueError(f"--dtype {args.dtype} does not support the CPU backend.") + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose == 2: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # TODO(mn5k): support of multiple GPUs + if args.ngpu > 1: + logging.error("The program only supports ngpu=1.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # seed setting + random.seed(args.seed) + np.random.seed(args.seed) + logging.info("set random seed = %d" % args.seed) + + # validate rnn options + if args.rnnlm is not None and args.word_rnnlm is not None: + logging.error( + "It seems that both --rnnlm and --word-rnnlm are specified. " + "Please use either option." + ) + sys.exit(1) + + # recog + logging.info("backend = " + args.backend) + if args.num_spkrs == 1: + if args.backend == "chainer": + from espnet.asr.chainer_backend.asr import recog + + recog(args) + elif args.backend == "pytorch": + if args.num_encs == 1: + # Experimental API that supports custom LMs + if args.api == "v2": + from espnet.asr.pytorch_backend.recog import recog_v2 + + recog_v2(args) + else: + from espnet.asr.pytorch_backend.asr import recog + + if args.dtype != "float32": + raise NotImplementedError( + f"`--dtype {args.dtype}` is only available with `--api v2`" + ) + recog(args) + else: + if args.api == "v2": + raise NotImplementedError( + f"--num-encs {args.num_encs} > 1 is not supported in --api v2" + ) + else: + from espnet.asr.pytorch_backend.asr import recog + + recog(args) + else: + raise ValueError("Only chainer and pytorch are supported.") + elif args.num_spkrs == 2: + if args.backend == "pytorch": + from espnet.asr.pytorch_backend.asr_mix import recog + + recog(args) + else: + raise ValueError("Only pytorch is supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/asr_train.py b/espnet/bin/asr_train.py new file mode 100644 index 0000000000000000000000000000000000000000..89b7272fe8938a90dacab636e9b462f84f527769 --- /dev/null +++ b/espnet/bin/asr_train.py @@ -0,0 +1,644 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2017 Tomoki Hayashi (Nagoya University) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Automatic speech recognition model training script.""" + +import logging +import os +import random +import subprocess +import sys + +from distutils.version import LooseVersion + +import configargparse +import numpy as np +import torch + +from espnet import __version__ +from espnet.utils.cli_utils import strtobool +from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES + +is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2") + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(parser=None, required=True): + """Get default arguments.""" + if parser is None: + parser = configargparse.ArgumentParser( + description="Train an automatic speech recognition (ASR) model on one CPU, " + "one or multiple GPUs", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="config file path") + parser.add( + "--config2", + is_config_file=True, + help="second config file path that overwrites the settings in `--config`.", + ) + parser.add( + "--config3", + is_config_file=True, + help="third config file path that overwrites the settings in " + "`--config` and `--config2`.", + ) + + parser.add_argument( + "--ngpu", + default=None, + type=int, + help="Number of GPUs. If not given, use all visible devices", + ) + parser.add_argument( + "--train-dtype", + default="float32", + choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"], + help="Data type for training (only pytorch backend). " + "O0,O1,.. flags require apex. " + "See https://nvidia.github.io/apex/amp.html#opt-levels", + ) + parser.add_argument( + "--backend", + default="chainer", + type=str, + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument( + "--outdir", type=str, required=required, help="Output directory" + ) + parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") + parser.add_argument("--dict", required=required, help="Dictionary") + parser.add_argument("--seed", default=1, type=int, help="Random seed") + parser.add_argument("--debugdir", type=str, help="Output directory for debugging") + parser.add_argument( + "--resume", + "-r", + default="", + nargs="?", + help="Resume the training from snapshot", + ) + parser.add_argument( + "--minibatches", + "-N", + type=int, + default="-1", + help="Process only N minibatches (for debug)", + ) + parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--tensorboard-dir", + default=None, + type=str, + nargs="?", + help="Tensorboard log dir path", + ) + parser.add_argument( + "--report-interval-iters", + default=100, + type=int, + help="Report interval iterations", + ) + parser.add_argument( + "--save-interval-iters", + default=0, + type=int, + help="Save snapshot interval iterations", + ) + # task related + parser.add_argument( + "--train-json", + type=str, + default=None, + help="Filename of train label data (json)", + ) + parser.add_argument( + "--valid-json", + type=str, + default=None, + help="Filename of validation label data (json)", + ) + # network architecture + parser.add_argument( + "--model-module", + type=str, + default=None, + help="model defined module (default: espnet.nets.xxx_backend.e2e_asr:E2E)", + ) + # encoder + parser.add_argument( + "--num-encs", default=1, type=int, help="Number of encoders in the model." + ) + # loss related + parser.add_argument( + "--ctc_type", + default="warpctc", + type=str, + choices=["builtin", "warpctc", "gtnctc", "cudnnctc"], + help="Type of CTC implementation to calculate loss.", + ) + parser.add_argument( + "--mtlalpha", + default=0.5, + type=float, + help="Multitask learning coefficient, " + "alpha: alpha*ctc_loss + (1-alpha)*att_loss ", + ) + parser.add_argument( + "--lsm-weight", default=0.0, type=float, help="Label smoothing weight" + ) + # recognition options to compute CER/WER + parser.add_argument( + "--report-cer", + default=False, + action="store_true", + help="Compute CER on development set", + ) + parser.add_argument( + "--report-wer", + default=False, + action="store_true", + help="Compute WER on development set", + ) + parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=4, help="Beam size") + parser.add_argument("--penalty", default=0.0, type=float, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + default=0.0, + type=float, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths""", + ) + parser.add_argument( + "--minlenratio", + default=0.0, + type=float, + help="Input length ratio to obtain min output length", + ) + parser.add_argument( + "--ctc-weight", default=0.3, type=float, help="CTC weight in joint decoding" + ) + parser.add_argument( + "--rnnlm", type=str, default=None, help="RNNLM model file to read" + ) + parser.add_argument( + "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read" + ) + parser.add_argument("--lm-weight", default=0.1, type=float, help="RNNLM weight.") + parser.add_argument("--sym-space", default="", type=str, help="Space symbol") + parser.add_argument("--sym-blank", default="", type=str, help="Blank symbol") + # minibatch related + parser.add_argument( + "--sortagrad", + default=0, + type=int, + nargs="?", + help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs", + ) + parser.add_argument( + "--batch-count", + default="auto", + choices=BATCH_COUNT_CHOICES, + help="How to count batch_size. " + "The default (auto) will find how to count by args.", + ) + parser.add_argument( + "--batch-size", + "--batch-seqs", + "-b", + default=0, + type=int, + help="Maximum seqs in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-bins", + default=0, + type=int, + help="Maximum bins in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-in", + default=0, + type=int, + help="Maximum input frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-out", + default=0, + type=int, + help="Maximum output frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-inout", + default=0, + type=int, + help="Maximum input+output frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--maxlen-in", + "--batch-seq-maxlen-in", + default=800, + type=int, + metavar="ML", + help="When --batch-count=seq, " + "batch size is reduced if the input sequence length > ML.", + ) + parser.add_argument( + "--maxlen-out", + "--batch-seq-maxlen-out", + default=150, + type=int, + metavar="ML", + help="When --batch-count=seq, " + "batch size is reduced if the output sequence length > ML", + ) + parser.add_argument( + "--n-iter-processes", + default=0, + type=int, + help="Number of processes of iterator", + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + nargs="?", + help="The configuration file for the pre-processing", + ) + # optimization related + parser.add_argument( + "--opt", + default="adadelta", + type=str, + choices=["adadelta", "adam", "noam"], + help="Optimizer", + ) + parser.add_argument( + "--accum-grad", default=1, type=int, help="Number of gradient accumuration" + ) + parser.add_argument( + "--eps", default=1e-8, type=float, help="Epsilon constant for optimizer" + ) + parser.add_argument( + "--eps-decay", default=0.01, type=float, help="Decaying ratio of epsilon" + ) + parser.add_argument( + "--weight-decay", default=0.0, type=float, help="Weight decay ratio" + ) + parser.add_argument( + "--criterion", + default="acc", + type=str, + choices=["loss", "loss_eps_decay_only", "acc"], + help="Criterion to perform epsilon decay", + ) + parser.add_argument( + "--threshold", default=1e-4, type=float, help="Threshold to stop iteration" + ) + parser.add_argument( + "--epochs", "-e", default=30, type=int, help="Maximum number of epochs" + ) + parser.add_argument( + "--early-stop-criterion", + default="validation/main/acc", + type=str, + nargs="?", + help="Value to monitor to trigger an early stopping of the training", + ) + parser.add_argument( + "--patience", + default=3, + type=int, + nargs="?", + help="Number of epochs to wait without improvement " + "before stopping the training", + ) + parser.add_argument( + "--grad-clip", default=5, type=float, help="Gradient norm threshold to clip" + ) + parser.add_argument( + "--num-save-attention", + default=3, + type=int, + help="Number of samples of attention to be saved", + ) + parser.add_argument( + "--num-save-ctc", + default=3, + type=int, + help="Number of samples of CTC probability to be saved", + ) + parser.add_argument( + "--grad-noise", + type=strtobool, + default=False, + help="The flag to switch to use noise injection to gradients during training", + ) + # asr_mix related + parser.add_argument( + "--num-spkrs", + default=1, + type=int, + choices=[1, 2], + help="Number of speakers in the speech.", + ) + # decoder related + parser.add_argument( + "--context-residual", + default=False, + type=strtobool, + nargs="?", + help="The flag to switch to use context vector residual in the decoder network", + ) + # finetuning related + parser.add_argument( + "--enc-init", + default=None, + type=str, + help="Pre-trained ASR model to initialize encoder.", + ) + parser.add_argument( + "--enc-init-mods", + default="enc.enc.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of encoder modules to initialize, separated by a comma.", + ) + parser.add_argument( + "--dec-init", + default=None, + type=str, + help="Pre-trained ASR, MT or LM model to initialize decoder.", + ) + parser.add_argument( + "--dec-init-mods", + default="att.,dec.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of decoder modules to initialize, separated by a comma.", + ) + parser.add_argument( + "--freeze-mods", + default=None, + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of modules to freeze, separated by a comma.", + ) + # front end related + parser.add_argument( + "--use-frontend", + type=strtobool, + default=False, + help="The flag to switch to use frontend system.", + ) + + # WPE related + parser.add_argument( + "--use-wpe", + type=strtobool, + default=False, + help="Apply Weighted Prediction Error", + ) + parser.add_argument( + "--wtype", + default="blstmp", + type=str, + choices=[ + "lstm", + "blstm", + "lstmp", + "blstmp", + "vgglstmp", + "vggblstmp", + "vgglstm", + "vggblstm", + "gru", + "bgru", + "grup", + "bgrup", + "vgggrup", + "vggbgrup", + "vgggru", + "vggbgru", + ], + help="Type of encoder network architecture " + "of the mask estimator for WPE. " + "", + ) + parser.add_argument("--wlayers", type=int, default=2, help="") + parser.add_argument("--wunits", type=int, default=300, help="") + parser.add_argument("--wprojs", type=int, default=300, help="") + parser.add_argument("--wdropout-rate", type=float, default=0.0, help="") + parser.add_argument("--wpe-taps", type=int, default=5, help="") + parser.add_argument("--wpe-delay", type=int, default=3, help="") + parser.add_argument( + "--use-dnn-mask-for-wpe", + type=strtobool, + default=False, + help="Use DNN to estimate the power spectrogram. " + "This option is experimental.", + ) + # Beamformer related + parser.add_argument("--use-beamformer", type=strtobool, default=True, help="") + parser.add_argument( + "--btype", + default="blstmp", + type=str, + choices=[ + "lstm", + "blstm", + "lstmp", + "blstmp", + "vgglstmp", + "vggblstmp", + "vgglstm", + "vggblstm", + "gru", + "bgru", + "grup", + "bgrup", + "vgggrup", + "vggbgrup", + "vgggru", + "vggbgru", + ], + help="Type of encoder network architecture " + "of the mask estimator for Beamformer.", + ) + parser.add_argument("--blayers", type=int, default=2, help="") + parser.add_argument("--bunits", type=int, default=300, help="") + parser.add_argument("--bprojs", type=int, default=300, help="") + parser.add_argument("--badim", type=int, default=320, help="") + parser.add_argument( + "--bnmask", + type=int, + default=2, + help="Number of beamforming masks, " "default is 2 for [speech, noise].", + ) + parser.add_argument( + "--ref-channel", + type=int, + default=-1, + help="The reference channel used for beamformer. " + "By default, the channel is estimated by DNN.", + ) + parser.add_argument("--bdropout-rate", type=float, default=0.0, help="") + # Feature transform: Normalization + parser.add_argument( + "--stats-file", + type=str, + default=None, + help="The stats file for the feature normalization", + ) + parser.add_argument( + "--apply-uttmvn", + type=strtobool, + default=True, + help="Apply utterance level mean " "variance normalization.", + ) + parser.add_argument("--uttmvn-norm-means", type=strtobool, default=True, help="") + parser.add_argument("--uttmvn-norm-vars", type=strtobool, default=False, help="") + # Feature transform: Fbank + parser.add_argument( + "--fbank-fs", + type=int, + default=16000, + help="The sample frequency used for " "the mel-fbank creation.", + ) + parser.add_argument( + "--n-mels", type=int, default=80, help="The number of mel-frequency bins." + ) + parser.add_argument("--fbank-fmin", type=float, default=0.0, help="") + parser.add_argument("--fbank-fmax", type=float, default=None, help="") + return parser + + +def main(cmd_args): + """Run the main training function.""" + parser = get_parser() + args, _ = parser.parse_known_args(cmd_args) + if args.backend == "chainer" and args.train_dtype != "float32": + raise NotImplementedError( + f"chainer backend does not support --train-dtype {args.train_dtype}." + "Use --dtype float32." + ) + if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"): + raise ValueError( + f"--train-dtype {args.train_dtype} does not support the CPU backend." + ) + + from espnet.utils.dynamic_import import dynamic_import + + if args.model_module is None: + if args.num_spkrs == 1: + model_module = "espnet.nets." + args.backend + "_backend.e2e_asr:E2E" + else: + model_module = "espnet.nets." + args.backend + "_backend.e2e_asr_mix:E2E" + else: + model_module = args.model_module + model_class = dynamic_import(model_module) + model_class.add_arguments(parser) + + args = parser.parse_args(cmd_args) + args.model_module = model_module + if "chainer_backend" in args.model_module: + args.backend = "chainer" + if "pytorch_backend" in args.model_module: + args.backend = "pytorch" + + # add version info in args + args.version = __version__ + + # logging info + if args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # If --ngpu is not given, + # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices + # 2. if nvidia-smi exists, use all devices + # 3. else ngpu=0 + if args.ngpu is None: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is not None: + ngpu = len(cvd.split(",")) + else: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + try: + p = subprocess.run( + ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + except (subprocess.CalledProcessError, FileNotFoundError): + ngpu = 0 + else: + ngpu = len(p.stderr.decode().split("\n")) - 1 + else: + if is_torch_1_2_plus and args.ngpu != 1: + logging.debug( + "There are some bugs with multi-GPU processing in PyTorch 1.2+" + + " (see https://github.com/pytorch/pytorch/issues/21108)" + ) + ngpu = args.ngpu + logging.info(f"ngpu: {ngpu}") + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # set random seed + logging.info("random seed = %d" % args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + # load dictionary for debug log + if args.dict is not None: + with open(args.dict, "rb") as f: + dictionary = f.readlines() + char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary] + char_list.insert(0, "") + char_list.append("") + # for non-autoregressive maskctc model + if "maskctc" in args.model_module: + char_list.append("") + args.char_list = char_list + else: + args.char_list = None + + # train + logging.info("backend = " + args.backend) + + if args.num_spkrs == 1: + if args.backend == "chainer": + from espnet.asr.chainer_backend.asr import train + + train(args) + elif args.backend == "pytorch": + from espnet.asr.pytorch_backend.asr import train + + train(args) + else: + raise ValueError("Only chainer and pytorch are supported.") + else: + # FIXME(kamo): Support --model-module + if args.backend == "pytorch": + from espnet.asr.pytorch_backend.asr_mix import train + + train(args) + else: + raise ValueError("Only pytorch is supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/lm_train.py b/espnet/bin/lm_train.py new file mode 100644 index 0000000000000000000000000000000000000000..39adf1878ea3c6e430cab9ef5885d71ca89f4ce9 --- /dev/null +++ b/espnet/bin/lm_train.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +# This code is ported from the following implementation written in Torch. +# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py + +"""Language model training script.""" + +import logging +import os +import random +import subprocess +import sys + +import configargparse +import numpy as np + +from espnet import __version__ +from espnet.nets.lm_interface import dynamic_import_lm +from espnet.optimizer.factory import dynamic_import_optimizer +from espnet.scheduler.scheduler import dynamic_import_scheduler + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(parser=None, required=True): + """Get parser.""" + if parser is None: + parser = configargparse.ArgumentParser( + description="Train a new language model on one CPU or one GPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="config file path") + parser.add( + "--config2", + is_config_file=True, + help="second config file path that overwrites the settings in `--config`.", + ) + parser.add( + "--config3", + is_config_file=True, + help="third config file path that overwrites the settings " + "in `--config` and `--config2`.", + ) + + parser.add_argument( + "--ngpu", + default=None, + type=int, + help="Number of GPUs. If not given, use all visible devices", + ) + parser.add_argument( + "--train-dtype", + default="float32", + choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"], + help="Data type for training (only pytorch backend). " + "O0,O1,.. flags require apex. " + "See https://nvidia.github.io/apex/amp.html#opt-levels", + ) + parser.add_argument( + "--backend", + default="chainer", + type=str, + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument( + "--outdir", type=str, required=required, help="Output directory" + ) + parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") + parser.add_argument("--dict", type=str, required=required, help="Dictionary") + parser.add_argument("--seed", default=1, type=int, help="Random seed") + parser.add_argument( + "--resume", + "-r", + default="", + nargs="?", + help="Resume the training from snapshot", + ) + parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--tensorboard-dir", + default=None, + type=str, + nargs="?", + help="Tensorboard log dir path", + ) + parser.add_argument( + "--report-interval-iters", + default=100, + type=int, + help="Report interval iterations", + ) + # task related + parser.add_argument( + "--train-label", + type=str, + required=required, + help="Filename of train label data", + ) + parser.add_argument( + "--valid-label", + type=str, + required=required, + help="Filename of validation label data", + ) + parser.add_argument("--test-label", type=str, help="Filename of test label data") + parser.add_argument( + "--dump-hdf5-path", + type=str, + default=None, + help="Path to dump a preprocessed dataset as hdf5", + ) + # training configuration + parser.add_argument("--opt", default="sgd", type=str, help="Optimizer") + parser.add_argument( + "--sortagrad", + default=0, + type=int, + nargs="?", + help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs", + ) + parser.add_argument( + "--batchsize", + "-b", + type=int, + default=300, + help="Number of examples in each mini-batch", + ) + parser.add_argument( + "--accum-grad", type=int, default=1, help="Number of gradient accumueration" + ) + parser.add_argument( + "--epoch", + "-e", + type=int, + default=20, + help="Number of sweeps over the dataset to train", + ) + parser.add_argument( + "--early-stop-criterion", + default="validation/main/loss", + type=str, + nargs="?", + help="Value to monitor to trigger an early stopping of the training", + ) + parser.add_argument( + "--patience", + default=3, + type=int, + nargs="?", + help="Number of epochs " + "to wait without improvement before stopping the training", + ) + parser.add_argument( + "--schedulers", + default=None, + action="append", + type=lambda kv: kv.split("="), + help="optimizer schedulers, you can configure params like:" + " --" + ' e.g., "--schedulers lr=noam --lr-noam-warmup 1000".', + ) + parser.add_argument( + "--gradclip", + "-c", + type=float, + default=5, + help="Gradient norm threshold to clip", + ) + parser.add_argument( + "--maxlen", + type=int, + default=40, + help="Batch size is reduced if the input sequence > ML", + ) + parser.add_argument( + "--model-module", + type=str, + default="default", + help="model defined module " + "(default: espnet.nets.xxx_backend.lm.default:DefaultRNNLM)", + ) + return parser + + +def main(cmd_args): + """Train LM.""" + parser = get_parser() + args, _ = parser.parse_known_args(cmd_args) + if args.backend == "chainer" and args.train_dtype != "float32": + raise NotImplementedError( + f"chainer backend does not support --train-dtype {args.train_dtype}." + "Use --dtype float32." + ) + if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"): + raise ValueError( + f"--train-dtype {args.train_dtype} does not support the CPU backend." + ) + + # parse arguments dynamically + model_class = dynamic_import_lm(args.model_module, args.backend) + model_class.add_arguments(parser) + if args.schedulers is not None: + for k, v in args.schedulers: + scheduler_class = dynamic_import_scheduler(v) + scheduler_class.add_arguments(k, parser) + + opt_class = dynamic_import_optimizer(args.opt, args.backend) + opt_class.add_arguments(parser) + + args = parser.parse_args(cmd_args) + + # add version info in args + args.version = __version__ + + # logging info + if args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # If --ngpu is not given, + # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices + # 2. if nvidia-smi exists, use all devices + # 3. else ngpu=0 + if args.ngpu is None: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is not None: + ngpu = len(cvd.split(",")) + else: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + try: + p = subprocess.run( + ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + except (subprocess.CalledProcessError, FileNotFoundError): + ngpu = 0 + else: + ngpu = len(p.stderr.decode().split("\n")) - 1 + args.ngpu = ngpu + else: + ngpu = args.ngpu + logging.info(f"ngpu: {ngpu}") + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # seed setting + nseed = args.seed + random.seed(nseed) + np.random.seed(nseed) + + # load dictionary + with open(args.dict, "rb") as f: + dictionary = f.readlines() + char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary] + char_list.insert(0, "") + char_list.append("") + args.char_list_dict = {x: i for i, x in enumerate(char_list)} + args.n_vocab = len(char_list) + + # train + logging.info("backend = " + args.backend) + if args.backend == "chainer": + from espnet.lm.chainer_backend.lm import train + + train(args) + elif args.backend == "pytorch": + from espnet.lm.pytorch_backend.lm import train + + train(args) + else: + raise ValueError("Only chainer and pytorch are supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/mt_train.py b/espnet/bin/mt_train.py new file mode 100644 index 0000000000000000000000000000000000000000..7251617e0985565d8026ca8ab1d6c937c974289f --- /dev/null +++ b/espnet/bin/mt_train.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Neural machine translation model training script.""" + +import logging +import os +import random +import subprocess +import sys + +from distutils.version import LooseVersion + +import configargparse +import numpy as np +import torch + +from espnet import __version__ +from espnet.utils.cli_utils import strtobool +from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES + +is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2") + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(parser=None, required=True): + """Get default arguments.""" + if parser is None: + parser = configargparse.ArgumentParser( + description="Train a neural machine translation (NMT) model on one CPU, " + "one or multiple GPUs", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="config file path") + parser.add( + "--config2", + is_config_file=True, + help="second config file path that overwrites the settings in `--config`.", + ) + parser.add( + "--config3", + is_config_file=True, + help="third config file path that overwrites the settings " + "in `--config` and `--config2`.", + ) + + parser.add_argument( + "--ngpu", + default=None, + type=int, + help="Number of GPUs. If not given, use all visible devices", + ) + parser.add_argument( + "--train-dtype", + default="float32", + choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"], + help="Data type for training (only pytorch backend). " + "O0,O1,.. flags require apex. " + "See https://nvidia.github.io/apex/amp.html#opt-levels", + ) + parser.add_argument( + "--backend", + default="chainer", + type=str, + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument( + "--outdir", type=str, required=required, help="Output directory" + ) + parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") + parser.add_argument( + "--dict", required=required, help="Dictionary for source/target languages" + ) + parser.add_argument("--seed", default=1, type=int, help="Random seed") + parser.add_argument("--debugdir", type=str, help="Output directory for debugging") + parser.add_argument( + "--resume", + "-r", + default="", + nargs="?", + help="Resume the training from snapshot", + ) + parser.add_argument( + "--minibatches", + "-N", + type=int, + default="-1", + help="Process only N minibatches (for debug)", + ) + parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--tensorboard-dir", + default=None, + type=str, + nargs="?", + help="Tensorboard log dir path", + ) + parser.add_argument( + "--report-interval-iters", + default=100, + type=int, + help="Report interval iterations", + ) + parser.add_argument( + "--save-interval-iters", + default=0, + type=int, + help="Save snapshot interval iterations", + ) + # task related + parser.add_argument( + "--train-json", + type=str, + default=None, + help="Filename of train label data (json)", + ) + parser.add_argument( + "--valid-json", + type=str, + default=None, + help="Filename of validation label data (json)", + ) + # network architecture + parser.add_argument( + "--model-module", + type=str, + default=None, + help="model defined module (default: espnet.nets.xxx_backend.e2e_mt:E2E)", + ) + # loss related + parser.add_argument( + "--lsm-weight", default=0.0, type=float, help="Label smoothing weight" + ) + # translations options to compute BLEU + parser.add_argument( + "--report-bleu", + default=True, + action="store_true", + help="Compute BLEU on development set", + ) + parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=4, help="Beam size") + parser.add_argument("--penalty", default=0.0, type=float, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + default=0.0, + type=float, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths""", + ) + parser.add_argument( + "--minlenratio", + default=0.0, + type=float, + help="Input length ratio to obtain min output length", + ) + parser.add_argument( + "--rnnlm", type=str, default=None, help="RNNLM model file to read" + ) + parser.add_argument( + "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read" + ) + parser.add_argument("--lm-weight", default=0.0, type=float, help="RNNLM weight.") + parser.add_argument("--sym-space", default="", type=str, help="Space symbol") + parser.add_argument("--sym-blank", default="", type=str, help="Blank symbol") + # minibatch related + parser.add_argument( + "--sortagrad", + default=0, + type=int, + nargs="?", + help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs", + ) + parser.add_argument( + "--batch-count", + default="auto", + choices=BATCH_COUNT_CHOICES, + help="How to count batch_size. " + "The default (auto) will find how to count by args.", + ) + parser.add_argument( + "--batch-size", + "--batch-seqs", + "-b", + default=0, + type=int, + help="Maximum seqs in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-bins", + default=0, + type=int, + help="Maximum bins in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-in", + default=0, + type=int, + help="Maximum input frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-out", + default=0, + type=int, + help="Maximum output frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-inout", + default=0, + type=int, + help="Maximum input+output frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--maxlen-in", + "--batch-seq-maxlen-in", + default=100, + type=int, + metavar="ML", + help="When --batch-count=seq, " + "batch size is reduced if the input sequence length > ML.", + ) + parser.add_argument( + "--maxlen-out", + "--batch-seq-maxlen-out", + default=100, + type=int, + metavar="ML", + help="When --batch-count=seq, " + "batch size is reduced if the output sequence length > ML", + ) + parser.add_argument( + "--n-iter-processes", + default=0, + type=int, + help="Number of processes of iterator", + ) + # optimization related + parser.add_argument( + "--opt", + default="adadelta", + type=str, + choices=["adadelta", "adam", "noam"], + help="Optimizer", + ) + parser.add_argument( + "--accum-grad", default=1, type=int, help="Number of gradient accumuration" + ) + parser.add_argument( + "--eps", default=1e-8, type=float, help="Epsilon constant for optimizer" + ) + parser.add_argument( + "--eps-decay", default=0.01, type=float, help="Decaying ratio of epsilon" + ) + parser.add_argument( + "--lr", default=1e-3, type=float, help="Learning rate for optimizer" + ) + parser.add_argument( + "--lr-decay", default=1.0, type=float, help="Decaying ratio of learning rate" + ) + parser.add_argument( + "--weight-decay", default=0.0, type=float, help="Weight decay ratio" + ) + parser.add_argument( + "--criterion", + default="acc", + type=str, + choices=["loss", "acc"], + help="Criterion to perform epsilon decay", + ) + parser.add_argument( + "--threshold", default=1e-4, type=float, help="Threshold to stop iteration" + ) + parser.add_argument( + "--epochs", "-e", default=30, type=int, help="Maximum number of epochs" + ) + parser.add_argument( + "--early-stop-criterion", + default="validation/main/acc", + type=str, + nargs="?", + help="Value to monitor to trigger an early stopping of the training", + ) + parser.add_argument( + "--patience", + default=3, + type=int, + nargs="?", + help="Number of epochs to wait " + "without improvement before stopping the training", + ) + parser.add_argument( + "--grad-clip", default=5, type=float, help="Gradient norm threshold to clip" + ) + parser.add_argument( + "--num-save-attention", + default=3, + type=int, + help="Number of samples of attention to be saved", + ) + # decoder related + parser.add_argument( + "--context-residual", + default=False, + type=strtobool, + nargs="?", + help="The flag to switch to use context vector residual in the decoder network", + ) + parser.add_argument( + "--tie-src-tgt-embedding", + default=False, + type=strtobool, + nargs="?", + help="Tie parameters of source embedding and target embedding.", + ) + parser.add_argument( + "--tie-classifier", + default=False, + type=strtobool, + nargs="?", + help="Tie parameters of target embedding and output projection layer.", + ) + # finetuning related + parser.add_argument( + "--enc-init", + default=None, + type=str, + nargs="?", + help="Pre-trained ASR model to initialize encoder.", + ) + parser.add_argument( + "--enc-init-mods", + default="enc.enc.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of encoder modules to initialize, separated by a comma.", + ) + parser.add_argument( + "--dec-init", + default=None, + type=str, + nargs="?", + help="Pre-trained ASR, MT or LM model to initialize decoder.", + ) + parser.add_argument( + "--dec-init-mods", + default="att., dec.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of decoder modules to initialize, separated by a comma.", + ) + # multilingual related + parser.add_argument( + "--multilingual", + default=False, + type=strtobool, + help="Prepend target language ID to the source sentence. " + "Both source/target language IDs must be prepend in the pre-processing stage.", + ) + parser.add_argument( + "--replace-sos", + default=False, + type=strtobool, + help="Replace in the decoder with a target language ID " + "(the first token in the target sequence)", + ) + + return parser + + +def main(cmd_args): + """Run the main training function.""" + parser = get_parser() + args, _ = parser.parse_known_args(cmd_args) + if args.backend == "chainer" and args.train_dtype != "float32": + raise NotImplementedError( + f"chainer backend does not support --train-dtype {args.train_dtype}." + "Use --dtype float32." + ) + if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"): + raise ValueError( + f"--train-dtype {args.train_dtype} does not support the CPU backend." + ) + + from espnet.utils.dynamic_import import dynamic_import + + if args.model_module is None: + model_module = "espnet.nets." + args.backend + "_backend.e2e_mt:E2E" + else: + model_module = args.model_module + model_class = dynamic_import(model_module) + model_class.add_arguments(parser) + + args = parser.parse_args(cmd_args) + args.model_module = model_module + if "chainer_backend" in args.model_module: + args.backend = "chainer" + if "pytorch_backend" in args.model_module: + args.backend = "pytorch" + + # add version info in args + args.version = __version__ + + # logging info + if args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # If --ngpu is not given, + # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices + # 2. if nvidia-smi exists, use all devices + # 3. else ngpu=0 + if args.ngpu is None: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is not None: + ngpu = len(cvd.split(",")) + else: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + try: + p = subprocess.run( + ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + except (subprocess.CalledProcessError, FileNotFoundError): + ngpu = 0 + else: + ngpu = len(p.stderr.decode().split("\n")) - 1 + args.ngpu = ngpu + else: + if is_torch_1_2_plus and args.ngpu != 1: + logging.debug( + "There are some bugs with multi-GPU processing in PyTorch 1.2+" + + " (see https://github.com/pytorch/pytorch/issues/21108)" + ) + ngpu = args.ngpu + logging.info(f"ngpu: {ngpu}") + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # set random seed + logging.info("random seed = %d" % args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + # load dictionary for debug log + if args.dict is not None: + with open(args.dict, "rb") as f: + dictionary = f.readlines() + char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary] + char_list.insert(0, "") + char_list.append("") + args.char_list = char_list + else: + args.char_list = None + + # train + logging.info("backend = " + args.backend) + + if args.backend == "pytorch": + from espnet.mt.pytorch_backend.mt import train + + train(args) + else: + raise ValueError("Only pytorch are supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/mt_trans.py b/espnet/bin/mt_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..c229f16d79fcb32325337ad614229344328628fe --- /dev/null +++ b/espnet/bin/mt_trans.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Neural machine translation model decoding script.""" + +import configargparse +import logging +import os +import random +import sys + +import numpy as np + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(): + """Get default arguments.""" + parser = configargparse.ArgumentParser( + description="Translate text from speech " + "using a speech translation model on one CPU or GPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="Config file path") + parser.add( + "--config2", + is_config_file=True, + help="Second config file path that overwrites the settings in `--config`", + ) + parser.add( + "--config3", + is_config_file=True, + help="Third config file path " + "that overwrites the settings in `--config` and `--config2`", + ) + + parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") + parser.add_argument( + "--dtype", + choices=("float16", "float32", "float64"), + default="float32", + help="Float precision (only available in --api v2)", + ) + parser.add_argument( + "--backend", + type=str, + default="chainer", + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option") + parser.add_argument( + "--batchsize", + type=int, + default=1, + help="Batch size for beam search (0: means no batch processing)", + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + parser.add_argument( + "--api", + default="v1", + choices=["v1", "v2"], + help="Beam search APIs " + "v1: Default API. It only supports " + "the ASRInterface.recognize method and DefaultRNNLM. " + "v2: Experimental API. " + "It supports any models that implements ScorerInterface.", + ) + # task related + parser.add_argument( + "--trans-json", type=str, help="Filename of translation data (json)" + ) + parser.add_argument( + "--result-label", + type=str, + required=True, + help="Filename of result label data (json)", + ) + # model (parameter) related + parser.add_argument( + "--model", type=str, required=True, help="Model file parameters to read" + ) + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file" + ) + # search related + parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=1, help="Beam size") + parser.add_argument("--penalty", type=float, default=0.1, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + type=float, + default=3.0, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths""", + ) + parser.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + # multilingual related + parser.add_argument( + "--tgt-lang", + default=False, + type=str, + help="target language ID (e.g., , , and etc.)", + ) + return parser + + +def main(args): + """Run the main decoding function.""" + parser = get_parser() + args = parser.parse_args(args) + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose == 2: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # TODO(mn5k): support of multiple GPUs + if args.ngpu > 1: + logging.error("The program only supports ngpu=1.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # seed setting + random.seed(args.seed) + np.random.seed(args.seed) + logging.info("set random seed = %d" % args.seed) + + # trans + logging.info("backend = " + args.backend) + if args.backend == "pytorch": + # Experimental API that supports custom LMs + from espnet.mt.pytorch_backend.mt import trans + + if args.dtype != "float32": + raise NotImplementedError( + f"`--dtype {args.dtype}` is only available with `--api v2`" + ) + trans(args) + else: + raise ValueError("Only pytorch are supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/st_train.py b/espnet/bin/st_train.py new file mode 100644 index 0000000000000000000000000000000000000000..4398d6aaa0c68db56dd66dacf089cd10ee189465 --- /dev/null +++ b/espnet/bin/st_train.py @@ -0,0 +1,550 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""End-to-end speech translation model training script.""" + +from distutils.version import LooseVersion +import logging +import os +import random +import subprocess +import sys + +import configargparse +import numpy as np +import torch + +from espnet import __version__ +from espnet.utils.cli_utils import strtobool +from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES + +is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2") + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(parser=None, required=True): + """Get default arguments.""" + if parser is None: + parser = configargparse.ArgumentParser( + description="Train a speech translation (ST) model on one CPU, " + "one or multiple GPUs", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="config file path") + parser.add( + "--config2", + is_config_file=True, + help="second config file path that overwrites the settings in `--config`.", + ) + parser.add( + "--config3", + is_config_file=True, + help="third config file path that overwrites the settings " + "in `--config` and `--config2`.", + ) + + parser.add_argument( + "--ngpu", + default=None, + type=int, + help="Number of GPUs. If not given, use all visible devices", + ) + parser.add_argument( + "--train-dtype", + default="float32", + choices=["float16", "float32", "float64", "O0", "O1", "O2", "O3"], + help="Data type for training (only pytorch backend). " + "O0,O1,.. flags require apex. " + "See https://nvidia.github.io/apex/amp.html#opt-levels", + ) + parser.add_argument( + "--backend", + default="chainer", + type=str, + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument( + "--outdir", type=str, required=required, help="Output directory" + ) + parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") + parser.add_argument("--dict", required=required, help="Dictionary") + parser.add_argument("--seed", default=1, type=int, help="Random seed") + parser.add_argument("--debugdir", type=str, help="Output directory for debugging") + parser.add_argument( + "--resume", + "-r", + default="", + nargs="?", + help="Resume the training from snapshot", + ) + parser.add_argument( + "--minibatches", + "-N", + type=int, + default="-1", + help="Process only N minibatches (for debug)", + ) + parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--tensorboard-dir", + default=None, + type=str, + nargs="?", + help="Tensorboard log dir path", + ) + parser.add_argument( + "--report-interval-iters", + default=100, + type=int, + help="Report interval iterations", + ) + parser.add_argument( + "--save-interval-iters", + default=0, + type=int, + help="Save snapshot interval iterations", + ) + # task related + parser.add_argument( + "--train-json", + type=str, + default=None, + help="Filename of train label data (json)", + ) + parser.add_argument( + "--valid-json", + type=str, + default=None, + help="Filename of validation label data (json)", + ) + # network architecture + parser.add_argument( + "--model-module", + type=str, + default=None, + help="model defined module (default: espnet.nets.xxx_backend.e2e_st:E2E)", + ) + # loss related + parser.add_argument( + "--ctc_type", + default="warpctc", + type=str, + choices=["builtin", "warpctc", "gtnctc", "cudnnctc"], + help="Type of CTC implementation to calculate loss.", + ) + parser.add_argument( + "--mtlalpha", + default=0.0, + type=float, + help="Multitask learning coefficient, alpha: \ + alpha*ctc_loss + (1-alpha)*att_loss", + ) + parser.add_argument( + "--asr-weight", + default=0.0, + type=float, + help="Multitask learning coefficient for ASR task, weight: " + " asr_weight*(alpha*ctc_loss + (1-alpha)*att_loss)" + " + (1-asr_weight-mt_weight)*st_loss", + ) + parser.add_argument( + "--mt-weight", + default=0.0, + type=float, + help="Multitask learning coefficient for MT task, weight: \ + mt_weight*mt_loss + (1-mt_weight-asr_weight)*st_loss", + ) + parser.add_argument( + "--lsm-weight", default=0.0, type=float, help="Label smoothing weight" + ) + # recognition options to compute CER/WER + parser.add_argument( + "--report-cer", + default=False, + action="store_true", + help="Compute CER on development set", + ) + parser.add_argument( + "--report-wer", + default=False, + action="store_true", + help="Compute WER on development set", + ) + # translations options to compute BLEU + parser.add_argument( + "--report-bleu", + default=True, + action="store_true", + help="Compute BLEU on development set", + ) + parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=4, help="Beam size") + parser.add_argument("--penalty", default=0.0, type=float, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + default=0.0, + type=float, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths""", + ) + parser.add_argument( + "--minlenratio", + default=0.0, + type=float, + help="Input length ratio to obtain min output length", + ) + parser.add_argument( + "--rnnlm", type=str, default=None, help="RNNLM model file to read" + ) + parser.add_argument( + "--rnnlm-conf", type=str, default=None, help="RNNLM model config file to read" + ) + parser.add_argument("--lm-weight", default=0.0, type=float, help="RNNLM weight.") + parser.add_argument("--sym-space", default="", type=str, help="Space symbol") + parser.add_argument("--sym-blank", default="", type=str, help="Blank symbol") + # minibatch related + parser.add_argument( + "--sortagrad", + default=0, + type=int, + nargs="?", + help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs", + ) + parser.add_argument( + "--batch-count", + default="auto", + choices=BATCH_COUNT_CHOICES, + help="How to count batch_size. " + "The default (auto) will find how to count by args.", + ) + parser.add_argument( + "--batch-size", + "--batch-seqs", + "-b", + default=0, + type=int, + help="Maximum seqs in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-bins", + default=0, + type=int, + help="Maximum bins in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-in", + default=0, + type=int, + help="Maximum input frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-out", + default=0, + type=int, + help="Maximum output frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-inout", + default=0, + type=int, + help="Maximum input+output frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--maxlen-in", + "--batch-seq-maxlen-in", + default=800, + type=int, + metavar="ML", + help="When --batch-count=seq, batch size is reduced " + "if the input sequence length > ML.", + ) + parser.add_argument( + "--maxlen-out", + "--batch-seq-maxlen-out", + default=150, + type=int, + metavar="ML", + help="When --batch-count=seq, " + "batch size is reduced if the output sequence length > ML", + ) + parser.add_argument( + "--n-iter-processes", + default=0, + type=int, + help="Number of processes of iterator", + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + nargs="?", + help="The configuration file for the pre-processing", + ) + # optimization related + parser.add_argument( + "--opt", + default="adadelta", + type=str, + choices=["adadelta", "adam", "noam"], + help="Optimizer", + ) + parser.add_argument( + "--accum-grad", default=1, type=int, help="Number of gradient accumuration" + ) + parser.add_argument( + "--eps", default=1e-8, type=float, help="Epsilon constant for optimizer" + ) + parser.add_argument( + "--eps-decay", default=0.01, type=float, help="Decaying ratio of epsilon" + ) + parser.add_argument( + "--lr", default=1e-3, type=float, help="Learning rate for optimizer" + ) + parser.add_argument( + "--lr-decay", default=1.0, type=float, help="Decaying ratio of learning rate" + ) + parser.add_argument( + "--weight-decay", default=0.0, type=float, help="Weight decay ratio" + ) + parser.add_argument( + "--criterion", + default="acc", + type=str, + choices=["loss", "acc"], + help="Criterion to perform epsilon decay", + ) + parser.add_argument( + "--threshold", default=1e-4, type=float, help="Threshold to stop iteration" + ) + parser.add_argument( + "--epochs", "-e", default=30, type=int, help="Maximum number of epochs" + ) + parser.add_argument( + "--early-stop-criterion", + default="validation/main/acc", + type=str, + nargs="?", + help="Value to monitor to trigger an early stopping of the training", + ) + parser.add_argument( + "--patience", + default=3, + type=int, + nargs="?", + help="Number of epochs to wait " + "without improvement before stopping the training", + ) + parser.add_argument( + "--grad-clip", default=5, type=float, help="Gradient norm threshold to clip" + ) + parser.add_argument( + "--num-save-attention", + default=3, + type=int, + help="Number of samples of attention to be saved", + ) + parser.add_argument( + "--num-save-ctc", + default=3, + type=int, + help="Number of samples of CTC probability to be saved", + ) + parser.add_argument( + "--grad-noise", + type=strtobool, + default=False, + help="The flag to switch to use noise injection to gradients during training", + ) + # speech translation related + parser.add_argument( + "--context-residual", + default=False, + type=strtobool, + nargs="?", + help="The flag to switch to use context vector residual in the decoder network", + ) + # finetuning related + parser.add_argument( + "--enc-init", + default=None, + type=str, + nargs="?", + help="Pre-trained ASR model to initialize encoder.", + ) + parser.add_argument( + "--enc-init-mods", + default="enc.enc.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of encoder modules to initialize, separated by a comma.", + ) + parser.add_argument( + "--dec-init", + default=None, + type=str, + nargs="?", + help="Pre-trained ASR, MT or LM model to initialize decoder.", + ) + parser.add_argument( + "--dec-init-mods", + default="att., dec.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of decoder modules to initialize, separated by a comma.", + ) + # multilingual related + parser.add_argument( + "--multilingual", + default=False, + type=strtobool, + help="Prepend target language ID to the source sentence. " + " Both source/target language IDs must be prepend in the pre-processing stage.", + ) + parser.add_argument( + "--replace-sos", + default=False, + type=strtobool, + help="Replace in the decoder with a target language ID \ + (the first token in the target sequence)", + ) + # Feature transform: Normalization + parser.add_argument( + "--stats-file", + type=str, + default=None, + help="The stats file for the feature normalization", + ) + parser.add_argument( + "--apply-uttmvn", + type=strtobool, + default=True, + help="Apply utterance level mean " "variance normalization.", + ) + parser.add_argument("--uttmvn-norm-means", type=strtobool, default=True, help="") + parser.add_argument("--uttmvn-norm-vars", type=strtobool, default=False, help="") + # Feature transform: Fbank + parser.add_argument( + "--fbank-fs", + type=int, + default=16000, + help="The sample frequency used for " "the mel-fbank creation.", + ) + parser.add_argument( + "--n-mels", type=int, default=80, help="The number of mel-frequency bins." + ) + parser.add_argument("--fbank-fmin", type=float, default=0.0, help="") + parser.add_argument("--fbank-fmax", type=float, default=None, help="") + return parser + + +def main(cmd_args): + """Run the main training function.""" + parser = get_parser() + args, _ = parser.parse_known_args(cmd_args) + if args.backend == "chainer" and args.train_dtype != "float32": + raise NotImplementedError( + f"chainer backend does not support --train-dtype {args.train_dtype}." + "Use --dtype float32." + ) + if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"): + raise ValueError( + f"--train-dtype {args.train_dtype} does not support the CPU backend." + ) + + from espnet.utils.dynamic_import import dynamic_import + + if args.model_module is None: + model_module = "espnet.nets." + args.backend + "_backend.e2e_st:E2E" + else: + model_module = args.model_module + model_class = dynamic_import(model_module) + model_class.add_arguments(parser) + + args = parser.parse_args(cmd_args) + args.model_module = model_module + if "chainer_backend" in args.model_module: + args.backend = "chainer" + if "pytorch_backend" in args.model_module: + args.backend = "pytorch" + + # add version info in args + args.version = __version__ + + # logging info + if args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # If --ngpu is not given, + # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices + # 2. if nvidia-smi exists, use all devices + # 3. else ngpu=0 + if args.ngpu is None: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is not None: + ngpu = len(cvd.split(",")) + else: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + try: + p = subprocess.run( + ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + except (subprocess.CalledProcessError, FileNotFoundError): + ngpu = 0 + else: + ngpu = len(p.stderr.decode().split("\n")) - 1 + args.ngpu = ngpu + else: + if is_torch_1_2_plus and args.ngpu != 1: + logging.debug( + "There are some bugs with multi-GPU processing in PyTorch 1.2+" + + " (see https://github.com/pytorch/pytorch/issues/21108)" + ) + ngpu = args.ngpu + logging.info(f"ngpu: {ngpu}") + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # set random seed + logging.info("random seed = %d" % args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + # load dictionary for debug log + if args.dict is not None: + with open(args.dict, "rb") as f: + dictionary = f.readlines() + char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary] + char_list.insert(0, "") + char_list.append("") + args.char_list = char_list + else: + args.char_list = None + + # train + logging.info("backend = " + args.backend) + + if args.backend == "pytorch": + from espnet.st.pytorch_backend.st import train + + train(args) + else: + raise ValueError("Only pytorch are supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/st_trans.py b/espnet/bin/st_trans.py new file mode 100644 index 0000000000000000000000000000000000000000..227a10b0a222f95335d0db67119b55a39aff5ad4 --- /dev/null +++ b/espnet/bin/st_trans.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""End-to-end speech translation model decoding script.""" + +import logging +import os +import random +import sys + +import configargparse +import numpy as np + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(): + """Get default arguments.""" + parser = configargparse.ArgumentParser( + description="Translate text from speech using a speech translation " + "model on one CPU or GPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="Config file path") + parser.add( + "--config2", + is_config_file=True, + help="Second config file path that overwrites the settings in `--config`", + ) + parser.add( + "--config3", + is_config_file=True, + help="Third config file path that overwrites " + "the settings in `--config` and `--config2`", + ) + + parser.add_argument("--ngpu", type=int, default=0, help="Number of GPUs") + parser.add_argument( + "--dtype", + choices=("float16", "float32", "float64"), + default="float32", + help="Float precision (only available in --api v2)", + ) + parser.add_argument( + "--backend", + type=str, + default="chainer", + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument("--debugmode", type=int, default=1, help="Debugmode") + parser.add_argument("--seed", type=int, default=1, help="Random seed") + parser.add_argument("--verbose", "-V", type=int, default=1, help="Verbose option") + parser.add_argument( + "--batchsize", + type=int, + default=1, + help="Batch size for beam search (0: means no batch processing)", + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + parser.add_argument( + "--api", + default="v1", + choices=["v1", "v2"], + help="Beam search APIs " + "v1: Default API. " + "It only supports the ASRInterface.recognize method and DefaultRNNLM. " + "v2: Experimental API. " + "It supports any models that implements ScorerInterface.", + ) + # task related + parser.add_argument( + "--trans-json", type=str, help="Filename of translation data (json)" + ) + parser.add_argument( + "--result-label", + type=str, + required=True, + help="Filename of result label data (json)", + ) + # model (parameter) related + parser.add_argument( + "--model", type=str, required=True, help="Model file parameters to read" + ) + # search related + parser.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + parser.add_argument("--beam-size", type=int, default=1, help="Beam size") + parser.add_argument("--penalty", type=float, default=0.0, help="Incertion penalty") + parser.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="""Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths""", + ) + parser.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + # multilingual related + parser.add_argument( + "--tgt-lang", + default=False, + type=str, + help="target language ID (e.g., , , and etc.)", + ) + return parser + + +def main(args): + """Run the main decoding function.""" + parser = get_parser() + args = parser.parse_args(args) + + # logging info + if args.verbose == 1: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + elif args.verbose == 2: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # TODO(mn5k): support of multiple GPUs + if args.ngpu > 1: + logging.error("The program only supports ngpu=1.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # seed setting + random.seed(args.seed) + np.random.seed(args.seed) + logging.info("set random seed = %d" % args.seed) + + # trans + logging.info("backend = " + args.backend) + if args.backend == "pytorch": + # Experimental API that supports custom LMs + from espnet.st.pytorch_backend.st import trans + + if args.dtype != "float32": + raise NotImplementedError( + f"`--dtype {args.dtype}` is only available with `--api v2`" + ) + trans(args) + else: + raise ValueError("Only pytorch are supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/tts_decode.py b/espnet/bin/tts_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..8c04b1024587e6c99458a37754f062d33ec381f3 --- /dev/null +++ b/espnet/bin/tts_decode.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""TTS decoding script.""" + +import configargparse +import logging +import os +import platform +import subprocess +import sys + +from espnet.utils.cli_utils import strtobool + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(): + """Get parser of decoding arguments.""" + parser = configargparse.ArgumentParser( + description="Synthesize speech from text using a TTS model on one CPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="config file path") + parser.add( + "--config2", + is_config_file=True, + help="second config file path that overwrites the settings in `--config`.", + ) + parser.add( + "--config3", + is_config_file=True, + help="third config file path that overwrites " + "the settings in `--config` and `--config2`.", + ) + + parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs") + parser.add_argument( + "--backend", + default="pytorch", + type=str, + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") + parser.add_argument("--seed", default=1, type=int, help="Random seed") + parser.add_argument("--out", type=str, required=True, help="Output filename") + parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + # task related + parser.add_argument( + "--json", type=str, required=True, help="Filename of train label data (json)" + ) + parser.add_argument( + "--model", type=str, required=True, help="Model file parameters to read" + ) + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file" + ) + # decoding related + parser.add_argument( + "--maxlenratio", type=float, default=5, help="Maximum length ratio in decoding" + ) + parser.add_argument( + "--minlenratio", type=float, default=0, help="Minimum length ratio in decoding" + ) + parser.add_argument( + "--threshold", type=float, default=0.5, help="Threshold value in decoding" + ) + parser.add_argument( + "--use-att-constraint", + type=strtobool, + default=False, + help="Whether to use the attention constraint", + ) + parser.add_argument( + "--backward-window", + type=int, + default=1, + help="Backward window size in the attention constraint", + ) + parser.add_argument( + "--forward-window", + type=int, + default=3, + help="Forward window size in the attention constraint", + ) + parser.add_argument( + "--fastspeech-alpha", + type=float, + default=1.0, + help="Alpha to change the speed for FastSpeech", + ) + # save related + parser.add_argument( + "--save-durations", + default=False, + type=strtobool, + help="Whether to save durations converted from attentions", + ) + parser.add_argument( + "--save-focus-rates", + default=False, + type=strtobool, + help="Whether to save focus rates of attentions", + ) + return parser + + +def main(args): + """Run deocding.""" + parser = get_parser() + args = parser.parse_args(args) + + # logging info + if args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + # python 2 case + if platform.python_version_tuple()[0] == "2": + if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]): + cvd = subprocess.check_output( + ["/usr/local/bin/free-gpu", "-n", str(args.ngpu)] + ).strip() + logging.info("CLSP: use gpu" + cvd) + os.environ["CUDA_VISIBLE_DEVICES"] = cvd + # python 3 case + else: + if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]).decode(): + cvd = ( + subprocess.check_output( + ["/usr/local/bin/free-gpu", "-n", str(args.ngpu)] + ) + .decode() + .strip() + ) + logging.info("CLSP: use gpu" + cvd) + os.environ["CUDA_VISIBLE_DEVICES"] = cvd + + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # extract + logging.info("backend = " + args.backend) + if args.backend == "pytorch": + from espnet.tts.pytorch_backend.tts import decode + + decode(args) + else: + raise NotImplementedError("Only pytorch is supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/tts_train.py b/espnet/bin/tts_train.py new file mode 100644 index 0000000000000000000000000000000000000000..930f2583bb414327c0e0c946a7578318b48d41f4 --- /dev/null +++ b/espnet/bin/tts_train.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Text-to-speech model training script.""" + +import logging +import os +import random +import subprocess +import sys + +import configargparse +import numpy as np + +from espnet import __version__ +from espnet.nets.tts_interface import TTSInterface +from espnet.utils.cli_utils import strtobool +from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(): + """Get parser of training arguments.""" + parser = configargparse.ArgumentParser( + description="Train a new text-to-speech (TTS) model on one CPU, " + "one or multiple GPUs", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + + # general configuration + parser.add("--config", is_config_file=True, help="config file path") + parser.add( + "--config2", + is_config_file=True, + help="second config file path that overwrites the settings in `--config`.", + ) + parser.add( + "--config3", + is_config_file=True, + help="third config file path that overwrites " + "the settings in `--config` and `--config2`.", + ) + + parser.add_argument( + "--ngpu", + default=None, + type=int, + help="Number of GPUs. If not given, use all visible devices", + ) + parser.add_argument( + "--backend", + default="pytorch", + type=str, + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument("--outdir", type=str, required=True, help="Output directory") + parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") + parser.add_argument("--seed", default=1, type=int, help="Random seed") + parser.add_argument( + "--resume", + "-r", + default="", + type=str, + nargs="?", + help="Resume the training from snapshot", + ) + parser.add_argument( + "--minibatches", + "-N", + type=int, + default="-1", + help="Process only N minibatches (for debug)", + ) + parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--tensorboard-dir", + default=None, + type=str, + nargs="?", + help="Tensorboard log directory path", + ) + parser.add_argument( + "--eval-interval-epochs", default=1, type=int, help="Evaluation interval epochs" + ) + parser.add_argument( + "--save-interval-epochs", default=1, type=int, help="Save interval epochs" + ) + parser.add_argument( + "--report-interval-iters", + default=100, + type=int, + help="Report interval iterations", + ) + # task related + parser.add_argument( + "--train-json", type=str, required=True, help="Filename of training json" + ) + parser.add_argument( + "--valid-json", type=str, required=True, help="Filename of validation json" + ) + # network architecture + parser.add_argument( + "--model-module", + type=str, + default="espnet.nets.pytorch_backend.e2e_tts_tacotron2:Tacotron2", + help="model defined module", + ) + # minibatch related + parser.add_argument( + "--sortagrad", + default=0, + type=int, + nargs="?", + help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs", + ) + parser.add_argument( + "--batch-sort-key", + default="shuffle", + type=str, + choices=["shuffle", "output", "input"], + nargs="?", + help='Batch sorting key. "shuffle" only work with --batch-count "seq".', + ) + parser.add_argument( + "--batch-count", + default="auto", + choices=BATCH_COUNT_CHOICES, + help="How to count batch_size. " + "The default (auto) will find how to count by args.", + ) + parser.add_argument( + "--batch-size", + "--batch-seqs", + "-b", + default=0, + type=int, + help="Maximum seqs in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-bins", + default=0, + type=int, + help="Maximum bins in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-in", + default=0, + type=int, + help="Maximum input frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-out", + default=0, + type=int, + help="Maximum output frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-inout", + default=0, + type=int, + help="Maximum input+output frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--maxlen-in", + "--batch-seq-maxlen-in", + default=100, + type=int, + metavar="ML", + help="When --batch-count=seq, " + "batch size is reduced if the input sequence length > ML.", + ) + parser.add_argument( + "--maxlen-out", + "--batch-seq-maxlen-out", + default=200, + type=int, + metavar="ML", + help="When --batch-count=seq, " + "batch size is reduced if the output sequence length > ML", + ) + parser.add_argument( + "--num-iter-processes", + default=0, + type=int, + help="Number of processes of iterator", + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + parser.add_argument( + "--use-speaker-embedding", + default=False, + type=strtobool, + help="Whether to use speaker embedding", + ) + parser.add_argument( + "--use-second-target", + default=False, + type=strtobool, + help="Whether to use second target", + ) + # optimization related + parser.add_argument( + "--opt", default="adam", type=str, choices=["adam", "noam"], help="Optimizer" + ) + parser.add_argument( + "--accum-grad", default=1, type=int, help="Number of gradient accumuration" + ) + parser.add_argument( + "--lr", default=1e-3, type=float, help="Learning rate for optimizer" + ) + parser.add_argument("--eps", default=1e-6, type=float, help="Epsilon for optimizer") + parser.add_argument( + "--weight-decay", + default=1e-6, + type=float, + help="Weight decay coefficient for optimizer", + ) + parser.add_argument( + "--epochs", "-e", default=30, type=int, help="Number of maximum epochs" + ) + parser.add_argument( + "--early-stop-criterion", + default="validation/main/loss", + type=str, + nargs="?", + help="Value to monitor to trigger an early stopping of the training", + ) + parser.add_argument( + "--patience", + default=3, + type=int, + nargs="?", + help="Number of epochs to wait " + "without improvement before stopping the training", + ) + parser.add_argument( + "--grad-clip", default=1, type=float, help="Gradient norm threshold to clip" + ) + parser.add_argument( + "--num-save-attention", + default=5, + type=int, + help="Number of samples of attention to be saved", + ) + parser.add_argument( + "--keep-all-data-on-mem", + default=False, + type=strtobool, + help="Whether to keep all data on memory", + ) + # finetuning related + parser.add_argument( + "--enc-init", + default=None, + type=str, + help="Pre-trained TTS model path to initialize encoder.", + ) + parser.add_argument( + "--enc-init-mods", + default="enc.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of encoder modules to initialize, separated by a comma.", + ) + parser.add_argument( + "--dec-init", + default=None, + type=str, + help="Pre-trained TTS model path to initialize decoder.", + ) + parser.add_argument( + "--dec-init-mods", + default="dec.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of decoder modules to initialize, separated by a comma.", + ) + parser.add_argument( + "--freeze-mods", + default=None, + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of modules to freeze (not to train), separated by a comma.", + ) + + return parser + + +def main(cmd_args): + """Run training.""" + parser = get_parser() + args, _ = parser.parse_known_args(cmd_args) + + from espnet.utils.dynamic_import import dynamic_import + + model_class = dynamic_import(args.model_module) + assert issubclass(model_class, TTSInterface) + model_class.add_arguments(parser) + args = parser.parse_args(cmd_args) + + # add version info in args + args.version = __version__ + + # logging info + if args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # If --ngpu is not given, + # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices + # 2. if nvidia-smi exists, use all devices + # 3. else ngpu=0 + if args.ngpu is None: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is not None: + ngpu = len(cvd.split(",")) + else: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + try: + p = subprocess.run( + ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + except (subprocess.CalledProcessError, FileNotFoundError): + ngpu = 0 + else: + ngpu = len(p.stderr.decode().split("\n")) - 1 + args.ngpu = ngpu + else: + ngpu = args.ngpu + logging.info(f"ngpu: {ngpu}") + + # set random seed + logging.info("random seed = %d" % args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + if args.backend == "pytorch": + from espnet.tts.pytorch_backend.tts import train + + train(args) + else: + raise NotImplementedError("Only pytorch is supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/vc_decode.py b/espnet/bin/vc_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..b45f5efde6827d5eb9a30010cfe65610ed27b175 --- /dev/null +++ b/espnet/bin/vc_decode.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Nagoya University (Wen-Chin Huang) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""VC decoding script.""" + +import configargparse +import logging +import os +import platform +import subprocess +import sys + +from espnet.utils.cli_utils import strtobool + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(): + """Get parser of decoding arguments.""" + parser = configargparse.ArgumentParser( + description="Converting speech using a VC model on one CPU", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + # general configuration + parser.add("--config", is_config_file=True, help="config file path") + parser.add( + "--config2", + is_config_file=True, + help="second config file path that overwrites the settings in `--config`.", + ) + parser.add( + "--config3", + is_config_file=True, + help="third config file path that overwrites the settings " + "in `--config` and `--config2`.", + ) + + parser.add_argument("--ngpu", default=0, type=int, help="Number of GPUs") + parser.add_argument( + "--backend", + default="pytorch", + type=str, + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") + parser.add_argument("--seed", default=1, type=int, help="Random seed") + parser.add_argument("--out", type=str, required=True, help="Output filename") + parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + # task related + parser.add_argument( + "--json", type=str, required=True, help="Filename of train label data (json)" + ) + parser.add_argument( + "--model", type=str, required=True, help="Model file parameters to read" + ) + parser.add_argument( + "--model-conf", type=str, default=None, help="Model config file" + ) + # decoding related + parser.add_argument( + "--maxlenratio", type=float, default=5, help="Maximum length ratio in decoding" + ) + parser.add_argument( + "--minlenratio", type=float, default=0, help="Minimum length ratio in decoding" + ) + parser.add_argument( + "--threshold", type=float, default=0.5, help="Threshold value in decoding" + ) + parser.add_argument( + "--use-att-constraint", + type=strtobool, + default=False, + help="Whether to use the attention constraint", + ) + parser.add_argument( + "--backward-window", + type=int, + default=1, + help="Backward window size in the attention constraint", + ) + parser.add_argument( + "--forward-window", + type=int, + default=3, + help="Forward window size in the attention constraint", + ) + # save related + parser.add_argument( + "--save-durations", + default=False, + type=strtobool, + help="Whether to save durations converted from attentions", + ) + parser.add_argument( + "--save-focus-rates", + default=False, + type=strtobool, + help="Whether to save focus rates of attentions", + ) + return parser + + +def main(args): + """Run deocding.""" + parser = get_parser() + args = parser.parse_args(args) + + # logging info + if args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # check CUDA_VISIBLE_DEVICES + if args.ngpu > 0: + # python 2 case + if platform.python_version_tuple()[0] == "2": + if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]): + cvd = subprocess.check_output( + ["/usr/local/bin/free-gpu", "-n", str(args.ngpu)] + ).strip() + logging.info("CLSP: use gpu" + cvd) + os.environ["CUDA_VISIBLE_DEVICES"] = cvd + # python 3 case + else: + if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]).decode(): + cvd = ( + subprocess.check_output( + ["/usr/local/bin/free-gpu", "-n", str(args.ngpu)] + ) + .decode() + .strip() + ) + logging.info("CLSP: use gpu" + cvd) + os.environ["CUDA_VISIBLE_DEVICES"] = cvd + + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is None: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + elif args.ngpu != len(cvd.split(",")): + logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.") + sys.exit(1) + + # display PYTHONPATH + logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)")) + + # extract + logging.info("backend = " + args.backend) + if args.backend == "pytorch": + from espnet.vc.pytorch_backend.vc import decode + + decode(args) + else: + raise NotImplementedError("Only pytorch is supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/bin/vc_train.py b/espnet/bin/vc_train.py new file mode 100644 index 0000000000000000000000000000000000000000..30ecfb6ac2f7adde70f28bbdd5a43c248e1c2101 --- /dev/null +++ b/espnet/bin/vc_train.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Nagoya University (Wen-Chin Huang) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Voice conversion model training script.""" + +import logging +import os +import random +import subprocess +import sys + +import configargparse +import numpy as np + +from espnet import __version__ +from espnet.nets.tts_interface import TTSInterface +from espnet.utils.cli_utils import strtobool +from espnet.utils.training.batchfy import BATCH_COUNT_CHOICES + + +# NOTE: you need this func to generate our sphinx doc +def get_parser(): + """Get parser of training arguments.""" + parser = configargparse.ArgumentParser( + description="Train a new voice conversion (VC) model on one CPU, " + "one or multiple GPUs", + config_file_parser_class=configargparse.YAMLConfigFileParser, + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + ) + + # general configuration + parser.add("--config", is_config_file=True, help="config file path") + parser.add( + "--config2", + is_config_file=True, + help="second config file path that overwrites the settings in `--config`.", + ) + parser.add( + "--config3", + is_config_file=True, + help="third config file path that overwrites the settings " + "in `--config` and `--config2`.", + ) + + parser.add_argument( + "--ngpu", + default=None, + type=int, + help="Number of GPUs. If not given, use all visible devices", + ) + parser.add_argument( + "--backend", + default="pytorch", + type=str, + choices=["chainer", "pytorch"], + help="Backend library", + ) + parser.add_argument("--outdir", type=str, required=True, help="Output directory") + parser.add_argument("--debugmode", default=1, type=int, help="Debugmode") + parser.add_argument("--seed", default=1, type=int, help="Random seed") + parser.add_argument( + "--resume", + "-r", + default="", + type=str, + nargs="?", + help="Resume the training from snapshot", + ) + parser.add_argument( + "--minibatches", + "-N", + type=int, + default="-1", + help="Process only N minibatches (for debug)", + ) + parser.add_argument("--verbose", "-V", default=0, type=int, help="Verbose option") + parser.add_argument( + "--tensorboard-dir", + default=None, + type=str, + nargs="?", + help="Tensorboard log directory path", + ) + parser.add_argument( + "--eval-interval-epochs", + default=100, + type=int, + help="Evaluation interval epochs", + ) + parser.add_argument( + "--save-interval-epochs", default=1, type=int, help="Save interval epochs" + ) + parser.add_argument( + "--report-interval-iters", + default=10, + type=int, + help="Report interval iterations", + ) + # task related + parser.add_argument("--srcspk", type=str, help="Source speaker") + parser.add_argument("--trgspk", type=str, help="Target speaker") + parser.add_argument( + "--train-json", type=str, required=True, help="Filename of training json" + ) + parser.add_argument( + "--valid-json", type=str, required=True, help="Filename of validation json" + ) + + # network architecture + parser.add_argument( + "--model-module", + type=str, + default="espnet.nets.pytorch_backend.e2e_tts_tacotron2:Tacotron2", + help="model defined module", + ) + # minibatch related + parser.add_argument( + "--sortagrad", + default=0, + type=int, + nargs="?", + help="How many epochs to use sortagrad for. 0 = deactivated, -1 = all epochs", + ) + parser.add_argument( + "--batch-sort-key", + default="shuffle", + type=str, + choices=["shuffle", "output", "input"], + nargs="?", + help='Batch sorting key. "shuffle" only work with --batch-count "seq".', + ) + parser.add_argument( + "--batch-count", + default="auto", + choices=BATCH_COUNT_CHOICES, + help="How to count batch_size. " + "The default (auto) will find how to count by args.", + ) + parser.add_argument( + "--batch-size", + "--batch-seqs", + "-b", + default=0, + type=int, + help="Maximum seqs in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-bins", + default=0, + type=int, + help="Maximum bins in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-in", + default=0, + type=int, + help="Maximum input frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-out", + default=0, + type=int, + help="Maximum output frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--batch-frames-inout", + default=0, + type=int, + help="Maximum input+output frames in a minibatch (0 to disable)", + ) + parser.add_argument( + "--maxlen-in", + "--batch-seq-maxlen-in", + default=100, + type=int, + metavar="ML", + help="When --batch-count=seq, " + "batch size is reduced if the input sequence length > ML.", + ) + parser.add_argument( + "--maxlen-out", + "--batch-seq-maxlen-out", + default=200, + type=int, + metavar="ML", + help="When --batch-count=seq, " + "batch size is reduced if the output sequence length > ML", + ) + parser.add_argument( + "--num-iter-processes", + default=0, + type=int, + help="Number of processes of iterator", + ) + parser.add_argument( + "--preprocess-conf", + type=str, + default=None, + help="The configuration file for the pre-processing", + ) + parser.add_argument( + "--use-speaker-embedding", + default=False, + type=strtobool, + help="Whether to use speaker embedding", + ) + parser.add_argument( + "--use-second-target", + default=False, + type=strtobool, + help="Whether to use second target", + ) + # optimization related + parser.add_argument( + "--opt", + default="adam", + type=str, + choices=["adam", "noam", "lamb"], + help="Optimizer", + ) + parser.add_argument( + "--accum-grad", default=1, type=int, help="Number of gradient accumuration" + ) + parser.add_argument( + "--lr", default=1e-3, type=float, help="Learning rate for optimizer" + ) + parser.add_argument("--eps", default=1e-6, type=float, help="Epsilon for optimizer") + parser.add_argument( + "--weight-decay", + default=1e-6, + type=float, + help="Weight decay coefficient for optimizer", + ) + parser.add_argument( + "--epochs", "-e", default=30, type=int, help="Number of maximum epochs" + ) + parser.add_argument( + "--early-stop-criterion", + default="validation/main/loss", + type=str, + nargs="?", + help="Value to monitor to trigger an early stopping of the training", + ) + parser.add_argument( + "--patience", + default=3, + type=int, + nargs="?", + help="Number of epochs to wait without improvement " + "before stopping the training", + ) + parser.add_argument( + "--grad-clip", default=1, type=float, help="Gradient norm threshold to clip" + ) + parser.add_argument( + "--num-save-attention", + default=5, + type=int, + help="Number of samples of attention to be saved", + ) + parser.add_argument( + "--keep-all-data-on-mem", + default=False, + type=strtobool, + help="Whether to keep all data on memory", + ) + + parser.add_argument( + "--enc-init", + default=None, + type=str, + help="Pre-trained model path to initialize encoder.", + ) + parser.add_argument( + "--enc-init-mods", + default="enc.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of encoder modules to initialize, separated by a comma.", + ) + parser.add_argument( + "--dec-init", + default=None, + type=str, + help="Pre-trained model path to initialize decoder.", + ) + parser.add_argument( + "--dec-init-mods", + default="dec.", + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of decoder modules to initialize, separated by a comma.", + ) + parser.add_argument( + "--freeze-mods", + default=None, + type=lambda s: [str(mod) for mod in s.split(",") if s != ""], + help="List of modules to freeze (not to train), separated by a comma.", + ) + + return parser + + +def main(cmd_args): + """Run training.""" + parser = get_parser() + args, _ = parser.parse_known_args(cmd_args) + + from espnet.utils.dynamic_import import dynamic_import + + model_class = dynamic_import(args.model_module) + assert issubclass(model_class, TTSInterface) + model_class.add_arguments(parser) + args = parser.parse_args(cmd_args) + + # add version info in args + args.version = __version__ + + # logging info + if args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + logging.warning("Skip DEBUG/INFO messages") + + # If --ngpu is not given, + # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices + # 2. if nvidia-smi exists, use all devices + # 3. else ngpu=0 + if args.ngpu is None: + cvd = os.environ.get("CUDA_VISIBLE_DEVICES") + if cvd is not None: + ngpu = len(cvd.split(",")) + else: + logging.warning("CUDA_VISIBLE_DEVICES is not set.") + try: + p = subprocess.run( + ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + except (subprocess.CalledProcessError, FileNotFoundError): + ngpu = 0 + else: + ngpu = len(p.stderr.decode().split("\n")) - 1 + else: + ngpu = args.ngpu + logging.info(f"ngpu: {ngpu}") + + # set random seed + logging.info("random seed = %d" % args.seed) + random.seed(args.seed) + np.random.seed(args.seed) + + if args.backend == "pytorch": + from espnet.vc.pytorch_backend.vc import train + + train(args) + else: + raise NotImplementedError("Only pytorch is supported.") + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/espnet/lm/__init__.py b/espnet/lm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/lm/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/lm/chainer_backend/__init__.py b/espnet/lm/chainer_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/lm/chainer_backend/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/lm/chainer_backend/extlm.py b/espnet/lm/chainer_backend/extlm.py new file mode 100644 index 0000000000000000000000000000000000000000..711e878c1d8677eeb3937fd67ef113c68ed90c08 --- /dev/null +++ b/espnet/lm/chainer_backend/extlm.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + + +import math + +import chainer +import chainer.functions as F +from espnet.lm.lm_utils import make_lexical_tree + + +# Definition of a multi-level (subword/word) language model +class MultiLevelLM(chainer.Chain): + logzero = -10000000000.0 + zero = 1.0e-10 + + def __init__( + self, + wordlm, + subwordlm, + word_dict, + subword_dict, + subwordlm_weight=0.8, + oov_penalty=1.0, + open_vocab=True, + ): + super(MultiLevelLM, self).__init__() + self.wordlm = wordlm + self.subwordlm = subwordlm + self.word_eos = word_dict[""] + self.word_unk = word_dict[""] + self.xp_word_eos = self.xp.full(1, self.word_eos, "i") + self.xp_word_unk = self.xp.full(1, self.word_unk, "i") + self.space = subword_dict[""] + self.eos = subword_dict[""] + self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk) + self.log_oov_penalty = math.log(oov_penalty) + self.open_vocab = open_vocab + self.subword_dict_size = len(subword_dict) + self.subwordlm_weight = subwordlm_weight + self.normalized = True + + def __call__(self, state, x): + # update state with input label x + if state is None: # make initial states and log-prob vectors + wlm_state, z_wlm = self.wordlm(None, self.xp_word_eos) + wlm_logprobs = F.log_softmax(z_wlm).data + clm_state, z_clm = self.subwordlm(None, x) + log_y = F.log_softmax(z_clm).data * self.subwordlm_weight + new_node = self.lexroot + clm_logprob = 0.0 + xi = self.space + else: + clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state + xi = int(x) + if xi == self.space: # inter-word transition + if node is not None and node[1] >= 0: # check if the node is word end + w = self.xp.full(1, node[1], "i") + else: # this node is not a word end, which means + w = self.xp_word_unk + # update wordlm state and log-prob vector + wlm_state, z_wlm = self.wordlm(wlm_state, w) + wlm_logprobs = F.log_softmax(z_wlm).data + new_node = self.lexroot # move to the tree root + clm_logprob = 0.0 + elif node is not None and xi in node[0]: # intra-word transition + new_node = node[0][xi] + clm_logprob += log_y[0, xi] + elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode + new_node = None + clm_logprob += log_y[0, xi] + else: # if open_vocab flag is disabled, return 0 probabilities + log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f") + return (clm_state, wlm_state, None, log_y, 0.0), log_y + + clm_state, z_clm = self.subwordlm(clm_state, x) + log_y = F.log_softmax(z_clm).data * self.subwordlm_weight + + # apply word-level probabilies for and labels + if xi != self.space: + if new_node is not None and new_node[1] >= 0: # if new node is word end + wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob + else: + wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty + log_y[:, self.space] = wlm_logprob + log_y[:, self.eos] = wlm_logprob + else: + log_y[:, self.space] = self.logzero + log_y[:, self.eos] = self.logzero + + return (clm_state, wlm_state, wlm_logprobs, new_node, log_y, clm_logprob), log_y + + def final(self, state): + clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state + if node is not None and node[1] >= 0: # check if the node is word end + w = self.xp.full(1, node[1], "i") + else: # this node is not a word end, which means + w = self.xp_word_unk + wlm_state, z_wlm = self.wordlm(wlm_state, w) + return F.log_softmax(z_wlm).data[:, self.word_eos] + + +# Definition of a look-ahead word language model +class LookAheadWordLM(chainer.Chain): + logzero = -10000000000.0 + zero = 1.0e-10 + + def __init__( + self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True + ): + super(LookAheadWordLM, self).__init__() + self.wordlm = wordlm + self.word_eos = word_dict[""] + self.word_unk = word_dict[""] + self.xp_word_eos = self.xp.full(1, self.word_eos, "i") + self.xp_word_unk = self.xp.full(1, self.word_unk, "i") + self.space = subword_dict[""] + self.eos = subword_dict[""] + self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk) + self.oov_penalty = oov_penalty + self.open_vocab = open_vocab + self.subword_dict_size = len(subword_dict) + self.normalized = True + + def __call__(self, state, x): + # update state with input label x + if state is None: # make initial states and cumlative probability vector + wlm_state, z_wlm = self.wordlm(None, self.xp_word_eos) + cumsum_probs = self.xp.cumsum(F.softmax(z_wlm).data, axis=1) + new_node = self.lexroot + xi = self.space + else: + wlm_state, cumsum_probs, node = state + xi = int(x) + if xi == self.space: # inter-word transition + if node is not None and node[1] >= 0: # check if the node is word end + w = self.xp.full(1, node[1], "i") + else: # this node is not a word end, which means + w = self.xp_word_unk + # update wordlm state and cumlative probability vector + wlm_state, z_wlm = self.wordlm(wlm_state, w) + cumsum_probs = self.xp.cumsum(F.softmax(z_wlm).data, axis=1) + new_node = self.lexroot # move to the tree root + elif node is not None and xi in node[0]: # intra-word transition + new_node = node[0][xi] + elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode + new_node = None + else: # if open_vocab flag is disabled, return 0 probabilities + log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f") + return (wlm_state, None, None), log_y + + if new_node is not None: + succ, wid, wids = new_node + # compute parent node probability + sum_prob = ( + (cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]]) + if wids is not None + else 1.0 + ) + if sum_prob < self.zero: + log_y = self.xp.full((1, self.subword_dict_size), self.logzero, "f") + return (wlm_state, cumsum_probs, new_node), log_y + # set probability as a default value + unk_prob = ( + cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1] + ) + y = self.xp.full( + (1, self.subword_dict_size), unk_prob * self.oov_penalty, "f" + ) + # compute transition probabilities to child nodes + for cid, nd in succ.items(): + y[:, cid] = ( + cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]] + ) / sum_prob + # apply word-level probabilies for and labels + if wid >= 0: + wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob + y[:, self.space] = wlm_prob + y[:, self.eos] = wlm_prob + elif xi == self.space: + y[:, self.space] = self.zero + y[:, self.eos] = self.zero + log_y = self.xp.log( + self.xp.clip(y, self.zero, None) + ) # clip to avoid log(0) + else: # if no path in the tree, transition probability is one + log_y = self.xp.zeros((1, self.subword_dict_size), "f") + return (wlm_state, cumsum_probs, new_node), log_y + + def final(self, state): + wlm_state, cumsum_probs, node = state + if node is not None and node[1] >= 0: # check if the node is word end + w = self.xp.full(1, node[1], "i") + else: # this node is not a word end, which means + w = self.xp_word_unk + wlm_state, z_wlm = self.wordlm(wlm_state, w) + return F.log_softmax(z_wlm).data[:, self.word_eos] diff --git a/espnet/lm/chainer_backend/lm.py b/espnet/lm/chainer_backend/lm.py new file mode 100644 index 0000000000000000000000000000000000000000..eb13f288b5a949fb23ecfd3b62810284a463ea8a --- /dev/null +++ b/espnet/lm/chainer_backend/lm.py @@ -0,0 +1,484 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +# This code is ported from the following implementation written in Torch. +# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py + + +import copy +import json +import logging +import numpy as np +import six + +import chainer +from chainer.dataset import convert +import chainer.functions as F +import chainer.links as L + +# for classifier link +from chainer.functions.loss import softmax_cross_entropy +from chainer import link +from chainer import reporter +from chainer import training +from chainer.training import extensions + +from espnet.lm.lm_utils import compute_perplexity +from espnet.lm.lm_utils import count_tokens +from espnet.lm.lm_utils import MakeSymlinkToBestModel +from espnet.lm.lm_utils import ParallelSentenceIterator +from espnet.lm.lm_utils import read_tokens + +import espnet.nets.chainer_backend.deterministic_embed_id as DL +from espnet.nets.lm_interface import LMInterface +from espnet.optimizer.factory import dynamic_import_optimizer +from espnet.scheduler.chainer import ChainerScheduler +from espnet.scheduler.scheduler import dynamic_import_scheduler + +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from tensorboardX import SummaryWriter + +from espnet.utils.deterministic_utils import set_deterministic_chainer +from espnet.utils.training.evaluator import BaseEvaluator +from espnet.utils.training.iterators import ShufflingEnabler +from espnet.utils.training.train_utils import check_early_stop +from espnet.utils.training.train_utils import set_early_stop + + +# TODO(karita): reimplement RNNLM with new interface +class DefaultRNNLM(LMInterface, link.Chain): + """Default RNNLM wrapper to compute reduce framewise loss values. + + Args: + n_vocab (int): The size of the vocabulary + args (argparse.Namespace): configurations. see `add_arguments` + """ + + @staticmethod + def add_arguments(parser): + parser.add_argument( + "--type", + type=str, + default="lstm", + nargs="?", + choices=["lstm", "gru"], + help="Which type of RNN to use", + ) + parser.add_argument( + "--layer", "-l", type=int, default=2, help="Number of hidden layers" + ) + parser.add_argument( + "--unit", "-u", type=int, default=650, help="Number of hidden units" + ) + return parser + + +class ClassifierWithState(link.Chain): + """A wrapper for a chainer RNNLM + + :param link.Chain predictor : The RNNLM + :param function lossfun: The loss function to use + :param int/str label_key: + """ + + def __init__( + self, + predictor, + lossfun=softmax_cross_entropy.softmax_cross_entropy, + label_key=-1, + ): + if not (isinstance(label_key, (int, str))): + raise TypeError("label_key must be int or str, but is %s" % type(label_key)) + + super(ClassifierWithState, self).__init__() + self.lossfun = lossfun + self.y = None + self.loss = None + self.label_key = label_key + + with self.init_scope(): + self.predictor = predictor + + def __call__(self, state, *args, **kwargs): + """Computes the loss value for an input and label pair. + + It also computes accuracy and stores it to the attribute. + When ``label_key`` is ``int``, the corresponding element in ``args`` + is treated as ground truth labels. And when it is ``str``, the + element in ``kwargs`` is used. + The all elements of ``args`` and ``kwargs`` except the groundtruth + labels are features. + It feeds features to the predictor and compare the result + with ground truth labels. + + :param state : The LM state + :param list[chainer.Variable] args : Input minibatch + :param dict[chainer.Variable] kwargs : Input minibatch + :return loss value + :rtype chainer.Variable + """ + + if isinstance(self.label_key, int): + if not (-len(args) <= self.label_key < len(args)): + msg = "Label key %d is out of bounds" % self.label_key + raise ValueError(msg) + t = args[self.label_key] + if self.label_key == -1: + args = args[:-1] + else: + args = args[: self.label_key] + args[self.label_key + 1 :] + elif isinstance(self.label_key, str): + if self.label_key not in kwargs: + msg = 'Label key "%s" is not found' % self.label_key + raise ValueError(msg) + t = kwargs[self.label_key] + del kwargs[self.label_key] + + self.y = None + self.loss = None + state, self.y = self.predictor(state, *args, **kwargs) + self.loss = self.lossfun(self.y, t) + return state, self.loss + + def predict(self, state, x): + """Predict log probabilities for given state and input x using the predictor + + :param state : the state + :param x : the input + :return a tuple (state, log prob vector) + :rtype cupy/numpy array + """ + if hasattr(self.predictor, "normalized") and self.predictor.normalized: + return self.predictor(state, x) + else: + state, z = self.predictor(state, x) + return state, F.log_softmax(z).data + + def final(self, state): + """Predict final log probabilities for given state using the predictor + + :param state : the state + :return log probability vector + :rtype cupy/numpy array + + """ + if hasattr(self.predictor, "final"): + return self.predictor.final(state) + else: + return 0.0 + + +# Definition of a recurrent net for language modeling +class RNNLM(chainer.Chain): + """A chainer RNNLM + + :param int n_vocab: The size of the vocabulary + :param int n_layers: The number of layers to create + :param int n_units: The number of units per layer + :param str type: The RNN type + """ + + def __init__(self, n_vocab, n_layers, n_units, typ="lstm"): + super(RNNLM, self).__init__() + with self.init_scope(): + self.embed = DL.EmbedID(n_vocab, n_units) + self.rnn = ( + chainer.ChainList( + *[L.StatelessLSTM(n_units, n_units) for _ in range(n_layers)] + ) + if typ == "lstm" + else chainer.ChainList( + *[L.StatelessGRU(n_units, n_units) for _ in range(n_layers)] + ) + ) + self.lo = L.Linear(n_units, n_vocab) + + for param in self.params(): + param.data[...] = np.random.uniform(-0.1, 0.1, param.data.shape) + self.n_layers = n_layers + self.n_units = n_units + self.typ = typ + + def __call__(self, state, x): + if state is None: + if self.typ == "lstm": + state = {"c": [None] * self.n_layers, "h": [None] * self.n_layers} + else: + state = {"h": [None] * self.n_layers} + + h = [None] * self.n_layers + emb = self.embed(x) + if self.typ == "lstm": + c = [None] * self.n_layers + c[0], h[0] = self.rnn[0](state["c"][0], state["h"][0], F.dropout(emb)) + for n in six.moves.range(1, self.n_layers): + c[n], h[n] = self.rnn[n]( + state["c"][n], state["h"][n], F.dropout(h[n - 1]) + ) + state = {"c": c, "h": h} + else: + if state["h"][0] is None: + xp = self.xp + with chainer.backends.cuda.get_device_from_id(self._device_id): + state["h"][0] = chainer.Variable( + xp.zeros((emb.shape[0], self.n_units), dtype=emb.dtype) + ) + h[0] = self.rnn[0](state["h"][0], F.dropout(emb)) + for n in six.moves.range(1, self.n_layers): + if state["h"][n] is None: + xp = self.xp + with chainer.backends.cuda.get_device_from_id(self._device_id): + state["h"][n] = chainer.Variable( + xp.zeros( + (h[n - 1].shape[0], self.n_units), dtype=h[n - 1].dtype + ) + ) + h[n] = self.rnn[n](state["h"][n], F.dropout(h[n - 1])) + state = {"h": h} + y = self.lo(F.dropout(h[-1])) + return state, y + + +class BPTTUpdater(training.updaters.StandardUpdater): + """An updater for a chainer LM + + :param chainer.dataset.Iterator train_iter : The train iterator + :param optimizer: + :param schedulers: + :param int device : The device id + :param int accum_grad : + """ + + def __init__(self, train_iter, optimizer, schedulers, device, accum_grad): + super(BPTTUpdater, self).__init__(train_iter, optimizer, device=device) + self.scheduler = ChainerScheduler(schedulers, optimizer) + self.accum_grad = accum_grad + + # The core part of the update routine can be customized by overriding. + def update_core(self): + # When we pass one iterator and optimizer to StandardUpdater.__init__, + # they are automatically named 'main'. + train_iter = self.get_iterator("main") + optimizer = self.get_optimizer("main") + + count = 0 + sum_loss = 0 + optimizer.target.cleargrads() # Clear the parameter gradients + for _ in range(self.accum_grad): + # Progress the dataset iterator for sentences at each iteration. + batch = train_iter.__next__() + x, t = convert.concat_examples(batch, device=self.device, padding=(0, -1)) + # Concatenate the token IDs to matrices and send them to the device + # self.converter does this job + # (it is chainer.dataset.concat_examples by default) + xp = chainer.backends.cuda.get_array_module(x) + loss = 0 + state = None + batch_size, sequence_length = x.shape + for i in six.moves.range(sequence_length): + # Compute the loss at this time step and accumulate it + state, loss_batch = optimizer.target( + state, chainer.Variable(x[:, i]), chainer.Variable(t[:, i]) + ) + non_zeros = xp.count_nonzero(x[:, i]) + loss += loss_batch * non_zeros + count += int(non_zeros) + # backward + loss /= batch_size * self.accum_grad # normalized by batch size + sum_loss += float(loss.data) + loss.backward() # Backprop + loss.unchain_backward() # Truncate the graph + + reporter.report({"loss": sum_loss}, optimizer.target) + reporter.report({"count": count}, optimizer.target) + # update + optimizer.update() # Update the parameters + self.scheduler.step(self.iteration) + + +class LMEvaluator(BaseEvaluator): + """A custom evaluator for a chainer LM + + :param chainer.dataset.Iterator val_iter : The validation iterator + :param eval_model : The model to evaluate + :param int device : The device id to use + """ + + def __init__(self, val_iter, eval_model, device): + super(LMEvaluator, self).__init__(val_iter, eval_model, device=device) + + def evaluate(self): + val_iter = self.get_iterator("main") + target = self.get_target("main") + loss = 0 + count = 0 + for batch in copy.copy(val_iter): + x, t = convert.concat_examples(batch, device=self.device, padding=(0, -1)) + xp = chainer.backends.cuda.get_array_module(x) + state = None + for i in six.moves.range(len(x[0])): + state, loss_batch = target(state, x[:, i], t[:, i]) + non_zeros = xp.count_nonzero(x[:, i]) + loss += loss_batch.data * non_zeros + count += int(non_zeros) + # report validation loss + observation = {} + with reporter.report_scope(observation): + reporter.report({"loss": float(loss / count)}, target) + return observation + + +def train(args): + """Train with the given args + + :param Namespace args: The program arguments + """ + # TODO(karita): support this + if args.model_module != "default": + raise NotImplementedError("chainer backend does not support --model-module") + + # display chainer version + logging.info("chainer version = " + chainer.__version__) + + set_deterministic_chainer(args) + + # check cuda and cudnn availability + if not chainer.cuda.available: + logging.warning("cuda is not available") + if not chainer.cuda.cudnn_enabled: + logging.warning("cudnn is not available") + + # get special label ids + unk = args.char_list_dict[""] + eos = args.char_list_dict[""] + # read tokens as a sequence of sentences + train = read_tokens(args.train_label, args.char_list_dict) + val = read_tokens(args.valid_label, args.char_list_dict) + # count tokens + n_train_tokens, n_train_oovs = count_tokens(train, unk) + n_val_tokens, n_val_oovs = count_tokens(val, unk) + logging.info("#vocab = " + str(args.n_vocab)) + logging.info("#sentences in the training data = " + str(len(train))) + logging.info("#tokens in the training data = " + str(n_train_tokens)) + logging.info( + "oov rate in the training data = %.2f %%" + % (n_train_oovs / n_train_tokens * 100) + ) + logging.info("#sentences in the validation data = " + str(len(val))) + logging.info("#tokens in the validation data = " + str(n_val_tokens)) + logging.info( + "oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100) + ) + + use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 + + # Create the dataset iterators + train_iter = ParallelSentenceIterator( + train, + args.batchsize, + max_length=args.maxlen, + sos=eos, + eos=eos, + shuffle=not use_sortagrad, + ) + val_iter = ParallelSentenceIterator( + val, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, repeat=False + ) + epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad) + logging.info("#iterations per epoch = %d" % epoch_iters) + logging.info("#total iterations = " + str(args.epoch * epoch_iters)) + # Prepare an RNNLM model + rnn = RNNLM(args.n_vocab, args.layer, args.unit, args.type) + model = ClassifierWithState(rnn) + if args.ngpu > 1: + logging.warning("currently, multi-gpu is not supported. use single gpu.") + if args.ngpu > 0: + # Make the specified GPU current + gpu_id = 0 + chainer.cuda.get_device_from_id(gpu_id).use() + model.to_gpu() + else: + gpu_id = -1 + + # Save model conf to json + model_conf = args.outdir + "/model.json" + with open(model_conf, "wb") as f: + logging.info("writing a model config file to " + model_conf) + f.write( + json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode( + "utf_8" + ) + ) + + # Set up an optimizer + opt_class = dynamic_import_optimizer(args.opt, args.backend) + optimizer = opt_class.from_args(model, args) + if args.schedulers is None: + schedulers = [] + else: + schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers] + + optimizer.setup(model) + optimizer.add_hook(chainer.optimizer.GradientClipping(args.gradclip)) + + updater = BPTTUpdater(train_iter, optimizer, schedulers, gpu_id, args.accum_grad) + trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir) + trainer.extend(LMEvaluator(val_iter, model, device=gpu_id)) + trainer.extend( + extensions.LogReport( + postprocess=compute_perplexity, + trigger=(args.report_interval_iters, "iteration"), + ) + ) + trainer.extend( + extensions.PrintReport( + ["epoch", "iteration", "perplexity", "val_perplexity", "elapsed_time"] + ), + trigger=(args.report_interval_iters, "iteration"), + ) + trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) + trainer.extend(extensions.snapshot(filename="snapshot.ep.{.updater.epoch}")) + trainer.extend(extensions.snapshot_object(model, "rnnlm.model.{.updater.epoch}")) + # MEMO(Hori): wants to use MinValueTrigger, but it seems to fail in resuming + trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model")) + + if use_sortagrad: + trainer.extend( + ShufflingEnabler([train_iter]), + trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"), + ) + + if args.resume: + logging.info("resumed from %s" % args.resume) + chainer.serializers.load_npz(args.resume, trainer) + + set_early_stop(trainer, args, is_lm=True) + if args.tensorboard_dir is not None and args.tensorboard_dir != "": + writer = SummaryWriter(args.tensorboard_dir) + trainer.extend( + TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration") + ) + + trainer.run() + check_early_stop(trainer, args.epoch) + + # compute perplexity for test set + if args.test_label: + logging.info("test the best model") + chainer.serializers.load_npz(args.outdir + "/rnnlm.model.best", model) + test = read_tokens(args.test_label, args.char_list_dict) + n_test_tokens, n_test_oovs = count_tokens(test, unk) + logging.info("#sentences in the test data = " + str(len(test))) + logging.info("#tokens in the test data = " + str(n_test_tokens)) + logging.info( + "oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100) + ) + test_iter = ParallelSentenceIterator( + test, args.batchsize, max_length=args.maxlen, sos=eos, eos=eos, repeat=False + ) + evaluator = LMEvaluator(test_iter, model, device=gpu_id) + with chainer.using_config("train", False): + result = evaluator() + logging.info("test perplexity: " + str(np.exp(float(result["main/loss"])))) diff --git a/espnet/lm/lm_utils.py b/espnet/lm/lm_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bb43e5de0e7ac83cf889d5e536d51c853058013e --- /dev/null +++ b/espnet/lm/lm_utils.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +# This code is ported from the following implementation written in Torch. +# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py + +import chainer +import h5py +import logging +import numpy as np +import os +import random +import six +from tqdm import tqdm + +from chainer.training import extension + + +def load_dataset(path, label_dict, outdir=None): + """Load and save HDF5 that contains a dataset and stats for LM + + Args: + path (str): The path of an input text dataset file + label_dict (dict[str, int]): + dictionary that maps token label string to its ID number + outdir (str): The path of an output dir + + Returns: + tuple[list[np.ndarray], int, int]: Tuple of + token IDs in np.int32 converted by `read_tokens` + the number of tokens by `count_tokens`, + and the number of OOVs by `count_tokens` + """ + if outdir is not None: + os.makedirs(outdir, exist_ok=True) + filename = outdir + "/" + os.path.basename(path) + ".h5" + if os.path.exists(filename): + logging.info(f"loading binary dataset: {filename}") + f = h5py.File(filename, "r") + return f["data"][:], f["n_tokens"][()], f["n_oovs"][()] + else: + logging.info("skip dump/load HDF5 because the output dir is not specified") + logging.info(f"reading text dataset: {path}") + ret = read_tokens(path, label_dict) + n_tokens, n_oovs = count_tokens(ret, label_dict[""]) + if outdir is not None: + logging.info(f"saving binary dataset: {filename}") + with h5py.File(filename, "w") as f: + # http://docs.h5py.org/en/stable/special.html#arbitrary-vlen-data + data = f.create_dataset( + "data", (len(ret),), dtype=h5py.special_dtype(vlen=np.int32) + ) + data[:] = ret + f["n_tokens"] = n_tokens + f["n_oovs"] = n_oovs + return ret, n_tokens, n_oovs + + +def read_tokens(filename, label_dict): + """Read tokens as a sequence of sentences + + :param str filename : The name of the input file + :param dict label_dict : dictionary that maps token label string to its ID number + :return list of ID sequences + :rtype list + """ + + data = [] + unk = label_dict[""] + for ln in tqdm(open(filename, "r", encoding="utf-8")): + data.append( + np.array( + [label_dict.get(label, unk) for label in ln.split()], dtype=np.int32 + ) + ) + return data + + +def count_tokens(data, unk_id=None): + """Count tokens and oovs in token ID sequences. + + Args: + data (list[np.ndarray]): list of token ID sequences + unk_id (int): ID of unknown token + + Returns: + tuple: tuple of number of token occurrences and number of oov tokens + + """ + + n_tokens = 0 + n_oovs = 0 + for sentence in data: + n_tokens += len(sentence) + if unk_id is not None: + n_oovs += np.count_nonzero(sentence == unk_id) + return n_tokens, n_oovs + + +def compute_perplexity(result): + """Computes and add the perplexity to the LogReport + + :param dict result: The current observations + """ + # Routine to rewrite the result dictionary of LogReport to add perplexity values + result["perplexity"] = np.exp(result["main/loss"] / result["main/count"]) + if "validation/main/loss" in result: + result["val_perplexity"] = np.exp(result["validation/main/loss"]) + + +class ParallelSentenceIterator(chainer.dataset.Iterator): + """Dataset iterator to create a batch of sentences. + + This iterator returns a pair of sentences, where one token is shifted + between the sentences like ' w1 w2 w3' and 'w1 w2 w3 ' + Sentence batches are made in order of longer sentences, and then + randomly shuffled. + """ + + def __init__( + self, dataset, batch_size, max_length=0, sos=0, eos=0, repeat=True, shuffle=True + ): + self.dataset = dataset + self.batch_size = batch_size # batch size + # Number of completed sweeps over the dataset. In this case, it is + # incremented if every word is visited at least once after the last + # increment. + self.epoch = 0 + # True if the epoch is incremented at the last iteration. + self.is_new_epoch = False + self.repeat = repeat + length = len(dataset) + self.batch_indices = [] + # make mini-batches + if batch_size > 1: + indices = sorted(range(len(dataset)), key=lambda i: -len(dataset[i])) + bs = 0 + while bs < length: + be = min(bs + batch_size, length) + # batch size is automatically reduced if the sentence length + # is larger than max_length + if max_length > 0: + sent_length = len(dataset[indices[bs]]) + be = min( + be, bs + max(batch_size // (sent_length // max_length + 1), 1) + ) + self.batch_indices.append(np.array(indices[bs:be])) + bs = be + if shuffle: + # shuffle batches + random.shuffle(self.batch_indices) + else: + self.batch_indices = [np.array([i]) for i in six.moves.range(length)] + + # NOTE: this is not a count of parameter updates. It is just a count of + # calls of ``__next__``. + self.iteration = 0 + self.sos = sos + self.eos = eos + # use -1 instead of None internally + self._previous_epoch_detail = -1.0 + + def __next__(self): + # This iterator returns a list representing a mini-batch. Each item + # indicates a sentence pair like ' w1 w2 w3' and 'w1 w2 w3 ' + # represented by token IDs. + n_batches = len(self.batch_indices) + if not self.repeat and self.iteration >= n_batches: + # If not self.repeat, this iterator stops at the end of the first + # epoch (i.e., when all words are visited once). + raise StopIteration + + batch = [] + for idx in self.batch_indices[self.iteration % n_batches]: + batch.append( + ( + np.append([self.sos], self.dataset[idx]), + np.append(self.dataset[idx], [self.eos]), + ) + ) + + self._previous_epoch_detail = self.epoch_detail + self.iteration += 1 + + epoch = self.iteration // n_batches + self.is_new_epoch = self.epoch < epoch + if self.is_new_epoch: + self.epoch = epoch + + return batch + + def start_shuffle(self): + random.shuffle(self.batch_indices) + + @property + def epoch_detail(self): + # Floating point version of epoch. + return self.iteration / len(self.batch_indices) + + @property + def previous_epoch_detail(self): + if self._previous_epoch_detail < 0: + return None + return self._previous_epoch_detail + + def serialize(self, serializer): + # It is important to serialize the state to be recovered on resume. + self.iteration = serializer("iteration", self.iteration) + self.epoch = serializer("epoch", self.epoch) + try: + self._previous_epoch_detail = serializer( + "previous_epoch_detail", self._previous_epoch_detail + ) + except KeyError: + # guess previous_epoch_detail for older version + self._previous_epoch_detail = self.epoch + ( + self.current_position - 1 + ) / len(self.batch_indices) + if self.epoch_detail > 0: + self._previous_epoch_detail = max(self._previous_epoch_detail, 0.0) + else: + self._previous_epoch_detail = -1.0 + + +class MakeSymlinkToBestModel(extension.Extension): + """Extension that makes a symbolic link to the best model + + :param str key: Key of value + :param str prefix: Prefix of model files and link target + :param str suffix: Suffix of link target + """ + + def __init__(self, key, prefix="model", suffix="best"): + super(MakeSymlinkToBestModel, self).__init__() + self.best_model = -1 + self.min_loss = 0.0 + self.key = key + self.prefix = prefix + self.suffix = suffix + + def __call__(self, trainer): + observation = trainer.observation + if self.key in observation: + loss = observation[self.key] + if self.best_model == -1 or loss < self.min_loss: + self.min_loss = loss + self.best_model = trainer.updater.epoch + src = "%s.%d" % (self.prefix, self.best_model) + dest = os.path.join(trainer.out, "%s.%s" % (self.prefix, self.suffix)) + if os.path.lexists(dest): + os.remove(dest) + os.symlink(src, dest) + logging.info("best model is " + src) + + def serialize(self, serializer): + if isinstance(serializer, chainer.serializer.Serializer): + serializer("_best_model", self.best_model) + serializer("_min_loss", self.min_loss) + serializer("_key", self.key) + serializer("_prefix", self.prefix) + serializer("_suffix", self.suffix) + else: + self.best_model = serializer("_best_model", -1) + self.min_loss = serializer("_min_loss", 0.0) + self.key = serializer("_key", "") + self.prefix = serializer("_prefix", "model") + self.suffix = serializer("_suffix", "best") + + +# TODO(Hori): currently it only works with character-word level LM. +# need to consider any types of subwords-to-word mapping. +def make_lexical_tree(word_dict, subword_dict, word_unk): + """Make a lexical tree to compute word-level probabilities""" + # node [dict(subword_id -> node), word_id, word_set[start-1, end]] + root = [{}, -1, None] + for w, wid in word_dict.items(): + if wid > 0 and wid != word_unk: # skip and + if True in [c not in subword_dict for c in w]: # skip unknown subword + continue + succ = root[0] # get successors from root node + for i, c in enumerate(w): + cid = subword_dict[c] + if cid not in succ: # if next node does not exist, make a new node + succ[cid] = [{}, -1, (wid - 1, wid)] + else: + prev = succ[cid][2] + succ[cid][2] = (min(prev[0], wid - 1), max(prev[1], wid)) + if i == len(w) - 1: # if word end, set word id + succ[cid][1] = wid + succ = succ[cid][0] # move to the child successors + return root diff --git a/espnet/lm/pytorch_backend/__init__.py b/espnet/lm/pytorch_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/lm/pytorch_backend/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/lm/pytorch_backend/extlm.py b/espnet/lm/pytorch_backend/extlm.py new file mode 100644 index 0000000000000000000000000000000000000000..caf6b495df528863e46f8e424cbf828089a009d7 --- /dev/null +++ b/espnet/lm/pytorch_backend/extlm.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from espnet.lm.lm_utils import make_lexical_tree +from espnet.nets.pytorch_backend.nets_utils import to_device + + +# Definition of a multi-level (subword/word) language model +class MultiLevelLM(nn.Module): + logzero = -10000000000.0 + zero = 1.0e-10 + + def __init__( + self, + wordlm, + subwordlm, + word_dict, + subword_dict, + subwordlm_weight=0.8, + oov_penalty=1.0, + open_vocab=True, + ): + super(MultiLevelLM, self).__init__() + self.wordlm = wordlm + self.subwordlm = subwordlm + self.word_eos = word_dict[""] + self.word_unk = word_dict[""] + self.var_word_eos = torch.LongTensor([self.word_eos]) + self.var_word_unk = torch.LongTensor([self.word_unk]) + self.space = subword_dict[""] + self.eos = subword_dict[""] + self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk) + self.log_oov_penalty = math.log(oov_penalty) + self.open_vocab = open_vocab + self.subword_dict_size = len(subword_dict) + self.subwordlm_weight = subwordlm_weight + self.normalized = True + + def forward(self, state, x): + # update state with input label x + if state is None: # make initial states and log-prob vectors + self.var_word_eos = to_device(x, self.var_word_eos) + self.var_word_unk = to_device(x, self.var_word_eos) + wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) + wlm_logprobs = F.log_softmax(z_wlm, dim=1) + clm_state, z_clm = self.subwordlm(None, x) + log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight + new_node = self.lexroot + clm_logprob = 0.0 + xi = self.space + else: + clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state + xi = int(x) + if xi == self.space: # inter-word transition + if node is not None and node[1] >= 0: # check if the node is word end + w = to_device(x, torch.LongTensor([node[1]])) + else: # this node is not a word end, which means + w = self.var_word_unk + # update wordlm state and log-prob vector + wlm_state, z_wlm = self.wordlm(wlm_state, w) + wlm_logprobs = F.log_softmax(z_wlm, dim=1) + new_node = self.lexroot # move to the tree root + clm_logprob = 0.0 + elif node is not None and xi in node[0]: # intra-word transition + new_node = node[0][xi] + clm_logprob += log_y[0, xi] + elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode + new_node = None + clm_logprob += log_y[0, xi] + else: # if open_vocab flag is disabled, return 0 probabilities + log_y = to_device( + x, torch.full((1, self.subword_dict_size), self.logzero) + ) + return (clm_state, wlm_state, wlm_logprobs, None, log_y, 0.0), log_y + + clm_state, z_clm = self.subwordlm(clm_state, x) + log_y = F.log_softmax(z_clm, dim=1) * self.subwordlm_weight + + # apply word-level probabilies for and labels + if xi != self.space: + if new_node is not None and new_node[1] >= 0: # if new node is word end + wlm_logprob = wlm_logprobs[:, new_node[1]] - clm_logprob + else: + wlm_logprob = wlm_logprobs[:, self.word_unk] + self.log_oov_penalty + log_y[:, self.space] = wlm_logprob + log_y[:, self.eos] = wlm_logprob + else: + log_y[:, self.space] = self.logzero + log_y[:, self.eos] = self.logzero + + return ( + (clm_state, wlm_state, wlm_logprobs, new_node, log_y, float(clm_logprob)), + log_y, + ) + + def final(self, state): + clm_state, wlm_state, wlm_logprobs, node, log_y, clm_logprob = state + if node is not None and node[1] >= 0: # check if the node is word end + w = to_device(wlm_logprobs, torch.LongTensor([node[1]])) + else: # this node is not a word end, which means + w = self.var_word_unk + wlm_state, z_wlm = self.wordlm(wlm_state, w) + return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos]) + + +# Definition of a look-ahead word language model +class LookAheadWordLM(nn.Module): + logzero = -10000000000.0 + zero = 1.0e-10 + + def __init__( + self, wordlm, word_dict, subword_dict, oov_penalty=0.0001, open_vocab=True + ): + super(LookAheadWordLM, self).__init__() + self.wordlm = wordlm + self.word_eos = word_dict[""] + self.word_unk = word_dict[""] + self.var_word_eos = torch.LongTensor([self.word_eos]) + self.var_word_unk = torch.LongTensor([self.word_unk]) + self.space = subword_dict[""] + self.eos = subword_dict[""] + self.lexroot = make_lexical_tree(word_dict, subword_dict, self.word_unk) + self.oov_penalty = oov_penalty + self.open_vocab = open_vocab + self.subword_dict_size = len(subword_dict) + self.zero_tensor = torch.FloatTensor([self.zero]) + self.normalized = True + + def forward(self, state, x): + # update state with input label x + if state is None: # make initial states and cumlative probability vector + self.var_word_eos = to_device(x, self.var_word_eos) + self.var_word_unk = to_device(x, self.var_word_eos) + self.zero_tensor = to_device(x, self.zero_tensor) + wlm_state, z_wlm = self.wordlm(None, self.var_word_eos) + cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1) + new_node = self.lexroot + xi = self.space + else: + wlm_state, cumsum_probs, node = state + xi = int(x) + if xi == self.space: # inter-word transition + if node is not None and node[1] >= 0: # check if the node is word end + w = to_device(x, torch.LongTensor([node[1]])) + else: # this node is not a word end, which means + w = self.var_word_unk + # update wordlm state and cumlative probability vector + wlm_state, z_wlm = self.wordlm(wlm_state, w) + cumsum_probs = torch.cumsum(F.softmax(z_wlm, dim=1), dim=1) + new_node = self.lexroot # move to the tree root + elif node is not None and xi in node[0]: # intra-word transition + new_node = node[0][xi] + elif self.open_vocab: # if no path in the tree, enter open-vocabulary mode + new_node = None + else: # if open_vocab flag is disabled, return 0 probabilities + log_y = to_device( + x, torch.full((1, self.subword_dict_size), self.logzero) + ) + return (wlm_state, None, None), log_y + + if new_node is not None: + succ, wid, wids = new_node + # compute parent node probability + sum_prob = ( + (cumsum_probs[:, wids[1]] - cumsum_probs[:, wids[0]]) + if wids is not None + else 1.0 + ) + if sum_prob < self.zero: + log_y = to_device( + x, torch.full((1, self.subword_dict_size), self.logzero) + ) + return (wlm_state, cumsum_probs, new_node), log_y + # set probability as a default value + unk_prob = ( + cumsum_probs[:, self.word_unk] - cumsum_probs[:, self.word_unk - 1] + ) + y = to_device( + x, + torch.full( + (1, self.subword_dict_size), float(unk_prob) * self.oov_penalty + ), + ) + # compute transition probabilities to child nodes + for cid, nd in succ.items(): + y[:, cid] = ( + cumsum_probs[:, nd[2][1]] - cumsum_probs[:, nd[2][0]] + ) / sum_prob + # apply word-level probabilies for and labels + if wid >= 0: + wlm_prob = (cumsum_probs[:, wid] - cumsum_probs[:, wid - 1]) / sum_prob + y[:, self.space] = wlm_prob + y[:, self.eos] = wlm_prob + elif xi == self.space: + y[:, self.space] = self.zero + y[:, self.eos] = self.zero + log_y = torch.log(torch.max(y, self.zero_tensor)) # clip to avoid log(0) + else: # if no path in the tree, transition probability is one + log_y = to_device(x, torch.zeros(1, self.subword_dict_size)) + return (wlm_state, cumsum_probs, new_node), log_y + + def final(self, state): + wlm_state, cumsum_probs, node = state + if node is not None and node[1] >= 0: # check if the node is word end + w = to_device(cumsum_probs, torch.LongTensor([node[1]])) + else: # this node is not a word end, which means + w = self.var_word_unk + wlm_state, z_wlm = self.wordlm(wlm_state, w) + return float(F.log_softmax(z_wlm, dim=1)[:, self.word_eos]) diff --git a/espnet/lm/pytorch_backend/lm.py b/espnet/lm/pytorch_backend/lm.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0249527fa2184afd9a6ae89eafb923de7a2b51 --- /dev/null +++ b/espnet/lm/pytorch_backend/lm.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +# This code is ported from the following implementation written in Torch. +# https://github.com/chainer/chainer/blob/master/examples/ptb/train_ptb_custom_loop.py + +"""LM training in pytorch.""" + +import copy +import json +import logging +import numpy as np + +import torch +import torch.nn as nn +from torch.nn.parallel import data_parallel + +from chainer import Chain +from chainer.dataset import convert +from chainer import reporter +from chainer import training +from chainer.training import extensions + +from espnet.lm.lm_utils import count_tokens +from espnet.lm.lm_utils import load_dataset +from espnet.lm.lm_utils import MakeSymlinkToBestModel +from espnet.lm.lm_utils import ParallelSentenceIterator +from espnet.lm.lm_utils import read_tokens +from espnet.nets.lm_interface import dynamic_import_lm +from espnet.nets.lm_interface import LMInterface +from espnet.optimizer.factory import dynamic_import_optimizer +from espnet.scheduler.pytorch import PyTorchScheduler +from espnet.scheduler.scheduler import dynamic_import_scheduler + +from espnet.asr.asr_utils import snapshot_object +from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import torch_resume +from espnet.asr.asr_utils import torch_snapshot + +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from tensorboardX import SummaryWriter + +from espnet.utils.deterministic_utils import set_deterministic_pytorch +from espnet.utils.training.evaluator import BaseEvaluator +from espnet.utils.training.iterators import ShufflingEnabler +from espnet.utils.training.train_utils import check_early_stop +from espnet.utils.training.train_utils import set_early_stop + + +def compute_perplexity(result): + """Compute and add the perplexity to the LogReport. + + :param dict result: The current observations + """ + # Routine to rewrite the result dictionary of LogReport to add perplexity values + result["perplexity"] = np.exp(result["main/nll"] / result["main/count"]) + if "validation/main/nll" in result: + result["val_perplexity"] = np.exp( + result["validation/main/nll"] / result["validation/main/count"] + ) + + +class Reporter(Chain): + """Dummy module to use chainer's trainer.""" + + def report(self, loss): + """Report nothing.""" + pass + + +def concat_examples(batch, device=None, padding=None): + """Concat examples in minibatch. + + :param np.ndarray batch: The batch to concatenate + :param int device: The device to send to + :param Tuple[int,int] padding: The padding to use + :return: (inputs, targets) + :rtype (torch.Tensor, torch.Tensor) + """ + x, t = convert.concat_examples(batch, padding=padding) + x = torch.from_numpy(x) + t = torch.from_numpy(t) + if device is not None and device >= 0: + x = x.cuda(device) + t = t.cuda(device) + return x, t + + +class BPTTUpdater(training.StandardUpdater): + """An updater for a pytorch LM.""" + + def __init__( + self, + train_iter, + model, + optimizer, + schedulers, + device, + gradclip=None, + use_apex=False, + accum_grad=1, + ): + """Initialize class. + + Args: + train_iter (chainer.dataset.Iterator): The train iterator + model (LMInterface) : The model to update + optimizer (torch.optim.Optimizer): The optimizer for training + schedulers (espnet.scheduler.scheduler.SchedulerInterface): + The schedulers of `optimizer` + device (int): The device id + gradclip (float): The gradient clipping value to use + use_apex (bool): The flag to use Apex in backprop. + accum_grad (int): The number of gradient accumulation. + + """ + super(BPTTUpdater, self).__init__(train_iter, optimizer) + self.model = model + self.device = device + self.gradclip = gradclip + self.use_apex = use_apex + self.scheduler = PyTorchScheduler(schedulers, optimizer) + self.accum_grad = accum_grad + + # The core part of the update routine can be customized by overriding. + def update_core(self): + """Update the model.""" + # When we pass one iterator and optimizer to StandardUpdater.__init__, + # they are automatically named 'main'. + train_iter = self.get_iterator("main") + optimizer = self.get_optimizer("main") + # Progress the dataset iterator for sentences at each iteration. + self.model.zero_grad() # Clear the parameter gradients + accum = {"loss": 0.0, "nll": 0.0, "count": 0} + for _ in range(self.accum_grad): + batch = train_iter.__next__() + # Concatenate the token IDs to matrices and send them to the device + # self.converter does this job + # (it is chainer.dataset.concat_examples by default) + x, t = concat_examples(batch, device=self.device[0], padding=(0, -100)) + if self.device[0] == -1: + loss, nll, count = self.model(x, t) + else: + # apex does not support torch.nn.DataParallel + loss, nll, count = data_parallel(self.model, (x, t), self.device) + + # backward + loss = loss.mean() / self.accum_grad + if self.use_apex: + from apex import amp + + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() # Backprop + # accumulate stats + accum["loss"] += float(loss) + accum["nll"] += float(nll.sum()) + accum["count"] += int(count.sum()) + + for k, v in accum.items(): + reporter.report({k: v}, optimizer.target) + if self.gradclip is not None: + nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip) + optimizer.step() # Update the parameters + self.scheduler.step(n_iter=self.iteration) + + +class LMEvaluator(BaseEvaluator): + """A custom evaluator for a pytorch LM.""" + + def __init__(self, val_iter, eval_model, reporter, device): + """Initialize class. + + :param chainer.dataset.Iterator val_iter : The validation iterator + :param LMInterface eval_model : The model to evaluate + :param chainer.Reporter reporter : The observations reporter + :param int device : The device id to use + + """ + super(LMEvaluator, self).__init__(val_iter, reporter, device=-1) + self.model = eval_model + self.device = device + + def evaluate(self): + """Evaluate the model.""" + val_iter = self.get_iterator("main") + loss = 0 + nll = 0 + count = 0 + self.model.eval() + with torch.no_grad(): + for batch in copy.copy(val_iter): + x, t = concat_examples(batch, device=self.device[0], padding=(0, -100)) + if self.device[0] == -1: + l, n, c = self.model(x, t) + else: + # apex does not support torch.nn.DataParallel + l, n, c = data_parallel(self.model, (x, t), self.device) + loss += float(l.sum()) + nll += float(n.sum()) + count += int(c.sum()) + self.model.train() + # report validation loss + observation = {} + with reporter.report_scope(observation): + reporter.report({"loss": loss}, self.model.reporter) + reporter.report({"nll": nll}, self.model.reporter) + reporter.report({"count": count}, self.model.reporter) + return observation + + +def train(args): + """Train with the given args. + + :param Namespace args: The program arguments + :param type model_class: LMInterface class for training + """ + model_class = dynamic_import_lm(args.model_module, args.backend) + assert issubclass(model_class, LMInterface), "model should implement LMInterface" + # display torch version + logging.info("torch version = " + torch.__version__) + + set_deterministic_pytorch(args) + + # check cuda and cudnn availability + if not torch.cuda.is_available(): + logging.warning("cuda is not available") + + # get special label ids + unk = args.char_list_dict[""] + eos = args.char_list_dict[""] + # read tokens as a sequence of sentences + val, n_val_tokens, n_val_oovs = load_dataset( + args.valid_label, args.char_list_dict, args.dump_hdf5_path + ) + train, n_train_tokens, n_train_oovs = load_dataset( + args.train_label, args.char_list_dict, args.dump_hdf5_path + ) + logging.info("#vocab = " + str(args.n_vocab)) + logging.info("#sentences in the training data = " + str(len(train))) + logging.info("#tokens in the training data = " + str(n_train_tokens)) + logging.info( + "oov rate in the training data = %.2f %%" + % (n_train_oovs / n_train_tokens * 100) + ) + logging.info("#sentences in the validation data = " + str(len(val))) + logging.info("#tokens in the validation data = " + str(n_val_tokens)) + logging.info( + "oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100) + ) + + use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 + # Create the dataset iterators + batch_size = args.batchsize * max(args.ngpu, 1) + if batch_size * args.accum_grad > args.batchsize: + logging.info( + f"batch size is automatically increased " + f"({args.batchsize} -> {batch_size * args.accum_grad})" + ) + train_iter = ParallelSentenceIterator( + train, + batch_size, + max_length=args.maxlen, + sos=eos, + eos=eos, + shuffle=not use_sortagrad, + ) + val_iter = ParallelSentenceIterator( + val, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False + ) + epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad) + logging.info("#iterations per epoch = %d" % epoch_iters) + logging.info("#total iterations = " + str(args.epoch * epoch_iters)) + # Prepare an RNNLM model + if args.train_dtype in ("float16", "float32", "float64"): + dtype = getattr(torch, args.train_dtype) + else: + dtype = torch.float32 + model = model_class(args.n_vocab, args).to(dtype=dtype) + if args.ngpu > 0: + model.to("cuda") + gpu_id = list(range(args.ngpu)) + else: + gpu_id = [-1] + + # Save model conf to json + model_conf = args.outdir + "/model.json" + with open(model_conf, "wb") as f: + logging.info("writing a model config file to " + model_conf) + f.write( + json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode( + "utf_8" + ) + ) + + logging.warning( + "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(p.numel() for p in model.parameters() if p.requires_grad) + * 100.0 + / sum(p.numel() for p in model.parameters()), + ) + ) + + # Set up an optimizer + opt_class = dynamic_import_optimizer(args.opt, args.backend) + optimizer = opt_class.from_args(model.parameters(), args) + if args.schedulers is None: + schedulers = [] + else: + schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers] + + # setup apex.amp + if args.train_dtype in ("O0", "O1", "O2", "O3"): + try: + from apex import amp + except ImportError as e: + logging.error( + f"You need to install apex for --train-dtype {args.train_dtype}. " + "See https://github.com/NVIDIA/apex#linux" + ) + raise e + model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) + use_apex = True + else: + use_apex = False + + # FIXME: TOO DIRTY HACK + reporter = Reporter() + setattr(model, "reporter", reporter) + setattr(optimizer, "target", reporter) + setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) + + updater = BPTTUpdater( + train_iter, + model, + optimizer, + schedulers, + gpu_id, + gradclip=args.gradclip, + use_apex=use_apex, + accum_grad=args.accum_grad, + ) + trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir) + trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id)) + trainer.extend( + extensions.LogReport( + postprocess=compute_perplexity, + trigger=(args.report_interval_iters, "iteration"), + ) + ) + trainer.extend( + extensions.PrintReport( + [ + "epoch", + "iteration", + "main/loss", + "perplexity", + "val_perplexity", + "elapsed_time", + ] + ), + trigger=(args.report_interval_iters, "iteration"), + ) + trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) + # Save best models + trainer.extend(torch_snapshot(filename="snapshot.ep.{.updater.epoch}")) + trainer.extend(snapshot_object(model, "rnnlm.model.{.updater.epoch}")) + # T.Hori: MinValueTrigger should be used, but it fails when resuming + trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model")) + + if use_sortagrad: + trainer.extend( + ShufflingEnabler([train_iter]), + trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"), + ) + if args.resume: + logging.info("resumed from %s" % args.resume) + torch_resume(args.resume, trainer) + + set_early_stop(trainer, args, is_lm=True) + if args.tensorboard_dir is not None and args.tensorboard_dir != "": + writer = SummaryWriter(args.tensorboard_dir) + trainer.extend( + TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration") + ) + + trainer.run() + check_early_stop(trainer, args.epoch) + + # compute perplexity for test set + if args.test_label: + logging.info("test the best model") + torch_load(args.outdir + "/rnnlm.model.best", model) + test = read_tokens(args.test_label, args.char_list_dict) + n_test_tokens, n_test_oovs = count_tokens(test, unk) + logging.info("#sentences in the test data = " + str(len(test))) + logging.info("#tokens in the test data = " + str(n_test_tokens)) + logging.info( + "oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100) + ) + test_iter = ParallelSentenceIterator( + test, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False + ) + evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id) + result = evaluator() + compute_perplexity(result) + logging.info(f"test perplexity: {result['perplexity']}") diff --git a/espnet/mt/__init__.py b/espnet/mt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/mt/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/mt/mt_utils.py b/espnet/mt/mt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50aa792ba3846c71fa185e7d454a1985e7702ab7 --- /dev/null +++ b/espnet/mt/mt_utils.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Utility funcitons for the text translation task.""" + +import logging + + +# * ------------------ recognition related ------------------ * +def parse_hypothesis(hyp, char_list): + """Parse hypothesis. + + :param list hyp: recognition hypothesis + :param list char_list: list of characters + :return: recognition text string + :return: recognition token string + :return: recognition tokenid string + """ + # remove sos and get results + tokenid_as_list = list(map(int, hyp["yseq"][1:])) + token_as_list = [char_list[idx] for idx in tokenid_as_list] + score = float(hyp["score"]) + + # convert to string + tokenid = " ".join([str(idx) for idx in tokenid_as_list]) + token = " ".join(token_as_list) + text = "".join(token_as_list).replace("", " ") + + return text, token, tokenid, score + + +def add_results_to_json(js, nbest_hyps, char_list): + """Add N-best results to json. + + :param dict js: groundtruth utterance dict + :param list nbest_hyps: list of hypothesis + :param list char_list: list of characters + :return: N-best results added utterance dict + """ + # copy old json info + new_js = dict() + if "utt2spk" in js.keys(): + new_js["utt2spk"] = js["utt2spk"] + new_js["output"] = [] + + for n, hyp in enumerate(nbest_hyps, 1): + # parse hypothesis + rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list) + + # copy ground-truth + if len(js["output"]) > 0: + out_dic = dict(js["output"][0].items()) + else: + out_dic = {"name": ""} + + # update name + out_dic["name"] += "[%d]" % n + + # add recognition results + out_dic["rec_text"] = rec_text + out_dic["rec_token"] = rec_token + out_dic["rec_tokenid"] = rec_tokenid + out_dic["score"] = score + + # add source reference + out_dic["text_src"] = js["output"][1]["text"] + out_dic["token_src"] = js["output"][1]["token"] + out_dic["tokenid_src"] = js["output"][1]["tokenid"] + + # add to list of N-best result dicts + new_js["output"].append(out_dic) + + # show 1-best result + if n == 1: + if "text" in out_dic.keys(): + logging.info("groundtruth: %s" % out_dic["text"]) + logging.info("prediction : %s" % out_dic["rec_text"]) + logging.info("source : %s" % out_dic["token_src"]) + + return new_js diff --git a/espnet/mt/pytorch_backend/__init__.py b/espnet/mt/pytorch_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/mt/pytorch_backend/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/mt/pytorch_backend/mt.py b/espnet/mt/pytorch_backend/mt.py new file mode 100644 index 0000000000000000000000000000000000000000..88474c944ed507c65733d498a50bcb0d46d1be8d --- /dev/null +++ b/espnet/mt/pytorch_backend/mt.py @@ -0,0 +1,600 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Training/decoding definition for the text translation task.""" + +import json +import logging +import os +import sys + +from chainer import training +from chainer.training import extensions +import numpy as np +from tensorboardX import SummaryWriter +import torch + +from espnet.asr.asr_utils import adadelta_eps_decay +from espnet.asr.asr_utils import adam_lr_decay +from espnet.asr.asr_utils import add_results_to_json +from espnet.asr.asr_utils import CompareValueTrigger +from espnet.asr.asr_utils import restore_snapshot +from espnet.asr.asr_utils import snapshot_object +from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import torch_resume +from espnet.asr.asr_utils import torch_snapshot +from espnet.nets.mt_interface import MTInterface +from espnet.nets.pytorch_backend.e2e_asr import pad_list +from espnet.utils.dataset import ChainerDataLoader +from espnet.utils.dataset import TransformDataset +from espnet.utils.deterministic_utils import set_deterministic_pytorch +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.io_utils import LoadInputsAndTargets +from espnet.utils.training.batchfy import make_batchset +from espnet.utils.training.iterators import ShufflingEnabler +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from espnet.utils.training.train_utils import check_early_stop +from espnet.utils.training.train_utils import set_early_stop + +from espnet.asr.pytorch_backend.asr import CustomEvaluator +from espnet.asr.pytorch_backend.asr import CustomUpdater +from espnet.asr.pytorch_backend.asr import load_trained_model + +import matplotlib + +matplotlib.use("Agg") + +if sys.version_info[0] == 2: + from itertools import izip_longest as zip_longest +else: + from itertools import zip_longest as zip_longest + + +class CustomConverter(object): + """Custom batch converter for Pytorch.""" + + def __init__(self): + """Construct a CustomConverter object.""" + self.ignore_id = -1 + self.pad = 0 + # NOTE: we reserve index:0 for although this is reserved for a blank class + # in ASR. However, + # blank labels are not used in NMT. To keep the vocabulary size, + # we use index:0 for padding instead of adding one more class. + + def __call__(self, batch, device=torch.device("cpu")): + """Transform a batch and send it to a device. + + Args: + batch (list): The batch to transform. + device (torch.device): The device to send to. + + Returns: + tuple(torch.Tensor, torch.Tensor, torch.Tensor) + + """ + # batch should be located in list + assert len(batch) == 1 + xs, ys = batch[0] + + # get batch of lengths of input sequences + ilens = np.array([x.shape[0] for x in xs]) + + # perform padding and convert to tensor + xs_pad = pad_list([torch.from_numpy(x).long() for x in xs], self.pad).to(device) + ilens = torch.from_numpy(ilens).to(device) + ys_pad = pad_list([torch.from_numpy(y).long() for y in ys], self.ignore_id).to( + device + ) + + return xs_pad, ilens, ys_pad + + +def train(args): + """Train with the given args. + + Args: + args (namespace): The program arguments. + + """ + set_deterministic_pytorch(args) + + # check cuda availability + if not torch.cuda.is_available(): + logging.warning("cuda is not available") + + # get input and output dimension info + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + utts = list(valid_json.keys()) + idim = int(valid_json[utts[0]]["output"][1]["shape"][1]) + odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) + logging.info("#input dims : " + str(idim)) + logging.info("#output dims: " + str(odim)) + + # specify model architecture + model_class = dynamic_import(args.model_module) + model = model_class(idim, odim, args) + assert isinstance(model, MTInterface) + + # write model config + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + model_conf = args.outdir + "/model.json" + with open(model_conf, "wb") as f: + logging.info("writing a model config file to " + model_conf) + f.write( + json.dumps( + (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) + for key in sorted(vars(args).keys()): + logging.info("ARGS: " + key + ": " + str(vars(args)[key])) + + reporter = model.reporter + + # check the use of multi-gpu + if args.ngpu > 1: + if args.batch_size != 0: + logging.warning( + "batch size is automatically increased (%d -> %d)" + % (args.batch_size, args.batch_size * args.ngpu) + ) + args.batch_size *= args.ngpu + + # set torch device + device = torch.device("cuda" if args.ngpu > 0 else "cpu") + if args.train_dtype in ("float16", "float32", "float64"): + dtype = getattr(torch, args.train_dtype) + else: + dtype = torch.float32 + model = model.to(device=device, dtype=dtype) + + logging.warning( + "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(p.numel() for p in model.parameters() if p.requires_grad) + * 100.0 + / sum(p.numel() for p in model.parameters()), + ) + ) + + # Setup an optimizer + if args.opt == "adadelta": + optimizer = torch.optim.Adadelta( + model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay + ) + elif args.opt == "adam": + optimizer = torch.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + elif args.opt == "noam": + from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + + optimizer = get_std_opt( + model.parameters(), + args.adim, + args.transformer_warmup_steps, + args.transformer_lr, + ) + else: + raise NotImplementedError("unknown optimizer: " + args.opt) + + # setup apex.amp + if args.train_dtype in ("O0", "O1", "O2", "O3"): + try: + from apex import amp + except ImportError as e: + logging.error( + f"You need to install apex for --train-dtype {args.train_dtype}. " + "See https://github.com/NVIDIA/apex#linux" + ) + raise e + if args.opt == "noam": + model, optimizer.optimizer = amp.initialize( + model, optimizer.optimizer, opt_level=args.train_dtype + ) + else: + model, optimizer = amp.initialize( + model, optimizer, opt_level=args.train_dtype + ) + use_apex = True + else: + use_apex = False + + # FIXME: TOO DIRTY HACK + setattr(optimizer, "target", reporter) + setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) + + # Setup a converter + converter = CustomConverter() + + # read json data + with open(args.train_json, "rb") as f: + train_json = json.load(f)["utts"] + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + + use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 + # make minibatch list (variable length) + train = make_batchset( + train_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + shortest_first=use_sortagrad, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + mt=True, + iaxis=1, + oaxis=0, + ) + valid = make_batchset( + valid_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + mt=True, + iaxis=1, + oaxis=0, + ) + + load_tr = LoadInputsAndTargets(mode="mt", load_output=True) + load_cv = LoadInputsAndTargets(mode="mt", load_output=True) + # hack to make batchsize argument as 1 + # actual bathsize is included in a list + # default collate function converts numpy array to pytorch tensor + # we used an empty collate function instead which returns list + train_iter = ChainerDataLoader( + dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), + batch_size=1, + num_workers=args.n_iter_processes, + shuffle=not use_sortagrad, + collate_fn=lambda x: x[0], + ) + valid_iter = ChainerDataLoader( + dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), + batch_size=1, + shuffle=False, + collate_fn=lambda x: x[0], + num_workers=args.n_iter_processes, + ) + + # Set up a trainer + updater = CustomUpdater( + model, + args.grad_clip, + {"main": train_iter}, + optimizer, + device, + args.ngpu, + False, + args.accum_grad, + use_apex=use_apex, + ) + trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) + + if use_sortagrad: + trainer.extend( + ShufflingEnabler([train_iter]), + trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), + ) + + # Resume from a snapshot + if args.resume: + logging.info("resumed from %s" % args.resume) + torch_resume(args.resume, trainer) + + # Evaluate the model with the test dataset for each epoch + if args.save_interval_iters > 0: + trainer.extend( + CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), + trigger=(args.save_interval_iters, "iteration"), + ) + else: + trainer.extend( + CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu) + ) + + # Save attention weight each epoch + if args.num_save_attention > 0: + # NOTE: sort it by output lengths + data = sorted( + list(valid_json.items())[: args.num_save_attention], + key=lambda x: int(x[1]["output"][0]["shape"][0]), + reverse=True, + ) + if hasattr(model, "module"): + att_vis_fn = model.module.calculate_all_attentions + plot_class = model.module.attention_plot_class + else: + att_vis_fn = model.calculate_all_attentions + plot_class = model.attention_plot_class + att_reporter = plot_class( + att_vis_fn, + data, + args.outdir + "/att_ws", + converter=converter, + transform=load_cv, + device=device, + ikey="output", + iaxis=1, + ) + trainer.extend(att_reporter, trigger=(1, "epoch")) + else: + att_reporter = None + + # Make a plot for training and validation values + trainer.extend( + extensions.PlotReport( + ["main/loss", "validation/main/loss"], "epoch", file_name="loss.png" + ) + ) + trainer.extend( + extensions.PlotReport( + ["main/acc", "validation/main/acc"], "epoch", file_name="acc.png" + ) + ) + trainer.extend( + extensions.PlotReport( + ["main/ppl", "validation/main/ppl"], "epoch", file_name="ppl.png" + ) + ) + trainer.extend( + extensions.PlotReport( + ["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png" + ) + ) + + # Save best models + trainer.extend( + snapshot_object(model, "model.loss.best"), + trigger=training.triggers.MinValueTrigger("validation/main/loss"), + ) + trainer.extend( + snapshot_object(model, "model.acc.best"), + trigger=training.triggers.MaxValueTrigger("validation/main/acc"), + ) + + # save snapshot which contains model and optimizer states + if args.save_interval_iters > 0: + trainer.extend( + torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), + trigger=(args.save_interval_iters, "iteration"), + ) + else: + trainer.extend(torch_snapshot(), trigger=(1, "epoch")) + + # epsilon decay in the optimizer + if args.opt == "adadelta": + if args.criterion == "acc": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.acc.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + elif args.criterion == "loss": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.loss.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + elif args.opt == "adam": + if args.criterion == "acc": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.acc.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + trainer.extend( + adam_lr_decay(args.lr_decay), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + elif args.criterion == "loss": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.loss.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + trainer.extend( + adam_lr_decay(args.lr_decay), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + + # Write a log of evaluation statistics for each epoch + trainer.extend( + extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) + ) + report_keys = [ + "epoch", + "iteration", + "main/loss", + "validation/main/loss", + "main/acc", + "validation/main/acc", + "main/ppl", + "validation/main/ppl", + "elapsed_time", + ] + if args.opt == "adadelta": + trainer.extend( + extensions.observe_value( + "eps", + lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ + "eps" + ], + ), + trigger=(args.report_interval_iters, "iteration"), + ) + report_keys.append("eps") + elif args.opt in ["adam", "noam"]: + trainer.extend( + extensions.observe_value( + "lr", + lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ + "lr" + ], + ), + trigger=(args.report_interval_iters, "iteration"), + ) + report_keys.append("lr") + if args.report_bleu: + report_keys.append("main/bleu") + report_keys.append("validation/main/bleu") + trainer.extend( + extensions.PrintReport(report_keys), + trigger=(args.report_interval_iters, "iteration"), + ) + + trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) + set_early_stop(trainer, args) + + if args.tensorboard_dir is not None and args.tensorboard_dir != "": + trainer.extend( + TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter), + trigger=(args.report_interval_iters, "iteration"), + ) + # Run the training + trainer.run() + check_early_stop(trainer, args.epochs) + + +def trans(args): + """Decode with the given args. + + Args: + args (namespace): The program arguments. + + """ + set_deterministic_pytorch(args) + model, train_args = load_trained_model(args.model) + assert isinstance(model, MTInterface) + model.trans_args = args + + # gpu + if args.ngpu == 1: + gpu_id = list(range(args.ngpu)) + logging.info("gpu id: " + str(gpu_id)) + model.cuda() + + # read json data + with open(args.trans_json, "rb") as f: + js = json.load(f)["utts"] + new_js = {} + + # remove enmpy utterances + if train_args.multilingual: + js = { + k: v + for k, v in js.items() + if v["output"][0]["shape"][0] > 1 and v["output"][1]["shape"][0] > 1 + } + else: + js = { + k: v + for k, v in js.items() + if v["output"][0]["shape"][0] > 0 and v["output"][1]["shape"][0] > 0 + } + + if args.batchsize == 0: + with torch.no_grad(): + for idx, name in enumerate(js.keys(), 1): + logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) + feat = [js[name]["output"][1]["tokenid"].split()] + nbest_hyps = model.translate(feat, args, train_args.char_list) + new_js[name] = add_results_to_json( + js[name], nbest_hyps, train_args.char_list + ) + + else: + + def grouper(n, iterable, fillvalue=None): + kargs = [iter(iterable)] * n + return zip_longest(*kargs, fillvalue=fillvalue) + + # sort data + keys = list(js.keys()) + feat_lens = [js[key]["output"][1]["shape"][0] for key in keys] + sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) + keys = [keys[i] for i in sorted_index] + + with torch.no_grad(): + for names in grouper(args.batchsize, keys, None): + names = [name for name in names if name] + feats = [ + np.fromiter( + map(int, js[name]["output"][1]["tokenid"].split()), + dtype=np.int64, + ) + for name in names + ] + nbest_hyps = model.translate_batch( + feats, + args, + train_args.char_list, + ) + + for i, nbest_hyp in enumerate(nbest_hyps): + name = names[i] + new_js[name] = add_results_to_json( + js[name], nbest_hyp, train_args.char_list + ) + + with open(args.result_label, "wb") as f: + f.write( + json.dumps( + {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) diff --git a/espnet/nets/__init__.py b/espnet/nets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/asr_interface.py b/espnet/nets/asr_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..eba4ef0e722faa38b069304a8402544d6f49f393 --- /dev/null +++ b/espnet/nets/asr_interface.py @@ -0,0 +1,172 @@ +"""ASR Interface module.""" +import argparse + +from espnet.bin.asr_train import get_parser +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.fill_missing_args import fill_missing_args + + +class ASRInterface: + """ASR Interface for ESPnet model implementation.""" + + @staticmethod + def add_arguments(parser): + """Add arguments to parser.""" + return parser + + @classmethod + def build(cls, idim: int, odim: int, **kwargs): + """Initialize this class with python-level args. + + Args: + idim (int): The number of an input feature dim. + odim (int): The number of output vocab. + + Returns: + ASRinterface: A new instance of ASRInterface. + + """ + + def wrap(parser): + return get_parser(parser, required=False) + + args = argparse.Namespace(**kwargs) + args = fill_missing_args(args, wrap) + args = fill_missing_args(args, cls.add_arguments) + return cls(idim, odim, args) + + def forward(self, xs, ilens, ys): + """Compute loss for training. + + :param xs: + For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim) + For chainer, list of source sequences chainer.Variable + :param ilens: batch of lengths of source sequences (B) + For pytorch, torch.Tensor + For chainer, list of int + :param ys: + For pytorch, batch of padded source sequences torch.Tensor (B, Lmax) + For chainer, list of source sequences chainer.Variable + :return: loss value + :rtype: torch.Tensor for pytorch, chainer.Variable for chainer + """ + raise NotImplementedError("forward method is not implemented") + + def recognize(self, x, recog_args, char_list=None, rnnlm=None): + """Recognize x for evaluation. + + :param ndarray x: input acouctic feature (B, T, D) or (T, D) + :param namespace recog_args: argment namespace contraining options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("recognize method is not implemented") + + def recognize_batch(self, x, recog_args, char_list=None, rnnlm=None): + """Beam search implementation for batch. + + :param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc) + :param namespace recog_args: argument namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("Batch decoding is not supported yet.") + + def calculate_all_attentions(self, xs, ilens, ys): + """Caluculate attention. + + :param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...] + :param ndarray ilens: batch of lengths of input sequences (B) + :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...] + :return: attention weights (B, Lmax, Tmax) + :rtype: float ndarray + """ + raise NotImplementedError("calculate_all_attentions method is not implemented") + + def calculate_all_ctc_probs(self, xs, ilens, ys): + """Caluculate CTC probability. + + :param list xs_pad: list of padded input sequences [(T1, idim), (T2, idim), ...] + :param ndarray ilens: batch of lengths of input sequences (B) + :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...] + :return: CTC probabilities (B, Tmax, vocab) + :rtype: float ndarray + """ + raise NotImplementedError("calculate_all_ctc_probs method is not implemented") + + @property + def attention_plot_class(self): + """Get attention plot class.""" + from espnet.asr.asr_utils import PlotAttentionReport + + return PlotAttentionReport + + @property + def ctc_plot_class(self): + """Get CTC plot class.""" + from espnet.asr.asr_utils import PlotCTCReport + + return PlotCTCReport + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + raise NotImplementedError( + "get_total_subsampling_factor method is not implemented" + ) + + def encode(self, feat): + """Encode feature in `beam_search` (optional). + + Args: + x (numpy.ndarray): input feature (T, D) + Returns: + torch.Tensor for pytorch, chainer.Variable for chainer: + encoded feature (T, D) + + """ + raise NotImplementedError("encode method is not implemented") + + def scorers(self): + """Get scorers for `beam_search` (optional). + + Returns: + dict[str, ScorerInterface]: dict of `ScorerInterface` objects + + """ + raise NotImplementedError("decoders method is not implemented") + + +predefined_asr = { + "pytorch": { + "rnn": "espnet.nets.pytorch_backend.e2e_asr:E2E", + "transducer": "espnet.nets.pytorch_backend.e2e_asr_transducer:E2E", + "transformer": "espnet.nets.pytorch_backend.e2e_asr_transformer:E2E", + "conformer": "espnet.nets.pytorch_backend.e2e_asr_conformer:E2E", + }, + "chainer": { + "rnn": "espnet.nets.chainer_backend.e2e_asr:E2E", + "transformer": "espnet.nets.chainer_backend.e2e_asr_transformer:E2E", + }, +} + + +def dynamic_import_asr(module, backend): + """Import ASR models dynamically. + + Args: + module (str): module_name:class_name or alias in `predefined_asr` + backend (str): NN backend. e.g., pytorch, chainer + + Returns: + type: ASR class + + """ + model_class = dynamic_import(module, predefined_asr.get(backend, dict())) + assert issubclass( + model_class, ASRInterface + ), f"{module} does not implement ASRInterface" + return model_class diff --git a/espnet/nets/batch_beam_search.py b/espnet/nets/batch_beam_search.py new file mode 100644 index 0000000000000000000000000000000000000000..ba861f3f154258a1708185fdd792f8a17c29f585 --- /dev/null +++ b/espnet/nets/batch_beam_search.py @@ -0,0 +1,348 @@ +"""Parallel beam search module.""" + +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Tuple + +import torch +from torch.nn.utils.rnn import pad_sequence + +from espnet.nets.beam_search import BeamSearch +from espnet.nets.beam_search import Hypothesis + + +class BatchHypothesis(NamedTuple): + """Batchfied/Vectorized hypothesis data type.""" + + yseq: torch.Tensor = torch.tensor([]) # (batch, maxlen) + score: torch.Tensor = torch.tensor([]) # (batch,) + length: torch.Tensor = torch.tensor([]) # (batch,) + scores: Dict[str, torch.Tensor] = dict() # values: (batch,) + states: Dict[str, Dict] = dict() + + def __len__(self) -> int: + """Return a batch size.""" + return len(self.length) + + +class BatchBeamSearch(BeamSearch): + """Batch beam search implementation.""" + + def batchfy(self, hyps: List[Hypothesis]) -> BatchHypothesis: + """Convert list to batch.""" + if len(hyps) == 0: + return BatchHypothesis() + return BatchHypothesis( + yseq=pad_sequence( + [h.yseq for h in hyps], batch_first=True, padding_value=self.eos + ), + length=torch.tensor([len(h.yseq) for h in hyps], dtype=torch.int64), + score=torch.tensor([h.score for h in hyps]), + scores={k: torch.tensor([h.scores[k] for h in hyps]) for k in self.scorers}, + states={k: [h.states[k] for h in hyps] for k in self.scorers}, + ) + + def _batch_select(self, hyps: BatchHypothesis, ids: List[int]) -> BatchHypothesis: + return BatchHypothesis( + yseq=hyps.yseq[ids], + score=hyps.score[ids], + length=hyps.length[ids], + scores={k: v[ids] for k, v in hyps.scores.items()}, + states={ + k: [self.scorers[k].select_state(v, i) for i in ids] + for k, v in hyps.states.items() + }, + ) + + def _select(self, hyps: BatchHypothesis, i: int) -> Hypothesis: + return Hypothesis( + yseq=hyps.yseq[i, : hyps.length[i]], + score=hyps.score[i], + scores={k: v[i] for k, v in hyps.scores.items()}, + states={ + k: self.scorers[k].select_state(v, i) for k, v in hyps.states.items() + }, + ) + + def unbatchfy(self, batch_hyps: BatchHypothesis) -> List[Hypothesis]: + """Revert batch to list.""" + return [ + Hypothesis( + yseq=batch_hyps.yseq[i][: batch_hyps.length[i]], + score=batch_hyps.score[i], + scores={k: batch_hyps.scores[k][i] for k in self.scorers}, + states={ + k: v.select_state(batch_hyps.states[k], i) + for k, v in self.scorers.items() + }, + ) + for i in range(len(batch_hyps.length)) + ] + + def batch_beam( + self, weighted_scores: torch.Tensor, ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Batch-compute topk full token ids and partial token ids. + + Args: + weighted_scores (torch.Tensor): The weighted sum scores for each tokens. + Its shape is `(n_beam, self.vocab_size)`. + ids (torch.Tensor): The partial token ids to compute topk. + Its shape is `(n_beam, self.pre_beam_size)`. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + The topk full (prev_hyp, new_token) ids + and partial (prev_hyp, new_token) ids. + Their shapes are all `(self.beam_size,)` + + """ + top_ids = weighted_scores.view(-1).topk(self.beam_size)[1] + # Because of the flatten above, `top_ids` is organized as: + # [hyp1 * V + token1, hyp2 * V + token2, ..., hypK * V + tokenK], + # where V is `self.n_vocab` and K is `self.beam_size` + prev_hyp_ids = top_ids // self.n_vocab + new_token_ids = top_ids % self.n_vocab + return prev_hyp_ids, new_token_ids, prev_hyp_ids, new_token_ids + + def init_hyp(self, x: torch.Tensor) -> BatchHypothesis: + """Get an initial hypothesis data. + + Args: + x (torch.Tensor): The encoder output feature + + Returns: + Hypothesis: The initial hypothesis. + + """ + init_states = dict() + init_scores = dict() + for k, d in self.scorers.items(): + init_states[k] = d.batch_init_state(x) + init_scores[k] = 0.0 + return self.batchfy( + [ + Hypothesis( + score=0.0, + scores=init_scores, + states=init_states, + yseq=torch.tensor([self.sos], device=x.device), + ) + ] + ) + + def score_full( + self, hyp: BatchHypothesis, x: torch.Tensor + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.full_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.full_scorers` + and tensor score values of shape: `(self.n_vocab,)`, + and state dict that has string keys + and state values of `self.full_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.full_scorers.items(): + scores[k], states[k] = d.batch_score(hyp.yseq, hyp.states[k], x) + return scores, states + + def score_partial( + self, hyp: BatchHypothesis, ids: torch.Tensor, x: torch.Tensor + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.full_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + ids (torch.Tensor): 2D tensor of new partial tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.full_scorers` + and tensor score values of shape: `(self.n_vocab,)`, + and state dict that has string keys + and state values of `self.full_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.part_scorers.items(): + scores[k], states[k] = d.batch_score_partial( + hyp.yseq, ids, hyp.states[k], x + ) + return scores, states + + def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: + """Merge states for new hypothesis. + + Args: + states: states of `self.full_scorers` + part_states: states of `self.part_scorers` + part_idx (int): The new token id for `part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are states of the scorers. + + """ + new_states = dict() + for k, v in states.items(): + new_states[k] = v + for k, v in part_states.items(): + new_states[k] = v + return new_states + + def search(self, running_hyps: BatchHypothesis, x: torch.Tensor) -> BatchHypothesis: + """Search new tokens for running hypotheses and encoded speech x. + + Args: + running_hyps (BatchHypothesis): Running hypotheses on beam + x (torch.Tensor): Encoded speech feature (T, D) + + Returns: + BatchHypothesis: Best sorted hypotheses + + """ + n_batch = len(running_hyps) + part_ids = None # no pre-beam + # batch scoring + weighted_scores = torch.zeros( + n_batch, self.n_vocab, dtype=x.dtype, device=x.device + ) + scores, states = self.score_full(running_hyps, x.expand(n_batch, *x.shape)) + for k in self.full_scorers: + weighted_scores += self.weights[k] * scores[k] + # partial scoring + if self.do_pre_beam: + pre_beam_scores = ( + weighted_scores + if self.pre_beam_score_key == "full" + else scores[self.pre_beam_score_key] + ) + part_ids = torch.topk(pre_beam_scores, self.pre_beam_size, dim=-1)[1] + # NOTE(takaaki-hori): Unlike BeamSearch, we assume that score_partial returns + # full-size score matrices, which has non-zero scores for part_ids and zeros + # for others. + part_scores, part_states = self.score_partial(running_hyps, part_ids, x) + for k in self.part_scorers: + weighted_scores += self.weights[k] * part_scores[k] + # add previous hyp scores + weighted_scores += running_hyps.score.to( + dtype=x.dtype, device=x.device + ).unsqueeze(1) + + # TODO(karita): do not use list. use batch instead + # see also https://github.com/espnet/espnet/pull/1402#discussion_r354561029 + # update hyps + best_hyps = [] + prev_hyps = self.unbatchfy(running_hyps) + for ( + full_prev_hyp_id, + full_new_token_id, + part_prev_hyp_id, + part_new_token_id, + ) in zip(*self.batch_beam(weighted_scores, part_ids)): + prev_hyp = prev_hyps[full_prev_hyp_id] + best_hyps.append( + Hypothesis( + score=weighted_scores[full_prev_hyp_id, full_new_token_id], + yseq=self.append_token(prev_hyp.yseq, full_new_token_id), + scores=self.merge_scores( + prev_hyp.scores, + {k: v[full_prev_hyp_id] for k, v in scores.items()}, + full_new_token_id, + {k: v[part_prev_hyp_id] for k, v in part_scores.items()}, + part_new_token_id, + ), + states=self.merge_states( + { + k: self.full_scorers[k].select_state(v, full_prev_hyp_id) + for k, v in states.items() + }, + { + k: self.part_scorers[k].select_state( + v, part_prev_hyp_id, part_new_token_id + ) + for k, v in part_states.items() + }, + part_new_token_id, + ), + ) + ) + return self.batchfy(best_hyps) + + def post_process( + self, + i: int, + maxlen: int, + maxlenratio: float, + running_hyps: BatchHypothesis, + ended_hyps: List[Hypothesis], + ) -> BatchHypothesis: + """Perform post-processing of beam search iterations. + + Args: + i (int): The length of hypothesis tokens. + maxlen (int): The maximum length of tokens in beam search. + maxlenratio (int): The maximum length ratio in beam search. + running_hyps (BatchHypothesis): The running hypotheses in beam search. + ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. + + Returns: + BatchHypothesis: The new running hypotheses. + + """ + n_batch = running_hyps.yseq.shape[0] + logging.debug(f"the number of running hypothes: {n_batch}") + if self.token_list is not None: + logging.debug( + "best hypo: " + + "".join( + [ + self.token_list[x] + for x in running_hyps.yseq[0, 1 : running_hyps.length[0]] + ] + ) + ) + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last position in the loop") + yseq_eos = torch.cat( + ( + running_hyps.yseq, + torch.full( + (n_batch, 1), + self.eos, + device=running_hyps.yseq.device, + dtype=torch.int64, + ), + ), + 1, + ) + running_hyps.yseq.resize_as_(yseq_eos) + running_hyps.yseq[:] = yseq_eos + running_hyps.length[:] = yseq_eos.shape[1] + + # add ended hypotheses to a final list, and removed them from current hypotheses + # (this will be a probmlem, number of hyps < beam) + is_eos = ( + running_hyps.yseq[torch.arange(n_batch), running_hyps.length - 1] + == self.eos + ) + for b in torch.nonzero(is_eos).view(-1): + hyp = self._select(running_hyps, b) + ended_hyps.append(hyp) + remained_ids = torch.nonzero(is_eos == 0).view(-1) + return self._batch_select(running_hyps, remained_ids) diff --git a/espnet/nets/batch_beam_search_online_sim.py b/espnet/nets/batch_beam_search_online_sim.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b348654ed51da54c38cf9d93420b69f0790fd0 --- /dev/null +++ b/espnet/nets/batch_beam_search_online_sim.py @@ -0,0 +1,270 @@ +"""Parallel beam search module for online simulation.""" + +import logging +from pathlib import Path +from typing import List + +import yaml + +import torch + +from espnet.nets.batch_beam_search import BatchBeamSearch +from espnet.nets.beam_search import Hypothesis +from espnet.nets.e2e_asr_common import end_detect + + +class BatchBeamSearchOnlineSim(BatchBeamSearch): + """Online beam search implementation. + + This simulates streaming decoding. + It requires encoded features of entire utterance and + extracts block by block from it as it shoud be done + in streaming processing. + This is based on Tsunoo et al, "STREAMING TRANSFORMER ASR + WITH BLOCKWISE SYNCHRONOUS BEAM SEARCH" + (https://arxiv.org/abs/2006.14941). + """ + + def set_streaming_config(self, asr_config: str): + """Set config file for streaming decoding. + + Args: + asr_config (str): The config file for asr training + + """ + train_config_file = Path(asr_config) + self.block_size = None + self.hop_size = None + self.look_ahead = None + config = None + with train_config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + if "encoder_conf" in args.keys(): + if "block_size" in args["encoder_conf"].keys(): + self.block_size = args["encoder_conf"]["block_size"] + if "hop_size" in args["encoder_conf"].keys(): + self.hop_size = args["encoder_conf"]["hop_size"] + if "look_ahead" in args["encoder_conf"].keys(): + self.look_ahead = args["encoder_conf"]["look_ahead"] + elif "config" in args.keys(): + config = args["config"] + if config is None: + logging.info( + "Cannot find config file for streaming decoding: " + + "apply batch beam search instead." + ) + return + if ( + self.block_size is None or self.hop_size is None or self.look_ahead is None + ) and config is not None: + config_file = Path(config) + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + if "encoder_conf" in args.keys(): + enc_args = args["encoder_conf"] + if enc_args and "block_size" in enc_args: + self.block_size = enc_args["block_size"] + if enc_args and "hop_size" in enc_args: + self.hop_size = enc_args["hop_size"] + if enc_args and "look_ahead" in enc_args: + self.look_ahead = enc_args["look_ahead"] + + def set_block_size(self, block_size: int): + """Set block size for streaming decoding. + + Args: + block_size (int): The block size of encoder + """ + self.block_size = block_size + + def set_hop_size(self, hop_size: int): + """Set hop size for streaming decoding. + + Args: + hop_size (int): The hop size of encoder + """ + self.hop_size = hop_size + + def set_look_ahead(self, look_ahead: int): + """Set look ahead size for streaming decoding. + + Args: + look_ahead (int): The look ahead size of encoder + """ + self.look_ahead = look_ahead + + def forward( + self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 + ) -> List[Hypothesis]: + """Perform beam search. + + Args: + x (torch.Tensor): Encoded speech feature (T, D) + maxlenratio (float): Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths + minlenratio (float): Input length ratio to obtain min output length. + + Returns: + list[Hypothesis]: N-best decoding results + + """ + self.conservative = True # always true + + if self.block_size and self.hop_size and self.look_ahead: + cur_end_frame = int(self.block_size - self.look_ahead) + else: + cur_end_frame = x.shape[0] + process_idx = 0 + if cur_end_frame < x.shape[0]: + h = x.narrow(0, 0, cur_end_frame) + else: + h = x + + # set length bounds + if maxlenratio == 0: + maxlen = x.shape[0] + else: + maxlen = max(1, int(maxlenratio * x.size(0))) + minlen = int(minlenratio * x.size(0)) + logging.info("decoder input length: " + str(x.shape[0])) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # main loop of prefix search + running_hyps = self.init_hyp(h) + prev_hyps = [] + ended_hyps = [] + prev_repeat = False + + continue_decode = True + + while continue_decode: + move_to_next_block = False + if cur_end_frame < x.shape[0]: + h = x.narrow(0, 0, cur_end_frame) + else: + h = x + + # extend states for ctc + self.extend(h, running_hyps) + + while process_idx < maxlen: + logging.debug("position " + str(process_idx)) + best = self.search(running_hyps, h) + + if process_idx == maxlen - 1: + # end decoding + running_hyps = self.post_process( + process_idx, maxlen, maxlenratio, best, ended_hyps + ) + n_batch = best.yseq.shape[0] + local_ended_hyps = [] + is_local_eos = ( + best.yseq[torch.arange(n_batch), best.length - 1] == self.eos + ) + for i in range(is_local_eos.shape[0]): + if is_local_eos[i]: + hyp = self._select(best, i) + local_ended_hyps.append(hyp) + # NOTE(tsunoo): check repetitions here + # This is a implicit implementation of + # Eq (11) in https://arxiv.org/abs/2006.14941 + # A flag prev_repeat is used instead of using set + elif ( + not prev_repeat + and best.yseq[i, -1] in best.yseq[i, :-1] + and cur_end_frame < x.shape[0] + ): + move_to_next_block = True + prev_repeat = True + if maxlenratio == 0.0 and end_detect( + [lh.asdict() for lh in local_ended_hyps], process_idx + ): + logging.info(f"end detected at {process_idx}") + continue_decode = False + break + if len(local_ended_hyps) > 0 and cur_end_frame < x.shape[0]: + move_to_next_block = True + + if move_to_next_block: + if ( + self.hop_size + and cur_end_frame + int(self.hop_size) + int(self.look_ahead) + < x.shape[0] + ): + cur_end_frame += int(self.hop_size) + else: + cur_end_frame = x.shape[0] + logging.debug("Going to next block: %d", cur_end_frame) + if process_idx > 1 and len(prev_hyps) > 0 and self.conservative: + running_hyps = prev_hyps + process_idx -= 1 + prev_hyps = [] + break + + prev_repeat = False + prev_hyps = running_hyps + running_hyps = self.post_process( + process_idx, maxlen, maxlenratio, best, ended_hyps + ) + + if cur_end_frame >= x.shape[0]: + for hyp in local_ended_hyps: + ended_hyps.append(hyp) + + if len(running_hyps) == 0: + logging.info("no hypothesis. Finish decoding.") + continue_decode = False + break + else: + logging.debug(f"remained hypotheses: {len(running_hyps)}") + # increment number + process_idx += 1 + + nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) + # check the number of hypotheses reaching to eos + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform recognition " + "again with smaller minlenratio." + ) + return ( + [] + if minlenratio < 0.1 + else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) + ) + + # report the best result + best = nbest_hyps[0] + for k, v in best.scores.items(): + logging.info( + f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" + ) + logging.info(f"total log probability: {best.score:.2f}") + logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") + logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") + if self.token_list is not None: + logging.info( + "best hypo: " + + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + + "\n" + ) + return nbest_hyps + + def extend(self, x: torch.Tensor, hyps: Hypothesis) -> List[Hypothesis]: + """Extend probabilities and states with more encoded chunks. + + Args: + x (torch.Tensor): The extended encoder output feature + hyps (Hypothesis): Current list of hypothesis + + Returns: + Hypothesis: The exxtended hypothesis + + """ + for k, d in self.scorers.items(): + if hasattr(d, "extend_prob"): + d.extend_prob(x) + if hasattr(d, "extend_state"): + hyps.states[k] = d.extend_state(hyps.states[k]) diff --git a/espnet/nets/beam_search.py b/espnet/nets/beam_search.py new file mode 100644 index 0000000000000000000000000000000000000000..fa41753c948621dae51794f7c111188f39bddd49 --- /dev/null +++ b/espnet/nets/beam_search.py @@ -0,0 +1,512 @@ +"""Beam search module.""" + +from itertools import chain +import logging +from typing import Any +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Tuple +from typing import Union + +import torch + +from espnet.nets.e2e_asr_common import end_detect +from espnet.nets.scorer_interface import PartialScorerInterface +from espnet.nets.scorer_interface import ScorerInterface + + +class Hypothesis(NamedTuple): + """Hypothesis data type.""" + + yseq: torch.Tensor + score: Union[float, torch.Tensor] = 0 + scores: Dict[str, Union[float, torch.Tensor]] = dict() + states: Dict[str, Any] = dict() + + def asdict(self) -> dict: + """Convert data to JSON-friendly dict.""" + return self._replace( + yseq=self.yseq.tolist(), + score=float(self.score), + scores={k: float(v) for k, v in self.scores.items()}, + )._asdict() + + +class BeamSearch(torch.nn.Module): + """Beam search implementation.""" + + def __init__( + self, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + beam_size: int, + vocab_size: int, + sos: int, + eos: int, + token_list: List[str] = None, + pre_beam_ratio: float = 1.5, + pre_beam_score_key: str = None, + ): + """Initialize beam search. + + Args: + scorers (dict[str, ScorerInterface]): Dict of decoder modules + e.g., Decoder, CTCPrefixScorer, LM + The scorer will be ignored if it is `None` + weights (dict[str, float]): Dict of weights for each scorers + The scorer will be ignored if its weight is 0 + beam_size (int): The number of hypotheses kept during search + vocab_size (int): The number of vocabulary + sos (int): Start of sequence id + eos (int): End of sequence id + token_list (list[str]): List of tokens for debug log + pre_beam_score_key (str): key of scores to perform pre-beam search + pre_beam_ratio (float): beam size in the pre-beam search + will be `int(pre_beam_ratio * beam_size)` + + """ + super().__init__() + # set scorers + self.weights = weights + self.scorers = dict() + self.full_scorers = dict() + self.part_scorers = dict() + # this module dict is required for recursive cast + # `self.to(device, dtype)` in `recog.py` + self.nn_dict = torch.nn.ModuleDict() + for k, v in scorers.items(): + w = weights.get(k, 0) + if w == 0 or v is None: + continue + assert isinstance( + v, ScorerInterface + ), f"{k} ({type(v)}) does not implement ScorerInterface" + self.scorers[k] = v + if isinstance(v, PartialScorerInterface): + self.part_scorers[k] = v + else: + self.full_scorers[k] = v + if isinstance(v, torch.nn.Module): + self.nn_dict[k] = v + + # set configurations + self.sos = sos + self.eos = eos + self.token_list = token_list + self.pre_beam_size = int(pre_beam_ratio * beam_size) + self.beam_size = beam_size + self.n_vocab = vocab_size + if ( + pre_beam_score_key is not None + and pre_beam_score_key != "full" + and pre_beam_score_key not in self.full_scorers + ): + raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") + self.pre_beam_score_key = pre_beam_score_key + self.do_pre_beam = ( + self.pre_beam_score_key is not None + and self.pre_beam_size < self.n_vocab + and len(self.part_scorers) > 0 + ) + + def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]: + """Get an initial hypothesis data. + + Args: + x (torch.Tensor): The encoder output feature + + Returns: + Hypothesis: The initial hypothesis. + + """ + init_states = dict() + init_scores = dict() + for k, d in self.scorers.items(): + init_states[k] = d.init_state(x) + init_scores[k] = 0.0 + return [ + Hypothesis( + score=0.0, + scores=init_scores, + states=init_states, + yseq=torch.tensor([self.sos], device=x.device), + ) + ] + + @staticmethod + def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: + """Append new token to prefix tokens. + + Args: + xs (torch.Tensor): The prefix token + x (int): The new token to append + + Returns: + torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device + + """ + x = torch.tensor([x], dtype=xs.dtype, device=xs.device) + return torch.cat((xs, x)) + + def score_full( + self, hyp: Hypothesis, x: torch.Tensor + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.full_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.full_scorers` + and tensor score values of shape: `(self.n_vocab,)`, + and state dict that has string keys + and state values of `self.full_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.full_scorers.items(): + scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x) + return scores, states + + def score_partial( + self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: + """Score new hypothesis by `self.part_scorers`. + + Args: + hyp (Hypothesis): Hypothesis with prefix tokens to score + ids (torch.Tensor): 1D tensor of new partial tokens to score + x (torch.Tensor): Corresponding input feature + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of + score dict of `hyp` that has string keys of `self.part_scorers` + and tensor score values of shape: `(len(ids),)`, + and state dict that has string keys + and state values of `self.part_scorers` + + """ + scores = dict() + states = dict() + for k, d in self.part_scorers.items(): + scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) + return scores, states + + def beam( + self, weighted_scores: torch.Tensor, ids: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute topk full token ids and partial token ids. + + Args: + weighted_scores (torch.Tensor): The weighted sum scores for each tokens. + Its shape is `(self.n_vocab,)`. + ids (torch.Tensor): The partial token ids to compute topk + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + The topk full token ids and partial token ids. + Their shapes are `(self.beam_size,)` + + """ + # no pre beam performed + if weighted_scores.size(0) == ids.size(0): + top_ids = weighted_scores.topk(self.beam_size)[1] + return top_ids, top_ids + + # mask pruned in pre-beam not to select in topk + tmp = weighted_scores[ids] + weighted_scores[:] = -float("inf") + weighted_scores[ids] = tmp + top_ids = weighted_scores.topk(self.beam_size)[1] + local_ids = weighted_scores[ids].topk(self.beam_size)[1] + return top_ids, local_ids + + @staticmethod + def merge_scores( + prev_scores: Dict[str, float], + next_full_scores: Dict[str, torch.Tensor], + full_idx: int, + next_part_scores: Dict[str, torch.Tensor], + part_idx: int, + ) -> Dict[str, torch.Tensor]: + """Merge scores for new hypothesis. + + Args: + prev_scores (Dict[str, float]): + The previous hypothesis scores by `self.scorers` + next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` + full_idx (int): The next token id for `next_full_scores` + next_part_scores (Dict[str, torch.Tensor]): + scores of partial tokens by `self.part_scorers` + part_idx (int): The new token id for `next_part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are scalar tensors by the scorers. + + """ + new_scores = dict() + for k, v in next_full_scores.items(): + new_scores[k] = prev_scores[k] + v[full_idx] + for k, v in next_part_scores.items(): + new_scores[k] = prev_scores[k] + v[part_idx] + return new_scores + + def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: + """Merge states for new hypothesis. + + Args: + states: states of `self.full_scorers` + part_states: states of `self.part_scorers` + part_idx (int): The new token id for `part_scores` + + Returns: + Dict[str, torch.Tensor]: The new score dict. + Its keys are names of `self.full_scorers` and `self.part_scorers`. + Its values are states of the scorers. + + """ + new_states = dict() + for k, v in states.items(): + new_states[k] = v + for k, d in self.part_scorers.items(): + new_states[k] = d.select_state(part_states[k], part_idx) + return new_states + + def search( + self, running_hyps: List[Hypothesis], x: torch.Tensor + ) -> List[Hypothesis]: + """Search new tokens for running hypotheses and encoded speech x. + + Args: + running_hyps (List[Hypothesis]): Running hypotheses on beam + x (torch.Tensor): Encoded speech feature (T, D) + + Returns: + List[Hypotheses]: Best sorted hypotheses + + """ + best_hyps = [] + part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam + for hyp in running_hyps: + # scoring + weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) + scores, states = self.score_full(hyp, x) + for k in self.full_scorers: + weighted_scores += self.weights[k] * scores[k] + # partial scoring + if self.do_pre_beam: + pre_beam_scores = ( + weighted_scores + if self.pre_beam_score_key == "full" + else scores[self.pre_beam_score_key] + ) + part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] + part_scores, part_states = self.score_partial(hyp, part_ids, x) + for k in self.part_scorers: + weighted_scores[part_ids] += self.weights[k] * part_scores[k] + # add previous hyp score + weighted_scores += hyp.score + + # update hyps + for j, part_j in zip(*self.beam(weighted_scores, part_ids)): + # will be (2 x beam at most) + best_hyps.append( + Hypothesis( + score=weighted_scores[j], + yseq=self.append_token(hyp.yseq, j), + scores=self.merge_scores( + hyp.scores, scores, j, part_scores, part_j + ), + states=self.merge_states(states, part_states, part_j), + ) + ) + + # sort and prune 2 x beam -> beam + best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ + : min(len(best_hyps), self.beam_size) + ] + return best_hyps + + def forward( + self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 + ) -> List[Hypothesis]: + """Perform beam search. + + Args: + x (torch.Tensor): Encoded speech feature (T, D) + maxlenratio (float): Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths + minlenratio (float): Input length ratio to obtain min output length. + + Returns: + list[Hypothesis]: N-best decoding results + + """ + # set length bounds + if maxlenratio == 0: + maxlen = x.shape[0] + else: + maxlen = max(1, int(maxlenratio * x.size(0))) + minlen = int(minlenratio * x.size(0)) + logging.info("decoder input length: " + str(x.shape[0])) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # main loop of prefix search + running_hyps = self.init_hyp(x) + ended_hyps = [] + for i in range(maxlen): + logging.debug("position " + str(i)) + best = self.search(running_hyps, x) + # post process of one iteration + running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) + # end detection + if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): + logging.info(f"end detected at {i}") + break + if len(running_hyps) == 0: + logging.info("no hypothesis. Finish decoding.") + break + else: + logging.debug(f"remained hypotheses: {len(running_hyps)}") + + nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) + # check the number of hypotheses reaching to eos + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform recognition " + "again with smaller minlenratio." + ) + return ( + [] + if minlenratio < 0.1 + else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) + ) + + # report the best result + best = nbest_hyps[0] + for k, v in best.scores.items(): + logging.info( + f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" + ) + logging.info(f"total log probability: {best.score:.2f}") + logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") + logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") + if self.token_list is not None: + logging.info( + "best hypo: " + + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + + "\n" + ) + return nbest_hyps + + def post_process( + self, + i: int, + maxlen: int, + maxlenratio: float, + running_hyps: List[Hypothesis], + ended_hyps: List[Hypothesis], + ) -> List[Hypothesis]: + """Perform post-processing of beam search iterations. + + Args: + i (int): The length of hypothesis tokens. + maxlen (int): The maximum length of tokens in beam search. + maxlenratio (int): The maximum length ratio in beam search. + running_hyps (List[Hypothesis]): The running hypotheses in beam search. + ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. + + Returns: + List[Hypothesis]: The new running hypotheses. + + """ + logging.debug(f"the number of running hypotheses: {len(running_hyps)}") + if self.token_list is not None: + logging.debug( + "best hypo: " + + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]]) + ) + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last position in the loop") + running_hyps = [ + h._replace(yseq=self.append_token(h.yseq, self.eos)) + for h in running_hyps + ] + + # add ended hypotheses to a final list, and removed them from current hypotheses + # (this will be a problem, number of hyps < beam) + remained_hyps = [] + for hyp in running_hyps: + if hyp.yseq[-1] == self.eos: + # e.g., Word LM needs to add final score + for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): + s = d.final_score(hyp.states[k]) + hyp.scores[k] += s + hyp = hyp._replace(score=hyp.score + self.weights[k] * s) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + return remained_hyps + + +def beam_search( + x: torch.Tensor, + sos: int, + eos: int, + beam_size: int, + vocab_size: int, + scorers: Dict[str, ScorerInterface], + weights: Dict[str, float], + token_list: List[str] = None, + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + pre_beam_ratio: float = 1.5, + pre_beam_score_key: str = "full", +) -> list: + """Perform beam search with scorers. + + Args: + x (torch.Tensor): Encoded speech feature (T, D) + sos (int): Start of sequence id + eos (int): End of sequence id + beam_size (int): The number of hypotheses kept during search + vocab_size (int): The number of vocabulary + scorers (dict[str, ScorerInterface]): Dict of decoder modules + e.g., Decoder, CTCPrefixScorer, LM + The scorer will be ignored if it is `None` + weights (dict[str, float]): Dict of weights for each scorers + The scorer will be ignored if its weight is 0 + token_list (list[str]): List of tokens for debug log + maxlenratio (float): Input length ratio to obtain max output length. + If maxlenratio=0.0 (default), it uses a end-detect function + to automatically find maximum hypothesis lengths + minlenratio (float): Input length ratio to obtain min output length. + pre_beam_score_key (str): key of scores to perform pre-beam search + pre_beam_ratio (float): beam size in the pre-beam search + will be `int(pre_beam_ratio * beam_size)` + + Returns: + list: N-best decoding results + + """ + ret = BeamSearch( + scorers, + weights, + beam_size=beam_size, + vocab_size=vocab_size, + pre_beam_ratio=pre_beam_ratio, + pre_beam_score_key=pre_beam_score_key, + sos=sos, + eos=eos, + token_list=token_list, + ).forward(x=x, maxlenratio=maxlenratio, minlenratio=minlenratio) + return [h.asdict() for h in ret] diff --git a/espnet/nets/beam_search_transducer.py b/espnet/nets/beam_search_transducer.py new file mode 100644 index 0000000000000000000000000000000000000000..925374a163cfe50f85223d907456b4bab5e58ee6 --- /dev/null +++ b/espnet/nets/beam_search_transducer.py @@ -0,0 +1,629 @@ +"""Search algorithms for transducer models.""" + +from typing import List +from typing import Union + +import numpy as np +import torch + +from espnet.nets.pytorch_backend.transducer.utils import create_lm_batch_state +from espnet.nets.pytorch_backend.transducer.utils import init_lm_state +from espnet.nets.pytorch_backend.transducer.utils import is_prefix +from espnet.nets.pytorch_backend.transducer.utils import recombine_hyps +from espnet.nets.pytorch_backend.transducer.utils import select_lm_state +from espnet.nets.pytorch_backend.transducer.utils import substract +from espnet.nets.transducer_decoder_interface import Hypothesis +from espnet.nets.transducer_decoder_interface import NSCHypothesis +from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface + + +class BeamSearchTransducer: + """Beam search implementation for transducer.""" + + def __init__( + self, + decoder: Union[TransducerDecoderInterface, torch.nn.Module], + joint_network: torch.nn.Module, + beam_size: int, + lm: torch.nn.Module = None, + lm_weight: float = 0.1, + search_type: str = "default", + max_sym_exp: int = 2, + u_max: int = 50, + nstep: int = 1, + prefix_alpha: int = 1, + score_norm: bool = True, + nbest: int = 1, + ): + """Initialize transducer beam search. + + Args: + decoder: Decoder class to use + joint_network: Joint Network class + beam_size: Number of hypotheses kept during search + lm: LM class to use + lm_weight: lm weight for soft fusion + search_type: type of algorithm to use for search + max_sym_exp: number of maximum symbol expansions at each time step ("tsd") + u_max: maximum output sequence length ("alsd") + nstep: number of maximum expansion steps at each time step ("nsc") + prefix_alpha: maximum prefix length in prefix search ("nsc") + score_norm: normalize final scores by length ("default") + nbest: number of returned final hypothesis + """ + self.decoder = decoder + self.joint_network = joint_network + + self.beam_size = beam_size + self.hidden_size = decoder.dunits + self.vocab_size = decoder.odim + self.blank = decoder.blank + + if self.beam_size <= 1: + self.search_algorithm = self.greedy_search + elif search_type == "default": + self.search_algorithm = self.default_beam_search + elif search_type == "tsd": + self.search_algorithm = self.time_sync_decoding + elif search_type == "alsd": + self.search_algorithm = self.align_length_sync_decoding + elif search_type == "nsc": + self.search_algorithm = self.nsc_beam_search + else: + raise NotImplementedError + + self.lm = lm + self.lm_weight = lm_weight + + if lm is not None: + self.use_lm = True + self.is_wordlm = True if hasattr(lm.predictor, "wordlm") else False + self.lm_predictor = lm.predictor.wordlm if self.is_wordlm else lm.predictor + self.lm_layers = len(self.lm_predictor.rnn) + else: + self.use_lm = False + + self.max_sym_exp = max_sym_exp + self.u_max = u_max + self.nstep = nstep + self.prefix_alpha = prefix_alpha + self.score_norm = score_norm + + self.nbest = nbest + + def __call__(self, h: torch.Tensor) -> Union[List[Hypothesis], List[NSCHypothesis]]: + """Perform beam search. + + Args: + h: Encoded speech features (T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + + """ + self.decoder.set_device(h.device) + + if not hasattr(self.decoder, "decoders"): + self.decoder.set_data_type(h.dtype) + + nbest_hyps = self.search_algorithm(h) + + return nbest_hyps + + def sort_nbest( + self, hyps: Union[List[Hypothesis], List[NSCHypothesis]] + ) -> Union[List[Hypothesis], List[NSCHypothesis]]: + """Sort hypotheses by score or score given sequence length. + + Args: + hyps: list of hypotheses + + Return: + hyps: sorted list of hypotheses + + """ + if self.score_norm: + hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) + else: + hyps.sort(key=lambda x: x.score, reverse=True) + + return hyps[: self.nbest] + + def greedy_search(self, h: torch.Tensor) -> List[Hypothesis]: + """Greedy search implementation for transformer-transducer. + + Args: + h: Encoded speech features (T_max, D_enc) + + Returns: + hyp: 1-best decoding results + + """ + dec_state = self.decoder.init_state(1) + + hyp = Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state) + cache = {} + + y, state, _ = self.decoder.score(hyp, cache) + + for i, hi in enumerate(h): + ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1) + logp, pred = torch.max(ytu, dim=-1) + + if pred != self.blank: + hyp.yseq.append(int(pred)) + hyp.score += float(logp) + + hyp.dec_state = state + + y, state, _ = self.decoder.score(hyp, cache) + + return [hyp] + + def default_beam_search(self, h: torch.Tensor) -> List[Hypothesis]: + """Beam search implementation. + + Args: + x: Encoded speech features (T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + + """ + beam = min(self.beam_size, self.vocab_size) + beam_k = min(beam, (self.vocab_size - 1)) + + dec_state = self.decoder.init_state(1) + + kept_hyps = [Hypothesis(score=0.0, yseq=[self.blank], dec_state=dec_state)] + cache = {} + + for hi in h: + hyps = kept_hyps + kept_hyps = [] + + while True: + max_hyp = max(hyps, key=lambda x: x.score) + hyps.remove(max_hyp) + + y, state, lm_tokens = self.decoder.score(max_hyp, cache) + + ytu = torch.log_softmax(self.joint_network(hi, y), dim=-1) + top_k = ytu[1:].topk(beam_k, dim=-1) + + kept_hyps.append( + Hypothesis( + score=(max_hyp.score + float(ytu[0:1])), + yseq=max_hyp.yseq[:], + dec_state=max_hyp.dec_state, + lm_state=max_hyp.lm_state, + ) + ) + + if self.use_lm: + lm_state, lm_scores = self.lm.predict(max_hyp.lm_state, lm_tokens) + else: + lm_state = max_hyp.lm_state + + for logp, k in zip(*top_k): + score = max_hyp.score + float(logp) + + if self.use_lm: + score += self.lm_weight * lm_scores[0][k + 1] + + hyps.append( + Hypothesis( + score=score, + yseq=max_hyp.yseq[:] + [int(k + 1)], + dec_state=state, + lm_state=lm_state, + ) + ) + + hyps_max = float(max(hyps, key=lambda x: x.score).score) + kept_most_prob = sorted( + [hyp for hyp in kept_hyps if hyp.score > hyps_max], + key=lambda x: x.score, + ) + if len(kept_most_prob) >= beam: + kept_hyps = kept_most_prob + break + + return self.sort_nbest(kept_hyps) + + def time_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]: + """Time synchronous beam search implementation. + + Based on https://ieeexplore.ieee.org/document/9053040 + + Args: + h: Encoded speech features (T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + + """ + beam = min(self.beam_size, self.vocab_size) + + beam_state = self.decoder.init_state(beam) + + B = [ + Hypothesis( + yseq=[self.blank], + score=0.0, + dec_state=self.decoder.select_state(beam_state, 0), + ) + ] + cache = {} + + if self.use_lm and not self.is_wordlm: + B[0].lm_state = init_lm_state(self.lm_predictor) + + for hi in h: + A = [] + C = B + + h_enc = hi.unsqueeze(0) + + for v in range(self.max_sym_exp): + D = [] + + beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( + C, + beam_state, + cache, + self.use_lm, + ) + + beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1) + beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) + + seq_A = [h.yseq for h in A] + + for i, hyp in enumerate(C): + if hyp.yseq not in seq_A: + A.append( + Hypothesis( + score=(hyp.score + float(beam_logp[i, 0])), + yseq=hyp.yseq[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + ) + ) + else: + dict_pos = seq_A.index(hyp.yseq) + + A[dict_pos].score = np.logaddexp( + A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) + ) + + if v < (self.max_sym_exp - 1): + if self.use_lm: + beam_lm_states = create_lm_batch_state( + [c.lm_state for c in C], self.lm_layers, self.is_wordlm + ) + + beam_lm_states, beam_lm_scores = self.lm.buff_predict( + beam_lm_states, beam_lm_tokens, len(C) + ) + + for i, hyp in enumerate(C): + for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): + new_hyp = Hypothesis( + score=(hyp.score + float(logp)), + yseq=(hyp.yseq + [int(k)]), + dec_state=self.decoder.select_state(beam_state, i), + lm_state=hyp.lm_state, + ) + + if self.use_lm: + new_hyp.score += self.lm_weight * beam_lm_scores[i, k] + + new_hyp.lm_state = select_lm_state( + beam_lm_states, i, self.lm_layers, self.is_wordlm + ) + + D.append(new_hyp) + + C = sorted(D, key=lambda x: x.score, reverse=True)[:beam] + + B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] + + return self.sort_nbest(B) + + def align_length_sync_decoding(self, h: torch.Tensor) -> List[Hypothesis]: + """Alignment-length synchronous beam search implementation. + + Based on https://ieeexplore.ieee.org/document/9053040 + + Args: + h: Encoded speech features (T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + + """ + beam = min(self.beam_size, self.vocab_size) + + h_length = int(h.size(0)) + u_max = min(self.u_max, (h_length - 1)) + + beam_state = self.decoder.init_state(beam) + + B = [ + Hypothesis( + yseq=[self.blank], + score=0.0, + dec_state=self.decoder.select_state(beam_state, 0), + ) + ] + final = [] + cache = {} + + if self.use_lm and not self.is_wordlm: + B[0].lm_state = init_lm_state(self.lm_predictor) + + for i in range(h_length + u_max): + A = [] + + B_ = [] + h_states = [] + for hyp in B: + u = len(hyp.yseq) - 1 + t = i - u + 1 + + if t > (h_length - 1): + continue + + B_.append(hyp) + h_states.append((t, h[t])) + + if B_: + beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( + B_, + beam_state, + cache, + self.use_lm, + ) + + h_enc = torch.stack([h[1] for h in h_states]) + + beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1) + beam_topk = beam_logp[:, 1:].topk(beam, dim=-1) + + if self.use_lm: + beam_lm_states = create_lm_batch_state( + [b.lm_state for b in B_], self.lm_layers, self.is_wordlm + ) + + beam_lm_states, beam_lm_scores = self.lm.buff_predict( + beam_lm_states, beam_lm_tokens, len(B_) + ) + + for i, hyp in enumerate(B_): + new_hyp = Hypothesis( + score=(hyp.score + float(beam_logp[i, 0])), + yseq=hyp.yseq[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + ) + + A.append(new_hyp) + + if h_states[i][0] == (h_length - 1): + final.append(new_hyp) + + for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): + new_hyp = Hypothesis( + score=(hyp.score + float(logp)), + yseq=(hyp.yseq[:] + [int(k)]), + dec_state=self.decoder.select_state(beam_state, i), + lm_state=hyp.lm_state, + ) + + if self.use_lm: + new_hyp.score += self.lm_weight * beam_lm_scores[i, k] + + new_hyp.lm_state = select_lm_state( + beam_lm_states, i, self.lm_layers, self.is_wordlm + ) + + A.append(new_hyp) + + B = sorted(A, key=lambda x: x.score, reverse=True)[:beam] + B = recombine_hyps(B) + + if final: + return self.sort_nbest(final) + else: + return B + + def nsc_beam_search(self, h: torch.Tensor) -> List[NSCHypothesis]: + """N-step constrained beam search implementation. + + Based and modified from https://arxiv.org/pdf/2002.03577.pdf. + Please reference ESPnet (b-flo, PR #2444) for any usage outside ESPnet + until further modifications. + + Note: the algorithm is not in his "complete" form but works almost as + intended. + + Args: + h: Encoded speech features (T_max, D_enc) + + Returns: + nbest_hyps: N-best decoding results + + """ + beam = min(self.beam_size, self.vocab_size) + beam_k = min(beam, (self.vocab_size - 1)) + + beam_state = self.decoder.init_state(beam) + + init_tokens = [ + NSCHypothesis( + yseq=[self.blank], + score=0.0, + dec_state=self.decoder.select_state(beam_state, 0), + ) + ] + + cache = {} + + beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( + init_tokens, + beam_state, + cache, + self.use_lm, + ) + + state = self.decoder.select_state(beam_state, 0) + + if self.use_lm: + beam_lm_states, beam_lm_scores = self.lm.buff_predict( + None, beam_lm_tokens, 1 + ) + lm_state = select_lm_state( + beam_lm_states, 0, self.lm_layers, self.is_wordlm + ) + lm_scores = beam_lm_scores[0] + else: + lm_state = None + lm_scores = None + + kept_hyps = [ + NSCHypothesis( + yseq=[self.blank], + score=0.0, + dec_state=state, + y=[beam_y[0]], + lm_state=lm_state, + lm_scores=lm_scores, + ) + ] + + for hi in h: + hyps = sorted(kept_hyps, key=lambda x: len(x.yseq), reverse=True) + kept_hyps = [] + + h_enc = hi.unsqueeze(0) + + for j, hyp_j in enumerate(hyps[:-1]): + for hyp_i in hyps[(j + 1) :]: + curr_id = len(hyp_j.yseq) + next_id = len(hyp_i.yseq) + + if ( + is_prefix(hyp_j.yseq, hyp_i.yseq) + and (curr_id - next_id) <= self.prefix_alpha + ): + ytu = torch.log_softmax( + self.joint_network(hi, hyp_i.y[-1]), dim=-1 + ) + + curr_score = hyp_i.score + float(ytu[hyp_j.yseq[next_id]]) + + for k in range(next_id, (curr_id - 1)): + ytu = torch.log_softmax( + self.joint_network(hi, hyp_j.y[k]), dim=-1 + ) + + curr_score += float(ytu[hyp_j.yseq[k + 1]]) + + hyp_j.score = np.logaddexp(hyp_j.score, curr_score) + + S = [] + V = [] + for n in range(self.nstep): + beam_y = torch.stack([hyp.y[-1] for hyp in hyps]) + + beam_logp = torch.log_softmax(self.joint_network(h_enc, beam_y), dim=-1) + beam_topk = beam_logp[:, 1:].topk(beam_k, dim=-1) + + for i, hyp in enumerate(hyps): + S.append( + NSCHypothesis( + yseq=hyp.yseq[:], + score=hyp.score + float(beam_logp[i, 0:1]), + y=hyp.y[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + lm_scores=hyp.lm_scores, + ) + ) + + for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): + score = hyp.score + float(logp) + + if self.use_lm: + score += self.lm_weight * float(hyp.lm_scores[k]) + + V.append( + NSCHypothesis( + yseq=hyp.yseq[:] + [int(k)], + score=score, + y=hyp.y[:], + dec_state=hyp.dec_state, + lm_state=hyp.lm_state, + lm_scores=hyp.lm_scores, + ) + ) + + V.sort(key=lambda x: x.score, reverse=True) + V = substract(V, hyps)[:beam] + + beam_state = self.decoder.create_batch_states( + beam_state, + [v.dec_state for v in V], + [v.yseq for v in V], + ) + beam_y, beam_state, beam_lm_tokens = self.decoder.batch_score( + V, + beam_state, + cache, + self.use_lm, + ) + + if self.use_lm: + beam_lm_states = create_lm_batch_state( + [v.lm_state for v in V], self.lm_layers, self.is_wordlm + ) + beam_lm_states, beam_lm_scores = self.lm.buff_predict( + beam_lm_states, beam_lm_tokens, len(V) + ) + + if n < (self.nstep - 1): + for i, v in enumerate(V): + v.y.append(beam_y[i]) + + v.dec_state = self.decoder.select_state(beam_state, i) + + if self.use_lm: + v.lm_state = select_lm_state( + beam_lm_states, i, self.lm_layers, self.is_wordlm + ) + v.lm_scores = beam_lm_scores[i] + + hyps = V[:] + else: + beam_logp = torch.log_softmax( + self.joint_network(h_enc, beam_y), dim=-1 + ) + + for i, v in enumerate(V): + if self.nstep != 1: + v.score += float(beam_logp[i, 0]) + + v.y.append(beam_y[i]) + + v.dec_state = self.decoder.select_state(beam_state, i) + + if self.use_lm: + v.lm_state = select_lm_state( + beam_lm_states, i, self.lm_layers, self.is_wordlm + ) + v.lm_scores = beam_lm_scores[i] + + kept_hyps = sorted((S + V), key=lambda x: x.score, reverse=True)[:beam] + + return self.sort_nbest(kept_hyps) diff --git a/espnet/nets/chainer_backend/__init__.py b/espnet/nets/chainer_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/chainer_backend/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/chainer_backend/asr_interface.py b/espnet/nets/chainer_backend/asr_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a570bee6b02bb712b8d64aa81986b40bdcac9904 --- /dev/null +++ b/espnet/nets/chainer_backend/asr_interface.py @@ -0,0 +1,29 @@ +"""ASR Interface module.""" +import chainer + +from espnet.nets.asr_interface import ASRInterface + + +class ChainerASRInterface(ASRInterface, chainer.Chain): + """ASR Interface for ESPnet model implementation.""" + + @staticmethod + def custom_converter(*args, **kw): + """Get customconverter of the model (Chainer only).""" + raise NotImplementedError("custom converter method is not implemented") + + @staticmethod + def custom_updater(*args, **kw): + """Get custom_updater of the model (Chainer only).""" + raise NotImplementedError("custom updater method is not implemented") + + @staticmethod + def custom_parallel_updater(*args, **kw): + """Get custom_parallel_updater of the model (Chainer only).""" + raise NotImplementedError("custom parallel updater method is not implemented") + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + raise NotImplementedError( + "get_total_subsampling_factor method is not implemented" + ) diff --git a/espnet/nets/chainer_backend/ctc.py b/espnet/nets/chainer_backend/ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..222ae0c3d9364e518954ae20667ec355961073ef --- /dev/null +++ b/espnet/nets/chainer_backend/ctc.py @@ -0,0 +1,184 @@ +import logging + +import chainer +from chainer import cuda +import chainer.functions as F +import chainer.links as L +import numpy as np + + +class CTC(chainer.Chain): + """Chainer implementation of ctc layer. + + Args: + odim (int): The output dimension. + eprojs (int | None): Dimension of input vectors from encoder. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, odim, eprojs, dropout_rate): + super(CTC, self).__init__() + self.dropout_rate = dropout_rate + self.loss = None + + with self.init_scope(): + self.ctc_lo = L.Linear(eprojs, odim) + + def __call__(self, hs, ys): + """CTC forward. + + Args: + hs (list of chainer.Variable | N-dimension array): + Input variable from encoder. + ys (list of chainer.Variable | N-dimension array): + Input variable of decoder. + + Returns: + chainer.Variable: A variable holding a scalar value of the CTC loss. + + """ + self.loss = None + ilens = [x.shape[0] for x in hs] + olens = [x.shape[0] for x in ys] + + # zero padding for hs + y_hat = self.ctc_lo( + F.dropout(F.pad_sequence(hs), ratio=self.dropout_rate), n_batch_axes=2 + ) + y_hat = F.separate(y_hat, axis=1) # ilen list of batch x hdim + + # zero padding for ys + y_true = F.pad_sequence(ys, padding=-1) # batch x olen + + # get length info + input_length = chainer.Variable(self.xp.array(ilens, dtype=np.int32)) + label_length = chainer.Variable(self.xp.array(olens, dtype=np.int32)) + logging.info( + self.__class__.__name__ + " input lengths: " + str(input_length.data) + ) + logging.info( + self.__class__.__name__ + " output lengths: " + str(label_length.data) + ) + + # get ctc loss + self.loss = F.connectionist_temporal_classification( + y_hat, y_true, 0, input_length, label_length + ) + logging.info("ctc loss:" + str(self.loss.data)) + + return self.loss + + def log_softmax(self, hs): + """Log_softmax of frame activations. + + Args: + hs (list of chainer.Variable | N-dimension array): + Input variable from encoder. + + Returns: + chainer.Variable: A n-dimension float array. + + """ + y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2) + return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape) + + +class WarpCTC(chainer.Chain): + """Chainer implementation of warp-ctc layer. + + Args: + odim (int): The output dimension. + eproj (int | None): Dimension of input vector from encoder. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, odim, eprojs, dropout_rate): + super(WarpCTC, self).__init__() + self.dropout_rate = dropout_rate + self.loss = None + + with self.init_scope(): + self.ctc_lo = L.Linear(eprojs, odim) + + def __call__(self, hs, ys): + """Core function of the Warp-CTC layer. + + Args: + hs (iterable of chainer.Variable | N-dimention array): + Input variable from encoder. + ys (iterable of chainer.Variable | N-dimension array): + Input variable of decoder. + + Returns: + chainer.Variable: A variable holding a scalar value of the CTC loss. + + """ + self.loss = None + ilens = [x.shape[0] for x in hs] + olens = [x.shape[0] for x in ys] + + # zero padding for hs + y_hat = self.ctc_lo( + F.dropout(F.pad_sequence(hs), ratio=self.dropout_rate), n_batch_axes=2 + ) + y_hat = y_hat.transpose(1, 0, 2) # batch x frames x hdim + + # get length info + logging.info(self.__class__.__name__ + " input lengths: " + str(ilens)) + logging.info(self.__class__.__name__ + " output lengths: " + str(olens)) + + # get ctc loss + from chainer_ctc.warpctc import ctc as warp_ctc + + self.loss = warp_ctc(y_hat, ilens, [cuda.to_cpu(y.data) for y in ys])[0] + logging.info("ctc loss:" + str(self.loss.data)) + + return self.loss + + def log_softmax(self, hs): + """Log_softmax of frame activations. + + Args: + hs (list of chainer.Variable | N-dimension array): + Input variable from encoder. + + Returns: + chainer.Variable: A n-dimension float array. + + """ + y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2) + return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape) + + def argmax(self, hs_pad): + """argmax of frame activations + + :param chainer variable hs_pad: 3d tensor (B, Tmax, eprojs) + :return: argmax applied 2d tensor (B, Tmax) + :rtype: chainer.Variable + """ + return F.argmax(self.ctc_lo(F.pad_sequence(hs_pad), n_batch_axes=2), axis=-1) + + +def ctc_for(args, odim): + """Return the CTC layer corresponding to the args. + + Args: + args (Namespace): The program arguments. + odim (int): The output dimension. + + Returns: + The CTC module. + + """ + ctc_type = args.ctc_type + if ctc_type == "builtin": + logging.info("Using chainer CTC implementation") + ctc = CTC(odim, args.eprojs, args.dropout_rate) + elif ctc_type == "warpctc": + logging.info("Using warpctc CTC implementation") + ctc = WarpCTC(odim, args.eprojs, args.dropout_rate) + else: + raise ValueError('ctc_type must be "builtin" or "warpctc": {}'.format(ctc_type)) + return ctc diff --git a/espnet/nets/chainer_backend/deterministic_embed_id.py b/espnet/nets/chainer_backend/deterministic_embed_id.py new file mode 100644 index 0000000000000000000000000000000000000000..22bc3e3b3ae92f57e3759de91f2595bbb9b9ac8e --- /dev/null +++ b/espnet/nets/chainer_backend/deterministic_embed_id.py @@ -0,0 +1,253 @@ +import numpy +import six + +import chainer +from chainer import cuda +from chainer import function_node +from chainer.initializers import normal + +# from chainer.functions.connection import embed_id +from chainer import link +from chainer.utils import type_check +from chainer import variable + +"""Deterministic EmbedID link and function + + copied from chainer/links/connection/embed_id.py + and chainer/functions/connection/embed_id.py, + and modified not to use atomicAdd operation +""" + + +class EmbedIDFunction(function_node.FunctionNode): + def __init__(self, ignore_label=None): + self.ignore_label = ignore_label + self._w_shape = None + + def check_type_forward(self, in_types): + type_check.expect(in_types.size() == 2) + x_type, w_type = in_types + type_check.expect( + x_type.dtype.kind == "i", + x_type.ndim >= 1, + ) + type_check.expect(w_type.dtype == numpy.float32, w_type.ndim == 2) + + def forward(self, inputs): + self.retain_inputs((0,)) + x, W = inputs + self._w_shape = W.shape + + if not type_check.same_types(*inputs): + raise ValueError( + "numpy and cupy must not be used together\n" + "type(W): {0}, type(x): {1}".format(type(W), type(x)) + ) + + xp = cuda.get_array_module(*inputs) + if chainer.is_debug(): + valid_x = xp.logical_and(0 <= x, x < len(W)) + if self.ignore_label is not None: + valid_x = xp.logical_or(valid_x, x == self.ignore_label) + if not valid_x.all(): + raise ValueError( + "Each not ignored `x` value need to satisfy" "`0 <= x < len(W)`" + ) + + if self.ignore_label is not None: + mask = x == self.ignore_label + return (xp.where(mask[..., None], 0, W[xp.where(mask, 0, x)]),) + + return (W[x],) + + def backward(self, indexes, grad_outputs): + inputs = self.get_retained_inputs() + gW = EmbedIDGrad(self._w_shape, self.ignore_label).apply(inputs + grad_outputs)[ + 0 + ] + return None, gW + + +class EmbedIDGrad(function_node.FunctionNode): + def __init__(self, w_shape, ignore_label=None): + self.w_shape = w_shape + self.ignore_label = ignore_label + self._gy_shape = None + + def forward(self, inputs): + self.retain_inputs((0,)) + xp = cuda.get_array_module(*inputs) + x, gy = inputs + self._gy_shape = gy.shape + gW = xp.zeros(self.w_shape, dtype=gy.dtype) + + if xp is numpy: + # It is equivalent to `numpy.add.at(gW, x, gy)` but ufunc.at is + # too slow. + for ix, igy in six.moves.zip(x.ravel(), gy.reshape(x.size, -1)): + if ix == self.ignore_label: + continue + gW[ix] += igy + else: + """ + # original code based on cuda elementwise method + if self.ignore_label is None: + cuda.elementwise( + 'T gy, S x, S n_out', 'raw T gW', + 'ptrdiff_t w_ind[] = {x, i % n_out};' + 'atomicAdd(&gW[w_ind], gy)', + 'embed_id_bwd')( + gy, xp.expand_dims(x, -1), gW.shape[1], gW) + else: + cuda.elementwise( + 'T gy, S x, S n_out, S ignore', 'raw T gW', + ''' + if (x != ignore) { + ptrdiff_t w_ind[] = {x, i % n_out}; + atomicAdd(&gW[w_ind], gy); + } + ''', + 'embed_id_bwd_ignore_label')( + gy, xp.expand_dims(x, -1), gW.shape[1], + self.ignore_label, gW) + """ + # EmbedID gradient alternative without atomicAdd, which simply + # creates a one-hot vector and applies dot product + xi = xp.zeros((x.size, len(gW)), dtype=numpy.float32) + idx = xp.arange(x.size, dtype=numpy.int32) * len(gW) + x.ravel() + xi.ravel()[idx] = 1.0 + if self.ignore_label is not None: + xi[:, self.ignore_label] = 0.0 + gW = xi.T.dot(gy.reshape(x.size, -1)).astype(gW.dtype, copy=False) + + return (gW,) + + def backward(self, indexes, grads): + xp = cuda.get_array_module(*grads) + x = self.get_retained_inputs()[0].data + ggW = grads[0] + + if self.ignore_label is not None: + mask = x == self.ignore_label + # To prevent index out of bounds, we need to check if ignore_label + # is inside of W. + if not (0 <= self.ignore_label < self.w_shape[1]): + x = xp.where(mask, 0, x) + + ggy = ggW[x] + + if self.ignore_label is not None: + mask, zero, _ = xp.broadcast_arrays( + mask[..., None], xp.zeros((), "f"), ggy.data + ) + ggy = chainer.functions.where(mask, zero, ggy) + return None, ggy + + +def embed_id(x, W, ignore_label=None): + r"""Efficient linear function for one-hot input. + + This function implements so called *word embeddings*. It takes two + arguments: a set of IDs (words) ``x`` in :math:`B` dimensional integer + vector, and a set of all ID (word) embeddings ``W`` in :math:`V \\times d` + float32 matrix. It outputs :math:`B \\times d` matrix whose ``i``-th + column is the ``x[i]``-th column of ``W``. + This function is only differentiable on the input ``W``. + + Args: + x (chainer.Variable | np.ndarray): Batch vectors of IDs. Each + element must be signed integer. + W (chainer.Variable | np.ndarray): Distributed representation + of each ID (a.k.a. word embeddings). + ignore_label (int): If ignore_label is an int value, i-th column + of return value is filled with 0. + + Returns: + chainer.Variable: Embedded variable. + + + .. rubric:: :class:`~chainer.links.EmbedID` + + Examples: + + >>> x = np.array([2, 1]).astype('i') + >>> x + array([2, 1], dtype=int32) + >>> W = np.array([[0, 0, 0], + ... [1, 1, 1], + ... [2, 2, 2]]).astype('f') + >>> W + array([[ 0., 0., 0.], + [ 1., 1., 1.], + [ 2., 2., 2.]], dtype=float32) + >>> F.embed_id(x, W).data + array([[ 2., 2., 2.], + [ 1., 1., 1.]], dtype=float32) + >>> F.embed_id(x, W, ignore_label=1).data + array([[ 2., 2., 2.], + [ 0., 0., 0.]], dtype=float32) + + """ + return EmbedIDFunction(ignore_label=ignore_label).apply((x, W))[0] + + +class EmbedID(link.Link): + """Efficient linear layer for one-hot input. + + This is a link that wraps the :func:`~chainer.functions.embed_id` function. + This link holds the ID (word) embedding matrix ``W`` as a parameter. + + Args: + in_size (int): Number of different identifiers (a.k.a. vocabulary size). + out_size (int): Output dimension. + initialW (Initializer): Initializer to initialize the weight. + ignore_label (int): If `ignore_label` is an int value, i-th column of + return value is filled with 0. + + .. rubric:: :func:`~chainer.functions.embed_id` + + Attributes: + W (~chainer.Variable): Embedding parameter matrix. + + Examples: + + >>> W = np.array([[0, 0, 0], + ... [1, 1, 1], + ... [2, 2, 2]]).astype('f') + >>> W + array([[ 0., 0., 0.], + [ 1., 1., 1.], + [ 2., 2., 2.]], dtype=float32) + >>> l = L.EmbedID(W.shape[0], W.shape[1], initialW=W) + >>> x = np.array([2, 1]).astype('i') + >>> x + array([2, 1], dtype=int32) + >>> y = l(x) + >>> y.data + array([[ 2., 2., 2.], + [ 1., 1., 1.]], dtype=float32) + + """ + + ignore_label = None + + def __init__(self, in_size, out_size, initialW=None, ignore_label=None): + super(EmbedID, self).__init__() + self.ignore_label = ignore_label + + with self.init_scope(): + if initialW is None: + initialW = normal.Normal(1.0) + self.W = variable.Parameter(initialW, (in_size, out_size)) + + def __call__(self, x): + """Extracts the word embedding of given IDs. + + Args: + x (chainer.Variable): Batch vectors of IDs. + + Returns: + chainer.Variable: Batch of corresponding embeddings. + + """ + return embed_id(x, self.W, ignore_label=self.ignore_label) diff --git a/espnet/nets/chainer_backend/e2e_asr.py b/espnet/nets/chainer_backend/e2e_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..dc589ef1a1280ad86c42bad64dbe95573a226dc8 --- /dev/null +++ b/espnet/nets/chainer_backend/e2e_asr.py @@ -0,0 +1,226 @@ +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""RNN sequence-to-sequence speech recognition model (chainer).""" + +import logging +import math + +import chainer +from chainer import reporter +import numpy as np + +from espnet.nets.chainer_backend.asr_interface import ChainerASRInterface +from espnet.nets.chainer_backend.ctc import ctc_for +from espnet.nets.chainer_backend.rnn.attentions import att_for +from espnet.nets.chainer_backend.rnn.decoders import decoder_for +from espnet.nets.chainer_backend.rnn.encoders import encoder_for +from espnet.nets.e2e_asr_common import label_smoothing_dist +from espnet.nets.pytorch_backend.e2e_asr import E2E as E2E_pytorch +from espnet.nets.pytorch_backend.nets_utils import get_subsample + +CTC_LOSS_THRESHOLD = 10000 + + +class E2E(ChainerASRInterface): + """E2E module for chainer backend. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + args (parser.args): Training config. + flag_return (bool): If True, train() would return + additional metrics in addition to the training + loss. + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + return E2E_pytorch.add_arguments(parser) + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + return self.enc.conv_subsampling_factor * int(np.prod(self.subsample)) + + def __init__(self, idim, odim, args, flag_return=True): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + chainer.Chain.__init__(self) + self.mtlalpha = args.mtlalpha + assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]" + self.etype = args.etype + self.verbose = args.verbose + self.char_list = args.char_list + self.outdir = args.outdir + + # below means the last number becomes eos/sos ID + # note that sos/eos IDs are identical + self.sos = odim - 1 + self.eos = odim - 1 + + # subsample info + self.subsample = get_subsample(args, mode="asr", arch="rnn") + + # label smoothing info + if args.lsm_type: + logging.info("Use label smoothing with " + args.lsm_type) + labeldist = label_smoothing_dist( + odim, args.lsm_type, transcript=args.train_json + ) + else: + labeldist = None + + with self.init_scope(): + # encoder + self.enc = encoder_for(args, idim, self.subsample) + # ctc + self.ctc = ctc_for(args, odim) + # attention + self.att = att_for(args) + # decoder + self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) + + self.acc = None + self.loss = None + self.flag_return = flag_return + + def forward(self, xs, ilens, ys): + """E2E forward propagation. + + Args: + xs (chainer.Variable): Batch of padded charactor ids. (B, Tmax) + ilens (chainer.Variable): Batch of length of each input batch. (B,) + ys (chainer.Variable): Batch of padded target features. (B, Lmax, odim) + + Returns: + float: Loss that calculated by attention and ctc loss. + float (optional): Ctc loss. + float (optional): Attention loss. + float (optional): Accuracy. + + """ + # 1. encoder + hs, ilens = self.enc(xs, ilens) + + # 3. CTC loss + if self.mtlalpha == 0: + loss_ctc = None + else: + loss_ctc = self.ctc(hs, ys) + + # 4. attention loss + if self.mtlalpha == 1: + loss_att = None + acc = None + else: + loss_att, acc = self.dec(hs, ys) + + self.acc = acc + alpha = self.mtlalpha + if alpha == 0: + self.loss = loss_att + elif alpha == 1: + self.loss = loss_ctc + else: + self.loss = alpha * loss_ctc + (1 - alpha) * loss_att + + if self.loss.data < CTC_LOSS_THRESHOLD and not math.isnan(self.loss.data): + reporter.report({"loss_ctc": loss_ctc}, self) + reporter.report({"loss_att": loss_att}, self) + reporter.report({"acc": acc}, self) + + logging.info("mtl loss:" + str(self.loss.data)) + reporter.report({"loss": self.loss}, self) + else: + logging.warning("loss (=%f) is not correct", self.loss.data) + if self.flag_return: + return self.loss, loss_ctc, loss_att, acc + else: + return self.loss + + def recognize(self, x, recog_args, char_list, rnnlm=None): + """E2E greedy/beam search. + + Args: + x (chainer.Variable): Input tensor for recognition. + recog_args (parser.args): Arguments of config file. + char_list (List[str]): List of Charactors. + rnnlm (Module): RNNLM module defined at `espnet.lm.chainer_backend.lm`. + + Returns: + List[Dict[str, Any]]: Result of recognition. + + """ + # subsample frame + x = x[:: self.subsample[0], :] + ilen = self.xp.array(x.shape[0], dtype=np.int32) + h = chainer.Variable(self.xp.array(x, dtype=np.float32)) + + with chainer.no_backprop_mode(), chainer.using_config("train", False): + # 1. encoder + # make a utt list (1) to use the same interface for encoder + h, _ = self.enc([h], [ilen]) + + # calculate log P(z_t|X) for CTC scores + if recog_args.ctc_weight > 0.0: + lpz = self.ctc.log_softmax(h).data[0] + else: + lpz = None + + # 2. decoder + # decode the first utterance + y = self.dec.recognize_beam(h[0], lpz, recog_args, char_list, rnnlm) + + return y + + def calculate_all_attentions(self, xs, ilens, ys): + """E2E attention calculation. + + Args: + xs (List): List of padded input sequences. [(T1, idim), (T2, idim), ...] + ilens (np.ndarray): Batch of lengths of input sequences. (B) + ys (List): List of character id sequence tensor. [(L1), (L2), (L3), ...] + + Returns: + float np.ndarray: Attention weights. (B, Lmax, Tmax) + + """ + hs, ilens = self.enc(xs, ilens) + att_ws = self.dec.calculate_all_attentions(hs, ys) + + return att_ws + + @staticmethod + def custom_converter(subsampling_factor=0): + """Get customconverter of the model.""" + from espnet.nets.chainer_backend.rnn.training import CustomConverter + + return CustomConverter(subsampling_factor=subsampling_factor) + + @staticmethod + def custom_updater(iters, optimizer, converter, device=-1, accum_grad=1): + """Get custom_updater of the model.""" + from espnet.nets.chainer_backend.rnn.training import CustomUpdater + + return CustomUpdater( + iters, optimizer, converter=converter, device=device, accum_grad=accum_grad + ) + + @staticmethod + def custom_parallel_updater(iters, optimizer, converter, devices, accum_grad=1): + """Get custom_parallel_updater of the model.""" + from espnet.nets.chainer_backend.rnn.training import CustomParallelUpdater + + return CustomParallelUpdater( + iters, + optimizer, + converter=converter, + devices=devices, + accum_grad=accum_grad, + ) diff --git a/espnet/nets/chainer_backend/e2e_asr_transformer.py b/espnet/nets/chainer_backend/e2e_asr_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..67f05c71e146505499c7080502f9ff7c54658403 --- /dev/null +++ b/espnet/nets/chainer_backend/e2e_asr_transformer.py @@ -0,0 +1,630 @@ +# encoding: utf-8 +"""Transformer-based model for End-to-end ASR.""" + +from argparse import Namespace +from distutils.util import strtobool +import logging +import math + +import chainer +import chainer.functions as F +from chainer import reporter +import numpy as np +import six + +from espnet.nets.chainer_backend.asr_interface import ChainerASRInterface +from espnet.nets.chainer_backend.transformer.attention import MultiHeadAttention +from espnet.nets.chainer_backend.transformer import ctc +from espnet.nets.chainer_backend.transformer.decoder import Decoder +from espnet.nets.chainer_backend.transformer.encoder import Encoder +from espnet.nets.chainer_backend.transformer.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from espnet.nets.chainer_backend.transformer.training import CustomConverter +from espnet.nets.chainer_backend.transformer.training import CustomUpdater +from espnet.nets.chainer_backend.transformer.training import ( + CustomParallelUpdater, # noqa: H301 +) +from espnet.nets.ctc_prefix_score import CTCPrefixScore +from espnet.nets.e2e_asr_common import end_detect +from espnet.nets.e2e_asr_common import ErrorCalculator +from espnet.nets.pytorch_backend.nets_utils import get_subsample +from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport + + +CTC_SCORING_RATIO = 1.5 +MAX_DECODER_OUTPUT = 5 + + +class E2E(ChainerASRInterface): + """E2E module. + + Args: + idim (int): Input dimmensions. + odim (int): Output dimmensions. + args (Namespace): Training config. + ignore_id (int, optional): Id for ignoring a character. + flag_return (bool, optional): If true, return a list with (loss, + loss_ctc, loss_att, acc) in forward. Otherwise, return loss. + + """ + + @staticmethod + def add_arguments(parser): + """Customize flags for transformer setup. + + Args: + parser (Namespace): Training config. + + """ + group = parser.add_argument_group("transformer model setting") + group.add_argument( + "--transformer-init", + type=str, + default="pytorch", + help="how to initialize transformer parameters", + ) + group.add_argument( + "--transformer-input-layer", + type=str, + default="conv2d", + choices=["conv2d", "linear", "embed"], + help="transformer input layer type", + ) + group.add_argument( + "--transformer-attn-dropout-rate", + default=None, + type=float, + help="dropout in transformer attention. use --dropout-rate if None is set", + ) + group.add_argument( + "--transformer-lr", + default=10.0, + type=float, + help="Initial value of learning rate", + ) + group.add_argument( + "--transformer-warmup-steps", + default=25000, + type=int, + help="optimizer warmup steps", + ) + group.add_argument( + "--transformer-length-normalized-loss", + default=True, + type=strtobool, + help="normalize loss by length", + ) + + group.add_argument( + "--dropout-rate", + default=0.0, + type=float, + help="Dropout rate for the encoder", + ) + # Encoder + group.add_argument( + "--elayers", + default=4, + type=int, + help="Number of encoder layers (for shared recognition part " + "in multi-speaker asr mode)", + ) + group.add_argument( + "--eunits", + "-u", + default=300, + type=int, + help="Number of encoder hidden units", + ) + # Attention + group.add_argument( + "--adim", + default=320, + type=int, + help="Number of attention transformation dimensions", + ) + group.add_argument( + "--aheads", + default=4, + type=int, + help="Number of heads for multi head attention", + ) + # Decoder + group.add_argument( + "--dlayers", default=1, type=int, help="Number of decoder layers" + ) + group.add_argument( + "--dunits", default=320, type=int, help="Number of decoder hidden units" + ) + return parser + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + return self.encoder.conv_subsampling_factor * int(np.prod(self.subsample)) + + def __init__(self, idim, odim, args, ignore_id=-1, flag_return=True): + """Initialize the transformer.""" + chainer.Chain.__init__(self) + self.mtlalpha = args.mtlalpha + assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]" + if args.transformer_attn_dropout_rate is None: + args.transformer_attn_dropout_rate = args.dropout_rate + self.use_label_smoothing = False + self.char_list = args.char_list + self.space = args.sym_space + self.blank = args.sym_blank + self.scale_emb = args.adim ** 0.5 + self.sos = odim - 1 + self.eos = odim - 1 + self.subsample = get_subsample(args, mode="asr", arch="transformer") + self.ignore_id = ignore_id + self.reset_parameters(args) + with self.init_scope(): + self.encoder = Encoder( + idim=idim, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.eunits, + input_layer=args.transformer_input_layer, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + attention_dropout_rate=args.transformer_attn_dropout_rate, + initialW=self.initialW, + initial_bias=self.initialB, + ) + self.decoder = Decoder( + odim, args, initialW=self.initialW, initial_bias=self.initialB + ) + self.criterion = LabelSmoothingLoss( + args.lsm_weight, + len(args.char_list), + args.transformer_length_normalized_loss, + ) + if args.mtlalpha > 0.0: + if args.ctc_type == "builtin": + logging.info("Using chainer CTC implementation") + self.ctc = ctc.CTC(odim, args.adim, args.dropout_rate) + elif args.ctc_type == "warpctc": + logging.info("Using warpctc CTC implementation") + self.ctc = ctc.WarpCTC(odim, args.adim, args.dropout_rate) + else: + raise ValueError( + 'ctc_type must be "builtin" or "warpctc": {}'.format( + args.ctc_type + ) + ) + else: + self.ctc = None + self.dims = args.adim + self.odim = odim + self.flag_return = flag_return + if args.report_cer or args.report_wer: + self.error_calculator = ErrorCalculator( + args.char_list, + args.sym_space, + args.sym_blank, + args.report_cer, + args.report_wer, + ) + else: + self.error_calculator = None + if "Namespace" in str(type(args)): + self.verbose = 0 if "verbose" not in args else args.verbose + else: + self.verbose = 0 if args.verbose is None else args.verbose + + def reset_parameters(self, args): + """Initialize the Weight according to the give initialize-type. + + Args: + args (Namespace): Transformer config. + + """ + type_init = args.transformer_init + if type_init == "lecun_uniform": + logging.info("Using LeCunUniform as Parameter initializer") + self.initialW = chainer.initializers.LeCunUniform + elif type_init == "lecun_normal": + logging.info("Using LeCunNormal as Parameter initializer") + self.initialW = chainer.initializers.LeCunNormal + elif type_init == "gorot_uniform": + logging.info("Using GlorotUniform as Parameter initializer") + self.initialW = chainer.initializers.GlorotUniform + elif type_init == "gorot_normal": + logging.info("Using GlorotNormal as Parameter initializer") + self.initialW = chainer.initializers.GlorotNormal + elif type_init == "he_uniform": + logging.info("Using HeUniform as Parameter initializer") + self.initialW = chainer.initializers.HeUniform + elif type_init == "he_normal": + logging.info("Using HeNormal as Parameter initializer") + self.initialW = chainer.initializers.HeNormal + elif type_init == "pytorch": + logging.info("Using Pytorch initializer") + self.initialW = chainer.initializers.Uniform + else: + logging.info("Using Chainer default as Parameter initializer") + self.initialW = chainer.initializers.Uniform + self.initialB = chainer.initializers.Uniform + + def forward(self, xs, ilens, ys_pad, calculate_attentions=False): + """E2E forward propagation. + + Args: + xs (chainer.Variable): Batch of padded charactor ids. (B, Tmax) + ilens (chainer.Variable): Batch of length of each input batch. (B,) + ys (chainer.Variable): Batch of padded target features. (B, Lmax, odim) + calculate_attentions (bool): If true, return value is the output of encoder. + + Returns: + float: Training loss. + float (optional): Training loss for ctc. + float (optional): Training loss for attention. + float (optional): Accuracy. + chainer.Variable (Optional): Output of the encoder. + + """ + alpha = self.mtlalpha + + # 1. Encoder + xs, x_mask, ilens = self.encoder(xs, ilens) + + # 2. CTC loss + cer_ctc = None + if alpha == 0.0: + loss_ctc = None + else: + _ys = [y.astype(np.int32) for y in ys_pad] + loss_ctc = self.ctc(xs, _ys) + if self.error_calculator is not None: + with chainer.no_backprop_mode(): + ys_hat = chainer.backends.cuda.to_cpu(self.ctc.argmax(xs).data) + cer_ctc = self.error_calculator(ys_hat, ys_pad, is_ctc=True) + + # 3. Decoder + if calculate_attentions: + self.calculate_attentions(xs, x_mask, ys_pad) + ys = self.decoder(ys_pad, xs, x_mask) + + # 4. Attention Loss + cer, wer = None, None + if alpha == 1: + loss_att = None + acc = None + else: + # Make target + eos = np.array([self.eos], "i") + with chainer.no_backprop_mode(): + ys_pad_out = [np.concatenate([y, eos], axis=0) for y in ys_pad] + ys_pad_out = F.pad_sequence(ys_pad_out, padding=-1).data + ys_pad_out = self.xp.array(ys_pad_out) + + loss_att = self.criterion(ys, ys_pad_out) + acc = F.accuracy( + ys.reshape(-1, self.odim), ys_pad_out.reshape(-1), ignore_label=-1 + ) + if (not chainer.config.train) and (self.error_calculator is not None): + cer, wer = self.error_calculator(ys, ys_pad) + + if alpha == 0.0: + self.loss = loss_att + loss_att_data = loss_att.data + loss_ctc_data = None + elif alpha == 1.0: + self.loss = loss_ctc + loss_att_data = None + loss_ctc_data = loss_ctc.data + else: + self.loss = alpha * loss_ctc + (1 - alpha) * loss_att + loss_att_data = loss_att.data + loss_ctc_data = loss_ctc.data + loss_data = self.loss.data + + if not math.isnan(loss_data): + reporter.report({"loss_ctc": loss_ctc_data}, self) + reporter.report({"loss_att": loss_att_data}, self) + reporter.report({"acc": acc}, self) + + reporter.report({"cer_ctc": cer_ctc}, self) + reporter.report({"cer": cer}, self) + reporter.report({"wer": wer}, self) + + logging.info("mtl loss:" + str(loss_data)) + reporter.report({"loss": loss_data}, self) + else: + logging.warning("loss (=%f) is not correct", loss_data) + + if self.flag_return: + loss_ctc = None + return self.loss, loss_ctc, loss_att, acc + else: + return self.loss + + def calculate_attentions(self, xs, x_mask, ys_pad): + """Calculate Attentions.""" + self.decoder(ys_pad, xs, x_mask) + + def recognize(self, x_block, recog_args, char_list=None, rnnlm=None): + """E2E recognition function. + + Args: + x (ndarray): Input acouctic feature (B, T, D) or (T, D). + recog_args (Namespace): Argment namespace contraining options. + char_list (List[str]): List of characters. + rnnlm (chainer.Chain): Language model module defined at + `espnet.lm.chainer_backend.lm`. + + Returns: + List: N-best decoding results. + + """ + with chainer.no_backprop_mode(), chainer.using_config("train", False): + # 1. encoder + ilens = [x_block.shape[0]] + batch = len(ilens) + xs, _, _ = self.encoder(x_block[None, :, :], ilens) + + # calculate log P(z_t|X) for CTC scores + if recog_args.ctc_weight > 0.0: + lpz = self.ctc.log_softmax(xs.reshape(batch, -1, self.dims)).data[0] + else: + lpz = None + # 2. decoder + if recog_args.lm_weight == 0.0: + rnnlm = None + y = self.recognize_beam(xs, lpz, recog_args, char_list, rnnlm) + + return y + + def recognize_beam(self, h, lpz, recog_args, char_list=None, rnnlm=None): + """E2E beam search. + + Args: + h (ndarray): Encoder ouput features (B, T, D) or (T, D). + lpz (ndarray): Log probabilities from CTC. + recog_args (Namespace): Argment namespace contraining options. + char_list (List[str]): List of characters. + rnnlm (chainer.Chain): Language model module defined at + `espnet.lm.chainer_backend.lm`. + + Returns: + List: N-best decoding results. + + """ + logging.info("input lengths: " + str(h.shape[1])) + + # initialization + n_len = h.shape[1] + xp = self.xp + h_mask = xp.ones((1, n_len)) + + # search parms + beam = recog_args.beam_size + penalty = recog_args.penalty + ctc_weight = recog_args.ctc_weight + + # prepare sos + y = self.sos + if recog_args.maxlenratio == 0: + maxlen = n_len + else: + maxlen = max(1, int(recog_args.maxlenratio * n_len)) + minlen = int(recog_args.minlenratio * n_len) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # initialize hypothesis + if rnnlm: + hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} + else: + hyp = {"score": 0.0, "yseq": [y]} + + if lpz is not None: + ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp) + hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() + hyp["ctc_score_prev"] = 0.0 + if ctc_weight != 1.0: + # pre-pruning based on attention scores + ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) + else: + ctc_beam = lpz.shape[-1] + + hyps = [hyp] + ended_hyps = [] + + for i in six.moves.range(maxlen): + logging.debug("position " + str(i)) + + hyps_best_kept = [] + for hyp in hyps: + ys = F.expand_dims(xp.array(hyp["yseq"]), axis=0).data + out = self.decoder(ys, h, h_mask) + + # get nbest local scores and their ids + local_att_scores = F.log_softmax(out[:, -1], axis=-1).data + if rnnlm: + rnnlm_state, local_lm_scores = rnnlm.predict( + hyp["rnnlm_prev"], hyp["yseq"][i] + ) + local_scores = ( + local_att_scores + recog_args.lm_weight * local_lm_scores + ) + else: + local_scores = local_att_scores + + if lpz is not None: + local_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][ + :ctc_beam + ] + ctc_scores, ctc_states = ctc_prefix_score( + hyp["yseq"], local_best_ids, hyp["ctc_state_prev"] + ) + local_scores = (1.0 - ctc_weight) * local_att_scores[ + :, local_best_ids + ] + ctc_weight * (ctc_scores - hyp["ctc_score_prev"]) + if rnnlm: + local_scores += ( + recog_args.lm_weight * local_lm_scores[:, local_best_ids] + ) + joint_best_ids = xp.argsort(local_scores, axis=1)[0, ::-1][:beam] + local_best_scores = local_scores[:, joint_best_ids] + local_best_ids = local_best_ids[joint_best_ids] + else: + local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][ + :beam + ] + local_best_scores = local_scores[:, local_best_ids] + + for j in six.moves.range(beam): + new_hyp = {} + new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) + new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) + new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] + new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[j]) + if rnnlm: + new_hyp["rnnlm_prev"] = rnnlm_state + if lpz is not None: + new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[j]] + new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[j]] + hyps_best_kept.append(new_hyp) + + hyps_best_kept = sorted( + hyps_best_kept, key=lambda x: x["score"], reverse=True + )[:beam] + + # sort and get nbest + hyps = hyps_best_kept + logging.debug("number of pruned hypothesis: " + str(len(hyps))) + if char_list is not None: + logging.debug( + "best hypo: " + + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) + + " score: " + + str(hyps[0]["score"]) + ) + + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last postion in the loop") + for hyp in hyps: + hyp["yseq"].append(self.eos) + + # add ended hypothes to a final list, and removed them from current hypothes + # (this will be a probmlem, number of hyps < beam) + remained_hyps = [] + for hyp in hyps: + if hyp["yseq"][-1] == self.eos: + # only store the sequence that has more than minlen outputs + # also add penalty + if len(hyp["yseq"]) > minlen: + hyp["score"] += (i + 1) * penalty + if rnnlm: # Word LM needs to add final score + hyp["score"] += recog_args.lm_weight * rnnlm.final( + hyp["rnnlm_prev"] + ) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + + # end detection + if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: + logging.info("end detected at %d", i) + break + + hyps = remained_hyps + if len(hyps) > 0: + logging.debug("remained hypothes: " + str(len(hyps))) + else: + logging.info("no hypothesis. Finish decoding.") + break + if char_list is not None: + for hyp in hyps: + logging.debug( + "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) + ) + + logging.debug("number of ended hypothes: " + str(len(ended_hyps))) + + nbest_hyps = sorted( + ended_hyps, key=lambda x: x["score"], reverse=True + ) # [:min(len(ended_hyps), recog_args.nbest)] + + logging.debug(nbest_hyps) + # check number of hypotheis + if len(nbest_hyps) == 0: + logging.warn( + "there is no N-best results, perform recognition " + "again with smaller minlenratio." + ) + # should copy becasuse Namespace will be overwritten globally + recog_args = Namespace(**vars(recog_args)) + recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) + return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) + + logging.info("total log probability: " + str(nbest_hyps[0]["score"])) + logging.info( + "normalized log probability: " + + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) + ) + # remove sos + return nbest_hyps + + def calculate_all_attentions(self, xs, ilens, ys): + """E2E attention calculation. + + Args: + xs (List[tuple()]): List of padded input sequences. + [(T1, idim), (T2, idim), ...] + ilens (ndarray): Batch of lengths of input sequences. (B) + ys (List): List of character id sequence tensor. [(L1), (L2), (L3), ...] + + Returns: + float ndarray: Attention weights. (B, Lmax, Tmax) + + """ + with chainer.no_backprop_mode(): + self(xs, ilens, ys, calculate_attentions=True) + ret = dict() + for name, m in self.namedlinks(): + if isinstance(m, MultiHeadAttention): + var = m.attn + var.to_cpu() + _name = name[1:].replace("/", "_") + ret[_name] = var.data + return ret + + @property + def attention_plot_class(self): + """Attention plot function. + + Redirects to PlotAttentionReport + + Returns: + PlotAttentionReport + + """ + return PlotAttentionReport + + @staticmethod + def custom_converter(subsampling_factor=0): + """Get customconverter of the model.""" + return CustomConverter() + + @staticmethod + def custom_updater(iters, optimizer, converter, device=-1, accum_grad=1): + """Get custom_updater of the model.""" + return CustomUpdater( + iters, optimizer, converter=converter, device=device, accum_grad=accum_grad + ) + + @staticmethod + def custom_parallel_updater(iters, optimizer, converter, devices, accum_grad=1): + """Get custom_parallel_updater of the model.""" + return CustomParallelUpdater( + iters, + optimizer, + converter=converter, + devices=devices, + accum_grad=accum_grad, + ) diff --git a/espnet/nets/chainer_backend/nets_utils.py b/espnet/nets/chainer_backend/nets_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4919abb49f2985192149452ed59663b4caf8bf --- /dev/null +++ b/espnet/nets/chainer_backend/nets_utils.py @@ -0,0 +1,7 @@ +import chainer.functions as F + + +def _subsamplex(x, n): + x = [F.get_item(xx, (slice(None, None, n), slice(None))) for xx in x] + ilens = [xx.shape[0] for xx in x] + return x, ilens diff --git a/espnet/nets/chainer_backend/rnn/__init__.py b/espnet/nets/chainer_backend/rnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/chainer_backend/rnn/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/chainer_backend/rnn/attentions.py b/espnet/nets/chainer_backend/rnn/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..e9a776e5b2ee67150fc778eefe3cb21a8c5582c3 --- /dev/null +++ b/espnet/nets/chainer_backend/rnn/attentions.py @@ -0,0 +1,280 @@ +import chainer +import chainer.functions as F +import chainer.links as L + +import numpy as np + + +# dot product based attention +class AttDot(chainer.Chain): + """Compute attention based on dot product. + + Args: + eprojs (int | None): Dimension of input vectors from encoder. + dunits (int | None): Dimension of input vectors for decoder. + att_dim (int): Dimension of input vectors for attention. + + """ + + def __init__(self, eprojs, dunits, att_dim): + super(AttDot, self).__init__() + with self.init_scope(): + self.mlp_enc = L.Linear(eprojs, att_dim) + self.mlp_dec = L.Linear(dunits, att_dim) + + self.dunits = dunits + self.eprojs = eprojs + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + + def reset(self): + """Reset states.""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + + def __call__(self, enc_hs, dec_z, att_prev, scaling=2.0): + """Compute AttDot forward layer. + + Args: + enc_hs (chainer.Variable | N-dimensional array): + Input variable from encoder. + dec_z (chainer.Variable | N-dimensional array): Input variable of decoder. + scaling (float): Scaling weight to make attention sharp. + + Returns: + chainer.Variable: Weighted sum over flames. + chainer.Variable: Attention weight. + + """ + batch = len(enc_hs) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None: + self.enc_h = F.pad_sequence(enc_hs) # utt x frame x hdim + self.h_length = self.enc_h.shape[1] + # utt x frame x att_dim + self.pre_compute_enc_h = F.tanh(self.mlp_enc(self.enc_h, n_batch_axes=2)) + + if dec_z is None: + dec_z = chainer.Variable( + self.xp.zeros((batch, self.dunits), dtype=np.float32) + ) + else: + dec_z = dec_z.reshape(batch, self.dunits) + + # for all t + u = F.broadcast_to( + F.expand_dims(F.tanh(self.mlp_dec(dec_z)), 1), self.pre_compute_enc_h.shape + ) + e = F.sum(self.pre_compute_enc_h * u, axis=2) # utt x frame + # Applying a minus-large-number filter + # to make a probability value zero for a padded area + # simply degrades the performance, and I gave up this implementation + # Apply a scaling to make an attention sharp + w = F.softmax(scaling * e) + # weighted sum over flames + # utt x hdim + c = F.sum( + self.enc_h * F.broadcast_to(F.expand_dims(w, 2), self.enc_h.shape), axis=1 + ) + + return c, w + + +# location based attention +class AttLoc(chainer.Chain): + """Compute location-based attention. + + Args: + eprojs (int | None): Dimension of input vectors from encoder. + dunits (int | None): Dimension of input vectors for decoder. + att_dim (int): Dimension of input vectors for attention. + aconv_chans (int): Number of channels of output arrays from convolutional layer. + aconv_filts (int): Size of filters of convolutional layer. + + """ + + def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts): + super(AttLoc, self).__init__() + with self.init_scope(): + self.mlp_enc = L.Linear(eprojs, att_dim) + self.mlp_dec = L.Linear(dunits, att_dim, nobias=True) + self.mlp_att = L.Linear(aconv_chans, att_dim, nobias=True) + self.loc_conv = L.Convolution2D( + 1, aconv_chans, ksize=(1, 2 * aconv_filts + 1), pad=(0, aconv_filts) + ) + self.gvec = L.Linear(att_dim, 1) + + self.dunits = dunits + self.eprojs = eprojs + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.aconv_chans = aconv_chans + + def reset(self): + """Reset states.""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + + def __call__(self, enc_hs, dec_z, att_prev, scaling=2.0): + """Compute AttLoc forward layer. + + Args: + enc_hs (chainer.Variable | N-dimensional array): + Input variable from encoders. + dec_z (chainer.Variable | N-dimensional array): Input variable of decoder. + att_prev (chainer.Variable | None): Attention weight. + scaling (float): Scaling weight to make attention sharp. + + Returns: + chainer.Variable: Weighted sum over flames. + chainer.Variable: Attention weight. + + """ + batch = len(enc_hs) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None: + self.enc_h = F.pad_sequence(enc_hs) # utt x frame x hdim + self.h_length = self.enc_h.shape[1] + # utt x frame x att_dim + self.pre_compute_enc_h = self.mlp_enc(self.enc_h, n_batch_axes=2) + + if dec_z is None: + dec_z = chainer.Variable( + self.xp.zeros((batch, self.dunits), dtype=np.float32) + ) + else: + dec_z = dec_z.reshape(batch, self.dunits) + + # initialize attention weight with uniform dist. + if att_prev is None: + att_prev = [ + self.xp.full(hh.shape[0], 1.0 / hh.shape[0], dtype=np.float32) + for hh in enc_hs + ] + att_prev = [chainer.Variable(att) for att in att_prev] + att_prev = F.pad_sequence(att_prev) + + # att_prev: utt x frame -> utt x 1 x 1 x frame + # -> utt x att_conv_chans x 1 x frame + att_conv = self.loc_conv(att_prev.reshape(batch, 1, 1, self.h_length)) + # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans + att_conv = F.swapaxes(F.squeeze(att_conv, axis=2), 1, 2) + # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim + att_conv = self.mlp_att(att_conv, n_batch_axes=2) + + # dec_z_tiled: utt x frame x att_dim + dec_z_tiled = F.broadcast_to( + F.expand_dims(self.mlp_dec(dec_z), 1), self.pre_compute_enc_h.shape + ) + + # dot with gvec + # utt x frame x att_dim -> utt x frame + # TODO(watanabe) use batch_matmul + e = F.squeeze( + self.gvec( + F.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled), n_batch_axes=2 + ), + axis=2, + ) + # Applying a minus-large-number filter + # to make a probability value zero for a padded area + # simply degrades the performance, and I gave up this implementation + # Apply a scaling to make an attention sharp + w = F.softmax(scaling * e) + + # weighted sum over flames + # utt x hdim + c = F.sum( + self.enc_h * F.broadcast_to(F.expand_dims(w, 2), self.enc_h.shape), axis=1 + ) + + return c, w + + +class NoAtt(chainer.Chain): + """Compute non-attention layer. + + This layer is a dummy attention layer to be compatible with other + attention-based models. + + """ + + def __init__(self): + super(NoAtt, self).__init__() + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.c = None + + def reset(self): + """Reset states.""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.c = None + + def __call__(self, enc_hs, dec_z, att_prev): + """Compute NoAtt forward layer. + + Args: + enc_hs (chainer.Variable | N-dimensional array): + Input variable from encoders. + dec_z: Dummy. + att_prev (chainer.Variable | None): Attention weight. + + Returns: + chainer.Variable: Sum over flames. + chainer.Variable: Attention weight. + + """ + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None: + self.enc_h = F.pad_sequence(enc_hs) # utt x frame x hdim + self.h_length = self.enc_h.shape[1] + + # initialize attention weight with uniform dist. + if att_prev is None: + att_prev = [ + self.xp.full(hh.shape[0], 1.0 / hh.shape[0], dtype=np.float32) + for hh in enc_hs + ] + att_prev = [chainer.Variable(att) for att in att_prev] + att_prev = F.pad_sequence(att_prev) + self.c = F.sum( + self.enc_h + * F.broadcast_to(F.expand_dims(att_prev, 2), self.enc_h.shape), + axis=1, + ) + + return self.c, att_prev + + +def att_for(args): + """Returns an attention layer given the program arguments. + + Args: + args (Namespace): The arguments. + + Returns: + chainer.Chain: The corresponding attention module. + + """ + if args.atype == "dot": + att = AttDot(args.eprojs, args.dunits, args.adim) + elif args.atype == "location": + att = AttLoc( + args.eprojs, args.dunits, args.adim, args.aconv_chans, args.aconv_filts + ) + elif args.atype == "noatt": + att = NoAtt() + else: + raise NotImplementedError( + "chainer supports only noatt, dot, and location attention." + ) + return att diff --git a/espnet/nets/chainer_backend/rnn/decoders.py b/espnet/nets/chainer_backend/rnn/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a94a33dd243c98441bc29bce9ed7c7876b3459 --- /dev/null +++ b/espnet/nets/chainer_backend/rnn/decoders.py @@ -0,0 +1,528 @@ +import logging +import random +import six + +import chainer +import chainer.functions as F +import chainer.links as L +import numpy as np + +import espnet.nets.chainer_backend.deterministic_embed_id as DL + +from argparse import Namespace + +from espnet.nets.ctc_prefix_score import CTCPrefixScore +from espnet.nets.e2e_asr_common import end_detect + +CTC_SCORING_RATIO = 1.5 +MAX_DECODER_OUTPUT = 5 + + +class Decoder(chainer.Chain): + """Decoder layer. + + Args: + eprojs (int): Dimension of input variables from encoder. + odim (int): The output dimension. + dtype (str): Decoder type. + dlayers (int): Number of layers for decoder. + dunits (int): Dimension of input vector of decoder. + sos (int): Number to indicate the start of sequences. + eos (int): Number to indicate the end of sequences. + att (Module): Attention module defined at + `espnet.espnet.nets.chainer_backend.attentions`. + verbose (int): Verbosity level. + char_list (List[str]): List of all charactors. + labeldist (numpy.array): Distributed array of counted transcript length. + lsm_weight (float): Weight to use when calculating the training loss. + sampling_probability (float): Threshold for scheduled sampling. + + """ + + def __init__( + self, + eprojs, + odim, + dtype, + dlayers, + dunits, + sos, + eos, + att, + verbose=0, + char_list=None, + labeldist=None, + lsm_weight=0.0, + sampling_probability=0.0, + ): + super(Decoder, self).__init__() + with self.init_scope(): + self.embed = DL.EmbedID(odim, dunits) + self.rnn0 = ( + L.StatelessLSTM(dunits + eprojs, dunits) + if dtype == "lstm" + else L.StatelessGRU(dunits + eprojs, dunits) + ) + for i in six.moves.range(1, dlayers): + setattr( + self, + "rnn%d" % i, + L.StatelessLSTM(dunits, dunits) + if dtype == "lstm" + else L.StatelessGRU(dunits, dunits), + ) + self.output = L.Linear(dunits, odim) + self.dtype = dtype + self.loss = None + self.att = att + self.dlayers = dlayers + self.dunits = dunits + self.sos = sos + self.eos = eos + self.verbose = verbose + self.char_list = char_list + # for label smoothing + self.labeldist = labeldist + self.vlabeldist = None + self.lsm_weight = lsm_weight + self.sampling_probability = sampling_probability + + def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev): + if self.dtype == "lstm": + c_list[0], z_list[0] = self.rnn0(c_prev[0], z_prev[0], ey) + for i in six.moves.range(1, self.dlayers): + c_list[i], z_list[i] = self["rnn%d" % i]( + c_prev[i], z_prev[i], z_list[i - 1] + ) + else: + if z_prev[0] is None: + xp = self.xp + with chainer.backends.cuda.get_device_from_id(self._device_id): + z_prev[0] = chainer.Variable( + xp.zeros((ey.shape[0], self.dunits), dtype=ey.dtype) + ) + z_list[0] = self.rnn0(z_prev[0], ey) + for i in six.moves.range(1, self.dlayers): + if z_prev[i] is None: + xp = self.xp + with chainer.backends.cuda.get_device_from_id(self._device_id): + z_prev[i] = chainer.Variable( + xp.zeros( + (z_list[i - 1].shape[0], self.dunits), + dtype=z_list[i - 1].dtype, + ) + ) + z_list[i] = self["rnn%d" % i](z_prev[i], z_list[i - 1]) + return z_list, c_list + + def __call__(self, hs, ys): + """Core function of Decoder layer. + + Args: + hs (list of chainer.Variable | N-dimension array): + Input variable from encoder. + ys (list of chainer.Variable | N-dimension array): + Input variable of decoder. + + Returns: + chainer.Variable: A variable holding a scalar array of the training loss. + chainer.Variable: A variable holding a scalar array of the accuracy. + + """ + self.loss = None + # prepare input and output word sequences with sos/eos IDs + eos = self.xp.array([self.eos], "i") + sos = self.xp.array([self.sos], "i") + ys_in = [F.concat([sos, y], axis=0) for y in ys] + ys_out = [F.concat([y, eos], axis=0) for y in ys] + + # padding for ys with -1 + # pys: utt x olen + pad_ys_in = F.pad_sequence(ys_in, padding=self.eos) + pad_ys_out = F.pad_sequence(ys_out, padding=-1) + + # get dim, length info + batch = pad_ys_out.shape[0] + olength = pad_ys_out.shape[1] + logging.info( + self.__class__.__name__ + + " input lengths: " + + str(self.xp.array([h.shape[0] for h in hs])) + ) + logging.info( + self.__class__.__name__ + + " output lengths: " + + str(self.xp.array([y.shape[0] for y in ys_out])) + ) + + # initialization + c_list = [None] # list of cell state of each layer + z_list = [None] # list of hidden state of each layer + for _ in six.moves.range(1, self.dlayers): + c_list.append(None) + z_list.append(None) + att_w = None + z_all = [] + self.att.reset() # reset pre-computation of h + + # pre-computation of embedding + eys = self.embed(pad_ys_in) # utt x olen x zdim + eys = F.separate(eys, axis=1) + + # loop for an output sequence + for i in six.moves.range(olength): + att_c, att_w = self.att(hs, z_list[0], att_w) + if i > 0 and random.random() < self.sampling_probability: + logging.info(" scheduled sampling ") + z_out = self.output(z_all[-1]) + z_out = F.argmax(F.log_softmax(z_out), axis=1) + z_out = self.embed(z_out) + ey = F.hstack((z_out, att_c)) # utt x (zdim + hdim) + else: + ey = F.hstack((eys[i], att_c)) # utt x (zdim + hdim) + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) + z_all.append(z_list[-1]) + + z_all = F.stack(z_all, axis=1).reshape(batch * olength, self.dunits) + # compute loss + y_all = self.output(z_all) + self.loss = F.softmax_cross_entropy(y_all, F.flatten(pad_ys_out)) + # -1: eos, which is removed in the loss computation + self.loss *= np.mean([len(x) for x in ys_in]) - 1 + acc = F.accuracy(y_all, F.flatten(pad_ys_out), ignore_label=-1) + logging.info("att loss:" + str(self.loss.data)) + + # show predicted character sequence for debug + if self.verbose > 0 and self.char_list is not None: + y_hat = y_all.reshape(batch, olength, -1) + y_true = pad_ys_out + for (i, y_hat_), y_true_ in zip(enumerate(y_hat.data), y_true.data): + if i == MAX_DECODER_OUTPUT: + break + idx_hat = self.xp.argmax(y_hat_[y_true_ != -1], axis=1) + idx_true = y_true_[y_true_ != -1] + seq_hat = [self.char_list[int(idx)] for idx in idx_hat] + seq_true = [self.char_list[int(idx)] for idx in idx_true] + seq_hat = "".join(seq_hat).replace("", " ") + seq_true = "".join(seq_true).replace("", " ") + logging.info("groundtruth[%d]: " % i + seq_true) + logging.info("prediction [%d]: " % i + seq_hat) + + if self.labeldist is not None: + if self.vlabeldist is None: + self.vlabeldist = chainer.Variable(self.xp.asarray(self.labeldist)) + loss_reg = -F.sum( + F.scale(F.log_softmax(y_all), self.vlabeldist, axis=1) + ) / len(ys_in) + self.loss = (1.0 - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg + + return self.loss, acc + + def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None): + """Beam search implementation. + + Args: + h (chainer.Variable): One of the output from the encoder. + lpz (chainer.Variable | None): Result of net propagation. + recog_args (Namespace): The argument. + char_list (List[str]): List of all charactors. + rnnlm (Module): RNNLM module. Defined at `espnet.lm.chainer_backend.lm` + + Returns: + List[Dict[str,Any]]: Result of recognition. + + """ + logging.info("input lengths: " + str(h.shape[0])) + # initialization + c_list = [None] # list of cell state of each layer + z_list = [None] # list of hidden state of each layer + for _ in six.moves.range(1, self.dlayers): + c_list.append(None) + z_list.append(None) + a = None + self.att.reset() # reset pre-computation of h + + # search parms + beam = recog_args.beam_size + penalty = recog_args.penalty + ctc_weight = recog_args.ctc_weight + + # preprate sos + y = self.xp.full(1, self.sos, "i") + if recog_args.maxlenratio == 0: + maxlen = h.shape[0] + else: + # maxlen >= 1 + maxlen = max(1, int(recog_args.maxlenratio * h.shape[0])) + minlen = int(recog_args.minlenratio * h.shape[0]) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # initialize hypothesis + if rnnlm: + hyp = { + "score": 0.0, + "yseq": [y], + "c_prev": c_list, + "z_prev": z_list, + "a_prev": a, + "rnnlm_prev": None, + } + else: + hyp = { + "score": 0.0, + "yseq": [y], + "c_prev": c_list, + "z_prev": z_list, + "a_prev": a, + } + if lpz is not None: + ctc_prefix_score = CTCPrefixScore(lpz, 0, self.eos, self.xp) + hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() + hyp["ctc_score_prev"] = 0.0 + if ctc_weight != 1.0: + # pre-pruning based on attention scores + ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) + else: + ctc_beam = lpz.shape[-1] + hyps = [hyp] + ended_hyps = [] + + for i in six.moves.range(maxlen): + logging.debug("position " + str(i)) + + hyps_best_kept = [] + for hyp in hyps: + ey = self.embed(hyp["yseq"][i]) # utt list (1) x zdim + att_c, att_w = self.att([h], hyp["z_prev"][0], hyp["a_prev"]) + ey = F.hstack((ey, att_c)) # utt(1) x (zdim + hdim) + + z_list, c_list = self.rnn_forward( + ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"] + ) + + # get nbest local scores and their ids + local_att_scores = F.log_softmax(self.output(z_list[-1])).data + if rnnlm: + rnnlm_state, local_lm_scores = rnnlm.predict( + hyp["rnnlm_prev"], hyp["yseq"][i] + ) + local_scores = ( + local_att_scores + recog_args.lm_weight * local_lm_scores + ) + else: + local_scores = local_att_scores + + if lpz is not None: + local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][ + :ctc_beam + ] + ctc_scores, ctc_states = ctc_prefix_score( + hyp["yseq"], local_best_ids, hyp["ctc_state_prev"] + ) + local_scores = (1.0 - ctc_weight) * local_att_scores[ + :, local_best_ids + ] + ctc_weight * (ctc_scores - hyp["ctc_score_prev"]) + if rnnlm: + local_scores += ( + recog_args.lm_weight * local_lm_scores[:, local_best_ids] + ) + joint_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][ + :beam + ] + local_best_scores = local_scores[:, joint_best_ids] + local_best_ids = local_best_ids[joint_best_ids] + else: + local_best_ids = self.xp.argsort(local_scores, axis=1)[0, ::-1][ + :beam + ] + local_best_scores = local_scores[:, local_best_ids] + + for j in six.moves.range(beam): + new_hyp = {} + # do not copy {z,c}_list directly + new_hyp["z_prev"] = z_list[:] + new_hyp["c_prev"] = c_list[:] + new_hyp["a_prev"] = att_w + new_hyp["score"] = hyp["score"] + local_best_scores[0, j] + new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) + new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] + new_hyp["yseq"][len(hyp["yseq"])] = self.xp.full( + 1, local_best_ids[j], "i" + ) + if rnnlm: + new_hyp["rnnlm_prev"] = rnnlm_state + if lpz is not None: + new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[j]] + new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[j]] + # will be (2 x beam) hyps at most + hyps_best_kept.append(new_hyp) + + hyps_best_kept = sorted( + hyps_best_kept, key=lambda x: x["score"], reverse=True + )[:beam] + + # sort and get nbest + hyps = hyps_best_kept + logging.debug("number of pruned hypotheses: " + str(len(hyps))) + logging.debug( + "best hypo: " + + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]).replace( + "", " " + ) + ) + + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last position in the loop") + for hyp in hyps: + hyp["yseq"].append(self.xp.full(1, self.eos, "i")) + + # add ended hypotheses to a final list, + # and removed them from current hypotheses + # (this will be a problem, number of hyps < beam) + remained_hyps = [] + for hyp in hyps: + if hyp["yseq"][-1] == self.eos: + # only store the sequence that has more than minlen outputs + # also add penalty + if len(hyp["yseq"]) > minlen: + hyp["score"] += (i + 1) * penalty + if rnnlm: # Word LM needs to add final score + hyp["score"] += recog_args.lm_weight * rnnlm.final( + hyp["rnnlm_prev"] + ) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + + # end detection + if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: + logging.info("end detected at %d", i) + break + + hyps = remained_hyps + if len(hyps) > 0: + logging.debug("remaining hypotheses: " + str(len(hyps))) + else: + logging.info("no hypothesis. Finish decoding.") + break + + for hyp in hyps: + logging.debug( + "hypo: " + + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]).replace( + "", " " + ) + ) + + logging.debug("number of ended hypotheses: " + str(len(ended_hyps))) + + nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ + : min(len(ended_hyps), recog_args.nbest) + ] + + # check number of hypotheses + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, " + "perform recognition again with smaller minlenratio." + ) + # should copy because Namespace will be overwritten globally + recog_args = Namespace(**vars(recog_args)) + recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) + return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) + + logging.info("total log probability: " + str(nbest_hyps[0]["score"])) + logging.info( + "normalized log probability: " + + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) + ) + + return nbest_hyps + + def calculate_all_attentions(self, hs, ys): + """Calculate all of attentions. + + Args: + hs (list of chainer.Variable | N-dimensional array): + Input variable from encoder. + ys (list of chainer.Variable | N-dimensional array): + Input variable of decoder. + + Returns: + chainer.Variable: List of attention weights. + + """ + # prepare input and output word sequences with sos/eos IDs + eos = self.xp.array([self.eos], "i") + sos = self.xp.array([self.sos], "i") + ys_in = [F.concat([sos, y], axis=0) for y in ys] + ys_out = [F.concat([y, eos], axis=0) for y in ys] + + # padding for ys with -1 + # pys: utt x olen + pad_ys_in = F.pad_sequence(ys_in, padding=self.eos) + pad_ys_out = F.pad_sequence(ys_out, padding=-1) + + # get length info + olength = pad_ys_out.shape[1] + + # initialization + c_list = [None] # list of cell state of each layer + z_list = [None] # list of hidden state of each layer + for _ in six.moves.range(1, self.dlayers): + c_list.append(None) + z_list.append(None) + att_w = None + att_ws = [] + self.att.reset() # reset pre-computation of h + + # pre-computation of embedding + eys = self.embed(pad_ys_in) # utt x olen x zdim + eys = F.separate(eys, axis=1) + + # loop for an output sequence + for i in six.moves.range(olength): + att_c, att_w = self.att(hs, z_list[0], att_w) + ey = F.hstack((eys[i], att_c)) # utt x (zdim + hdim) + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) + att_ws.append(att_w) # for debugging + + att_ws = F.stack(att_ws, axis=1) + att_ws.to_cpu() + + return att_ws.data + + +def decoder_for(args, odim, sos, eos, att, labeldist): + """Return the decoding layer corresponding to the args. + + Args: + args (Namespace): The program arguments. + odim (int): The output dimension. + sos (int): Number to indicate the start of sequences. + eos (int) Number to indicate the end of sequences. + att (Module): + Attention module defined at `espnet.nets.chainer_backend.attentions`. + labeldist (numpy.array): Distributed array of length od transcript. + + Returns: + chainer.Chain: The decoder module. + + """ + return Decoder( + args.eprojs, + odim, + args.dtype, + args.dlayers, + args.dunits, + sos, + eos, + att, + args.verbose, + args.char_list, + labeldist, + args.lsm_weight, + args.sampling_probability, + ) diff --git a/espnet/nets/chainer_backend/rnn/encoders.py b/espnet/nets/chainer_backend/rnn/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..e534c144860688963c5106d0147348511f38cdb4 --- /dev/null +++ b/espnet/nets/chainer_backend/rnn/encoders.py @@ -0,0 +1,329 @@ +import logging +import six + +import chainer +import chainer.functions as F +import chainer.links as L +import numpy as np + +from chainer import cuda + +from espnet.nets.chainer_backend.nets_utils import _subsamplex +from espnet.nets.e2e_asr_common import get_vgg2l_odim + + +# TODO(watanabe) explanation of BLSTMP +class RNNP(chainer.Chain): + """RNN with projection layer module. + + Args: + idim (int): Dimension of inputs. + elayers (int): Number of encoder layers. + cdim (int): Number of rnn units. (resulted in cdim * 2 if bidirectional) + hdim (int): Number of projection units. + subsample (np.ndarray): List to use sabsample the input array. + dropout (float): Dropout rate. + typ (str): The RNN type. + + """ + + def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"): + super(RNNP, self).__init__() + bidir = typ[0] == "b" + if bidir: + rnn = L.NStepBiLSTM if "lstm" in typ else L.NStepBiGRU + else: + rnn = L.NStepLSTM if "lstm" in typ else L.NStepGRU + rnn_label = "birnn" if bidir else "rnn" + with self.init_scope(): + for i in six.moves.range(elayers): + if i == 0: + inputdim = idim + else: + inputdim = hdim + _cdim = 2 * cdim if bidir else cdim + # bottleneck layer to merge + setattr( + self, "{}{:d}".format(rnn_label, i), rnn(1, inputdim, cdim, dropout) + ) + setattr(self, "bt%d" % i, L.Linear(_cdim, hdim)) + + self.elayers = elayers + self.rnn_label = rnn_label + self.cdim = cdim + self.subsample = subsample + self.typ = typ + self.bidir = bidir + + def __call__(self, xs, ilens): + """RNNP forward. + + Args: + xs (chainer.Variable): Batch of padded charactor ids. (B, Tmax) + ilens (chainer.Variable): Batch of length of each input batch. (B,) + + Returns: + xs (chainer.Variable):subsampled vector of xs. + chainer.Variable: Subsampled vector of ilens. + + """ + logging.info(self.__class__.__name__ + " input lengths: " + str(ilens)) + + for layer in six.moves.range(self.elayers): + if "lstm" in self.typ: + _, _, ys = self[self.rnn_label + str(layer)](None, None, xs) + else: + _, ys = self[self.rnn_label + str(layer)](None, xs) + # ys: utt list of frame x cdim x 2 (2: means bidirectional) + # TODO(watanabe) replace subsample and FC layer with CNN + ys, ilens = _subsamplex(ys, self.subsample[layer + 1]) + # (sum _utt frame_utt) x dim + ys = self["bt" + str(layer)](F.vstack(ys)) + xs = F.split_axis(ys, np.cumsum(ilens[:-1]), axis=0) + + # final tanh operation + xs = F.split_axis(F.tanh(F.vstack(xs)), np.cumsum(ilens[:-1]), axis=0) + + # 1 utterance case, it becomes an array, so need to make a utt tuple + if not isinstance(xs, tuple): + xs = [xs] + + return xs, ilens # x: utt list of frame x dim + + +class RNN(chainer.Chain): + """RNN Module. + + Args: + idim (int): Dimension of the imput. + elayers (int): Number of encoder layers. + cdim (int): Number of rnn units. + hdim (int): Number of projection units. + dropout (float): Dropout rate. + typ (str): Rnn type. + + """ + + def __init__(self, idim, elayers, cdim, hdim, dropout, typ="lstm"): + super(RNN, self).__init__() + bidir = typ[0] == "b" + if bidir: + rnn = L.NStepBiLSTM if "lstm" in typ else L.NStepBiGRU + else: + rnn = L.NStepLSTM if "lstm" in typ else L.NStepGRU + _cdim = 2 * cdim if bidir else cdim + with self.init_scope(): + self.nbrnn = rnn(elayers, idim, cdim, dropout) + self.l_last = L.Linear(_cdim, hdim) + self.typ = typ + self.bidir = bidir + + def __call__(self, xs, ilens): + """BRNN forward propagation. + + Args: + xs (chainer.Variable): Batch of padded charactor ids. (B, Tmax) + ilens (chainer.Variable): Batch of length of each input batch. (B,) + + Returns: + tuple(chainer.Variable): Tuple of `chainer.Variable` objects. + chainer.Variable: `ilens` . + + """ + logging.info(self.__class__.__name__ + " input lengths: " + str(ilens)) + # need to move ilens to cpu + ilens = cuda.to_cpu(ilens) + + if "lstm" in self.typ: + _, _, ys = self.nbrnn(None, None, xs) + else: + _, ys = self.nbrnn(None, xs) + ys = self.l_last(F.vstack(ys)) # (sum _utt frame_utt) x dim + xs = F.split_axis(ys, np.cumsum(ilens[:-1]), axis=0) + + # final tanh operation + xs = F.split_axis(F.tanh(F.vstack(xs)), np.cumsum(ilens[:-1]), axis=0) + + # 1 utterance case, it becomes an array, so need to make a utt tuple + if not isinstance(xs, tuple): + xs = [xs] + + return xs, ilens # x: utt list of frame x dim + + +# TODO(watanabe) explanation of VGG2L, VGG2B (Block) might be better +class VGG2L(chainer.Chain): + """VGG motibated cnn layers. + + Args: + in_channel (int): Number of channels. + + """ + + def __init__(self, in_channel=1): + super(VGG2L, self).__init__() + with self.init_scope(): + # CNN layer (VGG motivated) + self.conv1_1 = L.Convolution2D(in_channel, 64, 3, stride=1, pad=1) + self.conv1_2 = L.Convolution2D(64, 64, 3, stride=1, pad=1) + self.conv2_1 = L.Convolution2D(64, 128, 3, stride=1, pad=1) + self.conv2_2 = L.Convolution2D(128, 128, 3, stride=1, pad=1) + + self.in_channel = in_channel + + def __call__(self, xs, ilens): + """VGG2L forward propagation. + + Args: + xs (chainer.Variable): Batch of padded charactor ids. (B, Tmax) + ilens (chainer.Variable): Batch of length of each features. (B,) + + Returns: + chainer.Variable: Subsampled vector of xs. + chainer.Variable: Subsampled vector of ilens. + + """ + logging.info(self.__class__.__name__ + " input lengths: " + str(ilens)) + + # x: utt x frame x dim + xs = F.pad_sequence(xs) + + # x: utt x 1 (input channel num) x frame x dim + xs = F.swapaxes( + xs.reshape( + xs.shape[0], + xs.shape[1], + self.in_channel, + xs.shape[2] // self.in_channel, + ), + 1, + 2, + ) + + xs = F.relu(self.conv1_1(xs)) + xs = F.relu(self.conv1_2(xs)) + xs = F.max_pooling_2d(xs, 2, stride=2) + + xs = F.relu(self.conv2_1(xs)) + xs = F.relu(self.conv2_2(xs)) + xs = F.max_pooling_2d(xs, 2, stride=2) + + # change ilens accordingly + ilens = self.xp.array( + self.xp.ceil(self.xp.array(ilens, dtype=np.float32) / 2), dtype=np.int32 + ) + ilens = self.xp.array( + self.xp.ceil(self.xp.array(ilens, dtype=np.float32) / 2), dtype=np.int32 + ) + + # x: utt_list of frame (remove zeropaded frames) x (input channel num x dim) + xs = F.swapaxes(xs, 1, 2) + xs = xs.reshape(xs.shape[0], xs.shape[1], xs.shape[2] * xs.shape[3]) + xs = [xs[i, : ilens[i], :] for i in range(len(ilens))] + + return xs, ilens + + +class Encoder(chainer.Chain): + """Encoder network class. + + Args: + etype (str): Type of encoder network. + idim (int): Number of dimensions of encoder network. + elayers (int): Number of layers of encoder network. + eunits (int): Number of lstm units of encoder network. + eprojs (int): Number of projection units of encoder network. + subsample (np.array): Subsampling number. e.g. 1_2_2_2_1 + dropout (float): Dropout rate. + + """ + + def __init__( + self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1 + ): + super(Encoder, self).__init__() + typ = etype.lstrip("vgg").rstrip("p") + if typ not in ["lstm", "gru", "blstm", "bgru"]: + logging.error("Error: need to specify an appropriate encoder architecture") + with self.init_scope(): + if etype.startswith("vgg"): + if etype[-1] == "p": + self.enc = chainer.Sequential( + VGG2L(in_channel), + RNNP( + get_vgg2l_odim(idim, in_channel=in_channel), + elayers, + eunits, + eprojs, + subsample, + dropout, + typ=typ, + ), + ) + logging.info("Use CNN-VGG + " + typ.upper() + "P for encoder") + else: + self.enc = chainer.Sequential( + VGG2L(in_channel), + RNN( + get_vgg2l_odim(idim, in_channel=in_channel), + elayers, + eunits, + eprojs, + dropout, + typ=typ, + ), + ) + logging.info("Use CNN-VGG + " + typ.upper() + " for encoder") + self.conv_subsampling_factor = 4 + else: + if etype[-1] == "p": + self.enc = chainer.Sequential( + RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ) + ) + logging.info( + typ.upper() + " with every-layer projection for encoder" + ) + else: + self.enc = chainer.Sequential( + RNN(idim, elayers, eunits, eprojs, dropout, typ=typ) + ) + logging.info(typ.upper() + " without projection for encoder") + self.conv_subsampling_factor = 1 + + def __call__(self, xs, ilens): + """Encoder forward. + + Args: + xs (chainer.Variable): Batch of padded charactor ids. (B, Tmax) + ilens (chainer.variable): Batch of length of each features. (B,) + + Returns: + chainer.Variable: Output of the encoder. + chainer.Variable: (Subsampled) vector of ilens. + + """ + xs, ilens = self.enc(xs, ilens) + + return xs, ilens + + +def encoder_for(args, idim, subsample): + """Return the Encoder module. + + Args: + idim (int): Dimension of input array. + subsample (numpy.array): Subsample number. egs).1_2_2_2_1 + + Return + chainer.nn.Module: Encoder module. + + """ + return Encoder( + args.etype, + idim, + args.elayers, + args.eunits, + args.eprojs, + subsample, + args.dropout_rate, + ) diff --git a/espnet/nets/chainer_backend/rnn/training.py b/espnet/nets/chainer_backend/rnn/training.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc37d681a19622662e3b4b20ec9225098eb25f6 --- /dev/null +++ b/espnet/nets/chainer_backend/rnn/training.py @@ -0,0 +1,261 @@ +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + + +import collections +import logging +import math +import six + +# chainer related +from chainer import cuda +from chainer import training +from chainer import Variable + +from chainer.training.updaters.multiprocess_parallel_updater import gather_grads +from chainer.training.updaters.multiprocess_parallel_updater import gather_params +from chainer.training.updaters.multiprocess_parallel_updater import scatter_grads + +import numpy as np + + +# copied from https://github.com/chainer/chainer/blob/master/chainer/optimizer.py +def sum_sqnorm(arr): + """Calculate the norm of the array. + + Args: + arr (numpy.ndarray) + + Returns: + Float: Sum of the norm calculated from the given array. + + """ + sq_sum = collections.defaultdict(float) + for x in arr: + with cuda.get_device_from_array(x) as dev: + if x is not None: + x = x.ravel() + s = x.dot(x) + sq_sum[int(dev)] += s + return sum([float(i) for i in six.itervalues(sq_sum)]) + + +class CustomUpdater(training.StandardUpdater): + """Custom updater for chainer. + + Args: + train_iter (iterator | dict[str, iterator]): Dataset iterator for the + training dataset. It can also be a dictionary that maps strings to + iterators. If this is just an iterator, then the iterator is + registered by the name ``'main'``. + optimizer (optimizer | dict[str, optimizer]): Optimizer to update + parameters. It can also be a dictionary that maps strings to + optimizers. If this is just an optimizer, then the optimizer is + registered by the name ``'main'``. + converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter + function to build input arrays. Each batch extracted by the main + iterator and the ``device`` option are passed to this function. + :func:`chainer.dataset.concat_examples` is used by default. + device (int or dict): The destination device info to send variables. In the + case of cpu or single gpu, `device=-1 or 0`, respectively. + In the case of multi-gpu, `device={"main":0, "sub_1": 1, ...}`. + accum_grad (int):The number of gradient accumulation. if set to 2, the network + parameters will be updated once in twice, + i.e. actual batchsize will be doubled. + + """ + + def __init__(self, train_iter, optimizer, converter, device, accum_grad=1): + super(CustomUpdater, self).__init__( + train_iter, optimizer, converter=converter, device=device + ) + self.forward_count = 0 + self.accum_grad = accum_grad + self.start = True + # To solve #1091, it is required to set the variable inside this class. + self.device = device + + # The core part of the update routine can be customized by overriding. + def update_core(self): + """Main update routine for Custom Updater.""" + train_iter = self.get_iterator("main") + optimizer = self.get_optimizer("main") + + # Get batch and convert into variables + batch = train_iter.next() + x = self.converter(batch, self.device) + if self.start: + optimizer.target.cleargrads() + self.start = False + + # Compute the loss at this time step and accumulate it + loss = optimizer.target(*x) / self.accum_grad + loss.backward() # Backprop + loss.unchain_backward() # Truncate the graph + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + # compute the gradient norm to check if it is normal or not + grad_norm = np.sqrt( + sum_sqnorm([p.grad for p in optimizer.target.params(False)]) + ) + logging.info("grad norm={}".format(grad_norm)) + if math.isnan(grad_norm): + logging.warning("grad norm is nan. Do not update model.") + else: + optimizer.update() + optimizer.target.cleargrads() # Clear the parameter gradients + + def update(self): + self.update_core() + if self.forward_count == 0: + self.iteration += 1 + + +class CustomParallelUpdater(training.updaters.MultiprocessParallelUpdater): + """Custom Parallel Updater for chainer. + + Defines the main update routine. + + Args: + train_iter (iterator | dict[str, iterator]): Dataset iterator for the + training dataset. It can also be a dictionary that maps strings to + iterators. If this is just an iterator, then the iterator is + registered by the name ``'main'``. + optimizer (optimizer | dict[str, optimizer]): Optimizer to update + parameters. It can also be a dictionary that maps strings to + optimizers. If this is just an optimizer, then the optimizer is + registered by the name ``'main'``. + converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter + function to build input arrays. Each batch extracted by the main + iterator and the ``device`` option are passed to this function. + :func:`chainer.dataset.concat_examples` is used by default. + device (torch.device): Device to which the training data is sent. + Negative value + indicates the host memory (CPU). + accum_grad (int):The number of gradient accumulation. if set to 2, + the network parameters will be updated once in twice, + i.e. actual batchsize will be doubled. + + """ + + def __init__(self, train_iters, optimizer, converter, devices, accum_grad=1): + super(CustomParallelUpdater, self).__init__( + train_iters, optimizer, converter=converter, devices=devices + ) + from cupy.cuda import nccl + + self.accum_grad = accum_grad + self.forward_count = 0 + self.nccl = nccl + + # The core part of the update routine can be customized by overriding. + def update_core(self): + """Main Update routine of the custom parallel updater.""" + self.setup_workers() + + self._send_message(("update", None)) + with cuda.Device(self._devices[0]): + # For reducing memory + + optimizer = self.get_optimizer("main") + batch = self.get_iterator("main").next() + x = self.converter(batch, self._devices[0]) + + loss = self._master(*x) / self.accum_grad + loss.backward() + loss.unchain_backward() + + # NCCL: reduce grads + null_stream = cuda.Stream.null + if self.comm is not None: + gg = gather_grads(self._master) + self.comm.reduce( + gg.data.ptr, + gg.data.ptr, + gg.size, + self.nccl.NCCL_FLOAT, + self.nccl.NCCL_SUM, + 0, + null_stream.ptr, + ) + scatter_grads(self._master, gg) + del gg + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + # check gradient value + grad_norm = np.sqrt( + sum_sqnorm([p.grad for p in optimizer.target.params(False)]) + ) + logging.info("grad norm={}".format(grad_norm)) + + # update + if math.isnan(grad_norm): + logging.warning("grad norm is nan. Do not update model.") + else: + optimizer.update() + self._master.cleargrads() + + if self.comm is not None: + gp = gather_params(self._master) + self.comm.bcast( + gp.data.ptr, gp.size, self.nccl.NCCL_FLOAT, 0, null_stream.ptr + ) + + def update(self): + self.update_core() + if self.forward_count == 0: + self.iteration += 1 + + +class CustomConverter(object): + """Custom Converter. + + Args: + subsampling_factor (int): The subsampling factor. + + """ + + def __init__(self, subsampling_factor=1): + self.subsampling_factor = subsampling_factor + + def __call__(self, batch, device): + """Perform sabsampling. + + Args: + batch (list): Batch that will be sabsampled. + device (device): GPU device. + + Returns: + chainer.Variable: xp.array that sabsampled from batch. + xp.array: xp.array of the length of the mini-batches. + chainer.Variable: xp.array that sabsampled from batch. + + """ + # set device + xp = cuda.cupy if device != -1 else np + + # batch should be located in list + assert len(batch) == 1 + xs, ys = batch[0] + + # perform subsampling + if self.subsampling_factor > 1: + xs = [x[:: self.subsampling_factor, :] for x in xs] + + # get batch made of lengths of input sequences + ilens = [x.shape[0] for x in xs] + + # convert to Variable + xs = [Variable(xp.array(x, dtype=xp.float32)) for x in xs] + ilens = xp.array(ilens, dtype=xp.int32) + ys = [Variable(xp.array(y, dtype=xp.int32)) for y in ys] + + return xs, ilens, ys diff --git a/espnet/nets/chainer_backend/transformer/__init__.py b/espnet/nets/chainer_backend/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/chainer_backend/transformer/attention.py b/espnet/nets/chainer_backend/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d26d82fb10feca49e41cea4cd3b2107f12e64cc9 --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/attention.py @@ -0,0 +1,98 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's Attention.""" + +import chainer + +import chainer.functions as F +import chainer.links as L + +import numpy as np + +MIN_VALUE = float(np.finfo(np.float32).min) + + +class MultiHeadAttention(chainer.Chain): + """Multi Head Attention Layer. + + Args: + n_units (int): Number of input units. + h (int): Number of attention heads. + dropout (float): Dropout rate. + initialW: Initializer to initialize the weight. + initial_bias: Initializer to initialize the bias. + + :param int h: the number of heads + :param int n_units: the number of features + :param float dropout_rate: dropout rate + + """ + + def __init__(self, n_units, h=8, dropout=0.1, initialW=None, initial_bias=None): + """Initialize MultiHeadAttention.""" + super(MultiHeadAttention, self).__init__() + assert n_units % h == 0 + stvd = 1.0 / np.sqrt(n_units) + with self.init_scope(): + self.linear_q = L.Linear( + n_units, + n_units, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + self.linear_k = L.Linear( + n_units, + n_units, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + self.linear_v = L.Linear( + n_units, + n_units, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + self.linear_out = L.Linear( + n_units, + n_units, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + self.d_k = n_units // h + self.h = h + self.dropout = dropout + self.attn = None + + def forward(self, e_var, s_var=None, mask=None, batch=1): + """Core function of the Multi-head attention layer. + + Args: + e_var (chainer.Variable): Variable of input array. + s_var (chainer.Variable): Variable of source array from encoder. + mask (chainer.Variable): Attention mask. + batch (int): Batch size. + + Returns: + chainer.Variable: Outout of multi-head attention layer. + + """ + xp = self.xp + if s_var is None: + # batch, head, time1/2, d_k) + Q = self.linear_q(e_var).reshape(batch, -1, self.h, self.d_k) + K = self.linear_k(e_var).reshape(batch, -1, self.h, self.d_k) + V = self.linear_v(e_var).reshape(batch, -1, self.h, self.d_k) + else: + Q = self.linear_q(e_var).reshape(batch, -1, self.h, self.d_k) + K = self.linear_k(s_var).reshape(batch, -1, self.h, self.d_k) + V = self.linear_v(s_var).reshape(batch, -1, self.h, self.d_k) + scores = F.matmul(F.swapaxes(Q, 1, 2), K.transpose(0, 2, 3, 1)) / np.sqrt( + self.d_k + ) + if mask is not None: + mask = xp.stack([mask] * self.h, axis=1) + scores = F.where(mask, scores, xp.full(scores.shape, MIN_VALUE, "f")) + self.attn = F.softmax(scores, axis=-1) + p_attn = F.dropout(self.attn, self.dropout) + x = F.matmul(p_attn, F.swapaxes(V, 1, 2)) + x = F.swapaxes(x, 1, 2).reshape(-1, self.h * self.d_k) + return self.linear_out(x) diff --git a/espnet/nets/chainer_backend/transformer/ctc.py b/espnet/nets/chainer_backend/transformer/ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..caac6545655adec203596b566917a90e5e27a333 --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/ctc.py @@ -0,0 +1,169 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's CTC.""" +import logging + +import chainer +import chainer.functions as F +import chainer.links as L +import numpy as np + + +# TODO(nelson): Merge chainer_backend/transformer/ctc.py in chainer_backend/ctc.py +class CTC(chainer.Chain): + """Chainer implementation of ctc layer. + + Args: + odim (int): The output dimension. + eprojs (int | None): Dimension of input vectors from encoder. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, odim, eprojs, dropout_rate): + """Initialize CTC.""" + super(CTC, self).__init__() + self.dropout_rate = dropout_rate + self.loss = None + + with self.init_scope(): + self.ctc_lo = L.Linear(eprojs, odim) + + def __call__(self, hs, ys): + """CTC forward. + + Args: + hs (list of chainer.Variable | N-dimension array): + Input variable from encoder. + ys (list of chainer.Variable | N-dimension array): + Input variable of decoder. + + Returns: + chainer.Variable: A variable holding a scalar value of the CTC loss. + + """ + self.loss = None + ilens = [x.shape[0] for x in hs] + olens = [x.shape[0] for x in ys] + + # zero padding for hs + y_hat = self.ctc_lo( + F.dropout(F.pad_sequence(hs), ratio=self.dropout_rate), n_batch_axes=2 + ) + y_hat = F.separate(y_hat, axis=1) # ilen list of batch x hdim + + # zero padding for ys + y_true = F.pad_sequence(ys, padding=-1) # batch x olen + + # get length info + input_length = chainer.Variable(self.xp.array(ilens, dtype=np.int32)) + label_length = chainer.Variable(self.xp.array(olens, dtype=np.int32)) + logging.info( + self.__class__.__name__ + " input lengths: " + str(input_length.data) + ) + logging.info( + self.__class__.__name__ + " output lengths: " + str(label_length.data) + ) + + # get ctc loss + self.loss = F.connectionist_temporal_classification( + y_hat, y_true, 0, input_length, label_length + ) + logging.info("ctc loss:" + str(self.loss.data)) + + return self.loss + + def log_softmax(self, hs): + """Log_softmax of frame activations. + + Args: + hs (list of chainer.Variable | N-dimension array): + Input variable from encoder. + + Returns: + chainer.Variable: A n-dimension float array. + + """ + y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2) + return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape) + + +class WarpCTC(chainer.Chain): + """Chainer implementation of warp-ctc layer. + + Args: + odim (int): The output dimension. + eproj (int | None): Dimension of input vector from encoder. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, odim, eprojs, dropout_rate): + """Initialize WarpCTC.""" + super(WarpCTC, self).__init__() + # The main difference between the ctc for transformer and + # the rnn is because the target (ys) is already a list of + # arrays located in the cpu, while in rnn routine the target is + # a list of variables located in cpu/gpu. If the target of rnn becomes + # a list of cpu arrays then this file would be no longer required. + from chainer_ctc.warpctc import ctc as warp_ctc + + self.ctc = warp_ctc + self.dropout_rate = dropout_rate + self.loss = None + + with self.init_scope(): + self.ctc_lo = L.Linear(eprojs, odim) + + def forward(self, hs, ys): + """Core function of the Warp-CTC layer. + + Args: + hs (iterable of chainer.Variable | N-dimention array): + Input variable from encoder. + ys (iterable of N-dimension array): Input variable of decoder. + + Returns: + chainer.Variable: A variable holding a scalar value of the CTC loss. + + """ + self.loss = None + ilens = [hs.shape[1]] * hs.shape[0] + olens = [x.shape[0] for x in ys] + + # zero padding for hs + # output batch x frames x hdim > frames x batch x hdim + y_hat = self.ctc_lo( + F.dropout(hs, ratio=self.dropout_rate), n_batch_axes=2 + ).transpose(1, 0, 2) + + # get length info + logging.info(self.__class__.__name__ + " input lengths: " + str(ilens)) + logging.info(self.__class__.__name__ + " output lengths: " + str(olens)) + + # get ctc loss + self.loss = self.ctc(y_hat, ilens, ys)[0] + logging.info("ctc loss:" + str(self.loss.data)) + return self.loss + + def log_softmax(self, hs): + """Log_softmax of frame activations. + + Args: + hs (list of chainer.Variable | N-dimension array): + Input variable from encoder. + + Returns: + chainer.Variable: A n-dimension float array. + + """ + y_hat = self.ctc_lo(F.pad_sequence(hs), n_batch_axes=2) + return F.log_softmax(y_hat.reshape(-1, y_hat.shape[-1])).reshape(y_hat.shape) + + def argmax(self, hs_pad): + """Argmax of frame activations. + + :param chainer variable hs_pad: 3d tensor (B, Tmax, eprojs) + :return: argmax applied 2d tensor (B, Tmax) + :rtype: chainer.Variable. + """ + return F.argmax(self.ctc_lo(F.pad_sequence(hs_pad), n_batch_axes=2), axis=-1) diff --git a/espnet/nets/chainer_backend/transformer/decoder.py b/espnet/nets/chainer_backend/transformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c216de8e51aaed13af41ae13b495e08a41061eba --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/decoder.py @@ -0,0 +1,115 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's Decoder.""" + +import chainer + +import chainer.functions as F +import chainer.links as L + +from espnet.nets.chainer_backend.transformer.decoder_layer import DecoderLayer +from espnet.nets.chainer_backend.transformer.embedding import PositionalEncoding +from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm +from espnet.nets.chainer_backend.transformer.mask import make_history_mask + +import numpy as np + + +class Decoder(chainer.Chain): + """Decoder layer. + + Args: + odim (int): The output dimension. + n_layers (int): Number of ecoder layers. + n_units (int): Number of attention units. + d_units (int): Dimension of input vector of decoder. + h (int): Number of attention heads. + dropout (float): Dropout rate. + initialW (Initializer): Initializer to initialize the weight. + initial_bias (Initializer): Initializer to initialize teh bias. + + """ + + def __init__(self, odim, args, initialW=None, initial_bias=None): + """Initialize Decoder.""" + super(Decoder, self).__init__() + self.sos = odim - 1 + self.eos = odim - 1 + initialW = chainer.initializers.Uniform if initialW is None else initialW + initial_bias = ( + chainer.initializers.Uniform if initial_bias is None else initial_bias + ) + with self.init_scope(): + self.output_norm = LayerNorm(args.adim) + self.pe = PositionalEncoding(args.adim, args.dropout_rate) + stvd = 1.0 / np.sqrt(args.adim) + self.output_layer = L.Linear( + args.adim, + odim, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + self.embed = L.EmbedID( + odim, + args.adim, + ignore_label=-1, + initialW=chainer.initializers.Normal(scale=1.0), + ) + for i in range(args.dlayers): + name = "decoders." + str(i) + layer = DecoderLayer( + args.adim, + d_units=args.dunits, + h=args.aheads, + dropout=args.dropout_rate, + initialW=initialW, + initial_bias=initial_bias, + ) + self.add_link(name, layer) + self.n_layers = args.dlayers + + def make_attention_mask(self, source_block, target_block): + """Prepare the attention mask. + + Args: + source_block (ndarray): Source block with dimensions: (B x S). + target_block (ndarray): Target block with dimensions: (B x T). + Returns: + ndarray: Mask with dimensions (B, S, T). + + """ + mask = (target_block[:, None, :] >= 0) * (source_block[:, :, None] >= 0) + # (batch, source_length, target_length) + return mask + + def forward(self, ys_pad, source, x_mask): + """Forward decoder. + + :param xp.array e: input token ids, int64 (batch, maxlen_out) + :param xp.array yy_mask: input token mask, uint8 (batch, maxlen_out) + :param xp.array source: encoded memory, float32 (batch, maxlen_in, feat) + :param xp.array xy_mask: encoded memory mask, uint8 (batch, maxlen_in) + :return e: decoded token score before softmax (batch, maxlen_out, token) + :rtype: chainer.Variable + """ + xp = self.xp + sos = np.array([self.sos], np.int32) + ys = [np.concatenate([sos, y], axis=0) for y in ys_pad] + e = F.pad_sequence(ys, padding=self.eos).data + e = xp.array(e) + # mask preparation + xy_mask = self.make_attention_mask(e, xp.array(x_mask)) + yy_mask = self.make_attention_mask(e, e) + yy_mask *= make_history_mask(xp, e) + + e = self.pe(self.embed(e)) + batch, length, dims = e.shape + e = e.reshape(-1, dims) + source = source.reshape(-1, dims) + for i in range(self.n_layers): + e = self["decoders." + str(i)](e, source, xy_mask, yy_mask, batch) + return self.output_layer(self.output_norm(e)).reshape(batch, length, -1) + + def recognize(self, e, yy_mask, source): + """Process recognition function.""" + e = self.forward(e, source, yy_mask) + return F.log_softmax(e, axis=-1) diff --git a/espnet/nets/chainer_backend/transformer/decoder_layer.py b/espnet/nets/chainer_backend/transformer/decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..933290049c2d3c97ac366792bfd629a970b4d398 --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/decoder_layer.py @@ -0,0 +1,80 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's Decoder Block.""" + +import chainer + +import chainer.functions as F + +from espnet.nets.chainer_backend.transformer.attention import MultiHeadAttention +from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm +from espnet.nets.chainer_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) + + +class DecoderLayer(chainer.Chain): + """Single decoder layer module. + + Args: + n_units (int): Number of input/output dimension of a FeedForward layer. + d_units (int): Number of units of hidden layer in a FeedForward layer. + h (int): Number of attention heads. + dropout (float): Dropout rate + + """ + + def __init__( + self, n_units, d_units=0, h=8, dropout=0.1, initialW=None, initial_bias=None + ): + """Initialize DecoderLayer.""" + super(DecoderLayer, self).__init__() + with self.init_scope(): + self.self_attn = MultiHeadAttention( + n_units, + h, + dropout=dropout, + initialW=initialW, + initial_bias=initial_bias, + ) + self.src_attn = MultiHeadAttention( + n_units, + h, + dropout=dropout, + initialW=initialW, + initial_bias=initial_bias, + ) + self.feed_forward = PositionwiseFeedForward( + n_units, + d_units=d_units, + dropout=dropout, + initialW=initialW, + initial_bias=initial_bias, + ) + self.norm1 = LayerNorm(n_units) + self.norm2 = LayerNorm(n_units) + self.norm3 = LayerNorm(n_units) + self.dropout = dropout + + def forward(self, e, s, xy_mask, yy_mask, batch): + """Compute Encoder layer. + + Args: + e (chainer.Variable): Batch of padded features. (B, Lmax) + s (chainer.Variable): Batch of padded character. (B, Tmax) + + Returns: + chainer.Variable: Computed variable of decoder. + + """ + n_e = self.norm1(e) + n_e = self.self_attn(n_e, mask=yy_mask, batch=batch) + e = e + F.dropout(n_e, self.dropout) + + n_e = self.norm2(e) + n_e = self.src_attn(n_e, s_var=s, mask=xy_mask, batch=batch) + e = e + F.dropout(n_e, self.dropout) + + n_e = self.norm3(e) + n_e = self.feed_forward(n_e) + e = e + F.dropout(n_e, self.dropout) + return e diff --git a/espnet/nets/chainer_backend/transformer/embedding.py b/espnet/nets/chainer_backend/transformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..d838c085dad3426e3c9fa68ef12e4f44828a1a3d --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/embedding.py @@ -0,0 +1,37 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's Positional Encoding.""" + +import chainer +import chainer.functions as F + +import numpy as np + + +class PositionalEncoding(chainer.Chain): + """Positional encoding module. + + :param int n_units: embedding dim + :param float dropout: dropout rate + :param int length: maximum input length + + """ + + def __init__(self, n_units, dropout=0.1, length=5000): + """Initialize Positional Encoding.""" + # Implementation described in the paper + super(PositionalEncoding, self).__init__() + self.dropout = dropout + posi_block = np.arange(0, length, dtype=np.float32)[:, None] + unit_block = np.exp( + np.arange(0, n_units, 2, dtype=np.float32) * -(np.log(10000.0) / n_units) + ) + self.pe = np.zeros((length, n_units), dtype=np.float32) + self.pe[:, ::2] = np.sin(posi_block * unit_block) + self.pe[:, 1::2] = np.cos(posi_block * unit_block) + self.scale = np.sqrt(n_units) + + def forward(self, e): + """Forward Positional Encoding.""" + length = e.shape[1] + e = e * self.scale + self.xp.array(self.pe[:length]) + return F.dropout(e, self.dropout) diff --git a/espnet/nets/chainer_backend/transformer/encoder.py b/espnet/nets/chainer_backend/transformer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..af4e98d0504aa3f0e970936e20a2cb22c6f93e09 --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/encoder.py @@ -0,0 +1,134 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's Encoder.""" + +import chainer + +from chainer import links as L + +from espnet.nets.chainer_backend.transformer.embedding import PositionalEncoding +from espnet.nets.chainer_backend.transformer.encoder_layer import EncoderLayer +from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm +from espnet.nets.chainer_backend.transformer.mask import make_history_mask +from espnet.nets.chainer_backend.transformer.subsampling import Conv2dSubsampling +from espnet.nets.chainer_backend.transformer.subsampling import LinearSampling + +import logging +import numpy as np + + +class Encoder(chainer.Chain): + """Encoder. + + Args: + input_type(str): + Sampling type. `input_type` must be `conv2d` or 'linear' currently. + idim (int): Dimension of inputs. + n_layers (int): Number of encoder layers. + n_units (int): Number of input/output dimension of a FeedForward layer. + d_units (int): Number of units of hidden layer in a FeedForward layer. + h (int): Number of attention heads. + dropout (float): Dropout rate + + """ + + def __init__( + self, + idim, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + pos_enc_class=PositionalEncoding, + initialW=None, + initial_bias=None, + ): + """Initialize Encoder. + + Args: + idim (int): Input dimension. + args (Namespace): Training config. + initialW (int, optional): Initializer to initialize the weight. + initial_bias (bool, optional): Initializer to initialize the bias. + + """ + super(Encoder, self).__init__() + initialW = chainer.initializers.Uniform if initialW is None else initialW + initial_bias = ( + chainer.initializers.Uniform if initial_bias is None else initial_bias + ) + self.do_history_mask = False + with self.init_scope(): + self.conv_subsampling_factor = 1 + channels = 64 # Based in paper + if input_layer == "conv2d": + idim = int(np.ceil(np.ceil(idim / 2) / 2)) * channels + self.input_layer = Conv2dSubsampling( + channels, + idim, + attention_dim, + dropout=dropout_rate, + initialW=initialW, + initial_bias=initial_bias, + ) + self.conv_subsampling_factor = 4 + elif input_layer == "linear": + self.input_layer = LinearSampling( + idim, attention_dim, initialW=initialW, initial_bias=initial_bias + ) + elif input_layer == "embed": + self.input_layer = chainer.Sequential( + L.EmbedID(idim, attention_dim, ignore_label=-1), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + self.do_history_mask = True + else: + raise ValueError("unknown input_layer: " + input_layer) + self.norm = LayerNorm(attention_dim) + for i in range(num_blocks): + name = "encoders." + str(i) + layer = EncoderLayer( + attention_dim, + d_units=linear_units, + h=attention_heads, + dropout=attention_dropout_rate, + initialW=initialW, + initial_bias=initial_bias, + ) + self.add_link(name, layer) + self.n_layers = num_blocks + + def forward(self, e, ilens): + """Compute Encoder layer. + + Args: + e (chainer.Variable): Batch of padded charactor. (B, Tmax) + ilens (chainer.Variable): Batch of length of each input batch. (B,) + + Returns: + chainer.Variable: Computed variable of encoder. + numpy.array: Mask. + chainer.Variable: Batch of lengths of each encoder outputs. + + """ + if isinstance(self.input_layer, Conv2dSubsampling): + e, ilens = self.input_layer(e, ilens) + else: + e = self.input_layer(e) + batch, length, dims = e.shape + x_mask = np.ones([batch, length]) + for j in range(batch): + x_mask[j, ilens[j] :] = -1 + xx_mask = (x_mask[:, None, :] >= 0) * (x_mask[:, :, None] >= 0) + xx_mask = self.xp.array(xx_mask) + if self.do_history_mask: + history_mask = make_history_mask(self.xp, x_mask) + xx_mask *= history_mask + logging.debug("encoders size: " + str(e.shape)) + e = e.reshape(-1, dims) + for i in range(self.n_layers): + e = self["encoders." + str(i)](e, xx_mask, batch) + return self.norm(e).reshape(batch, length, -1), x_mask, ilens diff --git a/espnet/nets/chainer_backend/transformer/encoder_layer.py b/espnet/nets/chainer_backend/transformer/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..b742ef34ec36e1097a85e51313f723be3bba25ef --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/encoder_layer.py @@ -0,0 +1,60 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's Encoder Block.""" + +import chainer + +import chainer.functions as F + +from espnet.nets.chainer_backend.transformer.attention import MultiHeadAttention +from espnet.nets.chainer_backend.transformer.layer_norm import LayerNorm +from espnet.nets.chainer_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) + + +class EncoderLayer(chainer.Chain): + """Single encoder layer module. + + Args: + n_units (int): Number of input/output dimension of a FeedForward layer. + d_units (int): Number of units of hidden layer in a FeedForward layer. + h (int): Number of attention heads. + dropout (float): Dropout rate + + """ + + def __init__( + self, n_units, d_units=0, h=8, dropout=0.1, initialW=None, initial_bias=None + ): + """Initialize EncoderLayer.""" + super(EncoderLayer, self).__init__() + with self.init_scope(): + self.self_attn = MultiHeadAttention( + n_units, + h, + dropout=dropout, + initialW=initialW, + initial_bias=initial_bias, + ) + self.feed_forward = PositionwiseFeedForward( + n_units, + d_units=d_units, + dropout=dropout, + initialW=initialW, + initial_bias=initial_bias, + ) + self.norm1 = LayerNorm(n_units) + self.norm2 = LayerNorm(n_units) + self.dropout = dropout + self.n_units = n_units + + def forward(self, e, xx_mask, batch): + """Forward Positional Encoding.""" + n_e = self.norm1(e) + n_e = self.self_attn(n_e, mask=xx_mask, batch=batch) + e = e + F.dropout(n_e, self.dropout) + + n_e = self.norm2(e) + n_e = self.feed_forward(n_e) + e = e + F.dropout(n_e, self.dropout) + return e diff --git a/espnet/nets/chainer_backend/transformer/label_smoothing_loss.py b/espnet/nets/chainer_backend/transformer/label_smoothing_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..5aebc58a625d9b7b0652fcb603a28870bdac7470 --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/label_smoothing_loss.py @@ -0,0 +1,71 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's Label Smootion loss.""" + +import logging + +import chainer + +import chainer.functions as F + + +class LabelSmoothingLoss(chainer.Chain): + """Label Smoothing Loss. + + Args: + smoothing (float): smoothing rate (0.0 means the conventional CE). + n_target_vocab (int): number of classes. + normalize_length (bool): normalize loss by sequence length if True. + + """ + + def __init__(self, smoothing, n_target_vocab, normalize_length=False, ignore_id=-1): + """Initialize Loss.""" + super(LabelSmoothingLoss, self).__init__() + self.use_label_smoothing = False + if smoothing > 0.0: + logging.info("Use label smoothing") + self.smoothing = smoothing + self.confidence = 1.0 - smoothing + self.use_label_smoothing = True + self.n_target_vocab = n_target_vocab + self.normalize_length = normalize_length + self.ignore_id = ignore_id + self.acc = None + + def forward(self, ys_block, ys_pad): + """Forward Loss. + + Args: + ys_block (chainer.Variable): Predicted labels. + ys_pad (chainer.Variable): Target (true) labels. + + Returns: + float: Training loss. + + """ + # Output (all together at once for efficiency) + batch, length, dims = ys_block.shape + concat_logit_block = ys_block.reshape(-1, dims) + + # Target reshape + concat_t_block = ys_pad.reshape((batch * length)) + ignore_mask = concat_t_block >= 0 + n_token = ignore_mask.sum() + normalizer = n_token if self.normalize_length else batch + + if not self.use_label_smoothing: + loss = F.softmax_cross_entropy(concat_logit_block, concat_t_block) + loss = loss * n_token / normalizer + else: + log_prob = F.log_softmax(concat_logit_block) + broad_ignore_mask = self.xp.broadcast_to( + ignore_mask[:, None], concat_logit_block.shape + ) + pre_loss = ( + ignore_mask * log_prob[self.xp.arange(batch * length), concat_t_block] + ) + loss = -F.sum(pre_loss) / normalizer + label_smoothing = broad_ignore_mask * -1.0 / self.n_target_vocab * log_prob + label_smoothing = F.sum(label_smoothing) / normalizer + loss = self.confidence * loss + self.smoothing * label_smoothing + return loss diff --git a/espnet/nets/chainer_backend/transformer/layer_norm.py b/espnet/nets/chainer_backend/transformer/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..338aab75a3ae22b4bd38e893a06f75e7eef6bfc0 --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/layer_norm.py @@ -0,0 +1,16 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's Label Smootion loss.""" + +import chainer.links as L + + +class LayerNorm(L.LayerNormalization): + """Redirect to L.LayerNormalization.""" + + def __init__(self, dims, eps=1e-12): + """Initialize LayerNorm.""" + super(LayerNorm, self).__init__(size=dims, eps=eps) + + def __call__(self, e): + """Forward LayerNorm.""" + return super(LayerNorm, self).__call__(e) diff --git a/espnet/nets/chainer_backend/transformer/mask.py b/espnet/nets/chainer_backend/transformer/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..e83a1b1f5cf0b5d28992c1fe79f96304eef98dd9 --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/mask.py @@ -0,0 +1,19 @@ +"""Create mask for subsequent steps.""" + + +def make_history_mask(xp, block): + """Prepare the history mask. + + Args: + block (ndarray): Block with dimensions: (B x S). + Returns: + ndarray, np.ndarray: History mask with dimensions (B, S, S). + + """ + batch, length = block.shape + arange = xp.arange(length) + history_mask = (arange[None] <= arange[:, None])[ + None, + ] + history_mask = xp.broadcast_to(history_mask, (batch, length, length)) + return history_mask diff --git a/espnet/nets/chainer_backend/transformer/positionwise_feed_forward.py b/espnet/nets/chainer_backend/transformer/positionwise_feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..f6d5a7c1a46e906cc8a3c47a013f2630ca2be2bd --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/positionwise_feed_forward.py @@ -0,0 +1,66 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's Positionwise Feedforward.""" + +import chainer + +import chainer.functions as F +import chainer.links as L + +import numpy as np + + +class PositionwiseFeedForward(chainer.Chain): + """Positionwise feed forward. + + Args: + :param int idim: input dimenstion + :param int hidden_units: number of hidden units + :param float dropout_rate: dropout rate + + """ + + def __init__( + self, n_units, d_units=0, dropout=0.1, initialW=None, initial_bias=None + ): + """Initialize PositionwiseFeedForward. + + Args: + n_units (int): Input dimension. + d_units (int, optional): Output dimension of hidden layer. + dropout (float, optional): Dropout ratio. + initialW (int, optional): Initializer to initialize the weight. + initial_bias (bool, optional): Initializer to initialize the bias. + + """ + super(PositionwiseFeedForward, self).__init__() + n_inner_units = d_units if d_units > 0 else n_units * 4 + with self.init_scope(): + stvd = 1.0 / np.sqrt(n_units) + self.w_1 = L.Linear( + n_units, + n_inner_units, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + stvd = 1.0 / np.sqrt(n_inner_units) + self.w_2 = L.Linear( + n_inner_units, + n_units, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + self.act = F.relu + self.dropout = dropout + + def __call__(self, e): + """Initialize PositionwiseFeedForward. + + Args: + e (chainer.Variable): Input variable. + + Return: + chainer.Variable: Output variable. + + """ + e = F.dropout(self.act(self.w_1(e)), self.dropout) + return self.w_2(e) diff --git a/espnet/nets/chainer_backend/transformer/subsampling.py b/espnet/nets/chainer_backend/transformer/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba486c871fdab8076b375274c3f10baf8716d81 --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/subsampling.py @@ -0,0 +1,116 @@ +# encoding: utf-8 +"""Class Declaration of Transformer's Input layers.""" + +import chainer + +import chainer.functions as F +import chainer.links as L + +from espnet.nets.chainer_backend.transformer.embedding import PositionalEncoding + +import logging +import numpy as np + + +class Conv2dSubsampling(chainer.Chain): + """Convolutional 2D subsampling (to 1/4 length). + + :param int idim: input dim + :param int odim: output dim + :param flaot dropout_rate: dropout rate + + """ + + def __init__( + self, channels, idim, dims, dropout=0.1, initialW=None, initial_bias=None + ): + """Initialize Conv2dSubsampling.""" + super(Conv2dSubsampling, self).__init__() + self.dropout = dropout + with self.init_scope(): + # Standard deviation for Conv2D with 1 channel and kernel 3 x 3. + n = 1 * 3 * 3 + stvd = 1.0 / np.sqrt(n) + self.conv1 = L.Convolution2D( + 1, + channels, + 3, + stride=2, + pad=1, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + n = channels * 3 * 3 + stvd = 1.0 / np.sqrt(n) + self.conv2 = L.Convolution2D( + channels, + channels, + 3, + stride=2, + pad=1, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + stvd = 1.0 / np.sqrt(dims) + self.out = L.Linear( + idim, + dims, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + self.pe = PositionalEncoding(dims, dropout) + + def forward(self, xs, ilens): + """Subsample x. + + :param chainer.Variable x: input tensor + :return: subsampled x and mask + + """ + xs = self.xp.array(xs[:, None]) + xs = F.relu(self.conv1(xs)) + xs = F.relu(self.conv2(xs)) + batch, _, length, _ = xs.shape + xs = self.out(F.swapaxes(xs, 1, 2).reshape(batch * length, -1)) + xs = self.pe(xs.reshape(batch, length, -1)) + # change ilens accordingly + ilens = np.ceil(np.array(ilens, dtype=np.float32) / 2).astype(np.int) + ilens = np.ceil(np.array(ilens, dtype=np.float32) / 2).astype(np.int) + return xs, ilens + + +class LinearSampling(chainer.Chain): + """Linear 1D subsampling. + + :param int idim: input dim + :param int odim: output dim + :param flaot dropout_rate: dropout rate + + """ + + def __init__(self, idim, dims, dropout=0.1, initialW=None, initial_bias=None): + """Initialize LinearSampling.""" + super(LinearSampling, self).__init__() + stvd = 1.0 / np.sqrt(dims) + self.dropout = dropout + with self.init_scope(): + self.linear = L.Linear( + idim, + dims, + initialW=initialW(scale=stvd), + initial_bias=initial_bias(scale=stvd), + ) + self.pe = PositionalEncoding(dims, dropout) + + def forward(self, xs, ilens): + """Subsample x. + + :param chainer.Variable x: input tensor + :return: subsampled x and mask + + """ + logging.info(xs.shape) + xs = self.linear(xs, n_batch_axes=2) + logging.info(xs.shape) + xs = self.pe(xs) + return xs, ilens diff --git a/espnet/nets/chainer_backend/transformer/training.py b/espnet/nets/chainer_backend/transformer/training.py new file mode 100644 index 0000000000000000000000000000000000000000..e6a98651f36e099836a40af6086c6ebb6988e22a --- /dev/null +++ b/espnet/nets/chainer_backend/transformer/training.py @@ -0,0 +1,320 @@ +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +"""Class Declaration of Transformer's Training Subprocess.""" +import collections +import logging +import math +import six + +from chainer import cuda +from chainer import functions as F +from chainer import training +from chainer.training import extension +from chainer.training.updaters.multiprocess_parallel_updater import gather_grads +from chainer.training.updaters.multiprocess_parallel_updater import gather_params +from chainer.training.updaters.multiprocess_parallel_updater import scatter_grads +import numpy as np + + +# copied from https://github.com/chainer/chainer/blob/master/chainer/optimizer.py +def sum_sqnorm(arr): + """Calculate the norm of the array. + + Args: + arr (numpy.ndarray) + + Returns: + Float: Sum of the norm calculated from the given array. + + """ + sq_sum = collections.defaultdict(float) + for x in arr: + with cuda.get_device_from_array(x) as dev: + if x is not None: + x = x.ravel() + s = x.dot(x) + sq_sum[int(dev)] += s + return sum([float(i) for i in six.itervalues(sq_sum)]) + + +class CustomUpdater(training.StandardUpdater): + """Custom updater for chainer. + + Args: + train_iter (iterator | dict[str, iterator]): Dataset iterator for the + training dataset. It can also be a dictionary that maps strings to + iterators. If this is just an iterator, then the iterator is + registered by the name ``'main'``. + optimizer (optimizer | dict[str, optimizer]): Optimizer to update + parameters. It can also be a dictionary that maps strings to + optimizers. If this is just an optimizer, then the optimizer is + registered by the name ``'main'``. + converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter + function to build input arrays. Each batch extracted by the main + iterator and the ``device`` option are passed to this function. + :func:`chainer.dataset.concat_examples` is used by default. + device (int or dict): The destination device info to send variables. In the + case of cpu or single gpu, `device=-1 or 0`, respectively. + In the case of multi-gpu, `device={"main":0, "sub_1": 1, ...}`. + accum_grad (int):The number of gradient accumulation. if set to 2, the network + parameters will be updated once in twice, + i.e. actual batchsize will be doubled. + + """ + + def __init__(self, train_iter, optimizer, converter, device, accum_grad=1): + """Initialize Custom Updater.""" + super(CustomUpdater, self).__init__( + train_iter, optimizer, converter=converter, device=device + ) + self.accum_grad = accum_grad + self.forward_count = 0 + self.start = True + self.device = device + logging.debug("using custom converter for transformer") + + # The core part of the update routine can be customized by overriding. + def update_core(self): + """Process main update routine for Custom Updater.""" + train_iter = self.get_iterator("main") + optimizer = self.get_optimizer("main") + + # Get batch and convert into variables + batch = train_iter.next() + x = self.converter(batch, self.device) + if self.start: + optimizer.target.cleargrads() + self.start = False + + # Compute the loss at this time step and accumulate it + loss = optimizer.target(*x) / self.accum_grad + loss.backward() # Backprop + + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + # compute the gradient norm to check if it is normal or not + grad_norm = np.sqrt( + sum_sqnorm([p.grad for p in optimizer.target.params(False)]) + ) + logging.info("grad norm={}".format(grad_norm)) + if math.isnan(grad_norm): + logging.warning("grad norm is nan. Do not update model.") + else: + optimizer.update() + optimizer.target.cleargrads() # Clear the parameter gradients + + def update(self): + """Update step for Custom Updater.""" + self.update_core() + if self.forward_count == 0: + self.iteration += 1 + + +class CustomParallelUpdater(training.updaters.MultiprocessParallelUpdater): + """Custom Parallel Updater for chainer. + + Defines the main update routine. + + Args: + train_iter (iterator | dict[str, iterator]): Dataset iterator for the + training dataset. It can also be a dictionary that maps strings to + iterators. If this is just an iterator, then the iterator is + registered by the name ``'main'``. + optimizer (optimizer | dict[str, optimizer]): Optimizer to update + parameters. It can also be a dictionary that maps strings to + optimizers. If this is just an optimizer, then the optimizer is + registered by the name ``'main'``. + converter (espnet.asr.chainer_backend.asr.CustomConverter): Converter + function to build input arrays. Each batch extracted by the main + iterator and the ``device`` option are passed to this function. + :func:`chainer.dataset.concat_examples` is used by default. + device (torch.device): Device to which the training data is sent. Negative value + indicates the host memory (CPU). + accum_grad (int):The number of gradient accumulation. if set to 2, the network + parameters will be updated once in twice, + i.e. actual batchsize will be doubled. + + """ + + def __init__(self, train_iters, optimizer, converter, devices, accum_grad=1): + """Initialize custom parallel updater.""" + from cupy.cuda import nccl + + super(CustomParallelUpdater, self).__init__( + train_iters, optimizer, converter=converter, devices=devices + ) + self.accum_grad = accum_grad + self.forward_count = 0 + self.nccl = nccl + logging.debug("using custom parallel updater for transformer") + + # The core part of the update routine can be customized by overriding. + def update_core(self): + """Process main update routine for Custom Parallel Updater.""" + self.setup_workers() + + self._send_message(("update", None)) + with cuda.Device(self._devices[0]): + # For reducing memory + optimizer = self.get_optimizer("main") + batch = self.get_iterator("main").next() + x = self.converter(batch, self._devices[0]) + + loss = self._master(*x) / self.accum_grad + loss.backward() + + # NCCL: reduce grads + null_stream = cuda.Stream.null + if self.comm is not None: + gg = gather_grads(self._master) + self.comm.reduce( + gg.data.ptr, + gg.data.ptr, + gg.size, + self.nccl.NCCL_FLOAT, + self.nccl.NCCL_SUM, + 0, + null_stream.ptr, + ) + scatter_grads(self._master, gg) + del gg + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + # check gradient value + grad_norm = np.sqrt( + sum_sqnorm([p.grad for p in optimizer.target.params(False)]) + ) + logging.info("grad norm={}".format(grad_norm)) + + # update + if math.isnan(grad_norm): + logging.warning("grad norm is nan. Do not update model.") + else: + optimizer.update() + self._master.cleargrads() + + if self.comm is not None: + gp = gather_params(self._master) + self.comm.bcast( + gp.data.ptr, gp.size, self.nccl.NCCL_FLOAT, 0, null_stream.ptr + ) + + def update(self): + """Update step for Custom Parallel Updater.""" + self.update_core() + if self.forward_count == 0: + self.iteration += 1 + + +class VaswaniRule(extension.Extension): + """Trainer extension to shift an optimizer attribute magically by Vaswani. + + Args: + attr (str): Name of the attribute to shift. + rate (float): Rate of the exponential shift. This value is multiplied + to the attribute at each call. + init (float): Initial value of the attribute. If it is ``None``, the + extension extracts the attribute at the first call and uses it as + the initial value. + target (float): Target value of the attribute. If the attribute reaches + this value, the shift stops. + optimizer (~chainer.Optimizer): Target optimizer to adjust the + attribute. If it is ``None``, the main optimizer of the updater is + used. + + """ + + def __init__( + self, + attr, + d, + warmup_steps=4000, + init=None, + target=None, + optimizer=None, + scale=1.0, + ): + """Initialize Vaswani rule extension.""" + self._attr = attr + self._d_inv05 = d ** (-0.5) * scale + self._warmup_steps_inv15 = warmup_steps ** (-1.5) + self._init = init + self._target = target + self._optimizer = optimizer + self._t = 0 + self._last_value = None + + def initialize(self, trainer): + """Initialize Optimizer values.""" + optimizer = self._get_optimizer(trainer) + # ensure that _init is set + if self._init is None: + self._init = self._d_inv05 * (1.0 * self._warmup_steps_inv15) + if self._last_value is not None: # resuming from a snapshot + self._update_value(optimizer, self._last_value) + else: + self._update_value(optimizer, self._init) + + def __call__(self, trainer): + """Forward extension.""" + self._t += 1 + optimizer = self._get_optimizer(trainer) + value = self._d_inv05 * min( + self._t ** (-0.5), self._t * self._warmup_steps_inv15 + ) + self._update_value(optimizer, value) + + def serialize(self, serializer): + """Serialize extension.""" + self._t = serializer("_t", self._t) + self._last_value = serializer("_last_value", self._last_value) + + def _get_optimizer(self, trainer): + """Obtain optimizer from trainer.""" + return self._optimizer or trainer.updater.get_optimizer("main") + + def _update_value(self, optimizer, value): + """Update requested variable values.""" + setattr(optimizer, self._attr, value) + self._last_value = value + + +class CustomConverter(object): + """Custom Converter. + + Args: + subsampling_factor (int): The subsampling factor. + + """ + + def __init__(self): + """Initialize subsampling.""" + pass + + def __call__(self, batch, device): + """Perform subsampling. + + Args: + batch (list): Batch that will be sabsampled. + device (chainer.backend.Device): CPU or GPU device. + + Returns: + chainer.Variable: xp.array that are padded and subsampled from batch. + xp.array: xp.array of the length of the mini-batches. + chainer.Variable: xp.array that are padded and subsampled from batch. + + """ + # For transformer, data is processed in CPU. + # batch should be located in list + assert len(batch) == 1 + xs, ys = batch[0] + xs = F.pad_sequence(xs, padding=-1).data + # get batch of lengths of input sequences + ilens = np.array([x.shape[0] for x in xs], dtype=np.int32) + return xs, ilens, ys diff --git a/espnet/nets/ctc_prefix_score.py b/espnet/nets/ctc_prefix_score.py new file mode 100644 index 0000000000000000000000000000000000000000..ede03285164afa7f40b7de35517c051006ddc49a --- /dev/null +++ b/espnet/nets/ctc_prefix_score.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 + +# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import torch + +import numpy as np +import six + + +class CTCPrefixScoreTH(object): + """Batch processing of CTCPrefixScore + + which is based on Algorithm 2 in WATANABE et al. + "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION," + but extended to efficiently compute the label probablities for multiple + hypotheses simultaneously + See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based + Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019. + """ + + def __init__(self, x, xlens, blank, eos, margin=0): + """Construct CTC prefix scorer + + :param torch.Tensor x: input label posterior sequences (B, T, O) + :param torch.Tensor xlens: input lengths (B,) + :param int blank: blank label id + :param int eos: end-of-sequence id + :param int margin: margin parameter for windowing (0 means no windowing) + """ + # In the comment lines, + # we assume T: input_length, B: batch size, W: beam width, O: output dim. + self.logzero = -10000000000.0 + self.blank = blank + self.eos = eos + self.batch = x.size(0) + self.input_length = x.size(1) + self.odim = x.size(2) + self.dtype = x.dtype + self.device = ( + torch.device("cuda:%d" % x.get_device()) + if x.is_cuda + else torch.device("cpu") + ) + # Pad the rest of posteriors in the batch + # TODO(takaaki-hori): need a better way without for-loops + for i, l in enumerate(xlens): + if l < self.input_length: + x[i, l:, :] = self.logzero + x[i, l:, blank] = 0 + # Reshape input x + xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O) + xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) + self.x = torch.stack([xn, xb]) # (2, T, B, O) + self.end_frames = torch.as_tensor(xlens) - 1 + + # Setup CTC windowing + self.margin = margin + if margin > 0: + self.frame_ids = torch.arange( + self.input_length, dtype=self.dtype, device=self.device + ) + # Base indices for index conversion + self.idx_bh = None + self.idx_b = torch.arange(self.batch, device=self.device) + self.idx_bo = (self.idx_b * self.odim).unsqueeze(1) + + def __call__(self, y, state, scoring_ids=None, att_w=None): + """Compute CTC prefix scores for next labels + + :param list y: prefix label sequences + :param tuple state: previous CTC state + :param torch.Tensor pre_scores: scores for pre-selection of hypotheses (BW, O) + :param torch.Tensor att_w: attention weights to decide CTC window + :return new_state, ctc_local_scores (BW, O) + """ + output_length = len(y[0]) - 1 # ignore sos + last_ids = [yi[-1] for yi in y] # last output label ids + n_bh = len(last_ids) # batch * hyps + n_hyps = n_bh // self.batch # assuming each utterance has the same # of hyps + self.scoring_num = scoring_ids.size(-1) if scoring_ids is not None else 0 + # prepare state info + if state is None: + r_prev = torch.full( + (self.input_length, 2, self.batch, n_hyps), + self.logzero, + dtype=self.dtype, + device=self.device, + ) + r_prev[:, 1] = torch.cumsum(self.x[0, :, :, self.blank], 0).unsqueeze(2) + r_prev = r_prev.view(-1, 2, n_bh) + s_prev = 0.0 + f_min_prev = 0 + f_max_prev = 1 + else: + r_prev, s_prev, f_min_prev, f_max_prev = state + + # select input dimensions for scoring + if self.scoring_num > 0: + scoring_idmap = torch.full( + (n_bh, self.odim), -1, dtype=torch.long, device=self.device + ) + snum = self.scoring_num + if self.idx_bh is None or n_bh > len(self.idx_bh): + self.idx_bh = torch.arange(n_bh, device=self.device).view(-1, 1) + scoring_idmap[self.idx_bh[:n_bh], scoring_ids] = torch.arange( + snum, device=self.device + ) + scoring_idx = ( + scoring_ids + self.idx_bo.repeat(1, n_hyps).view(-1, 1) + ).view(-1) + x_ = torch.index_select( + self.x.view(2, -1, self.batch * self.odim), 2, scoring_idx + ).view(2, -1, n_bh, snum) + else: + scoring_ids = None + scoring_idmap = None + snum = self.odim + x_ = self.x.unsqueeze(3).repeat(1, 1, 1, n_hyps, 1).view(2, -1, n_bh, snum) + + # new CTC forward probs are prepared as a (T x 2 x BW x S) tensor + # that corresponds to r_t^n(h) and r_t^b(h) in a batch. + r = torch.full( + (self.input_length, 2, n_bh, snum), + self.logzero, + dtype=self.dtype, + device=self.device, + ) + if output_length == 0: + r[0, 0] = x_[0, 0] + + r_sum = torch.logsumexp(r_prev, 1) + log_phi = r_sum.unsqueeze(2).repeat(1, 1, snum) + if scoring_ids is not None: + for idx in range(n_bh): + pos = scoring_idmap[idx, last_ids[idx]] + if pos >= 0: + log_phi[:, idx, pos] = r_prev[:, 1, idx] + else: + for idx in range(n_bh): + log_phi[:, idx, last_ids[idx]] = r_prev[:, 1, idx] + + # decide start and end frames based on attention weights + if att_w is not None and self.margin > 0: + f_arg = torch.matmul(att_w, self.frame_ids) + f_min = max(int(f_arg.min().cpu()), f_min_prev) + f_max = max(int(f_arg.max().cpu()), f_max_prev) + start = min(f_max_prev, max(f_min - self.margin, output_length, 1)) + end = min(f_max + self.margin, self.input_length) + else: + f_min = f_max = 0 + start = max(output_length, 1) + end = self.input_length + + # compute forward probabilities log(r_t^n(h)) and log(r_t^b(h)) + for t in range(start, end): + rp = r[t - 1] + rr = torch.stack([rp[0], log_phi[t - 1], rp[0], rp[1]]).view( + 2, 2, n_bh, snum + ) + r[t] = torch.logsumexp(rr, 1) + x_[:, t] + + # compute log prefix probabilites log(psi) + log_phi_x = torch.cat((log_phi[0].unsqueeze(0), log_phi[:-1]), dim=0) + x_[0] + if scoring_ids is not None: + log_psi = torch.full( + (n_bh, self.odim), self.logzero, dtype=self.dtype, device=self.device + ) + log_psi_ = torch.logsumexp( + torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), + dim=0, + ) + for si in range(n_bh): + log_psi[si, scoring_ids[si]] = log_psi_[si] + else: + log_psi = torch.logsumexp( + torch.cat((log_phi_x[start:end], r[start - 1, 0].unsqueeze(0)), dim=0), + dim=0, + ) + + for si in range(n_bh): + log_psi[si, self.eos] = r_sum[self.end_frames[si // n_hyps], si] + + # exclude blank probs + log_psi[:, self.blank] = self.logzero + + return (log_psi - s_prev), (r, log_psi, f_min, f_max, scoring_idmap) + + def index_select_state(self, state, best_ids): + """Select CTC states according to best ids + + :param state : CTC state + :param best_ids : index numbers selected by beam pruning (B, W) + :return selected_state + """ + r, s, f_min, f_max, scoring_idmap = state + # convert ids to BHO space + n_bh = len(s) + n_hyps = n_bh // self.batch + vidx = (best_ids + (self.idx_b * (n_hyps * self.odim)).view(-1, 1)).view(-1) + # select hypothesis scores + s_new = torch.index_select(s.view(-1), 0, vidx) + s_new = s_new.view(-1, 1).repeat(1, self.odim).view(n_bh, self.odim) + # convert ids to BHS space (S: scoring_num) + if scoring_idmap is not None: + snum = self.scoring_num + hyp_idx = (best_ids // self.odim + (self.idx_b * n_hyps).view(-1, 1)).view( + -1 + ) + label_ids = torch.fmod(best_ids, self.odim).view(-1) + score_idx = scoring_idmap[hyp_idx, label_ids] + score_idx[score_idx == -1] = 0 + vidx = score_idx + hyp_idx * snum + else: + snum = self.odim + # select forward probabilities + r_new = torch.index_select(r.view(-1, 2, n_bh * snum), 2, vidx).view( + -1, 2, n_bh + ) + return r_new, s_new, f_min, f_max + + def extend_prob(self, x): + """Extend CTC prob. + + :param torch.Tensor x: input label posterior sequences (B, T, O) + """ + + if self.x.shape[1] < x.shape[1]: # self.x (2,T,B,O); x (B,T,O) + # Pad the rest of posteriors in the batch + # TODO(takaaki-hori): need a better way without for-loops + xlens = [x.size(1)] + for i, l in enumerate(xlens): + if l < self.input_length: + x[i, l:, :] = self.logzero + x[i, l:, self.blank] = 0 + tmp_x = self.x + xn = x.transpose(0, 1) # (B, T, O) -> (T, B, O) + xb = xn[:, :, self.blank].unsqueeze(2).expand(-1, -1, self.odim) + self.x = torch.stack([xn, xb]) # (2, T, B, O) + self.x[:, : tmp_x.shape[1], :, :] = tmp_x + self.input_length = x.size(1) + self.end_frames = torch.as_tensor(xlens) - 1 + + def extend_state(self, state): + """Compute CTC prefix state. + + + :param state : CTC state + :return ctc_state + """ + + if state is None: + # nothing to do + return state + else: + r_prev, s_prev, f_min_prev, f_max_prev = state + + r_prev_new = torch.full( + (self.input_length, 2), + self.logzero, + dtype=self.dtype, + device=self.device, + ) + start = max(r_prev.shape[0], 1) + r_prev_new[0:start] = r_prev + for t in six.moves.range(start, self.input_length): + r_prev_new[t, 1] = r_prev_new[t - 1, 1] + self.x[0, t, :, self.blank] + + return (r_prev_new, s_prev, f_min_prev, f_max_prev) + + +class CTCPrefixScore(object): + """Compute CTC label sequence scores + + which is based on Algorithm 2 in WATANABE et al. + "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION," + but extended to efficiently compute the probablities of multiple labels + simultaneously + """ + + def __init__(self, x, blank, eos, xp): + self.xp = xp + self.logzero = -10000000000.0 + self.blank = blank + self.eos = eos + self.input_length = len(x) + self.x = x + + def initial_state(self): + """Obtain an initial CTC state + + :return: CTC state + """ + # initial CTC state is made of a frame x 2 tensor that corresponds to + # r_t^n() and r_t^b(), where 0 and 1 of axis=1 represent + # superscripts n and b (non-blank and blank), respectively. + r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32) + r[0, 1] = self.x[0, self.blank] + for i in six.moves.range(1, self.input_length): + r[i, 1] = r[i - 1, 1] + self.x[i, self.blank] + return r + + def __call__(self, y, cs, r_prev): + """Compute CTC prefix scores for next labels + + :param y : prefix label sequence + :param cs : array of next labels + :param r_prev: previous CTC state + :return ctc_scores, ctc_states + """ + # initialize CTC states + output_length = len(y) - 1 # ignore sos + # new CTC states are prepared as a frame x (n or b) x n_labels tensor + # that corresponds to r_t^n(h) and r_t^b(h). + r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32) + xs = self.x[:, cs] + if output_length == 0: + r[0, 0] = xs[0] + r[0, 1] = self.logzero + else: + r[output_length - 1] = self.logzero + + # prepare forward probabilities for the last label + r_sum = self.xp.logaddexp( + r_prev[:, 0], r_prev[:, 1] + ) # log(r_t^n(g) + r_t^b(g)) + last = y[-1] + if output_length > 0 and last in cs: + log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32) + for i in six.moves.range(len(cs)): + log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1] + else: + log_phi = r_sum + + # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)), + # and log prefix probabilites log(psi) + start = max(output_length, 1) + log_psi = r[start - 1, 0] + for t in six.moves.range(start, self.input_length): + r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t] + r[t, 1] = ( + self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank] + ) + log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t]) + + # get P(...eos|X) that ends with the prefix itself + eos_pos = self.xp.where(cs == self.eos)[0] + if len(eos_pos) > 0: + log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g)) + + # exclude blank probs + blank_pos = self.xp.where(cs == self.blank)[0] + if len(blank_pos) > 0: + log_psi[blank_pos] = self.logzero + + # return the log prefix probability and CTC states, where the label axis + # of the CTC states is moved to the first axis to slice it easily + return log_psi, self.xp.rollaxis(r, 2) diff --git a/espnet/nets/e2e_asr_common.py b/espnet/nets/e2e_asr_common.py new file mode 100644 index 0000000000000000000000000000000000000000..17d2349afb02e2b3c5c6b715757801dc18b8101c --- /dev/null +++ b/espnet/nets/e2e_asr_common.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Common functions for ASR.""" + +import json +import logging +import sys + +import editdistance +from itertools import groupby +import numpy as np +import six + + +def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): + """End detection. + + described in Eq. (50) of S. Watanabe et al + "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" + + :param ended_hyps: + :param i: + :param M: + :param D_end: + :return: + """ + if len(ended_hyps) == 0: + return False + count = 0 + best_hyp = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[0] + for m in six.moves.range(M): + # get ended_hyps with their length is i - m + hyp_length = i - m + hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length] + if len(hyps_same_length) > 0: + best_hyp_same_length = sorted( + hyps_same_length, key=lambda x: x["score"], reverse=True + )[0] + if best_hyp_same_length["score"] - best_hyp["score"] < D_end: + count += 1 + + if count == M: + return True + else: + return False + + +# TODO(takaaki-hori): add different smoothing methods +def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0): + """Obtain label distribution for loss smoothing. + + :param odim: + :param lsm_type: + :param blank: + :param transcript: + :return: + """ + if transcript is not None: + with open(transcript, "rb") as f: + trans_json = json.load(f)["utts"] + + if lsm_type == "unigram": + assert transcript is not None, ( + "transcript is required for %s label smoothing" % lsm_type + ) + labelcount = np.zeros(odim) + for k, v in trans_json.items(): + ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()]) + # to avoid an error when there is no text in an uttrance + if len(ids) > 0: + labelcount[ids] += 1 + labelcount[odim - 1] = len(transcript) # count + labelcount[labelcount == 0] = 1 # flooring + labelcount[blank] = 0 # remove counts for blank + labeldist = labelcount.astype(np.float32) / np.sum(labelcount) + else: + logging.error("Error: unexpected label smoothing type: %s" % lsm_type) + sys.exit() + + return labeldist + + +def get_vgg2l_odim(idim, in_channel=3, out_channel=128): + """Return the output size of the VGG frontend. + + :param in_channel: input channel size + :param out_channel: output channel size + :return: output size + :rtype int + """ + idim = idim / in_channel + idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling + idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling + return int(idim) * out_channel # numer of channels + + +class ErrorCalculator(object): + """Calculate CER and WER for E2E_ASR and CTC models during training. + + :param y_hats: numpy array with predicted text + :param y_pads: numpy array with true (target) text + :param char_list: + :param sym_space: + :param sym_blank: + :return: + """ + + def __init__( + self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False + ): + """Construct an ErrorCalculator object.""" + super(ErrorCalculator, self).__init__() + + self.report_cer = report_cer + self.report_wer = report_wer + + self.char_list = char_list + self.space = sym_space + self.blank = sym_blank + self.idx_blank = self.char_list.index(self.blank) + if self.space in self.char_list: + self.idx_space = self.char_list.index(self.space) + else: + self.idx_space = None + + def __call__(self, ys_hat, ys_pad, is_ctc=False): + """Calculate sentence-level WER/CER score. + + :param torch.Tensor ys_hat: prediction (batch, seqlen) + :param torch.Tensor ys_pad: reference (batch, seqlen) + :param bool is_ctc: calculate CER score for CTC + :return: sentence-level WER score + :rtype float + :return: sentence-level CER score + :rtype float + """ + cer, wer = None, None + if is_ctc: + return self.calculate_cer_ctc(ys_hat, ys_pad) + elif not self.report_cer and not self.report_wer: + return cer, wer + + seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad) + if self.report_cer: + cer = self.calculate_cer(seqs_hat, seqs_true) + + if self.report_wer: + wer = self.calculate_wer(seqs_hat, seqs_true) + return cer, wer + + def calculate_cer_ctc(self, ys_hat, ys_pad): + """Calculate sentence-level CER score for CTC. + + :param torch.Tensor ys_hat: prediction (batch, seqlen) + :param torch.Tensor ys_pad: reference (batch, seqlen) + :return: average sentence-level CER score + :rtype float + """ + cers, char_ref_lens = [], [] + for i, y in enumerate(ys_hat): + y_hat = [x[0] for x in groupby(y)] + y_true = ys_pad[i] + seq_hat, seq_true = [], [] + for idx in y_hat: + idx = int(idx) + if idx != -1 and idx != self.idx_blank and idx != self.idx_space: + seq_hat.append(self.char_list[int(idx)]) + + for idx in y_true: + idx = int(idx) + if idx != -1 and idx != self.idx_blank and idx != self.idx_space: + seq_true.append(self.char_list[int(idx)]) + + hyp_chars = "".join(seq_hat) + ref_chars = "".join(seq_true) + if len(ref_chars) > 0: + cers.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + + cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None + return cer_ctc + + def convert_to_char(self, ys_hat, ys_pad): + """Convert index to character. + + :param torch.Tensor seqs_hat: prediction (batch, seqlen) + :param torch.Tensor seqs_true: reference (batch, seqlen) + :return: token list of prediction + :rtype list + :return: token list of reference + :rtype list + """ + seqs_hat, seqs_true = [], [] + for i, y_hat in enumerate(ys_hat): + y_true = ys_pad[i] + eos_true = np.where(y_true == -1)[0] + ymax = eos_true[0] if len(eos_true) > 0 else len(y_true) + # NOTE: padding index (-1) in y_true is used to pad y_hat + seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] + seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + seq_hat_text = "".join(seq_hat).replace(self.space, " ") + seq_hat_text = seq_hat_text.replace(self.blank, "") + seq_true_text = "".join(seq_true).replace(self.space, " ") + seqs_hat.append(seq_hat_text) + seqs_true.append(seq_true_text) + return seqs_hat, seqs_true + + def calculate_cer(self, seqs_hat, seqs_true): + """Calculate sentence-level CER score. + + :param list seqs_hat: prediction + :param list seqs_true: reference + :return: average sentence-level CER score + :rtype float + """ + char_eds, char_ref_lens = [], [] + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_chars = seq_hat_text.replace(" ", "") + ref_chars = seq_true_text.replace(" ", "") + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + return float(sum(char_eds)) / sum(char_ref_lens) + + def calculate_wer(self, seqs_hat, seqs_true): + """Calculate sentence-level WER score. + + :param list seqs_hat: prediction + :param list seqs_true: reference + :return: average sentence-level WER score + :rtype float + """ + word_eds, word_ref_lens = [], [] + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + return float(sum(word_eds)) / sum(word_ref_lens) diff --git a/espnet/nets/e2e_mt_common.py b/espnet/nets/e2e_mt_common.py new file mode 100644 index 0000000000000000000000000000000000000000..0ffd296a469690d5e5a07b60446ad757680288d5 --- /dev/null +++ b/espnet/nets/e2e_mt_common.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Common functions for ST and MT.""" + +import nltk +import numpy as np + + +class ErrorCalculator(object): + """Calculate BLEU for ST and MT models during training. + + :param y_hats: numpy array with predicted text + :param y_pads: numpy array with true (target) text + :param char_list: vocabulary list + :param sym_space: space symbol + :param sym_pad: pad symbol + :param report_bleu: report BLUE score if True + """ + + def __init__(self, char_list, sym_space, sym_pad, report_bleu=False): + """Construct an ErrorCalculator object.""" + super(ErrorCalculator, self).__init__() + self.char_list = char_list + self.space = sym_space + self.pad = sym_pad + self.report_bleu = report_bleu + if self.space in self.char_list: + self.idx_space = self.char_list.index(self.space) + else: + self.idx_space = None + + def __call__(self, ys_hat, ys_pad): + """Calculate corpus-level BLEU score. + + :param torch.Tensor ys_hat: prediction (batch, seqlen) + :param torch.Tensor ys_pad: reference (batch, seqlen) + :return: corpus-level BLEU score in a mini-batch + :rtype float + """ + bleu = None + if not self.report_bleu: + return bleu + + bleu = self.calculate_corpus_bleu(ys_hat, ys_pad) + return bleu + + def calculate_corpus_bleu(self, ys_hat, ys_pad): + """Calculate corpus-level BLEU score in a mini-batch. + + :param torch.Tensor seqs_hat: prediction (batch, seqlen) + :param torch.Tensor seqs_true: reference (batch, seqlen) + :return: corpus-level BLEU score + :rtype float + """ + seqs_hat, seqs_true = [], [] + for i, y_hat in enumerate(ys_hat): + y_true = ys_pad[i] + eos_true = np.where(y_true == -1)[0] + ymax = eos_true[0] if len(eos_true) > 0 else len(y_true) + # NOTE: padding index (-1) in y_true is used to pad y_hat + # because y_hats is not padded with -1 + seq_hat = [self.char_list[int(idx)] for idx in y_hat[:ymax]] + seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + seq_hat_text = "".join(seq_hat).replace(self.space, " ") + seq_hat_text = seq_hat_text.replace(self.pad, "") + seq_true_text = "".join(seq_true).replace(self.space, " ") + seqs_hat.append(seq_hat_text) + seqs_true.append(seq_true_text) + bleu = nltk.bleu_score.corpus_bleu([[ref] for ref in seqs_true], seqs_hat) + return bleu * 100 diff --git a/espnet/nets/lm_interface.py b/espnet/nets/lm_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..0f1e751c4d8c945c8bae3fc4356a4d380fc1e023 --- /dev/null +++ b/espnet/nets/lm_interface.py @@ -0,0 +1,86 @@ +"""Language model interface.""" + +import argparse + +from espnet.nets.scorer_interface import ScorerInterface +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.fill_missing_args import fill_missing_args + + +class LMInterface(ScorerInterface): + """LM Interface for ESPnet model implementation.""" + + @staticmethod + def add_arguments(parser): + """Add arguments to command line argument parser.""" + return parser + + @classmethod + def build(cls, n_vocab: int, **kwargs): + """Initialize this class with python-level args. + + Args: + idim (int): The number of vocabulary. + + Returns: + LMinterface: A new instance of LMInterface. + + """ + # local import to avoid cyclic import in lm_train + from espnet.bin.lm_train import get_parser + + def wrap(parser): + return get_parser(parser, required=False) + + args = argparse.Namespace(**kwargs) + args = fill_missing_args(args, wrap) + args = fill_missing_args(args, cls.add_arguments) + return cls(n_vocab, args) + + def forward(self, x, t): + """Compute LM loss value from buffer sequences. + + Args: + x (torch.Tensor): Input ids. (batch, len) + t (torch.Tensor): Target ids. (batch, len) + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of + loss to backward (scalar), + negative log-likelihood of t: -log p(t) (scalar) and + the number of elements in x (scalar) + + Notes: + The last two return values are used + in perplexity: p(t)^{-n} = exp(-log p(t) / n) + + """ + raise NotImplementedError("forward method is not implemented") + + +predefined_lms = { + "pytorch": { + "default": "espnet.nets.pytorch_backend.lm.default:DefaultRNNLM", + "seq_rnn": "espnet.nets.pytorch_backend.lm.seq_rnn:SequentialRNNLM", + "transformer": "espnet.nets.pytorch_backend.lm.transformer:TransformerLM", + }, + "chainer": {"default": "espnet.lm.chainer_backend.lm:DefaultRNNLM"}, +} + + +def dynamic_import_lm(module, backend): + """Import LM class dynamically. + + Args: + module (str): module_name:class_name or alias in `predefined_lms` + backend (str): NN backend. e.g., pytorch, chainer + + Returns: + type: LM class + + """ + model_class = dynamic_import(module, predefined_lms.get(backend, dict())) + assert issubclass( + model_class, LMInterface + ), f"{module} does not implement LMInterface" + return model_class diff --git a/espnet/nets/mt_interface.py b/espnet/nets/mt_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..0e1e61d0d5bb33293486e94d9f05f4cb659dcd36 --- /dev/null +++ b/espnet/nets/mt_interface.py @@ -0,0 +1,94 @@ +"""MT Interface module.""" +import argparse + +from espnet.bin.asr_train import get_parser +from espnet.utils.fill_missing_args import fill_missing_args + + +class MTInterface: + """MT Interface for ESPnet model implementation.""" + + @staticmethod + def add_arguments(parser): + """Add arguments to parser.""" + return parser + + @classmethod + def build(cls, idim: int, odim: int, **kwargs): + """Initialize this class with python-level args. + + Args: + idim (int): The number of an input feature dim. + odim (int): The number of output vocab. + + Returns: + ASRinterface: A new instance of ASRInterface. + + """ + + def wrap(parser): + return get_parser(parser, required=False) + + args = argparse.Namespace(**kwargs) + args = fill_missing_args(args, wrap) + args = fill_missing_args(args, cls.add_arguments) + return cls(idim, odim, args) + + def forward(self, xs, ilens, ys): + """Compute loss for training. + + :param xs: + For pytorch, batch of padded source sequences torch.Tensor (B, Tmax, idim) + For chainer, list of source sequences chainer.Variable + :param ilens: batch of lengths of source sequences (B) + For pytorch, torch.Tensor + For chainer, list of int + :param ys: + For pytorch, batch of padded source sequences torch.Tensor (B, Lmax) + For chainer, list of source sequences chainer.Variable + :return: loss value + :rtype: torch.Tensor for pytorch, chainer.Variable for chainer + """ + raise NotImplementedError("forward method is not implemented") + + def translate(self, x, trans_args, char_list=None, rnnlm=None): + """Translate x for evaluation. + + :param ndarray x: input acouctic feature (B, T, D) or (T, D) + :param namespace trans_args: argment namespace contraining options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("translate method is not implemented") + + def translate_batch(self, x, trans_args, char_list=None, rnnlm=None): + """Beam search implementation for batch. + + :param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc) + :param namespace trans_args: argument namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("Batch decoding is not supported yet.") + + def calculate_all_attentions(self, xs, ilens, ys): + """Caluculate attention. + + :param list xs: list of padded input sequences [(T1, idim), (T2, idim), ...] + :param ndarray ilens: batch of lengths of input sequences (B) + :param list ys: list of character id sequence tensor [(L1), (L2), (L3), ...] + :return: attention weights (B, Lmax, Tmax) + :rtype: float ndarray + """ + raise NotImplementedError("calculate_all_attentions method is not implemented") + + @property + def attention_plot_class(self): + """Get attention plot class.""" + from espnet.asr.asr_utils import PlotAttentionReport + + return PlotAttentionReport diff --git a/espnet/nets/pytorch_backend/__init__.py b/espnet/nets/pytorch_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/conformer/__init__.py b/espnet/nets/pytorch_backend/conformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/conformer/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/conformer/argument.py b/espnet/nets/pytorch_backend/conformer/argument.py new file mode 100644 index 0000000000000000000000000000000000000000..d5681565256125941daaeff61e050141fcafbeb1 --- /dev/null +++ b/espnet/nets/pytorch_backend/conformer/argument.py @@ -0,0 +1,87 @@ +# Copyright 2020 Hirofumi Inaguma +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Conformer common arguments.""" + + +from distutils.util import strtobool +import logging + + +def add_arguments_conformer_common(group): + """Add Transformer common arguments.""" + group.add_argument( + "--transformer-encoder-pos-enc-layer-type", + type=str, + default="abs_pos", + choices=["abs_pos", "scaled_abs_pos", "rel_pos"], + help="Transformer encoder positional encoding layer type", + ) + group.add_argument( + "--transformer-encoder-activation-type", + type=str, + default="swish", + choices=["relu", "hardtanh", "selu", "swish"], + help="Transformer encoder activation function type", + ) + group.add_argument( + "--macaron-style", + default=False, + type=strtobool, + help="Whether to use macaron style for positionwise layer", + ) + # Attention + group.add_argument( + "--zero-triu", + default=False, + type=strtobool, + help="If true, zero the uppper triangular part of attention matrix.", + ) + # Relative positional encoding + group.add_argument( + "--rel-pos-type", + type=str, + default="legacy", + choices=["legacy", "latest"], + help="Whether to use the latest relative positional encoding or the legacy one." + "The legacy relative positional encoding will be deprecated in the future." + "More Details can be found in https://github.com/espnet/espnet/pull/2816.", + ) + # CNN module + group.add_argument( + "--use-cnn-module", + default=False, + type=strtobool, + help="Use convolution module or not", + ) + group.add_argument( + "--cnn-module-kernel", + default=31, + type=int, + help="Kernel size of convolution module.", + ) + return group + + +def verify_rel_pos_type(args): + """Verify the relative positional encoding type for compatibility. + + Args: + args (Namespace): original arguments + Returns: + args (Namespace): modified arguments + """ + rel_pos_type = getattr(args, "rel_pos_type", None) + if rel_pos_type is None or rel_pos_type == "legacy": + if args.transformer_encoder_pos_enc_layer_type == "rel_pos": + args.transformer_encoder_pos_enc_layer_type = "legacy_rel_pos" + logging.warning( + "Using legacy_rel_pos and it will be deprecated in the future." + ) + if args.transformer_encoder_selfattn_layer_type == "rel_selfattn": + args.transformer_encoder_selfattn_layer_type = "legacy_rel_selfattn" + logging.warning( + "Using legacy_rel_selfattn and it will be deprecated in the future." + ) + + return args diff --git a/espnet/nets/pytorch_backend/conformer/convolution.py b/espnet/nets/pytorch_backend/conformer/convolution.py new file mode 100644 index 0000000000000000000000000000000000000000..6a5d2c30c313e73fa2097bc28721be00aeb6910f --- /dev/null +++ b/espnet/nets/pytorch_backend/conformer/convolution.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""ConvolutionModule definition.""" + +from torch import nn + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + + """ + + def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward(self, x): + """Compute convolution module. + + Args: + x (torch.Tensor): Input tensor (#batch, time, channels). + + Returns: + torch.Tensor: Output tensor (#batch, time, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) + + return x.transpose(1, 2) diff --git a/espnet/nets/pytorch_backend/conformer/encoder.py b/espnet/nets/pytorch_backend/conformer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..980d15a18b8b5c87e14c7edc808a6b19ba999dee --- /dev/null +++ b/espnet/nets/pytorch_backend/conformer/encoder.py @@ -0,0 +1,245 @@ +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder definition.""" + +import logging +import torch + +from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule +from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer +from espnet.nets.pytorch_backend.nets_utils import get_activation +from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L +from espnet.nets.pytorch_backend.transformer.attention import ( + MultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttention, # noqa: H301 + LegacyRelPositionMultiHeadedAttention, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, # noqa: H301 + ScaledPositionalEncoding, # noqa: H301 + RelPositionalEncoding, # noqa: H301 + LegacyRelPositionalEncoding, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling + + +class Encoder(torch.nn.Module): + """Conformer encoder module. + + Args: + idim (int): Input dimension. + attention_dim (int): Dimention of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + attention_dropout_rate (float): Dropout rate in attention. + input_layer (Union[str, torch.nn.Module]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + macaron_style (bool): Whether to use macaron style for positionwise layer. + pos_enc_layer_type (str): Encoder positional encoding layer type. + selfattention_layer_type (str): Encoder attention layer type. + activation_type (str): Encoder activation function type. + use_cnn_module (bool): Whether to use convolution module. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + + """ + + def __init__( + self, + idim, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + macaron_style=False, + pos_enc_layer_type="abs_pos", + selfattention_layer_type="selfattn", + activation_type="swish", + use_cnn_module=False, + zero_triu=False, + cnn_module_kernel=31, + padding_idx=-1, + ): + """Construct an Encoder object.""" + super(Encoder, self).__init__() + + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "legacy_rel_pos": + pos_enc_class = LegacyRelPositionalEncoding + assert selfattention_layer_type == "legacy_rel_selfattn" + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + self.conv_subsampling_factor = 1 + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(idim, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + self.conv_subsampling_factor = 4 + elif input_layer == "vgg2l": + self.embed = VGG2L(idim, attention_dim) + self.conv_subsampling_factor = 4 + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + + # self-attention module definition + if selfattention_layer_type == "selfattn": + logging.info("encoder self-attention layer type = self-attention") + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + elif selfattention_layer_type == "legacy_rel_selfattn": + assert pos_enc_layer_type == "legacy_rel_pos" + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + elif selfattention_layer_type == "rel_selfattn": + logging.info("encoder self-attention layer type = relative self-attention") + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + zero_triu, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + + # feed-forward module definition + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + attention_dim, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + # convolution module definition + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def forward(self, xs, masks): + """Encode input sequence. + + Args: + xs (torch.Tensor): Input tensor (#batch, time, idim). + masks (torch.Tensor): Mask tensor (#batch, time). + + Returns: + torch.Tensor: Output tensor (#batch, time, attention_dim). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + + xs, masks = self.encoders(xs, masks) + if isinstance(xs, tuple): + xs = xs[0] + + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks diff --git a/espnet/nets/pytorch_backend/conformer/encoder_layer.py b/espnet/nets/pytorch_backend/conformer/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..e8571e01eee2e126fcd0ce64524ae60f433ade2a --- /dev/null +++ b/espnet/nets/pytorch_backend/conformer/encoder_layer.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder self-attention layer definition.""" + +import torch + +from torch import nn + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + + +class EncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + conv_module (torch.nn.Module): Convolution module instance. + `ConvlutionModule` instance can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + feed_forward_macaron, + conv_module, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = LayerNorm(size) # for the FNN module + self.norm_mha = LayerNorm(size) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = LayerNorm(size) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = LayerNorm(size) # for the CNN module + self.norm_final = LayerNorm(size) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + + def forward(self, x_input, mask, cache=None): + """Compute encoded features. + + Args: + x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. + - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. + - w/o pos emb: Tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance(x_input, tuple): + x, pos_emb = x_input[0], x_input[1] + else: + x, pos_emb = x_input, None + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if pos_emb is not None: + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + else: + x_att = self.self_attn(x_q, x, x, mask) + + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x = residual + self.dropout(self.conv_module(x)) + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + if pos_emb is not None: + return (x, pos_emb), mask + + return x, mask diff --git a/espnet/nets/pytorch_backend/conformer/swish.py b/espnet/nets/pytorch_backend/conformer/swish.py new file mode 100644 index 0000000000000000000000000000000000000000..c53a7a98bfc6d983c3a308c4b40f81e315aa7875 --- /dev/null +++ b/espnet/nets/pytorch_backend/conformer/swish.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Swish() activation function for Conformer.""" + +import torch + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x): + """Return Swich activation function.""" + return x * torch.sigmoid(x) diff --git a/espnet/nets/pytorch_backend/ctc.py b/espnet/nets/pytorch_backend/ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..46f762bccffa82fecb018128e55bf67b4f8e6c37 --- /dev/null +++ b/espnet/nets/pytorch_backend/ctc.py @@ -0,0 +1,291 @@ +from distutils.version import LooseVersion +import logging + +import numpy as np +import six +import torch +import torch.nn.functional as F + +from espnet.nets.pytorch_backend.nets_utils import to_device + + +class CTC(torch.nn.Module): + """CTC module + + :param int odim: dimension of outputs + :param int eprojs: number of encoder projection units + :param float dropout_rate: dropout rate (0.0 ~ 1.0) + :param str ctc_type: builtin or warpctc + :param bool reduce: reduce the CTC loss into a scalar + """ + + def __init__(self, odim, eprojs, dropout_rate, ctc_type="warpctc", reduce=True): + super().__init__() + self.dropout_rate = dropout_rate + self.loss = None + self.ctc_lo = torch.nn.Linear(eprojs, odim) + self.probs = None # for visualization + + # In case of Pytorch >= 1.7.0, CTC will be always builtin + self.ctc_type = ( + ctc_type + if LooseVersion(torch.__version__) < LooseVersion("1.7.0") + else "builtin" + ) + + # ctc_type = buitin not support Pytorch=1.0.1 + if self.ctc_type == "builtin" and ( + LooseVersion(torch.__version__) < LooseVersion("1.1.0") + ): + self.ctc_type = "cudnnctc" + + if ctc_type != self.ctc_type: + logging.warning(f"CTC was set to {self.ctc_type} due to PyTorch version.") + + if self.ctc_type == "builtin": + reduction_type = "sum" if reduce else "none" + self.ctc_loss = torch.nn.CTCLoss( + reduction=reduction_type, zero_infinity=True + ) + elif self.ctc_type == "cudnnctc": + reduction_type = "sum" if reduce else "none" + self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) + elif self.ctc_type == "warpctc": + import warpctc_pytorch as warp_ctc + + self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce) + elif self.ctc_type == "gtnctc": + from espnet.nets.pytorch_backend.gtn_ctc import GTNCTCLossFunction + + self.ctc_loss = GTNCTCLossFunction.apply + else: + raise ValueError( + 'ctc_type must be "builtin" or "warpctc": {}'.format(self.ctc_type) + ) + + self.ignore_id = -1 + self.reduce = reduce + + def loss_fn(self, th_pred, th_target, th_ilen, th_olen): + if self.ctc_type in ["builtin", "cudnnctc"]: + th_pred = th_pred.log_softmax(2) + # Use the deterministic CuDNN implementation of CTC loss to avoid + # [issue#17798](https://github.com/pytorch/pytorch/issues/17798) + with torch.backends.cudnn.flags(deterministic=True): + loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) + # Batch-size average + loss = loss / th_pred.size(1) + return loss + elif self.ctc_type == "warpctc": + return self.ctc_loss(th_pred, th_target, th_ilen, th_olen) + elif self.ctc_type == "gtnctc": + targets = [t.tolist() for t in th_target] + log_probs = torch.nn.functional.log_softmax(th_pred, dim=2) + return self.ctc_loss(log_probs, targets, 0, "none") + else: + raise NotImplementedError + + def forward(self, hs_pad, hlens, ys_pad): + """CTC forward + + :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) + :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) + :param torch.Tensor ys_pad: + batch of padded character id sequence tensor (B, Lmax) + :return: ctc loss value + :rtype: torch.Tensor + """ + # TODO(kan-bayashi): need to make more smart way + ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys + + # zero padding for hs + ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) + if self.ctc_type != "gtnctc": + ys_hat = ys_hat.transpose(0, 1) + + if self.ctc_type == "builtin": + olens = to_device(ys_hat, torch.LongTensor([len(s) for s in ys])) + hlens = hlens.long() + ys_pad = torch.cat(ys) # without this the code breaks for asr_mix + self.loss = self.loss_fn(ys_hat, ys_pad, hlens, olens) + else: + self.loss = None + hlens = torch.from_numpy(np.fromiter(hlens, dtype=np.int32)) + olens = torch.from_numpy( + np.fromiter((x.size(0) for x in ys), dtype=np.int32) + ) + # zero padding for ys + ys_true = torch.cat(ys).cpu().int() # batch x olen + # get ctc loss + # expected shape of seqLength x batchSize x alphabet_size + dtype = ys_hat.dtype + if self.ctc_type == "warpctc" or dtype == torch.float16: + # warpctc only supports float32 + # torch.ctc does not support float16 (#1751) + ys_hat = ys_hat.to(dtype=torch.float32) + if self.ctc_type == "cudnnctc": + # use GPU when using the cuDNN implementation + ys_true = to_device(hs_pad, ys_true) + if self.ctc_type == "gtnctc": + # keep as list for gtn + ys_true = ys + self.loss = to_device( + hs_pad, self.loss_fn(ys_hat, ys_true, hlens, olens) + ).to(dtype=dtype) + + # get length info + logging.info( + self.__class__.__name__ + + " input lengths: " + + "".join(str(hlens).split("\n")) + ) + logging.info( + self.__class__.__name__ + + " output lengths: " + + "".join(str(olens).split("\n")) + ) + + if self.reduce: + # NOTE: sum() is needed to keep consistency + # since warpctc return as tensor w/ shape (1,) + # but builtin return as tensor w/o shape (scalar). + self.loss = self.loss.sum() + logging.info("ctc loss:" + str(float(self.loss))) + + return self.loss + + def softmax(self, hs_pad): + """softmax of frame activations + + :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + :return: log softmax applied 3d tensor (B, Tmax, odim) + :rtype: torch.Tensor + """ + self.probs = F.softmax(self.ctc_lo(hs_pad), dim=2) + return self.probs + + def log_softmax(self, hs_pad): + """log_softmax of frame activations + + :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + :return: log softmax applied 3d tensor (B, Tmax, odim) + :rtype: torch.Tensor + """ + return F.log_softmax(self.ctc_lo(hs_pad), dim=2) + + def argmax(self, hs_pad): + """argmax of frame activations + + :param torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + :return: argmax applied 2d tensor (B, Tmax) + :rtype: torch.Tensor + """ + return torch.argmax(self.ctc_lo(hs_pad), dim=2) + + def forced_align(self, h, y, blank_id=0): + """forced alignment. + + :param torch.Tensor h: hidden state sequence, 2d tensor (T, D) + :param torch.Tensor y: id sequence tensor 1d tensor (L) + :param int y: blank symbol index + :return: best alignment results + :rtype: list + """ + + def interpolate_blank(label, blank_id=0): + """Insert blank token between every two label token.""" + label = np.expand_dims(label, 1) + blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id + label = np.concatenate([blanks, label], axis=1) + label = label.reshape(-1) + label = np.append(label, label[0]) + return label + + lpz = self.log_softmax(h) + lpz = lpz.squeeze(0) + + y_int = interpolate_blank(y, blank_id) + + logdelta = np.zeros((lpz.size(0), len(y_int))) - 100000000000.0 # log of zero + state_path = ( + np.zeros((lpz.size(0), len(y_int)), dtype=np.int16) - 1 + ) # state path + + logdelta[0, 0] = lpz[0][y_int[0]] + logdelta[0, 1] = lpz[0][y_int[1]] + + for t in six.moves.range(1, lpz.size(0)): + for s in six.moves.range(len(y_int)): + if y_int[s] == blank_id or s < 2 or y_int[s] == y_int[s - 2]: + candidates = np.array([logdelta[t - 1, s], logdelta[t - 1, s - 1]]) + prev_state = [s, s - 1] + else: + candidates = np.array( + [ + logdelta[t - 1, s], + logdelta[t - 1, s - 1], + logdelta[t - 1, s - 2], + ] + ) + prev_state = [s, s - 1, s - 2] + logdelta[t, s] = np.max(candidates) + lpz[t][y_int[s]] + state_path[t, s] = prev_state[np.argmax(candidates)] + + state_seq = -1 * np.ones((lpz.size(0), 1), dtype=np.int16) + + candidates = np.array( + [logdelta[-1, len(y_int) - 1], logdelta[-1, len(y_int) - 2]] + ) + prev_state = [len(y_int) - 1, len(y_int) - 2] + state_seq[-1] = prev_state[np.argmax(candidates)] + for t in six.moves.range(lpz.size(0) - 2, -1, -1): + state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] + + output_state_seq = [] + for t in six.moves.range(0, lpz.size(0)): + output_state_seq.append(y_int[state_seq[t, 0]]) + + return output_state_seq + + +def ctc_for(args, odim, reduce=True): + """Returns the CTC module for the given args and output dimension + + :param Namespace args: the program args + :param int odim : The output dimension + :param bool reduce : return the CTC loss in a scalar + :return: the corresponding CTC module + """ + num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility + if num_encs == 1: + # compatible with single encoder asr mode + return CTC( + odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=reduce + ) + elif num_encs >= 1: + ctcs_list = torch.nn.ModuleList() + if args.share_ctc: + # use dropout_rate of the first encoder + ctc = CTC( + odim, + args.eprojs, + args.dropout_rate[0], + ctc_type=args.ctc_type, + reduce=reduce, + ) + ctcs_list.append(ctc) + else: + for idx in range(num_encs): + ctc = CTC( + odim, + args.eprojs, + args.dropout_rate[idx], + ctc_type=args.ctc_type, + reduce=reduce, + ) + ctcs_list.append(ctc) + return ctcs_list + else: + raise ValueError( + "Number of encoders needs to be more than one. {}".format(num_encs) + ) diff --git a/espnet/nets/pytorch_backend/e2e_asr.py b/espnet/nets/pytorch_backend/e2e_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..5644b99e3eec1ffaf178283899afa1917db5dc45 --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_asr.py @@ -0,0 +1,541 @@ +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""RNN sequence-to-sequence speech recognition model (pytorch).""" + +import argparse +from itertools import groupby +import logging +import math +import os + +import chainer +from chainer import reporter +import editdistance +import numpy as np +import six +import torch + +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.e2e_asr_common import label_smoothing_dist +from espnet.nets.pytorch_backend.ctc import ctc_for +from espnet.nets.pytorch_backend.frontends.feature_transform import ( + feature_transform_for, # noqa: H301 +) +from espnet.nets.pytorch_backend.frontends.frontend import frontend_for +from espnet.nets.pytorch_backend.initialization import lecun_normal_init_parameters +from espnet.nets.pytorch_backend.initialization import set_forget_bias_to_one +from espnet.nets.pytorch_backend.nets_utils import get_subsample +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor +from espnet.nets.pytorch_backend.rnn.argument import ( + add_arguments_rnn_encoder_common, # noqa: H301 + add_arguments_rnn_decoder_common, # noqa: H301 + add_arguments_rnn_attention_common, # noqa: H301 +) +from espnet.nets.pytorch_backend.rnn.attentions import att_for +from espnet.nets.pytorch_backend.rnn.decoders import decoder_for +from espnet.nets.pytorch_backend.rnn.encoders import encoder_for +from espnet.nets.scorers.ctc import CTCPrefixScorer +from espnet.utils.fill_missing_args import fill_missing_args + +CTC_LOSS_THRESHOLD = 10000 + + +class Reporter(chainer.Chain): + """A chainer reporter wrapper.""" + + def report(self, loss_ctc, loss_att, acc, cer_ctc, cer, wer, mtl_loss): + """Report at every step.""" + reporter.report({"loss_ctc": loss_ctc}, self) + reporter.report({"loss_att": loss_att}, self) + reporter.report({"acc": acc}, self) + reporter.report({"cer_ctc": cer_ctc}, self) + reporter.report({"cer": cer}, self) + reporter.report({"wer": wer}, self) + logging.info("mtl loss:" + str(mtl_loss)) + reporter.report({"loss": mtl_loss}, self) + + +class E2E(ASRInterface, torch.nn.Module): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + E2E.encoder_add_arguments(parser) + E2E.attention_add_arguments(parser) + E2E.decoder_add_arguments(parser) + return parser + + @staticmethod + def encoder_add_arguments(parser): + """Add arguments for the encoder.""" + group = parser.add_argument_group("E2E encoder setting") + group = add_arguments_rnn_encoder_common(group) + return parser + + @staticmethod + def attention_add_arguments(parser): + """Add arguments for the attention.""" + group = parser.add_argument_group("E2E attention setting") + group = add_arguments_rnn_attention_common(group) + return parser + + @staticmethod + def decoder_add_arguments(parser): + """Add arguments for the decoder.""" + group = parser.add_argument_group("E2E decoder setting") + group = add_arguments_rnn_decoder_common(group) + return parser + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + if isinstance(self.enc, torch.nn.ModuleList): + return self.enc[0].conv_subsampling_factor * int(np.prod(self.subsample)) + else: + return self.enc.conv_subsampling_factor * int(np.prod(self.subsample)) + + def __init__(self, idim, odim, args): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + super(E2E, self).__init__() + torch.nn.Module.__init__(self) + + # fill missing arguments for compatibility + args = fill_missing_args(args, self.add_arguments) + + self.mtlalpha = args.mtlalpha + assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" + self.etype = args.etype + self.verbose = args.verbose + # NOTE: for self.build method + args.char_list = getattr(args, "char_list", None) + self.char_list = args.char_list + self.outdir = args.outdir + self.space = args.sym_space + self.blank = args.sym_blank + self.reporter = Reporter() + + # below means the last number becomes eos/sos ID + # note that sos/eos IDs are identical + self.sos = odim - 1 + self.eos = odim - 1 + + # subsample info + self.subsample = get_subsample(args, mode="asr", arch="rnn") + + # label smoothing info + if args.lsm_type and os.path.isfile(args.train_json): + logging.info("Use label smoothing with " + args.lsm_type) + labeldist = label_smoothing_dist( + odim, args.lsm_type, transcript=args.train_json + ) + else: + labeldist = None + + if getattr(args, "use_frontend", False): # use getattr to keep compatibility + self.frontend = frontend_for(args, idim) + self.feature_transform = feature_transform_for(args, (idim - 1) * 2) + idim = args.n_mels + else: + self.frontend = None + + # encoder + self.enc = encoder_for(args, idim, self.subsample) + # ctc + self.ctc = ctc_for(args, odim) + # attention + self.att = att_for(args) + # decoder + self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) + + # weight initialization + self.init_like_chainer() + + # options for beam search + if args.report_cer or args.report_wer: + recog_args = { + "beam_size": args.beam_size, + "penalty": args.penalty, + "ctc_weight": args.ctc_weight, + "maxlenratio": args.maxlenratio, + "minlenratio": args.minlenratio, + "lm_weight": args.lm_weight, + "rnnlm": args.rnnlm, + "nbest": args.nbest, + "space": args.sym_space, + "blank": args.sym_blank, + } + + self.recog_args = argparse.Namespace(**recog_args) + self.report_cer = args.report_cer + self.report_wer = args.report_wer + else: + self.report_cer = False + self.report_wer = False + self.rnnlm = None + + self.logzero = -10000000000.0 + self.loss = None + self.acc = None + + def init_like_chainer(self): + """Initialize weight like chainer. + + chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 + pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) + however, there are two exceptions as far as I know. + - EmbedID.W ~ Normal(0, 1) + - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) + """ + lecun_normal_init_parameters(self) + # exceptions + # embed weight ~ Normal(0, 1) + self.dec.embed.weight.data.normal_(0, 1) + # forget-bias = 1.0 + # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 + for i in six.moves.range(len(self.dec.decoder)): + set_forget_bias_to_one(self.dec.decoder[i].bias_ih) + + def forward(self, xs_pad, ilens, ys_pad): + """E2E forward. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :return: loss value + :rtype: torch.Tensor + """ + # 0. Frontend + if self.frontend is not None: + hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) + hs_pad, hlens = self.feature_transform(hs_pad, hlens) + else: + hs_pad, hlens = xs_pad, ilens + + # 1. Encoder + hs_pad, hlens, _ = self.enc(hs_pad, hlens) + + # 2. CTC loss + if self.mtlalpha == 0: + self.loss_ctc = None + else: + self.loss_ctc = self.ctc(hs_pad, hlens, ys_pad) + + # 3. attention loss + if self.mtlalpha == 1: + self.loss_att, acc = None, None + else: + self.loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad) + self.acc = acc + + # 4. compute cer without beam search + if self.mtlalpha == 0 or self.char_list is None: + cer_ctc = None + else: + cers = [] + + y_hats = self.ctc.argmax(hs_pad).data + for i, y in enumerate(y_hats): + y_hat = [x[0] for x in groupby(y)] + y_true = ys_pad[i] + + seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.space, " ") + seq_hat_text = seq_hat_text.replace(self.blank, "") + seq_true_text = "".join(seq_true).replace(self.space, " ") + + hyp_chars = seq_hat_text.replace(" ", "") + ref_chars = seq_true_text.replace(" ", "") + if len(ref_chars) > 0: + cers.append( + editdistance.eval(hyp_chars, ref_chars) / len(ref_chars) + ) + + cer_ctc = sum(cers) / len(cers) if cers else None + + # 5. compute cer/wer + if self.training or not (self.report_cer or self.report_wer): + cer, wer = 0.0, 0.0 + # oracle_cer, oracle_wer = 0.0, 0.0 + else: + if self.recog_args.ctc_weight > 0.0: + lpz = self.ctc.log_softmax(hs_pad).data + else: + lpz = None + + word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] + nbest_hyps = self.dec.recognize_beam_batch( + hs_pad, + torch.tensor(hlens), + lpz, + self.recog_args, + self.char_list, + self.rnnlm, + ) + # remove and + y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] + for i, y_hat in enumerate(y_hats): + y_true = ys_pad[i] + + seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ") + seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") + seq_true_text = "".join(seq_true).replace(self.recog_args.space, " ") + + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + hyp_chars = seq_hat_text.replace(" ", "") + ref_chars = seq_true_text.replace(" ", "") + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + + wer = ( + 0.0 + if not self.report_wer + else float(sum(word_eds)) / sum(word_ref_lens) + ) + cer = ( + 0.0 + if not self.report_cer + else float(sum(char_eds)) / sum(char_ref_lens) + ) + + alpha = self.mtlalpha + if alpha == 0: + self.loss = self.loss_att + loss_att_data = float(self.loss_att) + loss_ctc_data = None + elif alpha == 1: + self.loss = self.loss_ctc + loss_att_data = None + loss_ctc_data = float(self.loss_ctc) + else: + self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att + loss_att_data = float(self.loss_att) + loss_ctc_data = float(self.loss_ctc) + + loss_data = float(self.loss) + if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): + self.reporter.report( + loss_ctc_data, loss_att_data, acc, cer_ctc, cer, wer, loss_data + ) + else: + logging.warning("loss (=%f) is not correct", loss_data) + return self.loss + + def scorers(self): + """Scorers.""" + return dict(decoder=self.dec, ctc=CTCPrefixScorer(self.ctc, self.eos)) + + def encode(self, x): + """Encode acoustic features. + + :param ndarray x: input acoustic feature (T, D) + :return: encoder outputs + :rtype: torch.Tensor + """ + self.eval() + ilens = [x.shape[0]] + + # subsample frame + x = x[:: self.subsample[0], :] + p = next(self.parameters()) + h = torch.as_tensor(x, device=p.device, dtype=p.dtype) + # make a utt list (1) to use the same interface for encoder + hs = h.contiguous().unsqueeze(0) + + # 0. Frontend + if self.frontend is not None: + enhanced, hlens, mask = self.frontend(hs, ilens) + hs, hlens = self.feature_transform(enhanced, hlens) + else: + hs, hlens = hs, ilens + + # 1. encoder + hs, _, _ = self.enc(hs, hlens) + return hs.squeeze(0) + + def recognize(self, x, recog_args, char_list, rnnlm=None): + """E2E beam search. + + :param ndarray x: input acoustic feature (T, D) + :param Namespace recog_args: argument Namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + hs = self.encode(x).unsqueeze(0) + # calculate log P(z_t|X) for CTC scores + if recog_args.ctc_weight > 0.0: + lpz = self.ctc.log_softmax(hs)[0] + else: + lpz = None + + # 2. Decoder + # decode the first utterance + y = self.dec.recognize_beam(hs[0], lpz, recog_args, char_list, rnnlm) + return y + + def recognize_batch(self, xs, recog_args, char_list, rnnlm=None): + """E2E batch beam search. + + :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] + :param Namespace recog_args: argument Namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + prev = self.training + self.eval() + ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) + + # subsample frame + xs = [xx[:: self.subsample[0], :] for xx in xs] + xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] + xs_pad = pad_list(xs, 0.0) + + # 0. Frontend + if self.frontend is not None: + enhanced, hlens, mask = self.frontend(xs_pad, ilens) + hs_pad, hlens = self.feature_transform(enhanced, hlens) + else: + hs_pad, hlens = xs_pad, ilens + + # 1. Encoder + hs_pad, hlens, _ = self.enc(hs_pad, hlens) + + # calculate log P(z_t|X) for CTC scores + if recog_args.ctc_weight > 0.0: + lpz = self.ctc.log_softmax(hs_pad) + normalize_score = False + else: + lpz = None + normalize_score = True + + # 2. Decoder + hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor + y = self.dec.recognize_beam_batch( + hs_pad, + hlens, + lpz, + recog_args, + char_list, + rnnlm, + normalize_score=normalize_score, + ) + + if prev: + self.train() + return y + + def enhance(self, xs): + """Forward only in the frontend stage. + + :param ndarray xs: input acoustic feature (T, C, F) + :return: enhaned feature + :rtype: torch.Tensor + """ + if self.frontend is None: + raise RuntimeError("Frontend does't exist") + prev = self.training + self.eval() + ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) + + # subsample frame + xs = [xx[:: self.subsample[0], :] for xx in xs] + xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] + xs_pad = pad_list(xs, 0.0) + enhanced, hlensm, mask = self.frontend(xs_pad, ilens) + if prev: + self.train() + return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens + + def calculate_all_attentions(self, xs_pad, ilens, ys_pad): + """E2E attention calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :return: attention weights with the following shape, + 1) multi-head case => attention weights (B, H, Lmax, Tmax), + 2) other case => attention weights (B, Lmax, Tmax). + :rtype: float ndarray + """ + self.eval() + with torch.no_grad(): + # 0. Frontend + if self.frontend is not None: + hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) + hs_pad, hlens = self.feature_transform(hs_pad, hlens) + else: + hs_pad, hlens = xs_pad, ilens + + # 1. Encoder + hpad, hlens, _ = self.enc(hs_pad, hlens) + + # 2. Decoder + att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad) + self.train() + return att_ws + + def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad): + """E2E CTC probability calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :return: CTC probability (B, Tmax, vocab) + :rtype: float ndarray + """ + probs = None + if self.mtlalpha == 0: + return probs + + self.eval() + with torch.no_grad(): + # 0. Frontend + if self.frontend is not None: + hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) + hs_pad, hlens = self.feature_transform(hs_pad, hlens) + else: + hs_pad, hlens = xs_pad, ilens + + # 1. Encoder + hpad, hlens, _ = self.enc(hs_pad, hlens) + + # 2. CTC probs + probs = self.ctc.softmax(hpad).cpu().numpy() + self.train() + return probs + + def subsample_frames(self, x): + """Subsample speeh frames in the encoder.""" + # subsample frame + x = x[:: self.subsample[0], :] + ilen = [x.shape[0]] + h = to_device(self, torch.from_numpy(np.array(x, dtype=np.float32))) + h.contiguous() + return h, ilen diff --git a/espnet/nets/pytorch_backend/e2e_asr_conformer.py b/espnet/nets/pytorch_backend/e2e_asr_conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..16cd2418ab77930dd1b0756d7f676cc57a1c71eb --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_asr_conformer.py @@ -0,0 +1,76 @@ +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" +Conformer speech recognition model (pytorch). + +It is a fusion of `e2e_asr_transformer.py` +Refer to: https://arxiv.org/abs/2005.08100 + +""" + +from espnet.nets.pytorch_backend.conformer.encoder import Encoder +from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2ETransformer +from espnet.nets.pytorch_backend.conformer.argument import ( + add_arguments_conformer_common, # noqa: H301 + verify_rel_pos_type, # noqa: H301 +) + + +class E2E(E2ETransformer): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + E2ETransformer.add_arguments(parser) + E2E.add_conformer_arguments(parser) + return parser + + @staticmethod + def add_conformer_arguments(parser): + """Add arguments for conformer model.""" + group = parser.add_argument_group("conformer model specific setting") + group = add_arguments_conformer_common(group) + return parser + + def __init__(self, idim, odim, args, ignore_id=-1): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + super().__init__(idim, odim, args, ignore_id) + if args.transformer_attn_dropout_rate is None: + args.transformer_attn_dropout_rate = args.dropout_rate + + # Check the relative positional encoding type + args = verify_rel_pos_type(args) + + self.encoder = Encoder( + idim=idim, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.eunits, + num_blocks=args.elayers, + input_layer=args.transformer_input_layer, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + attention_dropout_rate=args.transformer_attn_dropout_rate, + pos_enc_layer_type=args.transformer_encoder_pos_enc_layer_type, + selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, + activation_type=args.transformer_encoder_activation_type, + macaron_style=args.macaron_style, + use_cnn_module=args.use_cnn_module, + zero_triu=args.zero_triu, + cnn_module_kernel=args.cnn_module_kernel, + ) + self.reset_parameters(args) diff --git a/espnet/nets/pytorch_backend/e2e_asr_maskctc.py b/espnet/nets/pytorch_backend/e2e_asr_maskctc.py new file mode 100644 index 0000000000000000000000000000000000000000..c283f7de5bbee736e6b95cccc52d6b67bf83307d --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_asr_maskctc.py @@ -0,0 +1,249 @@ +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Waseda University (Yosuke Higuchi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" +Mask CTC based non-autoregressive speech recognition model (pytorch). + +See https://arxiv.org/abs/2005.08700 for the detail. + +""" + +from itertools import groupby +import logging +import math + +from distutils.util import strtobool +import numpy +import torch + +from espnet.nets.pytorch_backend.conformer.encoder import Encoder +from espnet.nets.pytorch_backend.conformer.argument import ( + add_arguments_conformer_common, # noqa: H301 +) +from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD +from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2ETransformer +from espnet.nets.pytorch_backend.maskctc.add_mask_token import mask_uniform +from espnet.nets.pytorch_backend.maskctc.mask import square_mask +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.nets_utils import th_accuracy + + +class E2E(E2ETransformer): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + E2ETransformer.add_arguments(parser) + E2E.add_maskctc_arguments(parser) + + return parser + + @staticmethod + def add_maskctc_arguments(parser): + """Add arguments for maskctc model.""" + group = parser.add_argument_group("maskctc specific setting") + + group.add_argument( + "--maskctc-use-conformer-encoder", + default=False, + type=strtobool, + ) + group = add_arguments_conformer_common(group) + + return parser + + def __init__(self, idim, odim, args, ignore_id=-1): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + odim += 1 # for the mask token + + super().__init__(idim, odim, args, ignore_id) + assert 0.0 <= self.mtlalpha < 1.0, "mtlalpha should be [0.0, 1.0)" + + self.mask_token = odim - 1 + self.sos = odim - 2 + self.eos = odim - 2 + self.odim = odim + + if args.maskctc_use_conformer_encoder: + if args.transformer_attn_dropout_rate is None: + args.transformer_attn_dropout_rate = args.conformer_dropout_rate + self.encoder = Encoder( + idim=idim, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.eunits, + num_blocks=args.elayers, + input_layer=args.transformer_input_layer, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + attention_dropout_rate=args.transformer_attn_dropout_rate, + pos_enc_layer_type=args.transformer_encoder_pos_enc_layer_type, + selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, + activation_type=args.transformer_encoder_activation_type, + macaron_style=args.macaron_style, + use_cnn_module=args.use_cnn_module, + cnn_module_kernel=args.cnn_module_kernel, + ) + self.reset_parameters(args) + + def forward(self, xs_pad, ilens, ys_pad): + """E2E forward. + + :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of source sequences (B) + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :return: ctc loss value + :rtype: torch.Tensor + :return: attention loss value + :rtype: torch.Tensor + :return: accuracy in attention decoder + :rtype: float + """ + # 1. forward encoder + xs_pad = xs_pad[:, : max(ilens)] # for data parallel + src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) + hs_pad, hs_mask = self.encoder(xs_pad, src_mask) + self.hs_pad = hs_pad + + # 2. forward decoder + ys_in_pad, ys_out_pad = mask_uniform( + ys_pad, self.mask_token, self.eos, self.ignore_id + ) + ys_mask = square_mask(ys_in_pad, self.eos) + pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) + self.pred_pad = pred_pad + + # 3. compute attention loss + loss_att = self.criterion(pred_pad, ys_out_pad) + self.acc = th_accuracy( + pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id + ) + + # 4. compute ctc loss + loss_ctc, cer_ctc = None, None + if self.mtlalpha > 0: + batch_size = xs_pad.size(0) + hs_len = hs_mask.view(batch_size, -1).sum(1) + loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) + if self.error_calculator is not None: + ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + # for visualization + if not self.training: + self.ctc.softmax(hs_pad) + + # 5. compute cer/wer + if self.training or self.error_calculator is None or self.decoder is None: + cer, wer = None, None + else: + ys_hat = pred_pad.argmax(dim=-1) + cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + alpha = self.mtlalpha + if alpha == 0: + self.loss = loss_att + loss_att_data = float(loss_att) + loss_ctc_data = None + else: + self.loss = alpha * loss_ctc + (1 - alpha) * loss_att + loss_att_data = float(loss_att) + loss_ctc_data = float(loss_ctc) + + loss_data = float(self.loss) + if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): + self.reporter.report( + loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data + ) + else: + logging.warning("loss (=%f) is not correct", loss_data) + return self.loss + + def recognize(self, x, recog_args, char_list=None, rnnlm=None): + """Recognize input speech. + + :param ndnarray x: input acoustic feature (B, T, D) or (T, D) + :param Namespace recog_args: argment Namespace contraining options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: decoding result + :rtype: list + """ + + def num2str(char_list, mask_token, mask_char="_"): + def f(yl): + cl = [char_list[y] if y != mask_token else mask_char for y in yl] + return "".join(cl).replace("", " ") + + return f + + n2s = num2str(char_list, self.mask_token) + + self.eval() + h = self.encode(x).unsqueeze(0) + + # greedy ctc outputs + ctc_probs, ctc_ids = torch.exp(self.ctc.log_softmax(h)).max(dim=-1) + y_hat = torch.stack([x[0] for x in groupby(ctc_ids[0])]) + y_idx = torch.nonzero(y_hat != 0).squeeze(-1) + + # calculate token-level ctc probabilities by taking + # the maximum probability of consecutive frames with + # the same ctc symbols + probs_hat = [] + cnt = 0 + for i, y in enumerate(y_hat.tolist()): + probs_hat.append(-1) + while cnt < ctc_ids.shape[1] and y == ctc_ids[0][cnt]: + if probs_hat[i] < ctc_probs[0][cnt]: + probs_hat[i] = ctc_probs[0][cnt].item() + cnt += 1 + probs_hat = torch.from_numpy(numpy.array(probs_hat)) + + # mask ctc outputs based on ctc probabilities + p_thres = recog_args.maskctc_probability_threshold + mask_idx = torch.nonzero(probs_hat[y_idx] < p_thres).squeeze(-1) + confident_idx = torch.nonzero(probs_hat[y_idx] >= p_thres).squeeze(-1) + mask_num = len(mask_idx) + + y_in = torch.zeros(1, len(y_idx), dtype=torch.long) + self.mask_token + y_in[0][confident_idx] = y_hat[y_idx][confident_idx] + + logging.info("ctc:{}".format(n2s(y_in[0].tolist()))) + + # iterative decoding + if not mask_num == 0: + K = recog_args.maskctc_n_iterations + num_iter = K if mask_num >= K and K > 0 else mask_num + + for t in range(num_iter - 1): + pred, _ = self.decoder(y_in, None, h, None) + pred_score, pred_id = pred[0][mask_idx].max(dim=-1) + cand = torch.topk(pred_score, mask_num // num_iter, -1)[1] + y_in[0][mask_idx[cand]] = pred_id[cand] + mask_idx = torch.nonzero(y_in[0] == self.mask_token).squeeze(-1) + + logging.info("msk:{}".format(n2s(y_in[0].tolist()))) + + # predict leftover masks (|masks| < mask_num // num_iter) + pred, pred_mask = self.decoder(y_in, None, h, None) + y_in[0][mask_idx] = pred[0][mask_idx].argmax(dim=-1) + + logging.info("msk:{}".format(n2s(y_in[0].tolist()))) + + ret = y_in.tolist()[0] + hyp = {"score": 0.0, "yseq": [self.sos] + ret + [self.eos]} + + return [hyp] diff --git a/espnet/nets/pytorch_backend/e2e_asr_mix.py b/espnet/nets/pytorch_backend/e2e_asr_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..1615f7e275e315e214fdb48024751af8ce6f3cd3 --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_asr_mix.py @@ -0,0 +1,827 @@ +#!/usr/bin/env python3 + +""" +This script is used to construct End-to-End models of multi-speaker ASR. + +Copyright 2017 Johns Hopkins University (Shinji Watanabe) + Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) +""" + +import argparse +from itertools import groupby +import logging +import math +import os +import sys + +import editdistance +import numpy as np +import six +import torch + +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.e2e_asr_common import get_vgg2l_odim +from espnet.nets.e2e_asr_common import label_smoothing_dist +from espnet.nets.pytorch_backend.ctc import ctc_for +from espnet.nets.pytorch_backend.e2e_asr import E2E as E2EASR +from espnet.nets.pytorch_backend.e2e_asr import Reporter +from espnet.nets.pytorch_backend.frontends.feature_transform import ( + feature_transform_for, # noqa: H301 +) +from espnet.nets.pytorch_backend.frontends.frontend import frontend_for +from espnet.nets.pytorch_backend.initialization import lecun_normal_init_parameters +from espnet.nets.pytorch_backend.initialization import set_forget_bias_to_one +from espnet.nets.pytorch_backend.nets_utils import get_subsample +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor +from espnet.nets.pytorch_backend.rnn.attentions import att_for +from espnet.nets.pytorch_backend.rnn.decoders import decoder_for +from espnet.nets.pytorch_backend.rnn.encoders import encoder_for as encoder_for_single +from espnet.nets.pytorch_backend.rnn.encoders import RNNP +from espnet.nets.pytorch_backend.rnn.encoders import VGG2L + +CTC_LOSS_THRESHOLD = 10000 + + +class PIT(object): + """Permutation Invariant Training (PIT) module. + + :parameter int num_spkrs: number of speakers for PIT process (2 or 3) + """ + + def __init__(self, num_spkrs): + """Initialize PIT module.""" + self.num_spkrs = num_spkrs + + # [[0, 1], [1, 0]] or + # [[0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 1, 0], [2, 0, 1]] + self.perm_choices = [] + initial_seq = np.linspace(0, num_spkrs - 1, num_spkrs, dtype=np.int64) + self.permutationDFS(initial_seq, 0) + + # [[0, 3], [1, 2]] or + # [[0, 4, 8], [0, 5, 7], [1, 3, 8], [1, 5, 6], [2, 4, 6], [2, 3, 7]] + self.loss_perm_idx = np.linspace( + 0, num_spkrs * (num_spkrs - 1), num_spkrs, dtype=np.int64 + ).reshape(1, num_spkrs) + self.loss_perm_idx = (self.loss_perm_idx + np.array(self.perm_choices)).tolist() + + def min_pit_sample(self, loss): + """Compute the PIT loss for each sample. + + :param 1-D torch.Tensor loss: list of losses for one sample, + including [h1r1, h1r2, h2r1, h2r2] or + [h1r1, h1r2, h1r3, h2r1, h2r2, h2r3, h3r1, h3r2, h3r3] + :return minimum loss of best permutation + :rtype torch.Tensor (1) + :return the best permutation + :rtype List: len=2 + + """ + score_perms = ( + torch.stack( + [torch.sum(loss[loss_perm_idx]) for loss_perm_idx in self.loss_perm_idx] + ) + / self.num_spkrs + ) + perm_loss, min_idx = torch.min(score_perms, 0) + permutation = self.perm_choices[min_idx] + return perm_loss, permutation + + def pit_process(self, losses): + """Compute the PIT loss for a batch. + + :param torch.Tensor losses: losses (B, 1|4|9) + :return minimum losses of a batch with best permutation + :rtype torch.Tensor (B) + :return the best permutation + :rtype torch.LongTensor (B, 1|2|3) + + """ + bs = losses.size(0) + ret = [self.min_pit_sample(losses[i]) for i in range(bs)] + + loss_perm = torch.stack([r[0] for r in ret], dim=0).to(losses.device) # (B) + permutation = torch.tensor([r[1] for r in ret]).long().to(losses.device) + return torch.mean(loss_perm), permutation + + def permutationDFS(self, source, start): + """Get permutations with DFS. + + The final result is all permutations of the 'source' sequence. + e.g. [[1, 2], [2, 1]] or + [[1, 2, 3], [1, 3, 2], [2, 1, 3], [2, 3, 1], [3, 2, 1], [3, 1, 2]] + + :param np.ndarray source: (num_spkrs, 1), e.g. [1, 2, ..., N] + :param int start: the start point to permute + + """ + if start == len(source) - 1: # reach final state + self.perm_choices.append(source.tolist()) + for i in range(start, len(source)): + # swap values at position start and i + source[start], source[i] = source[i], source[start] + self.permutationDFS(source, start + 1) + # reverse the swap + source[start], source[i] = source[i], source[start] + + +class E2E(ASRInterface, torch.nn.Module): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + E2EASR.encoder_add_arguments(parser) + E2E.encoder_mix_add_arguments(parser) + E2EASR.attention_add_arguments(parser) + E2EASR.decoder_add_arguments(parser) + return parser + + @staticmethod + def encoder_mix_add_arguments(parser): + """Add arguments for multi-speaker encoder.""" + group = parser.add_argument_group("E2E encoder setting for multi-speaker") + # asr-mix encoder + group.add_argument( + "--spa", + action="store_true", + help="Enable speaker parallel attention " + "for multi-speaker speech recognition task.", + ) + group.add_argument( + "--elayers-sd", + default=4, + type=int, + help="Number of speaker differentiate encoder layers" + "for multi-speaker speech recognition task.", + ) + return parser + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + return self.enc.conv_subsampling_factor * int(np.prod(self.subsample)) + + def __init__(self, idim, odim, args): + """Initialize multi-speaker E2E module.""" + super(E2E, self).__init__() + torch.nn.Module.__init__(self) + self.mtlalpha = args.mtlalpha + assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" + self.etype = args.etype + self.verbose = args.verbose + # NOTE: for self.build method + args.char_list = getattr(args, "char_list", None) + self.char_list = args.char_list + self.outdir = args.outdir + self.space = args.sym_space + self.blank = args.sym_blank + self.reporter = Reporter() + self.num_spkrs = args.num_spkrs + self.spa = args.spa + self.pit = PIT(self.num_spkrs) + + # below means the last number becomes eos/sos ID + # note that sos/eos IDs are identical + self.sos = odim - 1 + self.eos = odim - 1 + + # subsample info + self.subsample = get_subsample(args, mode="asr", arch="rnn_mix") + + # label smoothing info + if args.lsm_type and os.path.isfile(args.train_json): + logging.info("Use label smoothing with " + args.lsm_type) + labeldist = label_smoothing_dist( + odim, args.lsm_type, transcript=args.train_json + ) + else: + labeldist = None + + if getattr(args, "use_frontend", False): # use getattr to keep compatibility + self.frontend = frontend_for(args, idim) + self.feature_transform = feature_transform_for(args, (idim - 1) * 2) + idim = args.n_mels + else: + self.frontend = None + + # encoder + self.enc = encoder_for(args, idim, self.subsample) + # ctc + self.ctc = ctc_for(args, odim, reduce=False) + # attention + num_att = self.num_spkrs if args.spa else 1 + self.att = att_for(args, num_att) + # decoder + self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) + + # weight initialization + self.init_like_chainer() + + # options for beam search + if "report_cer" in vars(args) and (args.report_cer or args.report_wer): + recog_args = { + "beam_size": args.beam_size, + "penalty": args.penalty, + "ctc_weight": args.ctc_weight, + "maxlenratio": args.maxlenratio, + "minlenratio": args.minlenratio, + "lm_weight": args.lm_weight, + "rnnlm": args.rnnlm, + "nbest": args.nbest, + "space": args.sym_space, + "blank": args.sym_blank, + } + + self.recog_args = argparse.Namespace(**recog_args) + self.report_cer = args.report_cer + self.report_wer = args.report_wer + else: + self.report_cer = False + self.report_wer = False + self.rnnlm = None + + self.logzero = -10000000000.0 + self.loss = None + self.acc = None + + def init_like_chainer(self): + """Initialize weight like chainer. + + chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 + pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) + + however, there are two exceptions as far as I know. + - EmbedID.W ~ Normal(0, 1) + - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) + """ + lecun_normal_init_parameters(self) + # exceptions + # embed weight ~ Normal(0, 1) + self.dec.embed.weight.data.normal_(0, 1) + # forget-bias = 1.0 + # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 + for i in six.moves.range(len(self.dec.decoder)): + set_forget_bias_to_one(self.dec.decoder[i].bias_ih) + + def forward(self, xs_pad, ilens, ys_pad): + """E2E forward. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: + batch of padded character id sequence tensor (B, num_spkrs, Lmax) + :return: ctc loss value + :rtype: torch.Tensor + :return: attention loss value + :rtype: torch.Tensor + :return: accuracy in attention decoder + :rtype: float + """ + # 0. Frontend + if self.frontend is not None: + hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) + if isinstance(hs_pad, list): + hlens_n = [None] * self.num_spkrs + for i in range(self.num_spkrs): + hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens) + hlens = hlens_n + else: + hs_pad, hlens = self.feature_transform(hs_pad, hlens) + else: + hs_pad, hlens = xs_pad, ilens + + # 1. Encoder + if not isinstance( + hs_pad, list + ): # single-channel input xs_pad (single- or multi-speaker) + hs_pad, hlens, _ = self.enc(hs_pad, hlens) + else: # multi-channel multi-speaker input xs_pad + for i in range(self.num_spkrs): + hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) + + # 2. CTC loss + if self.mtlalpha == 0: + loss_ctc, min_perm = None, None + else: + if not isinstance(hs_pad, list): # single-speaker input xs_pad + loss_ctc = torch.mean(self.ctc(hs_pad, hlens, ys_pad)) + else: # multi-speaker input xs_pad + ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax) + loss_ctc_perm = torch.stack( + [ + self.ctc( + hs_pad[i // self.num_spkrs], + hlens[i // self.num_spkrs], + ys_pad[i % self.num_spkrs], + ) + for i in range(self.num_spkrs ** 2) + ], + dim=1, + ) # (B, num_spkrs^2) + loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm) + logging.info("ctc loss:" + str(float(loss_ctc))) + + # 3. attention loss + if self.mtlalpha == 1: + loss_att = None + acc = None + else: + if not isinstance(hs_pad, list): # single-speaker input xs_pad + loss_att, acc, _ = self.dec(hs_pad, hlens, ys_pad) + else: + for i in range(ys_pad.size(1)): # B + ys_pad[:, i] = ys_pad[min_perm[i], i] + rslt = [ + self.dec(hs_pad[i], hlens[i], ys_pad[i], strm_idx=i) + for i in range(self.num_spkrs) + ] + loss_att = sum([r[0] for r in rslt]) / float(len(rslt)) + acc = sum([r[1] for r in rslt]) / float(len(rslt)) + self.acc = acc + + # 4. compute cer without beam search + if self.mtlalpha == 0 or self.char_list is None: + cer_ctc = None + else: + cers = [] + for ns in range(self.num_spkrs): + y_hats = self.ctc.argmax(hs_pad[ns]).data + for i, y in enumerate(y_hats): + y_hat = [x[0] for x in groupby(y)] + y_true = ys_pad[ns][i] + + seq_hat = [ + self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 + ] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.space, " ") + seq_hat_text = seq_hat_text.replace(self.blank, "") + seq_true_text = "".join(seq_true).replace(self.space, " ") + + hyp_chars = seq_hat_text.replace(" ", "") + ref_chars = seq_true_text.replace(" ", "") + if len(ref_chars) > 0: + cers.append( + editdistance.eval(hyp_chars, ref_chars) / len(ref_chars) + ) + + cer_ctc = sum(cers) / len(cers) if cers else None + + # 5. compute cer/wer + if ( + self.training + or not (self.report_cer or self.report_wer) + or not isinstance(hs_pad, list) + ): + cer, wer = 0.0, 0.0 + else: + if self.recog_args.ctc_weight > 0.0: + lpz = [ + self.ctc.log_softmax(hs_pad[i]).data for i in range(self.num_spkrs) + ] + else: + lpz = None + + word_eds, char_eds, word_ref_lens, char_ref_lens = [], [], [], [] + nbest_hyps = [ + self.dec.recognize_beam_batch( + hs_pad[i], + torch.tensor(hlens[i]), + lpz[i], + self.recog_args, + self.char_list, + self.rnnlm, + strm_idx=i, + ) + for i in range(self.num_spkrs) + ] + # remove and + y_hats = [ + [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps[i]] + for i in range(self.num_spkrs) + ] + for i in range(len(y_hats[0])): + hyp_words = [] + hyp_chars = [] + ref_words = [] + ref_chars = [] + for ns in range(self.num_spkrs): + y_hat = y_hats[ns][i] + y_true = ys_pad[ns][i] + + seq_hat = [ + self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 + ] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ") + seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") + seq_true_text = "".join(seq_true).replace( + self.recog_args.space, " " + ) + + hyp_words.append(seq_hat_text.split()) + ref_words.append(seq_true_text.split()) + hyp_chars.append(seq_hat_text.replace(" ", "")) + ref_chars.append(seq_true_text.replace(" ", "")) + + tmp_word_ed = [ + editdistance.eval( + hyp_words[ns // self.num_spkrs], ref_words[ns % self.num_spkrs] + ) + for ns in range(self.num_spkrs ** 2) + ] # h1r1,h1r2,h2r1,h2r2 + tmp_char_ed = [ + editdistance.eval( + hyp_chars[ns // self.num_spkrs], ref_chars[ns % self.num_spkrs] + ) + for ns in range(self.num_spkrs ** 2) + ] # h1r1,h1r2,h2r1,h2r2 + + word_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_word_ed))[0]) + word_ref_lens.append(len(sum(ref_words, []))) + char_eds.append(self.pit.min_pit_sample(torch.tensor(tmp_char_ed))[0]) + char_ref_lens.append(len("".join(ref_chars))) + + wer = ( + 0.0 + if not self.report_wer + else float(sum(word_eds)) / sum(word_ref_lens) + ) + cer = ( + 0.0 + if not self.report_cer + else float(sum(char_eds)) / sum(char_ref_lens) + ) + + alpha = self.mtlalpha + if alpha == 0: + self.loss = loss_att + loss_att_data = float(loss_att) + loss_ctc_data = None + elif alpha == 1: + self.loss = loss_ctc + loss_att_data = None + loss_ctc_data = float(loss_ctc) + else: + self.loss = alpha * loss_ctc + (1 - alpha) * loss_att + loss_att_data = float(loss_att) + loss_ctc_data = float(loss_ctc) + + loss_data = float(self.loss) + if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): + self.reporter.report( + loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data + ) + else: + logging.warning("loss (=%f) is not correct", loss_data) + return self.loss + + def recognize(self, x, recog_args, char_list, rnnlm=None): + """E2E beam search. + + :param ndarray x: input acoustic feature (T, D) + :param Namespace recog_args: argument Namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + prev = self.training + self.eval() + ilens = [x.shape[0]] + + # subsample frame + x = x[:: self.subsample[0], :] + h = to_device(self, to_torch_tensor(x).float()) + # make a utt list (1) to use the same interface for encoder + hs = h.contiguous().unsqueeze(0) + + # 0. Frontend + if self.frontend is not None: + hs, hlens, mask = self.frontend(hs, ilens) + hlens_n = [None] * self.num_spkrs + for i in range(self.num_spkrs): + hs[i], hlens_n[i] = self.feature_transform(hs[i], hlens) + hlens = hlens_n + else: + hs, hlens = hs, ilens + + # 1. Encoder + if not isinstance(hs, list): # single-channel multi-speaker input x + hs, hlens, _ = self.enc(hs, hlens) + else: # multi-channel multi-speaker input x + for i in range(self.num_spkrs): + hs[i], hlens[i], _ = self.enc(hs[i], hlens[i]) + + # calculate log P(z_t|X) for CTC scores + if recog_args.ctc_weight > 0.0: + lpz = [self.ctc.log_softmax(i)[0] for i in hs] + else: + lpz = None + + # 2. decoder + # decode the first utterance + y = [ + self.dec.recognize_beam( + hs[i][0], lpz[i], recog_args, char_list, rnnlm, strm_idx=i + ) + for i in range(self.num_spkrs) + ] + + if prev: + self.train() + return y + + def recognize_batch(self, xs, recog_args, char_list, rnnlm=None): + """E2E beam search. + + :param ndarray xs: input acoustic feature (T, D) + :param Namespace recog_args: argument Namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + prev = self.training + self.eval() + ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) + + # subsample frame + xs = [xx[:: self.subsample[0], :] for xx in xs] + xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] + xs_pad = pad_list(xs, 0.0) + + # 0. Frontend + if self.frontend is not None: + hs_pad, hlens, mask = self.frontend(xs_pad, ilens) + hlens_n = [None] * self.num_spkrs + for i in range(self.num_spkrs): + hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens) + hlens = hlens_n + else: + hs_pad, hlens = xs_pad, ilens + + # 1. Encoder + if not isinstance(hs_pad, list): # single-channel multi-speaker input x + hs_pad, hlens, _ = self.enc(hs_pad, hlens) + else: # multi-channel multi-speaker input x + for i in range(self.num_spkrs): + hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) + + # calculate log P(z_t|X) for CTC scores + if recog_args.ctc_weight > 0.0: + lpz = [self.ctc.log_softmax(hs_pad[i]) for i in range(self.num_spkrs)] + normalize_score = False + else: + lpz = None + normalize_score = True + + # 2. decoder + y = [ + self.dec.recognize_beam_batch( + hs_pad[i], + hlens[i], + lpz[i], + recog_args, + char_list, + rnnlm, + normalize_score=normalize_score, + strm_idx=i, + ) + for i in range(self.num_spkrs) + ] + + if prev: + self.train() + return y + + def enhance(self, xs): + """Forward only the frontend stage. + + :param ndarray xs: input acoustic feature (T, C, F) + """ + if self.frontend is None: + raise RuntimeError("Frontend doesn't exist") + prev = self.training + self.eval() + ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) + + # subsample frame + xs = [xx[:: self.subsample[0], :] for xx in xs] + xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] + xs_pad = pad_list(xs, 0.0) + enhanced, hlensm, mask = self.frontend(xs_pad, ilens) + if prev: + self.train() + + if isinstance(enhanced, (tuple, list)): + enhanced = list(enhanced) + mask = list(mask) + for idx in range(len(enhanced)): # number of speakers + enhanced[idx] = enhanced[idx].cpu().numpy() + mask[idx] = mask[idx].cpu().numpy() + return enhanced, mask, ilens + return enhanced.cpu().numpy(), mask.cpu().numpy(), ilens + + def calculate_all_attentions(self, xs_pad, ilens, ys_pad): + """E2E attention calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: + batch of padded character id sequence tensor (B, num_spkrs, Lmax) + :return: attention weights with the following shape, + 1) multi-head case => attention weights (B, H, Lmax, Tmax), + 2) other case => attention weights (B, Lmax, Tmax). + :rtype: float ndarray + """ + with torch.no_grad(): + # 0. Frontend + if self.frontend is not None: + hs_pad, hlens, mask = self.frontend(to_torch_tensor(xs_pad), ilens) + hlens_n = [None] * self.num_spkrs + for i in range(self.num_spkrs): + hs_pad[i], hlens_n[i] = self.feature_transform(hs_pad[i], hlens) + hlens = hlens_n + else: + hs_pad, hlens = xs_pad, ilens + + # 1. Encoder + if not isinstance(hs_pad, list): # single-channel multi-speaker input x + hs_pad, hlens, _ = self.enc(hs_pad, hlens) + else: # multi-channel multi-speaker input x + for i in range(self.num_spkrs): + hs_pad[i], hlens[i], _ = self.enc(hs_pad[i], hlens[i]) + + # Permutation + ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax) + if self.num_spkrs <= 3: + loss_ctc = torch.stack( + [ + self.ctc( + hs_pad[i // self.num_spkrs], + hlens[i // self.num_spkrs], + ys_pad[i % self.num_spkrs], + ) + for i in range(self.num_spkrs ** 2) + ], + 1, + ) # (B, num_spkrs^2) + loss_ctc, min_perm = self.pit.pit_process(loss_ctc) + for i in range(ys_pad.size(1)): # B + ys_pad[:, i] = ys_pad[min_perm[i], i] + + # 2. Decoder + att_ws = [ + self.dec.calculate_all_attentions( + hs_pad[i], hlens[i], ys_pad[i], strm_idx=i + ) + for i in range(self.num_spkrs) + ] + + return att_ws + + +class EncoderMix(torch.nn.Module): + """Encoder module for the case of multi-speaker mixture speech. + + :param str etype: type of encoder network + :param int idim: number of dimensions of encoder network + :param int elayers_sd: + number of layers of speaker differentiate part in encoder network + :param int elayers_rec: + number of layers of shared recognition part in encoder network + :param int eunits: number of lstm units of encoder network + :param int eprojs: number of projection units of encoder network + :param np.ndarray subsample: list of subsampling numbers + :param float dropout: dropout rate + :param int in_channel: number of input channels + :param int num_spkrs: number of number of speakers + """ + + def __init__( + self, + etype, + idim, + elayers_sd, + elayers_rec, + eunits, + eprojs, + subsample, + dropout, + num_spkrs=2, + in_channel=1, + ): + """Initialize the encoder of single-channel multi-speaker ASR.""" + super(EncoderMix, self).__init__() + typ = etype.lstrip("vgg").rstrip("p") + if typ not in ["lstm", "gru", "blstm", "bgru"]: + logging.error("Error: need to specify an appropriate encoder architecture") + if etype.startswith("vgg"): + if etype[-1] == "p": + self.enc_mix = torch.nn.ModuleList([VGG2L(in_channel)]) + self.enc_sd = torch.nn.ModuleList( + [ + torch.nn.ModuleList( + [ + RNNP( + get_vgg2l_odim(idim, in_channel=in_channel), + elayers_sd, + eunits, + eprojs, + subsample[: elayers_sd + 1], + dropout, + typ=typ, + ) + ] + ) + for i in range(num_spkrs) + ] + ) + self.enc_rec = torch.nn.ModuleList( + [ + RNNP( + eprojs, + elayers_rec, + eunits, + eprojs, + subsample[elayers_sd:], + dropout, + typ=typ, + ) + ] + ) + logging.info("Use CNN-VGG + B" + typ.upper() + "P for encoder") + else: + logging.error( + f"Error: need to specify an appropriate encoder architecture. " + f"Illegal name {etype}" + ) + sys.exit() + else: + logging.error( + f"Error: need to specify an appropriate encoder architecture. " + f"Illegal name {etype}" + ) + sys.exit() + + self.num_spkrs = num_spkrs + + def forward(self, xs_pad, ilens): + """Encodermix forward. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :return: list: batch of hidden state sequences [num_spkrs x (B, Tmax, eprojs)] + :rtype: torch.Tensor + """ + # mixture encoder + for module in self.enc_mix: + xs_pad, ilens, _ = module(xs_pad, ilens) + + # SD and Rec encoder + xs_pad_sd = [xs_pad for i in range(self.num_spkrs)] + ilens_sd = [ilens for i in range(self.num_spkrs)] + for ns in range(self.num_spkrs): + # Encoder_SD: speaker differentiate encoder + for module in self.enc_sd[ns]: + xs_pad_sd[ns], ilens_sd[ns], _ = module(xs_pad_sd[ns], ilens_sd[ns]) + # Encoder_Rec: recognition encoder + for module in self.enc_rec: + xs_pad_sd[ns], ilens_sd[ns], _ = module(xs_pad_sd[ns], ilens_sd[ns]) + + # make mask to remove bias value in padded part + mask = to_device(xs_pad, make_pad_mask(ilens_sd[0]).unsqueeze(-1)) + + return [x.masked_fill(mask, 0.0) for x in xs_pad_sd], ilens_sd, None + + +def encoder_for(args, idim, subsample): + """Construct the encoder.""" + if getattr(args, "use_frontend", False): # use getattr to keep compatibility + # with frontend, the mixed speech are separated as streams for each speaker + return encoder_for_single(args, idim, subsample) + else: + return EncoderMix( + args.etype, + idim, + args.elayers_sd, + args.elayers, + args.eunits, + args.eprojs, + subsample, + args.dropout_rate, + args.num_spkrs, + ) diff --git a/espnet/nets/pytorch_backend/e2e_asr_mix_transformer.py b/espnet/nets/pytorch_backend/e2e_asr_mix_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..176c504966137da475caf0928983b16a949f5c41 --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_asr_mix_transformer.py @@ -0,0 +1,462 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +# Copyright 2020 Johns Hopkins University (Xuankai Chang) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" +Transformer speech recognition model for single-channel multi-speaker mixture speech. + +It is a fusion of `e2e_asr_mix.py` and `e2e_asr_transformer.py`. Refer to: + https://arxiv.org/pdf/2002.03921.pdf +1. The Transformer-based Encoder now consists of three stages: + (a): Enc_mix: encoding input mixture speech; + (b): Enc_SD: separating mixed speech representations; + (c): Enc_rec: transforming each separated speech representation. +2. PIT is used in CTC to determine the permutation with minimum loss. +""" + +from argparse import Namespace +import logging +import math + +import numpy +import torch + +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.ctc_prefix_score import CTCPrefixScore +from espnet.nets.e2e_asr_common import end_detect +from espnet.nets.pytorch_backend.ctc import CTC +from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD +from espnet.nets.pytorch_backend.e2e_asr_mix import E2E as E2EASRMIX +from espnet.nets.pytorch_backend.e2e_asr_mix import PIT +from espnet.nets.pytorch_backend.e2e_asr_transformer import E2E as E2EASR +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.encoder_mix import EncoderMix +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet.nets.pytorch_backend.transformer.mask import target_mask + + +class E2E(E2EASR, ASRInterface, torch.nn.Module): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + E2EASR.add_arguments(parser) + E2EASRMIX.encoder_mix_add_arguments(parser) + return parser + + def __init__(self, idim, odim, args, ignore_id=-1): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + super(E2E, self).__init__(idim, odim, args, ignore_id=-1) + if args.transformer_attn_dropout_rate is None: + args.transformer_attn_dropout_rate = args.dropout_rate + self.encoder = EncoderMix( + idim=idim, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.eunits, + num_blocks_sd=args.elayers_sd, + num_blocks_rec=args.elayers, + input_layer=args.transformer_input_layer, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + attention_dropout_rate=args.transformer_attn_dropout_rate, + num_spkrs=args.num_spkrs, + ) + + if args.mtlalpha > 0.0: + self.ctc = CTC( + odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=False + ) + else: + self.ctc = None + + self.num_spkrs = args.num_spkrs + self.pit = PIT(self.num_spkrs) + + def forward(self, xs_pad, ilens, ys_pad): + """E2E forward. + + :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of source sequences (B) + :param torch.Tensor ys_pad: batch of padded target sequences + (B, num_spkrs, Lmax) + :return: ctc loass value + :rtype: torch.Tensor + :return: attention loss value + :rtype: torch.Tensor + :return: accuracy in attention decoder + :rtype: float + """ + # 1. forward encoder + xs_pad = xs_pad[:, : max(ilens)] # for data parallel + src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) + hs_pad, hs_mask = self.encoder(xs_pad, src_mask) # list: speaker differentiate + self.hs_pad = hs_pad + + # 2. ctc + # TODO(karita) show predicted text + # TODO(karita) calculate these stats + cer_ctc = None + assert self.mtlalpha > 0.0 + batch_size = xs_pad.size(0) + ys_pad = ys_pad.transpose(0, 1) # (num_spkrs, B, Lmax) + hs_len = [hs_mask[i].view(batch_size, -1).sum(1) for i in range(self.num_spkrs)] + loss_ctc_perm = torch.stack( + [ + self.ctc( + hs_pad[i // self.num_spkrs].view(batch_size, -1, self.adim), + hs_len[i // self.num_spkrs], + ys_pad[i % self.num_spkrs], + ) + for i in range(self.num_spkrs ** 2) + ], + dim=1, + ) # (B, num_spkrs^2) + loss_ctc, min_perm = self.pit.pit_process(loss_ctc_perm) + logging.info("ctc loss:" + str(float(loss_ctc))) + + # Permute the labels according to loss + for b in range(batch_size): # B + ys_pad[:, b] = ys_pad[min_perm[b], b] # (num_spkrs, B, Lmax) + ys_out_len = [ + float(torch.sum(ys_pad[i] != self.ignore_id)) for i in range(self.num_spkrs) + ] + + # TODO(karita) show predicted text + # TODO(karita) calculate these stats + if self.error_calculator is not None: + cer_ctc = [] + for i in range(self.num_spkrs): + ys_hat = self.ctc.argmax(hs_pad[i].view(batch_size, -1, self.adim)).data + cer_ctc.append( + self.error_calculator(ys_hat.cpu(), ys_pad[i].cpu(), is_ctc=True) + ) + cer_ctc = sum(map(lambda x: x[0] * x[1], zip(cer_ctc, ys_out_len))) / sum( + ys_out_len + ) + else: + cer_ctc = None + + # 3. forward decoder + if self.mtlalpha == 1.0: + loss_att, self.acc, cer, wer = None, None, None, None + else: + pred_pad, pred_mask = [None] * self.num_spkrs, [None] * self.num_spkrs + loss_att, acc = [None] * self.num_spkrs, [None] * self.num_spkrs + for i in range(self.num_spkrs): + ( + pred_pad[i], + pred_mask[i], + loss_att[i], + acc[i], + ) = self.decoder_and_attention( + hs_pad[i], hs_mask[i], ys_pad[i], batch_size + ) + + # 4. compute attention loss + # The following is just an approximation + loss_att = sum(map(lambda x: x[0] * x[1], zip(loss_att, ys_out_len))) / sum( + ys_out_len + ) + self.acc = sum(map(lambda x: x[0] * x[1], zip(acc, ys_out_len))) / sum( + ys_out_len + ) + + # 5. compute cer/wer + if self.training or self.error_calculator is None: + cer, wer = None, None + else: + ys_hat = pred_pad.argmax(dim=-1) + cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + # copyied from e2e_asr + alpha = self.mtlalpha + if alpha == 0: + self.loss = loss_att + loss_att_data = float(loss_att) + loss_ctc_data = None + elif alpha == 1: + self.loss = loss_ctc + loss_att_data = None + loss_ctc_data = float(loss_ctc) + else: + self.loss = alpha * loss_ctc + (1 - alpha) * loss_att + loss_att_data = float(loss_att) + loss_ctc_data = float(loss_ctc) + + loss_data = float(self.loss) + if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): + self.reporter.report( + loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data + ) + else: + logging.warning("loss (=%f) is not correct", loss_data) + return self.loss + + def decoder_and_attention(self, hs_pad, hs_mask, ys_pad, batch_size): + """Forward decoder and attention loss.""" + # forward decoder + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_mask = target_mask(ys_in_pad, self.ignore_id) + pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) + + # compute attention loss + loss_att = self.criterion(pred_pad, ys_out_pad) + acc = th_accuracy( + pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id + ) + return pred_pad, pred_mask, loss_att, acc + + def encode(self, x): + """Encode acoustic features. + + :param ndarray x: source acoustic feature (T, D) + :return: encoder outputs + :rtype: torch.Tensor + """ + self.eval() + x = torch.as_tensor(x).unsqueeze(0) + enc_output, _ = self.encoder(x, None) + return enc_output + + def recog(self, enc_output, recog_args, char_list=None, rnnlm=None, use_jit=False): + """Recognize input speech of each speaker. + + :param ndnarray enc_output: encoder outputs (B, T, D) or (T, D) + :param Namespace recog_args: argment Namespace contraining options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + if recog_args.ctc_weight > 0.0: + lpz = self.ctc.log_softmax(enc_output) + lpz = lpz.squeeze(0) + else: + lpz = None + + h = enc_output.squeeze(0) + + logging.info("input lengths: " + str(h.size(0))) + # search parms + beam = recog_args.beam_size + penalty = recog_args.penalty + ctc_weight = recog_args.ctc_weight + + # preprare sos + y = self.sos + vy = h.new_zeros(1).long() + + if recog_args.maxlenratio == 0: + maxlen = h.shape[0] + else: + # maxlen >= 1 + maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) + minlen = int(recog_args.minlenratio * h.size(0)) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # initialize hypothesis + if rnnlm: + hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} + else: + hyp = {"score": 0.0, "yseq": [y]} + if lpz is not None: + ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) + hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() + hyp["ctc_score_prev"] = 0.0 + if ctc_weight != 1.0: + # pre-pruning based on attention scores + ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) + else: + ctc_beam = lpz.shape[-1] + hyps = [hyp] + ended_hyps = [] + + import six + + traced_decoder = None + for i in six.moves.range(maxlen): + logging.debug("position " + str(i)) + + hyps_best_kept = [] + for hyp in hyps: + vy[0] = hyp["yseq"][i] + + # get nbest local scores and their ids + ys_mask = subsequent_mask(i + 1).unsqueeze(0) + ys = torch.tensor(hyp["yseq"]).unsqueeze(0) + # FIXME: jit does not match non-jit result + if use_jit: + if traced_decoder is None: + traced_decoder = torch.jit.trace( + self.decoder.forward_one_step, (ys, ys_mask, enc_output) + ) + local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] + else: + local_att_scores = self.decoder.forward_one_step( + ys, ys_mask, enc_output + )[0] + + if rnnlm: + rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) + local_scores = ( + local_att_scores + recog_args.lm_weight * local_lm_scores + ) + else: + local_scores = local_att_scores + + if lpz is not None: + local_best_scores, local_best_ids = torch.topk( + local_att_scores, ctc_beam, dim=1 + ) + ctc_scores, ctc_states = ctc_prefix_score( + hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"] + ) + local_scores = (1.0 - ctc_weight) * local_att_scores[ + :, local_best_ids[0] + ] + ctc_weight * torch.from_numpy( + ctc_scores - hyp["ctc_score_prev"] + ) + if rnnlm: + local_scores += ( + recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] + ) + local_best_scores, joint_best_ids = torch.topk( + local_scores, beam, dim=1 + ) + local_best_ids = local_best_ids[:, joint_best_ids[0]] + else: + local_best_scores, local_best_ids = torch.topk( + local_scores, beam, dim=1 + ) + + for j in six.moves.range(beam): + new_hyp = {} + new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) + new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) + new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] + new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) + if rnnlm: + new_hyp["rnnlm_prev"] = rnnlm_state + if lpz is not None: + new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[0, j]] + new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[0, j]] + # will be (2 x beam) hyps at most + hyps_best_kept.append(new_hyp) + + hyps_best_kept = sorted( + hyps_best_kept, key=lambda x: x["score"], reverse=True + )[:beam] + + # sort and get nbest + hyps = hyps_best_kept + logging.debug("number of pruned hypothes: " + str(len(hyps))) + if char_list is not None: + logging.debug( + "best hypo: " + + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) + ) + + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last postion in the loop") + for hyp in hyps: + hyp["yseq"].append(self.eos) + + # add ended hypothes to a final list, and removed them from current hypothes + # (this will be a probmlem, number of hyps < beam) + remained_hyps = [] + for hyp in hyps: + if hyp["yseq"][-1] == self.eos: + # only store the sequence that has more than minlen outputs + # also add penalty + if len(hyp["yseq"]) > minlen: + hyp["score"] += (i + 1) * penalty + if rnnlm: # Word LM needs to add final score + hyp["score"] += recog_args.lm_weight * rnnlm.final( + hyp["rnnlm_prev"] + ) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + + # end detection + + if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: + logging.info("end detected at %d", i) + break + + hyps = remained_hyps + if len(hyps) > 0: + logging.debug("remeined hypothes: " + str(len(hyps))) + else: + logging.info("no hypothesis. Finish decoding.") + break + + if char_list is not None: + for hyp in hyps: + logging.debug( + "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) + ) + + logging.debug("number of ended hypothes: " + str(len(ended_hyps))) + + nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ + : min(len(ended_hyps), recog_args.nbest) + ] + + # check number of hypotheis + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform recognition " + "again with smaller minlenratio." + ) + # should copy becasuse Namespace will be overwritten globally + recog_args = Namespace(**vars(recog_args)) + recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) + return self.recog(enc_output, recog_args, char_list, rnnlm) + + logging.info("total log probability: " + str(nbest_hyps[0]["score"])) + logging.info( + "normalized log probability: " + + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) + ) + return nbest_hyps + + def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): + """Recognize input speech of each speaker. + + :param ndnarray x: input acoustic feature (B, T, D) or (T, D) + :param Namespace recog_args: argment Namespace contraining options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + # Encoder + enc_output = self.encode(x) + + # Decoder + nbest_hyps = [] + for enc_out in enc_output: + nbest_hyps.append( + self.recog(enc_out, recog_args, char_list, rnnlm, use_jit) + ) + return nbest_hyps diff --git a/espnet/nets/pytorch_backend/e2e_asr_mulenc.py b/espnet/nets/pytorch_backend/e2e_asr_mulenc.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4d2d70314cb083af7d745a0601f90c0bf1134c --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_asr_mulenc.py @@ -0,0 +1,887 @@ +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Copyright 2017 Johns Hopkins University (Ruizhi Li) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Define e2e module for multi-encoder network. https://arxiv.org/pdf/1811.04903.pdf.""" + +import argparse +from itertools import groupby +import logging +import math +import os + +import chainer +from chainer import reporter +import editdistance +import numpy as np +import torch + +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.e2e_asr_common import label_smoothing_dist +from espnet.nets.pytorch_backend.ctc import ctc_for +from espnet.nets.pytorch_backend.nets_utils import get_subsample +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor +from espnet.nets.pytorch_backend.rnn.attentions import att_for +from espnet.nets.pytorch_backend.rnn.decoders import decoder_for +from espnet.nets.pytorch_backend.rnn.encoders import Encoder +from espnet.nets.pytorch_backend.rnn.encoders import encoder_for +from espnet.nets.scorers.ctc import CTCPrefixScorer +from espnet.utils.cli_utils import strtobool + +CTC_LOSS_THRESHOLD = 10000 + + +class Reporter(chainer.Chain): + """Define a chainer reporter wrapper.""" + + def report(self, loss_ctc_list, loss_att, acc, cer_ctc_list, cer, wer, mtl_loss): + """Define a chainer reporter function.""" + # loss_ctc_list = [weighted CTC, CTC1, CTC2, ... CTCN] + # cer_ctc_list = [weighted cer_ctc, cer_ctc_1, cer_ctc_2, ... cer_ctc_N] + num_encs = len(loss_ctc_list) - 1 + reporter.report({"loss_ctc": loss_ctc_list[0]}, self) + for i in range(num_encs): + reporter.report({"loss_ctc{}".format(i + 1): loss_ctc_list[i + 1]}, self) + reporter.report({"loss_att": loss_att}, self) + reporter.report({"acc": acc}, self) + reporter.report({"cer_ctc": cer_ctc_list[0]}, self) + for i in range(num_encs): + reporter.report({"cer_ctc{}".format(i + 1): cer_ctc_list[i + 1]}, self) + reporter.report({"cer": cer}, self) + reporter.report({"wer": wer}, self) + logging.info("mtl loss:" + str(mtl_loss)) + reporter.report({"loss": mtl_loss}, self) + + +class E2E(ASRInterface, torch.nn.Module): + """E2E module. + + :param List idims: List of dimensions of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments for multi-encoder setting.""" + E2E.encoder_add_arguments(parser) + E2E.attention_add_arguments(parser) + E2E.decoder_add_arguments(parser) + E2E.ctc_add_arguments(parser) + return parser + + @staticmethod + def encoder_add_arguments(parser): + """Add arguments for encoders in multi-encoder setting.""" + group = parser.add_argument_group("E2E encoder setting") + group.add_argument( + "--etype", + action="append", + type=str, + choices=[ + "lstm", + "blstm", + "lstmp", + "blstmp", + "vgglstmp", + "vggblstmp", + "vgglstm", + "vggblstm", + "gru", + "bgru", + "grup", + "bgrup", + "vgggrup", + "vggbgrup", + "vgggru", + "vggbgru", + ], + help="Type of encoder network architecture", + ) + group.add_argument( + "--elayers", + type=int, + action="append", + help="Number of encoder layers " + "(for shared recognition part in multi-speaker asr mode)", + ) + group.add_argument( + "--eunits", + "-u", + type=int, + action="append", + help="Number of encoder hidden units", + ) + group.add_argument( + "--eprojs", default=320, type=int, help="Number of encoder projection units" + ) + group.add_argument( + "--subsample", + type=str, + action="append", + help="Subsample input frames x_y_z means " + "subsample every x frame at 1st layer, " + "every y frame at 2nd layer etc.", + ) + return parser + + @staticmethod + def attention_add_arguments(parser): + """Add arguments for attentions in multi-encoder setting.""" + group = parser.add_argument_group("E2E attention setting") + # attention + group.add_argument( + "--atype", + type=str, + action="append", + choices=[ + "noatt", + "dot", + "add", + "location", + "coverage", + "coverage_location", + "location2d", + "location_recurrent", + "multi_head_dot", + "multi_head_add", + "multi_head_loc", + "multi_head_multi_res_loc", + ], + help="Type of attention architecture", + ) + group.add_argument( + "--adim", + type=int, + action="append", + help="Number of attention transformation dimensions", + ) + group.add_argument( + "--awin", + type=int, + action="append", + help="Window size for location2d attention", + ) + group.add_argument( + "--aheads", + type=int, + action="append", + help="Number of heads for multi head attention", + ) + group.add_argument( + "--aconv-chans", + type=int, + action="append", + help="Number of attention convolution channels \ + (negative value indicates no location-aware attention)", + ) + group.add_argument( + "--aconv-filts", + type=int, + action="append", + help="Number of attention convolution filters \ + (negative value indicates no location-aware attention)", + ) + group.add_argument( + "--dropout-rate", + type=float, + action="append", + help="Dropout rate for the encoder", + ) + # hierarchical attention network (HAN) + group.add_argument( + "--han-type", + default="dot", + type=str, + choices=[ + "noatt", + "dot", + "add", + "location", + "coverage", + "coverage_location", + "location2d", + "location_recurrent", + "multi_head_dot", + "multi_head_add", + "multi_head_loc", + "multi_head_multi_res_loc", + ], + help="Type of attention architecture (multi-encoder asr mode only)", + ) + group.add_argument( + "--han-dim", + default=320, + type=int, + help="Number of attention transformation dimensions in HAN", + ) + group.add_argument( + "--han-win", + default=5, + type=int, + help="Window size for location2d attention in HAN", + ) + group.add_argument( + "--han-heads", + default=4, + type=int, + help="Number of heads for multi head attention in HAN", + ) + group.add_argument( + "--han-conv-chans", + default=-1, + type=int, + help="Number of attention convolution channels in HAN \ + (negative value indicates no location-aware attention)", + ) + group.add_argument( + "--han-conv-filts", + default=100, + type=int, + help="Number of attention convolution filters in HAN \ + (negative value indicates no location-aware attention)", + ) + return parser + + @staticmethod + def decoder_add_arguments(parser): + """Add arguments for decoder in multi-encoder setting.""" + group = parser.add_argument_group("E2E decoder setting") + group.add_argument( + "--dtype", + default="lstm", + type=str, + choices=["lstm", "gru"], + help="Type of decoder network architecture", + ) + group.add_argument( + "--dlayers", default=1, type=int, help="Number of decoder layers" + ) + group.add_argument( + "--dunits", default=320, type=int, help="Number of decoder hidden units" + ) + group.add_argument( + "--dropout-rate-decoder", + default=0.0, + type=float, + help="Dropout rate for the decoder", + ) + group.add_argument( + "--sampling-probability", + default=0.0, + type=float, + help="Ratio of predicted labels fed back to decoder", + ) + group.add_argument( + "--lsm-type", + const="", + default="", + type=str, + nargs="?", + choices=["", "unigram"], + help="Apply label smoothing with a specified distribution type", + ) + return parser + + @staticmethod + def ctc_add_arguments(parser): + """Add arguments for ctc in multi-encoder setting.""" + group = parser.add_argument_group("E2E multi-ctc setting") + group.add_argument( + "--share-ctc", + type=strtobool, + default=False, + help="The flag to switch to share ctc across multiple encoders " + "(multi-encoder asr mode only).", + ) + group.add_argument( + "--weights-ctc-train", + type=float, + action="append", + help="ctc weight assigned to each encoder during training.", + ) + group.add_argument( + "--weights-ctc-dec", + type=float, + action="append", + help="ctc weight assigned to each encoder during decoding.", + ) + return parser + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + if isinstance(self.enc, Encoder): + return self.enc.conv_subsampling_factor * int( + np.prod(self.subsample_list[0]) + ) + else: + return self.enc[0].conv_subsampling_factor * int( + np.prod(self.subsample_list[0]) + ) + + def __init__(self, idims, odim, args): + """Initialize this class with python-level args. + + Args: + idims (list): list of the number of an input feature dim. + odim (int): The number of output vocab. + args (Namespace): arguments + + """ + super(E2E, self).__init__() + torch.nn.Module.__init__(self) + self.mtlalpha = args.mtlalpha + assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" + self.verbose = args.verbose + # NOTE: for self.build method + args.char_list = getattr(args, "char_list", None) + self.char_list = args.char_list + self.outdir = args.outdir + self.space = args.sym_space + self.blank = args.sym_blank + self.reporter = Reporter() + self.num_encs = args.num_encs + self.share_ctc = args.share_ctc + + # below means the last number becomes eos/sos ID + # note that sos/eos IDs are identical + self.sos = odim - 1 + self.eos = odim - 1 + + # subsample info + self.subsample_list = get_subsample(args, mode="asr", arch="rnn_mulenc") + + # label smoothing info + if args.lsm_type and os.path.isfile(args.train_json): + logging.info("Use label smoothing with " + args.lsm_type) + labeldist = label_smoothing_dist( + odim, args.lsm_type, transcript=args.train_json + ) + else: + labeldist = None + + # speech translation related + self.replace_sos = getattr( + args, "replace_sos", False + ) # use getattr to keep compatibility + + self.frontend = None + + # encoder + self.enc = encoder_for(args, idims, self.subsample_list) + # ctc + self.ctc = ctc_for(args, odim) + # attention + self.att = att_for(args) + # hierarchical attention network + han = att_for(args, han_mode=True) + self.att.append(han) + # decoder + self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) + + if args.mtlalpha > 0 and self.num_encs > 1: + # weights-ctc, + # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss + self.weights_ctc_train = args.weights_ctc_train / np.sum( + args.weights_ctc_train + ) # normalize + self.weights_ctc_dec = args.weights_ctc_dec / np.sum( + args.weights_ctc_dec + ) # normalize + logging.info( + "ctc weights (training during training): " + + " ".join([str(x) for x in self.weights_ctc_train]) + ) + logging.info( + "ctc weights (decoding during training): " + + " ".join([str(x) for x in self.weights_ctc_dec]) + ) + else: + self.weights_ctc_dec = [1.0] + self.weights_ctc_train = [1.0] + + # weight initialization + self.init_like_chainer() + + # options for beam search + if args.report_cer or args.report_wer: + recog_args = { + "beam_size": args.beam_size, + "penalty": args.penalty, + "ctc_weight": args.ctc_weight, + "maxlenratio": args.maxlenratio, + "minlenratio": args.minlenratio, + "lm_weight": args.lm_weight, + "rnnlm": args.rnnlm, + "nbest": args.nbest, + "space": args.sym_space, + "blank": args.sym_blank, + "tgt_lang": False, + "ctc_weights_dec": self.weights_ctc_dec, + } + + self.recog_args = argparse.Namespace(**recog_args) + self.report_cer = args.report_cer + self.report_wer = args.report_wer + else: + self.report_cer = False + self.report_wer = False + self.rnnlm = None + + self.logzero = -10000000000.0 + self.loss = None + self.acc = None + + def init_like_chainer(self): + """Initialize weight like chainer. + + chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 + pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) + + however, there are two exceptions as far as I know. + - EmbedID.W ~ Normal(0, 1) + - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) + """ + + def lecun_normal_init_parameters(module): + for p in module.parameters(): + data = p.data + if data.dim() == 1: + # bias + data.zero_() + elif data.dim() == 2: + # linear weight + n = data.size(1) + stdv = 1.0 / math.sqrt(n) + data.normal_(0, stdv) + elif data.dim() in (3, 4): + # conv weight + n = data.size(1) + for k in data.size()[2:]: + n *= k + stdv = 1.0 / math.sqrt(n) + data.normal_(0, stdv) + else: + raise NotImplementedError + + def set_forget_bias_to_one(bias): + n = bias.size(0) + start, end = n // 4, n // 2 + bias.data[start:end].fill_(1.0) + + lecun_normal_init_parameters(self) + # exceptions + # embed weight ~ Normal(0, 1) + self.dec.embed.weight.data.normal_(0, 1) + # forget-bias = 1.0 + # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 + for i in range(len(self.dec.decoder)): + set_forget_bias_to_one(self.dec.decoder[i].bias_ih) + + def forward(self, xs_pad_list, ilens_list, ys_pad): + """E2E forward. + + :param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences + [(B, Tmax_1, idim), (B, Tmax_2, idim),..] + :param List ilens_list: + list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..] + :param torch.Tensor ys_pad: + batch of padded character id sequence tensor (B, Lmax) + :return: loss value + :rtype: torch.Tensor + """ + if self.replace_sos: + tgt_lang_ids = ys_pad[:, 0:1] + ys_pad = ys_pad[:, 1:] # remove target language ID in the beginning + else: + tgt_lang_ids = None + + hs_pad_list, hlens_list, self.loss_ctc_list = [], [], [] + for idx in range(self.num_encs): + # 1. Encoder + hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) + + # 2. CTC loss + if self.mtlalpha == 0: + self.loss_ctc_list.append(None) + else: + ctc_idx = 0 if self.share_ctc else idx + loss_ctc = self.ctc[ctc_idx](hs_pad, hlens, ys_pad) + self.loss_ctc_list.append(loss_ctc) + hs_pad_list.append(hs_pad) + hlens_list.append(hlens) + + # 3. attention loss + if self.mtlalpha == 1: + self.loss_att, acc = None, None + else: + self.loss_att, acc, _ = self.dec( + hs_pad_list, hlens_list, ys_pad, lang_ids=tgt_lang_ids + ) + self.acc = acc + + # 4. compute cer without beam search + if self.mtlalpha == 0 or self.char_list is None: + cer_ctc_list = [None] * (self.num_encs + 1) + else: + cer_ctc_list = [] + for ind in range(self.num_encs): + cers = [] + ctc_idx = 0 if self.share_ctc else ind + y_hats = self.ctc[ctc_idx].argmax(hs_pad_list[ind]).data + for i, y in enumerate(y_hats): + y_hat = [x[0] for x in groupby(y)] + y_true = ys_pad[i] + + seq_hat = [ + self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 + ] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.space, " ") + seq_hat_text = seq_hat_text.replace(self.blank, "") + seq_true_text = "".join(seq_true).replace(self.space, " ") + + hyp_chars = seq_hat_text.replace(" ", "") + ref_chars = seq_true_text.replace(" ", "") + if len(ref_chars) > 0: + cers.append( + editdistance.eval(hyp_chars, ref_chars) / len(ref_chars) + ) + + cer_ctc = sum(cers) / len(cers) if cers else None + cer_ctc_list.append(cer_ctc) + cer_ctc_weighted = np.sum( + [ + item * self.weights_ctc_train[i] + for i, item in enumerate(cer_ctc_list) + ] + ) + cer_ctc_list = [float(cer_ctc_weighted)] + [ + float(item) for item in cer_ctc_list + ] + + # 5. compute cer/wer + if self.training or not (self.report_cer or self.report_wer): + cer, wer = 0.0, 0.0 + # oracle_cer, oracle_wer = 0.0, 0.0 + else: + if self.recog_args.ctc_weight > 0.0: + lpz_list = [] + for idx in range(self.num_encs): + ctc_idx = 0 if self.share_ctc else idx + lpz = self.ctc[ctc_idx].log_softmax(hs_pad_list[idx]).data + lpz_list.append(lpz) + else: + lpz_list = None + + word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] + nbest_hyps = self.dec.recognize_beam_batch( + hs_pad_list, + hlens_list, + lpz_list, + self.recog_args, + self.char_list, + self.rnnlm, + lang_ids=tgt_lang_ids.squeeze(1).tolist() if self.replace_sos else None, + ) + # remove and + y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] + for i, y_hat in enumerate(y_hats): + y_true = ys_pad[i] + + seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ") + seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") + seq_true_text = "".join(seq_true).replace(self.recog_args.space, " ") + + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + hyp_chars = seq_hat_text.replace(" ", "") + ref_chars = seq_true_text.replace(" ", "") + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + + wer = ( + 0.0 + if not self.report_wer + else float(sum(word_eds)) / sum(word_ref_lens) + ) + cer = ( + 0.0 + if not self.report_cer + else float(sum(char_eds)) / sum(char_ref_lens) + ) + + alpha = self.mtlalpha + if alpha == 0: + self.loss = self.loss_att + loss_att_data = float(self.loss_att) + loss_ctc_data_list = [None] * (self.num_encs + 1) + elif alpha == 1: + self.loss = torch.sum( + torch.cat( + [ + (item * self.weights_ctc_train[i]).unsqueeze(0) + for i, item in enumerate(self.loss_ctc_list) + ] + ) + ) + loss_att_data = None + loss_ctc_data_list = [float(self.loss)] + [ + float(item) for item in self.loss_ctc_list + ] + else: + self.loss_ctc = torch.sum( + torch.cat( + [ + (item * self.weights_ctc_train[i]).unsqueeze(0) + for i, item in enumerate(self.loss_ctc_list) + ] + ) + ) + self.loss = alpha * self.loss_ctc + (1 - alpha) * self.loss_att + loss_att_data = float(self.loss_att) + loss_ctc_data_list = [float(self.loss_ctc)] + [ + float(item) for item in self.loss_ctc_list + ] + + loss_data = float(self.loss) + if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): + self.reporter.report( + loss_ctc_data_list, + loss_att_data, + acc, + cer_ctc_list, + cer, + wer, + loss_data, + ) + else: + logging.warning("loss (=%f) is not correct", loss_data) + return self.loss + + def scorers(self): + """Get scorers for `beam_search` (optional). + + Returns: + dict[str, ScorerInterface]: dict of `ScorerInterface` objects + + """ + return dict(decoder=self.dec, ctc=CTCPrefixScorer(self.ctc, self.eos)) + + def encode(self, x_list): + """Encode feature. + + Args: + x_list (list): input feature [(T1, D), (T2, D), ... ] + Returns: + list + encoded feature [(T1, D), (T2, D), ... ] + + """ + self.eval() + ilens_list = [[x_list[idx].shape[0]] for idx in range(self.num_encs)] + + # subsample frame + x_list = [ + x_list[idx][:: self.subsample_list[idx][0], :] + for idx in range(self.num_encs) + ] + p = next(self.parameters()) + x_list = [ + torch.as_tensor(x_list[idx], device=p.device, dtype=p.dtype) + for idx in range(self.num_encs) + ] + # make a utt list (1) to use the same interface for encoder + xs_list = [ + x_list[idx].contiguous().unsqueeze(0) for idx in range(self.num_encs) + ] + + # 1. encoder + hs_list = [] + for idx in range(self.num_encs): + hs, _, _ = self.enc[idx](xs_list[idx], ilens_list[idx]) + hs_list.append(hs[0]) + return hs_list + + def recognize(self, x_list, recog_args, char_list, rnnlm=None): + """E2E beam search. + + :param list of ndarray x: list of input acoustic feature [(T1, D), (T2,D),...] + :param Namespace recog_args: argument Namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + hs_list = self.encode(x_list) + # calculate log P(z_t|X) for CTC scores + if recog_args.ctc_weight > 0.0: + if self.share_ctc: + lpz_list = [ + self.ctc[0].log_softmax(hs_list[idx].unsqueeze(0))[0] + for idx in range(self.num_encs) + ] + else: + lpz_list = [ + self.ctc[idx].log_softmax(hs_list[idx].unsqueeze(0))[0] + for idx in range(self.num_encs) + ] + else: + lpz_list = None + + # 2. Decoder + # decode the first utterance + y = self.dec.recognize_beam(hs_list, lpz_list, recog_args, char_list, rnnlm) + return y + + def recognize_batch(self, xs_list, recog_args, char_list, rnnlm=None): + """E2E beam search. + + :param list xs_list: list of list of input acoustic feature arrays + [[(T1_1, D), (T1_2, D), ...],[(T2_1, D), (T2_2, D), ...], ...] + :param Namespace recog_args: argument Namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + prev = self.training + self.eval() + ilens_list = [ + np.fromiter((xx.shape[0] for xx in xs_list[idx]), dtype=np.int64) + for idx in range(self.num_encs) + ] + + # subsample frame + xs_list = [ + [xx[:: self.subsample_list[idx][0], :] for xx in xs_list[idx]] + for idx in range(self.num_encs) + ] + + xs_list = [ + [to_device(self, to_torch_tensor(xx).float()) for xx in xs_list[idx]] + for idx in range(self.num_encs) + ] + xs_pad_list = [pad_list(xs_list[idx], 0.0) for idx in range(self.num_encs)] + + # 1. Encoder + hs_pad_list, hlens_list = [], [] + for idx in range(self.num_encs): + hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) + hs_pad_list.append(hs_pad) + hlens_list.append(hlens) + + # calculate log P(z_t|X) for CTC scores + if recog_args.ctc_weight > 0.0: + if self.share_ctc: + lpz_list = [ + self.ctc[0].log_softmax(hs_pad_list[idx]) + for idx in range(self.num_encs) + ] + else: + lpz_list = [ + self.ctc[idx].log_softmax(hs_pad_list[idx]) + for idx in range(self.num_encs) + ] + normalize_score = False + else: + lpz_list = None + normalize_score = True + + # 2. Decoder + hlens_list = [ + torch.tensor(list(map(int, hlens_list[idx]))) + for idx in range(self.num_encs) + ] # make sure hlens is tensor + y = self.dec.recognize_beam_batch( + hs_pad_list, + hlens_list, + lpz_list, + recog_args, + char_list, + rnnlm, + normalize_score=normalize_score, + ) + + if prev: + self.train() + return y + + def calculate_all_attentions(self, xs_pad_list, ilens_list, ys_pad): + """E2E attention calculation. + + :param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences + [(B, Tmax_1, idim), (B, Tmax_2, idim),..] + :param List ilens_list: + list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..] + :param torch.Tensor ys_pad: + batch of padded character id sequence tensor (B, Lmax) + :return: attention weights with the following shape, + 1) multi-head case => attention weights (B, H, Lmax, Tmax), + 2) multi-encoder case + => [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)] + 3) other case => attention weights (B, Lmax, Tmax). + :rtype: float ndarray or list + """ + self.eval() + with torch.no_grad(): + # 1. Encoder + if self.replace_sos: + tgt_lang_ids = ys_pad[:, 0:1] + ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining + else: + tgt_lang_ids = None + + hs_pad_list, hlens_list = [], [] + for idx in range(self.num_encs): + hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) + hs_pad_list.append(hs_pad) + hlens_list.append(hlens) + + # 2. Decoder + att_ws = self.dec.calculate_all_attentions( + hs_pad_list, hlens_list, ys_pad, lang_ids=tgt_lang_ids + ) + self.train() + return att_ws + + def calculate_all_ctc_probs(self, xs_pad_list, ilens_list, ys_pad): + """E2E CTC probability calculation. + + :param List xs_pad_list: list of batch (torch.Tensor) of padded input sequences + [(B, Tmax_1, idim), (B, Tmax_2, idim),..] + :param List ilens_list: + list of batch (torch.Tensor) of lengths of input sequences [(B), (B), ..] + :param torch.Tensor ys_pad: + batch of padded character id sequence tensor (B, Lmax) + :return: CTC probability (B, Tmax, vocab) + :rtype: float ndarray or list + """ + probs_list = [None] + if self.mtlalpha == 0: + return probs_list + + self.eval() + probs_list = [] + with torch.no_grad(): + # 1. Encoder + for idx in range(self.num_encs): + hs_pad, hlens, _ = self.enc[idx](xs_pad_list[idx], ilens_list[idx]) + + # 2. CTC loss + ctc_idx = 0 if self.share_ctc else idx + probs = self.ctc[ctc_idx].softmax(hs_pad).cpu().numpy() + probs_list.append(probs) + self.train() + return probs_list diff --git a/espnet/nets/pytorch_backend/e2e_asr_transducer.py b/espnet/nets/pytorch_backend/e2e_asr_transducer.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0d038276efed90a96a38a5ec235257f624a70d --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_asr_transducer.py @@ -0,0 +1,576 @@ +"""Transducer speech recognition model (pytorch).""" + +from argparse import Namespace +from collections import Counter +from dataclasses import asdict +import logging +import math +import numpy + +import chainer +import torch + +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.pytorch_backend.ctc import ctc_for +from espnet.nets.pytorch_backend.nets_utils import get_subsample +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.transducer.arguments import ( + add_encoder_general_arguments, # noqa: H301 + add_rnn_encoder_arguments, # noqa: H301 + add_custom_encoder_arguments, # noqa: H301 + add_decoder_general_arguments, # noqa: H301 + add_rnn_decoder_arguments, # noqa: H301 + add_custom_decoder_arguments, # noqa: H301 + add_custom_training_arguments, # noqa: H301 + add_transducer_arguments, # noqa: H301 + add_auxiliary_task_arguments, # noqa: H301 +) +from espnet.nets.pytorch_backend.transducer.auxiliary_task import AuxiliaryTask +from espnet.nets.pytorch_backend.transducer.custom_decoder import CustomDecoder +from espnet.nets.pytorch_backend.transducer.custom_encoder import CustomEncoder +from espnet.nets.pytorch_backend.transducer.error_calculator import ErrorCalculator +from espnet.nets.pytorch_backend.transducer.initializer import initializer +from espnet.nets.pytorch_backend.transducer.joint_network import JointNetwork +from espnet.nets.pytorch_backend.transducer.loss import TransLoss +from espnet.nets.pytorch_backend.transducer.rnn_decoder import DecoderRNNT +from espnet.nets.pytorch_backend.transducer.rnn_encoder import encoder_for +from espnet.nets.pytorch_backend.transducer.utils import prepare_loss_inputs +from espnet.nets.pytorch_backend.transducer.utils import valid_aux_task_layer_list +from espnet.nets.pytorch_backend.transformer.attention import ( + MultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttention, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.mask import target_mask +from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport +from espnet.utils.fill_missing_args import fill_missing_args + + +class Reporter(chainer.Chain): + """A chainer reporter wrapper for transducer models.""" + + def report( + self, + loss, + loss_trans, + loss_ctc, + loss_lm, + loss_aux_trans, + loss_aux_symm_kl, + cer, + wer, + ): + """Instantiate reporter attributes.""" + chainer.reporter.report({"loss": loss}, self) + chainer.reporter.report({"loss_trans": loss_trans}, self) + chainer.reporter.report({"loss_ctc": loss_ctc}, self) + chainer.reporter.report({"loss_lm": loss_lm}, self) + chainer.reporter.report({"loss_aux_trans": loss_aux_trans}, self) + chainer.reporter.report({"loss_aux_symm_kl": loss_aux_symm_kl}, self) + chainer.reporter.report({"cer": cer}, self) + chainer.reporter.report({"wer": wer}, self) + + logging.info("loss:" + str(loss)) + + +class E2E(ASRInterface, torch.nn.Module): + """E2E module for transducer models. + + Args: + idim (int): dimension of inputs + odim (int): dimension of outputs + args (Namespace): argument Namespace containing options + ignore_id (int): padding symbol id + blank_id (int): blank symbol id + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments for transducer model.""" + E2E.encoder_add_general_arguments(parser) + E2E.encoder_add_rnn_arguments(parser) + E2E.encoder_add_custom_arguments(parser) + + E2E.decoder_add_general_arguments(parser) + E2E.decoder_add_rnn_arguments(parser) + E2E.decoder_add_custom_arguments(parser) + + E2E.training_add_custom_arguments(parser) + E2E.transducer_add_arguments(parser) + E2E.auxiliary_task_add_arguments(parser) + + return parser + + @staticmethod + def encoder_add_general_arguments(parser): + """Add general arguments for encoder.""" + group = parser.add_argument_group("Encoder general arguments") + group = add_encoder_general_arguments(group) + + return parser + + @staticmethod + def encoder_add_rnn_arguments(parser): + """Add arguments for RNN encoder.""" + group = parser.add_argument_group("RNN encoder arguments") + group = add_rnn_encoder_arguments(group) + + return parser + + @staticmethod + def encoder_add_custom_arguments(parser): + """Add arguments for Custom encoder.""" + group = parser.add_argument_group("Custom encoder arguments") + group = add_custom_encoder_arguments(group) + + return parser + + @staticmethod + def decoder_add_general_arguments(parser): + """Add general arguments for decoder.""" + group = parser.add_argument_group("Decoder general arguments") + group = add_decoder_general_arguments(group) + + return parser + + @staticmethod + def decoder_add_rnn_arguments(parser): + """Add arguments for RNN decoder.""" + group = parser.add_argument_group("RNN decoder arguments") + group = add_rnn_decoder_arguments(group) + + return parser + + @staticmethod + def decoder_add_custom_arguments(parser): + """Add arguments for Custom decoder.""" + group = parser.add_argument_group("Custom decoder arguments") + group = add_custom_decoder_arguments(group) + + return parser + + @staticmethod + def training_add_custom_arguments(parser): + """Add arguments for Custom architecture training.""" + group = parser.add_argument_group("Training arguments for custom archictecture") + group = add_custom_training_arguments(group) + + return parser + + @staticmethod + def transducer_add_arguments(parser): + """Add arguments for transducer model.""" + group = parser.add_argument_group("Transducer model arguments") + group = add_transducer_arguments(group) + + return parser + + @staticmethod + def auxiliary_task_add_arguments(parser): + """Add arguments for auxiliary task.""" + group = parser.add_argument_group("Auxiliary task arguments") + group = add_auxiliary_task_arguments(group) + + return parser + + @property + def attention_plot_class(self): + """Get attention plot class.""" + return PlotAttentionReport + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + if self.etype == "custom": + return self.encoder.conv_subsampling_factor * int( + numpy.prod(self.subsample) + ) + else: + return self.enc.conv_subsampling_factor * int(numpy.prod(self.subsample)) + + def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0, training=True): + """Construct an E2E object for transducer model.""" + torch.nn.Module.__init__(self) + + args = fill_missing_args(args, self.add_arguments) + + self.is_rnnt = True + self.transducer_weight = args.transducer_weight + + self.use_aux_task = ( + True if (args.aux_task_type is not None and training) else False + ) + + self.use_aux_ctc = args.aux_ctc and training + self.aux_ctc_weight = args.aux_ctc_weight + + self.use_aux_cross_entropy = args.aux_cross_entropy and training + self.aux_cross_entropy_weight = args.aux_cross_entropy_weight + + if self.use_aux_task: + n_layers = ( + (len(args.enc_block_arch) * args.enc_block_repeat - 1) + if args.enc_block_arch is not None + else (args.elayers - 1) + ) + + aux_task_layer_list = valid_aux_task_layer_list( + args.aux_task_layer_list, + n_layers, + ) + else: + aux_task_layer_list = [] + + if "custom" in args.etype: + if args.enc_block_arch is None: + raise ValueError( + "When specifying custom encoder type, --enc-block-arch" + "should also be specified in training config. See" + "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." + ) + + self.subsample = get_subsample(args, mode="asr", arch="transformer") + + self.encoder = CustomEncoder( + idim, + args.enc_block_arch, + input_layer=args.custom_enc_input_layer, + repeat_block=args.enc_block_repeat, + self_attn_type=args.custom_enc_self_attn_type, + positional_encoding_type=args.custom_enc_positional_encoding_type, + positionwise_activation_type=args.custom_enc_pw_activation_type, + conv_mod_activation_type=args.custom_enc_conv_mod_activation_type, + aux_task_layer_list=aux_task_layer_list, + ) + encoder_out = self.encoder.enc_out + + self.most_dom_list = args.enc_block_arch[:] + else: + self.subsample = get_subsample(args, mode="asr", arch="rnn-t") + + self.enc = encoder_for( + args, + idim, + self.subsample, + aux_task_layer_list=aux_task_layer_list, + ) + encoder_out = args.eprojs + + if "custom" in args.dtype: + if args.dec_block_arch is None: + raise ValueError( + "When specifying custom decoder type, --dec-block-arch" + "should also be specified in training config. See" + "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." + ) + + self.decoder = CustomDecoder( + odim, + args.dec_block_arch, + input_layer=args.custom_dec_input_layer, + repeat_block=args.dec_block_repeat, + positionwise_activation_type=args.custom_dec_pw_activation_type, + dropout_rate_embed=args.dropout_rate_embed_decoder, + ) + decoder_out = self.decoder.dunits + + if "custom" in args.etype: + self.most_dom_list += args.dec_block_arch[:] + else: + self.most_dom_list = args.dec_block_arch[:] + else: + self.dec = DecoderRNNT( + odim, + args.dtype, + args.dlayers, + args.dunits, + blank_id, + args.dec_embed_dim, + args.dropout_rate_decoder, + args.dropout_rate_embed_decoder, + ) + decoder_out = args.dunits + + self.joint_network = JointNetwork( + odim, encoder_out, decoder_out, args.joint_dim, args.joint_activation_type + ) + + if hasattr(self, "most_dom_list"): + self.most_dom_dim = sorted( + Counter( + d["d_hidden"] for d in self.most_dom_list if "d_hidden" in d + ).most_common(), + key=lambda x: x[0], + reverse=True, + )[0][0] + + self.etype = args.etype + self.dtype = args.dtype + + self.sos = odim - 1 + self.eos = odim - 1 + self.blank_id = blank_id + self.ignore_id = ignore_id + + self.space = args.sym_space + self.blank = args.sym_blank + + self.odim = odim + + self.reporter = Reporter() + + self.error_calculator = None + + self.default_parameters(args) + + if training: + self.criterion = TransLoss(args.trans_type, self.blank_id) + + decoder = self.decoder if self.dtype == "custom" else self.dec + + if args.report_cer or args.report_wer: + self.error_calculator = ErrorCalculator( + decoder, + self.joint_network, + args.char_list, + args.sym_space, + args.sym_blank, + args.report_cer, + args.report_wer, + ) + + if self.use_aux_task: + self.auxiliary_task = AuxiliaryTask( + decoder, + self.joint_network, + self.criterion, + args.aux_task_type, + args.aux_task_weight, + encoder_out, + args.joint_dim, + ) + + if self.use_aux_ctc: + self.aux_ctc = ctc_for( + Namespace( + num_encs=1, + eprojs=encoder_out, + dropout_rate=args.aux_ctc_dropout_rate, + ctc_type="warpctc", + ), + odim, + ) + + if self.use_aux_cross_entropy: + self.aux_decoder_output = torch.nn.Linear(decoder_out, odim) + + self.aux_cross_entropy = LabelSmoothingLoss( + odim, ignore_id, args.aux_cross_entropy_smoothing + ) + + self.loss = None + self.rnnlm = None + + def default_parameters(self, args): + """Initialize/reset parameters for transducer. + + Args: + args (Namespace): argument Namespace containing options + + """ + initializer(self, args) + + def forward(self, xs_pad, ilens, ys_pad): + """E2E forward. + + Args: + xs_pad (torch.Tensor): batch of padded source sequences (B, Tmax, idim) + ilens (torch.Tensor): batch of lengths of input sequences (B) + ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) + + Returns: + loss (torch.Tensor): transducer loss value + + """ + # 1. encoder + xs_pad = xs_pad[:, : max(ilens)] + + if "custom" in self.etype: + src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) + + _hs_pad, hs_mask = self.encoder(xs_pad, src_mask) + else: + _hs_pad, hs_mask, _ = self.enc(xs_pad, ilens) + + if self.use_aux_task: + hs_pad, aux_hs_pad = _hs_pad[0], _hs_pad[1] + else: + hs_pad, aux_hs_pad = _hs_pad, None + + # 1.5. transducer preparation related + ys_in_pad, ys_out_pad, target, pred_len, target_len = prepare_loss_inputs( + ys_pad, hs_mask + ) + + # 2. decoder + if "custom" in self.dtype: + ys_mask = target_mask(ys_in_pad, self.blank_id) + pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad) + else: + pred_pad = self.dec(hs_pad, ys_in_pad) + + z = self.joint_network(hs_pad.unsqueeze(2), pred_pad.unsqueeze(1)) + + # 3. loss computation + loss_trans = self.criterion(z, target, pred_len, target_len) + + if self.use_aux_task and aux_hs_pad is not None: + loss_aux_trans, loss_aux_symm_kl = self.auxiliary_task( + aux_hs_pad, pred_pad, z, target, pred_len, target_len + ) + else: + loss_aux_trans, loss_aux_symm_kl = 0.0, 0.0 + + if self.use_aux_ctc: + if "custom" in self.etype: + hs_mask = torch.IntTensor( + [h.size(1) for h in hs_mask], + ).to(hs_mask.device) + + loss_ctc = self.aux_ctc_weight * self.aux_ctc(hs_pad, hs_mask, ys_pad) + else: + loss_ctc = 0.0 + + if self.use_aux_cross_entropy: + loss_lm = self.aux_cross_entropy_weight * self.aux_cross_entropy( + self.aux_decoder_output(pred_pad), ys_out_pad + ) + else: + loss_lm = 0.0 + + loss = ( + loss_trans + + self.transducer_weight * (loss_aux_trans + loss_aux_symm_kl) + + loss_ctc + + loss_lm + ) + + self.loss = loss + loss_data = float(loss) + + # 4. compute cer/wer + if self.training or self.error_calculator is None: + cer, wer = None, None + else: + cer, wer = self.error_calculator(hs_pad, ys_pad) + + if not math.isnan(loss_data): + self.reporter.report( + loss_data, + float(loss_trans), + float(loss_ctc), + float(loss_lm), + float(loss_aux_trans), + float(loss_aux_symm_kl), + cer, + wer, + ) + else: + logging.warning("loss (=%f) is not correct", loss_data) + + return self.loss + + def encode_custom(self, x): + """Encode acoustic features. + + Args: + x (ndarray): input acoustic feature (T, D) + + Returns: + x (torch.Tensor): encoded features (T, D_enc) + + """ + x = torch.as_tensor(x).unsqueeze(0) + enc_output, _ = self.encoder(x, None) + + return enc_output.squeeze(0) + + def encode_rnn(self, x): + """Encode acoustic features. + + Args: + x (ndarray): input acoustic feature (T, D) + + Returns: + x (torch.Tensor): encoded features (T, D_enc) + + """ + p = next(self.parameters()) + + ilens = [x.shape[0]] + x = x[:: self.subsample[0], :] + + h = torch.as_tensor(x, device=p.device, dtype=p.dtype) + hs = h.contiguous().unsqueeze(0) + + hs, _, _ = self.enc(hs, ilens) + + return hs.squeeze(0) + + def recognize(self, x, beam_search): + """Recognize input features. + + Args: + x (ndarray): input acoustic feature (T, D) + beam_search (class): beam search class + + Returns: + nbest_hyps (list): n-best decoding results + + """ + self.eval() + + if "custom" in self.etype: + h = self.encode_custom(x) + else: + h = self.encode_rnn(x) + + nbest_hyps = beam_search(h) + + return [asdict(n) for n in nbest_hyps] + + def calculate_all_attentions(self, xs_pad, ilens, ys_pad): + """E2E attention calculation. + + Args: + xs_pad (torch.Tensor): batch of padded input sequences (B, Tmax, idim) + ilens (torch.Tensor): batch of lengths of input sequences (B) + ys_pad (torch.Tensor): + batch of padded character id sequence tensor (B, Lmax) + + Returns: + ret (ndarray): attention weights with the following shape, + 1) multi-head case => attention weights (B, H, Lmax, Tmax), + 2) other case => attention weights (B, Lmax, Tmax). + + """ + self.eval() + + if "custom" not in self.etype and "custom" not in self.dtype: + return [] + else: + with torch.no_grad(): + self.forward(xs_pad, ilens, ys_pad) + + ret = dict() + for name, m in self.named_modules(): + if isinstance(m, MultiHeadedAttention) or isinstance( + m, RelPositionMultiHeadedAttention + ): + ret[name] = m.attn.cpu().numpy() + + self.train() + + return ret diff --git a/espnet/nets/pytorch_backend/e2e_asr_transformer.py b/espnet/nets/pytorch_backend/e2e_asr_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b0cd0931dc95999796958d9749ea7344bff216cb --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_asr_transformer.py @@ -0,0 +1,529 @@ +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Transformer speech recognition model (pytorch).""" + +from argparse import Namespace +import logging +import math + +import numpy +import torch + +from espnet.nets.asr_interface import ASRInterface +from espnet.nets.ctc_prefix_score import CTCPrefixScore +from espnet.nets.e2e_asr_common import end_detect +from espnet.nets.e2e_asr_common import ErrorCalculator +from espnet.nets.pytorch_backend.ctc import CTC +from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD +from espnet.nets.pytorch_backend.e2e_asr import Reporter +from espnet.nets.pytorch_backend.nets_utils import get_subsample +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.rnn.decoders import CTC_SCORING_RATIO +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.argument import ( + add_arguments_transformer_common, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.attention import ( + MultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttention, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.decoder import Decoder +from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution +from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.initializer import initialize +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet.nets.pytorch_backend.transformer.mask import target_mask +from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport +from espnet.nets.scorers.ctc import CTCPrefixScorer +from espnet.utils.fill_missing_args import fill_missing_args + + +class E2E(ASRInterface, torch.nn.Module): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + group = parser.add_argument_group("transformer model setting") + + group = add_arguments_transformer_common(group) + + return parser + + @property + def attention_plot_class(self): + """Return PlotAttentionReport.""" + return PlotAttentionReport + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + return self.encoder.conv_subsampling_factor * int(numpy.prod(self.subsample)) + + def __init__(self, idim, odim, args, ignore_id=-1): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + torch.nn.Module.__init__(self) + + # fill missing arguments for compatibility + args = fill_missing_args(args, self.add_arguments) + + if args.transformer_attn_dropout_rate is None: + args.transformer_attn_dropout_rate = args.dropout_rate + self.encoder = Encoder( + idim=idim, + selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, + attention_dim=args.adim, + attention_heads=args.aheads, + conv_wshare=args.wshare, + conv_kernel_length=args.ldconv_encoder_kernel_length, + conv_usebias=args.ldconv_usebias, + linear_units=args.eunits, + num_blocks=args.elayers, + input_layer=args.transformer_input_layer, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + attention_dropout_rate=args.transformer_attn_dropout_rate, + ) + if args.mtlalpha < 1: + self.decoder = Decoder( + odim=odim, + selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, + attention_dim=args.adim, + attention_heads=args.aheads, + conv_wshare=args.wshare, + conv_kernel_length=args.ldconv_decoder_kernel_length, + conv_usebias=args.ldconv_usebias, + linear_units=args.dunits, + num_blocks=args.dlayers, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + self_attention_dropout_rate=args.transformer_attn_dropout_rate, + src_attention_dropout_rate=args.transformer_attn_dropout_rate, + ) + self.criterion = LabelSmoothingLoss( + odim, + ignore_id, + args.lsm_weight, + args.transformer_length_normalized_loss, + ) + else: + self.decoder = None + self.criterion = None + self.blank = 0 + self.sos = odim - 1 + self.eos = odim - 1 + self.odim = odim + self.ignore_id = ignore_id + self.subsample = get_subsample(args, mode="asr", arch="transformer") + self.reporter = Reporter() + + self.reset_parameters(args) + self.adim = args.adim # used for CTC (equal to d_model) + self.mtlalpha = args.mtlalpha + if args.mtlalpha > 0.0: + self.ctc = CTC( + odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True + ) + else: + self.ctc = None + + if args.report_cer or args.report_wer: + self.error_calculator = ErrorCalculator( + args.char_list, + args.sym_space, + args.sym_blank, + args.report_cer, + args.report_wer, + ) + else: + self.error_calculator = None + self.rnnlm = None + + def reset_parameters(self, args): + """Initialize parameters.""" + # initialize parameters + initialize(self, args.transformer_init) + + def forward(self, xs_pad, ilens, ys_pad): + """E2E forward. + + :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of source sequences (B) + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :return: ctc loss value + :rtype: torch.Tensor + :return: attention loss value + :rtype: torch.Tensor + :return: accuracy in attention decoder + :rtype: float + """ + # 1. forward encoder + xs_pad = xs_pad[:, : max(ilens)] # for data parallel + src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) + hs_pad, hs_mask = self.encoder(xs_pad, src_mask) + self.hs_pad = hs_pad + + # 2. forward decoder + if self.decoder is not None: + ys_in_pad, ys_out_pad = add_sos_eos( + ys_pad, self.sos, self.eos, self.ignore_id + ) + ys_mask = target_mask(ys_in_pad, self.ignore_id) + pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) + self.pred_pad = pred_pad + + # 3. compute attention loss + loss_att = self.criterion(pred_pad, ys_out_pad) + self.acc = th_accuracy( + pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id + ) + else: + loss_att = None + self.acc = None + + # TODO(karita) show predicted text + # TODO(karita) calculate these stats + cer_ctc = None + if self.mtlalpha == 0.0: + loss_ctc = None + else: + batch_size = xs_pad.size(0) + hs_len = hs_mask.view(batch_size, -1).sum(1) + loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(hs_pad.view(batch_size, -1, self.adim)).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + # for visualization + if not self.training: + self.ctc.softmax(hs_pad) + + # 5. compute cer/wer + if self.training or self.error_calculator is None or self.decoder is None: + cer, wer = None, None + else: + ys_hat = pred_pad.argmax(dim=-1) + cer, wer = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + # copied from e2e_asr + alpha = self.mtlalpha + if alpha == 0: + self.loss = loss_att + loss_att_data = float(loss_att) + loss_ctc_data = None + elif alpha == 1: + self.loss = loss_ctc + loss_att_data = None + loss_ctc_data = float(loss_ctc) + else: + self.loss = alpha * loss_ctc + (1 - alpha) * loss_att + loss_att_data = float(loss_att) + loss_ctc_data = float(loss_ctc) + + loss_data = float(self.loss) + if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): + self.reporter.report( + loss_ctc_data, loss_att_data, self.acc, cer_ctc, cer, wer, loss_data + ) + else: + logging.warning("loss (=%f) is not correct", loss_data) + return self.loss + + def scorers(self): + """Scorers.""" + return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) + + def encode(self, x): + """Encode acoustic features. + + :param ndarray x: source acoustic feature (T, D) + :return: encoder outputs + :rtype: torch.Tensor + """ + self.eval() + x = torch.as_tensor(x).unsqueeze(0) + enc_output, _ = self.encoder(x, None) + return enc_output.squeeze(0) + + def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): + """Recognize input speech. + + :param ndnarray x: input acoustic feature (B, T, D) or (T, D) + :param Namespace recog_args: argment Namespace contraining options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + enc_output = self.encode(x).unsqueeze(0) + if self.mtlalpha == 1.0: + recog_args.ctc_weight = 1.0 + logging.info("Set to pure CTC decoding mode.") + + if self.mtlalpha > 0 and recog_args.ctc_weight == 1.0: + from itertools import groupby + + lpz = self.ctc.argmax(enc_output) + collapsed_indices = [x[0] for x in groupby(lpz[0])] + hyp = [x for x in filter(lambda x: x != self.blank, collapsed_indices)] + nbest_hyps = [{"score": 0.0, "yseq": [self.sos] + hyp}] + if recog_args.beam_size > 1: + raise NotImplementedError("Pure CTC beam search is not implemented.") + # TODO(hirofumi0810): Implement beam search + return nbest_hyps + elif self.mtlalpha > 0 and recog_args.ctc_weight > 0.0: + lpz = self.ctc.log_softmax(enc_output) + lpz = lpz.squeeze(0) + else: + lpz = None + + h = enc_output.squeeze(0) + + logging.info("input lengths: " + str(h.size(0))) + # search parms + beam = recog_args.beam_size + penalty = recog_args.penalty + ctc_weight = recog_args.ctc_weight + + # preprare sos + y = self.sos + vy = h.new_zeros(1).long() + + if recog_args.maxlenratio == 0: + maxlen = h.shape[0] + else: + # maxlen >= 1 + maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) + minlen = int(recog_args.minlenratio * h.size(0)) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # initialize hypothesis + if rnnlm: + hyp = {"score": 0.0, "yseq": [y], "rnnlm_prev": None} + else: + hyp = {"score": 0.0, "yseq": [y]} + if lpz is not None: + ctc_prefix_score = CTCPrefixScore(lpz.detach().numpy(), 0, self.eos, numpy) + hyp["ctc_state_prev"] = ctc_prefix_score.initial_state() + hyp["ctc_score_prev"] = 0.0 + if ctc_weight != 1.0: + # pre-pruning based on attention scores + ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) + else: + ctc_beam = lpz.shape[-1] + hyps = [hyp] + ended_hyps = [] + + import six + + traced_decoder = None + for i in six.moves.range(maxlen): + logging.debug("position " + str(i)) + + hyps_best_kept = [] + for hyp in hyps: + vy[0] = hyp["yseq"][i] + + # get nbest local scores and their ids + ys_mask = subsequent_mask(i + 1).unsqueeze(0) + ys = torch.tensor(hyp["yseq"]).unsqueeze(0) + # FIXME: jit does not match non-jit result + if use_jit: + if traced_decoder is None: + traced_decoder = torch.jit.trace( + self.decoder.forward_one_step, (ys, ys_mask, enc_output) + ) + local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] + else: + local_att_scores = self.decoder.forward_one_step( + ys, ys_mask, enc_output + )[0] + + if rnnlm: + rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) + local_scores = ( + local_att_scores + recog_args.lm_weight * local_lm_scores + ) + else: + local_scores = local_att_scores + + if lpz is not None: + local_best_scores, local_best_ids = torch.topk( + local_att_scores, ctc_beam, dim=1 + ) + ctc_scores, ctc_states = ctc_prefix_score( + hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"] + ) + local_scores = (1.0 - ctc_weight) * local_att_scores[ + :, local_best_ids[0] + ] + ctc_weight * torch.from_numpy( + ctc_scores - hyp["ctc_score_prev"] + ) + if rnnlm: + local_scores += ( + recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] + ) + local_best_scores, joint_best_ids = torch.topk( + local_scores, beam, dim=1 + ) + local_best_ids = local_best_ids[:, joint_best_ids[0]] + else: + local_best_scores, local_best_ids = torch.topk( + local_scores, beam, dim=1 + ) + + for j in six.moves.range(beam): + new_hyp = {} + new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) + new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) + new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] + new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) + if rnnlm: + new_hyp["rnnlm_prev"] = rnnlm_state + if lpz is not None: + new_hyp["ctc_state_prev"] = ctc_states[joint_best_ids[0, j]] + new_hyp["ctc_score_prev"] = ctc_scores[joint_best_ids[0, j]] + # will be (2 x beam) hyps at most + hyps_best_kept.append(new_hyp) + + hyps_best_kept = sorted( + hyps_best_kept, key=lambda x: x["score"], reverse=True + )[:beam] + + # sort and get nbest + hyps = hyps_best_kept + logging.debug("number of pruned hypothes: " + str(len(hyps))) + if char_list is not None: + logging.debug( + "best hypo: " + + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) + ) + + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last postion in the loop") + for hyp in hyps: + hyp["yseq"].append(self.eos) + + # add ended hypothes to a final list, and removed them from current hypothes + # (this will be a probmlem, number of hyps < beam) + remained_hyps = [] + for hyp in hyps: + if hyp["yseq"][-1] == self.eos: + # only store the sequence that has more than minlen outputs + # also add penalty + if len(hyp["yseq"]) > minlen: + hyp["score"] += (i + 1) * penalty + if rnnlm: # Word LM needs to add final score + hyp["score"] += recog_args.lm_weight * rnnlm.final( + hyp["rnnlm_prev"] + ) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + + # end detection + if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: + logging.info("end detected at %d", i) + break + + hyps = remained_hyps + if len(hyps) > 0: + logging.debug("remeined hypothes: " + str(len(hyps))) + else: + logging.info("no hypothesis. Finish decoding.") + break + + if char_list is not None: + for hyp in hyps: + logging.debug( + "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) + ) + + logging.debug("number of ended hypothes: " + str(len(ended_hyps))) + + nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ + : min(len(ended_hyps), recog_args.nbest) + ] + + # check number of hypotheis + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform recognition " + "again with smaller minlenratio." + ) + # should copy becasuse Namespace will be overwritten globally + recog_args = Namespace(**vars(recog_args)) + recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) + return self.recognize(x, recog_args, char_list, rnnlm) + + logging.info("total log probability: " + str(nbest_hyps[0]["score"])) + logging.info( + "normalized log probability: " + + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) + ) + return nbest_hyps + + def calculate_all_attentions(self, xs_pad, ilens, ys_pad): + """E2E attention calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :return: attention weights (B, H, Lmax, Tmax) + :rtype: float ndarray + """ + self.eval() + with torch.no_grad(): + self.forward(xs_pad, ilens, ys_pad) + ret = dict() + for name, m in self.named_modules(): + if ( + isinstance(m, MultiHeadedAttention) + or isinstance(m, DynamicConvolution) + or isinstance(m, RelPositionMultiHeadedAttention) + ): + ret[name] = m.attn.cpu().numpy() + if isinstance(m, DynamicConvolution2D): + ret[name + "_time"] = m.attn_t.cpu().numpy() + ret[name + "_freq"] = m.attn_f.cpu().numpy() + self.train() + return ret + + def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad): + """E2E CTC probability calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :return: CTC probability (B, Tmax, vocab) + :rtype: float ndarray + """ + ret = None + if self.mtlalpha == 0: + return ret + + self.eval() + with torch.no_grad(): + self.forward(xs_pad, ilens, ys_pad) + for name, m in self.named_modules(): + if isinstance(m, CTC) and m.probs is not None: + ret = m.probs.cpu().numpy() + self.train() + return ret diff --git a/espnet/nets/pytorch_backend/e2e_mt.py b/espnet/nets/pytorch_backend/e2e_mt.py new file mode 100644 index 0000000000000000000000000000000000000000..71f819a7ef393d7f52764b3ebe697dbbfe0752c7 --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_mt.py @@ -0,0 +1,373 @@ +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""RNN sequence-to-sequence text translation model (pytorch).""" + +import argparse +import logging +import math +import os + +import chainer +from chainer import reporter +import nltk +import numpy as np +import torch + +from espnet.nets.e2e_asr_common import label_smoothing_dist +from espnet.nets.mt_interface import MTInterface +from espnet.nets.pytorch_backend.initialization import uniform_init_parameters +from espnet.nets.pytorch_backend.nets_utils import get_subsample +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.rnn.argument import ( + add_arguments_rnn_encoder_common, # noqa: H301 + add_arguments_rnn_decoder_common, # noqa: H301 + add_arguments_rnn_attention_common, # noqa: H301 +) +from espnet.nets.pytorch_backend.rnn.attentions import att_for +from espnet.nets.pytorch_backend.rnn.decoders import decoder_for +from espnet.nets.pytorch_backend.rnn.encoders import encoder_for +from espnet.utils.fill_missing_args import fill_missing_args + + +class Reporter(chainer.Chain): + """A chainer reporter wrapper.""" + + def report(self, loss, acc, ppl, bleu): + """Report at every step.""" + reporter.report({"loss": loss}, self) + reporter.report({"acc": acc}, self) + reporter.report({"ppl": ppl}, self) + reporter.report({"bleu": bleu}, self) + + +class E2E(MTInterface, torch.nn.Module): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + E2E.encoder_add_arguments(parser) + E2E.attention_add_arguments(parser) + E2E.decoder_add_arguments(parser) + return parser + + @staticmethod + def encoder_add_arguments(parser): + """Add arguments for the encoder.""" + group = parser.add_argument_group("E2E encoder setting") + group = add_arguments_rnn_encoder_common(group) + return parser + + @staticmethod + def attention_add_arguments(parser): + """Add arguments for the attention.""" + group = parser.add_argument_group("E2E attention setting") + group = add_arguments_rnn_attention_common(group) + return parser + + @staticmethod + def decoder_add_arguments(parser): + """Add arguments for the decoder.""" + group = parser.add_argument_group("E2E decoder setting") + group = add_arguments_rnn_decoder_common(group) + return parser + + def __init__(self, idim, odim, args): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + super(E2E, self).__init__() + torch.nn.Module.__init__(self) + + # fill missing arguments for compatibility + args = fill_missing_args(args, self.add_arguments) + + self.etype = args.etype + self.verbose = args.verbose + # NOTE: for self.build method + args.char_list = getattr(args, "char_list", None) + self.char_list = args.char_list + self.outdir = args.outdir + self.space = args.sym_space + self.blank = args.sym_blank + self.reporter = Reporter() + + # below means the last number becomes eos/sos ID + # note that sos/eos IDs are identical + self.sos = odim - 1 + self.eos = odim - 1 + self.pad = 0 + # NOTE: we reserve index:0 for although this is reserved for a blank class + # in ASR. However, blank labels are not used in MT. + # To keep the vocabulary size, + # we use index:0 for padding instead of adding one more class. + + # subsample info + self.subsample = get_subsample(args, mode="mt", arch="rnn") + + # label smoothing info + if args.lsm_type and os.path.isfile(args.train_json): + logging.info("Use label smoothing with " + args.lsm_type) + labeldist = label_smoothing_dist( + odim, args.lsm_type, transcript=args.train_json + ) + else: + labeldist = None + + # multilingual related + self.multilingual = getattr(args, "multilingual", False) + self.replace_sos = getattr(args, "replace_sos", False) + + # encoder + self.embed = torch.nn.Embedding(idim, args.eunits, padding_idx=self.pad) + self.dropout = torch.nn.Dropout(p=args.dropout_rate) + self.enc = encoder_for(args, args.eunits, self.subsample) + # attention + self.att = att_for(args) + # decoder + self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) + + # tie source and target emeddings + if args.tie_src_tgt_embedding: + if idim != odim: + raise ValueError( + "When using tie_src_tgt_embedding, idim and odim must be equal." + ) + if args.eunits != args.dunits: + raise ValueError( + "When using tie_src_tgt_embedding, eunits and dunits must be equal." + ) + self.embed.weight = self.dec.embed.weight + + # tie emeddings and the classfier + if args.tie_classifier: + if args.context_residual: + raise ValueError( + "When using tie_classifier, context_residual must be turned off." + ) + self.dec.output.weight = self.dec.embed.weight + + # weight initialization + self.init_like_fairseq() + + # options for beam search + if args.report_bleu: + trans_args = { + "beam_size": args.beam_size, + "penalty": args.penalty, + "ctc_weight": 0, + "maxlenratio": args.maxlenratio, + "minlenratio": args.minlenratio, + "lm_weight": args.lm_weight, + "rnnlm": args.rnnlm, + "nbest": args.nbest, + "space": args.sym_space, + "blank": args.sym_blank, + "tgt_lang": False, + } + + self.trans_args = argparse.Namespace(**trans_args) + self.report_bleu = args.report_bleu + else: + self.report_bleu = False + self.rnnlm = None + + self.logzero = -10000000000.0 + self.loss = None + self.acc = None + + def init_like_fairseq(self): + """Initialize weight like Fairseq. + + Fairseq basically uses W, b, EmbedID.W ~ Uniform(-0.1, 0.1), + """ + uniform_init_parameters(self) + # exceptions + # embed weight ~ Normal(-0.1, 0.1) + torch.nn.init.uniform_(self.embed.weight, -0.1, 0.1) + torch.nn.init.constant_(self.embed.weight[self.pad], 0) + torch.nn.init.uniform_(self.dec.embed.weight, -0.1, 0.1) + torch.nn.init.constant_(self.dec.embed.weight[self.pad], 0) + + def forward(self, xs_pad, ilens, ys_pad): + """E2E forward. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :return: loss value + :rtype: torch.Tensor + """ + # 1. Encoder + xs_pad, ys_pad = self.target_language_biasing(xs_pad, ilens, ys_pad) + hs_pad, hlens, _ = self.enc(self.dropout(self.embed(xs_pad)), ilens) + + # 3. attention loss + self.loss, self.acc, self.ppl = self.dec(hs_pad, hlens, ys_pad) + + # 4. compute bleu + if self.training or not self.report_bleu: + self.bleu = 0.0 + else: + lpz = None + + nbest_hyps = self.dec.recognize_beam_batch( + hs_pad, + torch.tensor(hlens), + lpz, + self.trans_args, + self.char_list, + self.rnnlm, + ) + # remove and + list_of_refs = [] + hyps = [] + y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] + for i, y_hat in enumerate(y_hats): + y_true = ys_pad[i] + + seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.trans_args.space, " ") + seq_hat_text = seq_hat_text.replace(self.trans_args.blank, "") + seq_true_text = "".join(seq_true).replace(self.trans_args.space, " ") + + hyps += [seq_hat_text.split(" ")] + list_of_refs += [[seq_true_text.split(" ")]] + + self.bleu = nltk.bleu_score.corpus_bleu(list_of_refs, hyps) * 100 + + loss_data = float(self.loss) + if not math.isnan(loss_data): + self.reporter.report(loss_data, self.acc, self.ppl, self.bleu) + else: + logging.warning("loss (=%f) is not correct", loss_data) + return self.loss + + def target_language_biasing(self, xs_pad, ilens, ys_pad): + """Prepend target language IDs to source sentences for multilingual MT. + + These tags are prepended in source/target sentences as pre-processing. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :return: source text without language IDs + :rtype: torch.Tensor + :return: target text without language IDs + :rtype: torch.Tensor + :return: target language IDs + :rtype: torch.Tensor (B, 1) + """ + if self.multilingual: + # remove language ID in the beggining + tgt_lang_ids = ys_pad[:, 0].unsqueeze(1) + xs_pad = xs_pad[:, 1:] # remove source language IDs here + ys_pad = ys_pad[:, 1:] + + # prepend target language ID to source sentences + xs_pad = torch.cat([tgt_lang_ids, xs_pad], dim=1) + return xs_pad, ys_pad + + def translate(self, x, trans_args, char_list, rnnlm=None): + """E2E beam search. + + :param ndarray x: input source text feature (B, T, D) + :param Namespace trans_args: argument Namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + prev = self.training + self.eval() + + # 1. encoder + # make a utt list (1) to use the same interface for encoder + if self.multilingual: + ilen = [len(x[0][1:])] + h = to_device( + self, torch.from_numpy(np.fromiter(map(int, x[0][1:]), dtype=np.int64)) + ) + else: + ilen = [len(x[0])] + h = to_device( + self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64)) + ) + hs, _, _ = self.enc(self.dropout(self.embed(h.unsqueeze(0))), ilen) + + # 2. decoder + # decode the first utterance + y = self.dec.recognize_beam(hs[0], None, trans_args, char_list, rnnlm) + + if prev: + self.train() + return y + + def translate_batch(self, xs, trans_args, char_list, rnnlm=None): + """E2E batch beam search. + + :param list xs: + list of input source text feature arrays [(T_1, D), (T_2, D), ...] + :param Namespace trans_args: argument Namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + prev = self.training + self.eval() + + # 1. Encoder + if self.multilingual: + ilens = np.fromiter((len(xx[1:]) for xx in xs), dtype=np.int64) + hs = [to_device(self, torch.from_numpy(xx[1:])) for xx in xs] + else: + ilens = np.fromiter((len(xx) for xx in xs), dtype=np.int64) + hs = [to_device(self, torch.from_numpy(xx)) for xx in xs] + xpad = pad_list(hs, self.pad) + hs_pad, hlens, _ = self.enc(self.dropout(self.embed(xpad)), ilens) + + # 2. Decoder + hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor + y = self.dec.recognize_beam_batch( + hs_pad, hlens, None, trans_args, char_list, rnnlm + ) + + if prev: + self.train() + return y + + def calculate_all_attentions(self, xs_pad, ilens, ys_pad): + """E2E attention calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :return: attention weights with the following shape, + 1) multi-head case => attention weights (B, H, Lmax, Tmax), + 2) other case => attention weights (B, Lmax, Tmax). + :rtype: float ndarray + """ + self.eval() + with torch.no_grad(): + # 1. Encoder + xs_pad, ys_pad = self.target_language_biasing(xs_pad, ilens, ys_pad) + hpad, hlens, _ = self.enc(self.dropout(self.embed(xs_pad)), ilens) + + # 2. Decoder + att_ws = self.dec.calculate_all_attentions(hpad, hlens, ys_pad) + self.train() + return att_ws diff --git a/espnet/nets/pytorch_backend/e2e_mt_transformer.py b/espnet/nets/pytorch_backend/e2e_mt_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c7422ada4b8a5fea0afdc8d1aa8e99b665e0949c --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_mt_transformer.py @@ -0,0 +1,410 @@ +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Transformer text translation model (pytorch).""" + +from argparse import Namespace +import logging +import math + +import numpy as np +import torch + +from espnet.nets.e2e_asr_common import end_detect +from espnet.nets.e2e_mt_common import ErrorCalculator +from espnet.nets.mt_interface import MTInterface +from espnet.nets.pytorch_backend.e2e_mt import Reporter +from espnet.nets.pytorch_backend.nets_utils import get_subsample +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.argument import ( + add_arguments_transformer_common, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.decoder import Decoder +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.initializer import initialize +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet.nets.pytorch_backend.transformer.mask import target_mask +from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport +from espnet.utils.fill_missing_args import fill_missing_args + + +class E2E(MTInterface, torch.nn.Module): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + group = parser.add_argument_group("transformer model setting") + group = add_arguments_transformer_common(group) + return parser + + @property + def attention_plot_class(self): + """Return PlotAttentionReport.""" + return PlotAttentionReport + + def __init__(self, idim, odim, args, ignore_id=-1): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + torch.nn.Module.__init__(self) + + # fill missing arguments for compatibility + args = fill_missing_args(args, self.add_arguments) + + if args.transformer_attn_dropout_rate is None: + args.transformer_attn_dropout_rate = args.dropout_rate + self.encoder = Encoder( + idim=idim, + selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, + attention_dim=args.adim, + attention_heads=args.aheads, + conv_wshare=args.wshare, + conv_kernel_length=args.ldconv_encoder_kernel_length, + conv_usebias=args.ldconv_usebias, + linear_units=args.eunits, + num_blocks=args.elayers, + input_layer="embed", + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + attention_dropout_rate=args.transformer_attn_dropout_rate, + ) + self.decoder = Decoder( + odim=odim, + selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, + attention_dim=args.adim, + attention_heads=args.aheads, + conv_wshare=args.wshare, + conv_kernel_length=args.ldconv_decoder_kernel_length, + conv_usebias=args.ldconv_usebias, + linear_units=args.dunits, + num_blocks=args.dlayers, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + self_attention_dropout_rate=args.transformer_attn_dropout_rate, + src_attention_dropout_rate=args.transformer_attn_dropout_rate, + ) + self.pad = 0 # use for padding + self.sos = odim - 1 + self.eos = odim - 1 + self.odim = odim + self.ignore_id = ignore_id + self.subsample = get_subsample(args, mode="mt", arch="transformer") + self.reporter = Reporter() + + # tie source and target emeddings + if args.tie_src_tgt_embedding: + if idim != odim: + raise ValueError( + "When using tie_src_tgt_embedding, idim and odim must be equal." + ) + self.encoder.embed[0].weight = self.decoder.embed[0].weight + + # tie emeddings and the classfier + if args.tie_classifier: + self.decoder.output_layer.weight = self.decoder.embed[0].weight + + self.criterion = LabelSmoothingLoss( + self.odim, + self.ignore_id, + args.lsm_weight, + args.transformer_length_normalized_loss, + ) + self.normalize_length = args.transformer_length_normalized_loss # for PPL + self.reset_parameters(args) + self.adim = args.adim + self.error_calculator = ErrorCalculator( + args.char_list, args.sym_space, args.sym_blank, args.report_bleu + ) + self.rnnlm = None + + # multilingual MT related + self.multilingual = args.multilingual + + def reset_parameters(self, args): + """Initialize parameters.""" + initialize(self, args.transformer_init) + torch.nn.init.normal_( + self.encoder.embed[0].weight, mean=0, std=args.adim ** -0.5 + ) + torch.nn.init.constant_(self.encoder.embed[0].weight[self.pad], 0) + torch.nn.init.normal_( + self.decoder.embed[0].weight, mean=0, std=args.adim ** -0.5 + ) + torch.nn.init.constant_(self.decoder.embed[0].weight[self.pad], 0) + + def forward(self, xs_pad, ilens, ys_pad): + """E2E forward. + + :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax) + :param torch.Tensor ilens: batch of lengths of source sequences (B) + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :rtype: torch.Tensor + :return: attention loss value + :rtype: torch.Tensor + :return: accuracy in attention decoder + :rtype: float + """ + # 1. forward encoder + xs_pad = xs_pad[:, : max(ilens)] # for data parallel + src_mask = (~make_pad_mask(ilens.tolist())).to(xs_pad.device).unsqueeze(-2) + xs_pad, ys_pad = self.target_forcing(xs_pad, ys_pad) + hs_pad, hs_mask = self.encoder(xs_pad, src_mask) + + # 2. forward decoder + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_mask = target_mask(ys_in_pad, self.ignore_id) + pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) + + # 3. compute attention loss + self.loss = self.criterion(pred_pad, ys_out_pad) + self.acc = th_accuracy( + pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id + ) + + # 4. compute corpus-level bleu in a mini-batch + if self.training: + self.bleu = None + else: + ys_hat = pred_pad.argmax(dim=-1) + self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + loss_data = float(self.loss) + if self.normalize_length: + self.ppl = np.exp(loss_data) + else: + batch_size = ys_out_pad.size(0) + ys_out_pad = ys_out_pad.view(-1) + ignore = ys_out_pad == self.ignore_id # (B*T,) + total_n_tokens = len(ys_out_pad) - ignore.sum().item() + self.ppl = np.exp(loss_data * batch_size / total_n_tokens) + if not math.isnan(loss_data): + self.reporter.report(loss_data, self.acc, self.ppl, self.bleu) + else: + logging.warning("loss (=%f) is not correct", loss_data) + return self.loss + + def scorers(self): + """Scorers.""" + return dict(decoder=self.decoder) + + def encode(self, xs): + """Encode source sentences.""" + self.eval() + xs = torch.as_tensor(xs).unsqueeze(0) + enc_output, _ = self.encoder(xs, None) + return enc_output.squeeze(0) + + def target_forcing(self, xs_pad, ys_pad=None, tgt_lang=None): + """Prepend target language IDs to source sentences for multilingual MT. + + These tags are prepended in source/target sentences as pre-processing. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) + :return: source text without language IDs + :rtype: torch.Tensor + :return: target text without language IDs + :rtype: torch.Tensor + :return: target language IDs + :rtype: torch.Tensor (B, 1) + """ + if self.multilingual: + xs_pad = xs_pad[:, 1:] # remove source language IDs here + if ys_pad is not None: + # remove language ID in the beginning + lang_ids = ys_pad[:, 0].unsqueeze(1) + ys_pad = ys_pad[:, 1:] + elif tgt_lang is not None: + lang_ids = xs_pad.new_zeros(xs_pad.size(0), 1).fill_(tgt_lang) + else: + raise ValueError("Set ys_pad or tgt_lang.") + + # prepend target language ID to source sentences + xs_pad = torch.cat([lang_ids, xs_pad], dim=1) + return xs_pad, ys_pad + + def translate(self, x, trans_args, char_list=None): + """Translate source text. + + :param list x: input source text feature (T,) + :param Namespace trans_args: argment Namespace contraining options + :param list char_list: list of characters + :return: N-best decoding results + :rtype: list + """ + self.eval() # NOTE: this is important because self.encode() is not used + assert isinstance(x, list) + + # make a utt list (1) to use the same interface for encoder + if self.multilingual: + x = to_device( + self, torch.from_numpy(np.fromiter(map(int, x[0][1:]), dtype=np.int64)) + ) + else: + x = to_device( + self, torch.from_numpy(np.fromiter(map(int, x[0]), dtype=np.int64)) + ) + + logging.info("input lengths: " + str(x.size(0))) + xs_pad = x.unsqueeze(0) + tgt_lang = None + if trans_args.tgt_lang: + tgt_lang = char_list.index(trans_args.tgt_lang) + xs_pad, _ = self.target_forcing(xs_pad, tgt_lang=tgt_lang) + h, _ = self.encoder(xs_pad, None) + logging.info("encoder output lengths: " + str(h.size(1))) + + # search parms + beam = trans_args.beam_size + penalty = trans_args.penalty + + if trans_args.maxlenratio == 0: + maxlen = h.size(1) + else: + # maxlen >= 1 + maxlen = max(1, int(trans_args.maxlenratio * h.size(1))) + minlen = int(trans_args.minlenratio * h.size(1)) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # initialize hypothesis + hyp = {"score": 0.0, "yseq": [self.sos]} + hyps = [hyp] + ended_hyps = [] + + for i in range(maxlen): + logging.debug("position " + str(i)) + + # batchfy + ys = h.new_zeros((len(hyps), i + 1), dtype=torch.int64) + for j, hyp in enumerate(hyps): + ys[j, :] = torch.tensor(hyp["yseq"]) + ys_mask = subsequent_mask(i + 1).unsqueeze(0).to(h.device) + + local_scores = self.decoder.forward_one_step( + ys, ys_mask, h.repeat([len(hyps), 1, 1]) + )[0] + + hyps_best_kept = [] + for j, hyp in enumerate(hyps): + local_best_scores, local_best_ids = torch.topk( + local_scores[j : j + 1], beam, dim=1 + ) + + for j in range(beam): + new_hyp = {} + new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) + new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) + new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] + new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) + # will be (2 x beam) hyps at most + hyps_best_kept.append(new_hyp) + + hyps_best_kept = sorted( + hyps_best_kept, key=lambda x: x["score"], reverse=True + )[:beam] + + # sort and get nbest + hyps = hyps_best_kept + logging.debug("number of pruned hypothes: " + str(len(hyps))) + if char_list is not None: + logging.debug( + "best hypo: " + + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) + ) + + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last postion in the loop") + for hyp in hyps: + hyp["yseq"].append(self.eos) + + # add ended hypothes to a final list, and removed them from current hypothes + # (this will be a probmlem, number of hyps < beam) + remained_hyps = [] + for hyp in hyps: + if hyp["yseq"][-1] == self.eos: + # only store the sequence that has more than minlen outputs + # also add penalty + if len(hyp["yseq"]) > minlen: + hyp["score"] += (i + 1) * penalty + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + + # end detection + if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: + logging.info("end detected at %d", i) + break + + hyps = remained_hyps + if len(hyps) > 0: + logging.debug("remeined hypothes: " + str(len(hyps))) + else: + logging.info("no hypothesis. Finish decoding.") + break + + if char_list is not None: + for hyp in hyps: + logging.debug( + "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) + ) + + logging.debug("number of ended hypothes: " + str(len(ended_hyps))) + + nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ + : min(len(ended_hyps), trans_args.nbest) + ] + + # check number of hypotheis + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform translation " + "again with smaller minlenratio." + ) + # should copy becasuse Namespace will be overwritten globally + trans_args = Namespace(**vars(trans_args)) + trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) + return self.translate(x, trans_args, char_list) + + logging.info("total log probability: " + str(nbest_hyps[0]["score"])) + logging.info( + "normalized log probability: " + + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) + ) + return nbest_hyps + + def calculate_all_attentions(self, xs_pad, ilens, ys_pad): + """E2E attention calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :return: attention weights (B, H, Lmax, Tmax) + :rtype: float ndarray + """ + self.eval() + with torch.no_grad(): + self.forward(xs_pad, ilens, ys_pad) + ret = dict() + for name, m in self.named_modules(): + if isinstance(m, MultiHeadedAttention) and m.attn is not None: + ret[name] = m.attn.cpu().numpy() + self.train() + return ret diff --git a/espnet/nets/pytorch_backend/e2e_st.py b/espnet/nets/pytorch_backend/e2e_st.py new file mode 100644 index 0000000000000000000000000000000000000000..f64e786f36180dcce6308cf5c859301fe14cceed --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_st.py @@ -0,0 +1,665 @@ +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""RNN sequence-to-sequence speech translation model (pytorch).""" + +import argparse +import copy +import logging +import math +import os + +import editdistance +import nltk + +import chainer +import numpy as np +import six +import torch + +from itertools import groupby + +from chainer import reporter + +from espnet.nets.e2e_asr_common import label_smoothing_dist +from espnet.nets.pytorch_backend.ctc import CTC +from espnet.nets.pytorch_backend.initialization import lecun_normal_init_parameters +from espnet.nets.pytorch_backend.initialization import set_forget_bias_to_one +from espnet.nets.pytorch_backend.nets_utils import get_subsample +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.nets_utils import to_torch_tensor +from espnet.nets.pytorch_backend.rnn.argument import ( + add_arguments_rnn_encoder_common, # noqa: H301 + add_arguments_rnn_decoder_common, # noqa: H301 + add_arguments_rnn_attention_common, # noqa: H301 +) +from espnet.nets.pytorch_backend.rnn.attentions import att_for +from espnet.nets.pytorch_backend.rnn.decoders import decoder_for +from espnet.nets.pytorch_backend.rnn.encoders import encoder_for +from espnet.nets.st_interface import STInterface +from espnet.utils.fill_missing_args import fill_missing_args + +CTC_LOSS_THRESHOLD = 10000 + + +class Reporter(chainer.Chain): + """A chainer reporter wrapper.""" + + def report( + self, + loss_asr, + loss_mt, + loss_st, + acc_asr, + acc_mt, + acc, + cer_ctc, + cer, + wer, + bleu, + mtl_loss, + ): + """Report at every step.""" + reporter.report({"loss_asr": loss_asr}, self) + reporter.report({"loss_mt": loss_mt}, self) + reporter.report({"loss_st": loss_st}, self) + reporter.report({"acc_asr": acc_asr}, self) + reporter.report({"acc_mt": acc_mt}, self) + reporter.report({"acc": acc}, self) + reporter.report({"cer_ctc": cer_ctc}, self) + reporter.report({"cer": cer}, self) + reporter.report({"wer": wer}, self) + reporter.report({"bleu": bleu}, self) + logging.info("mtl loss:" + str(mtl_loss)) + reporter.report({"loss": mtl_loss}, self) + + +class E2E(STInterface, torch.nn.Module): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + E2E.encoder_add_arguments(parser) + E2E.attention_add_arguments(parser) + E2E.decoder_add_arguments(parser) + return parser + + @staticmethod + def encoder_add_arguments(parser): + """Add arguments for the encoder.""" + group = parser.add_argument_group("E2E encoder setting") + group = add_arguments_rnn_encoder_common(group) + return parser + + @staticmethod + def attention_add_arguments(parser): + """Add arguments for the attention.""" + group = parser.add_argument_group("E2E attention setting") + group = add_arguments_rnn_attention_common(group) + return parser + + @staticmethod + def decoder_add_arguments(parser): + """Add arguments for the decoder.""" + group = parser.add_argument_group("E2E decoder setting") + group = add_arguments_rnn_decoder_common(group) + return parser + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + return self.enc.conv_subsampling_factor * int(np.prod(self.subsample)) + + def __init__(self, idim, odim, args): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + super(E2E, self).__init__() + torch.nn.Module.__init__(self) + + # fill missing arguments for compatibility + args = fill_missing_args(args, self.add_arguments) + + self.asr_weight = args.asr_weight + self.mt_weight = args.mt_weight + self.mtlalpha = args.mtlalpha + assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)" + assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)" + assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" + self.etype = args.etype + self.verbose = args.verbose + # NOTE: for self.build method + args.char_list = getattr(args, "char_list", None) + self.char_list = args.char_list + self.outdir = args.outdir + self.space = args.sym_space + self.blank = args.sym_blank + self.reporter = Reporter() + + # below means the last number becomes eos/sos ID + # note that sos/eos IDs are identical + self.sos = odim - 1 + self.eos = odim - 1 + self.pad = 0 + # NOTE: we reserve index:0 for although this is reserved for a blank class + # in ASR. However, blank labels are not used in MT. + # To keep the vocabulary size, + # we use index:0 for padding instead of adding one more class. + + # subsample info + self.subsample = get_subsample(args, mode="st", arch="rnn") + + # label smoothing info + if args.lsm_type and os.path.isfile(args.train_json): + logging.info("Use label smoothing with " + args.lsm_type) + labeldist = label_smoothing_dist( + odim, args.lsm_type, transcript=args.train_json + ) + else: + labeldist = None + + # multilingual related + self.multilingual = getattr(args, "multilingual", False) + self.replace_sos = getattr(args, "replace_sos", False) + + # encoder + self.enc = encoder_for(args, idim, self.subsample) + # attention (ST) + self.att = att_for(args) + # decoder (ST) + self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) + + # submodule for ASR task + self.ctc = None + self.att_asr = None + self.dec_asr = None + if self.asr_weight > 0: + if self.mtlalpha > 0.0: + self.ctc = CTC( + odim, + args.eprojs, + args.dropout_rate, + ctc_type=args.ctc_type, + reduce=True, + ) + if self.mtlalpha < 1.0: + # attention (asr) + self.att_asr = att_for(args) + # decoder (asr) + args_asr = copy.deepcopy(args) + args_asr.atype = "location" # TODO(hirofumi0810): make this option + self.dec_asr = decoder_for( + args_asr, odim, self.sos, self.eos, self.att_asr, labeldist + ) + + # submodule for MT task + if self.mt_weight > 0: + self.embed_mt = torch.nn.Embedding(odim, args.eunits, padding_idx=self.pad) + self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate) + self.enc_mt = encoder_for( + args, args.eunits, subsample=np.ones(args.elayers + 1, dtype=np.int) + ) + + # weight initialization + self.init_like_chainer() + + # options for beam search + if self.asr_weight > 0 and args.report_cer or args.report_wer: + recog_args = { + "beam_size": args.beam_size, + "penalty": args.penalty, + "ctc_weight": args.ctc_weight, + "maxlenratio": args.maxlenratio, + "minlenratio": args.minlenratio, + "lm_weight": args.lm_weight, + "rnnlm": args.rnnlm, + "nbest": args.nbest, + "space": args.sym_space, + "blank": args.sym_blank, + "tgt_lang": False, + } + + self.recog_args = argparse.Namespace(**recog_args) + self.report_cer = args.report_cer + self.report_wer = args.report_wer + else: + self.report_cer = False + self.report_wer = False + if args.report_bleu: + trans_args = { + "beam_size": args.beam_size, + "penalty": args.penalty, + "ctc_weight": 0, + "maxlenratio": args.maxlenratio, + "minlenratio": args.minlenratio, + "lm_weight": args.lm_weight, + "rnnlm": args.rnnlm, + "nbest": args.nbest, + "space": args.sym_space, + "blank": args.sym_blank, + "tgt_lang": False, + } + + self.trans_args = argparse.Namespace(**trans_args) + self.report_bleu = args.report_bleu + else: + self.report_bleu = False + self.rnnlm = None + + self.logzero = -10000000000.0 + self.loss = None + self.acc = None + + def init_like_chainer(self): + """Initialize weight like chainer. + + chainer basically uses LeCun way: W ~ Normal(0, fan_in ** -0.5), b = 0 + pytorch basically uses W, b ~ Uniform(-fan_in**-0.5, fan_in**-0.5) + however, there are two exceptions as far as I know. + - EmbedID.W ~ Normal(0, 1) + - LSTM.upward.b[forget_gate_range] = 1 (but not used in NStepLSTM) + """ + lecun_normal_init_parameters(self) + # exceptions + # embed weight ~ Normal(0, 1) + self.dec.embed.weight.data.normal_(0, 1) + # forget-bias = 1.0 + # https://discuss.pytorch.org/t/set-forget-gate-bias-of-lstm/1745 + for i in six.moves.range(len(self.dec.decoder)): + set_forget_bias_to_one(self.dec.decoder[i].bias_ih) + + def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): + """E2E forward. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :return: loss value + :rtype: torch.Tensor + """ + # 0. Extract target language ID + if self.multilingual: + tgt_lang_ids = ys_pad[:, 0:1] + ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining + else: + tgt_lang_ids = None + + # 1. Encoder + hs_pad, hlens, _ = self.enc(xs_pad, ilens) + + # 2. ST attention loss + self.loss_st, self.acc, _ = self.dec( + hs_pad, hlens, ys_pad, lang_ids=tgt_lang_ids + ) + + # 3. ASR loss + ( + self.loss_asr_att, + acc_asr, + self.loss_asr_ctc, + cer_ctc, + cer, + wer, + ) = self.forward_asr(hs_pad, hlens, ys_pad_src) + + # 4. MT attention loss + self.loss_mt, acc_mt = self.forward_mt(ys_pad, ys_pad_src) + + # 5. Compute BLEU + if self.training or not self.report_bleu: + self.bleu = 0.0 + else: + lpz = None + + nbest_hyps = self.dec.recognize_beam_batch( + hs_pad, + torch.tensor(hlens), + lpz, + self.trans_args, + self.char_list, + self.rnnlm, + lang_ids=tgt_lang_ids.squeeze(1).tolist() + if self.multilingual + else None, + ) + # remove and + list_of_refs = [] + hyps = [] + y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps] + for i, y_hat in enumerate(y_hats): + y_true = ys_pad[i] + + seq_hat = [self.char_list[int(idx)] for idx in y_hat if int(idx) != -1] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.trans_args.space, " ") + seq_hat_text = seq_hat_text.replace(self.trans_args.blank, "") + seq_true_text = "".join(seq_true).replace(self.trans_args.space, " ") + + hyps += [seq_hat_text.split(" ")] + list_of_refs += [[seq_true_text.split(" ")]] + + self.bleu = nltk.bleu_score.corpus_bleu(list_of_refs, hyps) * 100 + + asr_ctc_weight = self.mtlalpha + self.loss_asr = ( + asr_ctc_weight * self.loss_asr_ctc + + (1 - asr_ctc_weight) * self.loss_asr_att + ) + self.loss = ( + (1 - self.asr_weight - self.mt_weight) * self.loss_st + + self.asr_weight * self.loss_asr + + self.mt_weight * self.loss_mt + ) + loss_st_data = float(self.loss_st) + loss_asr_data = float(self.loss_asr) + loss_mt_data = float(self.loss_mt) + loss_data = float(self.loss) + if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): + self.reporter.report( + loss_asr_data, + loss_mt_data, + loss_st_data, + acc_asr, + acc_mt, + self.acc, + cer_ctc, + cer, + wer, + self.bleu, + loss_data, + ) + else: + logging.warning("loss (=%f) is not correct", loss_data) + return self.loss + + def forward_asr(self, hs_pad, hlens, ys_pad): + """Forward pass in the auxiliary ASR task. + + :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim) + :param torch.Tensor hlens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :return: ASR attention loss value + :rtype: torch.Tensor + :return: accuracy in ASR attention decoder + :rtype: float + :return: ASR CTC loss value + :rtype: torch.Tensor + :return: character error rate from CTC prediction + :rtype: float + :return: character error rate from attetion decoder prediction + :rtype: float + :return: word error rate from attetion decoder prediction + :rtype: float + """ + loss_att, loss_ctc = 0.0, 0.0 + acc = None + cer, wer = None, None + cer_ctc = None + if self.asr_weight == 0: + return loss_att, acc, loss_ctc, cer_ctc, cer, wer + + # attention + if self.mtlalpha < 1: + loss_asr, acc_asr, _ = self.dec_asr(hs_pad, hlens, ys_pad) + + # Compute wer and cer + if not self.training and (self.report_cer or self.report_wer): + if self.mtlalpha > 0 and self.recog_args.ctc_weight > 0.0: + lpz = self.ctc.log_softmax(hs_pad).data + else: + lpz = None + + word_eds, word_ref_lens, char_eds, char_ref_lens = [], [], [], [] + nbest_hyps_asr = self.dec_asr.recognize_beam_batch( + hs_pad, + torch.tensor(hlens), + lpz, + self.recog_args, + self.char_list, + self.rnnlm, + ) + # remove and + y_hats = [nbest_hyp[0]["yseq"][1:-1] for nbest_hyp in nbest_hyps_asr] + for i, y_hat in enumerate(y_hats): + y_true = ys_pad[i] + + seq_hat = [ + self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 + ] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.recog_args.space, " ") + seq_hat_text = seq_hat_text.replace(self.recog_args.blank, "") + seq_true_text = "".join(seq_true).replace( + self.recog_args.space, " " + ) + + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + hyp_chars = seq_hat_text.replace(" ", "") + ref_chars = seq_true_text.replace(" ", "") + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + + wer = ( + 0.0 + if not self.report_wer + else float(sum(word_eds)) / sum(word_ref_lens) + ) + cer = ( + 0.0 + if not self.report_cer + else float(sum(char_eds)) / sum(char_ref_lens) + ) + + # CTC + if self.mtlalpha > 0: + loss_ctc = self.ctc(hs_pad, hlens, ys_pad) + + # Compute cer with CTC prediction + if self.char_list is not None: + cers = [] + y_hats = self.ctc.argmax(hs_pad).data + for i, y in enumerate(y_hats): + y_hat = [x[0] for x in groupby(y)] + y_true = ys_pad[i] + + seq_hat = [ + self.char_list[int(idx)] for idx in y_hat if int(idx) != -1 + ] + seq_true = [ + self.char_list[int(idx)] for idx in y_true if int(idx) != -1 + ] + seq_hat_text = "".join(seq_hat).replace(self.space, " ") + seq_hat_text = seq_hat_text.replace(self.blank, "") + seq_true_text = "".join(seq_true).replace(self.space, " ") + + hyp_chars = seq_hat_text.replace(" ", "") + ref_chars = seq_true_text.replace(" ", "") + if len(ref_chars) > 0: + cers.append( + editdistance.eval(hyp_chars, ref_chars) / len(ref_chars) + ) + cer_ctc = sum(cers) / len(cers) if cers else None + + return loss_att, acc, loss_ctc, cer_ctc, cer, wer + + def forward_mt(self, xs_pad, ys_pad): + """Forward pass in the auxiliary MT task. + + :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :return: MT loss value + :rtype: torch.Tensor + :return: accuracy in MT decoder + :rtype: float + """ + loss = 0.0 + acc = 0.0 + if self.mt_weight == 0: + return loss, acc + + ilens = torch.sum(xs_pad != -1, dim=1).cpu().numpy() + # NOTE: xs_pad is padded with -1 + ys_src = [y[y != -1] for y in xs_pad] # parse padded ys_src + xs_zero_pad = pad_list(ys_src, self.pad) # re-pad with zero + hs_pad, hlens, _ = self.enc_mt( + self.dropout_mt(self.embed_mt(xs_zero_pad)), ilens + ) + loss, acc, _ = self.dec(hs_pad, hlens, ys_pad) + return loss, acc + + def scorers(self): + """Scorers.""" + return dict(decoder=self.dec) + + def encode(self, x): + """Encode acoustic features. + + :param ndarray x: input acoustic feature (T, D) + :return: encoder outputs + :rtype: torch.Tensor + """ + self.eval() + ilens = [x.shape[0]] + + # subsample frame + x = x[:: self.subsample[0], :] + p = next(self.parameters()) + h = torch.as_tensor(x, device=p.device, dtype=p.dtype) + # make a utt list (1) to use the same interface for encoder + hs = h.contiguous().unsqueeze(0) + + # 1. encoder + hs, _, _ = self.enc(hs, ilens) + return hs.squeeze(0) + + def translate(self, x, trans_args, char_list, rnnlm=None): + """E2E beam search. + + :param ndarray x: input acoustic feature (T, D) + :param Namespace trans_args: argument Namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + logging.info("input lengths: " + str(x.shape[0])) + hs = self.encode(x).unsqueeze(0) + logging.info("encoder output lengths: " + str(hs.size(1))) + + # 2. Decoder + # decode the first utterance + y = self.dec.recognize_beam(hs[0], None, trans_args, char_list, rnnlm) + return y + + def translate_batch(self, xs, trans_args, char_list, rnnlm=None): + """E2E batch beam search. + + :param list xs: list of input acoustic feature arrays [(T_1, D), (T_2, D), ...] + :param Namespace trans_args: argument Namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + prev = self.training + self.eval() + ilens = np.fromiter((xx.shape[0] for xx in xs), dtype=np.int64) + + # subsample frame + xs = [xx[:: self.subsample[0], :] for xx in xs] + xs = [to_device(self, to_torch_tensor(xx).float()) for xx in xs] + xs_pad = pad_list(xs, 0.0) + + # 1. Encoder + hs_pad, hlens, _ = self.enc(xs_pad, ilens) + + # 2. Decoder + hlens = torch.tensor(list(map(int, hlens))) # make sure hlens is tensor + y = self.dec.recognize_beam_batch( + hs_pad, hlens, None, trans_args, char_list, rnnlm + ) + + if prev: + self.train() + return y + + def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_src): + """E2E attention calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :param torch.Tensor ys_pad_src: + batch of padded token id sequence tensor (B, Lmax) + :return: attention weights with the following shape, + 1) multi-head case => attention weights (B, H, Lmax, Tmax), + 2) other case => attention weights (B, Lmax, Tmax). + :rtype: float ndarray + """ + self.eval() + with torch.no_grad(): + # 1. Encoder + if self.multilingual: + tgt_lang_ids = ys_pad[:, 0:1] + ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining + else: + tgt_lang_ids = None + hpad, hlens, _ = self.enc(xs_pad, ilens) + + # 2. Decoder + att_ws = self.dec.calculate_all_attentions( + hpad, hlens, ys_pad, lang_ids=tgt_lang_ids + ) + self.train() + return att_ws + + def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad, ys_pad_src): + """E2E CTC probability calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :param torch.Tensor + ys_pad_src: batch of padded token id sequence tensor (B, Lmax) + :return: CTC probability (B, Tmax, vocab) + :rtype: float ndarray + """ + probs = None + if self.asr_weight == 0 or self.mtlalpha == 0: + return probs + + self.eval() + with torch.no_grad(): + # 1. Encoder + hpad, hlens, _ = self.enc(xs_pad, ilens) + + # 2. CTC probs + probs = self.ctc.softmax(hpad).cpu().numpy() + self.train() + return probs + + def subsample_frames(self, x): + """Subsample speeh frames in the encoder.""" + # subsample frame + x = x[:: self.subsample[0], :] + ilen = [x.shape[0]] + h = to_device(self, torch.from_numpy(np.array(x, dtype=np.float32))) + h.contiguous() + return h, ilen diff --git a/espnet/nets/pytorch_backend/e2e_st_conformer.py b/espnet/nets/pytorch_backend/e2e_st_conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f34bb1f598a3ffe81e6f0a8b64a4ff89194ff653 --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_st_conformer.py @@ -0,0 +1,74 @@ +# Copyright 2020 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +""" +Conformer speech translation model (pytorch). + +It is a fusion of `e2e_st_transformer.py` +Refer to: https://arxiv.org/abs/2005.08100 + +""" + +from espnet.nets.pytorch_backend.conformer.encoder import Encoder +from espnet.nets.pytorch_backend.e2e_st_transformer import E2E as E2ETransformer +from espnet.nets.pytorch_backend.conformer.argument import ( + add_arguments_conformer_common, # noqa: H301 + verify_rel_pos_type, # noqa: H301 +) + + +class E2E(E2ETransformer): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + E2ETransformer.add_arguments(parser) + E2E.add_conformer_arguments(parser) + return parser + + @staticmethod + def add_conformer_arguments(parser): + """Add arguments for conformer model.""" + group = parser.add_argument_group("conformer model specific setting") + group = add_arguments_conformer_common(group) + return parser + + def __init__(self, idim, odim, args, ignore_id=-1): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + super().__init__(idim, odim, args, ignore_id) + if args.transformer_attn_dropout_rate is None: + args.transformer_attn_dropout_rate = args.dropout_rate + + # Check the relative positional encoding type + args = verify_rel_pos_type(args) + + self.encoder = Encoder( + idim=idim, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.eunits, + num_blocks=args.elayers, + input_layer=args.transformer_input_layer, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + attention_dropout_rate=args.transformer_attn_dropout_rate, + pos_enc_layer_type=args.transformer_encoder_pos_enc_layer_type, + selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, + activation_type=args.transformer_encoder_activation_type, + macaron_style=args.macaron_style, + use_cnn_module=args.use_cnn_module, + cnn_module_kernel=args.cnn_module_kernel, + ) + self.reset_parameters(args) diff --git a/espnet/nets/pytorch_backend/e2e_st_transformer.py b/espnet/nets/pytorch_backend/e2e_st_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..64c3b7dcc471b579553d50a24a212984faf9b743 --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_st_transformer.py @@ -0,0 +1,586 @@ +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Transformer speech recognition model (pytorch).""" + +from argparse import Namespace +import logging +import math +import numpy + +import torch + +from espnet.nets.e2e_asr_common import end_detect +from espnet.nets.e2e_asr_common import ErrorCalculator as ASRErrorCalculator +from espnet.nets.e2e_mt_common import ErrorCalculator as MTErrorCalculator +from espnet.nets.pytorch_backend.ctc import CTC +from espnet.nets.pytorch_backend.e2e_asr import CTC_LOSS_THRESHOLD +from espnet.nets.pytorch_backend.e2e_st import Reporter +from espnet.nets.pytorch_backend.nets_utils import get_subsample +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.argument import ( + add_arguments_transformer_common, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.decoder import Decoder +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.initializer import initialize +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet.nets.pytorch_backend.transformer.mask import target_mask +from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport +from espnet.nets.st_interface import STInterface +from espnet.utils.fill_missing_args import fill_missing_args + + +class E2E(STInterface, torch.nn.Module): + """E2E module. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments.""" + group = parser.add_argument_group("transformer model setting") + group = add_arguments_transformer_common(group) + return parser + + @property + def attention_plot_class(self): + """Return PlotAttentionReport.""" + return PlotAttentionReport + + def get_total_subsampling_factor(self): + """Get total subsampling factor.""" + return self.encoder.conv_subsampling_factor * int(numpy.prod(self.subsample)) + + def __init__(self, idim, odim, args, ignore_id=-1): + """Construct an E2E object. + + :param int idim: dimension of inputs + :param int odim: dimension of outputs + :param Namespace args: argument Namespace containing options + """ + torch.nn.Module.__init__(self) + + # fill missing arguments for compatibility + args = fill_missing_args(args, self.add_arguments) + + if args.transformer_attn_dropout_rate is None: + args.transformer_attn_dropout_rate = args.dropout_rate + self.encoder = Encoder( + idim=idim, + selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, + attention_dim=args.adim, + attention_heads=args.aheads, + conv_wshare=args.wshare, + conv_kernel_length=args.ldconv_encoder_kernel_length, + conv_usebias=args.ldconv_usebias, + linear_units=args.eunits, + num_blocks=args.elayers, + input_layer=args.transformer_input_layer, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + attention_dropout_rate=args.transformer_attn_dropout_rate, + ) + self.decoder = Decoder( + odim=odim, + selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, + attention_dim=args.adim, + attention_heads=args.aheads, + conv_wshare=args.wshare, + conv_kernel_length=args.ldconv_decoder_kernel_length, + conv_usebias=args.ldconv_usebias, + linear_units=args.dunits, + num_blocks=args.dlayers, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + self_attention_dropout_rate=args.transformer_attn_dropout_rate, + src_attention_dropout_rate=args.transformer_attn_dropout_rate, + ) + self.pad = 0 # use for padding + self.sos = odim - 1 + self.eos = odim - 1 + self.odim = odim + self.ignore_id = ignore_id + self.subsample = get_subsample(args, mode="st", arch="transformer") + self.reporter = Reporter() + + self.criterion = LabelSmoothingLoss( + self.odim, + self.ignore_id, + args.lsm_weight, + args.transformer_length_normalized_loss, + ) + # submodule for ASR task + self.mtlalpha = args.mtlalpha + self.asr_weight = args.asr_weight + if self.asr_weight > 0 and args.mtlalpha < 1: + self.decoder_asr = Decoder( + odim=odim, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.dunits, + num_blocks=args.dlayers, + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + self_attention_dropout_rate=args.transformer_attn_dropout_rate, + src_attention_dropout_rate=args.transformer_attn_dropout_rate, + ) + + # submodule for MT task + self.mt_weight = args.mt_weight + if self.mt_weight > 0: + self.encoder_mt = Encoder( + idim=odim, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.dunits, + num_blocks=args.dlayers, + input_layer="embed", + dropout_rate=args.dropout_rate, + positional_dropout_rate=args.dropout_rate, + attention_dropout_rate=args.transformer_attn_dropout_rate, + padding_idx=0, + ) + self.reset_parameters(args) # NOTE: place after the submodule initialization + self.adim = args.adim # used for CTC (equal to d_model) + if self.asr_weight > 0 and args.mtlalpha > 0.0: + self.ctc = CTC( + odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True + ) + else: + self.ctc = None + + # translation error calculator + self.error_calculator = MTErrorCalculator( + args.char_list, args.sym_space, args.sym_blank, args.report_bleu + ) + + # recognition error calculator + self.error_calculator_asr = ASRErrorCalculator( + args.char_list, + args.sym_space, + args.sym_blank, + args.report_cer, + args.report_wer, + ) + self.rnnlm = None + + # multilingual E2E-ST related + self.multilingual = getattr(args, "multilingual", False) + self.replace_sos = getattr(args, "replace_sos", False) + + def reset_parameters(self, args): + """Initialize parameters.""" + initialize(self, args.transformer_init) + if self.mt_weight > 0: + torch.nn.init.normal_( + self.encoder_mt.embed[0].weight, mean=0, std=args.adim ** -0.5 + ) + torch.nn.init.constant_(self.encoder_mt.embed[0].weight[self.pad], 0) + + def forward(self, xs_pad, ilens, ys_pad, ys_pad_src): + """E2E forward. + + :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of source sequences (B) + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :param torch.Tensor ys_pad_src: batch of padded target sequences (B, Lmax) + :return: ctc loss value + :rtype: torch.Tensor + :return: attention loss value + :rtype: torch.Tensor + :return: accuracy in attention decoder + :rtype: float + """ + # 0. Extract target language ID + tgt_lang_ids = None + if self.multilingual: + tgt_lang_ids = ys_pad[:, 0:1] + ys_pad = ys_pad[:, 1:] # remove target language ID in the beggining + + # 1. forward encoder + xs_pad = xs_pad[:, : max(ilens)] # for data parallel + src_mask = make_non_pad_mask(ilens.tolist()).to(xs_pad.device).unsqueeze(-2) + hs_pad, hs_mask = self.encoder(xs_pad, src_mask) + + # 2. forward decoder + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + # replace with target language ID + if self.replace_sos: + ys_in_pad = torch.cat([tgt_lang_ids, ys_in_pad[:, 1:]], dim=1) + ys_mask = target_mask(ys_in_pad, self.ignore_id) + pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) + + # 3. compute ST loss + loss_att = self.criterion(pred_pad, ys_out_pad) + + self.acc = th_accuracy( + pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id + ) + + # 4. compute corpus-level bleu in a mini-batch + if self.training: + self.bleu = None + else: + ys_hat = pred_pad.argmax(dim=-1) + self.bleu = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + # 5. compute auxiliary ASR loss + loss_asr_att, acc_asr, loss_asr_ctc, cer_ctc, cer, wer = self.forward_asr( + hs_pad, hs_mask, ys_pad_src + ) + + # 6. compute auxiliary MT loss + loss_mt, acc_mt = 0.0, None + if self.mt_weight > 0: + loss_mt, acc_mt = self.forward_mt( + ys_pad_src, ys_in_pad, ys_out_pad, ys_mask + ) + + asr_ctc_weight = self.mtlalpha + self.loss = ( + (1 - self.asr_weight - self.mt_weight) * loss_att + + self.asr_weight + * (asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att) + + self.mt_weight * loss_mt + ) + loss_asr_data = float( + asr_ctc_weight * loss_asr_ctc + (1 - asr_ctc_weight) * loss_asr_att + ) + loss_mt_data = None if self.mt_weight == 0 else float(loss_mt) + loss_st_data = float(loss_att) + + loss_data = float(self.loss) + if loss_data < CTC_LOSS_THRESHOLD and not math.isnan(loss_data): + self.reporter.report( + loss_asr_data, + loss_mt_data, + loss_st_data, + acc_asr, + acc_mt, + self.acc, + cer_ctc, + cer, + wer, + self.bleu, + loss_data, + ) + else: + logging.warning("loss (=%f) is not correct", loss_data) + return self.loss + + def forward_asr(self, hs_pad, hs_mask, ys_pad): + """Forward pass in the auxiliary ASR task. + + :param torch.Tensor hs_pad: batch of padded source sequences (B, Tmax, idim) + :param torch.Tensor hs_mask: batch of input token mask (B, Lmax) + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :return: ASR attention loss value + :rtype: torch.Tensor + :return: accuracy in ASR attention decoder + :rtype: float + :return: ASR CTC loss value + :rtype: torch.Tensor + :return: character error rate from CTC prediction + :rtype: float + :return: character error rate from attetion decoder prediction + :rtype: float + :return: word error rate from attetion decoder prediction + :rtype: float + """ + loss_att, loss_ctc = 0.0, 0.0 + acc = None + cer, wer = None, None + cer_ctc = None + if self.asr_weight == 0: + return loss_att, acc, loss_ctc, cer_ctc, cer, wer + + # attention + if self.mtlalpha < 1: + ys_in_pad_asr, ys_out_pad_asr = add_sos_eos( + ys_pad, self.sos, self.eos, self.ignore_id + ) + ys_mask_asr = target_mask(ys_in_pad_asr, self.ignore_id) + pred_pad, _ = self.decoder_asr(ys_in_pad_asr, ys_mask_asr, hs_pad, hs_mask) + loss_att = self.criterion(pred_pad, ys_out_pad_asr) + + acc = th_accuracy( + pred_pad.view(-1, self.odim), + ys_out_pad_asr, + ignore_label=self.ignore_id, + ) + if not self.training: + ys_hat_asr = pred_pad.argmax(dim=-1) + cer, wer = self.error_calculator_asr(ys_hat_asr.cpu(), ys_pad.cpu()) + + # CTC + if self.mtlalpha > 0: + batch_size = hs_pad.size(0) + hs_len = hs_mask.view(batch_size, -1).sum(1) + loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) + if not self.training: + ys_hat_ctc = self.ctc.argmax( + hs_pad.view(batch_size, -1, self.adim) + ).data + cer_ctc = self.error_calculator_asr( + ys_hat_ctc.cpu(), ys_pad.cpu(), is_ctc=True + ) + # for visualization + self.ctc.softmax(hs_pad) + return loss_att, acc, loss_ctc, cer_ctc, cer, wer + + def forward_mt(self, xs_pad, ys_in_pad, ys_out_pad, ys_mask): + """Forward pass in the auxiliary MT task. + + :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) + :param torch.Tensor ys_in_pad: batch of padded target sequences (B, Lmax) + :param torch.Tensor ys_out_pad: batch of padded target sequences (B, Lmax) + :param torch.Tensor ys_mask: batch of input token mask (B, Lmax) + :return: MT loss value + :rtype: torch.Tensor + :return: accuracy in MT decoder + :rtype: float + """ + loss, acc = 0.0, None + if self.mt_weight == 0: + return loss, acc + + ilens = torch.sum(xs_pad != self.ignore_id, dim=1).cpu().numpy() + # NOTE: xs_pad is padded with -1 + xs = [x[x != self.ignore_id] for x in xs_pad] # parse padded xs + xs_zero_pad = pad_list(xs, self.pad) # re-pad with zero + xs_zero_pad = xs_zero_pad[:, : max(ilens)] # for data parallel + src_mask = ( + make_non_pad_mask(ilens.tolist()).to(xs_zero_pad.device).unsqueeze(-2) + ) + hs_pad, hs_mask = self.encoder_mt(xs_zero_pad, src_mask) + pred_pad, _ = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) + loss = self.criterion(pred_pad, ys_out_pad) + acc = th_accuracy( + pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id + ) + return loss, acc + + def scorers(self): + """Scorers.""" + return dict(decoder=self.decoder) + + def encode(self, x): + """Encode source acoustic features. + + :param ndarray x: source acoustic feature (T, D) + :return: encoder outputs + :rtype: torch.Tensor + """ + self.eval() + x = torch.as_tensor(x).unsqueeze(0) + enc_output, _ = self.encoder(x, None) + return enc_output.squeeze(0) + + def translate( + self, + x, + trans_args, + char_list=None, + ): + """Translate input speech. + + :param ndnarray x: input acoustic feature (B, T, D) or (T, D) + :param Namespace trans_args: argment Namespace contraining options + :param list char_list: list of characters + :return: N-best decoding results + :rtype: list + """ + # preprate sos + if getattr(trans_args, "tgt_lang", False): + if self.replace_sos: + y = char_list.index(trans_args.tgt_lang) + else: + y = self.sos + logging.info(" index: " + str(y)) + logging.info(" mark: " + char_list[y]) + logging.info("input lengths: " + str(x.shape[0])) + + enc_output = self.encode(x).unsqueeze(0) + + h = enc_output + + logging.info("encoder output lengths: " + str(h.size(1))) + # search parms + beam = trans_args.beam_size + penalty = trans_args.penalty + + if trans_args.maxlenratio == 0: + maxlen = h.size(1) + else: + # maxlen >= 1 + maxlen = max(1, int(trans_args.maxlenratio * h.size(1))) + minlen = int(trans_args.minlenratio * h.size(1)) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # initialize hypothesis + hyp = {"score": 0.0, "yseq": [y]} + hyps = [hyp] + ended_hyps = [] + + for i in range(maxlen): + logging.debug("position " + str(i)) + + # batchfy + ys = h.new_zeros((len(hyps), i + 1), dtype=torch.int64) + for j, hyp in enumerate(hyps): + ys[j, :] = torch.tensor(hyp["yseq"]) + ys_mask = subsequent_mask(i + 1).unsqueeze(0).to(h.device) + + local_scores = self.decoder.forward_one_step( + ys, ys_mask, h.repeat([len(hyps), 1, 1]) + )[0] + + hyps_best_kept = [] + for j, hyp in enumerate(hyps): + local_best_scores, local_best_ids = torch.topk( + local_scores[j : j + 1], beam, dim=1 + ) + + for j in range(beam): + new_hyp = {} + new_hyp["score"] = hyp["score"] + float(local_best_scores[0, j]) + new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) + new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] + new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) + # will be (2 x beam) hyps at most + hyps_best_kept.append(new_hyp) + + hyps_best_kept = sorted( + hyps_best_kept, key=lambda x: x["score"], reverse=True + )[:beam] + + # sort and get nbest + hyps = hyps_best_kept + logging.debug("number of pruned hypothes: " + str(len(hyps))) + if char_list is not None: + logging.debug( + "best hypo: " + + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) + ) + + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last postion in the loop") + for hyp in hyps: + hyp["yseq"].append(self.eos) + + # add ended hypothes to a final list, and removed them from current hypothes + # (this will be a probmlem, number of hyps < beam) + remained_hyps = [] + for hyp in hyps: + if hyp["yseq"][-1] == self.eos: + # only store the sequence that has more than minlen outputs + # also add penalty + if len(hyp["yseq"]) > minlen: + hyp["score"] += (i + 1) * penalty + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + + # end detection + if end_detect(ended_hyps, i) and trans_args.maxlenratio == 0.0: + logging.info("end detected at %d", i) + break + + hyps = remained_hyps + if len(hyps) > 0: + logging.debug("remeined hypothes: " + str(len(hyps))) + else: + logging.info("no hypothesis. Finish decoding.") + break + + if char_list is not None: + for hyp in hyps: + logging.debug( + "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) + ) + + logging.debug("number of ended hypothes: " + str(len(ended_hyps))) + + nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ + : min(len(ended_hyps), trans_args.nbest) + ] + + # check number of hypotheis + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, perform translation " + "again with smaller minlenratio." + ) + # should copy becasuse Namespace will be overwritten globally + trans_args = Namespace(**vars(trans_args)) + trans_args.minlenratio = max(0.0, trans_args.minlenratio - 0.1) + return self.translate(x, trans_args, char_list) + + logging.info("total log probability: " + str(nbest_hyps[0]["score"])) + logging.info( + "normalized log probability: " + + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) + ) + return nbest_hyps + + def calculate_all_attentions(self, xs_pad, ilens, ys_pad, ys_pad_src): + """E2E attention calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :param torch.Tensor ys_pad_src: + batch of padded token id sequence tensor (B, Lmax) + :return: attention weights (B, H, Lmax, Tmax) + :rtype: float ndarray + """ + self.eval() + with torch.no_grad(): + self.forward(xs_pad, ilens, ys_pad, ys_pad_src) + ret = dict() + for name, m in self.named_modules(): + if ( + isinstance(m, MultiHeadedAttention) and m.attn is not None + ): # skip MHA for submodules + ret[name] = m.attn.cpu().numpy() + self.train() + return ret + + def calculate_all_ctc_probs(self, xs_pad, ilens, ys_pad, ys_pad_src): + """E2E CTC probability calculation. + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor ys_pad: batch of padded token id sequence tensor (B, Lmax) + :param torch.Tensor ys_pad_src: + batch of padded token id sequence tensor (B, Lmax) + :return: CTC probability (B, Tmax, vocab) + :rtype: float ndarray + """ + ret = None + if self.asr_weight == 0 or self.mtlalpha == 0: + return ret + + self.eval() + with torch.no_grad(): + self.forward(xs_pad, ilens, ys_pad, ys_pad_src) + ret = None + for name, m in self.named_modules(): + if isinstance(m, CTC) and m.probs is not None: + ret = m.probs.cpu().numpy() + self.train() + return ret diff --git a/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py b/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a3069e53c30cfa21cd93202c14dd2a6f4e31d6 --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_tts_fastspeech.py @@ -0,0 +1,899 @@ +# Copyright 2019 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""FastSpeech related modules.""" + +import logging + +import torch +import torch.nn.functional as F + +from espnet.asr.asr_utils import get_model_conf +from espnet.asr.asr_utils import torch_load +from espnet.nets.pytorch_backend.fastspeech.duration_calculator import ( + DurationCalculator, # noqa: H301 +) +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import ( + DurationPredictorLoss, # noqa: H301 +) +from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.initializer import initialize +from espnet.nets.tts_interface import TTSInterface +from espnet.utils.cli_utils import strtobool +from espnet.utils.fill_missing_args import fill_missing_args + + +class FeedForwardTransformerLoss(torch.nn.Module): + """Loss function module for feed-forward Transformer.""" + + def __init__(self, use_masking=True, use_weighted_masking=False): + """Initialize feed-forward Transformer loss module. + + Args: + use_masking (bool): + Whether to apply masking for padded part in loss calculation. + use_weighted_masking (bool): + Whether to weighted masking in loss calculation. + + """ + super(FeedForwardTransformerLoss, self).__init__() + assert (use_masking != use_weighted_masking) or not use_masking + self.use_masking = use_masking + self.use_weighted_masking = use_weighted_masking + + # define criterions + reduction = "none" if self.use_weighted_masking else "mean" + self.l1_criterion = torch.nn.L1Loss(reduction=reduction) + self.duration_criterion = DurationPredictorLoss(reduction=reduction) + + def forward(self, after_outs, before_outs, d_outs, ys, ds, ilens, olens): + """Calculate forward propagation. + + Args: + after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). + before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). + d_outs (Tensor): Batch of outputs of duration predictor (B, Tmax). + ys (Tensor): Batch of target features (B, Lmax, odim). + ds (Tensor): Batch of durations (B, Tmax). + ilens (LongTensor): Batch of the lengths of each input (B,). + olens (LongTensor): Batch of the lengths of each target (B,). + + Returns: + Tensor: L1 loss value. + Tensor: Duration predictor loss value. + + """ + # apply mask to remove padded part + if self.use_masking: + duration_masks = make_non_pad_mask(ilens).to(ys.device) + d_outs = d_outs.masked_select(duration_masks) + ds = ds.masked_select(duration_masks) + out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) + before_outs = before_outs.masked_select(out_masks) + after_outs = ( + after_outs.masked_select(out_masks) if after_outs is not None else None + ) + ys = ys.masked_select(out_masks) + + # calculate loss + l1_loss = self.l1_criterion(before_outs, ys) + if after_outs is not None: + l1_loss += self.l1_criterion(after_outs, ys) + duration_loss = self.duration_criterion(d_outs, ds) + + # make weighted mask and apply it + if self.use_weighted_masking: + out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) + out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() + out_weights /= ys.size(0) * ys.size(2) + duration_masks = make_non_pad_mask(ilens).to(ys.device) + duration_weights = ( + duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float() + ) + duration_weights /= ds.size(0) + + # apply weight + l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() + duration_loss = ( + duration_loss.mul(duration_weights).masked_select(duration_masks).sum() + ) + + return l1_loss, duration_loss + + +class FeedForwardTransformer(TTSInterface, torch.nn.Module): + """Feed Forward Transformer for TTS a.k.a. FastSpeech. + + This is a module of FastSpeech, + feed-forward Transformer with duration predictor described in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_, + which does not require any auto-regressive + processing during inference, + resulting in fast decoding compared with auto-regressive Transformer. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + @staticmethod + def add_arguments(parser): + """Add model-specific arguments to the parser.""" + group = parser.add_argument_group("feed-forward transformer model setting") + # network structure related + group.add_argument( + "--adim", + default=384, + type=int, + help="Number of attention transformation dimensions", + ) + group.add_argument( + "--aheads", + default=4, + type=int, + help="Number of heads for multi head attention", + ) + group.add_argument( + "--elayers", default=6, type=int, help="Number of encoder layers" + ) + group.add_argument( + "--eunits", default=1536, type=int, help="Number of encoder hidden units" + ) + group.add_argument( + "--dlayers", default=6, type=int, help="Number of decoder layers" + ) + group.add_argument( + "--dunits", default=1536, type=int, help="Number of decoder hidden units" + ) + group.add_argument( + "--positionwise-layer-type", + default="linear", + type=str, + choices=["linear", "conv1d", "conv1d-linear"], + help="Positionwise layer type.", + ) + group.add_argument( + "--positionwise-conv-kernel-size", + default=3, + type=int, + help="Kernel size of positionwise conv1d layer", + ) + group.add_argument( + "--postnet-layers", default=0, type=int, help="Number of postnet layers" + ) + group.add_argument( + "--postnet-chans", default=256, type=int, help="Number of postnet channels" + ) + group.add_argument( + "--postnet-filts", default=5, type=int, help="Filter size of postnet" + ) + group.add_argument( + "--use-batch-norm", + default=True, + type=strtobool, + help="Whether to use batch normalization", + ) + group.add_argument( + "--use-scaled-pos-enc", + default=True, + type=strtobool, + help="Use trainable scaled positional encoding " + "instead of the fixed scale one", + ) + group.add_argument( + "--encoder-normalize-before", + default=False, + type=strtobool, + help="Whether to apply layer norm before encoder block", + ) + group.add_argument( + "--decoder-normalize-before", + default=False, + type=strtobool, + help="Whether to apply layer norm before decoder block", + ) + group.add_argument( + "--encoder-concat-after", + default=False, + type=strtobool, + help="Whether to concatenate attention layer's input and output in encoder", + ) + group.add_argument( + "--decoder-concat-after", + default=False, + type=strtobool, + help="Whether to concatenate attention layer's input and output in decoder", + ) + group.add_argument( + "--duration-predictor-layers", + default=2, + type=int, + help="Number of layers in duration predictor", + ) + group.add_argument( + "--duration-predictor-chans", + default=384, + type=int, + help="Number of channels in duration predictor", + ) + group.add_argument( + "--duration-predictor-kernel-size", + default=3, + type=int, + help="Kernel size in duration predictor", + ) + group.add_argument( + "--teacher-model", + default=None, + type=str, + nargs="?", + help="Teacher model file path", + ) + group.add_argument( + "--reduction-factor", default=1, type=int, help="Reduction factor" + ) + group.add_argument( + "--spk-embed-dim", + default=None, + type=int, + help="Number of speaker embedding dimensions", + ) + group.add_argument( + "--spk-embed-integration-type", + type=str, + default="add", + choices=["add", "concat"], + help="How to integrate speaker embedding", + ) + # training related + group.add_argument( + "--transformer-init", + type=str, + default="pytorch", + choices=[ + "pytorch", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + ], + help="How to initialize transformer parameters", + ) + group.add_argument( + "--initial-encoder-alpha", + type=float, + default=1.0, + help="Initial alpha value in encoder's ScaledPositionalEncoding", + ) + group.add_argument( + "--initial-decoder-alpha", + type=float, + default=1.0, + help="Initial alpha value in decoder's ScaledPositionalEncoding", + ) + group.add_argument( + "--transformer-lr", + default=1.0, + type=float, + help="Initial value of learning rate", + ) + group.add_argument( + "--transformer-warmup-steps", + default=4000, + type=int, + help="Optimizer warmup steps", + ) + group.add_argument( + "--transformer-enc-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder except for attention", + ) + group.add_argument( + "--transformer-enc-positional-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder positional encoding", + ) + group.add_argument( + "--transformer-enc-attn-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder self-attention", + ) + group.add_argument( + "--transformer-dec-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer decoder except " + "for attention and pos encoding", + ) + group.add_argument( + "--transformer-dec-positional-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer decoder positional encoding", + ) + group.add_argument( + "--transformer-dec-attn-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer decoder self-attention", + ) + group.add_argument( + "--transformer-enc-dec-attn-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder-decoder attention", + ) + group.add_argument( + "--duration-predictor-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for duration predictor", + ) + group.add_argument( + "--postnet-dropout-rate", + default=0.5, + type=float, + help="Dropout rate in postnet", + ) + group.add_argument( + "--transfer-encoder-from-teacher", + default=True, + type=strtobool, + help="Whether to transfer teacher's parameters", + ) + group.add_argument( + "--transferred-encoder-module", + default="all", + type=str, + choices=["all", "embed"], + help="Encoder modeules to be trasferred from teacher", + ) + # loss related + group.add_argument( + "--use-masking", + default=True, + type=strtobool, + help="Whether to use masking in calculation of loss", + ) + group.add_argument( + "--use-weighted-masking", + default=False, + type=strtobool, + help="Whether to use weighted masking in calculation of loss", + ) + return parser + + def __init__(self, idim, odim, args=None): + """Initialize feed-forward Transformer module. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + args (Namespace, optional): + - elayers (int): Number of encoder layers. + - eunits (int): Number of encoder hidden units. + - adim (int): Number of attention transformation dimensions. + - aheads (int): Number of heads for multi head attention. + - dlayers (int): Number of decoder layers. + - dunits (int): Number of decoder hidden units. + - use_scaled_pos_enc (bool): + Whether to use trainable scaled positional encoding. + - encoder_normalize_before (bool): + Whether to perform layer normalization before encoder block. + - decoder_normalize_before (bool): + Whether to perform layer normalization before decoder block. + - encoder_concat_after (bool): Whether to concatenate attention + layer's input and output in encoder. + - decoder_concat_after (bool): Whether to concatenate attention + layer's input and output in decoder. + - duration_predictor_layers (int): Number of duration predictor layers. + - duration_predictor_chans (int): Number of duration predictor channels. + - duration_predictor_kernel_size (int): + Kernel size of duration predictor. + - spk_embed_dim (int): Number of speaker embedding dimensions. + - spk_embed_integration_type: How to integrate speaker embedding. + - teacher_model (str): Teacher auto-regressive transformer model path. + - reduction_factor (int): Reduction factor. + - transformer_init (float): How to initialize transformer parameters. + - transformer_lr (float): Initial value of learning rate. + - transformer_warmup_steps (int): Optimizer warmup steps. + - transformer_enc_dropout_rate (float): + Dropout rate in encoder except attention & positional encoding. + - transformer_enc_positional_dropout_rate (float): + Dropout rate after encoder positional encoding. + - transformer_enc_attn_dropout_rate (float): + Dropout rate in encoder self-attention module. + - transformer_dec_dropout_rate (float): + Dropout rate in decoder except attention & positional encoding. + - transformer_dec_positional_dropout_rate (float): + Dropout rate after decoder positional encoding. + - transformer_dec_attn_dropout_rate (float): + Dropout rate in deocoder self-attention module. + - transformer_enc_dec_attn_dropout_rate (float): + Dropout rate in encoder-deocoder attention module. + - use_masking (bool): + Whether to apply masking for padded part in loss calculation. + - use_weighted_masking (bool): + Whether to apply weighted masking in loss calculation. + - transfer_encoder_from_teacher: + Whether to transfer encoder using teacher encoder parameters. + - transferred_encoder_module: + Encoder module to be initialized using teacher parameters. + + """ + # initialize base classes + TTSInterface.__init__(self) + torch.nn.Module.__init__(self) + + # fill missing arguments + args = fill_missing_args(args, self.add_arguments) + + # store hyperparameters + self.idim = idim + self.odim = odim + self.reduction_factor = args.reduction_factor + self.use_scaled_pos_enc = args.use_scaled_pos_enc + self.spk_embed_dim = args.spk_embed_dim + if self.spk_embed_dim is not None: + self.spk_embed_integration_type = args.spk_embed_integration_type + + # use idx 0 as padding idx + padding_idx = 0 + + # get positional encoding class + pos_enc_class = ( + ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding + ) + + # define encoder + encoder_input_layer = torch.nn.Embedding( + num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx + ) + self.encoder = Encoder( + idim=idim, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.eunits, + num_blocks=args.elayers, + input_layer=encoder_input_layer, + dropout_rate=args.transformer_enc_dropout_rate, + positional_dropout_rate=args.transformer_enc_positional_dropout_rate, + attention_dropout_rate=args.transformer_enc_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=args.encoder_normalize_before, + concat_after=args.encoder_concat_after, + positionwise_layer_type=args.positionwise_layer_type, + positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, + ) + + # define additional projection for speaker embedding + if self.spk_embed_dim is not None: + if self.spk_embed_integration_type == "add": + self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) + else: + self.projection = torch.nn.Linear( + args.adim + self.spk_embed_dim, args.adim + ) + + # define duration predictor + self.duration_predictor = DurationPredictor( + idim=args.adim, + n_layers=args.duration_predictor_layers, + n_chans=args.duration_predictor_chans, + kernel_size=args.duration_predictor_kernel_size, + dropout_rate=args.duration_predictor_dropout_rate, + ) + + # define length regulator + self.length_regulator = LengthRegulator() + + # define decoder + # NOTE: we use encoder as decoder + # because fastspeech's decoder is the same as encoder + self.decoder = Encoder( + idim=0, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.dunits, + num_blocks=args.dlayers, + input_layer=None, + dropout_rate=args.transformer_dec_dropout_rate, + positional_dropout_rate=args.transformer_dec_positional_dropout_rate, + attention_dropout_rate=args.transformer_dec_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=args.decoder_normalize_before, + concat_after=args.decoder_concat_after, + positionwise_layer_type=args.positionwise_layer_type, + positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, + ) + + # define final projection + self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) + + # define postnet + self.postnet = ( + None + if args.postnet_layers == 0 + else Postnet( + idim=idim, + odim=odim, + n_layers=args.postnet_layers, + n_chans=args.postnet_chans, + n_filts=args.postnet_filts, + use_batch_norm=args.use_batch_norm, + dropout_rate=args.postnet_dropout_rate, + ) + ) + + # initialize parameters + self._reset_parameters( + init_type=args.transformer_init, + init_enc_alpha=args.initial_encoder_alpha, + init_dec_alpha=args.initial_decoder_alpha, + ) + + # define teacher model + if args.teacher_model is not None: + self.teacher = self._load_teacher_model(args.teacher_model) + else: + self.teacher = None + + # define duration calculator + if self.teacher is not None: + self.duration_calculator = DurationCalculator(self.teacher) + else: + self.duration_calculator = None + + # transfer teacher parameters + if self.teacher is not None and args.transfer_encoder_from_teacher: + self._transfer_from_teacher(args.transferred_encoder_module) + + # define criterions + self.criterion = FeedForwardTransformerLoss( + use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking + ) + + def _forward( + self, + xs, + ilens, + ys=None, + olens=None, + spembs=None, + ds=None, + is_inference=False, + alpha=1.0, + ): + # forward encoder + x_masks = self._source_mask(ilens) + hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + # forward duration predictor and length regulator + d_masks = make_pad_mask(ilens).to(xs.device) + if is_inference: + d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) + hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim) + else: + if ds is None: + with torch.no_grad(): + ds = self.duration_calculator( + xs, ilens, ys, olens, spembs + ) # (B, Tmax) + d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) + hs = self.length_regulator(hs, ds) # (B, Lmax, adim) + + # forward decoder + if olens is not None: + if self.reduction_factor > 1: + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + olens_in = olens + h_masks = self._source_mask(olens_in) + else: + h_masks = None + zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) + before_outs = self.feat_out(zs).view( + zs.size(0), -1, self.odim + ) # (B, Lmax, odim) + + # postnet -> (B, Lmax//r * r, odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose(1, 2) + ).transpose(1, 2) + + if is_inference: + return before_outs, after_outs, d_outs + else: + return before_outs, after_outs, ds, d_outs + + def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of padded character ids (B, Tmax). + ilens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + olens (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): + Batch of speaker embedding vectors (B, spk_embed_dim). + extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1). + + Returns: + Tensor: Loss value. + + """ + # remove unnecessary padded part (for multi-gpus) + xs = xs[:, : max(ilens)] + ys = ys[:, : max(olens)] + if extras is not None: + extras = extras[:, : max(ilens)].squeeze(-1) + + # forward propagation + before_outs, after_outs, ds, d_outs = self._forward( + xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False + ) + + # modifiy mod part of groundtruth + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + max_olen = max(olens) + ys = ys[:, :max_olen] + + # calculate loss + if self.postnet is None: + l1_loss, duration_loss = self.criterion( + None, before_outs, d_outs, ys, ds, ilens, olens + ) + else: + l1_loss, duration_loss = self.criterion( + after_outs, before_outs, d_outs, ys, ds, ilens, olens + ) + loss = l1_loss + duration_loss + report_keys = [ + {"l1_loss": l1_loss.item()}, + {"duration_loss": duration_loss.item()}, + {"loss": loss.item()}, + ] + + # report extra information + if self.use_scaled_pos_enc: + report_keys += [ + {"encoder_alpha": self.encoder.embed[-1].alpha.data.item()}, + {"decoder_alpha": self.decoder.embed[-1].alpha.data.item()}, + ] + self.reporter.report(report_keys) + + return loss + + def calculate_all_attentions( + self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs + ): + """Calculate all of the attention weights. + + Args: + xs (Tensor): Batch of padded character ids (B, Tmax). + ilens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + olens (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): + Batch of speaker embedding vectors (B, spk_embed_dim). + extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1). + + Returns: + dict: Dict of attention weights and outputs. + + """ + with torch.no_grad(): + # remove unnecessary padded part (for multi-gpus) + xs = xs[:, : max(ilens)] + ys = ys[:, : max(olens)] + if extras is not None: + extras = extras[:, : max(ilens)].squeeze(-1) + + # forward propagation + outs = self._forward( + xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False + )[1] + + att_ws_dict = dict() + for name, m in self.named_modules(): + if isinstance(m, MultiHeadedAttention): + attn = m.attn.cpu().numpy() + if "encoder" in name: + attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())] + elif "decoder" in name: + if "src" in name: + attn = [ + a[:, :ol, :il] + for a, il, ol in zip(attn, ilens.tolist(), olens.tolist()) + ] + elif "self" in name: + attn = [a[:, :l, :l] for a, l in zip(attn, olens.tolist())] + else: + logging.warning("unknown attention module: " + name) + else: + logging.warning("unknown attention module: " + name) + att_ws_dict[name] = attn + att_ws_dict["predicted_fbank"] = [ + m[:l].T for m, l in zip(outs.cpu().numpy(), olens.tolist()) + ] + + return att_ws_dict + + def inference(self, x, inference_args, spemb=None, *args, **kwargs): + """Generate the sequence of features given the sequences of characters. + + Args: + x (Tensor): Input sequence of characters (T,). + inference_args (Namespace): Dummy for compatibility. + spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). + + Returns: + Tensor: Output sequence of features (L, odim). + None: Dummy for compatibility. + None: Dummy for compatibility. + + """ + # setup batch axis + ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) + xs = x.unsqueeze(0) + if spemb is not None: + spembs = spemb.unsqueeze(0) + else: + spembs = None + + # get option + alpha = getattr(inference_args, "fastspeech_alpha", 1.0) + + # inference + _, outs, _ = self._forward( + xs, + ilens, + spembs=spembs, + is_inference=True, + alpha=alpha, + ) # (1, L, odim) + + return outs[0], None, None + + def _integrate_with_spk_embed(self, hs, spembs): + """Integrate speaker embedding with hidden states. + + Args: + hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). + spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) + + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spembs = self.projection(F.normalize(spembs)) + hs = hs + spembs.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds and then apply projection + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = self.projection(torch.cat([hs, spembs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs + + def _source_mask(self, ilens): + """Make masks for self-attention. + + Args: + ilens (LongTensor or List): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + return x_masks.unsqueeze(-2) + + def _load_teacher_model(self, model_path): + # get teacher model config + idim, odim, args = get_model_conf(model_path) + + # assert dimension is the same between teacher and studnet + assert idim == self.idim + assert odim == self.odim + assert args.reduction_factor == self.reduction_factor + + # load teacher model + from espnet.utils.dynamic_import import dynamic_import + + model_class = dynamic_import(args.model_module) + model = model_class(idim, odim, args) + torch_load(model_path, model) + + # freeze teacher model parameters + for p in model.parameters(): + p.requires_grad = False + + return model + + def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): + # initialize parameters + initialize(self, init_type) + + # initialize alpha in scaled positional encoding + if self.use_scaled_pos_enc: + self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) + self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) + + def _transfer_from_teacher(self, transferred_encoder_module): + if transferred_encoder_module == "all": + for (n1, p1), (n2, p2) in zip( + self.encoder.named_parameters(), self.teacher.encoder.named_parameters() + ): + assert n1 == n2, "It seems that encoder structure is different." + assert p1.shape == p2.shape, "It seems that encoder size is different." + p1.data.copy_(p2.data) + elif transferred_encoder_module == "embed": + student_shape = self.encoder.embed[0].weight.data.shape + teacher_shape = self.teacher.encoder.embed[0].weight.data.shape + assert ( + student_shape == teacher_shape + ), "It seems that embed dimension is different." + self.encoder.embed[0].weight.data.copy_( + self.teacher.encoder.embed[0].weight.data + ) + else: + raise NotImplementedError("Support only all or embed.") + + @property + def attention_plot_class(self): + """Return plot class for attention weight plot.""" + # Lazy import to avoid chainer dependency + from espnet.nets.pytorch_backend.e2e_tts_transformer import TTSPlot + + return TTSPlot + + @property + def base_plot_keys(self): + """Return base key names to plot during training. + + keys should match what `chainer.reporter` reports. + If you add the key `loss`, + the reporter will report `main/loss` and `validation/main/loss` values. + also `loss.png` will be created as a figure visulizing `main/loss` + and `validation/main/loss` values. + + Returns: + list: List of strings which are base keys to plot during training. + + """ + plot_keys = ["loss", "l1_loss", "duration_loss"] + if self.use_scaled_pos_enc: + plot_keys += ["encoder_alpha", "decoder_alpha"] + + return plot_keys diff --git a/espnet/nets/pytorch_backend/e2e_tts_tacotron2.py b/espnet/nets/pytorch_backend/e2e_tts_tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..c41dd4262b3cc63807a451d2351b062294b12178 --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_tts_tacotron2.py @@ -0,0 +1,893 @@ +# Copyright 2018 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Tacotron 2 related modules.""" + +import logging + +import numpy as np +import torch +import torch.nn.functional as F + +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.rnn.attentions import AttForward +from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA +from espnet.nets.pytorch_backend.rnn.attentions import AttLoc +from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHG +from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHGLoss +from espnet.nets.pytorch_backend.tacotron2.decoder import Decoder +from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder +from espnet.nets.tts_interface import TTSInterface +from espnet.utils.cli_utils import strtobool +from espnet.utils.fill_missing_args import fill_missing_args + + +class GuidedAttentionLoss(torch.nn.Module): + """Guided attention loss function module. + + This module calculates the guided attention loss described + in `Efficiently Trainable Text-to-Speech System Based + on Deep Convolutional Networks with Guided Attention`_, + which forces the attention to be diagonal. + + .. _`Efficiently Trainable Text-to-Speech System + Based on Deep Convolutional Networks with Guided Attention`: + https://arxiv.org/abs/1710.08969 + + """ + + def __init__(self, sigma=0.4, alpha=1.0, reset_always=True): + """Initialize guided attention loss module. + + Args: + sigma (float, optional): Standard deviation to control + how close attention to a diagonal. + alpha (float, optional): Scaling coefficient (lambda). + reset_always (bool, optional): Whether to always reset masks. + + """ + super(GuidedAttentionLoss, self).__init__() + self.sigma = sigma + self.alpha = alpha + self.reset_always = reset_always + self.guided_attn_masks = None + self.masks = None + + def _reset_masks(self): + self.guided_attn_masks = None + self.masks = None + + def forward(self, att_ws, ilens, olens): + """Calculate forward propagation. + + Args: + att_ws (Tensor): Batch of attention weights (B, T_max_out, T_max_in). + ilens (LongTensor): Batch of input lenghts (B,). + olens (LongTensor): Batch of output lenghts (B,). + + Returns: + Tensor: Guided attention loss value. + + """ + if self.guided_attn_masks is None: + self.guided_attn_masks = self._make_guided_attention_masks(ilens, olens).to( + att_ws.device + ) + if self.masks is None: + self.masks = self._make_masks(ilens, olens).to(att_ws.device) + losses = self.guided_attn_masks * att_ws + loss = torch.mean(losses.masked_select(self.masks)) + if self.reset_always: + self._reset_masks() + return self.alpha * loss + + def _make_guided_attention_masks(self, ilens, olens): + n_batches = len(ilens) + max_ilen = max(ilens) + max_olen = max(olens) + guided_attn_masks = torch.zeros((n_batches, max_olen, max_ilen)) + for idx, (ilen, olen) in enumerate(zip(ilens, olens)): + guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask( + ilen, olen, self.sigma + ) + return guided_attn_masks + + @staticmethod + def _make_guided_attention_mask(ilen, olen, sigma): + """Make guided attention mask. + + Examples: + >>> guided_attn_mask =_make_guided_attention(5, 5, 0.4) + >>> guided_attn_mask.shape + torch.Size([5, 5]) + >>> guided_attn_mask + tensor([[0.0000, 0.1175, 0.3935, 0.6753, 0.8647], + [0.1175, 0.0000, 0.1175, 0.3935, 0.6753], + [0.3935, 0.1175, 0.0000, 0.1175, 0.3935], + [0.6753, 0.3935, 0.1175, 0.0000, 0.1175], + [0.8647, 0.6753, 0.3935, 0.1175, 0.0000]]) + >>> guided_attn_mask =_make_guided_attention(3, 6, 0.4) + >>> guided_attn_mask.shape + torch.Size([6, 3]) + >>> guided_attn_mask + tensor([[0.0000, 0.2934, 0.7506], + [0.0831, 0.0831, 0.5422], + [0.2934, 0.0000, 0.2934], + [0.5422, 0.0831, 0.0831], + [0.7506, 0.2934, 0.0000], + [0.8858, 0.5422, 0.0831]]) + + """ + grid_x, grid_y = torch.meshgrid(torch.arange(olen), torch.arange(ilen)) + grid_x, grid_y = grid_x.float().to(olen.device), grid_y.float().to(ilen.device) + return 1.0 - torch.exp( + -((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)) + ) + + @staticmethod + def _make_masks(ilens, olens): + """Make masks indicating non-padded part. + + Args: + ilens (LongTensor or List): Batch of lengths (B,). + olens (LongTensor or List): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor indicating non-padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> ilens, olens = [5, 2], [8, 5] + >>> _make_mask(ilens, olens) + tensor([[[1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1]], + [[1, 1, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]], dtype=torch.uint8) + + """ + in_masks = make_non_pad_mask(ilens) # (B, T_in) + out_masks = make_non_pad_mask(olens) # (B, T_out) + return out_masks.unsqueeze(-1) & in_masks.unsqueeze(-2) # (B, T_out, T_in) + + +class Tacotron2Loss(torch.nn.Module): + """Loss function module for Tacotron2.""" + + def __init__( + self, use_masking=True, use_weighted_masking=False, bce_pos_weight=20.0 + ): + """Initialize Tactoron2 loss module. + + Args: + use_masking (bool): Whether to apply masking + for padded part in loss calculation. + use_weighted_masking (bool): + Whether to apply weighted masking in loss calculation. + bce_pos_weight (float): Weight of positive sample of stop token. + + """ + super(Tacotron2Loss, self).__init__() + assert (use_masking != use_weighted_masking) or not use_masking + self.use_masking = use_masking + self.use_weighted_masking = use_weighted_masking + + # define criterions + reduction = "none" if self.use_weighted_masking else "mean" + self.l1_criterion = torch.nn.L1Loss(reduction=reduction) + self.mse_criterion = torch.nn.MSELoss(reduction=reduction) + self.bce_criterion = torch.nn.BCEWithLogitsLoss( + reduction=reduction, pos_weight=torch.tensor(bce_pos_weight) + ) + + # NOTE(kan-bayashi): register pre hook function for the compatibility + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + + def forward(self, after_outs, before_outs, logits, ys, labels, olens): + """Calculate forward propagation. + + Args: + after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). + before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). + logits (Tensor): Batch of stop logits (B, Lmax). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + labels (LongTensor): Batch of the sequences of stop token labels (B, Lmax). + olens (LongTensor): Batch of the lengths of each target (B,). + + Returns: + Tensor: L1 loss value. + Tensor: Mean square error loss value. + Tensor: Binary cross entropy loss value. + + """ + # make mask and apply it + if self.use_masking: + masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) + ys = ys.masked_select(masks) + after_outs = after_outs.masked_select(masks) + before_outs = before_outs.masked_select(masks) + labels = labels.masked_select(masks[:, :, 0]) + logits = logits.masked_select(masks[:, :, 0]) + + # calculate loss + l1_loss = self.l1_criterion(after_outs, ys) + self.l1_criterion(before_outs, ys) + mse_loss = self.mse_criterion(after_outs, ys) + self.mse_criterion( + before_outs, ys + ) + bce_loss = self.bce_criterion(logits, labels) + + # make weighted mask and apply it + if self.use_weighted_masking: + masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) + weights = masks.float() / masks.sum(dim=1, keepdim=True).float() + out_weights = weights.div(ys.size(0) * ys.size(2)) + logit_weights = weights.div(ys.size(0)) + + # apply weight + l1_loss = l1_loss.mul(out_weights).masked_select(masks).sum() + mse_loss = mse_loss.mul(out_weights).masked_select(masks).sum() + bce_loss = ( + bce_loss.mul(logit_weights.squeeze(-1)) + .masked_select(masks.squeeze(-1)) + .sum() + ) + + return l1_loss, mse_loss, bce_loss + + def _load_state_dict_pre_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """Apply pre hook fucntion before loading state dict. + + From v.0.6.1 `bce_criterion.pos_weight` param is registered as a parameter but + old models do not include it and as a result, it causes missing key error when + loading old model parameter. This function solve the issue by adding param in + state dict before loading as a pre hook function + of the `load_state_dict` method. + + """ + key = prefix + "bce_criterion.pos_weight" + if key not in state_dict: + state_dict[key] = self.bce_criterion.pos_weight + + +class Tacotron2(TTSInterface, torch.nn.Module): + """Tacotron2 module for end-to-end text-to-speech (E2E-TTS). + + This is a module of Spectrogram prediction network in Tacotron2 described + in `Natural TTS Synthesis + by Conditioning WaveNet on Mel Spectrogram Predictions`_, + which converts the sequence of characters + into the sequence of Mel-filterbanks. + + .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: + https://arxiv.org/abs/1712.05884 + + """ + + @staticmethod + def add_arguments(parser): + """Add model-specific arguments to the parser.""" + group = parser.add_argument_group("tacotron 2 model setting") + # encoder + group.add_argument( + "--embed-dim", + default=512, + type=int, + help="Number of dimension of embedding", + ) + group.add_argument( + "--elayers", default=1, type=int, help="Number of encoder layers" + ) + group.add_argument( + "--eunits", + "-u", + default=512, + type=int, + help="Number of encoder hidden units", + ) + group.add_argument( + "--econv-layers", + default=3, + type=int, + help="Number of encoder convolution layers", + ) + group.add_argument( + "--econv-chans", + default=512, + type=int, + help="Number of encoder convolution channels", + ) + group.add_argument( + "--econv-filts", + default=5, + type=int, + help="Filter size of encoder convolution", + ) + # attention + group.add_argument( + "--atype", + default="location", + type=str, + choices=["forward_ta", "forward", "location"], + help="Type of attention mechanism", + ) + group.add_argument( + "--adim", + default=512, + type=int, + help="Number of attention transformation dimensions", + ) + group.add_argument( + "--aconv-chans", + default=32, + type=int, + help="Number of attention convolution channels", + ) + group.add_argument( + "--aconv-filts", + default=15, + type=int, + help="Filter size of attention convolution", + ) + group.add_argument( + "--cumulate-att-w", + default=True, + type=strtobool, + help="Whether or not to cumulate attention weights", + ) + # decoder + group.add_argument( + "--dlayers", default=2, type=int, help="Number of decoder layers" + ) + group.add_argument( + "--dunits", default=1024, type=int, help="Number of decoder hidden units" + ) + group.add_argument( + "--prenet-layers", default=2, type=int, help="Number of prenet layers" + ) + group.add_argument( + "--prenet-units", + default=256, + type=int, + help="Number of prenet hidden units", + ) + group.add_argument( + "--postnet-layers", default=5, type=int, help="Number of postnet layers" + ) + group.add_argument( + "--postnet-chans", default=512, type=int, help="Number of postnet channels" + ) + group.add_argument( + "--postnet-filts", default=5, type=int, help="Filter size of postnet" + ) + group.add_argument( + "--output-activation", + default=None, + type=str, + nargs="?", + help="Output activation function", + ) + # cbhg + group.add_argument( + "--use-cbhg", + default=False, + type=strtobool, + help="Whether to use CBHG module", + ) + group.add_argument( + "--cbhg-conv-bank-layers", + default=8, + type=int, + help="Number of convoluional bank layers in CBHG", + ) + group.add_argument( + "--cbhg-conv-bank-chans", + default=128, + type=int, + help="Number of convoluional bank channles in CBHG", + ) + group.add_argument( + "--cbhg-conv-proj-filts", + default=3, + type=int, + help="Filter size of convoluional projection layer in CBHG", + ) + group.add_argument( + "--cbhg-conv-proj-chans", + default=256, + type=int, + help="Number of convoluional projection channels in CBHG", + ) + group.add_argument( + "--cbhg-highway-layers", + default=4, + type=int, + help="Number of highway layers in CBHG", + ) + group.add_argument( + "--cbhg-highway-units", + default=128, + type=int, + help="Number of highway units in CBHG", + ) + group.add_argument( + "--cbhg-gru-units", + default=256, + type=int, + help="Number of GRU units in CBHG", + ) + # model (parameter) related + group.add_argument( + "--use-batch-norm", + default=True, + type=strtobool, + help="Whether to use batch normalization", + ) + group.add_argument( + "--use-concate", + default=True, + type=strtobool, + help="Whether to concatenate encoder embedding with decoder outputs", + ) + group.add_argument( + "--use-residual", + default=True, + type=strtobool, + help="Whether to use residual connection in conv layer", + ) + group.add_argument( + "--dropout-rate", default=0.5, type=float, help="Dropout rate" + ) + group.add_argument( + "--zoneout-rate", default=0.1, type=float, help="Zoneout rate" + ) + group.add_argument( + "--reduction-factor", default=1, type=int, help="Reduction factor" + ) + group.add_argument( + "--spk-embed-dim", + default=None, + type=int, + help="Number of speaker embedding dimensions", + ) + group.add_argument( + "--spc-dim", default=None, type=int, help="Number of spectrogram dimensions" + ) + group.add_argument( + "--pretrained-model", default=None, type=str, help="Pretrained model path" + ) + # loss related + group.add_argument( + "--use-masking", + default=False, + type=strtobool, + help="Whether to use masking in calculation of loss", + ) + group.add_argument( + "--use-weighted-masking", + default=False, + type=strtobool, + help="Whether to use weighted masking in calculation of loss", + ) + group.add_argument( + "--bce-pos-weight", + default=20.0, + type=float, + help="Positive sample weight in BCE calculation " + "(only for use-masking=True)", + ) + group.add_argument( + "--use-guided-attn-loss", + default=False, + type=strtobool, + help="Whether to use guided attention loss", + ) + group.add_argument( + "--guided-attn-loss-sigma", + default=0.4, + type=float, + help="Sigma in guided attention loss", + ) + group.add_argument( + "--guided-attn-loss-lambda", + default=1.0, + type=float, + help="Lambda in guided attention loss", + ) + return parser + + def __init__(self, idim, odim, args=None): + """Initialize Tacotron2 module. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + args (Namespace, optional): + - spk_embed_dim (int): Dimension of the speaker embedding. + - embed_dim (int): Dimension of character embedding. + - elayers (int): The number of encoder blstm layers. + - eunits (int): The number of encoder blstm units. + - econv_layers (int): The number of encoder conv layers. + - econv_filts (int): The number of encoder conv filter size. + - econv_chans (int): The number of encoder conv filter channels. + - dlayers (int): The number of decoder lstm layers. + - dunits (int): The number of decoder lstm units. + - prenet_layers (int): The number of prenet layers. + - prenet_units (int): The number of prenet units. + - postnet_layers (int): The number of postnet layers. + - postnet_filts (int): The number of postnet filter size. + - postnet_chans (int): The number of postnet filter channels. + - output_activation (int): The name of activation function for outputs. + - adim (int): The number of dimension of mlp in attention. + - aconv_chans (int): The number of attention conv filter channels. + - aconv_filts (int): The number of attention conv filter size. + - cumulate_att_w (bool): Whether to cumulate previous attention weight. + - use_batch_norm (bool): Whether to use batch normalization. + - use_concate (int): Whether to concatenate encoder embedding + with decoder lstm outputs. + - dropout_rate (float): Dropout rate. + - zoneout_rate (float): Zoneout rate. + - reduction_factor (int): Reduction factor. + - spk_embed_dim (int): Number of speaker embedding dimenstions. + - spc_dim (int): Number of spectrogram embedding dimenstions + (only for use_cbhg=True). + - use_cbhg (bool): Whether to use CBHG module. + - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG. + - cbhg_conv_bank_chans (int): The number of channels of + convolutional bank in CBHG. + - cbhg_proj_filts (int): + The number of filter size of projection layeri in CBHG. + - cbhg_proj_chans (int): + The number of channels of projection layer in CBHG. + - cbhg_highway_layers (int): + The number of layers of highway network in CBHG. + - cbhg_highway_units (int): + The number of units of highway network in CBHG. + - cbhg_gru_units (int): The number of units of GRU in CBHG. + - use_masking (bool): + Whether to apply masking for padded part in loss calculation. + - use_weighted_masking (bool): + Whether to apply weighted masking in loss calculation. + - bce_pos_weight (float): + Weight of positive sample of stop token (only for use_masking=True). + - use-guided-attn-loss (bool): Whether to use guided attention loss. + - guided-attn-loss-sigma (float) Sigma in guided attention loss. + - guided-attn-loss-lamdba (float): Lambda in guided attention loss. + + """ + # initialize base classes + TTSInterface.__init__(self) + torch.nn.Module.__init__(self) + + # fill missing arguments + args = fill_missing_args(args, self.add_arguments) + + # store hyperparameters + self.idim = idim + self.odim = odim + self.spk_embed_dim = args.spk_embed_dim + self.cumulate_att_w = args.cumulate_att_w + self.reduction_factor = args.reduction_factor + self.use_cbhg = args.use_cbhg + self.use_guided_attn_loss = args.use_guided_attn_loss + + # define activation function for the final output + if args.output_activation is None: + self.output_activation_fn = None + elif hasattr(F, args.output_activation): + self.output_activation_fn = getattr(F, args.output_activation) + else: + raise ValueError( + "there is no such an activation function. (%s)" % args.output_activation + ) + + # set padding idx + padding_idx = 0 + + # define network modules + self.enc = Encoder( + idim=idim, + embed_dim=args.embed_dim, + elayers=args.elayers, + eunits=args.eunits, + econv_layers=args.econv_layers, + econv_chans=args.econv_chans, + econv_filts=args.econv_filts, + use_batch_norm=args.use_batch_norm, + use_residual=args.use_residual, + dropout_rate=args.dropout_rate, + padding_idx=padding_idx, + ) + dec_idim = ( + args.eunits + if args.spk_embed_dim is None + else args.eunits + args.spk_embed_dim + ) + if args.atype == "location": + att = AttLoc( + dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts + ) + elif args.atype == "forward": + att = AttForward( + dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts + ) + if self.cumulate_att_w: + logging.warning( + "cumulation of attention weights is disabled in forward attention." + ) + self.cumulate_att_w = False + elif args.atype == "forward_ta": + att = AttForwardTA( + dec_idim, + args.dunits, + args.adim, + args.aconv_chans, + args.aconv_filts, + odim, + ) + if self.cumulate_att_w: + logging.warning( + "cumulation of attention weights is disabled in forward attention." + ) + self.cumulate_att_w = False + else: + raise NotImplementedError("Support only location or forward") + self.dec = Decoder( + idim=dec_idim, + odim=odim, + att=att, + dlayers=args.dlayers, + dunits=args.dunits, + prenet_layers=args.prenet_layers, + prenet_units=args.prenet_units, + postnet_layers=args.postnet_layers, + postnet_chans=args.postnet_chans, + postnet_filts=args.postnet_filts, + output_activation_fn=self.output_activation_fn, + cumulate_att_w=self.cumulate_att_w, + use_batch_norm=args.use_batch_norm, + use_concate=args.use_concate, + dropout_rate=args.dropout_rate, + zoneout_rate=args.zoneout_rate, + reduction_factor=args.reduction_factor, + ) + self.taco2_loss = Tacotron2Loss( + use_masking=args.use_masking, + use_weighted_masking=args.use_weighted_masking, + bce_pos_weight=args.bce_pos_weight, + ) + if self.use_guided_attn_loss: + self.attn_loss = GuidedAttentionLoss( + sigma=args.guided_attn_loss_sigma, + alpha=args.guided_attn_loss_lambda, + ) + if self.use_cbhg: + self.cbhg = CBHG( + idim=odim, + odim=args.spc_dim, + conv_bank_layers=args.cbhg_conv_bank_layers, + conv_bank_chans=args.cbhg_conv_bank_chans, + conv_proj_filts=args.cbhg_conv_proj_filts, + conv_proj_chans=args.cbhg_conv_proj_chans, + highway_layers=args.cbhg_highway_layers, + highway_units=args.cbhg_highway_units, + gru_units=args.cbhg_gru_units, + ) + self.cbhg_loss = CBHGLoss(use_masking=args.use_masking) + + # load pretrained model + if args.pretrained_model is not None: + self.load_pretrained_model(args.pretrained_model) + + def forward( + self, xs, ilens, ys, labels, olens, spembs=None, extras=None, *args, **kwargs + ): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of padded character ids (B, Tmax). + ilens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + olens (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): + Batch of speaker embedding vectors (B, spk_embed_dim). + extras (Tensor, optional): + Batch of groundtruth spectrograms (B, Lmax, spc_dim). + + Returns: + Tensor: Loss value. + + """ + # remove unnecessary padded part (for multi-gpus) + max_in = max(ilens) + max_out = max(olens) + if max_in != xs.shape[1]: + xs = xs[:, :max_in] + if max_out != ys.shape[1]: + ys = ys[:, :max_out] + labels = labels[:, :max_out] + + # calculate tacotron2 outputs + hs, hlens = self.enc(xs, ilens) + if self.spk_embed_dim is not None: + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = torch.cat([hs, spembs], dim=-1) + after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys) + + # modifiy mod part of groundtruth + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + max_out = max(olens) + ys = ys[:, :max_out] + labels = labels[:, :max_out] + labels[:, -1] = 1.0 # make sure at least one frame has 1 + + # caluculate taco2 loss + l1_loss, mse_loss, bce_loss = self.taco2_loss( + after_outs, before_outs, logits, ys, labels, olens + ) + loss = l1_loss + mse_loss + bce_loss + report_keys = [ + {"l1_loss": l1_loss.item()}, + {"mse_loss": mse_loss.item()}, + {"bce_loss": bce_loss.item()}, + ] + + # caluculate attention loss + if self.use_guided_attn_loss: + # NOTE(kan-bayashi): + # length of output for auto-regressive input will be changed when r > 1 + if self.reduction_factor > 1: + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + olens_in = olens + attn_loss = self.attn_loss(att_ws, ilens, olens_in) + loss = loss + attn_loss + report_keys += [ + {"attn_loss": attn_loss.item()}, + ] + + # caluculate cbhg loss + if self.use_cbhg: + # remove unnecessary padded part (for multi-gpus) + if max_out != extras.shape[1]: + extras = extras[:, :max_out] + + # caluculate cbhg outputs & loss and report them + cbhg_outs, _ = self.cbhg(after_outs, olens) + cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(cbhg_outs, extras, olens) + loss = loss + cbhg_l1_loss + cbhg_mse_loss + report_keys += [ + {"cbhg_l1_loss": cbhg_l1_loss.item()}, + {"cbhg_mse_loss": cbhg_mse_loss.item()}, + ] + + report_keys += [{"loss": loss.item()}] + self.reporter.report(report_keys) + + return loss + + def inference(self, x, inference_args, spemb=None, *args, **kwargs): + """Generate the sequence of features given the sequences of characters. + + Args: + x (Tensor): Input sequence of characters (T,). + inference_args (Namespace): + - threshold (float): Threshold in inference. + - minlenratio (float): Minimum length ratio in inference. + - maxlenratio (float): Maximum length ratio in inference. + spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). + + Returns: + Tensor: Output sequence of features (L, odim). + Tensor: Output sequence of stop probabilities (L,). + Tensor: Attention weights (L, T). + + """ + # get options + threshold = inference_args.threshold + minlenratio = inference_args.minlenratio + maxlenratio = inference_args.maxlenratio + use_att_constraint = getattr( + inference_args, "use_att_constraint", False + ) # keep compatibility + backward_window = inference_args.backward_window if use_att_constraint else 0 + forward_window = inference_args.forward_window if use_att_constraint else 0 + + # inference + h = self.enc.inference(x) + if self.spk_embed_dim is not None: + spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) + h = torch.cat([h, spemb], dim=-1) + outs, probs, att_ws = self.dec.inference( + h, + threshold, + minlenratio, + maxlenratio, + use_att_constraint=use_att_constraint, + backward_window=backward_window, + forward_window=forward_window, + ) + + if self.use_cbhg: + cbhg_outs = self.cbhg.inference(outs) + return cbhg_outs, probs, att_ws + else: + return outs, probs, att_ws + + def calculate_all_attentions( + self, xs, ilens, ys, spembs=None, keep_tensor=False, *args, **kwargs + ): + """Calculate all of the attention weights. + + Args: + xs (Tensor): Batch of padded character ids (B, Tmax). + ilens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + olens (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): + Batch of speaker embedding vectors (B, spk_embed_dim). + keep_tensor (bool, optional): Whether to keep original tensor. + + Returns: + Union[ndarray, Tensor]: Batch of attention weights (B, Lmax, Tmax). + + """ + # check ilens type (should be list of int) + if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray): + ilens = list(map(int, ilens)) + + self.eval() + with torch.no_grad(): + hs, hlens = self.enc(xs, ilens) + if self.spk_embed_dim is not None: + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = torch.cat([hs, spembs], dim=-1) + att_ws = self.dec.calculate_all_attentions(hs, hlens, ys) + self.train() + + if keep_tensor: + return att_ws + else: + return att_ws.cpu().numpy() + + @property + def base_plot_keys(self): + """Return base key names to plot during training. + + keys should match what `chainer.reporter` reports. + If you add the key `loss`, the reporter will report `main/loss` + and `validation/main/loss` values. + also `loss.png` will be created as a figure visulizing `main/loss` + and `validation/main/loss` values. + + Returns: + list: List of strings which are base keys to plot during training. + + """ + plot_keys = ["loss", "l1_loss", "mse_loss", "bce_loss"] + if self.use_guided_attn_loss: + plot_keys += ["attn_loss"] + if self.use_cbhg: + plot_keys += ["cbhg_l1_loss", "cbhg_mse_loss"] + return plot_keys diff --git a/espnet/nets/pytorch_backend/e2e_tts_transformer.py b/espnet/nets/pytorch_backend/e2e_tts_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..3628e7df00ac4b83e5bb4e0bc6a37135be2190ff --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_tts_transformer.py @@ -0,0 +1,1153 @@ +# Copyright 2019 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""TTS-Transformer related modules.""" + +import logging + +import torch +import torch.nn.functional as F + +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import ( + Tacotron2Loss as TransformerLoss, # noqa: H301 +) +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet +from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as DecoderPrenet +from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder as EncoderPrenet +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.decoder import Decoder +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.initializer import initialize +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet.nets.tts_interface import TTSInterface +from espnet.utils.cli_utils import strtobool +from espnet.utils.fill_missing_args import fill_missing_args + + +class GuidedMultiHeadAttentionLoss(GuidedAttentionLoss): + """Guided attention loss function module for multi head attention. + + Args: + sigma (float, optional): Standard deviation to control + how close attention to a diagonal. + alpha (float, optional): Scaling coefficient (lambda). + reset_always (bool, optional): Whether to always reset masks. + + """ + + def forward(self, att_ws, ilens, olens): + """Calculate forward propagation. + + Args: + att_ws (Tensor): + Batch of multi head attention weights (B, H, T_max_out, T_max_in). + ilens (LongTensor): Batch of input lenghts (B,). + olens (LongTensor): Batch of output lenghts (B,). + + Returns: + Tensor: Guided attention loss value. + + """ + if self.guided_attn_masks is None: + self.guided_attn_masks = ( + self._make_guided_attention_masks(ilens, olens) + .to(att_ws.device) + .unsqueeze(1) + ) + if self.masks is None: + self.masks = self._make_masks(ilens, olens).to(att_ws.device).unsqueeze(1) + losses = self.guided_attn_masks * att_ws + loss = torch.mean(losses.masked_select(self.masks)) + if self.reset_always: + self._reset_masks() + + return self.alpha * loss + + +try: + from espnet.nets.pytorch_backend.transformer.plot import PlotAttentionReport +except (ImportError, TypeError): + TTSPlot = None +else: + + class TTSPlot(PlotAttentionReport): + """Attention plot module for TTS-Transformer.""" + + def plotfn( + self, data_dict, uttid_list, attn_dict, outdir, suffix="png", savefn=None + ): + """Plot multi head attentions. + + Args: + data_dict (dict): Utts info from json file. + uttid_list (list): List of utt_id. + attn_dict (dict): Multi head attention dict. + Values should be numpy.ndarray (H, L, T) + outdir (str): Directory name to save figures. + suffix (str): Filename suffix including image type (e.g., png). + savefn (function): Function to save figures. + + """ + import matplotlib.pyplot as plt + from espnet.nets.pytorch_backend.transformer.plot import ( + _plot_and_save_attention, # noqa: H301 + ) + + for name, att_ws in attn_dict.items(): + for utt_id, att_w in zip(uttid_list, att_ws): + filename = "%s/%s.%s.%s" % (outdir, utt_id, name, suffix) + if "fbank" in name: + fig = plt.Figure() + ax = fig.subplots(1, 1) + ax.imshow(att_w, aspect="auto") + ax.set_xlabel("frames") + ax.set_ylabel("fbank coeff") + fig.tight_layout() + else: + fig = _plot_and_save_attention(att_w, filename) + savefn(fig, filename) + + +class Transformer(TTSInterface, torch.nn.Module): + """Text-to-Speech Transformer module. + + This is a module of text-to-speech Transformer described + in `Neural Speech Synthesis with Transformer Network`_, + which convert the sequence of characters + or phonemes into the sequence of Mel-filterbanks. + + .. _`Neural Speech Synthesis with Transformer Network`: + https://arxiv.org/pdf/1809.08895.pdf + + """ + + @staticmethod + def add_arguments(parser): + """Add model-specific arguments to the parser.""" + group = parser.add_argument_group("transformer model setting") + # network structure related + group.add_argument( + "--embed-dim", + default=512, + type=int, + help="Dimension of character embedding in encoder prenet", + ) + group.add_argument( + "--eprenet-conv-layers", + default=3, + type=int, + help="Number of encoder prenet convolution layers", + ) + group.add_argument( + "--eprenet-conv-chans", + default=256, + type=int, + help="Number of encoder prenet convolution channels", + ) + group.add_argument( + "--eprenet-conv-filts", + default=5, + type=int, + help="Filter size of encoder prenet convolution", + ) + group.add_argument( + "--dprenet-layers", + default=2, + type=int, + help="Number of decoder prenet layers", + ) + group.add_argument( + "--dprenet-units", + default=256, + type=int, + help="Number of decoder prenet hidden units", + ) + group.add_argument( + "--elayers", default=3, type=int, help="Number of encoder layers" + ) + group.add_argument( + "--eunits", default=1536, type=int, help="Number of encoder hidden units" + ) + group.add_argument( + "--adim", + default=384, + type=int, + help="Number of attention transformation dimensions", + ) + group.add_argument( + "--aheads", + default=4, + type=int, + help="Number of heads for multi head attention", + ) + group.add_argument( + "--dlayers", default=3, type=int, help="Number of decoder layers" + ) + group.add_argument( + "--dunits", default=1536, type=int, help="Number of decoder hidden units" + ) + group.add_argument( + "--positionwise-layer-type", + default="linear", + type=str, + choices=["linear", "conv1d", "conv1d-linear"], + help="Positionwise layer type.", + ) + group.add_argument( + "--positionwise-conv-kernel-size", + default=1, + type=int, + help="Kernel size of positionwise conv1d layer", + ) + group.add_argument( + "--postnet-layers", default=5, type=int, help="Number of postnet layers" + ) + group.add_argument( + "--postnet-chans", default=256, type=int, help="Number of postnet channels" + ) + group.add_argument( + "--postnet-filts", default=5, type=int, help="Filter size of postnet" + ) + group.add_argument( + "--use-scaled-pos-enc", + default=True, + type=strtobool, + help="Use trainable scaled positional encoding " + "instead of the fixed scale one.", + ) + group.add_argument( + "--use-batch-norm", + default=True, + type=strtobool, + help="Whether to use batch normalization", + ) + group.add_argument( + "--encoder-normalize-before", + default=False, + type=strtobool, + help="Whether to apply layer norm before encoder block", + ) + group.add_argument( + "--decoder-normalize-before", + default=False, + type=strtobool, + help="Whether to apply layer norm before decoder block", + ) + group.add_argument( + "--encoder-concat-after", + default=False, + type=strtobool, + help="Whether to concatenate attention layer's input and output in encoder", + ) + group.add_argument( + "--decoder-concat-after", + default=False, + type=strtobool, + help="Whether to concatenate attention layer's input and output in decoder", + ) + group.add_argument( + "--reduction-factor", default=1, type=int, help="Reduction factor" + ) + group.add_argument( + "--spk-embed-dim", + default=None, + type=int, + help="Number of speaker embedding dimensions", + ) + group.add_argument( + "--spk-embed-integration-type", + type=str, + default="add", + choices=["add", "concat"], + help="How to integrate speaker embedding", + ) + # training related + group.add_argument( + "--transformer-init", + type=str, + default="pytorch", + choices=[ + "pytorch", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + ], + help="How to initialize transformer parameters", + ) + group.add_argument( + "--initial-encoder-alpha", + type=float, + default=1.0, + help="Initial alpha value in encoder's ScaledPositionalEncoding", + ) + group.add_argument( + "--initial-decoder-alpha", + type=float, + default=1.0, + help="Initial alpha value in decoder's ScaledPositionalEncoding", + ) + group.add_argument( + "--transformer-lr", + default=1.0, + type=float, + help="Initial value of learning rate", + ) + group.add_argument( + "--transformer-warmup-steps", + default=4000, + type=int, + help="Optimizer warmup steps", + ) + group.add_argument( + "--transformer-enc-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder except for attention", + ) + group.add_argument( + "--transformer-enc-positional-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder positional encoding", + ) + group.add_argument( + "--transformer-enc-attn-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder self-attention", + ) + group.add_argument( + "--transformer-dec-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer decoder " + "except for attention and pos encoding", + ) + group.add_argument( + "--transformer-dec-positional-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer decoder positional encoding", + ) + group.add_argument( + "--transformer-dec-attn-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer decoder self-attention", + ) + group.add_argument( + "--transformer-enc-dec-attn-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder-decoder attention", + ) + group.add_argument( + "--eprenet-dropout-rate", + default=0.5, + type=float, + help="Dropout rate in encoder prenet", + ) + group.add_argument( + "--dprenet-dropout-rate", + default=0.5, + type=float, + help="Dropout rate in decoder prenet", + ) + group.add_argument( + "--postnet-dropout-rate", + default=0.5, + type=float, + help="Dropout rate in postnet", + ) + group.add_argument( + "--pretrained-model", default=None, type=str, help="Pretrained model path" + ) + # loss related + group.add_argument( + "--use-masking", + default=True, + type=strtobool, + help="Whether to use masking in calculation of loss", + ) + group.add_argument( + "--use-weighted-masking", + default=False, + type=strtobool, + help="Whether to use weighted masking in calculation of loss", + ) + group.add_argument( + "--loss-type", + default="L1", + choices=["L1", "L2", "L1+L2"], + help="How to calc loss", + ) + group.add_argument( + "--bce-pos-weight", + default=5.0, + type=float, + help="Positive sample weight in BCE calculation " + "(only for use-masking=True)", + ) + group.add_argument( + "--use-guided-attn-loss", + default=False, + type=strtobool, + help="Whether to use guided attention loss", + ) + group.add_argument( + "--guided-attn-loss-sigma", + default=0.4, + type=float, + help="Sigma in guided attention loss", + ) + group.add_argument( + "--guided-attn-loss-lambda", + default=1.0, + type=float, + help="Lambda in guided attention loss", + ) + group.add_argument( + "--num-heads-applied-guided-attn", + default=2, + type=int, + help="Number of heads in each layer to be applied guided attention loss" + "if set -1, all of the heads will be applied.", + ) + group.add_argument( + "--num-layers-applied-guided-attn", + default=2, + type=int, + help="Number of layers to be applied guided attention loss" + "if set -1, all of the layers will be applied.", + ) + group.add_argument( + "--modules-applied-guided-attn", + type=str, + nargs="+", + default=["encoder-decoder"], + help="Module name list to be applied guided attention loss", + ) + return parser + + @property + def attention_plot_class(self): + """Return plot class for attention weight plot.""" + return TTSPlot + + def __init__(self, idim, odim, args=None): + """Initialize TTS-Transformer module. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + args (Namespace, optional): + - embed_dim (int): Dimension of character embedding. + - eprenet_conv_layers (int): + Number of encoder prenet convolution layers. + - eprenet_conv_chans (int): + Number of encoder prenet convolution channels. + - eprenet_conv_filts (int): Filter size of encoder prenet convolution. + - dprenet_layers (int): Number of decoder prenet layers. + - dprenet_units (int): Number of decoder prenet hidden units. + - elayers (int): Number of encoder layers. + - eunits (int): Number of encoder hidden units. + - adim (int): Number of attention transformation dimensions. + - aheads (int): Number of heads for multi head attention. + - dlayers (int): Number of decoder layers. + - dunits (int): Number of decoder hidden units. + - postnet_layers (int): Number of postnet layers. + - postnet_chans (int): Number of postnet channels. + - postnet_filts (int): Filter size of postnet. + - use_scaled_pos_enc (bool): + Whether to use trainable scaled positional encoding. + - use_batch_norm (bool): + Whether to use batch normalization in encoder prenet. + - encoder_normalize_before (bool): + Whether to perform layer normalization before encoder block. + - decoder_normalize_before (bool): + Whether to perform layer normalization before decoder block. + - encoder_concat_after (bool): Whether to concatenate attention + layer's input and output in encoder. + - decoder_concat_after (bool): Whether to concatenate attention + layer's input and output in decoder. + - reduction_factor (int): Reduction factor. + - spk_embed_dim (int): Number of speaker embedding dimenstions. + - spk_embed_integration_type: How to integrate speaker embedding. + - transformer_init (float): How to initialize transformer parameters. + - transformer_lr (float): Initial value of learning rate. + - transformer_warmup_steps (int): Optimizer warmup steps. + - transformer_enc_dropout_rate (float): + Dropout rate in encoder except attention & positional encoding. + - transformer_enc_positional_dropout_rate (float): + Dropout rate after encoder positional encoding. + - transformer_enc_attn_dropout_rate (float): + Dropout rate in encoder self-attention module. + - transformer_dec_dropout_rate (float): + Dropout rate in decoder except attention & positional encoding. + - transformer_dec_positional_dropout_rate (float): + Dropout rate after decoder positional encoding. + - transformer_dec_attn_dropout_rate (float): + Dropout rate in deocoder self-attention module. + - transformer_enc_dec_attn_dropout_rate (float): + Dropout rate in encoder-deocoder attention module. + - eprenet_dropout_rate (float): Dropout rate in encoder prenet. + - dprenet_dropout_rate (float): Dropout rate in decoder prenet. + - postnet_dropout_rate (float): Dropout rate in postnet. + - use_masking (bool): + Whether to apply masking for padded part in loss calculation. + - use_weighted_masking (bool): + Whether to apply weighted masking in loss calculation. + - bce_pos_weight (float): Positive sample weight in bce calculation + (only for use_masking=true). + - loss_type (str): How to calculate loss. + - use_guided_attn_loss (bool): Whether to use guided attention loss. + - num_heads_applied_guided_attn (int): + Number of heads in each layer to apply guided attention loss. + - num_layers_applied_guided_attn (int): + Number of layers to apply guided attention loss. + - modules_applied_guided_attn (list): + List of module names to apply guided attention loss. + - guided-attn-loss-sigma (float) Sigma in guided attention loss. + - guided-attn-loss-lambda (float): Lambda in guided attention loss. + + """ + # initialize base classes + TTSInterface.__init__(self) + torch.nn.Module.__init__(self) + + # fill missing arguments + args = fill_missing_args(args, self.add_arguments) + + # store hyperparameters + self.idim = idim + self.odim = odim + self.spk_embed_dim = args.spk_embed_dim + if self.spk_embed_dim is not None: + self.spk_embed_integration_type = args.spk_embed_integration_type + self.use_scaled_pos_enc = args.use_scaled_pos_enc + self.reduction_factor = args.reduction_factor + self.loss_type = args.loss_type + self.use_guided_attn_loss = args.use_guided_attn_loss + if self.use_guided_attn_loss: + if args.num_layers_applied_guided_attn == -1: + self.num_layers_applied_guided_attn = args.elayers + else: + self.num_layers_applied_guided_attn = ( + args.num_layers_applied_guided_attn + ) + if args.num_heads_applied_guided_attn == -1: + self.num_heads_applied_guided_attn = args.aheads + else: + self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn + self.modules_applied_guided_attn = args.modules_applied_guided_attn + + # use idx 0 as padding idx + padding_idx = 0 + + # get positional encoding class + pos_enc_class = ( + ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding + ) + + # define transformer encoder + if args.eprenet_conv_layers != 0: + # encoder prenet + encoder_input_layer = torch.nn.Sequential( + EncoderPrenet( + idim=idim, + embed_dim=args.embed_dim, + elayers=0, + econv_layers=args.eprenet_conv_layers, + econv_chans=args.eprenet_conv_chans, + econv_filts=args.eprenet_conv_filts, + use_batch_norm=args.use_batch_norm, + dropout_rate=args.eprenet_dropout_rate, + padding_idx=padding_idx, + ), + torch.nn.Linear(args.eprenet_conv_chans, args.adim), + ) + else: + encoder_input_layer = torch.nn.Embedding( + num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx + ) + self.encoder = Encoder( + idim=idim, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.eunits, + num_blocks=args.elayers, + input_layer=encoder_input_layer, + dropout_rate=args.transformer_enc_dropout_rate, + positional_dropout_rate=args.transformer_enc_positional_dropout_rate, + attention_dropout_rate=args.transformer_enc_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=args.encoder_normalize_before, + concat_after=args.encoder_concat_after, + positionwise_layer_type=args.positionwise_layer_type, + positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, + ) + + # define projection layer + if self.spk_embed_dim is not None: + if self.spk_embed_integration_type == "add": + self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) + else: + self.projection = torch.nn.Linear( + args.adim + self.spk_embed_dim, args.adim + ) + + # define transformer decoder + if args.dprenet_layers != 0: + # decoder prenet + decoder_input_layer = torch.nn.Sequential( + DecoderPrenet( + idim=odim, + n_layers=args.dprenet_layers, + n_units=args.dprenet_units, + dropout_rate=args.dprenet_dropout_rate, + ), + torch.nn.Linear(args.dprenet_units, args.adim), + ) + else: + decoder_input_layer = "linear" + self.decoder = Decoder( + odim=-1, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.dunits, + num_blocks=args.dlayers, + dropout_rate=args.transformer_dec_dropout_rate, + positional_dropout_rate=args.transformer_dec_positional_dropout_rate, + self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, + src_attention_dropout_rate=args.transformer_enc_dec_attn_dropout_rate, + input_layer=decoder_input_layer, + use_output_layer=False, + pos_enc_class=pos_enc_class, + normalize_before=args.decoder_normalize_before, + concat_after=args.decoder_concat_after, + ) + + # define final projection + self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) + self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) + + # define postnet + self.postnet = ( + None + if args.postnet_layers == 0 + else Postnet( + idim=idim, + odim=odim, + n_layers=args.postnet_layers, + n_chans=args.postnet_chans, + n_filts=args.postnet_filts, + use_batch_norm=args.use_batch_norm, + dropout_rate=args.postnet_dropout_rate, + ) + ) + + # define loss function + self.criterion = TransformerLoss( + use_masking=args.use_masking, + use_weighted_masking=args.use_weighted_masking, + bce_pos_weight=args.bce_pos_weight, + ) + if self.use_guided_attn_loss: + self.attn_criterion = GuidedMultiHeadAttentionLoss( + sigma=args.guided_attn_loss_sigma, + alpha=args.guided_attn_loss_lambda, + ) + + # initialize parameters + self._reset_parameters( + init_type=args.transformer_init, + init_enc_alpha=args.initial_encoder_alpha, + init_dec_alpha=args.initial_decoder_alpha, + ) + + # load pretrained model + if args.pretrained_model is not None: + self.load_pretrained_model(args.pretrained_model) + + def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): + # initialize parameters + initialize(self, init_type) + + # initialize alpha in scaled positional encoding + if self.use_scaled_pos_enc: + self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) + self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) + + def _add_first_frame_and_remove_last_frame(self, ys): + ys_in = torch.cat( + [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1 + ) + return ys_in + + def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of padded character ids (B, Tmax). + ilens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + olens (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): + Batch of speaker embedding vectors (B, spk_embed_dim). + + Returns: + Tensor: Loss value. + + """ + # remove unnecessary padded part (for multi-gpus) + max_ilen = max(ilens) + max_olen = max(olens) + if max_ilen != xs.shape[1]: + xs = xs[:, :max_ilen] + if max_olen != ys.shape[1]: + ys = ys[:, :max_olen] + labels = labels[:, :max_olen] + + # forward encoder + x_masks = self._source_mask(ilens) + hs, h_masks = self.encoder(xs, x_masks) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) + if self.reduction_factor > 1: + ys_in = ys[:, self.reduction_factor - 1 :: self.reduction_factor] + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + ys_in, olens_in = ys, olens + + # add first zero frame and remove last frame for auto-regressive + ys_in = self._add_first_frame_and_remove_last_frame(ys_in) + + # forward decoder + y_masks = self._target_mask(olens_in) + zs, _ = self.decoder(ys_in, y_masks, hs, h_masks) + # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) + before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) + # (B, Lmax//r, r) -> (B, Lmax//r * r) + logits = self.prob_out(zs).view(zs.size(0), -1) + + # postnet -> (B, Lmax//r * r, odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose(1, 2) + ).transpose(1, 2) + + # modifiy mod part of groundtruth + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + max_olen = max(olens) + ys = ys[:, :max_olen] + labels = labels[:, :max_olen] + labels[:, -1] = 1.0 # make sure at least one frame has 1 + + # caluculate loss values + l1_loss, l2_loss, bce_loss = self.criterion( + after_outs, before_outs, logits, ys, labels, olens + ) + if self.loss_type == "L1": + loss = l1_loss + bce_loss + elif self.loss_type == "L2": + loss = l2_loss + bce_loss + elif self.loss_type == "L1+L2": + loss = l1_loss + l2_loss + bce_loss + else: + raise ValueError("unknown --loss-type " + self.loss_type) + report_keys = [ + {"l1_loss": l1_loss.item()}, + {"l2_loss": l2_loss.item()}, + {"bce_loss": bce_loss.item()}, + {"loss": loss.item()}, + ] + + # calculate guided attention loss + if self.use_guided_attn_loss: + # calculate for encoder + if "encoder" in self.modules_applied_guided_attn: + att_ws = [] + for idx, layer_idx in enumerate( + reversed(range(len(self.encoder.encoders))) + ): + att_ws += [ + self.encoder.encoders[layer_idx].self_attn.attn[ + :, : self.num_heads_applied_guided_attn + ] + ] + if idx + 1 == self.num_layers_applied_guided_attn: + break + att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_in, T_in) + enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) + loss = loss + enc_attn_loss + report_keys += [{"enc_attn_loss": enc_attn_loss.item()}] + # calculate for decoder + if "decoder" in self.modules_applied_guided_attn: + att_ws = [] + for idx, layer_idx in enumerate( + reversed(range(len(self.decoder.decoders))) + ): + att_ws += [ + self.decoder.decoders[layer_idx].self_attn.attn[ + :, : self.num_heads_applied_guided_attn + ] + ] + if idx + 1 == self.num_layers_applied_guided_attn: + break + att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_out) + dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) + loss = loss + dec_attn_loss + report_keys += [{"dec_attn_loss": dec_attn_loss.item()}] + # calculate for encoder-decoder + if "encoder-decoder" in self.modules_applied_guided_attn: + att_ws = [] + for idx, layer_idx in enumerate( + reversed(range(len(self.decoder.decoders))) + ): + att_ws += [ + self.decoder.decoders[layer_idx].src_attn.attn[ + :, : self.num_heads_applied_guided_attn + ] + ] + if idx + 1 == self.num_layers_applied_guided_attn: + break + att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_in) + enc_dec_attn_loss = self.attn_criterion(att_ws, ilens, olens_in) + loss = loss + enc_dec_attn_loss + report_keys += [{"enc_dec_attn_loss": enc_dec_attn_loss.item()}] + + # report extra information + if self.use_scaled_pos_enc: + report_keys += [ + {"encoder_alpha": self.encoder.embed[-1].alpha.data.item()}, + {"decoder_alpha": self.decoder.embed[-1].alpha.data.item()}, + ] + self.reporter.report(report_keys) + + return loss + + def inference(self, x, inference_args, spemb=None, *args, **kwargs): + """Generate the sequence of features given the sequences of characters. + + Args: + x (Tensor): Input sequence of characters (T,). + inference_args (Namespace): + - threshold (float): Threshold in inference. + - minlenratio (float): Minimum length ratio in inference. + - maxlenratio (float): Maximum length ratio in inference. + spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). + + Returns: + Tensor: Output sequence of features (L, odim). + Tensor: Output sequence of stop probabilities (L,). + Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). + + """ + # get options + threshold = inference_args.threshold + minlenratio = inference_args.minlenratio + maxlenratio = inference_args.maxlenratio + use_att_constraint = getattr( + inference_args, "use_att_constraint", False + ) # keep compatibility + if use_att_constraint: + logging.warning( + "Attention constraint is not yet supported in Transformer. Not enabled." + ) + + # forward encoder + xs = x.unsqueeze(0) + hs, _ = self.encoder(xs, None) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + spembs = spemb.unsqueeze(0) + hs = self._integrate_with_spk_embed(hs, spembs) + + # set limits of length + maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) + minlen = int(hs.size(1) * minlenratio / self.reduction_factor) + + # initialize + idx = 0 + ys = hs.new_zeros(1, 1, self.odim) + outs, probs = [], [] + + # forward decoder step-by-step + z_cache = self.decoder.init_state(x) + while True: + # update index + idx += 1 + + # calculate output and stop prob at idx-th step + y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) + z, z_cache = self.decoder.forward_one_step( + ys, y_masks, hs, cache=z_cache + ) # (B, adim) + outs += [ + self.feat_out(z).view(self.reduction_factor, self.odim) + ] # [(r, odim), ...] + probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] + + # update next inputs + ys = torch.cat( + (ys, outs[-1][-1].view(1, 1, self.odim)), dim=1 + ) # (1, idx + 1, odim) + + # get attention weights + att_ws_ = [] + for name, m in self.named_modules(): + if isinstance(m, MultiHeadedAttention) and "src" in name: + att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] # [(#heads, 1, T),...] + if idx == 1: + att_ws = att_ws_ + else: + # [(#heads, l, T), ...] + att_ws = [ + torch.cat([att_w, att_w_], dim=1) + for att_w, att_w_ in zip(att_ws, att_ws_) + ] + + # check whether to finish generation + if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: + # check mininum length + if idx < minlen: + continue + outs = ( + torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) + ) # (L, odim) -> (1, L, odim) -> (1, odim, L) + if self.postnet is not None: + outs = outs + self.postnet(outs) # (1, odim, L) + outs = outs.transpose(2, 1).squeeze(0) # (L, odim) + probs = torch.cat(probs, dim=0) + break + + # concatenate attention weights -> (#layers, #heads, L, T) + att_ws = torch.stack(att_ws, dim=0) + + return outs, probs, att_ws + + def calculate_all_attentions( + self, + xs, + ilens, + ys, + olens, + spembs=None, + skip_output=False, + keep_tensor=False, + *args, + **kwargs + ): + """Calculate all of the attention weights. + + Args: + xs (Tensor): Batch of padded character ids (B, Tmax). + ilens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + olens (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): + Batch of speaker embedding vectors (B, spk_embed_dim). + skip_output (bool, optional): Whether to skip calculate the final output. + keep_tensor (bool, optional): Whether to keep original tensor. + + Returns: + dict: Dict of attention weights and outputs. + + """ + self.eval() + with torch.no_grad(): + # forward encoder + x_masks = self._source_mask(ilens) + hs, h_masks = self.encoder(xs, x_masks) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + # thin out frames for reduction factor + # (B, Lmax, odim) -> (B, Lmax//r, odim) + if self.reduction_factor > 1: + ys_in = ys[:, self.reduction_factor - 1 :: self.reduction_factor] + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + ys_in, olens_in = ys, olens + + # add first zero frame and remove last frame for auto-regressive + ys_in = self._add_first_frame_and_remove_last_frame(ys_in) + + # forward decoder + y_masks = self._target_mask(olens_in) + zs, _ = self.decoder(ys_in, y_masks, hs, h_masks) + + # calculate final outputs + if not skip_output: + before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose(1, 2) + ).transpose(1, 2) + + # modifiy mod part of output lengths due to reduction factor > 1 + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + + # store into dict + att_ws_dict = dict() + if keep_tensor: + for name, m in self.named_modules(): + if isinstance(m, MultiHeadedAttention): + att_ws_dict[name] = m.attn + if not skip_output: + att_ws_dict["before_postnet_fbank"] = before_outs + att_ws_dict["after_postnet_fbank"] = after_outs + else: + for name, m in self.named_modules(): + if isinstance(m, MultiHeadedAttention): + attn = m.attn.cpu().numpy() + if "encoder" in name: + attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())] + elif "decoder" in name: + if "src" in name: + attn = [ + a[:, :ol, :il] + for a, il, ol in zip( + attn, ilens.tolist(), olens_in.tolist() + ) + ] + elif "self" in name: + attn = [ + a[:, :l, :l] for a, l in zip(attn, olens_in.tolist()) + ] + else: + logging.warning("unknown attention module: " + name) + else: + logging.warning("unknown attention module: " + name) + att_ws_dict[name] = attn + if not skip_output: + before_outs = before_outs.cpu().numpy() + after_outs = after_outs.cpu().numpy() + att_ws_dict["before_postnet_fbank"] = [ + m[:l].T for m, l in zip(before_outs, olens.tolist()) + ] + att_ws_dict["after_postnet_fbank"] = [ + m[:l].T for m, l in zip(after_outs, olens.tolist()) + ] + self.train() + return att_ws_dict + + def _integrate_with_spk_embed(self, hs, spembs): + """Integrate speaker embedding with hidden states. + + Args: + hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). + spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) + + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spembs = self.projection(F.normalize(spembs)) + hs = hs + spembs.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds and then apply projection + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = self.projection(torch.cat([hs, spembs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs + + def _source_mask(self, ilens): + """Make masks for self-attention. + + Args: + ilens (LongTensor or List): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [[1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + return x_masks.unsqueeze(-2) + + def _target_mask(self, olens): + """Make masks for masked self-attention. + + Args: + olens (LongTensor or List): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for masked self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> olens = [5, 3] + >>> self._target_mask(olens) + tensor([[[1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 1, 0], + [1, 1, 1, 1, 1]], + [[1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) + s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) + return y_masks.unsqueeze(-2) & s_masks + + @property + def base_plot_keys(self): + """Return base key names to plot during training. + + keys should match what `chainer.reporter` reports. + If you add the key `loss`, the reporter will report `main/loss` + and `validation/main/loss` values. + also `loss.png` will be created as a figure visulizing `main/loss` + and `validation/main/loss` values. + + Returns: + list: List of strings which are base keys to plot during training. + + """ + plot_keys = ["loss", "l1_loss", "l2_loss", "bce_loss"] + if self.use_scaled_pos_enc: + plot_keys += ["encoder_alpha", "decoder_alpha"] + if self.use_guided_attn_loss: + if "encoder" in self.modules_applied_guided_attn: + plot_keys += ["enc_attn_loss"] + if "decoder" in self.modules_applied_guided_attn: + plot_keys += ["dec_attn_loss"] + if "encoder-decoder" in self.modules_applied_guided_attn: + plot_keys += ["enc_dec_attn_loss"] + + return plot_keys diff --git a/espnet/nets/pytorch_backend/e2e_vc_tacotron2.py b/espnet/nets/pytorch_backend/e2e_vc_tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..e876c42cf7c6c07a58406dcff7a31fbd3649c10f --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_vc_tacotron2.py @@ -0,0 +1,782 @@ +# Copyright 2020 Nagoya University (Wen-Chin Huang) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Tacotron2-VC related modules.""" + +import logging + +from distutils.util import strtobool + +import numpy as np +import torch +import torch.nn.functional as F + +from espnet.nets.pytorch_backend.rnn.attentions import AttForward +from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA +from espnet.nets.pytorch_backend.rnn.attentions import AttLoc +from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHG +from espnet.nets.pytorch_backend.tacotron2.cbhg import CBHGLoss +from espnet.nets.pytorch_backend.tacotron2.decoder import Decoder +from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder +from espnet.nets.tts_interface import TTSInterface +from espnet.utils.fill_missing_args import fill_missing_args +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import ( + GuidedAttentionLoss, # noqa: H301 + Tacotron2Loss, # noqa: H301 +) + + +class Tacotron2(TTSInterface, torch.nn.Module): + """VC Tacotron2 module for VC. + + This is a module of Tacotron2-based VC model, + which convert the sequence of acoustic features + into the sequence of acoustic features. + """ + + @staticmethod + def add_arguments(parser): + """Add model-specific arguments to the parser.""" + group = parser.add_argument_group("tacotron 2 model setting") + # encoder + group.add_argument( + "--elayers", default=1, type=int, help="Number of encoder layers" + ) + group.add_argument( + "--eunits", + "-u", + default=512, + type=int, + help="Number of encoder hidden units", + ) + group.add_argument( + "--econv-layers", + default=3, + type=int, + help="Number of encoder convolution layers", + ) + group.add_argument( + "--econv-chans", + default=512, + type=int, + help="Number of encoder convolution channels", + ) + group.add_argument( + "--econv-filts", + default=5, + type=int, + help="Filter size of encoder convolution", + ) + # attention + group.add_argument( + "--atype", + default="location", + type=str, + choices=["forward_ta", "forward", "location"], + help="Type of attention mechanism", + ) + group.add_argument( + "--adim", + default=512, + type=int, + help="Number of attention transformation dimensions", + ) + group.add_argument( + "--aconv-chans", + default=32, + type=int, + help="Number of attention convolution channels", + ) + group.add_argument( + "--aconv-filts", + default=15, + type=int, + help="Filter size of attention convolution", + ) + group.add_argument( + "--cumulate-att-w", + default=True, + type=strtobool, + help="Whether or not to cumulate attention weights", + ) + # decoder + group.add_argument( + "--dlayers", default=2, type=int, help="Number of decoder layers" + ) + group.add_argument( + "--dunits", default=1024, type=int, help="Number of decoder hidden units" + ) + group.add_argument( + "--prenet-layers", default=2, type=int, help="Number of prenet layers" + ) + group.add_argument( + "--prenet-units", + default=256, + type=int, + help="Number of prenet hidden units", + ) + group.add_argument( + "--postnet-layers", default=5, type=int, help="Number of postnet layers" + ) + group.add_argument( + "--postnet-chans", default=512, type=int, help="Number of postnet channels" + ) + group.add_argument( + "--postnet-filts", default=5, type=int, help="Filter size of postnet" + ) + group.add_argument( + "--output-activation", + default=None, + type=str, + nargs="?", + help="Output activation function", + ) + # cbhg + group.add_argument( + "--use-cbhg", + default=False, + type=strtobool, + help="Whether to use CBHG module", + ) + group.add_argument( + "--cbhg-conv-bank-layers", + default=8, + type=int, + help="Number of convoluional bank layers in CBHG", + ) + group.add_argument( + "--cbhg-conv-bank-chans", + default=128, + type=int, + help="Number of convoluional bank channles in CBHG", + ) + group.add_argument( + "--cbhg-conv-proj-filts", + default=3, + type=int, + help="Filter size of convoluional projection layer in CBHG", + ) + group.add_argument( + "--cbhg-conv-proj-chans", + default=256, + type=int, + help="Number of convoluional projection channels in CBHG", + ) + group.add_argument( + "--cbhg-highway-layers", + default=4, + type=int, + help="Number of highway layers in CBHG", + ) + group.add_argument( + "--cbhg-highway-units", + default=128, + type=int, + help="Number of highway units in CBHG", + ) + group.add_argument( + "--cbhg-gru-units", + default=256, + type=int, + help="Number of GRU units in CBHG", + ) + # model (parameter) related + group.add_argument( + "--use-batch-norm", + default=True, + type=strtobool, + help="Whether to use batch normalization", + ) + group.add_argument( + "--use-concate", + default=True, + type=strtobool, + help="Whether to concatenate encoder embedding with decoder outputs", + ) + group.add_argument( + "--use-residual", + default=True, + type=strtobool, + help="Whether to use residual connection in conv layer", + ) + group.add_argument( + "--dropout-rate", default=0.5, type=float, help="Dropout rate" + ) + group.add_argument( + "--zoneout-rate", default=0.1, type=float, help="Zoneout rate" + ) + group.add_argument( + "--reduction-factor", + default=1, + type=int, + help="Reduction factor (for decoder)", + ) + group.add_argument( + "--encoder-reduction-factor", + default=1, + type=int, + help="Reduction factor (for encoder)", + ) + group.add_argument( + "--spk-embed-dim", + default=None, + type=int, + help="Number of speaker embedding dimensions", + ) + group.add_argument( + "--spc-dim", default=None, type=int, help="Number of spectrogram dimensions" + ) + group.add_argument( + "--pretrained-model", default=None, type=str, help="Pretrained model path" + ) + # loss related + group.add_argument( + "--use-masking", + default=False, + type=strtobool, + help="Whether to use masking in calculation of loss", + ) + group.add_argument( + "--bce-pos-weight", + default=20.0, + type=float, + help="Positive sample weight in BCE calculation " + "(only for use-masking=True)", + ) + group.add_argument( + "--use-guided-attn-loss", + default=False, + type=strtobool, + help="Whether to use guided attention loss", + ) + group.add_argument( + "--guided-attn-loss-sigma", + default=0.4, + type=float, + help="Sigma in guided attention loss", + ) + group.add_argument( + "--guided-attn-loss-lambda", + default=1.0, + type=float, + help="Lambda in guided attention loss", + ) + group.add_argument( + "--src-reconstruction-loss-lambda", + default=1.0, + type=float, + help="Lambda in source reconstruction loss", + ) + group.add_argument( + "--trg-reconstruction-loss-lambda", + default=1.0, + type=float, + help="Lambda in target reconstruction loss", + ) + return parser + + def __init__(self, idim, odim, args=None): + """Initialize Tacotron2 module. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + args (Namespace, optional): + - spk_embed_dim (int): Dimension of the speaker embedding. + - elayers (int): The number of encoder blstm layers. + - eunits (int): The number of encoder blstm units. + - econv_layers (int): The number of encoder conv layers. + - econv_filts (int): The number of encoder conv filter size. + - econv_chans (int): The number of encoder conv filter channels. + - dlayers (int): The number of decoder lstm layers. + - dunits (int): The number of decoder lstm units. + - prenet_layers (int): The number of prenet layers. + - prenet_units (int): The number of prenet units. + - postnet_layers (int): The number of postnet layers. + - postnet_filts (int): The number of postnet filter size. + - postnet_chans (int): The number of postnet filter channels. + - output_activation (int): The name of activation function for outputs. + - adim (int): The number of dimension of mlp in attention. + - aconv_chans (int): The number of attention conv filter channels. + - aconv_filts (int): The number of attention conv filter size. + - cumulate_att_w (bool): Whether to cumulate previous attention weight. + - use_batch_norm (bool): Whether to use batch normalization. + - use_concate (int): + Whether to concatenate encoder embedding with decoder lstm outputs. + - dropout_rate (float): Dropout rate. + - zoneout_rate (float): Zoneout rate. + - reduction_factor (int): Reduction factor. + - spk_embed_dim (int): Number of speaker embedding dimenstions. + - spc_dim (int): Number of spectrogram embedding dimenstions + (only for use_cbhg=True). + - use_cbhg (bool): Whether to use CBHG module. + - cbhg_conv_bank_layers (int): + The number of convoluional banks in CBHG. + - cbhg_conv_bank_chans (int): + The number of channels of convolutional bank in CBHG. + - cbhg_proj_filts (int): + The number of filter size of projection layeri in CBHG. + - cbhg_proj_chans (int): + The number of channels of projection layer in CBHG. + - cbhg_highway_layers (int): + The number of layers of highway network in CBHG. + - cbhg_highway_units (int): + The number of units of highway network in CBHG. + - cbhg_gru_units (int): The number of units of GRU in CBHG. + - use_masking (bool): Whether to mask padded part in loss calculation. + - bce_pos_weight (float): Weight of positive sample of stop token + (only for use_masking=True). + - use-guided-attn-loss (bool): Whether to use guided attention loss. + - guided-attn-loss-sigma (float) Sigma in guided attention loss. + - guided-attn-loss-lamdba (float): Lambda in guided attention loss. + + """ + # initialize base classes + TTSInterface.__init__(self) + torch.nn.Module.__init__(self) + + # fill missing arguments + args = fill_missing_args(args, self.add_arguments) + + # store hyperparameters + self.idim = idim + self.odim = odim + self.adim = args.adim + self.spk_embed_dim = args.spk_embed_dim + self.cumulate_att_w = args.cumulate_att_w + self.reduction_factor = args.reduction_factor + self.encoder_reduction_factor = args.encoder_reduction_factor + self.use_cbhg = args.use_cbhg + self.use_guided_attn_loss = args.use_guided_attn_loss + self.src_reconstruction_loss_lambda = args.src_reconstruction_loss_lambda + self.trg_reconstruction_loss_lambda = args.trg_reconstruction_loss_lambda + + # define activation function for the final output + if args.output_activation is None: + self.output_activation_fn = None + elif hasattr(F, args.output_activation): + self.output_activation_fn = getattr(F, args.output_activation) + else: + raise ValueError( + "there is no such an activation function. (%s)" % args.output_activation + ) + + # define network modules + self.enc = Encoder( + idim=idim * args.encoder_reduction_factor, + input_layer="linear", + elayers=args.elayers, + eunits=args.eunits, + econv_layers=args.econv_layers, + econv_chans=args.econv_chans, + econv_filts=args.econv_filts, + use_batch_norm=args.use_batch_norm, + use_residual=args.use_residual, + dropout_rate=args.dropout_rate, + ) + dec_idim = ( + args.eunits + if args.spk_embed_dim is None + else args.eunits + args.spk_embed_dim + ) + if args.atype == "location": + att = AttLoc( + dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts + ) + elif args.atype == "forward": + att = AttForward( + dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts + ) + if self.cumulate_att_w: + logging.warning( + "cumulation of attention weights is disabled in forward attention." + ) + self.cumulate_att_w = False + elif args.atype == "forward_ta": + att = AttForwardTA( + dec_idim, + args.dunits, + args.adim, + args.aconv_chans, + args.aconv_filts, + odim, + ) + if self.cumulate_att_w: + logging.warning( + "cumulation of attention weights is disabled in forward attention." + ) + self.cumulate_att_w = False + else: + raise NotImplementedError("Support only location or forward") + self.dec = Decoder( + idim=dec_idim, + odim=odim, + att=att, + dlayers=args.dlayers, + dunits=args.dunits, + prenet_layers=args.prenet_layers, + prenet_units=args.prenet_units, + postnet_layers=args.postnet_layers, + postnet_chans=args.postnet_chans, + postnet_filts=args.postnet_filts, + output_activation_fn=self.output_activation_fn, + cumulate_att_w=self.cumulate_att_w, + use_batch_norm=args.use_batch_norm, + use_concate=args.use_concate, + dropout_rate=args.dropout_rate, + zoneout_rate=args.zoneout_rate, + reduction_factor=args.reduction_factor, + ) + self.taco2_loss = Tacotron2Loss( + use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight + ) + if self.use_guided_attn_loss: + self.attn_loss = GuidedAttentionLoss( + sigma=args.guided_attn_loss_sigma, + alpha=args.guided_attn_loss_lambda, + ) + if self.use_cbhg: + self.cbhg = CBHG( + idim=odim, + odim=args.spc_dim, + conv_bank_layers=args.cbhg_conv_bank_layers, + conv_bank_chans=args.cbhg_conv_bank_chans, + conv_proj_filts=args.cbhg_conv_proj_filts, + conv_proj_chans=args.cbhg_conv_proj_chans, + highway_layers=args.cbhg_highway_layers, + highway_units=args.cbhg_highway_units, + gru_units=args.cbhg_gru_units, + ) + self.cbhg_loss = CBHGLoss(use_masking=args.use_masking) + if self.src_reconstruction_loss_lambda > 0: + self.src_reconstructor = Encoder( + idim=dec_idim, + input_layer="linear", + elayers=args.elayers, + eunits=args.eunits, + econv_layers=args.econv_layers, + econv_chans=args.econv_chans, + econv_filts=args.econv_filts, + use_batch_norm=args.use_batch_norm, + use_residual=args.use_residual, + dropout_rate=args.dropout_rate, + ) + self.src_reconstructor_linear = torch.nn.Linear( + args.econv_chans, idim * args.encoder_reduction_factor + ) + + self.src_reconstruction_loss = CBHGLoss(use_masking=args.use_masking) + if self.trg_reconstruction_loss_lambda > 0: + self.trg_reconstructor = Encoder( + idim=dec_idim, + input_layer="linear", + elayers=args.elayers, + eunits=args.eunits, + econv_layers=args.econv_layers, + econv_chans=args.econv_chans, + econv_filts=args.econv_filts, + use_batch_norm=args.use_batch_norm, + use_residual=args.use_residual, + dropout_rate=args.dropout_rate, + ) + self.trg_reconstructor_linear = torch.nn.Linear( + args.econv_chans, odim * args.reduction_factor + ) + self.trg_reconstruction_loss = CBHGLoss(use_masking=args.use_masking) + + # load pretrained model + if args.pretrained_model is not None: + self.load_pretrained_model(args.pretrained_model) + + def forward( + self, xs, ilens, ys, labels, olens, spembs=None, spcs=None, *args, **kwargs + ): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of padded acoustic features (B, Tmax, idim). + ilens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + olens (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): + Batch of speaker embedding vectors (B, spk_embed_dim). + spcs (Tensor, optional): + Batch of groundtruth spectrograms (B, Lmax, spc_dim). + + Returns: + Tensor: Loss value. + + """ + # remove unnecessary padded part (for multi-gpus) + max_in = max(ilens) + max_out = max(olens) + if max_in != xs.shape[1]: + xs = xs[:, :max_in] + if max_out != ys.shape[1]: + ys = ys[:, :max_out] + labels = labels[:, :max_out] + + # thin out input frames for reduction factor + # (B, Lmax, idim) -> (B, Lmax // r, idim * r) + if self.encoder_reduction_factor > 1: + B, Lmax, idim = xs.shape + if Lmax % self.encoder_reduction_factor != 0: + xs = xs[:, : -(Lmax % self.encoder_reduction_factor), :] + xs_ds = xs.contiguous().view( + B, + int(Lmax / self.encoder_reduction_factor), + idim * self.encoder_reduction_factor, + ) + ilens_ds = ilens.new( + [ilen // self.encoder_reduction_factor for ilen in ilens] + ) + else: + xs_ds, ilens_ds = xs, ilens + + # calculate tacotron2 outputs + hs, hlens = self.enc(xs_ds, ilens_ds) + if self.spk_embed_dim is not None: + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = torch.cat([hs, spembs], dim=-1) + after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys) + + # caluculate src reconstruction + if self.src_reconstruction_loss_lambda > 0: + B, _in_length, _adim = hs.shape + xt, xtlens = self.src_reconstructor(hs, hlens) + xt = self.src_reconstructor_linear(xt) + if self.encoder_reduction_factor > 1: + xt = xt.view(B, -1, self.idim) + + # caluculate trg reconstruction + if self.trg_reconstruction_loss_lambda > 0: + olens_trg_cp = olens.new( + sorted([olen // self.reduction_factor for olen in olens], reverse=True) + ) + B, _in_length, _adim = hs.shape + _, _out_length, _ = att_ws.shape + # att_R should be [B, out_length / r_d, adim] + att_R = torch.sum( + hs.view(B, 1, _in_length, _adim) + * att_ws.view(B, _out_length, _in_length, 1), + dim=2, + ) + yt, ytlens = self.trg_reconstructor( + att_R, olens_trg_cp + ) # is using olens correct? + yt = self.trg_reconstructor_linear(yt) + if self.reduction_factor > 1: + yt = yt.view( + B, -1, self.odim + ) # now att_R should be [B, out_length, adim] + + # modifiy mod part of groundtruth + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + max_out = max(olens) + ys = ys[:, :max_out] + labels = labels[:, :max_out] + labels[:, -1] = 1.0 # make sure at least one frame has 1 + if self.encoder_reduction_factor > 1: + ilens = ilens.new( + [ilen - ilen % self.encoder_reduction_factor for ilen in ilens] + ) + max_in = max(ilens) + xs = xs[:, :max_in] + + # caluculate taco2 loss + l1_loss, mse_loss, bce_loss = self.taco2_loss( + after_outs, before_outs, logits, ys, labels, olens + ) + loss = l1_loss + mse_loss + bce_loss + report_keys = [ + {"l1_loss": l1_loss.item()}, + {"mse_loss": mse_loss.item()}, + {"bce_loss": bce_loss.item()}, + ] + + # caluculate context_perservation loss + if self.src_reconstruction_loss_lambda > 0: + src_recon_l1_loss, src_recon_mse_loss = self.src_reconstruction_loss( + xt, xs, ilens + ) + loss = loss + src_recon_l1_loss + report_keys += [ + {"src_recon_l1_loss": src_recon_l1_loss.item()}, + {"src_recon_mse_loss": src_recon_mse_loss.item()}, + ] + if self.trg_reconstruction_loss_lambda > 0: + trg_recon_l1_loss, trg_recon_mse_loss = self.trg_reconstruction_loss( + yt, ys, olens + ) + loss = loss + trg_recon_l1_loss + report_keys += [ + {"trg_recon_l1_loss": trg_recon_l1_loss.item()}, + {"trg_recon_mse_loss": trg_recon_mse_loss.item()}, + ] + + # caluculate attention loss + if self.use_guided_attn_loss: + # NOTE(kan-bayashi): length of output for auto-regressive input + # will be changed when r > 1 + if self.encoder_reduction_factor > 1: + ilens_in = ilens.new( + [ilen // self.encoder_reduction_factor for ilen in ilens] + ) + else: + ilens_in = ilens + if self.reduction_factor > 1: + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + olens_in = olens + attn_loss = self.attn_loss(att_ws, ilens_in, olens_in) + loss = loss + attn_loss + report_keys += [ + {"attn_loss": attn_loss.item()}, + ] + + # caluculate cbhg loss + if self.use_cbhg: + # remove unnecessary padded part (for multi-gpus) + if max_out != spcs.shape[1]: + spcs = spcs[:, :max_out] + + # caluculate cbhg outputs & loss and report them + cbhg_outs, _ = self.cbhg(after_outs, olens) + cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(cbhg_outs, spcs, olens) + loss = loss + cbhg_l1_loss + cbhg_mse_loss + report_keys += [ + {"cbhg_l1_loss": cbhg_l1_loss.item()}, + {"cbhg_mse_loss": cbhg_mse_loss.item()}, + ] + + report_keys += [{"loss": loss.item()}] + self.reporter.report(report_keys) + + return loss + + def inference(self, x, inference_args, spemb=None, *args, **kwargs): + """Generate the sequence of features given the sequences of characters. + + Args: + x (Tensor): Input sequence of acoustic features (T, idim). + inference_args (Namespace): + - threshold (float): Threshold in inference. + - minlenratio (float): Minimum length ratio in inference. + - maxlenratio (float): Maximum length ratio in inference. + spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). + + Returns: + Tensor: Output sequence of features (L, odim). + Tensor: Output sequence of stop probabilities (L,). + Tensor: Attention weights (L, T). + + """ + # get options + threshold = inference_args.threshold + minlenratio = inference_args.minlenratio + maxlenratio = inference_args.maxlenratio + + # thin out input frames for reduction factor + # (B, Lmax, idim) -> (B, Lmax // r, idim * r) + if self.encoder_reduction_factor > 1: + Lmax, idim = x.shape + if Lmax % self.encoder_reduction_factor != 0: + x = x[: -(Lmax % self.encoder_reduction_factor), :] + x_ds = x.contiguous().view( + int(Lmax / self.encoder_reduction_factor), + idim * self.encoder_reduction_factor, + ) + else: + x_ds = x + + # inference + h = self.enc.inference(x_ds) + if self.spk_embed_dim is not None: + spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) + h = torch.cat([h, spemb], dim=-1) + outs, probs, att_ws = self.dec.inference(h, threshold, minlenratio, maxlenratio) + + if self.use_cbhg: + cbhg_outs = self.cbhg.inference(outs) + return cbhg_outs, probs, att_ws + else: + return outs, probs, att_ws + + def calculate_all_attentions(self, xs, ilens, ys, spembs=None, *args, **kwargs): + """Calculate all of the attention weights. + + Args: + xs (Tensor): Batch of padded acoustic features (B, Tmax, idim). + ilens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + olens (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): + Batch of speaker embedding vectors (B, spk_embed_dim). + + Returns: + numpy.ndarray: Batch of attention weights (B, Lmax, Tmax). + + """ + # check ilens type (should be list of int) + if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray): + ilens = list(map(int, ilens)) + + self.eval() + with torch.no_grad(): + # thin out input frames for reduction factor + # (B, Lmax, idim) -> (B, Lmax // r, idim * r) + if self.encoder_reduction_factor > 1: + B, Lmax, idim = xs.shape + if Lmax % self.encoder_reduction_factor != 0: + xs = xs[:, : -(Lmax % self.encoder_reduction_factor), :] + xs_ds = xs.contiguous().view( + B, + int(Lmax / self.encoder_reduction_factor), + idim * self.encoder_reduction_factor, + ) + ilens_ds = [ilen // self.encoder_reduction_factor for ilen in ilens] + else: + xs_ds, ilens_ds = xs, ilens + + hs, hlens = self.enc(xs_ds, ilens_ds) + if self.spk_embed_dim is not None: + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = torch.cat([hs, spembs], dim=-1) + att_ws = self.dec.calculate_all_attentions(hs, hlens, ys) + self.train() + + return att_ws.cpu().numpy() + + @property + def base_plot_keys(self): + """Return base key names to plot during training. + + keys should match what `chainer.reporter` reports. + If you add the key `loss`, the reporter will report `main/loss` + and `validation/main/loss` values. + also `loss.png` will be created as a figure visulizing `main/loss` + and `validation/main/loss` values. + + Returns: + list: List of strings which are base keys to plot during training. + + """ + plot_keys = ["loss", "l1_loss", "mse_loss", "bce_loss"] + if self.use_guided_attn_loss: + plot_keys += ["attn_loss"] + if self.use_cbhg: + plot_keys += ["cbhg_l1_loss", "cbhg_mse_loss"] + if self.src_reconstruction_loss_lambda > 0: + plot_keys += ["src_recon_l1_loss", "src_recon_mse_loss"] + if self.trg_reconstruction_loss_lambda > 0: + plot_keys += ["trg_recon_l1_loss", "trg_recon_mse_loss"] + return plot_keys + + def _sort_by_length(self, xs, ilens): + sort_ilens, sort_idx = ilens.sort(0, descending=True) + return xs[sort_idx], ilens[sort_idx], sort_idx + + def _revert_sort_by_length(self, xs, ilens, sort_idx): + _, revert_idx = sort_idx.sort(0) + return xs[revert_idx], ilens[revert_idx] diff --git a/espnet/nets/pytorch_backend/e2e_vc_transformer.py b/espnet/nets/pytorch_backend/e2e_vc_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1b3c6913c1776be263f1a05e754ec6f1a997cc97 --- /dev/null +++ b/espnet/nets/pytorch_backend/e2e_vc_transformer.py @@ -0,0 +1,1155 @@ +# Copyright 2020 Nagoya University (Wen-Chin Huang) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Voice Transformer Network (Transformer-VC) related modules.""" + +import logging + +import torch +import torch.nn.functional as F + +from espnet.nets.pytorch_backend.e2e_asr_transformer import subsequent_mask +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import ( + Tacotron2Loss as TransformerLoss, # noqa: H301 +) +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet +from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as DecoderPrenet +from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder as EncoderPrenet +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.decoder import Decoder +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.initializer import initialize +from espnet.nets.tts_interface import TTSInterface +from espnet.utils.cli_utils import strtobool +from espnet.utils.fill_missing_args import fill_missing_args +from espnet.nets.pytorch_backend.e2e_tts_transformer import ( + GuidedMultiHeadAttentionLoss, # noqa: H301 + TTSPlot, # noqa: H301 +) + + +class Transformer(TTSInterface, torch.nn.Module): + """VC Transformer module. + + This is a module of the Voice Transformer Network + (a.k.a. VTN or Transformer-VC) described in + `Voice Transformer Network: Sequence-to-Sequence + Voice Conversion Using Transformer with + Text-to-Speech Pretraining`_, + which convert the sequence of acoustic features + into the sequence of acoustic features. + + .. _`Voice Transformer Network: Sequence-to-Sequence + Voice Conversion Using Transformer with + Text-to-Speech Pretraining`: + https://arxiv.org/pdf/1912.06813.pdf + + """ + + @staticmethod + def add_arguments(parser): + """Add model-specific arguments to the parser.""" + group = parser.add_argument_group("transformer model setting") + # network structure related + group.add_argument( + "--eprenet-conv-layers", + default=0, + type=int, + help="Number of encoder prenet convolution layers", + ) + group.add_argument( + "--eprenet-conv-chans", + default=0, + type=int, + help="Number of encoder prenet convolution channels", + ) + group.add_argument( + "--eprenet-conv-filts", + default=0, + type=int, + help="Filter size of encoder prenet convolution", + ) + group.add_argument( + "--transformer-input-layer", + default="linear", + type=str, + help="Type of input layer (linear or conv2d)", + ) + group.add_argument( + "--dprenet-layers", + default=2, + type=int, + help="Number of decoder prenet layers", + ) + group.add_argument( + "--dprenet-units", + default=256, + type=int, + help="Number of decoder prenet hidden units", + ) + group.add_argument( + "--elayers", default=3, type=int, help="Number of encoder layers" + ) + group.add_argument( + "--eunits", default=1536, type=int, help="Number of encoder hidden units" + ) + group.add_argument( + "--adim", + default=384, + type=int, + help="Number of attention transformation dimensions", + ) + group.add_argument( + "--aheads", + default=4, + type=int, + help="Number of heads for multi head attention", + ) + group.add_argument( + "--dlayers", default=3, type=int, help="Number of decoder layers" + ) + group.add_argument( + "--dunits", default=1536, type=int, help="Number of decoder hidden units" + ) + group.add_argument( + "--positionwise-layer-type", + default="linear", + type=str, + choices=["linear", "conv1d", "conv1d-linear"], + help="Positionwise layer type.", + ) + group.add_argument( + "--positionwise-conv-kernel-size", + default=1, + type=int, + help="Kernel size of positionwise conv1d layer", + ) + group.add_argument( + "--postnet-layers", default=5, type=int, help="Number of postnet layers" + ) + group.add_argument( + "--postnet-chans", default=256, type=int, help="Number of postnet channels" + ) + group.add_argument( + "--postnet-filts", default=5, type=int, help="Filter size of postnet" + ) + group.add_argument( + "--use-scaled-pos-enc", + default=True, + type=strtobool, + help="Use trainable scaled positional encoding" + "instead of the fixed scale one.", + ) + group.add_argument( + "--use-batch-norm", + default=True, + type=strtobool, + help="Whether to use batch normalization", + ) + group.add_argument( + "--encoder-normalize-before", + default=False, + type=strtobool, + help="Whether to apply layer norm before encoder block", + ) + group.add_argument( + "--decoder-normalize-before", + default=False, + type=strtobool, + help="Whether to apply layer norm before decoder block", + ) + group.add_argument( + "--encoder-concat-after", + default=False, + type=strtobool, + help="Whether to concatenate attention layer's input and output in encoder", + ) + group.add_argument( + "--decoder-concat-after", + default=False, + type=strtobool, + help="Whether to concatenate attention layer's input and output in decoder", + ) + group.add_argument( + "--reduction-factor", + default=1, + type=int, + help="Reduction factor (for decoder)", + ) + group.add_argument( + "--encoder-reduction-factor", + default=1, + type=int, + help="Reduction factor (for encoder)", + ) + group.add_argument( + "--spk-embed-dim", + default=None, + type=int, + help="Number of speaker embedding dimensions", + ) + group.add_argument( + "--spk-embed-integration-type", + type=str, + default="add", + choices=["add", "concat"], + help="How to integrate speaker embedding", + ) + # training related + group.add_argument( + "--transformer-init", + type=str, + default="pytorch", + choices=[ + "pytorch", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + ], + help="How to initialize transformer parameters", + ) + group.add_argument( + "--initial-encoder-alpha", + type=float, + default=1.0, + help="Initial alpha value in encoder's ScaledPositionalEncoding", + ) + group.add_argument( + "--initial-decoder-alpha", + type=float, + default=1.0, + help="Initial alpha value in decoder's ScaledPositionalEncoding", + ) + group.add_argument( + "--transformer-lr", + default=1.0, + type=float, + help="Initial value of learning rate", + ) + group.add_argument( + "--transformer-warmup-steps", + default=4000, + type=int, + help="Optimizer warmup steps", + ) + group.add_argument( + "--transformer-enc-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder except for attention", + ) + group.add_argument( + "--transformer-enc-positional-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder positional encoding", + ) + group.add_argument( + "--transformer-enc-attn-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder self-attention", + ) + group.add_argument( + "--transformer-dec-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer decoder " + "except for attention and pos encoding", + ) + group.add_argument( + "--transformer-dec-positional-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer decoder positional encoding", + ) + group.add_argument( + "--transformer-dec-attn-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer decoder self-attention", + ) + group.add_argument( + "--transformer-enc-dec-attn-dropout-rate", + default=0.1, + type=float, + help="Dropout rate for transformer encoder-decoder attention", + ) + group.add_argument( + "--eprenet-dropout-rate", + default=0.5, + type=float, + help="Dropout rate in encoder prenet", + ) + group.add_argument( + "--dprenet-dropout-rate", + default=0.5, + type=float, + help="Dropout rate in decoder prenet", + ) + group.add_argument( + "--postnet-dropout-rate", + default=0.5, + type=float, + help="Dropout rate in postnet", + ) + group.add_argument( + "--pretrained-model", default=None, type=str, help="Pretrained model path" + ) + + # loss related + group.add_argument( + "--use-masking", + default=True, + type=strtobool, + help="Whether to use masking in calculation of loss", + ) + group.add_argument( + "--use-weighted-masking", + default=False, + type=strtobool, + help="Whether to use weighted masking in calculation of loss", + ) + group.add_argument( + "--loss-type", + default="L1", + choices=["L1", "L2", "L1+L2"], + help="How to calc loss", + ) + group.add_argument( + "--bce-pos-weight", + default=5.0, + type=float, + help="Positive sample weight in BCE calculation " + "(only for use-masking=True)", + ) + group.add_argument( + "--use-guided-attn-loss", + default=False, + type=strtobool, + help="Whether to use guided attention loss", + ) + group.add_argument( + "--guided-attn-loss-sigma", + default=0.4, + type=float, + help="Sigma in guided attention loss", + ) + group.add_argument( + "--guided-attn-loss-lambda", + default=1.0, + type=float, + help="Lambda in guided attention loss", + ) + group.add_argument( + "--num-heads-applied-guided-attn", + default=2, + type=int, + help="Number of heads in each layer to be applied guided attention loss" + "if set -1, all of the heads will be applied.", + ) + group.add_argument( + "--num-layers-applied-guided-attn", + default=2, + type=int, + help="Number of layers to be applied guided attention loss" + "if set -1, all of the layers will be applied.", + ) + group.add_argument( + "--modules-applied-guided-attn", + type=str, + nargs="+", + default=["encoder-decoder"], + help="Module name list to be applied guided attention loss", + ) + return parser + + @property + def attention_plot_class(self): + """Return plot class for attention weight plot.""" + return TTSPlot + + def __init__(self, idim, odim, args=None): + """Initialize Transformer-VC module. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + args (Namespace, optional): + - eprenet_conv_layers (int): + Number of encoder prenet convolution layers. + - eprenet_conv_chans (int): + Number of encoder prenet convolution channels. + - eprenet_conv_filts (int): + Filter size of encoder prenet convolution. + - transformer_input_layer (str): Input layer before the encoder. + - dprenet_layers (int): Number of decoder prenet layers. + - dprenet_units (int): Number of decoder prenet hidden units. + - elayers (int): Number of encoder layers. + - eunits (int): Number of encoder hidden units. + - adim (int): Number of attention transformation dimensions. + - aheads (int): Number of heads for multi head attention. + - dlayers (int): Number of decoder layers. + - dunits (int): Number of decoder hidden units. + - postnet_layers (int): Number of postnet layers. + - postnet_chans (int): Number of postnet channels. + - postnet_filts (int): Filter size of postnet. + - use_scaled_pos_enc (bool): + Whether to use trainable scaled positional encoding. + - use_batch_norm (bool): + Whether to use batch normalization in encoder prenet. + - encoder_normalize_before (bool): + Whether to perform layer normalization before encoder block. + - decoder_normalize_before (bool): + Whether to perform layer normalization before decoder block. + - encoder_concat_after (bool): Whether to concatenate + attention layer's input and output in encoder. + - decoder_concat_after (bool): Whether to concatenate + attention layer's input and output in decoder. + - reduction_factor (int): Reduction factor (for decoder). + - encoder_reduction_factor (int): Reduction factor (for encoder). + - spk_embed_dim (int): Number of speaker embedding dimenstions. + - spk_embed_integration_type: How to integrate speaker embedding. + - transformer_init (float): How to initialize transformer parameters. + - transformer_lr (float): Initial value of learning rate. + - transformer_warmup_steps (int): Optimizer warmup steps. + - transformer_enc_dropout_rate (float): + Dropout rate in encoder except attention & positional encoding. + - transformer_enc_positional_dropout_rate (float): + Dropout rate after encoder positional encoding. + - transformer_enc_attn_dropout_rate (float): + Dropout rate in encoder self-attention module. + - transformer_dec_dropout_rate (float): + Dropout rate in decoder except attention & positional encoding. + - transformer_dec_positional_dropout_rate (float): + Dropout rate after decoder positional encoding. + - transformer_dec_attn_dropout_rate (float): + Dropout rate in deocoder self-attention module. + - transformer_enc_dec_attn_dropout_rate (float): + Dropout rate in encoder-deocoder attention module. + - eprenet_dropout_rate (float): Dropout rate in encoder prenet. + - dprenet_dropout_rate (float): Dropout rate in decoder prenet. + - postnet_dropout_rate (float): Dropout rate in postnet. + - use_masking (bool): + Whether to apply masking for padded part in loss calculation. + - use_weighted_masking (bool): + Whether to apply weighted masking in loss calculation. + - bce_pos_weight (float): Positive sample weight in bce calculation + (only for use_masking=true). + - loss_type (str): How to calculate loss. + - use_guided_attn_loss (bool): Whether to use guided attention loss. + - num_heads_applied_guided_attn (int): + Number of heads in each layer to apply guided attention loss. + - num_layers_applied_guided_attn (int): + Number of layers to apply guided attention loss. + - modules_applied_guided_attn (list): + List of module names to apply guided attention loss. + - guided-attn-loss-sigma (float) Sigma in guided attention loss. + - guided-attn-loss-lambda (float): Lambda in guided attention loss. + + """ + # initialize base classes + TTSInterface.__init__(self) + torch.nn.Module.__init__(self) + + # fill missing arguments + args = fill_missing_args(args, self.add_arguments) + + # store hyperparameters + self.idim = idim + self.odim = odim + self.spk_embed_dim = args.spk_embed_dim + if self.spk_embed_dim is not None: + self.spk_embed_integration_type = args.spk_embed_integration_type + self.use_scaled_pos_enc = args.use_scaled_pos_enc + self.reduction_factor = args.reduction_factor + self.encoder_reduction_factor = args.encoder_reduction_factor + self.transformer_input_layer = args.transformer_input_layer + self.loss_type = args.loss_type + self.use_guided_attn_loss = args.use_guided_attn_loss + if self.use_guided_attn_loss: + if args.num_layers_applied_guided_attn == -1: + self.num_layers_applied_guided_attn = args.elayers + else: + self.num_layers_applied_guided_attn = ( + args.num_layers_applied_guided_attn + ) + if args.num_heads_applied_guided_attn == -1: + self.num_heads_applied_guided_attn = args.aheads + else: + self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn + self.modules_applied_guided_attn = args.modules_applied_guided_attn + + # use idx 0 as padding idx + padding_idx = 0 + + # get positional encoding class + pos_enc_class = ( + ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding + ) + + # define transformer encoder + if args.eprenet_conv_layers != 0: + # encoder prenet + encoder_input_layer = torch.nn.Sequential( + EncoderPrenet( + idim=idim, + elayers=0, + econv_layers=args.eprenet_conv_layers, + econv_chans=args.eprenet_conv_chans, + econv_filts=args.eprenet_conv_filts, + use_batch_norm=args.use_batch_norm, + dropout_rate=args.eprenet_dropout_rate, + padding_idx=padding_idx, + input_layer=torch.nn.Linear( + idim * args.encoder_reduction_factor, idim + ), + ), + torch.nn.Linear(args.eprenet_conv_chans, args.adim), + ) + elif args.transformer_input_layer == "linear": + encoder_input_layer = torch.nn.Linear( + idim * args.encoder_reduction_factor, args.adim + ) + else: + encoder_input_layer = args.transformer_input_layer + self.encoder = Encoder( + idim=idim, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.eunits, + num_blocks=args.elayers, + input_layer=encoder_input_layer, + dropout_rate=args.transformer_enc_dropout_rate, + positional_dropout_rate=args.transformer_enc_positional_dropout_rate, + attention_dropout_rate=args.transformer_enc_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=args.encoder_normalize_before, + concat_after=args.encoder_concat_after, + positionwise_layer_type=args.positionwise_layer_type, + positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, + ) + + # define projection layer + if self.spk_embed_dim is not None: + if self.spk_embed_integration_type == "add": + self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) + else: + self.projection = torch.nn.Linear( + args.adim + self.spk_embed_dim, args.adim + ) + + # define transformer decoder + if args.dprenet_layers != 0: + # decoder prenet + decoder_input_layer = torch.nn.Sequential( + DecoderPrenet( + idim=odim, + n_layers=args.dprenet_layers, + n_units=args.dprenet_units, + dropout_rate=args.dprenet_dropout_rate, + ), + torch.nn.Linear(args.dprenet_units, args.adim), + ) + else: + decoder_input_layer = "linear" + self.decoder = Decoder( + odim=-1, + attention_dim=args.adim, + attention_heads=args.aheads, + linear_units=args.dunits, + num_blocks=args.dlayers, + dropout_rate=args.transformer_dec_dropout_rate, + positional_dropout_rate=args.transformer_dec_positional_dropout_rate, + self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, + src_attention_dropout_rate=args.transformer_enc_dec_attn_dropout_rate, + input_layer=decoder_input_layer, + use_output_layer=False, + pos_enc_class=pos_enc_class, + normalize_before=args.decoder_normalize_before, + concat_after=args.decoder_concat_after, + ) + + # define final projection + self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) + self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) + + # define postnet + self.postnet = ( + None + if args.postnet_layers == 0 + else Postnet( + idim=idim, + odim=odim, + n_layers=args.postnet_layers, + n_chans=args.postnet_chans, + n_filts=args.postnet_filts, + use_batch_norm=args.use_batch_norm, + dropout_rate=args.postnet_dropout_rate, + ) + ) + + # define loss function + self.criterion = TransformerLoss( + use_masking=args.use_masking, + use_weighted_masking=args.use_weighted_masking, + bce_pos_weight=args.bce_pos_weight, + ) + if self.use_guided_attn_loss: + self.attn_criterion = GuidedMultiHeadAttentionLoss( + sigma=args.guided_attn_loss_sigma, + alpha=args.guided_attn_loss_lambda, + ) + + # initialize parameters + self._reset_parameters( + init_type=args.transformer_init, + init_enc_alpha=args.initial_encoder_alpha, + init_dec_alpha=args.initial_decoder_alpha, + ) + + # load pretrained model + if args.pretrained_model is not None: + self.load_pretrained_model(args.pretrained_model) + + def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): + # initialize parameters + initialize(self, init_type) + + # initialize alpha in scaled positional encoding + if self.use_scaled_pos_enc: + self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) + self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) + + def _add_first_frame_and_remove_last_frame(self, ys): + ys_in = torch.cat( + [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1 + ) + return ys_in + + def forward(self, xs, ilens, ys, labels, olens, spembs=None, *args, **kwargs): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of padded acoustic features (B, Tmax, idim). + ilens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + olens (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): Batch of speaker embedding vectors + (B, spk_embed_dim). + + Returns: + Tensor: Loss value. + + """ + # remove unnecessary padded part (for multi-gpus) + max_ilen = max(ilens) + max_olen = max(olens) + if max_ilen != xs.shape[1]: + xs = xs[:, :max_ilen] + if max_olen != ys.shape[1]: + ys = ys[:, :max_olen] + labels = labels[:, :max_olen] + + # thin out input frames for reduction factor + # (B, Lmax, idim) -> (B, Lmax // r, idim * r) + if self.encoder_reduction_factor > 1: + B, Lmax, idim = xs.shape + if Lmax % self.encoder_reduction_factor != 0: + xs = xs[:, : -(Lmax % self.encoder_reduction_factor), :] + xs_ds = xs.contiguous().view( + B, + int(Lmax / self.encoder_reduction_factor), + idim * self.encoder_reduction_factor, + ) + ilens_ds = ilens.new( + [ilen // self.encoder_reduction_factor for ilen in ilens] + ) + else: + xs_ds, ilens_ds = xs, ilens + + # forward encoder + x_masks = self._source_mask(ilens_ds) + hs, hs_masks = self.encoder(xs_ds, x_masks) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs_int = self._integrate_with_spk_embed(hs, spembs) + else: + hs_int = hs + + # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) + if self.reduction_factor > 1: + ys_in = ys[:, self.reduction_factor - 1 :: self.reduction_factor] + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + ys_in, olens_in = ys, olens + + # add first zero frame and remove last frame for auto-regressive + ys_in = self._add_first_frame_and_remove_last_frame(ys_in) + + # if conv2d, modify mask. Use ceiling division here + if "conv2d" in self.transformer_input_layer: + ilens_ds_st = ilens_ds.new( + [((ilen - 2 + 1) // 2 - 2 + 1) // 2 for ilen in ilens_ds] + ) + else: + ilens_ds_st = ilens_ds + + # forward decoder + y_masks = self._target_mask(olens_in) + zs, _ = self.decoder(ys_in, y_masks, hs_int, hs_masks) + # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) + before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) + # (B, Lmax//r, r) -> (B, Lmax//r * r) + logits = self.prob_out(zs).view(zs.size(0), -1) + + # postnet -> (B, Lmax//r * r, odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose(1, 2) + ).transpose(1, 2) + + # modifiy mod part of groundtruth + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + max_olen = max(olens) + ys = ys[:, :max_olen] + labels = labels[:, :max_olen] + labels[:, -1] = 1.0 # make sure at least one frame has 1 + + # caluculate loss values + l1_loss, l2_loss, bce_loss = self.criterion( + after_outs, before_outs, logits, ys, labels, olens + ) + if self.loss_type == "L1": + loss = l1_loss + bce_loss + elif self.loss_type == "L2": + loss = l2_loss + bce_loss + elif self.loss_type == "L1+L2": + loss = l1_loss + l2_loss + bce_loss + else: + raise ValueError("unknown --loss-type " + self.loss_type) + report_keys = [ + {"l1_loss": l1_loss.item()}, + {"l2_loss": l2_loss.item()}, + {"bce_loss": bce_loss.item()}, + {"loss": loss.item()}, + ] + + # calculate guided attention loss + if self.use_guided_attn_loss: + # calculate for encoder + if "encoder" in self.modules_applied_guided_attn: + att_ws = [] + for idx, layer_idx in enumerate( + reversed(range(len(self.encoder.encoders))) + ): + att_ws += [ + self.encoder.encoders[layer_idx].self_attn.attn[ + :, : self.num_heads_applied_guided_attn + ] + ] + if idx + 1 == self.num_layers_applied_guided_attn: + break + att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_in, T_in) + enc_attn_loss = self.attn_criterion( + att_ws, ilens_ds_st, ilens_ds_st + ) # TODO(unilight): is changing to ilens_ds_st right? + loss = loss + enc_attn_loss + report_keys += [{"enc_attn_loss": enc_attn_loss.item()}] + # calculate for decoder + if "decoder" in self.modules_applied_guided_attn: + att_ws = [] + for idx, layer_idx in enumerate( + reversed(range(len(self.decoder.decoders))) + ): + att_ws += [ + self.decoder.decoders[layer_idx].self_attn.attn[ + :, : self.num_heads_applied_guided_attn + ] + ] + if idx + 1 == self.num_layers_applied_guided_attn: + break + att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_out) + dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) + loss = loss + dec_attn_loss + report_keys += [{"dec_attn_loss": dec_attn_loss.item()}] + # calculate for encoder-decoder + if "encoder-decoder" in self.modules_applied_guided_attn: + att_ws = [] + for idx, layer_idx in enumerate( + reversed(range(len(self.decoder.decoders))) + ): + att_ws += [ + self.decoder.decoders[layer_idx].src_attn.attn[ + :, : self.num_heads_applied_guided_attn + ] + ] + if idx + 1 == self.num_layers_applied_guided_attn: + break + att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_in) + enc_dec_attn_loss = self.attn_criterion( + att_ws, ilens_ds_st, olens_in + ) # TODO(unilight): is changing to ilens_ds_st right? + loss = loss + enc_dec_attn_loss + report_keys += [{"enc_dec_attn_loss": enc_dec_attn_loss.item()}] + + # report extra information + if self.use_scaled_pos_enc: + report_keys += [ + {"encoder_alpha": self.encoder.embed[-1].alpha.data.item()}, + {"decoder_alpha": self.decoder.embed[-1].alpha.data.item()}, + ] + self.reporter.report(report_keys) + + return loss + + def inference(self, x, inference_args, spemb=None, *args, **kwargs): + """Generate the sequence of features given the sequences of acoustic features. + + Args: + x (Tensor): Input sequence of acoustic features (T, idim). + inference_args (Namespace): + - threshold (float): Threshold in inference. + - minlenratio (float): Minimum length ratio in inference. + - maxlenratio (float): Maximum length ratio in inference. + spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). + + Returns: + Tensor: Output sequence of features (L, odim). + Tensor: Output sequence of stop probabilities (L,). + Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). + + """ + # get options + threshold = inference_args.threshold + minlenratio = inference_args.minlenratio + maxlenratio = inference_args.maxlenratio + use_att_constraint = getattr( + inference_args, "use_att_constraint", False + ) # keep compatibility + if use_att_constraint: + logging.warning( + "Attention constraint is not yet supported in Transformer. Not enabled." + ) + + # thin out input frames for reduction factor + # (B, Lmax, idim) -> (B, Lmax // r, idim * r) + if self.encoder_reduction_factor > 1: + Lmax, idim = x.shape + if Lmax % self.encoder_reduction_factor != 0: + x = x[: -(Lmax % self.encoder_reduction_factor), :] + x_ds = x.contiguous().view( + int(Lmax / self.encoder_reduction_factor), + idim * self.encoder_reduction_factor, + ) + else: + x_ds = x + + # forward encoder + x_ds = x_ds.unsqueeze(0) + hs, _ = self.encoder(x_ds, None) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + spembs = spemb.unsqueeze(0) + hs = self._integrate_with_spk_embed(hs, spembs) + + # set limits of length + maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) + minlen = int(hs.size(1) * minlenratio / self.reduction_factor) + + # initialize + idx = 0 + ys = hs.new_zeros(1, 1, self.odim) + outs, probs = [], [] + + # forward decoder step-by-step + z_cache = self.decoder.init_state(x) + while True: + # update index + idx += 1 + + # calculate output and stop prob at idx-th step + y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) + z, z_cache = self.decoder.forward_one_step( + ys, y_masks, hs, cache=z_cache + ) # (B, adim) + outs += [ + self.feat_out(z).view(self.reduction_factor, self.odim) + ] # [(r, odim), ...] + probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] + + # update next inputs + ys = torch.cat( + (ys, outs[-1][-1].view(1, 1, self.odim)), dim=1 + ) # (1, idx + 1, odim) + + # get attention weights + att_ws_ = [] + for name, m in self.named_modules(): + if isinstance(m, MultiHeadedAttention) and "src" in name: + att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] # [(#heads, 1, T),...] + if idx == 1: + att_ws = att_ws_ + else: + # [(#heads, l, T), ...] + att_ws = [ + torch.cat([att_w, att_w_], dim=1) + for att_w, att_w_ in zip(att_ws, att_ws_) + ] + + # check whether to finish generation + if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: + # check mininum length + if idx < minlen: + continue + outs = ( + torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) + ) # (L, odim) -> (1, L, odim) -> (1, odim, L) + if self.postnet is not None: + outs = outs + self.postnet(outs) # (1, odim, L) + outs = outs.transpose(2, 1).squeeze(0) # (L, odim) + probs = torch.cat(probs, dim=0) + break + + # concatenate attention weights -> (#layers, #heads, L, T) + att_ws = torch.stack(att_ws, dim=0) + + return outs, probs, att_ws + + def calculate_all_attentions( + self, + xs, + ilens, + ys, + olens, + spembs=None, + skip_output=False, + keep_tensor=False, + *args, + **kwargs + ): + """Calculate all of the attention weights. + + Args: + xs (Tensor): Batch of padded acoustic features (B, Tmax, idim). + ilens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): Batch of padded target features (B, Lmax, odim). + olens (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): Batch of speaker embedding vectors + (B, spk_embed_dim). + skip_output (bool, optional): Whether to skip calculate the final output. + keep_tensor (bool, optional): Whether to keep original tensor. + + Returns: + dict: Dict of attention weights and outputs. + + """ + with torch.no_grad(): + # thin out input frames for reduction factor + # (B, Lmax, idim) -> (B, Lmax // r, idim * r) + if self.encoder_reduction_factor > 1: + B, Lmax, idim = xs.shape + if Lmax % self.encoder_reduction_factor != 0: + xs = xs[:, : -(Lmax % self.encoder_reduction_factor), :] + xs_ds = xs.contiguous().view( + B, + int(Lmax / self.encoder_reduction_factor), + idim * self.encoder_reduction_factor, + ) + ilens_ds = ilens.new( + [ilen // self.encoder_reduction_factor for ilen in ilens] + ) + else: + xs_ds, ilens_ds = xs, ilens + + # forward encoder + x_masks = self._source_mask(ilens_ds) + hs, hs_masks = self.encoder(xs_ds, x_masks) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + # thin out frames for reduction factor + # (B, Lmax, odim) -> (B, Lmax//r, odim) + if self.reduction_factor > 1: + ys_in = ys[:, self.reduction_factor - 1 :: self.reduction_factor] + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + ys_in, olens_in = ys, olens + + # add first zero frame and remove last frame for auto-regressive + ys_in = self._add_first_frame_and_remove_last_frame(ys_in) + + # forward decoder + y_masks = self._target_mask(olens_in) + zs, _ = self.decoder(ys_in, y_masks, hs, hs_masks) + + # calculate final outputs + if not skip_output: + before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose(1, 2) + ).transpose(1, 2) + + # modifiy mod part of output lengths due to reduction factor > 1 + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + + # store into dict + att_ws_dict = dict() + if keep_tensor: + for name, m in self.named_modules(): + if isinstance(m, MultiHeadedAttention): + att_ws_dict[name] = m.attn + if not skip_output: + att_ws_dict["before_postnet_fbank"] = before_outs + att_ws_dict["after_postnet_fbank"] = after_outs + else: + for name, m in self.named_modules(): + if isinstance(m, MultiHeadedAttention): + attn = m.attn.cpu().numpy() + if "encoder" in name: + attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())] + elif "decoder" in name: + if "src" in name: + attn = [ + a[:, :ol, :il] + for a, il, ol in zip( + attn, ilens.tolist(), olens_in.tolist() + ) + ] + elif "self" in name: + attn = [ + a[:, :l, :l] for a, l in zip(attn, olens_in.tolist()) + ] + else: + logging.warning("unknown attention module: " + name) + else: + logging.warning("unknown attention module: " + name) + att_ws_dict[name] = attn + if not skip_output: + before_outs = before_outs.cpu().numpy() + after_outs = after_outs.cpu().numpy() + att_ws_dict["before_postnet_fbank"] = [ + m[:l].T for m, l in zip(before_outs, olens.tolist()) + ] + att_ws_dict["after_postnet_fbank"] = [ + m[:l].T for m, l in zip(after_outs, olens.tolist()) + ] + + return att_ws_dict + + def _integrate_with_spk_embed(self, hs, spembs): + """Integrate speaker embedding with hidden states. + + Args: + hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). + spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) + + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spembs = self.projection(F.normalize(spembs)) + hs = hs + spembs.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds and then apply projection + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = self.projection(torch.cat([hs, spembs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs + + def _source_mask(self, ilens): + """Make masks for self-attention. + + Args: + ilens (LongTensor or List): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [[1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + return x_masks.unsqueeze(-2) + + def _target_mask(self, olens): + """Make masks for masked self-attention. + + Args: + olens (LongTensor or List): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for masked self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> olens = [5, 3] + >>> self._target_mask(olens) + tensor([[[1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 1, 0], + [1, 1, 1, 1, 1]], + [[1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) + s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) + return y_masks.unsqueeze(-2) & s_masks + + @property + def base_plot_keys(self): + """Return base key names to plot during training. + + keys should match what `chainer.reporter` reports. + If you add the key `loss`, the reporter will report `main/loss` + and `validation/main/loss` values. + also `loss.png` will be created as a figure visulizing `main/loss` + and `validation/main/loss` values. + + Returns: + list: List of strings which are base keys to plot during training. + + """ + plot_keys = ["loss", "l1_loss", "l2_loss", "bce_loss"] + if self.use_scaled_pos_enc: + plot_keys += ["encoder_alpha", "decoder_alpha"] + if self.use_guided_attn_loss: + if "encoder" in self.modules_applied_guided_attn: + plot_keys += ["enc_attn_loss"] + if "decoder" in self.modules_applied_guided_attn: + plot_keys += ["dec_attn_loss"] + if "encoder-decoder" in self.modules_applied_guided_attn: + plot_keys += ["enc_dec_attn_loss"] + + return plot_keys diff --git a/espnet/nets/pytorch_backend/fastspeech/__init__.py b/espnet/nets/pytorch_backend/fastspeech/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/fastspeech/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/fastspeech/duration_calculator.py b/espnet/nets/pytorch_backend/fastspeech/duration_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..4a508ee1df9b3e7c94c07ed625b41e72ce7814f2 --- /dev/null +++ b/espnet/nets/pytorch_backend/fastspeech/duration_calculator.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Duration calculator related modules.""" + +import torch + +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2 +from espnet.nets.pytorch_backend.e2e_tts_transformer import Transformer +from espnet.nets.pytorch_backend.nets_utils import pad_list + + +class DurationCalculator(torch.nn.Module): + """Duration calculator module for FastSpeech. + + Todo: + * Fix the duplicated calculation of diagonal head decision + + """ + + def __init__(self, teacher_model): + """Initialize duration calculator module. + + Args: + teacher_model (e2e_tts_transformer.Transformer): + Pretrained auto-regressive Transformer. + + """ + super(DurationCalculator, self).__init__() + if isinstance(teacher_model, Transformer): + self.register_buffer("diag_head_idx", torch.tensor(-1)) + elif isinstance(teacher_model, Tacotron2): + pass + else: + raise ValueError( + "teacher model should be the instance of " + "e2e_tts_transformer.Transformer or e2e_tts_tacotron2.Tacotron2." + ) + self.teacher_model = teacher_model + + def forward(self, xs, ilens, ys, olens, spembs=None): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of the padded sequences of character ids (B, Tmax). + ilens (Tensor): Batch of lengths of each input sequence (B,). + ys (Tensor): + Batch of the padded sequence of target features (B, Lmax, odim). + olens (Tensor): Batch of lengths of each output sequence (B,). + spembs (Tensor, optional): + Batch of speaker embedding vectors (B, spk_embed_dim). + + Returns: + Tensor: Batch of durations (B, Tmax). + + """ + if isinstance(self.teacher_model, Transformer): + att_ws = self._calculate_encoder_decoder_attentions( + xs, ilens, ys, olens, spembs=spembs + ) + # TODO(kan-bayashi): fix this issue + # this does not work in multi-gpu case. registered buffer is not saved. + if int(self.diag_head_idx) == -1: + self._init_diagonal_head(att_ws) + att_ws = att_ws[:, self.diag_head_idx] + else: + # NOTE(kan-bayashi): Here we assume that the teacher is tacotron 2 + att_ws = self.teacher_model.calculate_all_attentions( + xs, ilens, ys, spembs=spembs, keep_tensor=True + ) + durations = [ + self._calculate_duration(att_w, ilen, olen) + for att_w, ilen, olen in zip(att_ws, ilens, olens) + ] + + return pad_list(durations, 0) + + @staticmethod + def _calculate_duration(att_w, ilen, olen): + return torch.stack( + [att_w[:olen, :ilen].argmax(-1).eq(i).sum() for i in range(ilen)] + ) + + def _init_diagonal_head(self, att_ws): + diagonal_scores = att_ws.max(dim=-1)[0].mean(dim=-1).mean(dim=0) # (H * L,) + self.register_buffer("diag_head_idx", diagonal_scores.argmax()) + + def _calculate_encoder_decoder_attentions(self, xs, ilens, ys, olens, spembs=None): + att_dict = self.teacher_model.calculate_all_attentions( + xs, ilens, ys, olens, spembs=spembs, skip_output=True, keep_tensor=True + ) + return torch.cat( + [att_dict[k] for k in att_dict.keys() if "src_attn" in k], dim=1 + ) # (B, H*L, Lmax, Tmax) diff --git a/espnet/nets/pytorch_backend/fastspeech/duration_predictor.py b/espnet/nets/pytorch_backend/fastspeech/duration_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..f95334fee2ffa7d707dcdf3d5693584895a53f05 --- /dev/null +++ b/espnet/nets/pytorch_backend/fastspeech/duration_predictor.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Duration predictor related modules.""" + +import torch + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + + +class DurationPredictor(torch.nn.Module): + """Duration predictor module. + + This is a module of duration predictor described + in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + The duration predictor predicts a duration of each frame in log domain + from the hidden embeddings of encoder. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + Note: + The calculation domain of outputs is different + between in `forward` and in `inference`. In `forward`, + the outputs are calculated in log domain but in `inference`, + those are calculated in linear domain. + + """ + + def __init__( + self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0 + ): + """Initilize duration predictor module. + + Args: + idim (int): Input dimension. + n_layers (int, optional): Number of convolutional layers. + n_chans (int, optional): Number of channels of convolutional layers. + kernel_size (int, optional): Kernel size of convolutional layers. + dropout_rate (float, optional): Dropout rate. + offset (float, optional): Offset value to avoid nan in log domain. + + """ + super(DurationPredictor, self).__init__() + self.offset = offset + self.conv = torch.nn.ModuleList() + for idx in range(n_layers): + in_chans = idim if idx == 0 else n_chans + self.conv += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chans, + n_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ), + torch.nn.ReLU(), + LayerNorm(n_chans, dim=1), + torch.nn.Dropout(dropout_rate), + ) + ] + self.linear = torch.nn.Linear(n_chans, 1) + + def _forward(self, xs, x_masks=None, is_inference=False): + xs = xs.transpose(1, -1) # (B, idim, Tmax) + for f in self.conv: + xs = f(xs) # (B, C, Tmax) + + # NOTE: calculate in log domain + xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax) + + if is_inference: + # NOTE: calculate in linear domain + xs = torch.clamp( + torch.round(xs.exp() - self.offset), min=0 + ).long() # avoid negative value + + if x_masks is not None: + xs = xs.masked_fill(x_masks, 0.0) + + return xs + + def forward(self, xs, x_masks=None): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + x_masks (ByteTensor, optional): + Batch of masks indicating padded part (B, Tmax). + + Returns: + Tensor: Batch of predicted durations in log domain (B, Tmax). + + """ + return self._forward(xs, x_masks, False) + + def inference(self, xs, x_masks=None): + """Inference duration. + + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + x_masks (ByteTensor, optional): + Batch of masks indicating padded part (B, Tmax). + + Returns: + LongTensor: Batch of predicted durations in linear domain (B, Tmax). + + """ + return self._forward(xs, x_masks, True) + + +class DurationPredictorLoss(torch.nn.Module): + """Loss function module for duration predictor. + + The loss value is Calculated in log domain to make it Gaussian. + + """ + + def __init__(self, offset=1.0, reduction="mean"): + """Initilize duration predictor loss module. + + Args: + offset (float, optional): Offset value to avoid nan in log domain. + reduction (str): Reduction type in loss calculation. + + """ + super(DurationPredictorLoss, self).__init__() + self.criterion = torch.nn.MSELoss(reduction=reduction) + self.offset = offset + + def forward(self, outputs, targets): + """Calculate forward propagation. + + Args: + outputs (Tensor): Batch of prediction durations in log domain (B, T) + targets (LongTensor): Batch of groundtruth durations in linear domain (B, T) + + Returns: + Tensor: Mean squared error loss value. + + Note: + `outputs` is in log domain but `targets` is in linear domain. + + """ + # NOTE: outputs is in log domain while targets in linear + targets = torch.log(targets.float() + self.offset) + loss = self.criterion(outputs, targets) + + return loss diff --git a/espnet/nets/pytorch_backend/fastspeech/length_regulator.py b/espnet/nets/pytorch_backend/fastspeech/length_regulator.py new file mode 100644 index 0000000000000000000000000000000000000000..4f14560a84d385bb0bb97d3b557c2460e7203839 --- /dev/null +++ b/espnet/nets/pytorch_backend/fastspeech/length_regulator.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Length regulator related modules.""" + +import logging + +from distutils.version import LooseVersion + +import torch + +from espnet.nets.pytorch_backend.nets_utils import pad_list + +is_torch_1_1_plus = LooseVersion(torch.__version__) >= LooseVersion("1.1") + + +class LengthRegulator(torch.nn.Module): + """Length regulator module for feed-forward Transformer. + + This is a module of length regulator described in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + The length regulator expands char or + phoneme-level embedding features to frame-level by repeating each + feature based on the corresponding predicted durations. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + def __init__(self, pad_value=0.0): + """Initilize length regulator module. + + Args: + pad_value (float, optional): Value used for padding. + + """ + super(LengthRegulator, self).__init__() + self.pad_value = pad_value + if is_torch_1_1_plus: + self.repeat_fn = self._repeat_one_sequence + else: + self.repeat_fn = self._legacy_repeat_one_sequence + + def forward(self, xs, ds, alpha=1.0): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D). + ds (LongTensor): Batch of durations of each frame (B, T). + alpha (float, optional): Alpha value to control speed of speech. + + Returns: + Tensor: replicated input tensor based on durations (B, T*, D). + + """ + if alpha != 1.0: + assert alpha > 0 + ds = torch.round(ds.float() * alpha).long() + + if ds.sum() == 0: + logging.warning( + "predicted durations includes all 0 sequences. " + "fill the first element with 1." + ) + # NOTE(kan-bayashi): This case must not be happend in teacher forcing. + # It will be happened in inference with a bad duration predictor. + # So we do not need to care the padded sequence case here. + ds[ds.sum(dim=1).eq(0)] = 1 + + return pad_list([self.repeat_fn(x, d) for x, d in zip(xs, ds)], self.pad_value) + + def _repeat_one_sequence(self, x, d): + """Repeat each frame according to duration for torch 1.1+.""" + return torch.repeat_interleave(x, d, dim=0) + + def _legacy_repeat_one_sequence(self, x, d): + """Repeat each frame according to duration for torch 1.0. + + Examples: + >>> x = torch.tensor([[1], [2], [3]]) + tensor([[1], + [2], + [3]]) + >>> d = torch.tensor([1, 2, 3]) + tensor([1, 2, 3]) + >>> self._repeat_one_sequence(x, d) + tensor([[1], + [2], + [2], + [3], + [3], + [3]]) + + """ + return torch.cat( + [x_.repeat(int(d_), 1) for x_, d_ in zip(x, d) if d_ != 0], dim=0 + ) diff --git a/espnet/nets/pytorch_backend/frontends/__init__.py b/espnet/nets/pytorch_backend/frontends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/frontends/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/frontends/beamformer.py b/espnet/nets/pytorch_backend/frontends/beamformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3eccee4cf98b164f8eb9802bde3741ac23dc9dc --- /dev/null +++ b/espnet/nets/pytorch_backend/frontends/beamformer.py @@ -0,0 +1,84 @@ +import torch +from torch_complex import functional as FC +from torch_complex.tensor import ComplexTensor + + +def get_power_spectral_density_matrix( + xs: ComplexTensor, mask: torch.Tensor, normalization=True, eps: float = 1e-15 +) -> ComplexTensor: + """Return cross-channel power spectral density (PSD) matrix + + Args: + xs (ComplexTensor): (..., F, C, T) + mask (torch.Tensor): (..., F, C, T) + normalization (bool): + eps (float): + Returns + psd (ComplexTensor): (..., F, C, C) + + """ + # outer product: (..., C_1, T) x (..., C_2, T) -> (..., T, C, C_2) + psd_Y = FC.einsum("...ct,...et->...tce", [xs, xs.conj()]) + + # Averaging mask along C: (..., C, T) -> (..., T) + mask = mask.mean(dim=-2) + + # Normalized mask along T: (..., T) + if normalization: + # If assuming the tensor is padded with zero, the summation along + # the time axis is same regardless of the padding length. + mask = mask / (mask.sum(dim=-1, keepdim=True) + eps) + + # psd: (..., T, C, C) + psd = psd_Y * mask[..., None, None] + # (..., T, C, C) -> (..., C, C) + psd = psd.sum(dim=-3) + + return psd + + +def get_mvdr_vector( + psd_s: ComplexTensor, + psd_n: ComplexTensor, + reference_vector: torch.Tensor, + eps: float = 1e-15, +) -> ComplexTensor: + """Return the MVDR(Minimum Variance Distortionless Response) vector: + + h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u + + Reference: + On optimal frequency-domain multichannel linear filtering + for noise reduction; M. Souden et al., 2010; + https://ieeexplore.ieee.org/document/5089420 + + Args: + psd_s (ComplexTensor): (..., F, C, C) + psd_n (ComplexTensor): (..., F, C, C) + reference_vector (torch.Tensor): (..., C) + eps (float): + Returns: + beamform_vector (ComplexTensor)r: (..., F, C) + """ + # Add eps + C = psd_n.size(-1) + eye = torch.eye(C, dtype=psd_n.dtype, device=psd_n.device) + shape = [1 for _ in range(psd_n.dim() - 2)] + [C, C] + eye = eye.view(*shape) + psd_n += eps * eye + + # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) + numerator = FC.einsum("...ec,...cd->...ed", [psd_n.inverse(), psd_s]) + # ws: (..., C, C) / (...,) -> (..., C, C) + ws = numerator / (FC.trace(numerator)[..., None, None] + eps) + # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) + beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) + return beamform_vector + + +def apply_beamforming_vector( + beamform_vector: ComplexTensor, mix: ComplexTensor +) -> ComplexTensor: + # (..., C) x (..., C, T) -> (..., T) + es = FC.einsum("...c,...ct->...t", [beamform_vector.conj(), mix]) + return es diff --git a/espnet/nets/pytorch_backend/frontends/dnn_beamformer.py b/espnet/nets/pytorch_backend/frontends/dnn_beamformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fc8031b807ebc5270dd712c514fd20ed0b1f061c --- /dev/null +++ b/espnet/nets/pytorch_backend/frontends/dnn_beamformer.py @@ -0,0 +1,177 @@ +from distutils.version import LooseVersion +from typing import Tuple + +import torch +from torch.nn import functional as F + +from espnet.nets.pytorch_backend.frontends.beamformer import apply_beamforming_vector +from espnet.nets.pytorch_backend.frontends.beamformer import get_mvdr_vector +from espnet.nets.pytorch_backend.frontends.beamformer import ( + get_power_spectral_density_matrix, # noqa: H301 +) +from espnet.nets.pytorch_backend.frontends.mask_estimator import MaskEstimator +from torch_complex.tensor import ComplexTensor + +is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0") +is_torch_1_3_plus = LooseVersion(torch.__version__) >= LooseVersion("1.3.0") + + +class DNN_Beamformer(torch.nn.Module): + """DNN mask based Beamformer + + Citation: + Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017; + https://arxiv.org/abs/1703.04783 + + """ + + def __init__( + self, + bidim, + btype="blstmp", + blayers=3, + bunits=300, + bprojs=320, + bnmask=2, + dropout_rate=0.0, + badim=320, + ref_channel: int = -1, + beamformer_type="mvdr", + ): + super().__init__() + self.mask = MaskEstimator( + btype, bidim, blayers, bunits, bprojs, dropout_rate, nmask=bnmask + ) + self.ref = AttentionReference(bidim, badim) + self.ref_channel = ref_channel + + self.nmask = bnmask + + if beamformer_type != "mvdr": + raise ValueError( + "Not supporting beamformer_type={}".format(beamformer_type) + ) + self.beamformer_type = beamformer_type + + def forward( + self, data: ComplexTensor, ilens: torch.LongTensor + ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: + """The forward function + + Notation: + B: Batch + C: Channel + T: Time or Sequence length + F: Freq + + Args: + data (ComplexTensor): (B, T, C, F) + ilens (torch.Tensor): (B,) + Returns: + enhanced (ComplexTensor): (B, T, F) + ilens (torch.Tensor): (B,) + + """ + + def apply_beamforming(data, ilens, psd_speech, psd_noise): + # u: (B, C) + if self.ref_channel < 0: + u, _ = self.ref(psd_speech, ilens) + else: + # (optional) Create onehot vector for fixed reference microphone + u = torch.zeros( + *(data.size()[:-3] + (data.size(-2),)), device=data.device + ) + u[..., self.ref_channel].fill_(1) + + ws = get_mvdr_vector(psd_speech, psd_noise, u) + enhanced = apply_beamforming_vector(ws, data) + + return enhanced, ws + + # data (B, T, C, F) -> (B, F, C, T) + data = data.permute(0, 3, 2, 1) + + # mask: (B, F, C, T) + masks, _ = self.mask(data, ilens) + assert self.nmask == len(masks) + + if self.nmask == 2: # (mask_speech, mask_noise) + mask_speech, mask_noise = masks + + psd_speech = get_power_spectral_density_matrix(data, mask_speech) + psd_noise = get_power_spectral_density_matrix(data, mask_noise) + + enhanced, ws = apply_beamforming(data, ilens, psd_speech, psd_noise) + + # (..., F, T) -> (..., T, F) + enhanced = enhanced.transpose(-1, -2) + mask_speech = mask_speech.transpose(-1, -3) + else: # multi-speaker case: (mask_speech1, ..., mask_noise) + mask_speech = list(masks[:-1]) + mask_noise = masks[-1] + + psd_speeches = [ + get_power_spectral_density_matrix(data, mask) for mask in mask_speech + ] + psd_noise = get_power_spectral_density_matrix(data, mask_noise) + + enhanced = [] + ws = [] + for i in range(self.nmask - 1): + psd_speech = psd_speeches.pop(i) + # treat all other speakers' psd_speech as noises + enh, w = apply_beamforming( + data, ilens, psd_speech, sum(psd_speeches) + psd_noise + ) + psd_speeches.insert(i, psd_speech) + + # (..., F, T) -> (..., T, F) + enh = enh.transpose(-1, -2) + mask_speech[i] = mask_speech[i].transpose(-1, -3) + + enhanced.append(enh) + ws.append(w) + + return enhanced, ilens, mask_speech + + +class AttentionReference(torch.nn.Module): + def __init__(self, bidim, att_dim): + super().__init__() + self.mlp_psd = torch.nn.Linear(bidim, att_dim) + self.gvec = torch.nn.Linear(att_dim, 1) + + def forward( + self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0 + ) -> Tuple[torch.Tensor, torch.LongTensor]: + """The forward function + + Args: + psd_in (ComplexTensor): (B, F, C, C) + ilens (torch.Tensor): (B,) + scaling (float): + Returns: + u (torch.Tensor): (B, C) + ilens (torch.Tensor): (B,) + """ + B, _, C = psd_in.size()[:3] + assert psd_in.size(2) == psd_in.size(3), psd_in.size() + # psd_in: (B, F, C, C) + datatype = torch.bool if is_torch_1_3_plus else torch.uint8 + datatype2 = torch.bool if is_torch_1_2_plus else torch.uint8 + psd = psd_in.masked_fill( + torch.eye(C, dtype=datatype, device=psd_in.device).type(datatype2), 0 + ) + # psd: (B, F, C, C) -> (B, C, F) + psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2) + + # Calculate amplitude + psd_feat = (psd.real ** 2 + psd.imag ** 2) ** 0.5 + + # (B, C, F) -> (B, C, F2) + mlp_psd = self.mlp_psd(psd_feat) + # (B, C, F2) -> (B, C, 1) -> (B, C) + e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1) + u = F.softmax(scaling * e, dim=-1) + return u, ilens diff --git a/espnet/nets/pytorch_backend/frontends/dnn_wpe.py b/espnet/nets/pytorch_backend/frontends/dnn_wpe.py new file mode 100644 index 0000000000000000000000000000000000000000..33ccd11c71a636ab5cb5542c749d0a89ad1423a3 --- /dev/null +++ b/espnet/nets/pytorch_backend/frontends/dnn_wpe.py @@ -0,0 +1,93 @@ +from typing import Tuple + +from pytorch_wpe import wpe_one_iteration +import torch +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.frontends.mask_estimator import MaskEstimator +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask + + +class DNN_WPE(torch.nn.Module): + def __init__( + self, + wtype: str = "blstmp", + widim: int = 257, + wlayers: int = 3, + wunits: int = 300, + wprojs: int = 320, + dropout_rate: float = 0.0, + taps: int = 5, + delay: int = 3, + use_dnn_mask: bool = True, + iterations: int = 1, + normalization: bool = False, + ): + super().__init__() + self.iterations = iterations + self.taps = taps + self.delay = delay + + self.normalization = normalization + self.use_dnn_mask = use_dnn_mask + + self.inverse_power = True + + if self.use_dnn_mask: + self.mask_est = MaskEstimator( + wtype, widim, wlayers, wunits, wprojs, dropout_rate, nmask=1 + ) + + def forward( + self, data: ComplexTensor, ilens: torch.LongTensor + ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: + """The forward function + + Notation: + B: Batch + C: Channel + T: Time or Sequence length + F: Freq or Some dimension of the feature vector + + Args: + data: (B, C, T, F) + ilens: (B,) + Returns: + data: (B, C, T, F) + ilens: (B,) + """ + # (B, T, C, F) -> (B, F, C, T) + enhanced = data = data.permute(0, 3, 2, 1) + mask = None + + for i in range(self.iterations): + # Calculate power: (..., C, T) + power = enhanced.real ** 2 + enhanced.imag ** 2 + if i == 0 and self.use_dnn_mask: + # mask: (B, F, C, T) + (mask,), _ = self.mask_est(enhanced, ilens) + if self.normalization: + # Normalize along T + mask = mask / mask.sum(dim=-1)[..., None] + # (..., C, T) * (..., C, T) -> (..., C, T) + power = power * mask + + # Averaging along the channel axis: (..., C, T) -> (..., T) + power = power.mean(dim=-2) + + # enhanced: (..., C, T) -> (..., C, T) + enhanced = wpe_one_iteration( + data.contiguous(), + power, + taps=self.taps, + delay=self.delay, + inverse_power=self.inverse_power, + ) + + enhanced.masked_fill_(make_pad_mask(ilens, enhanced.real), 0) + + # (B, F, C, T) -> (B, T, C, F) + enhanced = enhanced.permute(0, 3, 2, 1) + if mask is not None: + mask = mask.transpose(-1, -3) + return enhanced, ilens, mask diff --git a/espnet/nets/pytorch_backend/frontends/feature_transform.py b/espnet/nets/pytorch_backend/frontends/feature_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..700f63fdd0831e4cc20a12ecde7f4c9bd360ca4c --- /dev/null +++ b/espnet/nets/pytorch_backend/frontends/feature_transform.py @@ -0,0 +1,263 @@ +from typing import List +from typing import Tuple +from typing import Union + +import librosa +import numpy as np +import torch +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask + + +class FeatureTransform(torch.nn.Module): + def __init__( + self, + # Mel options, + fs: int = 16000, + n_fft: int = 512, + n_mels: int = 80, + fmin: float = 0.0, + fmax: float = None, + # Normalization + stats_file: str = None, + apply_uttmvn: bool = True, + uttmvn_norm_means: bool = True, + uttmvn_norm_vars: bool = False, + ): + super().__init__() + self.apply_uttmvn = apply_uttmvn + + self.logmel = LogMel(fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax) + self.stats_file = stats_file + if stats_file is not None: + self.global_mvn = GlobalMVN(stats_file) + else: + self.global_mvn = None + + if self.apply_uttmvn is not None: + self.uttmvn = UtteranceMVN( + norm_means=uttmvn_norm_means, norm_vars=uttmvn_norm_vars + ) + else: + self.uttmvn = None + + def forward( + self, x: ComplexTensor, ilens: Union[torch.LongTensor, np.ndarray, List[int]] + ) -> Tuple[torch.Tensor, torch.LongTensor]: + # (B, T, F) or (B, T, C, F) + if x.dim() not in (3, 4): + raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") + if not torch.is_tensor(ilens): + ilens = torch.from_numpy(np.asarray(ilens)).to(x.device) + + if x.dim() == 4: + # h: (B, T, C, F) -> h: (B, T, F) + if self.training: + # Select 1ch randomly + ch = np.random.randint(x.size(2)) + h = x[:, :, ch, :] + else: + # Use the first channel + h = x[:, :, 0, :] + else: + h = x + + # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) + h = h.real ** 2 + h.imag ** 2 + + h, _ = self.logmel(h, ilens) + if self.stats_file is not None: + h, _ = self.global_mvn(h, ilens) + if self.apply_uttmvn: + h, _ = self.uttmvn(h, ilens) + + return h, ilens + + +class LogMel(torch.nn.Module): + """Convert STFT to fbank feats + + The arguments is same as librosa.filters.mel + + Args: + fs: number > 0 [scalar] sampling rate of the incoming signal + n_fft: int > 0 [scalar] number of FFT components + n_mels: int > 0 [scalar] number of Mel bands to generate + fmin: float >= 0 [scalar] lowest frequency (in Hz) + fmax: float >= 0 [scalar] highest frequency (in Hz). + If `None`, use `fmax = fs / 2.0` + htk: use HTK formula instead of Slaney + norm: {None, 1, np.inf} [scalar] + if 1, divide the triangular mel weights by the width of the mel band + (area normalization). Otherwise, leave all the triangles aiming for + a peak value of 1.0 + + """ + + def __init__( + self, + fs: int = 16000, + n_fft: int = 512, + n_mels: int = 80, + fmin: float = 0.0, + fmax: float = None, + htk: bool = False, + norm=1, + ): + super().__init__() + + _mel_options = dict( + sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm + ) + self.mel_options = _mel_options + + # Note(kamo): The mel matrix of librosa is different from kaldi. + melmat = librosa.filters.mel(**_mel_options) + # melmat: (D2, D1) -> (D1, D2) + self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) + + def extra_repr(self): + return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) + + def forward( + self, feat: torch.Tensor, ilens: torch.LongTensor + ) -> Tuple[torch.Tensor, torch.LongTensor]: + # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) + mel_feat = torch.matmul(feat, self.melmat) + + logmel_feat = (mel_feat + 1e-20).log() + # Zero padding + logmel_feat = logmel_feat.masked_fill(make_pad_mask(ilens, logmel_feat, 1), 0.0) + return logmel_feat, ilens + + +class GlobalMVN(torch.nn.Module): + """Apply global mean and variance normalization + + Args: + stats_file(str): npy file of 1-dim array or text file. + From the _first element to + the {(len(array) - 1) / 2}th element are treated as + the sum of features, + and the rest excluding the last elements are + treated as the sum of the square value of features, + and the last elements eqauls to the number of samples. + std_floor(float): + """ + + def __init__( + self, + stats_file: str, + norm_means: bool = True, + norm_vars: bool = True, + eps: float = 1.0e-20, + ): + super().__init__() + self.norm_means = norm_means + self.norm_vars = norm_vars + + self.stats_file = stats_file + stats = np.load(stats_file) + + stats = stats.astype(float) + assert (len(stats) - 1) % 2 == 0, stats.shape + + count = stats.flatten()[-1] + mean = stats[: (len(stats) - 1) // 2] / count + var = stats[(len(stats) - 1) // 2 : -1] / count - mean * mean + std = np.maximum(np.sqrt(var), eps) + + self.register_buffer("bias", torch.from_numpy(-mean.astype(np.float32))) + self.register_buffer("scale", torch.from_numpy(1 / std.astype(np.float32))) + + def extra_repr(self): + return ( + f"stats_file={self.stats_file}, " + f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" + ) + + def forward( + self, x: torch.Tensor, ilens: torch.LongTensor + ) -> Tuple[torch.Tensor, torch.LongTensor]: + # feat: (B, T, D) + if self.norm_means: + x += self.bias.type_as(x) + x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) + + if self.norm_vars: + x *= self.scale.type_as(x) + return x, ilens + + +class UtteranceMVN(torch.nn.Module): + def __init__( + self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20 + ): + super().__init__() + self.norm_means = norm_means + self.norm_vars = norm_vars + self.eps = eps + + def extra_repr(self): + return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" + + def forward( + self, x: torch.Tensor, ilens: torch.LongTensor + ) -> Tuple[torch.Tensor, torch.LongTensor]: + return utterance_mvn( + x, ilens, norm_means=self.norm_means, norm_vars=self.norm_vars, eps=self.eps + ) + + +def utterance_mvn( + x: torch.Tensor, + ilens: torch.LongTensor, + norm_means: bool = True, + norm_vars: bool = False, + eps: float = 1.0e-20, +) -> Tuple[torch.Tensor, torch.LongTensor]: + """Apply utterance mean and variance normalization + + Args: + x: (B, T, D), assumed zero padded + ilens: (B, T, D) + norm_means: + norm_vars: + eps: + + """ + ilens_ = ilens.type_as(x) + # mean: (B, D) + mean = x.sum(dim=1) / ilens_[:, None] + + if norm_means: + x -= mean[:, None, :] + x_ = x + else: + x_ = x - mean[:, None, :] + + # Zero padding + x_.masked_fill(make_pad_mask(ilens, x_, 1), 0.0) + if norm_vars: + var = x_.pow(2).sum(dim=1) / ilens_[:, None] + var = torch.clamp(var, min=eps) + x /= var.sqrt()[:, None, :] + x_ = x + return x_, ilens + + +def feature_transform_for(args, n_fft): + return FeatureTransform( + # Mel options, + fs=args.fbank_fs, + n_fft=n_fft, + n_mels=args.n_mels, + fmin=args.fbank_fmin, + fmax=args.fbank_fmax, + # Normalization + stats_file=args.stats_file, + apply_uttmvn=args.apply_uttmvn, + uttmvn_norm_means=args.uttmvn_norm_means, + uttmvn_norm_vars=args.uttmvn_norm_vars, + ) diff --git a/espnet/nets/pytorch_backend/frontends/frontend.py b/espnet/nets/pytorch_backend/frontends/frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..7231f68b35f4300c55d78ec5a7fd33ffa46ed7ab --- /dev/null +++ b/espnet/nets/pytorch_backend/frontends/frontend.py @@ -0,0 +1,151 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy +import torch +import torch.nn as nn +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.frontends.dnn_beamformer import DNN_Beamformer +from espnet.nets.pytorch_backend.frontends.dnn_wpe import DNN_WPE + + +class Frontend(nn.Module): + def __init__( + self, + idim: int, + # WPE options + use_wpe: bool = False, + wtype: str = "blstmp", + wlayers: int = 3, + wunits: int = 300, + wprojs: int = 320, + wdropout_rate: float = 0.0, + taps: int = 5, + delay: int = 3, + use_dnn_mask_for_wpe: bool = True, + # Beamformer options + use_beamformer: bool = False, + btype: str = "blstmp", + blayers: int = 3, + bunits: int = 300, + bprojs: int = 320, + bnmask: int = 2, + badim: int = 320, + ref_channel: int = -1, + bdropout_rate=0.0, + ): + super().__init__() + + self.use_beamformer = use_beamformer + self.use_wpe = use_wpe + self.use_dnn_mask_for_wpe = use_dnn_mask_for_wpe + # use frontend for all the data, + # e.g. in the case of multi-speaker speech separation + self.use_frontend_for_all = bnmask > 2 + + if self.use_wpe: + if self.use_dnn_mask_for_wpe: + # Use DNN for power estimation + # (Not observed significant gains) + iterations = 1 + else: + # Performing as conventional WPE, without DNN Estimator + iterations = 2 + + self.wpe = DNN_WPE( + wtype=wtype, + widim=idim, + wunits=wunits, + wprojs=wprojs, + wlayers=wlayers, + taps=taps, + delay=delay, + dropout_rate=wdropout_rate, + iterations=iterations, + use_dnn_mask=use_dnn_mask_for_wpe, + ) + else: + self.wpe = None + + if self.use_beamformer: + self.beamformer = DNN_Beamformer( + btype=btype, + bidim=idim, + bunits=bunits, + bprojs=bprojs, + blayers=blayers, + bnmask=bnmask, + dropout_rate=bdropout_rate, + badim=badim, + ref_channel=ref_channel, + ) + else: + self.beamformer = None + + def forward( + self, x: ComplexTensor, ilens: Union[torch.LongTensor, numpy.ndarray, List[int]] + ) -> Tuple[ComplexTensor, torch.LongTensor, Optional[ComplexTensor]]: + assert len(x) == len(ilens), (len(x), len(ilens)) + # (B, T, F) or (B, T, C, F) + if x.dim() not in (3, 4): + raise ValueError(f"Input dim must be 3 or 4: {x.dim()}") + if not torch.is_tensor(ilens): + ilens = torch.from_numpy(numpy.asarray(ilens)).to(x.device) + + mask = None + h = x + if h.dim() == 4: + if self.training: + choices = [(False, False)] if not self.use_frontend_for_all else [] + if self.use_wpe: + choices.append((True, False)) + + if self.use_beamformer: + choices.append((False, True)) + + use_wpe, use_beamformer = choices[numpy.random.randint(len(choices))] + + else: + use_wpe = self.use_wpe + use_beamformer = self.use_beamformer + + # 1. WPE + if use_wpe: + # h: (B, T, C, F) -> h: (B, T, C, F) + h, ilens, mask = self.wpe(h, ilens) + + # 2. Beamformer + if use_beamformer: + # h: (B, T, C, F) -> h: (B, T, F) + h, ilens, mask = self.beamformer(h, ilens) + + return h, ilens, mask + + +def frontend_for(args, idim): + return Frontend( + idim=idim, + # WPE options + use_wpe=args.use_wpe, + wtype=args.wtype, + wlayers=args.wlayers, + wunits=args.wunits, + wprojs=args.wprojs, + wdropout_rate=args.wdropout_rate, + taps=args.wpe_taps, + delay=args.wpe_delay, + use_dnn_mask_for_wpe=args.use_dnn_mask_for_wpe, + # Beamformer options + use_beamformer=args.use_beamformer, + btype=args.btype, + blayers=args.blayers, + bunits=args.bunits, + bprojs=args.bprojs, + bnmask=args.bnmask, + badim=args.badim, + ref_channel=args.ref_channel, + bdropout_rate=args.bdropout_rate, + ) diff --git a/espnet/nets/pytorch_backend/frontends/mask_estimator.py b/espnet/nets/pytorch_backend/frontends/mask_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..48aaf0a1df1ea25a149fc88d5f54e7e2d85f6fa1 --- /dev/null +++ b/espnet/nets/pytorch_backend/frontends/mask_estimator.py @@ -0,0 +1,77 @@ +from typing import Tuple + +import numpy as np +import torch +from torch.nn import functional as F +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.rnn.encoders import RNN +from espnet.nets.pytorch_backend.rnn.encoders import RNNP + + +class MaskEstimator(torch.nn.Module): + def __init__(self, type, idim, layers, units, projs, dropout, nmask=1): + super().__init__() + subsample = np.ones(layers + 1, dtype=np.int) + + typ = type.lstrip("vgg").rstrip("p") + if type[-1] == "p": + self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ) + else: + self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ) + + self.type = type + self.nmask = nmask + self.linears = torch.nn.ModuleList( + [torch.nn.Linear(projs, idim) for _ in range(nmask)] + ) + + def forward( + self, xs: ComplexTensor, ilens: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: + """The forward function + + Args: + xs: (B, F, C, T) + ilens: (B,) + Returns: + hs (torch.Tensor): The hidden vector (B, F, C, T) + masks: A tuple of the masks. (B, F, C, T) + ilens: (B,) + """ + assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0)) + _, _, C, input_length = xs.size() + # (B, F, C, T) -> (B, C, T, F) + xs = xs.permute(0, 2, 3, 1) + + # Calculate amplitude: (B, C, T, F) -> (B, C, T, F) + xs = (xs.real ** 2 + xs.imag ** 2) ** 0.5 + # xs: (B, C, T, F) -> xs: (B * C, T, F) + xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1)) + # ilens: (B,) -> ilens_: (B * C) + ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1) + + # xs: (B * C, T, F) -> xs: (B * C, T, D) + xs, _, _ = self.brnn(xs, ilens_) + # xs: (B * C, T, D) -> xs: (B, C, T, D) + xs = xs.view(-1, C, xs.size(-2), xs.size(-1)) + + masks = [] + for linear in self.linears: + # xs: (B, C, T, D) -> mask:(B, C, T, F) + mask = linear(xs) + + mask = torch.sigmoid(mask) + # Zero padding + mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0) + + # (B, C, T, F) -> (B, F, C, T) + mask = mask.permute(0, 3, 1, 2) + + # Take cares of multi gpu cases: If input_length > max(ilens) + if mask.size(-1) < input_length: + mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0) + masks.append(mask) + + return tuple(masks), ilens diff --git a/espnet/nets/pytorch_backend/gtn_ctc.py b/espnet/nets/pytorch_backend/gtn_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..c7c71545c74dcbdd2272d2d15e4d7af3f92af72a --- /dev/null +++ b/espnet/nets/pytorch_backend/gtn_ctc.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +"""GTN CTC implementation.""" + +import gtn +import torch + + +class GTNCTCLossFunction(torch.autograd.Function): + """GTN CTC module.""" + + # Copied from FB's GTN example implementation: + # https://github.com/facebookresearch/gtn_applications/blob/master/utils.py#L251 + + @staticmethod + def create_ctc_graph(target, blank_idx): + """Build gtn graph. + + :param list target: single target sequence + :param int blank_idx: index of blank token + :return: gtn graph of target sequence + :rtype: gtn.Graph + """ + g_criterion = gtn.Graph(False) + L = len(target) + S = 2 * L + 1 + for s in range(S): + idx = (s - 1) // 2 + g_criterion.add_node(s == 0, s == S - 1 or s == S - 2) + label = target[idx] if s % 2 else blank_idx + g_criterion.add_arc(s, s, label) + if s > 0: + g_criterion.add_arc(s - 1, s, label) + if s % 2 and s > 1 and label != target[idx - 1]: + g_criterion.add_arc(s - 2, s, label) + g_criterion.arc_sort(False) + return g_criterion + + @staticmethod + def forward(ctx, log_probs, targets, blank_idx=0, reduction="none"): + """Forward computation. + + :param torch.tensor log_probs: batched log softmax probabilities (B, Tmax, oDim) + :param list targets: batched target sequences, list of lists + :param int blank_idx: index of blank token + :return: ctc loss value + :rtype: torch.Tensor + """ + B, T, C = log_probs.shape + losses = [None] * B + scales = [None] * B + emissions_graphs = [None] * B + + def process(b): + # create emission graph + g_emissions = gtn.linear_graph(T, C, log_probs.requires_grad) + cpu_data = log_probs[b].cpu().contiguous() + g_emissions.set_weights(cpu_data.data_ptr()) + + # create criterion graph + g_criterion = GTNCTCLossFunction.create_ctc_graph(targets[b], blank_idx) + # compose the graphs + g_loss = gtn.negate( + gtn.forward_score(gtn.intersect(g_emissions, g_criterion)) + ) + + scale = 1.0 + if reduction == "mean": + L = len(targets[b]) + scale = 1.0 / L if L > 0 else scale + elif reduction != "none": + raise ValueError("invalid value for reduction '" + str(reduction) + "'") + + # Save for backward: + losses[b] = g_loss + scales[b] = scale + emissions_graphs[b] = g_emissions + + gtn.parallel_for(process, range(B)) + + ctx.auxiliary_data = (losses, scales, emissions_graphs, log_probs.shape) + loss = torch.tensor([losses[b].item() * scales[b] for b in range(B)]) + return torch.mean(loss.cuda() if log_probs.is_cuda else loss) + + @staticmethod + def backward(ctx, grad_output): + """Backward computation. + + :param torch.tensor grad_output: backward passed gradient value + :return: cumulative gradient output + :rtype: (torch.Tensor, None, None, None) + """ + losses, scales, emissions_graphs, in_shape = ctx.auxiliary_data + B, T, C = in_shape + input_grad = torch.empty((B, T, C)) + + def process(b): + gtn.backward(losses[b], False) + emissions = emissions_graphs[b] + grad = emissions.grad().weights_to_numpy() + input_grad[b] = torch.from_numpy(grad).view(1, T, C) * scales[b] + + gtn.parallel_for(process, range(B)) + + if grad_output.is_cuda: + input_grad = input_grad.cuda() + input_grad *= grad_output / B + + return ( + input_grad, + None, # targets + None, # blank_idx + None, # reduction + ) diff --git a/espnet/nets/pytorch_backend/initialization.py b/espnet/nets/pytorch_backend/initialization.py new file mode 100644 index 0000000000000000000000000000000000000000..6ecdc8dab86d0eca2b9fb98a8a2434cc0d1b6b88 --- /dev/null +++ b/espnet/nets/pytorch_backend/initialization.py @@ -0,0 +1,55 @@ +#!/usr/bin/env python + +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Initialization functions for RNN sequence-to-sequence models.""" + +import math + + +def lecun_normal_init_parameters(module): + """Initialize parameters in the LeCun's manner.""" + for p in module.parameters(): + data = p.data + if data.dim() == 1: + # bias + data.zero_() + elif data.dim() == 2: + # linear weight + n = data.size(1) + stdv = 1.0 / math.sqrt(n) + data.normal_(0, stdv) + elif data.dim() in (3, 4): + # conv weight + n = data.size(1) + for k in data.size()[2:]: + n *= k + stdv = 1.0 / math.sqrt(n) + data.normal_(0, stdv) + else: + raise NotImplementedError + + +def uniform_init_parameters(module): + """Initialize parameters with an uniform distribution.""" + for p in module.parameters(): + data = p.data + if data.dim() == 1: + # bias + data.uniform_(-0.1, 0.1) + elif data.dim() == 2: + # linear weight + data.uniform_(-0.1, 0.1) + elif data.dim() in (3, 4): + # conv weight + pass # use the pytorch default + else: + raise NotImplementedError + + +def set_forget_bias_to_one(bias): + """Initialize a bias vector in the forget gate with one.""" + n = bias.size(0) + start, end = n // 4, n // 2 + bias.data[start:end].fill_(1.0) diff --git a/espnet/nets/pytorch_backend/lm/__init__.py b/espnet/nets/pytorch_backend/lm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/lm/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/lm/default.py b/espnet/nets/pytorch_backend/lm/default.py new file mode 100644 index 0000000000000000000000000000000000000000..01bb26ea4a071e1672952ee0cfb754d16ad6d8e6 --- /dev/null +++ b/espnet/nets/pytorch_backend/lm/default.py @@ -0,0 +1,431 @@ +"""Default Recurrent Neural Network Languge Model in `lm_train.py`.""" + +from typing import Any +from typing import List +from typing import Tuple + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F + +from espnet.nets.lm_interface import LMInterface +from espnet.nets.pytorch_backend.e2e_asr import to_device +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.utils.cli_utils import strtobool + + +class DefaultRNNLM(BatchScorerInterface, LMInterface, nn.Module): + """Default RNNLM for `LMInterface` Implementation. + + Note: + PyTorch seems to have memory leak when one GPU compute this after data parallel. + If parallel GPUs compute this, it seems to be fine. + See also https://github.com/espnet/espnet/issues/1075 + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments to command line argument parser.""" + parser.add_argument( + "--type", + type=str, + default="lstm", + nargs="?", + choices=["lstm", "gru"], + help="Which type of RNN to use", + ) + parser.add_argument( + "--layer", "-l", type=int, default=2, help="Number of hidden layers" + ) + parser.add_argument( + "--unit", "-u", type=int, default=650, help="Number of hidden units" + ) + parser.add_argument( + "--embed-unit", + default=None, + type=int, + help="Number of hidden units in embedding layer, " + "if it is not specified, it keeps the same number with hidden units.", + ) + parser.add_argument( + "--dropout-rate", type=float, default=0.5, help="dropout probability" + ) + parser.add_argument( + "--emb-dropout-rate", + type=float, + default=0.0, + help="emb dropout probability", + ) + parser.add_argument( + "--tie-weights", + type=strtobool, + default=False, + help="Tie input and output embeddings", + ) + return parser + + def __init__(self, n_vocab, args): + """Initialize class. + + Args: + n_vocab (int): The size of the vocabulary + args (argparse.Namespace): configurations. see py:method:`add_arguments` + + """ + nn.Module.__init__(self) + # NOTE: for a compatibility with less than 0.5.0 version models + dropout_rate = getattr(args, "dropout_rate", 0.0) + # NOTE: for a compatibility with less than 0.6.1 version models + embed_unit = getattr(args, "embed_unit", None) + # NOTE: for a compatibility with less than 0.9.7 version models + emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0) + # NOTE: for a compatibility with less than 0.9.7 version models + tie_weights = getattr(args, "tie_weights", False) + + self.model = ClassifierWithState( + RNNLM( + n_vocab, + args.layer, + args.unit, + embed_unit, + args.type, + dropout_rate, + emb_dropout_rate, + tie_weights, + ) + ) + + def state_dict(self): + """Dump state dict.""" + return self.model.state_dict() + + def load_state_dict(self, d): + """Load state dict.""" + self.model.load_state_dict(d) + + def forward(self, x, t): + """Compute LM loss value from buffer sequences. + + Args: + x (torch.Tensor): Input ids. (batch, len) + t (torch.Tensor): Target ids. (batch, len) + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of + loss to backward (scalar), + negative log-likelihood of t: -log p(t) (scalar) and + the number of elements in x (scalar) + + Notes: + The last two return values are used + in perplexity: p(t)^{-n} = exp(-log p(t) / n) + + """ + loss = 0 + logp = 0 + count = torch.tensor(0).long() + state = None + batch_size, sequence_length = x.shape + for i in range(sequence_length): + # Compute the loss at this time step and accumulate it + state, loss_batch = self.model(state, x[:, i], t[:, i]) + non_zeros = torch.sum(x[:, i] != 0, dtype=loss_batch.dtype) + loss += loss_batch.mean() * non_zeros + logp += torch.sum(loss_batch * non_zeros) + count += int(non_zeros) + return loss / batch_size, loss, count.to(loss.device) + + def score(self, y, state, x): + """Score new token. + + Args: + y (torch.Tensor): 1D torch.int64 prefix tokens. + state: Scorer state for prefix tokens + x (torch.Tensor): 2D encoder feature that generates ys. + + Returns: + tuple[torch.Tensor, Any]: Tuple of + torch.float32 scores for next token (n_vocab) + and next state for ys + + """ + new_state, scores = self.model.predict(state, y[-1].unsqueeze(0)) + return scores.squeeze(0), new_state + + def final_score(self, state): + """Score eos. + + Args: + state: Scorer state for prefix tokens + + Returns: + float: final score + + """ + return self.model.final(state) + + # batch beam search API (see BatchScorerInterface) + def batch_score( + self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor + ) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch. + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = self.model.predictor.n_layers + if self.model.predictor.typ == "lstm": + keys = ("c", "h") + else: + keys = ("h",) + + if states[0] is None: + states = None + else: + # transpose state of [batch, key, layer] into [key, layer, batch] + states = { + k: [ + torch.stack([states[b][k][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + for k in keys + } + states, logp = self.model.predict(states, ys[:, -1]) + + # transpose state of [key, layer, batch] into [batch, key, layer] + return ( + logp, + [ + {k: [states[k][i][b] for i in range(n_layers)] for k in keys} + for b in range(n_batch) + ], + ) + + +class ClassifierWithState(nn.Module): + """A wrapper for pytorch RNNLM.""" + + def __init__( + self, predictor, lossfun=nn.CrossEntropyLoss(reduction="none"), label_key=-1 + ): + """Initialize class. + + :param torch.nn.Module predictor : The RNNLM + :param function lossfun : The loss function to use + :param int/str label_key : + + """ + if not (isinstance(label_key, (int, str))): + raise TypeError("label_key must be int or str, but is %s" % type(label_key)) + super(ClassifierWithState, self).__init__() + self.lossfun = lossfun + self.y = None + self.loss = None + self.label_key = label_key + self.predictor = predictor + + def forward(self, state, *args, **kwargs): + """Compute the loss value for an input and label pair. + + Notes: + It also computes accuracy and stores it to the attribute. + When ``label_key`` is ``int``, the corresponding element in ``args`` + is treated as ground truth labels. And when it is ``str``, the + element in ``kwargs`` is used. + The all elements of ``args`` and ``kwargs`` except the groundtruth + labels are features. + It feeds features to the predictor and compare the result + with ground truth labels. + + :param torch.Tensor state : the LM state + :param list[torch.Tensor] args : Input minibatch + :param dict[torch.Tensor] kwargs : Input minibatch + :return loss value + :rtype torch.Tensor + + """ + if isinstance(self.label_key, int): + if not (-len(args) <= self.label_key < len(args)): + msg = "Label key %d is out of bounds" % self.label_key + raise ValueError(msg) + t = args[self.label_key] + if self.label_key == -1: + args = args[:-1] + else: + args = args[: self.label_key] + args[self.label_key + 1 :] + elif isinstance(self.label_key, str): + if self.label_key not in kwargs: + msg = 'Label key "%s" is not found' % self.label_key + raise ValueError(msg) + t = kwargs[self.label_key] + del kwargs[self.label_key] + + self.y = None + self.loss = None + state, self.y = self.predictor(state, *args, **kwargs) + self.loss = self.lossfun(self.y, t) + return state, self.loss + + def predict(self, state, x): + """Predict log probabilities for given state and input x using the predictor. + + :param torch.Tensor state : The current state + :param torch.Tensor x : The input + :return a tuple (new state, log prob vector) + :rtype (torch.Tensor, torch.Tensor) + """ + if hasattr(self.predictor, "normalized") and self.predictor.normalized: + return self.predictor(state, x) + else: + state, z = self.predictor(state, x) + return state, F.log_softmax(z, dim=1) + + def buff_predict(self, state, x, n): + """Predict new tokens from buffered inputs.""" + if self.predictor.__class__.__name__ == "RNNLM": + return self.predict(state, x) + + new_state = [] + new_log_y = [] + for i in range(n): + state_i = None if state is None else state[i] + state_i, log_y = self.predict(state_i, x[i].unsqueeze(0)) + new_state.append(state_i) + new_log_y.append(log_y) + + return new_state, torch.cat(new_log_y) + + def final(self, state, index=None): + """Predict final log probabilities for given state using the predictor. + + :param state: The state + :return The final log probabilities + :rtype torch.Tensor + """ + if hasattr(self.predictor, "final"): + if index is not None: + return self.predictor.final(state[index]) + else: + return self.predictor.final(state) + else: + return 0.0 + + +# Definition of a recurrent net for language modeling +class RNNLM(nn.Module): + """A pytorch RNNLM.""" + + def __init__( + self, + n_vocab, + n_layers, + n_units, + n_embed=None, + typ="lstm", + dropout_rate=0.5, + emb_dropout_rate=0.0, + tie_weights=False, + ): + """Initialize class. + + :param int n_vocab: The size of the vocabulary + :param int n_layers: The number of layers to create + :param int n_units: The number of units per layer + :param str typ: The RNN type + """ + super(RNNLM, self).__init__() + if n_embed is None: + n_embed = n_units + + self.embed = nn.Embedding(n_vocab, n_embed) + + if emb_dropout_rate == 0.0: + self.embed_drop = None + else: + self.embed_drop = nn.Dropout(emb_dropout_rate) + + if typ == "lstm": + self.rnn = nn.ModuleList( + [nn.LSTMCell(n_embed, n_units)] + + [nn.LSTMCell(n_units, n_units) for _ in range(n_layers - 1)] + ) + else: + self.rnn = nn.ModuleList( + [nn.GRUCell(n_embed, n_units)] + + [nn.GRUCell(n_units, n_units) for _ in range(n_layers - 1)] + ) + + self.dropout = nn.ModuleList( + [nn.Dropout(dropout_rate) for _ in range(n_layers + 1)] + ) + self.lo = nn.Linear(n_units, n_vocab) + self.n_layers = n_layers + self.n_units = n_units + self.typ = typ + + logging.info("Tie weights set to {}".format(tie_weights)) + logging.info("Dropout set to {}".format(dropout_rate)) + logging.info("Emb Dropout set to {}".format(emb_dropout_rate)) + + if tie_weights: + assert ( + n_embed == n_units + ), "Tie Weights: True need embedding and final dimensions to match" + self.lo.weight = self.embed.weight + + # initialize parameters from uniform distribution + for param in self.parameters(): + param.data.uniform_(-0.1, 0.1) + + def zero_state(self, batchsize): + """Initialize state.""" + p = next(self.parameters()) + return torch.zeros(batchsize, self.n_units).to(device=p.device, dtype=p.dtype) + + def forward(self, state, x): + """Forward neural networks.""" + if state is None: + h = [to_device(x, self.zero_state(x.size(0))) for n in range(self.n_layers)] + state = {"h": h} + if self.typ == "lstm": + c = [ + to_device(x, self.zero_state(x.size(0))) + for n in range(self.n_layers) + ] + state = {"c": c, "h": h} + + h = [None] * self.n_layers + if self.embed_drop is not None: + emb = self.embed_drop(self.embed(x)) + else: + emb = self.embed(x) + if self.typ == "lstm": + c = [None] * self.n_layers + h[0], c[0] = self.rnn[0]( + self.dropout[0](emb), (state["h"][0], state["c"][0]) + ) + for n in range(1, self.n_layers): + h[n], c[n] = self.rnn[n]( + self.dropout[n](h[n - 1]), (state["h"][n], state["c"][n]) + ) + state = {"c": c, "h": h} + else: + h[0] = self.rnn[0](self.dropout[0](emb), state["h"][0]) + for n in range(1, self.n_layers): + h[n] = self.rnn[n](self.dropout[n](h[n - 1]), state["h"][n]) + state = {"h": h} + y = self.lo(self.dropout[-1](h[-1])) + return state, y diff --git a/espnet/nets/pytorch_backend/lm/seq_rnn.py b/espnet/nets/pytorch_backend/lm/seq_rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5f026e3811c790f283dc9298e1221d783c0e4f --- /dev/null +++ b/espnet/nets/pytorch_backend/lm/seq_rnn.py @@ -0,0 +1,178 @@ +"""Sequential implementation of Recurrent Neural Network Language Model.""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from espnet.nets.lm_interface import LMInterface + + +class SequentialRNNLM(LMInterface, torch.nn.Module): + """Sequential RNNLM. + + See also: + https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py + + """ + + @staticmethod + def add_arguments(parser): + """Add arguments to command line argument parser.""" + parser.add_argument( + "--type", + type=str, + default="lstm", + nargs="?", + choices=["lstm", "gru"], + help="Which type of RNN to use", + ) + parser.add_argument( + "--layer", "-l", type=int, default=2, help="Number of hidden layers" + ) + parser.add_argument( + "--unit", "-u", type=int, default=650, help="Number of hidden units" + ) + parser.add_argument( + "--dropout-rate", type=float, default=0.5, help="dropout probability" + ) + return parser + + def __init__(self, n_vocab, args): + """Initialize class. + + Args: + n_vocab (int): The size of the vocabulary + args (argparse.Namespace): configurations. see py:method:`add_arguments` + + """ + torch.nn.Module.__init__(self) + self._setup( + rnn_type=args.type.upper(), + ntoken=n_vocab, + ninp=args.unit, + nhid=args.unit, + nlayers=args.layer, + dropout=args.dropout_rate, + ) + + def _setup( + self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False + ): + self.drop = nn.Dropout(dropout) + self.encoder = nn.Embedding(ntoken, ninp) + if rnn_type in ["LSTM", "GRU"]: + self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) + else: + try: + nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] + except KeyError: + raise ValueError( + "An invalid option for `--model` was supplied, " + "options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']" + ) + self.rnn = nn.RNN( + ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout + ) + self.decoder = nn.Linear(nhid, ntoken) + + # Optionally tie weights as in: + # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) + # https://arxiv.org/abs/1608.05859 + # and + # "Tying Word Vectors and Word Classifiers: + # A Loss Framework for Language Modeling" (Inan et al. 2016) + # https://arxiv.org/abs/1611.01462 + if tie_weights: + if nhid != ninp: + raise ValueError( + "When using the tied flag, nhid must be equal to emsize" + ) + self.decoder.weight = self.encoder.weight + + self._init_weights() + + self.rnn_type = rnn_type + self.nhid = nhid + self.nlayers = nlayers + + def _init_weights(self): + # NOTE: original init in pytorch/examples + # initrange = 0.1 + # self.encoder.weight.data.uniform_(-initrange, initrange) + # self.decoder.bias.data.zero_() + # self.decoder.weight.data.uniform_(-initrange, initrange) + # NOTE: our default.py:RNNLM init + for param in self.parameters(): + param.data.uniform_(-0.1, 0.1) + + def forward(self, x, t): + """Compute LM loss value from buffer sequences. + + Args: + x (torch.Tensor): Input ids. (batch, len) + t (torch.Tensor): Target ids. (batch, len) + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of + loss to backward (scalar), + negative log-likelihood of t: -log p(t) (scalar) and + the number of elements in x (scalar) + + Notes: + The last two return values are used + in perplexity: p(t)^{-n} = exp(-log p(t) / n) + + """ + y = self._before_loss(x, None)[0] + mask = (x != 0).to(y.dtype) + loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") + logp = loss * mask.view(-1) + logp = logp.sum() + count = mask.sum() + return logp / count, logp, count + + def _before_loss(self, input, hidden): + emb = self.drop(self.encoder(input)) + output, hidden = self.rnn(emb, hidden) + output = self.drop(output) + decoded = self.decoder( + output.view(output.size(0) * output.size(1), output.size(2)) + ) + return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden + + def init_state(self, x): + """Get an initial state for decoding. + + Args: + x (torch.Tensor): The encoded feature tensor + + Returns: initial state + + """ + bsz = 1 + weight = next(self.parameters()) + if self.rnn_type == "LSTM": + return ( + weight.new_zeros(self.nlayers, bsz, self.nhid), + weight.new_zeros(self.nlayers, bsz, self.nhid), + ) + else: + return weight.new_zeros(self.nlayers, bsz, self.nhid) + + def score(self, y, state, x): + """Score new token. + + Args: + y (torch.Tensor): 1D torch.int64 prefix tokens. + state: Scorer state for prefix tokens + x (torch.Tensor): 2D encoder feature that generates ys. + + Returns: + tuple[torch.Tensor, Any]: Tuple of + torch.float32 scores for next token (n_vocab) + and next state for ys + + """ + y, new_state = self._before_loss(y[-1].view(1, 1), state) + logp = y.log_softmax(dim=-1).view(-1) + return logp, new_state diff --git a/espnet/nets/pytorch_backend/lm/transformer.py b/espnet/nets/pytorch_backend/lm/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..42c2f86d461b5d6125f4b5455b7b31cd6944f75d --- /dev/null +++ b/espnet/nets/pytorch_backend/lm/transformer.py @@ -0,0 +1,252 @@ +"""Transformer language model.""" + +from typing import Any +from typing import List +from typing import Tuple + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F + +from espnet.nets.lm_interface import LMInterface +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.utils.cli_utils import strtobool + + +class TransformerLM(nn.Module, LMInterface, BatchScorerInterface): + """Transformer language model.""" + + @staticmethod + def add_arguments(parser): + """Add arguments to command line argument parser.""" + parser.add_argument( + "--layer", type=int, default=4, help="Number of hidden layers" + ) + parser.add_argument( + "--unit", + type=int, + default=1024, + help="Number of hidden units in feedforward layer", + ) + parser.add_argument( + "--att-unit", + type=int, + default=256, + help="Number of hidden units in attention layer", + ) + parser.add_argument( + "--embed-unit", + type=int, + default=128, + help="Number of hidden units in embedding layer", + ) + parser.add_argument( + "--head", type=int, default=2, help="Number of multi head attention" + ) + parser.add_argument( + "--dropout-rate", type=float, default=0.5, help="dropout probability" + ) + parser.add_argument( + "--att-dropout-rate", + type=float, + default=0.0, + help="att dropout probability", + ) + parser.add_argument( + "--emb-dropout-rate", + type=float, + default=0.0, + help="emb dropout probability", + ) + parser.add_argument( + "--tie-weights", + type=strtobool, + default=False, + help="Tie input and output embeddings", + ) + parser.add_argument( + "--pos-enc", + default="sinusoidal", + choices=["sinusoidal", "none"], + help="positional encoding", + ) + return parser + + def __init__(self, n_vocab, args): + """Initialize class. + + Args: + n_vocab (int): The size of the vocabulary + args (argparse.Namespace): configurations. see py:method:`add_arguments` + + """ + nn.Module.__init__(self) + + # NOTE: for a compatibility with less than 0.9.7 version models + emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0) + # NOTE: for a compatibility with less than 0.9.7 version models + tie_weights = getattr(args, "tie_weights", False) + # NOTE: for a compatibility with less than 0.9.7 version models + att_dropout_rate = getattr(args, "att_dropout_rate", 0.0) + + if args.pos_enc == "sinusoidal": + pos_enc_class = PositionalEncoding + elif args.pos_enc == "none": + + def pos_enc_class(*args, **kwargs): + return nn.Sequential() # indentity + + else: + raise ValueError(f"unknown pos-enc option: {args.pos_enc}") + + self.embed = nn.Embedding(n_vocab, args.embed_unit) + + if emb_dropout_rate == 0.0: + self.embed_drop = None + else: + self.embed_drop = nn.Dropout(emb_dropout_rate) + + self.encoder = Encoder( + idim=args.embed_unit, + attention_dim=args.att_unit, + attention_heads=args.head, + linear_units=args.unit, + num_blocks=args.layer, + dropout_rate=args.dropout_rate, + attention_dropout_rate=att_dropout_rate, + input_layer="linear", + pos_enc_class=pos_enc_class, + ) + self.decoder = nn.Linear(args.att_unit, n_vocab) + + logging.info("Tie weights set to {}".format(tie_weights)) + logging.info("Dropout set to {}".format(args.dropout_rate)) + logging.info("Emb Dropout set to {}".format(emb_dropout_rate)) + logging.info("Att Dropout set to {}".format(att_dropout_rate)) + + if tie_weights: + assert ( + args.att_unit == args.embed_unit + ), "Tie Weights: True need embedding and final dimensions to match" + self.decoder.weight = self.embed.weight + + def _target_mask(self, ys_in_pad): + ys_mask = ys_in_pad != 0 + m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) + return ys_mask.unsqueeze(-2) & m + + def forward( + self, x: torch.Tensor, t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute LM loss value from buffer sequences. + + Args: + x (torch.Tensor): Input ids. (batch, len) + t (torch.Tensor): Target ids. (batch, len) + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of + loss to backward (scalar), + negative log-likelihood of t: -log p(t) (scalar) and + the number of elements in x (scalar) + + Notes: + The last two return values are used + in perplexity: p(t)^{-n} = exp(-log p(t) / n) + + """ + xm = x != 0 + + if self.embed_drop is not None: + emb = self.embed_drop(self.embed(x)) + else: + emb = self.embed(x) + + h, _ = self.encoder(emb, self._target_mask(x)) + y = self.decoder(h) + loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") + mask = xm.to(dtype=loss.dtype) + logp = loss * mask.view(-1) + logp = logp.sum() + count = mask.sum() + return logp / count, logp, count + + def score( + self, y: torch.Tensor, state: Any, x: torch.Tensor + ) -> Tuple[torch.Tensor, Any]: + """Score new token. + + Args: + y (torch.Tensor): 1D torch.int64 prefix tokens. + state: Scorer state for prefix tokens + x (torch.Tensor): encoder feature that generates ys. + + Returns: + tuple[torch.Tensor, Any]: Tuple of + torch.float32 scores for next token (n_vocab) + and next state for ys + + """ + y = y.unsqueeze(0) + + if self.embed_drop is not None: + emb = self.embed_drop(self.embed(y)) + else: + emb = self.embed(y) + + h, _, cache = self.encoder.forward_one_step( + emb, self._target_mask(y), cache=state + ) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(dim=-1).squeeze(0) + return logp, cache + + # batch beam search API (see BatchScorerInterface) + def batch_score( + self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor + ) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch (required). + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.encoder.encoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [ + torch.stack([states[b][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + + if self.embed_drop is not None: + emb = self.embed_drop(self.embed(ys)) + else: + emb = self.embed(ys) + + # batch decoding + h, _, states = self.encoder.forward_one_step( + emb, self._target_mask(ys), cache=batch_state + ) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(dim=-1) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + return logp, state_list diff --git a/espnet/nets/pytorch_backend/maskctc/__init__.py b/espnet/nets/pytorch_backend/maskctc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/maskctc/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/maskctc/add_mask_token.py b/espnet/nets/pytorch_backend/maskctc/add_mask_token.py new file mode 100644 index 0000000000000000000000000000000000000000..e503a0235b5c9a211dc6a7702039f16332748c0d --- /dev/null +++ b/espnet/nets/pytorch_backend/maskctc/add_mask_token.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Waseda University (Yosuke Higuchi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Token masking module for Masked LM.""" + +import numpy + + +def mask_uniform(ys_pad, mask_token, eos, ignore_id): + """Replace random tokens with label and add label. + + The number of is chosen from a uniform distribution + between one and the target sequence's length. + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :param int mask_token: index of + :param int eos: index of + :param int ignore_id: index of padding + :return: padded tensor (B, Lmax) + :rtype: torch.Tensor + :return: padded tensor (B, Lmax) + :rtype: torch.Tensor + """ + from espnet.nets.pytorch_backend.nets_utils import pad_list + + ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys + ys_out = [y.new(y.size()).fill_(ignore_id) for y in ys] + ys_in = [y.clone() for y in ys] + for i in range(len(ys)): + num_samples = numpy.random.randint(1, len(ys[i]) + 1) + idx = numpy.random.choice(len(ys[i]), num_samples) + + ys_in[i][idx] = mask_token + ys_out[i][idx] = ys[i][idx] + + return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) diff --git a/espnet/nets/pytorch_backend/maskctc/mask.py b/espnet/nets/pytorch_backend/maskctc/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..c3341d4bbc7b9b93128f1083ae6c5f95cb86a2e1 --- /dev/null +++ b/espnet/nets/pytorch_backend/maskctc/mask.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Waseda University (Yosuke Higuchi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Attention masking module for Masked LM.""" + + +def square_mask(ys_in_pad, ignore_id): + """Create attention mask to avoid attending on padding tokens. + + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :param int ignore_id: index of padding + :param torch.dtype dtype: result dtype + :rtype: torch.Tensor (B, Lmax, Lmax) + """ + ys_mask = (ys_in_pad != ignore_id).unsqueeze(-2) + ymax = ys_mask.size(-1) + ys_mask_tmp = ys_mask.transpose(1, 2).repeat(1, 1, ymax) + ys_mask = ys_mask.repeat(1, ymax, 1) & ys_mask_tmp + + return ys_mask diff --git a/espnet/nets/pytorch_backend/nets_utils.py b/espnet/nets/pytorch_backend/nets_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51e13fbf65cc447786cc5585c171a51a8d77aeab --- /dev/null +++ b/espnet/nets/pytorch_backend/nets_utils.py @@ -0,0 +1,498 @@ +# -*- coding: utf-8 -*- + +"""Network related utility tools.""" + +import logging +from typing import Dict + +import numpy as np +import torch + + +def to_device(m, x): + """Send tensor into the device of the module. + + Args: + m (torch.nn.Module): Torch module. + x (Tensor): Torch tensor. + + Returns: + Tensor: Torch tensor located in the same place as torch module. + + """ + if isinstance(m, torch.nn.Module): + device = next(m.parameters()).device + elif isinstance(m, torch.Tensor): + device = m.device + else: + raise TypeError( + "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}" + ) + return x.to(device) + + +def pad_list(xs, pad_value): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(xs) + max_len = max(x.size(0) for x in xs) + pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + + for i in range(n_batch): + pad[i, : xs[i].size(0)] = xs[i] + + return pad + + +def make_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + """ + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) + ) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def mask_by_length(xs, lengths, fill=0): + """Mask tensor according to length. + + Args: + xs (Tensor): Batch of input tensor (B, `*`). + lengths (LongTensor or List): Batch of lengths (B,). + fill (int or float): Value to fill masked part. + + Returns: + Tensor: Batch of masked input tensor (B, `*`). + + Examples: + >>> x = torch.arange(5).repeat(3, 1) + 1 + >>> x + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5]]) + >>> lengths = [5, 3, 2] + >>> mask_by_length(x, lengths) + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 0, 0], + [1, 2, 0, 0, 0]]) + + """ + assert xs.size(0) == len(lengths) + ret = xs.data.new(*xs.size()).fill_(fill) + for i, l in enumerate(lengths): + ret[i, :l] = xs[i, :l] + return ret + + +def th_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view( + pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) + ).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask) + ) + denominator = torch.sum(mask) + return float(numerator) / float(denominator) + + +def to_torch_tensor(x): + """Change to torch.Tensor or ComplexTensor from numpy.ndarray. + + Args: + x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. + + Returns: + Tensor or ComplexTensor: Type converted inputs. + + Examples: + >>> xs = np.ones(3, dtype=np.float32) + >>> xs = to_torch_tensor(xs) + tensor([1., 1., 1.]) + >>> xs = torch.ones(3, 4, 5) + >>> assert to_torch_tensor(xs) is xs + >>> xs = {'real': xs, 'imag': xs} + >>> to_torch_tensor(xs) + ComplexTensor( + Real: + tensor([1., 1., 1.]) + Imag; + tensor([1., 1., 1.]) + ) + + """ + # If numpy, change to torch tensor + if isinstance(x, np.ndarray): + if x.dtype.kind == "c": + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + return ComplexTensor(x) + else: + return torch.from_numpy(x) + + # If {'real': ..., 'imag': ...}, convert to ComplexTensor + elif isinstance(x, dict): + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + if "real" not in x or "imag" not in x: + raise ValueError("has 'real' and 'imag' keys: {}".format(list(x))) + # Relative importing because of using python3 syntax + return ComplexTensor(x["real"], x["imag"]) + + # If torch.Tensor, as it is + elif isinstance(x, torch.Tensor): + return x + + else: + error = ( + "x must be numpy.ndarray, torch.Tensor or a dict like " + "{{'real': torch.Tensor, 'imag': torch.Tensor}}, " + "but got {}".format(type(x)) + ) + try: + from torch_complex.tensor import ComplexTensor + except Exception: + # If PY2 + raise ValueError(error) + else: + # If PY3 + if isinstance(x, ComplexTensor): + return x + else: + raise ValueError(error) + + +def get_subsample(train_args, mode, arch): + """Parse the subsampling factors from the args for the specified `mode` and `arch`. + + Args: + train_args: argument Namespace containing options. + mode: one of ('asr', 'mt', 'st') + arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer') + + Returns: + np.ndarray / List[np.ndarray]: subsampling factors. + """ + if arch == "transformer": + return np.array([1]) + + elif mode == "mt" and arch == "rnn": + # +1 means input (+1) and layers outputs (train_args.elayer) + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + logging.warning("Subsampling is not performed for machine translation.") + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif ( + (mode == "asr" and arch in ("rnn", "rnn-t")) + or (mode == "mt" and arch == "rnn") + or (mode == "st" and arch == "rnn") + ): + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN." + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif mode == "asr" and arch == "rnn_mix": + subsample = np.ones( + train_args.elayers_sd + train_args.elayers + 1, dtype=np.int + ) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range( + min(train_args.elayers_sd + train_args.elayers + 1, len(ss)) + ): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN." + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif mode == "asr" and arch == "rnn_mulenc": + subsample_list = [] + for idx in range(train_args.num_encs): + subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int) + if train_args.etype[idx].endswith("p") and not train_args.etype[ + idx + ].startswith("vgg"): + ss = train_args.subsample[idx].split("_") + for j in range(min(train_args.elayers[idx] + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Encoder %d: Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN.", + idx + 1, + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + subsample_list.append(subsample) + return subsample_list + + else: + raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch)) + + +def rename_state_dict( + old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor] +): + """Replace keys of old prefix with new prefix in state dict.""" + # need this list not to break the dict iterator + old_keys = [k for k in state_dict if k.startswith(old_prefix)] + if len(old_keys) > 0: + logging.warning(f"Rename: {old_prefix} -> {new_prefix}") + for k in old_keys: + v = state_dict.pop(k) + new_k = k.replace(old_prefix, new_prefix) + state_dict[new_k] = v + + +def get_activation(act): + """Return activation function.""" + # Lazy load to avoid unused import + from espnet.nets.pytorch_backend.conformer.swish import Swish + + activation_funcs = { + "hardtanh": torch.nn.Hardtanh, + "tanh": torch.nn.Tanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": Swish, + } + + return activation_funcs[act]() diff --git a/espnet/nets/pytorch_backend/rnn/__init__.py b/espnet/nets/pytorch_backend/rnn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/rnn/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/rnn/argument.py b/espnet/nets/pytorch_backend/rnn/argument.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c89d25f52882f0c99ec3e8c8a182e3b6dc5ee7 --- /dev/null +++ b/espnet/nets/pytorch_backend/rnn/argument.py @@ -0,0 +1,156 @@ +# Copyright 2020 Hirofumi Inaguma +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Conformer common arguments.""" + + +def add_arguments_rnn_encoder_common(group): + """Define common arguments for RNN encoder.""" + group.add_argument( + "--etype", + default="blstmp", + type=str, + choices=[ + "lstm", + "blstm", + "lstmp", + "blstmp", + "vgglstmp", + "vggblstmp", + "vgglstm", + "vggblstm", + "gru", + "bgru", + "grup", + "bgrup", + "vgggrup", + "vggbgrup", + "vgggru", + "vggbgru", + ], + help="Type of encoder network architecture", + ) + group.add_argument( + "--elayers", + default=4, + type=int, + help="Number of encoder layers", + ) + group.add_argument( + "--eunits", + "-u", + default=300, + type=int, + help="Number of encoder hidden units", + ) + group.add_argument( + "--eprojs", default=320, type=int, help="Number of encoder projection units" + ) + group.add_argument( + "--subsample", + default="1", + type=str, + help="Subsample input frames x_y_z means " + "subsample every x frame at 1st layer, " + "every y frame at 2nd layer etc.", + ) + return group + + +def add_arguments_rnn_decoder_common(group): + """Define common arguments for RNN decoder.""" + group.add_argument( + "--dtype", + default="lstm", + type=str, + choices=["lstm", "gru"], + help="Type of decoder network architecture", + ) + group.add_argument( + "--dlayers", default=1, type=int, help="Number of decoder layers" + ) + group.add_argument( + "--dunits", default=320, type=int, help="Number of decoder hidden units" + ) + group.add_argument( + "--dropout-rate-decoder", + default=0.0, + type=float, + help="Dropout rate for the decoder", + ) + group.add_argument( + "--sampling-probability", + default=0.0, + type=float, + help="Ratio of predicted labels fed back to decoder", + ) + group.add_argument( + "--lsm-type", + const="", + default="", + type=str, + nargs="?", + choices=["", "unigram"], + help="Apply label smoothing with a specified distribution type", + ) + return group + + +def add_arguments_rnn_attention_common(group): + """Define common arguments for RNN attention.""" + group.add_argument( + "--atype", + default="dot", + type=str, + choices=[ + "noatt", + "dot", + "add", + "location", + "coverage", + "coverage_location", + "location2d", + "location_recurrent", + "multi_head_dot", + "multi_head_add", + "multi_head_loc", + "multi_head_multi_res_loc", + ], + help="Type of attention architecture", + ) + group.add_argument( + "--adim", + default=320, + type=int, + help="Number of attention transformation dimensions", + ) + group.add_argument( + "--awin", default=5, type=int, help="Window size for location2d attention" + ) + group.add_argument( + "--aheads", + default=4, + type=int, + help="Number of heads for multi head attention", + ) + group.add_argument( + "--aconv-chans", + default=-1, + type=int, + help="Number of attention convolution channels \ + (negative value indicates no location-aware attention)", + ) + group.add_argument( + "--aconv-filts", + default=100, + type=int, + help="Number of attention convolution filters \ + (negative value indicates no location-aware attention)", + ) + group.add_argument( + "--dropout-rate", + default=0.0, + type=float, + help="Dropout rate for the encoder", + ) + return group diff --git a/espnet/nets/pytorch_backend/rnn/attentions.py b/espnet/nets/pytorch_backend/rnn/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..8458c156f45817f11ec40f905dcec0626def741c --- /dev/null +++ b/espnet/nets/pytorch_backend/rnn/attentions.py @@ -0,0 +1,1808 @@ +"""Attention modules for RNN.""" + +import math +import six + +import torch +import torch.nn.functional as F + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.nets_utils import to_device + + +def _apply_attention_constraint( + e, last_attended_idx, backward_window=1, forward_window=3 +): + """Apply monotonic attention constraint. + + This function apply the monotonic attention constraint + introduced in `Deep Voice 3: Scaling + Text-to-Speech with Convolutional Sequence Learning`_. + + Args: + e (Tensor): Attention energy before applying softmax (1, T). + last_attended_idx (int): The index of the inputs of the last attended [0, T]. + backward_window (int, optional): Backward window size in attention constraint. + forward_window (int, optional): Forward window size in attetion constraint. + + Returns: + Tensor: Monotonic constrained attention energy (1, T). + + .. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`: + https://arxiv.org/abs/1710.07654 + + """ + if e.size(0) != 1: + raise NotImplementedError("Batch attention constraining is not yet supported.") + backward_idx = last_attended_idx - backward_window + forward_idx = last_attended_idx + forward_window + if backward_idx > 0: + e[:, :backward_idx] = -float("inf") + if forward_idx < e.size(1): + e[:, forward_idx:] = -float("inf") + return e + + +class NoAtt(torch.nn.Module): + """No attention""" + + def __init__(self): + super(NoAtt, self).__init__() + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.c = None + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.c = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): + """NoAtt forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: dummy (does not use) + :param torch.Tensor att_prev: dummy (does not use) + :return: attention weighted encoder state (B, D_enc) + :rtype: torch.Tensor + :return: previous attention weights + :rtype: torch.Tensor + """ + batch = len(enc_hs_pad) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + + # initialize attention weight with uniform dist. + if att_prev is None: + # if no bias, 0 0-pad goes 0 + mask = 1.0 - make_pad_mask(enc_hs_len).float() + att_prev = mask / mask.new(enc_hs_len).unsqueeze(-1) + att_prev = att_prev.to(self.enc_h) + self.c = torch.sum( + self.enc_h * att_prev.view(batch, self.h_length, 1), dim=1 + ) + + return self.c, att_prev + + +class AttDot(torch.nn.Module): + """Dot product attention + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int att_dim: attention dimension + :param bool han_mode: flag to swith on mode of hierarchical attention + and not store pre_compute_enc_h + """ + + def __init__(self, eprojs, dunits, att_dim, han_mode=False): + super(AttDot, self).__init__() + self.mlp_enc = torch.nn.Linear(eprojs, att_dim) + self.mlp_dec = torch.nn.Linear(dunits, att_dim) + + self.dunits = dunits + self.eprojs = eprojs + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): + """AttDot forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: dummy (does not use) + :param torch.Tensor att_prev: dummy (does not use) + :param float scaling: scaling parameter before applying softmax + :return: attention weighted encoder state (B, D_enc) + :rtype: torch.Tensor + :return: previous attention weight (B x T_max) + :rtype: torch.Tensor + """ + + batch = enc_hs_pad.size(0) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_enc_h = torch.tanh(self.mlp_enc(self.enc_h)) + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + e = torch.sum( + self.pre_compute_enc_h + * torch.tanh(self.mlp_dec(dec_z)).view(batch, 1, self.att_dim), + dim=2, + ) # utt x frame + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + w = F.softmax(scaling * e, dim=1) + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) + return c, w + + +class AttAdd(torch.nn.Module): + """Additive attention + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int att_dim: attention dimension + :param bool han_mode: flag to swith on mode of hierarchical attention + and not store pre_compute_enc_h + """ + + def __init__(self, eprojs, dunits, att_dim, han_mode=False): + super(AttAdd, self).__init__() + self.mlp_enc = torch.nn.Linear(eprojs, att_dim) + self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) + self.gvec = torch.nn.Linear(att_dim, 1) + self.dunits = dunits + self.eprojs = eprojs + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): + """AttAdd forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param torch.Tensor att_prev: dummy (does not use) + :param float scaling: scaling parameter before applying softmax + :return: attention weighted encoder state (B, D_enc) + :rtype: torch.Tensor + :return: previous attention weights (B x T_max) + :rtype: torch.Tensor + """ + + batch = len(enc_hs_pad) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_enc_h = self.mlp_enc(self.enc_h) + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + # dec_z_tiled: utt x frame x att_dim + dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) + + # dot with gvec + # utt x frame x att_dim -> utt x frame + e = self.gvec(torch.tanh(self.pre_compute_enc_h + dec_z_tiled)).squeeze(2) + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + w = F.softmax(scaling * e, dim=1) + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) + + return c, w + + +class AttLoc(torch.nn.Module): + """location-aware attention module. + + Reference: Attention-Based Models for Speech Recognition + (https://arxiv.org/pdf/1506.07503.pdf) + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int att_dim: attention dimension + :param int aconv_chans: # channels of attention convolution + :param int aconv_filts: filter size of attention convolution + :param bool han_mode: flag to swith on mode of hierarchical attention + and not store pre_compute_enc_h + """ + + def __init__( + self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False + ): + super(AttLoc, self).__init__() + self.mlp_enc = torch.nn.Linear(eprojs, att_dim) + self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) + self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) + self.loc_conv = torch.nn.Conv2d( + 1, + aconv_chans, + (1, 2 * aconv_filts + 1), + padding=(0, aconv_filts), + bias=False, + ) + self.gvec = torch.nn.Linear(att_dim, 1) + + self.dunits = dunits + self.eprojs = eprojs + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + + def forward( + self, + enc_hs_pad, + enc_hs_len, + dec_z, + att_prev, + scaling=2.0, + last_attended_idx=None, + backward_window=1, + forward_window=3, + ): + """Calcualte AttLoc forward propagation. + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param torch.Tensor att_prev: previous attention weight (B x T_max) + :param float scaling: scaling parameter before applying softmax + :param torch.Tensor forward_window: + forward window size when constraining attention + :param int last_attended_idx: index of the inputs of the last attended + :param int backward_window: backward window size in attention constraint + :param int forward_window: forward window size in attetion constraint + :return: attention weighted encoder state (B, D_enc) + :rtype: torch.Tensor + :return: previous attention weights (B x T_max) + :rtype: torch.Tensor + """ + batch = len(enc_hs_pad) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_enc_h = self.mlp_enc(self.enc_h) + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + # initialize attention weight with uniform dist. + if att_prev is None: + # if no bias, 0 0-pad goes 0 + att_prev = 1.0 - make_pad_mask(enc_hs_len).to( + device=dec_z.device, dtype=dec_z.dtype + ) + att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) + + # att_prev: utt x frame -> utt x 1 x 1 x frame + # -> utt x att_conv_chans x 1 x frame + att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) + # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans + att_conv = att_conv.squeeze(2).transpose(1, 2) + # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim + att_conv = self.mlp_att(att_conv) + + # dec_z_tiled: utt x frame x att_dim + dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) + + # dot with gvec + # utt x frame x att_dim -> utt x frame + e = self.gvec( + torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) + ).squeeze(2) + + # NOTE: consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + + # apply monotonic attention constraint (mainly for TTS) + if last_attended_idx is not None: + e = _apply_attention_constraint( + e, last_attended_idx, backward_window, forward_window + ) + + w = F.softmax(scaling * e, dim=1) + + # weighted sum over flames + # utt x hdim + c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) + + return c, w + + +class AttCov(torch.nn.Module): + """Coverage mechanism attention + + Reference: Get To The Point: Summarization with Pointer-Generator Network + (https://arxiv.org/abs/1704.04368) + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int att_dim: attention dimension + :param bool han_mode: flag to swith on mode of hierarchical attention + and not store pre_compute_enc_h + """ + + def __init__(self, eprojs, dunits, att_dim, han_mode=False): + super(AttCov, self).__init__() + self.mlp_enc = torch.nn.Linear(eprojs, att_dim) + self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) + self.wvec = torch.nn.Linear(1, att_dim) + self.gvec = torch.nn.Linear(att_dim, 1) + + self.dunits = dunits + self.eprojs = eprojs + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): + """AttCov forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param list att_prev_list: list of previous attention weight + :param float scaling: scaling parameter before applying softmax + :return: attention weighted encoder state (B, D_enc) + :rtype: torch.Tensor + :return: list of previous attention weights + :rtype: list + """ + + batch = len(enc_hs_pad) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_enc_h = self.mlp_enc(self.enc_h) + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + # initialize attention weight with uniform dist. + if att_prev_list is None: + # if no bias, 0 0-pad goes 0 + att_prev_list = to_device( + enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float()) + ) + att_prev_list = [ + att_prev_list / att_prev_list.new(enc_hs_len).unsqueeze(-1) + ] + + # att_prev_list: L' * [B x T] => cov_vec B x T + cov_vec = sum(att_prev_list) + # cov_vec: B x T => B x T x 1 => B x T x att_dim + cov_vec = self.wvec(cov_vec.unsqueeze(-1)) + + # dec_z_tiled: utt x frame x att_dim + dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) + + # dot with gvec + # utt x frame x att_dim -> utt x frame + e = self.gvec( + torch.tanh(cov_vec + self.pre_compute_enc_h + dec_z_tiled) + ).squeeze(2) + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + w = F.softmax(scaling * e, dim=1) + att_prev_list += [w] + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) + + return c, att_prev_list + + +class AttLoc2D(torch.nn.Module): + """2D location-aware attention + + This attention is an extended version of location aware attention. + It take not only one frame before attention weights, + but also earlier frames into account. + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int att_dim: attention dimension + :param int aconv_chans: # channels of attention convolution + :param int aconv_filts: filter size of attention convolution + :param int att_win: attention window size (default=5) + :param bool han_mode: + flag to swith on mode of hierarchical attention and not store pre_compute_enc_h + """ + + def __init__( + self, eprojs, dunits, att_dim, att_win, aconv_chans, aconv_filts, han_mode=False + ): + super(AttLoc2D, self).__init__() + self.mlp_enc = torch.nn.Linear(eprojs, att_dim) + self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) + self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) + self.loc_conv = torch.nn.Conv2d( + 1, + aconv_chans, + (att_win, 2 * aconv_filts + 1), + padding=(0, aconv_filts), + bias=False, + ) + self.gvec = torch.nn.Linear(att_dim, 1) + + self.dunits = dunits + self.eprojs = eprojs + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.aconv_chans = aconv_chans + self.att_win = att_win + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): + """AttLoc2D forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param torch.Tensor att_prev: previous attention weight (B x att_win x T_max) + :param float scaling: scaling parameter before applying softmax + :return: attention weighted encoder state (B, D_enc) + :rtype: torch.Tensor + :return: previous attention weights (B x att_win x T_max) + :rtype: torch.Tensor + """ + + batch = len(enc_hs_pad) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_enc_h = self.mlp_enc(self.enc_h) + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + # initialize attention weight with uniform dist. + if att_prev is None: + # B * [Li x att_win] + # if no bias, 0 0-pad goes 0 + att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) + att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) + att_prev = att_prev.unsqueeze(1).expand(-1, self.att_win, -1) + + # att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax + att_conv = self.loc_conv(att_prev.unsqueeze(1)) + # att_conv: B x C x 1 x Tmax -> B x Tmax x C + att_conv = att_conv.squeeze(2).transpose(1, 2) + # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim + att_conv = self.mlp_att(att_conv) + + # dec_z_tiled: utt x frame x att_dim + dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) + + # dot with gvec + # utt x frame x att_dim -> utt x frame + e = self.gvec( + torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) + ).squeeze(2) + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + w = F.softmax(scaling * e, dim=1) + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) + + # update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax + # -> B x att_win x Tmax + att_prev = torch.cat([att_prev, w.unsqueeze(1)], dim=1) + att_prev = att_prev[:, 1:] + + return c, att_prev + + +class AttLocRec(torch.nn.Module): + """location-aware recurrent attention + + This attention is an extended version of location aware attention. + With the use of RNN, + it take the effect of the history of attention weights into account. + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int att_dim: attention dimension + :param int aconv_chans: # channels of attention convolution + :param int aconv_filts: filter size of attention convolution + :param bool han_mode: + flag to swith on mode of hierarchical attention and not store pre_compute_enc_h + """ + + def __init__( + self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False + ): + super(AttLocRec, self).__init__() + self.mlp_enc = torch.nn.Linear(eprojs, att_dim) + self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) + self.loc_conv = torch.nn.Conv2d( + 1, + aconv_chans, + (1, 2 * aconv_filts + 1), + padding=(0, aconv_filts), + bias=False, + ) + self.att_lstm = torch.nn.LSTMCell(aconv_chans, att_dim, bias=False) + self.gvec = torch.nn.Linear(att_dim, 1) + + self.dunits = dunits + self.eprojs = eprojs + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_states, scaling=2.0): + """AttLocRec forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param tuple att_prev_states: previous attention weight and lstm states + ((B, T_max), ((B, att_dim), (B, att_dim))) + :param float scaling: scaling parameter before applying softmax + :return: attention weighted encoder state (B, D_enc) + :rtype: torch.Tensor + :return: previous attention weights and lstm states (w, (hx, cx)) + ((B, T_max), ((B, att_dim), (B, att_dim))) + :rtype: tuple + """ + + batch = len(enc_hs_pad) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_enc_h = self.mlp_enc(self.enc_h) + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + if att_prev_states is None: + # initialize attention weight with uniform dist. + # if no bias, 0 0-pad goes 0 + att_prev = to_device(enc_hs_pad, (1.0 - make_pad_mask(enc_hs_len).float())) + att_prev = att_prev / att_prev.new(enc_hs_len).unsqueeze(-1) + + # initialize lstm states + att_h = enc_hs_pad.new_zeros(batch, self.att_dim) + att_c = enc_hs_pad.new_zeros(batch, self.att_dim) + att_states = (att_h, att_c) + else: + att_prev = att_prev_states[0] + att_states = att_prev_states[1] + + # B x 1 x 1 x T -> B x C x 1 x T + att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) + # apply non-linear + att_conv = F.relu(att_conv) + # B x C x 1 x T -> B x C x 1 x 1 -> B x C + att_conv = F.max_pool2d(att_conv, (1, att_conv.size(3))).view(batch, -1) + + att_h, att_c = self.att_lstm(att_conv, att_states) + + # dec_z_tiled: utt x frame x att_dim + dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) + + # dot with gvec + # utt x frame x att_dim -> utt x frame + e = self.gvec( + torch.tanh(att_h.unsqueeze(1) + self.pre_compute_enc_h + dec_z_tiled) + ).squeeze(2) + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + w = F.softmax(scaling * e, dim=1) + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) + + return c, (w, (att_h, att_c)) + + +class AttCovLoc(torch.nn.Module): + """Coverage mechanism location aware attention + + This attention is a combination of coverage and location-aware attentions. + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int att_dim: attention dimension + :param int aconv_chans: # channels of attention convolution + :param int aconv_filts: filter size of attention convolution + :param bool han_mode: + flag to swith on mode of hierarchical attention and not store pre_compute_enc_h + """ + + def __init__( + self, eprojs, dunits, att_dim, aconv_chans, aconv_filts, han_mode=False + ): + super(AttCovLoc, self).__init__() + self.mlp_enc = torch.nn.Linear(eprojs, att_dim) + self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) + self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) + self.loc_conv = torch.nn.Conv2d( + 1, + aconv_chans, + (1, 2 * aconv_filts + 1), + padding=(0, aconv_filts), + bias=False, + ) + self.gvec = torch.nn.Linear(att_dim, 1) + + self.dunits = dunits + self.eprojs = eprojs + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.aconv_chans = aconv_chans + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev_list, scaling=2.0): + """AttCovLoc forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param list att_prev_list: list of previous attention weight + :param float scaling: scaling parameter before applying softmax + :return: attention weighted encoder state (B, D_enc) + :rtype: torch.Tensor + :return: list of previous attention weights + :rtype: list + """ + + batch = len(enc_hs_pad) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_enc_h = self.mlp_enc(self.enc_h) + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + # initialize attention weight with uniform dist. + if att_prev_list is None: + # if no bias, 0 0-pad goes 0 + mask = 1.0 - make_pad_mask(enc_hs_len).float() + att_prev_list = [ + to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) + ] + + # att_prev_list: L' * [B x T] => cov_vec B x T + cov_vec = sum(att_prev_list) + + # cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T + att_conv = self.loc_conv(cov_vec.view(batch, 1, 1, self.h_length)) + # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans + att_conv = att_conv.squeeze(2).transpose(1, 2) + # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim + att_conv = self.mlp_att(att_conv) + + # dec_z_tiled: utt x frame x att_dim + dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) + + # dot with gvec + # utt x frame x att_dim -> utt x frame + e = self.gvec( + torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) + ).squeeze(2) + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + w = F.softmax(scaling * e, dim=1) + att_prev_list += [w] + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) + + return c, att_prev_list + + +class AttMultiHeadDot(torch.nn.Module): + """Multi head dot product attention + + Reference: Attention is all you need + (https://arxiv.org/abs/1706.03762) + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int aheads: # heads of multi head attention + :param int att_dim_k: dimension k in multi head attention + :param int att_dim_v: dimension v in multi head attention + :param bool han_mode: flag to swith on mode of hierarchical attention + and not store pre_compute_k and pre_compute_v + """ + + def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False): + super(AttMultiHeadDot, self).__init__() + self.mlp_q = torch.nn.ModuleList() + self.mlp_k = torch.nn.ModuleList() + self.mlp_v = torch.nn.ModuleList() + for _ in six.moves.range(aheads): + self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] + self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] + self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] + self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) + self.dunits = dunits + self.eprojs = eprojs + self.aheads = aheads + self.att_dim_k = att_dim_k + self.att_dim_v = att_dim_v + self.scaling = 1.0 / math.sqrt(att_dim_k) + self.h_length = None + self.enc_h = None + self.pre_compute_k = None + self.pre_compute_v = None + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_k = None + self.pre_compute_v = None + self.mask = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): + """AttMultiHeadDot forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param torch.Tensor att_prev: dummy (does not use) + :return: attention weighted encoder state (B x D_enc) + :rtype: torch.Tensor + :return: list of previous attention weight (B x T_max) * aheads + :rtype: list + """ + + batch = enc_hs_pad.size(0) + # pre-compute all k and v outside the decoder loop + if self.pre_compute_k is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_k = [ + torch.tanh(self.mlp_k[h](self.enc_h)) + for h in six.moves.range(self.aheads) + ] + + if self.pre_compute_v is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_v = [ + self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) + ] + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + c = [] + w = [] + for h in six.moves.range(self.aheads): + e = torch.sum( + self.pre_compute_k[h] + * torch.tanh(self.mlp_q[h](dec_z)).view(batch, 1, self.att_dim_k), + dim=2, + ) # utt x frame + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + w += [F.softmax(self.scaling * e, dim=1)] + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c += [ + torch.sum( + self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 + ) + ] + + # concat all of c + c = self.mlp_o(torch.cat(c, dim=1)) + + return c, w + + +class AttMultiHeadAdd(torch.nn.Module): + """Multi head additive attention + + Reference: Attention is all you need + (https://arxiv.org/abs/1706.03762) + + This attention is multi head attention using additive attention for each head. + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int aheads: # heads of multi head attention + :param int att_dim_k: dimension k in multi head attention + :param int att_dim_v: dimension v in multi head attention + :param bool han_mode: flag to swith on mode of hierarchical attention + and not store pre_compute_k and pre_compute_v + """ + + def __init__(self, eprojs, dunits, aheads, att_dim_k, att_dim_v, han_mode=False): + super(AttMultiHeadAdd, self).__init__() + self.mlp_q = torch.nn.ModuleList() + self.mlp_k = torch.nn.ModuleList() + self.mlp_v = torch.nn.ModuleList() + self.gvec = torch.nn.ModuleList() + for _ in six.moves.range(aheads): + self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] + self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] + self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] + self.gvec += [torch.nn.Linear(att_dim_k, 1)] + self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) + self.dunits = dunits + self.eprojs = eprojs + self.aheads = aheads + self.att_dim_k = att_dim_k + self.att_dim_v = att_dim_v + self.scaling = 1.0 / math.sqrt(att_dim_k) + self.h_length = None + self.enc_h = None + self.pre_compute_k = None + self.pre_compute_v = None + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_k = None + self.pre_compute_v = None + self.mask = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): + """AttMultiHeadAdd forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param torch.Tensor att_prev: dummy (does not use) + :return: attention weighted encoder state (B, D_enc) + :rtype: torch.Tensor + :return: list of previous attention weight (B x T_max) * aheads + :rtype: list + """ + + batch = enc_hs_pad.size(0) + # pre-compute all k and v outside the decoder loop + if self.pre_compute_k is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_k = [ + self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) + ] + + if self.pre_compute_v is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_v = [ + self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) + ] + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + c = [] + w = [] + for h in six.moves.range(self.aheads): + e = self.gvec[h]( + torch.tanh( + self.pre_compute_k[h] + + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) + ) + ).squeeze(2) + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + w += [F.softmax(self.scaling * e, dim=1)] + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c += [ + torch.sum( + self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 + ) + ] + + # concat all of c + c = self.mlp_o(torch.cat(c, dim=1)) + + return c, w + + +class AttMultiHeadLoc(torch.nn.Module): + """Multi head location based attention + + Reference: Attention is all you need + (https://arxiv.org/abs/1706.03762) + + This attention is multi head attention using location-aware attention for each head. + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int aheads: # heads of multi head attention + :param int att_dim_k: dimension k in multi head attention + :param int att_dim_v: dimension v in multi head attention + :param int aconv_chans: # channels of attention convolution + :param int aconv_filts: filter size of attention convolution + :param bool han_mode: flag to swith on mode of hierarchical attention + and not store pre_compute_k and pre_compute_v + """ + + def __init__( + self, + eprojs, + dunits, + aheads, + att_dim_k, + att_dim_v, + aconv_chans, + aconv_filts, + han_mode=False, + ): + super(AttMultiHeadLoc, self).__init__() + self.mlp_q = torch.nn.ModuleList() + self.mlp_k = torch.nn.ModuleList() + self.mlp_v = torch.nn.ModuleList() + self.gvec = torch.nn.ModuleList() + self.loc_conv = torch.nn.ModuleList() + self.mlp_att = torch.nn.ModuleList() + for _ in six.moves.range(aheads): + self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] + self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] + self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] + self.gvec += [torch.nn.Linear(att_dim_k, 1)] + self.loc_conv += [ + torch.nn.Conv2d( + 1, + aconv_chans, + (1, 2 * aconv_filts + 1), + padding=(0, aconv_filts), + bias=False, + ) + ] + self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)] + self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) + self.dunits = dunits + self.eprojs = eprojs + self.aheads = aheads + self.att_dim_k = att_dim_k + self.att_dim_v = att_dim_v + self.scaling = 1.0 / math.sqrt(att_dim_k) + self.h_length = None + self.enc_h = None + self.pre_compute_k = None + self.pre_compute_v = None + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_k = None + self.pre_compute_v = None + self.mask = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0): + """AttMultiHeadLoc forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param torch.Tensor att_prev: + list of previous attention weight (B x T_max) * aheads + :param float scaling: scaling parameter before applying softmax + :return: attention weighted encoder state (B x D_enc) + :rtype: torch.Tensor + :return: list of previous attention weight (B x T_max) * aheads + :rtype: list + """ + + batch = enc_hs_pad.size(0) + # pre-compute all k and v outside the decoder loop + if self.pre_compute_k is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_k = [ + self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) + ] + + if self.pre_compute_v is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_v = [ + self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) + ] + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + if att_prev is None: + att_prev = [] + for _ in six.moves.range(self.aheads): + # if no bias, 0 0-pad goes 0 + mask = 1.0 - make_pad_mask(enc_hs_len).float() + att_prev += [ + to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) + ] + + c = [] + w = [] + for h in six.moves.range(self.aheads): + att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length)) + att_conv = att_conv.squeeze(2).transpose(1, 2) + att_conv = self.mlp_att[h](att_conv) + + e = self.gvec[h]( + torch.tanh( + self.pre_compute_k[h] + + att_conv + + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) + ) + ).squeeze(2) + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + w += [F.softmax(scaling * e, dim=1)] + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c += [ + torch.sum( + self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 + ) + ] + + # concat all of c + c = self.mlp_o(torch.cat(c, dim=1)) + + return c, w + + +class AttMultiHeadMultiResLoc(torch.nn.Module): + """Multi head multi resolution location based attention + + Reference: Attention is all you need + (https://arxiv.org/abs/1706.03762) + + This attention is multi head attention using location-aware attention for each head. + Furthermore, it uses different filter size for each head. + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int aheads: # heads of multi head attention + :param int att_dim_k: dimension k in multi head attention + :param int att_dim_v: dimension v in multi head attention + :param int aconv_chans: maximum # channels of attention convolution + each head use #ch = aconv_chans * (head + 1) / aheads + e.g. aheads=4, aconv_chans=100 => filter size = 25, 50, 75, 100 + :param int aconv_filts: filter size of attention convolution + :param bool han_mode: flag to swith on mode of hierarchical attention + and not store pre_compute_k and pre_compute_v + """ + + def __init__( + self, + eprojs, + dunits, + aheads, + att_dim_k, + att_dim_v, + aconv_chans, + aconv_filts, + han_mode=False, + ): + super(AttMultiHeadMultiResLoc, self).__init__() + self.mlp_q = torch.nn.ModuleList() + self.mlp_k = torch.nn.ModuleList() + self.mlp_v = torch.nn.ModuleList() + self.gvec = torch.nn.ModuleList() + self.loc_conv = torch.nn.ModuleList() + self.mlp_att = torch.nn.ModuleList() + for h in six.moves.range(aheads): + self.mlp_q += [torch.nn.Linear(dunits, att_dim_k)] + self.mlp_k += [torch.nn.Linear(eprojs, att_dim_k, bias=False)] + self.mlp_v += [torch.nn.Linear(eprojs, att_dim_v, bias=False)] + self.gvec += [torch.nn.Linear(att_dim_k, 1)] + afilts = aconv_filts * (h + 1) // aheads + self.loc_conv += [ + torch.nn.Conv2d( + 1, aconv_chans, (1, 2 * afilts + 1), padding=(0, afilts), bias=False + ) + ] + self.mlp_att += [torch.nn.Linear(aconv_chans, att_dim_k, bias=False)] + self.mlp_o = torch.nn.Linear(aheads * att_dim_v, eprojs, bias=False) + self.dunits = dunits + self.eprojs = eprojs + self.aheads = aheads + self.att_dim_k = att_dim_k + self.att_dim_v = att_dim_v + self.scaling = 1.0 / math.sqrt(att_dim_k) + self.h_length = None + self.enc_h = None + self.pre_compute_k = None + self.pre_compute_v = None + self.mask = None + self.han_mode = han_mode + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_k = None + self.pre_compute_v = None + self.mask = None + + def forward(self, enc_hs_pad, enc_hs_len, dec_z, att_prev): + """AttMultiHeadMultiResLoc forward + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param torch.Tensor att_prev: list of previous attention weight + (B x T_max) * aheads + :return: attention weighted encoder state (B x D_enc) + :rtype: torch.Tensor + :return: list of previous attention weight (B x T_max) * aheads + :rtype: list + """ + + batch = enc_hs_pad.size(0) + # pre-compute all k and v outside the decoder loop + if self.pre_compute_k is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_k = [ + self.mlp_k[h](self.enc_h) for h in six.moves.range(self.aheads) + ] + + if self.pre_compute_v is None or self.han_mode: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_v = [ + self.mlp_v[h](self.enc_h) for h in six.moves.range(self.aheads) + ] + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + if att_prev is None: + att_prev = [] + for _ in six.moves.range(self.aheads): + # if no bias, 0 0-pad goes 0 + mask = 1.0 - make_pad_mask(enc_hs_len).float() + att_prev += [ + to_device(enc_hs_pad, mask / mask.new(enc_hs_len).unsqueeze(-1)) + ] + + c = [] + w = [] + for h in six.moves.range(self.aheads): + att_conv = self.loc_conv[h](att_prev[h].view(batch, 1, 1, self.h_length)) + att_conv = att_conv.squeeze(2).transpose(1, 2) + att_conv = self.mlp_att[h](att_conv) + + e = self.gvec[h]( + torch.tanh( + self.pre_compute_k[h] + + att_conv + + self.mlp_q[h](dec_z).view(batch, 1, self.att_dim_k) + ) + ).squeeze(2) + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + w += [F.softmax(self.scaling * e, dim=1)] + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c += [ + torch.sum( + self.pre_compute_v[h] * w[h].view(batch, self.h_length, 1), dim=1 + ) + ] + + # concat all of c + c = self.mlp_o(torch.cat(c, dim=1)) + + return c, w + + +class AttForward(torch.nn.Module): + """Forward attention module. + + Reference: + Forward attention in sequence-to-sequence acoustic modeling for speech synthesis + (https://arxiv.org/pdf/1807.06736.pdf) + + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int att_dim: attention dimension + :param int aconv_chans: # channels of attention convolution + :param int aconv_filts: filter size of attention convolution + """ + + def __init__(self, eprojs, dunits, att_dim, aconv_chans, aconv_filts): + super(AttForward, self).__init__() + self.mlp_enc = torch.nn.Linear(eprojs, att_dim) + self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) + self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) + self.loc_conv = torch.nn.Conv2d( + 1, + aconv_chans, + (1, 2 * aconv_filts + 1), + padding=(0, aconv_filts), + bias=False, + ) + self.gvec = torch.nn.Linear(att_dim, 1) + self.dunits = dunits + self.eprojs = eprojs + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + + def reset(self): + """reset states""" + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + + def forward( + self, + enc_hs_pad, + enc_hs_len, + dec_z, + att_prev, + scaling=1.0, + last_attended_idx=None, + backward_window=1, + forward_window=3, + ): + """Calculate AttForward forward propagation. + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B x D_dec) + :param torch.Tensor att_prev: attention weights of previous step + :param float scaling: scaling parameter before applying softmax + :param int last_attended_idx: index of the inputs of the last attended + :param int backward_window: backward window size in attention constraint + :param int forward_window: forward window size in attetion constraint + :return: attention weighted encoder state (B, D_enc) + :rtype: torch.Tensor + :return: previous attention weights (B x T_max) + :rtype: torch.Tensor + """ + batch = len(enc_hs_pad) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_enc_h = self.mlp_enc(self.enc_h) + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + if att_prev is None: + # initial attention will be [1, 0, 0, ...] + att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) + att_prev[:, 0] = 1.0 + + # att_prev: utt x frame -> utt x 1 x 1 x frame + # -> utt x att_conv_chans x 1 x frame + att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) + # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans + att_conv = att_conv.squeeze(2).transpose(1, 2) + # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim + att_conv = self.mlp_att(att_conv) + + # dec_z_tiled: utt x frame x att_dim + dec_z_tiled = self.mlp_dec(dec_z).unsqueeze(1) + + # dot with gvec + # utt x frame x att_dim -> utt x frame + e = self.gvec( + torch.tanh(self.pre_compute_enc_h + dec_z_tiled + att_conv) + ).squeeze(2) + + # NOTE: consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + + # apply monotonic attention constraint (mainly for TTS) + if last_attended_idx is not None: + e = _apply_attention_constraint( + e, last_attended_idx, backward_window, forward_window + ) + + w = F.softmax(scaling * e, dim=1) + + # forward attention + att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] + w = (att_prev + att_prev_shift) * w + # NOTE: clamp is needed to avoid nan gradient + w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c = torch.sum(self.enc_h * w.unsqueeze(-1), dim=1) + + return c, w + + +class AttForwardTA(torch.nn.Module): + """Forward attention with transition agent module. + + Reference: + Forward attention in sequence-to-sequence acoustic modeling for speech synthesis + (https://arxiv.org/pdf/1807.06736.pdf) + + :param int eunits: # units of encoder + :param int dunits: # units of decoder + :param int att_dim: attention dimension + :param int aconv_chans: # channels of attention convolution + :param int aconv_filts: filter size of attention convolution + :param int odim: output dimension + """ + + def __init__(self, eunits, dunits, att_dim, aconv_chans, aconv_filts, odim): + super(AttForwardTA, self).__init__() + self.mlp_enc = torch.nn.Linear(eunits, att_dim) + self.mlp_dec = torch.nn.Linear(dunits, att_dim, bias=False) + self.mlp_ta = torch.nn.Linear(eunits + dunits + odim, 1) + self.mlp_att = torch.nn.Linear(aconv_chans, att_dim, bias=False) + self.loc_conv = torch.nn.Conv2d( + 1, + aconv_chans, + (1, 2 * aconv_filts + 1), + padding=(0, aconv_filts), + bias=False, + ) + self.gvec = torch.nn.Linear(att_dim, 1) + self.dunits = dunits + self.eunits = eunits + self.att_dim = att_dim + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + self.trans_agent_prob = 0.5 + + def reset(self): + self.h_length = None + self.enc_h = None + self.pre_compute_enc_h = None + self.mask = None + self.trans_agent_prob = 0.5 + + def forward( + self, + enc_hs_pad, + enc_hs_len, + dec_z, + att_prev, + out_prev, + scaling=1.0, + last_attended_idx=None, + backward_window=1, + forward_window=3, + ): + """Calculate AttForwardTA forward propagation. + + :param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits) + :param list enc_hs_len: padded encoder hidden state length (B) + :param torch.Tensor dec_z: decoder hidden state (B, dunits) + :param torch.Tensor att_prev: attention weights of previous step + :param torch.Tensor out_prev: decoder outputs of previous step (B, odim) + :param float scaling: scaling parameter before applying softmax + :param int last_attended_idx: index of the inputs of the last attended + :param int backward_window: backward window size in attention constraint + :param int forward_window: forward window size in attetion constraint + :return: attention weighted encoder state (B, dunits) + :rtype: torch.Tensor + :return: previous attention weights (B, Tmax) + :rtype: torch.Tensor + """ + batch = len(enc_hs_pad) + # pre-compute all h outside the decoder loop + if self.pre_compute_enc_h is None: + self.enc_h = enc_hs_pad # utt x frame x hdim + self.h_length = self.enc_h.size(1) + # utt x frame x att_dim + self.pre_compute_enc_h = self.mlp_enc(self.enc_h) + + if dec_z is None: + dec_z = enc_hs_pad.new_zeros(batch, self.dunits) + else: + dec_z = dec_z.view(batch, self.dunits) + + if att_prev is None: + # initial attention will be [1, 0, 0, ...] + att_prev = enc_hs_pad.new_zeros(*enc_hs_pad.size()[:2]) + att_prev[:, 0] = 1.0 + + # att_prev: utt x frame -> utt x 1 x 1 x frame + # -> utt x att_conv_chans x 1 x frame + att_conv = self.loc_conv(att_prev.view(batch, 1, 1, self.h_length)) + # att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans + att_conv = att_conv.squeeze(2).transpose(1, 2) + # att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim + att_conv = self.mlp_att(att_conv) + + # dec_z_tiled: utt x frame x att_dim + dec_z_tiled = self.mlp_dec(dec_z).view(batch, 1, self.att_dim) + + # dot with gvec + # utt x frame x att_dim -> utt x frame + e = self.gvec( + torch.tanh(att_conv + self.pre_compute_enc_h + dec_z_tiled) + ).squeeze(2) + + # NOTE consider zero padding when compute w. + if self.mask is None: + self.mask = to_device(enc_hs_pad, make_pad_mask(enc_hs_len)) + e.masked_fill_(self.mask, -float("inf")) + + # apply monotonic attention constraint (mainly for TTS) + if last_attended_idx is not None: + e = _apply_attention_constraint( + e, last_attended_idx, backward_window, forward_window + ) + + w = F.softmax(scaling * e, dim=1) + + # forward attention + att_prev_shift = F.pad(att_prev, (1, 0))[:, :-1] + w = ( + self.trans_agent_prob * att_prev + + (1 - self.trans_agent_prob) * att_prev_shift + ) * w + # NOTE: clamp is needed to avoid nan gradient + w = F.normalize(torch.clamp(w, 1e-6), p=1, dim=1) + + # weighted sum over flames + # utt x hdim + # NOTE use bmm instead of sum(*) + c = torch.sum(self.enc_h * w.view(batch, self.h_length, 1), dim=1) + + # update transition agent prob + self.trans_agent_prob = torch.sigmoid( + self.mlp_ta(torch.cat([c, out_prev, dec_z], dim=1)) + ) + + return c, w + + +def att_for(args, num_att=1, han_mode=False): + """Instantiates an attention module given the program arguments + + :param Namespace args: The arguments + :param int num_att: number of attention modules + (in multi-speaker case, it can be 2 or more) + :param bool han_mode: switch on/off mode of hierarchical attention network (HAN) + :rtype torch.nn.Module + :return: The attention module + """ + att_list = torch.nn.ModuleList() + num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility + aheads = getattr(args, "aheads", None) + awin = getattr(args, "awin", None) + aconv_chans = getattr(args, "aconv_chans", None) + aconv_filts = getattr(args, "aconv_filts", None) + + if num_encs == 1: + for i in range(num_att): + att = initial_att( + args.atype, + args.eprojs, + args.dunits, + aheads, + args.adim, + awin, + aconv_chans, + aconv_filts, + ) + att_list.append(att) + elif num_encs > 1: # no multi-speaker mode + if han_mode: + att = initial_att( + args.han_type, + args.eprojs, + args.dunits, + args.han_heads, + args.han_dim, + args.han_win, + args.han_conv_chans, + args.han_conv_filts, + han_mode=True, + ) + return att + else: + att_list = torch.nn.ModuleList() + for idx in range(num_encs): + att = initial_att( + args.atype[idx], + args.eprojs, + args.dunits, + aheads[idx], + args.adim[idx], + awin[idx], + aconv_chans[idx], + aconv_filts[idx], + ) + att_list.append(att) + else: + raise ValueError( + "Number of encoders needs to be more than one. {}".format(num_encs) + ) + return att_list + + +def initial_att( + atype, eprojs, dunits, aheads, adim, awin, aconv_chans, aconv_filts, han_mode=False +): + """Instantiates a single attention module + + :param str atype: attention type + :param int eprojs: # projection-units of encoder + :param int dunits: # units of decoder + :param int aheads: # heads of multi head attention + :param int adim: attention dimension + :param int awin: attention window size + :param int aconv_chans: # channels of attention convolution + :param int aconv_filts: filter size of attention convolution + :param bool han_mode: flag to swith on mode of hierarchical attention + :return: The attention module + """ + + if atype == "noatt": + att = NoAtt() + elif atype == "dot": + att = AttDot(eprojs, dunits, adim, han_mode) + elif atype == "add": + att = AttAdd(eprojs, dunits, adim, han_mode) + elif atype == "location": + att = AttLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) + elif atype == "location2d": + att = AttLoc2D(eprojs, dunits, adim, awin, aconv_chans, aconv_filts, han_mode) + elif atype == "location_recurrent": + att = AttLocRec(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) + elif atype == "coverage": + att = AttCov(eprojs, dunits, adim, han_mode) + elif atype == "coverage_location": + att = AttCovLoc(eprojs, dunits, adim, aconv_chans, aconv_filts, han_mode) + elif atype == "multi_head_dot": + att = AttMultiHeadDot(eprojs, dunits, aheads, adim, adim, han_mode) + elif atype == "multi_head_add": + att = AttMultiHeadAdd(eprojs, dunits, aheads, adim, adim, han_mode) + elif atype == "multi_head_loc": + att = AttMultiHeadLoc( + eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode + ) + elif atype == "multi_head_multi_res_loc": + att = AttMultiHeadMultiResLoc( + eprojs, dunits, aheads, adim, adim, aconv_chans, aconv_filts, han_mode + ) + return att + + +def att_to_numpy(att_ws, att): + """Converts attention weights to a numpy array given the attention + + :param list att_ws: The attention weights + :param torch.nn.Module att: The attention + :rtype: np.ndarray + :return: The numpy array of the attention weights + """ + # convert to numpy array with the shape (B, Lmax, Tmax) + if isinstance(att, AttLoc2D): + # att_ws => list of previous concate attentions + att_ws = torch.stack([aw[:, -1] for aw in att_ws], dim=1).cpu().numpy() + elif isinstance(att, (AttCov, AttCovLoc)): + # att_ws => list of list of previous attentions + att_ws = ( + torch.stack([aw[idx] for idx, aw in enumerate(att_ws)], dim=1).cpu().numpy() + ) + elif isinstance(att, AttLocRec): + # att_ws => list of tuple of attention and hidden states + att_ws = torch.stack([aw[0] for aw in att_ws], dim=1).cpu().numpy() + elif isinstance( + att, + (AttMultiHeadDot, AttMultiHeadAdd, AttMultiHeadLoc, AttMultiHeadMultiResLoc), + ): + # att_ws => list of list of each head attention + n_heads = len(att_ws[0]) + att_ws_sorted_by_head = [] + for h in six.moves.range(n_heads): + att_ws_head = torch.stack([aw[h] for aw in att_ws], dim=1) + att_ws_sorted_by_head += [att_ws_head] + att_ws = torch.stack(att_ws_sorted_by_head, dim=1).cpu().numpy() + else: + # att_ws => list of attentions + att_ws = torch.stack(att_ws, dim=1).cpu().numpy() + return att_ws diff --git a/espnet/nets/pytorch_backend/rnn/decoders.py b/espnet/nets/pytorch_backend/rnn/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..dc04e20e925fc3fc93eb9b7dd51c146660d075e2 --- /dev/null +++ b/espnet/nets/pytorch_backend/rnn/decoders.py @@ -0,0 +1,1218 @@ +from distutils.version import LooseVersion +import logging +import math +import random +import six + +import numpy as np +import torch +import torch.nn.functional as F + +from argparse import Namespace + +from espnet.nets.ctc_prefix_score import CTCPrefixScore +from espnet.nets.ctc_prefix_score import CTCPrefixScoreTH +from espnet.nets.e2e_asr_common import end_detect + +from espnet.nets.pytorch_backend.rnn.attentions import att_to_numpy + +from espnet.nets.pytorch_backend.nets_utils import mask_by_length +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.scorer_interface import ScorerInterface + +MAX_DECODER_OUTPUT = 5 +CTC_SCORING_RATIO = 1.5 + + +class Decoder(torch.nn.Module, ScorerInterface): + """Decoder module + + :param int eprojs: encoder projection units + :param int odim: dimension of outputs + :param str dtype: gru or lstm + :param int dlayers: decoder layers + :param int dunits: decoder units + :param int sos: start of sequence symbol id + :param int eos: end of sequence symbol id + :param torch.nn.Module att: attention module + :param int verbose: verbose level + :param list char_list: list of character strings + :param ndarray labeldist: distribution of label smoothing + :param float lsm_weight: label smoothing weight + :param float sampling_probability: scheduled sampling probability + :param float dropout: dropout rate + :param float context_residual: if True, use context vector for token generation + :param float replace_sos: use for multilingual (speech/text) translation + """ + + def __init__( + self, + eprojs, + odim, + dtype, + dlayers, + dunits, + sos, + eos, + att, + verbose=0, + char_list=None, + labeldist=None, + lsm_weight=0.0, + sampling_probability=0.0, + dropout=0.0, + context_residual=False, + replace_sos=False, + num_encs=1, + ): + + torch.nn.Module.__init__(self) + self.dtype = dtype + self.dunits = dunits + self.dlayers = dlayers + self.context_residual = context_residual + self.embed = torch.nn.Embedding(odim, dunits) + self.dropout_emb = torch.nn.Dropout(p=dropout) + + self.decoder = torch.nn.ModuleList() + self.dropout_dec = torch.nn.ModuleList() + self.decoder += [ + torch.nn.LSTMCell(dunits + eprojs, dunits) + if self.dtype == "lstm" + else torch.nn.GRUCell(dunits + eprojs, dunits) + ] + self.dropout_dec += [torch.nn.Dropout(p=dropout)] + for _ in six.moves.range(1, self.dlayers): + self.decoder += [ + torch.nn.LSTMCell(dunits, dunits) + if self.dtype == "lstm" + else torch.nn.GRUCell(dunits, dunits) + ] + self.dropout_dec += [torch.nn.Dropout(p=dropout)] + # NOTE: dropout is applied only for the vertical connections + # see https://arxiv.org/pdf/1409.2329.pdf + self.ignore_id = -1 + + if context_residual: + self.output = torch.nn.Linear(dunits + eprojs, odim) + else: + self.output = torch.nn.Linear(dunits, odim) + + self.loss = None + self.att = att + self.dunits = dunits + self.sos = sos + self.eos = eos + self.odim = odim + self.verbose = verbose + self.char_list = char_list + # for label smoothing + self.labeldist = labeldist + self.vlabeldist = None + self.lsm_weight = lsm_weight + self.sampling_probability = sampling_probability + self.dropout = dropout + self.num_encs = num_encs + + # for multilingual E2E-ST + self.replace_sos = replace_sos + + self.logzero = -10000000000.0 + + def zero_state(self, hs_pad): + return hs_pad.new_zeros(hs_pad.size(0), self.dunits) + + def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev): + if self.dtype == "lstm": + z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0])) + for i in six.moves.range(1, self.dlayers): + z_list[i], c_list[i] = self.decoder[i]( + self.dropout_dec[i - 1](z_list[i - 1]), (z_prev[i], c_prev[i]) + ) + else: + z_list[0] = self.decoder[0](ey, z_prev[0]) + for i in six.moves.range(1, self.dlayers): + z_list[i] = self.decoder[i]( + self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i] + ) + return z_list, c_list + + def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None): + """Decoder forward + + :param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D) + [in multi-encoder case, + list of torch.Tensor, + [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ] + :param torch.Tensor hlens: batch of lengths of hidden state sequences (B) + [in multi-encoder case, list of torch.Tensor, + [(B), (B), ..., ] + :param torch.Tensor ys_pad: batch of padded character id sequence tensor + (B, Lmax) + :param int strm_idx: stream index indicates the index of decoding stream. + :param torch.Tensor lang_ids: batch of target language id tensor (B, 1) + :return: attention loss value + :rtype: torch.Tensor + :return: accuracy + :rtype: float + """ + # to support mutiple encoder asr mode, in single encoder mode, + # convert torch.Tensor to List of torch.Tensor + if self.num_encs == 1: + hs_pad = [hs_pad] + hlens = [hlens] + + # TODO(kan-bayashi): need to make more smart way + ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys + # attention index for the attention module + # in SPA (speaker parallel attention), + # att_idx is used to select attention module. In other cases, it is 0. + att_idx = min(strm_idx, len(self.att) - 1) + + # hlens should be list of list of integer + hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)] + + self.loss = None + # prepare input and output word sequences with sos/eos IDs + eos = ys[0].new([self.eos]) + sos = ys[0].new([self.sos]) + if self.replace_sos: + ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)] + else: + ys_in = [torch.cat([sos, y], dim=0) for y in ys] + ys_out = [torch.cat([y, eos], dim=0) for y in ys] + + # padding for ys with -1 + # pys: utt x olen + ys_in_pad = pad_list(ys_in, self.eos) + ys_out_pad = pad_list(ys_out, self.ignore_id) + + # get dim, length info + batch = ys_out_pad.size(0) + olength = ys_out_pad.size(1) + for idx in range(self.num_encs): + logging.info( + self.__class__.__name__ + + "Number of Encoder:{}; enc{}: input lengths: {}.".format( + self.num_encs, idx + 1, hlens[idx] + ) + ) + logging.info( + self.__class__.__name__ + + " output lengths: " + + str([y.size(0) for y in ys_out]) + ) + + # initialization + c_list = [self.zero_state(hs_pad[0])] + z_list = [self.zero_state(hs_pad[0])] + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(hs_pad[0])) + z_list.append(self.zero_state(hs_pad[0])) + z_all = [] + if self.num_encs == 1: + att_w = None + self.att[att_idx].reset() # reset pre-computation of h + else: + att_w_list = [None] * (self.num_encs + 1) # atts + han + att_c_list = [None] * (self.num_encs) # atts + for idx in range(self.num_encs + 1): + self.att[idx].reset() # reset pre-computation of h in atts and han + + # pre-computation of embedding + eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim + + # loop for an output sequence + for i in six.moves.range(olength): + if self.num_encs == 1: + att_c, att_w = self.att[att_idx]( + hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w + ) + else: + for idx in range(self.num_encs): + att_c_list[idx], att_w_list[idx] = self.att[idx]( + hs_pad[idx], + hlens[idx], + self.dropout_dec[0](z_list[0]), + att_w_list[idx], + ) + hs_pad_han = torch.stack(att_c_list, dim=1) + hlens_han = [self.num_encs] * len(ys_in) + att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( + hs_pad_han, + hlens_han, + self.dropout_dec[0](z_list[0]), + att_w_list[self.num_encs], + ) + if i > 0 and random.random() < self.sampling_probability: + logging.info(" scheduled sampling ") + z_out = self.output(z_all[-1]) + z_out = np.argmax(z_out.detach().cpu(), axis=1) + z_out = self.dropout_emb(self.embed(to_device(hs_pad[0], z_out))) + ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim) + else: + ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) + if self.context_residual: + z_all.append( + torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) + ) # utt x (zdim + hdim) + else: + z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim) + + z_all = torch.stack(z_all, dim=1).view(batch * olength, -1) + # compute loss + y_all = self.output(z_all) + if LooseVersion(torch.__version__) < LooseVersion("1.0"): + reduction_str = "elementwise_mean" + else: + reduction_str = "mean" + self.loss = F.cross_entropy( + y_all, + ys_out_pad.view(-1), + ignore_index=self.ignore_id, + reduction=reduction_str, + ) + # compute perplexity + ppl = math.exp(self.loss.item()) + # -1: eos, which is removed in the loss computation + self.loss *= np.mean([len(x) for x in ys_in]) - 1 + acc = th_accuracy(y_all, ys_out_pad, ignore_label=self.ignore_id) + logging.info("att loss:" + "".join(str(self.loss.item()).split("\n"))) + + # show predicted character sequence for debug + if self.verbose > 0 and self.char_list is not None: + ys_hat = y_all.view(batch, olength, -1) + ys_true = ys_out_pad + for (i, y_hat), y_true in zip( + enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy() + ): + if i == MAX_DECODER_OUTPUT: + break + idx_hat = np.argmax(y_hat[y_true != self.ignore_id], axis=1) + idx_true = y_true[y_true != self.ignore_id] + seq_hat = [self.char_list[int(idx)] for idx in idx_hat] + seq_true = [self.char_list[int(idx)] for idx in idx_true] + seq_hat = "".join(seq_hat) + seq_true = "".join(seq_true) + logging.info("groundtruth[%d]: " % i + seq_true) + logging.info("prediction [%d]: " % i + seq_hat) + + if self.labeldist is not None: + if self.vlabeldist is None: + self.vlabeldist = to_device(hs_pad[0], torch.from_numpy(self.labeldist)) + loss_reg = -torch.sum( + (F.log_softmax(y_all, dim=1) * self.vlabeldist).view(-1), dim=0 + ) / len(ys_in) + self.loss = (1.0 - self.lsm_weight) * self.loss + self.lsm_weight * loss_reg + + return self.loss, acc, ppl + + def recognize_beam(self, h, lpz, recog_args, char_list, rnnlm=None, strm_idx=0): + """beam search implementation + + :param torch.Tensor h: encoder hidden state (T, eprojs) + [in multi-encoder case, list of torch.Tensor, + [(T1, eprojs), (T2, eprojs), ...] ] + :param torch.Tensor lpz: ctc log softmax output (T, odim) + [in multi-encoder case, list of torch.Tensor, + [(T1, odim), (T2, odim), ...] ] + :param Namespace recog_args: argument Namespace containing options + :param char_list: list of character strings + :param torch.nn.Module rnnlm: language module + :param int strm_idx: + stream index for speaker parallel attention in multi-speaker case + :return: N-best decoding results + :rtype: list of dicts + """ + # to support mutiple encoder asr mode, in single encoder mode, + # convert torch.Tensor to List of torch.Tensor + if self.num_encs == 1: + h = [h] + lpz = [lpz] + if self.num_encs > 1 and lpz is None: + lpz = [lpz] * self.num_encs + + for idx in range(self.num_encs): + logging.info( + "Number of Encoder:{}; enc{}: input lengths: {}.".format( + self.num_encs, idx + 1, h[0].size(0) + ) + ) + att_idx = min(strm_idx, len(self.att) - 1) + # initialization + c_list = [self.zero_state(h[0].unsqueeze(0))] + z_list = [self.zero_state(h[0].unsqueeze(0))] + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(h[0].unsqueeze(0))) + z_list.append(self.zero_state(h[0].unsqueeze(0))) + if self.num_encs == 1: + a = None + self.att[att_idx].reset() # reset pre-computation of h + else: + a = [None] * (self.num_encs + 1) # atts + han + att_w_list = [None] * (self.num_encs + 1) # atts + han + att_c_list = [None] * (self.num_encs) # atts + for idx in range(self.num_encs + 1): + self.att[idx].reset() # reset pre-computation of h in atts and han + + # search parms + beam = recog_args.beam_size + penalty = recog_args.penalty + ctc_weight = getattr(recog_args, "ctc_weight", False) # for NMT + + if lpz[0] is not None and self.num_encs > 1: + # weights-ctc, + # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss + weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( + recog_args.weights_ctc_dec + ) # normalize + logging.info( + "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]) + ) + else: + weights_ctc_dec = [1.0] + + # preprate sos + if self.replace_sos and recog_args.tgt_lang: + y = char_list.index(recog_args.tgt_lang) + else: + y = self.sos + logging.info(" index: " + str(y)) + logging.info(" mark: " + char_list[y]) + vy = h[0].new_zeros(1).long() + + maxlen = np.amin([h[idx].size(0) for idx in range(self.num_encs)]) + if recog_args.maxlenratio != 0: + # maxlen >= 1 + maxlen = max(1, int(recog_args.maxlenratio * maxlen)) + minlen = int(recog_args.minlenratio * maxlen) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # initialize hypothesis + if rnnlm: + hyp = { + "score": 0.0, + "yseq": [y], + "c_prev": c_list, + "z_prev": z_list, + "a_prev": a, + "rnnlm_prev": None, + } + else: + hyp = { + "score": 0.0, + "yseq": [y], + "c_prev": c_list, + "z_prev": z_list, + "a_prev": a, + } + if lpz[0] is not None: + ctc_prefix_score = [ + CTCPrefixScore(lpz[idx].detach().numpy(), 0, self.eos, np) + for idx in range(self.num_encs) + ] + hyp["ctc_state_prev"] = [ + ctc_prefix_score[idx].initial_state() for idx in range(self.num_encs) + ] + hyp["ctc_score_prev"] = [0.0] * self.num_encs + if ctc_weight != 1.0: + # pre-pruning based on attention scores + ctc_beam = min(lpz[0].shape[-1], int(beam * CTC_SCORING_RATIO)) + else: + ctc_beam = lpz[0].shape[-1] + hyps = [hyp] + ended_hyps = [] + + for i in six.moves.range(maxlen): + logging.debug("position " + str(i)) + + hyps_best_kept = [] + for hyp in hyps: + vy[0] = hyp["yseq"][i] + ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim + if self.num_encs == 1: + att_c, att_w = self.att[att_idx]( + h[0].unsqueeze(0), + [h[0].size(0)], + self.dropout_dec[0](hyp["z_prev"][0]), + hyp["a_prev"], + ) + else: + for idx in range(self.num_encs): + att_c_list[idx], att_w_list[idx] = self.att[idx]( + h[idx].unsqueeze(0), + [h[idx].size(0)], + self.dropout_dec[0](hyp["z_prev"][0]), + hyp["a_prev"][idx], + ) + h_han = torch.stack(att_c_list, dim=1) + att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( + h_han, + [self.num_encs], + self.dropout_dec[0](hyp["z_prev"][0]), + hyp["a_prev"][self.num_encs], + ) + ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) + z_list, c_list = self.rnn_forward( + ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"] + ) + + # get nbest local scores and their ids + if self.context_residual: + logits = self.output( + torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) + ) + else: + logits = self.output(self.dropout_dec[-1](z_list[-1])) + local_att_scores = F.log_softmax(logits, dim=1) + if rnnlm: + rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy) + local_scores = ( + local_att_scores + recog_args.lm_weight * local_lm_scores + ) + else: + local_scores = local_att_scores + + if lpz[0] is not None: + local_best_scores, local_best_ids = torch.topk( + local_att_scores, ctc_beam, dim=1 + ) + ctc_scores, ctc_states = ( + [None] * self.num_encs, + [None] * self.num_encs, + ) + for idx in range(self.num_encs): + ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx]( + hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx] + ) + local_scores = (1.0 - ctc_weight) * local_att_scores[ + :, local_best_ids[0] + ] + if self.num_encs == 1: + local_scores += ctc_weight * torch.from_numpy( + ctc_scores[0] - hyp["ctc_score_prev"][0] + ) + else: + for idx in range(self.num_encs): + local_scores += ( + ctc_weight + * weights_ctc_dec[idx] + * torch.from_numpy( + ctc_scores[idx] - hyp["ctc_score_prev"][idx] + ) + ) + if rnnlm: + local_scores += ( + recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]] + ) + local_best_scores, joint_best_ids = torch.topk( + local_scores, beam, dim=1 + ) + local_best_ids = local_best_ids[:, joint_best_ids[0]] + else: + local_best_scores, local_best_ids = torch.topk( + local_scores, beam, dim=1 + ) + + for j in six.moves.range(beam): + new_hyp = {} + # [:] is needed! + new_hyp["z_prev"] = z_list[:] + new_hyp["c_prev"] = c_list[:] + if self.num_encs == 1: + new_hyp["a_prev"] = att_w[:] + else: + new_hyp["a_prev"] = [ + att_w_list[idx][:] for idx in range(self.num_encs + 1) + ] + new_hyp["score"] = hyp["score"] + local_best_scores[0, j] + new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"])) + new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"] + new_hyp["yseq"][len(hyp["yseq"])] = int(local_best_ids[0, j]) + if rnnlm: + new_hyp["rnnlm_prev"] = rnnlm_state + if lpz[0] is not None: + new_hyp["ctc_state_prev"] = [ + ctc_states[idx][joint_best_ids[0, j]] + for idx in range(self.num_encs) + ] + new_hyp["ctc_score_prev"] = [ + ctc_scores[idx][joint_best_ids[0, j]] + for idx in range(self.num_encs) + ] + # will be (2 x beam) hyps at most + hyps_best_kept.append(new_hyp) + + hyps_best_kept = sorted( + hyps_best_kept, key=lambda x: x["score"], reverse=True + )[:beam] + + # sort and get nbest + hyps = hyps_best_kept + logging.debug("number of pruned hypotheses: " + str(len(hyps))) + logging.debug( + "best hypo: " + + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]) + ) + + # add eos in the final loop to avoid that there are no ended hyps + if i == maxlen - 1: + logging.info("adding in the last position in the loop") + for hyp in hyps: + hyp["yseq"].append(self.eos) + + # add ended hypotheses to a final list, + # and removed them from current hypotheses + # (this will be a problem, number of hyps < beam) + remained_hyps = [] + for hyp in hyps: + if hyp["yseq"][-1] == self.eos: + # only store the sequence that has more than minlen outputs + # also add penalty + if len(hyp["yseq"]) > minlen: + hyp["score"] += (i + 1) * penalty + if rnnlm: # Word LM needs to add final score + hyp["score"] += recog_args.lm_weight * rnnlm.final( + hyp["rnnlm_prev"] + ) + ended_hyps.append(hyp) + else: + remained_hyps.append(hyp) + + # end detection + if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: + logging.info("end detected at %d", i) + break + + hyps = remained_hyps + if len(hyps) > 0: + logging.debug("remaining hypotheses: " + str(len(hyps))) + else: + logging.info("no hypothesis. Finish decoding.") + break + + for hyp in hyps: + logging.debug( + "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]) + ) + + logging.debug("number of ended hypotheses: " + str(len(ended_hyps))) + + nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[ + : min(len(ended_hyps), recog_args.nbest) + ] + + # check number of hypotheses + if len(nbest_hyps) == 0: + logging.warning( + "there is no N-best results, " + "perform recognition again with smaller minlenratio." + ) + # should copy because Namespace will be overwritten globally + recog_args = Namespace(**vars(recog_args)) + recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) + if self.num_encs == 1: + return self.recognize_beam(h[0], lpz[0], recog_args, char_list, rnnlm) + else: + return self.recognize_beam(h, lpz, recog_args, char_list, rnnlm) + + logging.info("total log probability: " + str(nbest_hyps[0]["score"])) + logging.info( + "normalized log probability: " + + str(nbest_hyps[0]["score"] / len(nbest_hyps[0]["yseq"])) + ) + + # remove sos + return nbest_hyps + + def recognize_beam_batch( + self, + h, + hlens, + lpz, + recog_args, + char_list, + rnnlm=None, + normalize_score=True, + strm_idx=0, + lang_ids=None, + ): + # to support mutiple encoder asr mode, in single encoder mode, + # convert torch.Tensor to List of torch.Tensor + if self.num_encs == 1: + h = [h] + hlens = [hlens] + lpz = [lpz] + if self.num_encs > 1 and lpz is None: + lpz = [lpz] * self.num_encs + + att_idx = min(strm_idx, len(self.att) - 1) + for idx in range(self.num_encs): + logging.info( + "Number of Encoder:{}; enc{}: input lengths: {}.".format( + self.num_encs, idx + 1, h[idx].size(1) + ) + ) + h[idx] = mask_by_length(h[idx], hlens[idx], 0.0) + + # search params + batch = len(hlens[0]) + beam = recog_args.beam_size + penalty = recog_args.penalty + ctc_weight = getattr(recog_args, "ctc_weight", 0) # for NMT + att_weight = 1.0 - ctc_weight + ctc_margin = getattr( + recog_args, "ctc_window_margin", 0 + ) # use getattr to keep compatibility + # weights-ctc, + # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss + if lpz[0] is not None and self.num_encs > 1: + weights_ctc_dec = recog_args.weights_ctc_dec / np.sum( + recog_args.weights_ctc_dec + ) # normalize + logging.info( + "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]) + ) + else: + weights_ctc_dec = [1.0] + + n_bb = batch * beam + pad_b = to_device(h[0], torch.arange(batch) * beam).view(-1, 1) + + max_hlen = np.amin([max(hlens[idx]) for idx in range(self.num_encs)]) + if recog_args.maxlenratio == 0: + maxlen = max_hlen + else: + maxlen = max(1, int(recog_args.maxlenratio * max_hlen)) + minlen = int(recog_args.minlenratio * max_hlen) + logging.info("max output length: " + str(maxlen)) + logging.info("min output length: " + str(minlen)) + + # initialization + c_prev = [ + to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) + ] + z_prev = [ + to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) + ] + c_list = [ + to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) + ] + z_list = [ + to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers) + ] + vscores = to_device(h[0], torch.zeros(batch, beam)) + + rnnlm_state = None + if self.num_encs == 1: + a_prev = [None] + att_w_list, ctc_scorer, ctc_state = [None], [None], [None] + self.att[att_idx].reset() # reset pre-computation of h + else: + a_prev = [None] * (self.num_encs + 1) # atts + han + att_w_list = [None] * (self.num_encs + 1) # atts + han + att_c_list = [None] * (self.num_encs) # atts + ctc_scorer, ctc_state = [None] * (self.num_encs), [None] * (self.num_encs) + for idx in range(self.num_encs + 1): + self.att[idx].reset() # reset pre-computation of h in atts and han + + if self.replace_sos and recog_args.tgt_lang: + logging.info(" index: " + str(char_list.index(recog_args.tgt_lang))) + logging.info(" mark: " + recog_args.tgt_lang) + yseq = [ + [char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb) + ] + elif lang_ids is not None: + # NOTE: used for evaluation during training + yseq = [ + [lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb) + ] + else: + logging.info(" index: " + str(self.sos)) + logging.info(" mark: " + char_list[self.sos]) + yseq = [[self.sos] for _ in six.moves.range(n_bb)] + + accum_odim_ids = [self.sos for _ in six.moves.range(n_bb)] + stop_search = [False for _ in six.moves.range(batch)] + nbest_hyps = [[] for _ in six.moves.range(batch)] + ended_hyps = [[] for _ in range(batch)] + + exp_hlens = [ + hlens[idx].repeat(beam).view(beam, batch).transpose(0, 1).contiguous() + for idx in range(self.num_encs) + ] + exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)] + exp_h = [ + h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() + for idx in range(self.num_encs) + ] + exp_h = [ + exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2]) + for idx in range(self.num_encs) + ] + + if lpz[0] is not None: + scoring_num = min( + int(beam * CTC_SCORING_RATIO) + if att_weight > 0.0 and not lpz[0].is_cuda + else 0, + lpz[0].size(-1), + ) + ctc_scorer = [ + CTCPrefixScoreTH( + lpz[idx], + hlens[idx], + 0, + self.eos, + margin=ctc_margin, + ) + for idx in range(self.num_encs) + ] + + for i in six.moves.range(maxlen): + logging.debug("position " + str(i)) + + vy = to_device(h[0], torch.LongTensor(self._get_last_yseq(yseq))) + ey = self.dropout_emb(self.embed(vy)) + if self.num_encs == 1: + att_c, att_w = self.att[att_idx]( + exp_h[0], exp_hlens[0], self.dropout_dec[0](z_prev[0]), a_prev[0] + ) + att_w_list = [att_w] + else: + for idx in range(self.num_encs): + att_c_list[idx], att_w_list[idx] = self.att[idx]( + exp_h[idx], + exp_hlens[idx], + self.dropout_dec[0](z_prev[0]), + a_prev[idx], + ) + exp_h_han = torch.stack(att_c_list, dim=1) + att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( + exp_h_han, + [self.num_encs] * n_bb, + self.dropout_dec[0](z_prev[0]), + a_prev[self.num_encs], + ) + ey = torch.cat((ey, att_c), dim=1) + + # attention decoder + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev) + if self.context_residual: + logits = self.output( + torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) + ) + else: + logits = self.output(self.dropout_dec[-1](z_list[-1])) + local_scores = att_weight * F.log_softmax(logits, dim=1) + + # rnnlm + if rnnlm: + rnnlm_state, local_lm_scores = rnnlm.buff_predict(rnnlm_state, vy, n_bb) + local_scores = local_scores + recog_args.lm_weight * local_lm_scores + + # ctc + if ctc_scorer[0]: + local_scores[:, 0] = self.logzero # avoid choosing blank + part_ids = ( + torch.topk(local_scores, scoring_num, dim=-1)[1] + if scoring_num > 0 + else None + ) + for idx in range(self.num_encs): + att_w = att_w_list[idx] + att_w_ = att_w if isinstance(att_w, torch.Tensor) else att_w[0] + local_ctc_scores, ctc_state[idx] = ctc_scorer[idx]( + yseq, ctc_state[idx], part_ids, att_w_ + ) + local_scores = ( + local_scores + + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores + ) + + local_scores = local_scores.view(batch, beam, self.odim) + if i == 0: + local_scores[:, 1:, :] = self.logzero + + # accumulate scores + eos_vscores = local_scores[:, :, self.eos] + vscores + vscores = vscores.view(batch, beam, 1).repeat(1, 1, self.odim) + vscores[:, :, self.eos] = self.logzero + vscores = (vscores + local_scores).view(batch, -1) + + # global pruning + accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1) + accum_odim_ids = ( + torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist() + ) + accum_padded_beam_ids = ( + (accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist() + ) + + y_prev = yseq[:][:] + yseq = self._index_select_list(yseq, accum_padded_beam_ids) + yseq = self._append_ids(yseq, accum_odim_ids) + vscores = accum_best_scores + vidx = to_device(h[0], torch.LongTensor(accum_padded_beam_ids)) + + a_prev = [] + num_atts = self.num_encs if self.num_encs == 1 else self.num_encs + 1 + for idx in range(num_atts): + if isinstance(att_w_list[idx], torch.Tensor): + _a_prev = torch.index_select( + att_w_list[idx].view(n_bb, *att_w_list[idx].shape[1:]), 0, vidx + ) + elif isinstance(att_w_list[idx], list): + # handle the case of multi-head attention + _a_prev = [ + torch.index_select(att_w_one.view(n_bb, -1), 0, vidx) + for att_w_one in att_w_list[idx] + ] + else: + # handle the case of location_recurrent when return is a tuple + _a_prev_ = torch.index_select( + att_w_list[idx][0].view(n_bb, -1), 0, vidx + ) + _h_prev_ = torch.index_select( + att_w_list[idx][1][0].view(n_bb, -1), 0, vidx + ) + _c_prev_ = torch.index_select( + att_w_list[idx][1][1].view(n_bb, -1), 0, vidx + ) + _a_prev = (_a_prev_, (_h_prev_, _c_prev_)) + a_prev.append(_a_prev) + z_prev = [ + torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) + for li in range(self.dlayers) + ] + c_prev = [ + torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) + for li in range(self.dlayers) + ] + + # pick ended hyps + if i >= minlen: + k = 0 + penalty_i = (i + 1) * penalty + thr = accum_best_scores[:, -1] + for samp_i in six.moves.range(batch): + if stop_search[samp_i]: + k = k + beam + continue + for beam_j in six.moves.range(beam): + _vscore = None + if eos_vscores[samp_i, beam_j] > thr[samp_i]: + yk = y_prev[k][:] + if len(yk) <= min( + hlens[idx][samp_i] for idx in range(self.num_encs) + ): + _vscore = eos_vscores[samp_i][beam_j] + penalty_i + elif i == maxlen - 1: + yk = yseq[k][:] + _vscore = vscores[samp_i][beam_j] + penalty_i + if _vscore: + yk.append(self.eos) + if rnnlm: + _vscore += recog_args.lm_weight * rnnlm.final( + rnnlm_state, index=k + ) + _score = _vscore.data.cpu().numpy() + ended_hyps[samp_i].append( + {"yseq": yk, "vscore": _vscore, "score": _score} + ) + k = k + 1 + + # end detection + stop_search = [ + stop_search[samp_i] or end_detect(ended_hyps[samp_i], i) + for samp_i in six.moves.range(batch) + ] + stop_search_summary = list(set(stop_search)) + if len(stop_search_summary) == 1 and stop_search_summary[0]: + break + + if rnnlm: + rnnlm_state = self._index_select_lm_state(rnnlm_state, 0, vidx) + if ctc_scorer[0]: + for idx in range(self.num_encs): + ctc_state[idx] = ctc_scorer[idx].index_select_state( + ctc_state[idx], accum_best_ids + ) + + torch.cuda.empty_cache() + + dummy_hyps = [ + {"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])} + ] + ended_hyps = [ + ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps + for samp_i in six.moves.range(batch) + ] + if normalize_score: + for samp_i in six.moves.range(batch): + for x in ended_hyps[samp_i]: + x["score"] /= len(x["yseq"]) + + nbest_hyps = [ + sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[ + : min(len(ended_hyps[samp_i]), recog_args.nbest) + ] + for samp_i in six.moves.range(batch) + ] + + return nbest_hyps + + def calculate_all_attentions(self, hs_pad, hlen, ys_pad, strm_idx=0, lang_ids=None): + """Calculate all of attentions + + :param torch.Tensor hs_pad: batch of padded hidden state sequences + (B, Tmax, D) + in multi-encoder case, list of torch.Tensor, + [(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ] + :param torch.Tensor hlen: batch of lengths of hidden state sequences (B) + [in multi-encoder case, list of torch.Tensor, + [(B), (B), ..., ] + :param torch.Tensor ys_pad: + batch of padded character id sequence tensor (B, Lmax) + :param int strm_idx: + stream index for parallel speaker attention in multi-speaker case + :param torch.Tensor lang_ids: batch of target language id tensor (B, 1) + :return: attention weights with the following shape, + 1) multi-head case => attention weights (B, H, Lmax, Tmax), + 2) multi-encoder case => + [(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)] + 3) other case => attention weights (B, Lmax, Tmax). + :rtype: float ndarray + """ + # to support mutiple encoder asr mode, in single encoder mode, + # convert torch.Tensor to List of torch.Tensor + if self.num_encs == 1: + hs_pad = [hs_pad] + hlen = [hlen] + + # TODO(kan-bayashi): need to make more smart way + ys = [y[y != self.ignore_id] for y in ys_pad] # parse padded ys + att_idx = min(strm_idx, len(self.att) - 1) + + # hlen should be list of list of integer + hlen = [list(map(int, hlen[idx])) for idx in range(self.num_encs)] + + self.loss = None + # prepare input and output word sequences with sos/eos IDs + eos = ys[0].new([self.eos]) + sos = ys[0].new([self.sos]) + if self.replace_sos: + ys_in = [torch.cat([idx, y], dim=0) for idx, y in zip(lang_ids, ys)] + else: + ys_in = [torch.cat([sos, y], dim=0) for y in ys] + ys_out = [torch.cat([y, eos], dim=0) for y in ys] + + # padding for ys with -1 + # pys: utt x olen + ys_in_pad = pad_list(ys_in, self.eos) + ys_out_pad = pad_list(ys_out, self.ignore_id) + + # get length info + olength = ys_out_pad.size(1) + + # initialization + c_list = [self.zero_state(hs_pad[0])] + z_list = [self.zero_state(hs_pad[0])] + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(hs_pad[0])) + z_list.append(self.zero_state(hs_pad[0])) + att_ws = [] + if self.num_encs == 1: + att_w = None + self.att[att_idx].reset() # reset pre-computation of h + else: + att_w_list = [None] * (self.num_encs + 1) # atts + han + att_c_list = [None] * (self.num_encs) # atts + for idx in range(self.num_encs + 1): + self.att[idx].reset() # reset pre-computation of h in atts and han + + # pre-computation of embedding + eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim + + # loop for an output sequence + for i in six.moves.range(olength): + if self.num_encs == 1: + att_c, att_w = self.att[att_idx]( + hs_pad[0], hlen[0], self.dropout_dec[0](z_list[0]), att_w + ) + att_ws.append(att_w) + else: + for idx in range(self.num_encs): + att_c_list[idx], att_w_list[idx] = self.att[idx]( + hs_pad[idx], + hlen[idx], + self.dropout_dec[0](z_list[0]), + att_w_list[idx], + ) + hs_pad_han = torch.stack(att_c_list, dim=1) + hlen_han = [self.num_encs] * len(ys_in) + att_c, att_w_list[self.num_encs] = self.att[self.num_encs]( + hs_pad_han, + hlen_han, + self.dropout_dec[0](z_list[0]), + att_w_list[self.num_encs], + ) + att_ws.append(att_w_list.copy()) + ey = torch.cat((eys[:, i, :], att_c), dim=1) # utt x (zdim + hdim) + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) + + if self.num_encs == 1: + # convert to numpy array with the shape (B, Lmax, Tmax) + att_ws = att_to_numpy(att_ws, self.att[att_idx]) + else: + _att_ws = [] + for idx, ws in enumerate(zip(*att_ws)): + ws = att_to_numpy(ws, self.att[idx]) + _att_ws.append(ws) + att_ws = _att_ws + return att_ws + + @staticmethod + def _get_last_yseq(exp_yseq): + last = [] + for y_seq in exp_yseq: + last.append(y_seq[-1]) + return last + + @staticmethod + def _append_ids(yseq, ids): + if isinstance(ids, list): + for i, j in enumerate(ids): + yseq[i].append(j) + else: + for i in range(len(yseq)): + yseq[i].append(ids) + return yseq + + @staticmethod + def _index_select_list(yseq, lst): + new_yseq = [] + for i in lst: + new_yseq.append(yseq[i][:]) + return new_yseq + + @staticmethod + def _index_select_lm_state(rnnlm_state, dim, vidx): + if isinstance(rnnlm_state, dict): + new_state = {} + for k, v in rnnlm_state.items(): + new_state[k] = [torch.index_select(vi, dim, vidx) for vi in v] + elif isinstance(rnnlm_state, list): + new_state = [] + for i in vidx: + new_state.append(rnnlm_state[int(i)][:]) + return new_state + + # scorer interface methods + def init_state(self, x): + # to support mutiple encoder asr mode, in single encoder mode, + # convert torch.Tensor to List of torch.Tensor + if self.num_encs == 1: + x = [x] + + c_list = [self.zero_state(x[0].unsqueeze(0))] + z_list = [self.zero_state(x[0].unsqueeze(0))] + for _ in six.moves.range(1, self.dlayers): + c_list.append(self.zero_state(x[0].unsqueeze(0))) + z_list.append(self.zero_state(x[0].unsqueeze(0))) + # TODO(karita): support strm_index for `asr_mix` + strm_index = 0 + att_idx = min(strm_index, len(self.att) - 1) + if self.num_encs == 1: + a = None + self.att[att_idx].reset() # reset pre-computation of h + else: + a = [None] * (self.num_encs + 1) # atts + han + for idx in range(self.num_encs + 1): + self.att[idx].reset() # reset pre-computation of h in atts and han + return dict( + c_prev=c_list[:], + z_prev=z_list[:], + a_prev=a, + workspace=(att_idx, z_list, c_list), + ) + + def score(self, yseq, state, x): + # to support mutiple encoder asr mode, in single encoder mode, + # convert torch.Tensor to List of torch.Tensor + if self.num_encs == 1: + x = [x] + + att_idx, z_list, c_list = state["workspace"] + vy = yseq[-1].unsqueeze(0) + ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim + if self.num_encs == 1: + att_c, att_w = self.att[att_idx]( + x[0].unsqueeze(0), + [x[0].size(0)], + self.dropout_dec[0](state["z_prev"][0]), + state["a_prev"], + ) + else: + att_w = [None] * (self.num_encs + 1) # atts + han + att_c_list = [None] * (self.num_encs) # atts + for idx in range(self.num_encs): + att_c_list[idx], att_w[idx] = self.att[idx]( + x[idx].unsqueeze(0), + [x[idx].size(0)], + self.dropout_dec[0](state["z_prev"][0]), + state["a_prev"][idx], + ) + h_han = torch.stack(att_c_list, dim=1) + att_c, att_w[self.num_encs] = self.att[self.num_encs]( + h_han, + [self.num_encs], + self.dropout_dec[0](state["z_prev"][0]), + state["a_prev"][self.num_encs], + ) + ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) + z_list, c_list = self.rnn_forward( + ey, z_list, c_list, state["z_prev"], state["c_prev"] + ) + if self.context_residual: + logits = self.output( + torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) + ) + else: + logits = self.output(self.dropout_dec[-1](z_list[-1])) + logp = F.log_softmax(logits, dim=1).squeeze(0) + return ( + logp, + dict( + c_prev=c_list[:], + z_prev=z_list[:], + a_prev=att_w, + workspace=(att_idx, z_list, c_list), + ), + ) + + +def decoder_for(args, odim, sos, eos, att, labeldist): + return Decoder( + args.eprojs, + odim, + args.dtype, + args.dlayers, + args.dunits, + sos, + eos, + att, + args.verbose, + args.char_list, + labeldist, + args.lsm_weight, + args.sampling_probability, + args.dropout_rate_decoder, + getattr(args, "context_residual", False), # use getattr to keep compatibility + getattr(args, "replace_sos", False), # use getattr to keep compatibility + getattr(args, "num_encs", 1), + ) # use getattr to keep compatibility diff --git a/espnet/nets/pytorch_backend/rnn/encoders.py b/espnet/nets/pytorch_backend/rnn/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..f01acd5a6a452a18c91bcab48c5a1944fe057a72 --- /dev/null +++ b/espnet/nets/pytorch_backend/rnn/encoders.py @@ -0,0 +1,372 @@ +import logging +import six + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + +from espnet.nets.e2e_asr_common import get_vgg2l_odim +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.nets_utils import to_device + + +class RNNP(torch.nn.Module): + """RNN with projection layer module + + :param int idim: dimension of inputs + :param int elayers: number of encoder layers + :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) + :param int hdim: number of projection units + :param np.ndarray subsample: list of subsampling numbers + :param float dropout: dropout rate + :param str typ: The RNN type + """ + + def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"): + super(RNNP, self).__init__() + bidir = typ[0] == "b" + for i in six.moves.range(elayers): + if i == 0: + inputdim = idim + else: + inputdim = hdim + + RNN = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU + rnn = RNN( + inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True + ) + + setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn) + + # bottleneck layer to merge + if bidir: + setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim)) + else: + setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim)) + + self.elayers = elayers + self.cdim = cdim + self.subsample = subsample + self.typ = typ + self.bidir = bidir + self.dropout = dropout + + def forward(self, xs_pad, ilens, prev_state=None): + """RNNP forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor prev_state: batch of previous RNN states + :return: batch of hidden state sequences (B, Tmax, hdim) + :rtype: torch.Tensor + """ + logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) + elayer_states = [] + for layer in six.moves.range(self.elayers): + if not isinstance(ilens, torch.Tensor): + ilens = torch.tensor(ilens) + xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True) + rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer)) + rnn.flatten_parameters() + if prev_state is not None and rnn.bidirectional: + prev_state = reset_backward_rnn_state(prev_state) + ys, states = rnn( + xs_pack, hx=None if prev_state is None else prev_state[layer] + ) + elayer_states.append(states) + # ys: utt list of frame x cdim x 2 (2: means bidirectional) + ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) + sub = self.subsample[layer + 1] + if sub > 1: + ys_pad = ys_pad[:, ::sub] + ilens = torch.tensor([int(i + 1) // sub for i in ilens]) + # (sum _utt frame_utt) x dim + projection_layer = getattr(self, "bt%d" % layer) + projected = projection_layer(ys_pad.contiguous().view(-1, ys_pad.size(2))) + xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) + if layer < self.elayers - 1: + xs_pad = torch.tanh(F.dropout(xs_pad, p=self.dropout)) + + return xs_pad, ilens, elayer_states # x: utt list of frame x dim + + +class RNN(torch.nn.Module): + """RNN module + + :param int idim: dimension of inputs + :param int elayers: number of encoder layers + :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) + :param int hdim: number of final projection units + :param float dropout: dropout rate + :param str typ: The RNN type + """ + + def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"): + super(RNN, self).__init__() + bidir = typ[0] == "b" + self.nbrnn = ( + torch.nn.LSTM( + idim, + cdim, + elayers, + batch_first=True, + dropout=dropout, + bidirectional=bidir, + ) + if "lstm" in typ + else torch.nn.GRU( + idim, + cdim, + elayers, + batch_first=True, + dropout=dropout, + bidirectional=bidir, + ) + ) + if bidir: + self.l_last = torch.nn.Linear(cdim * 2, hdim) + else: + self.l_last = torch.nn.Linear(cdim, hdim) + self.typ = typ + + def forward(self, xs_pad, ilens, prev_state=None): + """RNN forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor prev_state: batch of previous RNN states + :return: batch of hidden state sequences (B, Tmax, eprojs) + :rtype: torch.Tensor + """ + logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) + if not isinstance(ilens, torch.Tensor): + ilens = torch.tensor(ilens) + xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True) + self.nbrnn.flatten_parameters() + if prev_state is not None and self.nbrnn.bidirectional: + # We assume that when previous state is passed, + # it means that we're streaming the input + # and therefore cannot propagate backward BRNN state + # (otherwise it goes in the wrong direction) + prev_state = reset_backward_rnn_state(prev_state) + ys, states = self.nbrnn(xs_pack, hx=prev_state) + # ys: utt list of frame x cdim x 2 (2: means bidirectional) + ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) + # (sum _utt frame_utt) x dim + projected = torch.tanh( + self.l_last(ys_pad.contiguous().view(-1, ys_pad.size(2))) + ) + xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) + return xs_pad, ilens, states # x: utt list of frame x dim + + +def reset_backward_rnn_state(states): + """Sets backward BRNN states to zeroes + + Useful in processing of sliding windows over the inputs + """ + if isinstance(states, (list, tuple)): + for state in states: + state[1::2] = 0.0 + else: + states[1::2] = 0.0 + return states + + +class VGG2L(torch.nn.Module): + """VGG-like module + + :param int in_channel: number of input channels + """ + + def __init__(self, in_channel=1): + super(VGG2L, self).__init__() + # CNN layer (VGG motivated) + self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1) + self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1) + self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1) + self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1) + + self.in_channel = in_channel + + def forward(self, xs_pad, ilens, **kwargs): + """VGG2L forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) + :rtype: torch.Tensor + """ + logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) + + # x: utt x frame x dim + # xs_pad = F.pad_sequence(xs_pad) + + # x: utt x 1 (input channel num) x frame x dim + xs_pad = xs_pad.view( + xs_pad.size(0), + xs_pad.size(1), + self.in_channel, + xs_pad.size(2) // self.in_channel, + ).transpose(1, 2) + + # NOTE: max_pool1d ? + xs_pad = F.relu(self.conv1_1(xs_pad)) + xs_pad = F.relu(self.conv1_2(xs_pad)) + xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True) + + xs_pad = F.relu(self.conv2_1(xs_pad)) + xs_pad = F.relu(self.conv2_2(xs_pad)) + xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True) + if torch.is_tensor(ilens): + ilens = ilens.cpu().numpy() + else: + ilens = np.array(ilens, dtype=np.float32) + ilens = np.array(np.ceil(ilens / 2), dtype=np.int64) + ilens = np.array( + np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64 + ).tolist() + + # x: utt_list of frame (remove zeropaded frames) x (input channel num x dim) + xs_pad = xs_pad.transpose(1, 2) + xs_pad = xs_pad.contiguous().view( + xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3) + ) + return xs_pad, ilens, None # no state in this layer + + +class Encoder(torch.nn.Module): + """Encoder module + + :param str etype: type of encoder network + :param int idim: number of dimensions of encoder network + :param int elayers: number of layers of encoder network + :param int eunits: number of lstm units of encoder network + :param int eprojs: number of projection units of encoder network + :param np.ndarray subsample: list of subsampling numbers + :param float dropout: dropout rate + :param int in_channel: number of input channels + """ + + def __init__( + self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1 + ): + super(Encoder, self).__init__() + typ = etype.lstrip("vgg").rstrip("p") + if typ not in ["lstm", "gru", "blstm", "bgru"]: + logging.error("Error: need to specify an appropriate encoder architecture") + + if etype.startswith("vgg"): + if etype[-1] == "p": + self.enc = torch.nn.ModuleList( + [ + VGG2L(in_channel), + RNNP( + get_vgg2l_odim(idim, in_channel=in_channel), + elayers, + eunits, + eprojs, + subsample, + dropout, + typ=typ, + ), + ] + ) + logging.info("Use CNN-VGG + " + typ.upper() + "P for encoder") + else: + self.enc = torch.nn.ModuleList( + [ + VGG2L(in_channel), + RNN( + get_vgg2l_odim(idim, in_channel=in_channel), + elayers, + eunits, + eprojs, + dropout, + typ=typ, + ), + ] + ) + logging.info("Use CNN-VGG + " + typ.upper() + " for encoder") + self.conv_subsampling_factor = 4 + else: + if etype[-1] == "p": + self.enc = torch.nn.ModuleList( + [RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)] + ) + logging.info(typ.upper() + " with every-layer projection for encoder") + else: + self.enc = torch.nn.ModuleList( + [RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)] + ) + logging.info(typ.upper() + " without projection for encoder") + self.conv_subsampling_factor = 1 + + def forward(self, xs_pad, ilens, prev_states=None): + """Encoder forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...) + :return: batch of hidden state sequences (B, Tmax, eprojs) + :rtype: torch.Tensor + """ + if prev_states is None: + prev_states = [None] * len(self.enc) + assert len(prev_states) == len(self.enc) + + current_states = [] + for module, prev_state in zip(self.enc, prev_states): + xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state) + current_states.append(states) + + # make mask to remove bias value in padded part + mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1)) + + return xs_pad.masked_fill(mask, 0.0), ilens, current_states + + +def encoder_for(args, idim, subsample): + """Instantiates an encoder module given the program arguments + + :param Namespace args: The arguments + :param int or List of integer idim: dimension of input, e.g. 83, or + List of dimensions of inputs, e.g. [83,83] + :param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or + List of subsample factors of each encoder. + e.g. [[1,2,2,1,1], [1,2,2,1,1]] + :rtype torch.nn.Module + :return: The encoder module + """ + num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility + if num_encs == 1: + # compatible with single encoder asr mode + return Encoder( + args.etype, + idim, + args.elayers, + args.eunits, + args.eprojs, + subsample, + args.dropout_rate, + ) + elif num_encs >= 1: + enc_list = torch.nn.ModuleList() + for idx in range(num_encs): + enc = Encoder( + args.etype[idx], + idim[idx], + args.elayers[idx], + args.eunits[idx], + args.eprojs, + subsample[idx], + args.dropout_rate[idx], + ) + enc_list.append(enc) + return enc_list + else: + raise ValueError( + "Number of encoders needs to be more than one. {}".format(num_encs) + ) diff --git a/espnet/nets/pytorch_backend/streaming/__init__.py b/espnet/nets/pytorch_backend/streaming/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/streaming/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/streaming/segment.py b/espnet/nets/pytorch_backend/streaming/segment.py new file mode 100644 index 0000000000000000000000000000000000000000..45a2758c9d3c86d17cdd55884952e88d0b21240e --- /dev/null +++ b/espnet/nets/pytorch_backend/streaming/segment.py @@ -0,0 +1,129 @@ +import numpy as np +import torch + + +class SegmentStreamingE2E(object): + """SegmentStreamingE2E constructor. + + :param E2E e2e: E2E ASR object + :param recog_args: arguments for "recognize" method of E2E + """ + + def __init__(self, e2e, recog_args, rnnlm=None): + self._e2e = e2e + self._recog_args = recog_args + self._char_list = e2e.char_list + self._rnnlm = rnnlm + + self._e2e.eval() + + self._blank_idx_in_char_list = -1 + for idx in range(len(self._char_list)): + if self._char_list[idx] == self._e2e.blank: + self._blank_idx_in_char_list = idx + break + + self._subsampling_factor = np.prod(e2e.subsample) + self._activates = 0 + self._blank_dur = 0 + + self._previous_input = [] + self._previous_encoder_recurrent_state = None + self._encoder_states = [] + self._ctc_posteriors = [] + + assert ( + self._recog_args.batchsize <= 1 + ), "SegmentStreamingE2E works only with batch size <= 1" + assert ( + "b" not in self._e2e.etype + ), "SegmentStreamingE2E works only with uni-directional encoders" + + def accept_input(self, x): + """Call this method each time a new batch of input is available.""" + + self._previous_input.extend(x) + h, ilen = self._e2e.subsample_frames(x) + + # Run encoder and apply greedy search on CTC softmax output + h, _, self._previous_encoder_recurrent_state = self._e2e.enc( + h.unsqueeze(0), ilen, self._previous_encoder_recurrent_state + ) + z = self._e2e.ctc.argmax(h).squeeze(0) + + if self._activates == 0 and z[0] != self._blank_idx_in_char_list: + self._activates = 1 + + # Rerun encoder with zero state at onset of detection + tail_len = self._subsampling_factor * ( + self._recog_args.streaming_onset_margin + 1 + ) + h, ilen = self._e2e.subsample_frames( + np.reshape( + self._previous_input[-tail_len:], [-1, len(self._previous_input[0])] + ) + ) + h, _, self._previous_encoder_recurrent_state = self._e2e.enc( + h.unsqueeze(0), ilen, None + ) + + hyp = None + if self._activates == 1: + self._encoder_states.extend(h.squeeze(0)) + self._ctc_posteriors.extend(self._e2e.ctc.log_softmax(h).squeeze(0)) + + if z[0] == self._blank_idx_in_char_list: + self._blank_dur += 1 + else: + self._blank_dur = 0 + + if self._blank_dur >= self._recog_args.streaming_min_blank_dur: + seg_len = ( + len(self._encoder_states) + - self._blank_dur + + self._recog_args.streaming_offset_margin + ) + if seg_len > 0: + # Run decoder with a detected segment + h = torch.cat(self._encoder_states[:seg_len], dim=0).view( + -1, self._encoder_states[0].size(0) + ) + if self._recog_args.ctc_weight > 0.0: + lpz = torch.cat(self._ctc_posteriors[:seg_len], dim=0).view( + -1, self._ctc_posteriors[0].size(0) + ) + if self._recog_args.batchsize > 0: + lpz = lpz.unsqueeze(0) + normalize_score = False + else: + lpz = None + normalize_score = True + + if self._recog_args.batchsize == 0: + hyp = self._e2e.dec.recognize_beam( + h, lpz, self._recog_args, self._char_list, self._rnnlm + ) + else: + hlens = torch.tensor([h.shape[0]]) + hyp = self._e2e.dec.recognize_beam_batch( + h.unsqueeze(0), + hlens, + lpz, + self._recog_args, + self._char_list, + self._rnnlm, + normalize_score=normalize_score, + )[0] + + self._activates = 0 + self._blank_dur = 0 + + tail_len = ( + self._subsampling_factor + * self._recog_args.streaming_onset_margin + ) + self._previous_input = self._previous_input[-tail_len:] + self._encoder_states = [] + self._ctc_posteriors = [] + + return hyp diff --git a/espnet/nets/pytorch_backend/streaming/window.py b/espnet/nets/pytorch_backend/streaming/window.py new file mode 100644 index 0000000000000000000000000000000000000000..5565c232eb6feebfd0595fa46d07c0ecfc32c3dc --- /dev/null +++ b/espnet/nets/pytorch_backend/streaming/window.py @@ -0,0 +1,81 @@ +import torch + + +# TODO(pzelasko): Currently allows half-streaming only; +# needs streaming attention decoder implementation +class WindowStreamingE2E(object): + """WindowStreamingE2E constructor. + + :param E2E e2e: E2E ASR object + :param recog_args: arguments for "recognize" method of E2E + """ + + def __init__(self, e2e, recog_args, rnnlm=None): + self._e2e = e2e + self._recog_args = recog_args + self._char_list = e2e.char_list + self._rnnlm = rnnlm + + self._e2e.eval() + + self._offset = 0 + self._previous_encoder_recurrent_state = None + self._encoder_states = [] + self._ctc_posteriors = [] + self._last_recognition = None + + assert ( + self._recog_args.ctc_weight > 0.0 + ), "WindowStreamingE2E works only with combined CTC and attention decoders." + + def accept_input(self, x): + """Call this method each time a new batch of input is available.""" + + h, ilen = self._e2e.subsample_frames(x) + + # Streaming encoder + h, _, self._previous_encoder_recurrent_state = self._e2e.enc( + h.unsqueeze(0), ilen, self._previous_encoder_recurrent_state + ) + self._encoder_states.append(h.squeeze(0)) + + # CTC posteriors for the incoming audio + self._ctc_posteriors.append(self._e2e.ctc.log_softmax(h).squeeze(0)) + + def _input_window_for_decoder(self, use_all=False): + if use_all: + return ( + torch.cat(self._encoder_states, dim=0), + torch.cat(self._ctc_posteriors, dim=0), + ) + + def select_unprocessed_windows(window_tensors): + last_offset = self._offset + offset_traversed = 0 + selected_windows = [] + for es in window_tensors: + if offset_traversed > last_offset: + selected_windows.append(es) + continue + offset_traversed += es.size(1) + return torch.cat(selected_windows, dim=0) + + return ( + select_unprocessed_windows(self._encoder_states), + select_unprocessed_windows(self._ctc_posteriors), + ) + + def decode_with_attention_offline(self): + """Run the attention decoder offline. + + Works even if the previous layers (encoder and CTC decoder) were + being run in the online mode. + This method should be run after all the audio has been consumed. + This is used mostly to compare the results between offline + and online implementation of the previous layers. + """ + h, lpz = self._input_window_for_decoder(use_all=True) + + return self._e2e.dec.recognize_beam( + h, lpz, self._recog_args, self._char_list, self._rnnlm + ) diff --git a/espnet/nets/pytorch_backend/tacotron2/__init__.py b/espnet/nets/pytorch_backend/tacotron2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/tacotron2/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/tacotron2/cbhg.py b/espnet/nets/pytorch_backend/tacotron2/cbhg.py new file mode 100644 index 0000000000000000000000000000000000000000..c869e0f8c63c89e9e51fb130632b625ab906c3b5 --- /dev/null +++ b/espnet/nets/pytorch_backend/tacotron2/cbhg.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""CBHG related modules.""" + +import torch +import torch.nn.functional as F + +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask + + +class CBHGLoss(torch.nn.Module): + """Loss function module for CBHG.""" + + def __init__(self, use_masking=True): + """Initialize CBHG loss module. + + Args: + use_masking (bool): Whether to mask padded part in loss calculation. + + """ + super(CBHGLoss, self).__init__() + self.use_masking = use_masking + + def forward(self, cbhg_outs, spcs, olens): + """Calculate forward propagation. + + Args: + cbhg_outs (Tensor): Batch of CBHG outputs (B, Lmax, spc_dim). + spcs (Tensor): Batch of groundtruth of spectrogram (B, Lmax, spc_dim). + olens (LongTensor): Batch of the lengths of each sequence (B,). + + Returns: + Tensor: L1 loss value + Tensor: Mean square error loss value. + + """ + # perform masking for padded values + if self.use_masking: + mask = make_non_pad_mask(olens).unsqueeze(-1).to(spcs.device) + spcs = spcs.masked_select(mask) + cbhg_outs = cbhg_outs.masked_select(mask) + + # calculate loss + cbhg_l1_loss = F.l1_loss(cbhg_outs, spcs) + cbhg_mse_loss = F.mse_loss(cbhg_outs, spcs) + + return cbhg_l1_loss, cbhg_mse_loss + + +class CBHG(torch.nn.Module): + """CBHG module to convert log Mel-filterbanks to linear spectrogram. + + This is a module of CBHG introduced + in `Tacotron: Towards End-to-End Speech Synthesis`_. + The CBHG converts the sequence of log Mel-filterbanks into linear spectrogram. + + .. _`Tacotron: Towards End-to-End Speech Synthesis`: + https://arxiv.org/abs/1703.10135 + + """ + + def __init__( + self, + idim, + odim, + conv_bank_layers=8, + conv_bank_chans=128, + conv_proj_filts=3, + conv_proj_chans=256, + highway_layers=4, + highway_units=128, + gru_units=256, + ): + """Initialize CBHG module. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + conv_bank_layers (int, optional): The number of convolution bank layers. + conv_bank_chans (int, optional): The number of channels in convolution bank. + conv_proj_filts (int, optional): + Kernel size of convolutional projection layer. + conv_proj_chans (int, optional): + The number of channels in convolutional projection layer. + highway_layers (int, optional): The number of highway network layers. + highway_units (int, optional): The number of highway network units. + gru_units (int, optional): The number of GRU units (for both directions). + + """ + super(CBHG, self).__init__() + self.idim = idim + self.odim = odim + self.conv_bank_layers = conv_bank_layers + self.conv_bank_chans = conv_bank_chans + self.conv_proj_filts = conv_proj_filts + self.conv_proj_chans = conv_proj_chans + self.highway_layers = highway_layers + self.highway_units = highway_units + self.gru_units = gru_units + + # define 1d convolution bank + self.conv_bank = torch.nn.ModuleList() + for k in range(1, self.conv_bank_layers + 1): + if k % 2 != 0: + padding = (k - 1) // 2 + else: + padding = ((k - 1) // 2, (k - 1) // 2 + 1) + self.conv_bank += [ + torch.nn.Sequential( + torch.nn.ConstantPad1d(padding, 0.0), + torch.nn.Conv1d( + idim, self.conv_bank_chans, k, stride=1, padding=0, bias=True + ), + torch.nn.BatchNorm1d(self.conv_bank_chans), + torch.nn.ReLU(), + ) + ] + + # define max pooling (need padding for one-side to keep same length) + self.max_pool = torch.nn.Sequential( + torch.nn.ConstantPad1d((0, 1), 0.0), torch.nn.MaxPool1d(2, stride=1) + ) + + # define 1d convolution projection + self.projections = torch.nn.Sequential( + torch.nn.Conv1d( + self.conv_bank_chans * self.conv_bank_layers, + self.conv_proj_chans, + self.conv_proj_filts, + stride=1, + padding=(self.conv_proj_filts - 1) // 2, + bias=True, + ), + torch.nn.BatchNorm1d(self.conv_proj_chans), + torch.nn.ReLU(), + torch.nn.Conv1d( + self.conv_proj_chans, + self.idim, + self.conv_proj_filts, + stride=1, + padding=(self.conv_proj_filts - 1) // 2, + bias=True, + ), + torch.nn.BatchNorm1d(self.idim), + ) + + # define highway network + self.highways = torch.nn.ModuleList() + self.highways += [torch.nn.Linear(idim, self.highway_units)] + for _ in range(self.highway_layers): + self.highways += [HighwayNet(self.highway_units)] + + # define bidirectional GRU + self.gru = torch.nn.GRU( + self.highway_units, + gru_units // 2, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + + # define final projection + self.output = torch.nn.Linear(gru_units, odim, bias=True) + + def forward(self, xs, ilens): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of the padded sequences of inputs (B, Tmax, idim). + ilens (LongTensor): Batch of lengths of each input sequence (B,). + + Return: + Tensor: Batch of the padded sequence of outputs (B, Tmax, odim). + LongTensor: Batch of lengths of each output sequence (B,). + + """ + xs = xs.transpose(1, 2) # (B, idim, Tmax) + convs = [] + for k in range(self.conv_bank_layers): + convs += [self.conv_bank[k](xs)] + convs = torch.cat(convs, dim=1) # (B, #CH * #BANK, Tmax) + convs = self.max_pool(convs) + convs = self.projections(convs).transpose(1, 2) # (B, Tmax, idim) + xs = xs.transpose(1, 2) + convs + # + 1 for dimension adjustment layer + for i in range(self.highway_layers + 1): + xs = self.highways[i](xs) + + # sort by length + xs, ilens, sort_idx = self._sort_by_length(xs, ilens) + + # total_length needs for DataParallel + # (see https://github.com/pytorch/pytorch/pull/6327) + total_length = xs.size(1) + if not isinstance(ilens, torch.Tensor): + ilens = torch.tensor(ilens) + xs = pack_padded_sequence(xs, ilens.cpu(), batch_first=True) + self.gru.flatten_parameters() + xs, _ = self.gru(xs) + xs, ilens = pad_packed_sequence(xs, batch_first=True, total_length=total_length) + + # revert sorting by length + xs, ilens = self._revert_sort_by_length(xs, ilens, sort_idx) + + xs = self.output(xs) # (B, Tmax, odim) + + return xs, ilens + + def inference(self, x): + """Inference. + + Args: + x (Tensor): The sequences of inputs (T, idim). + + Return: + Tensor: The sequence of outputs (T, odim). + + """ + assert len(x.size()) == 2 + xs = x.unsqueeze(0) + ilens = x.new([x.size(0)]).long() + + return self.forward(xs, ilens)[0][0] + + def _sort_by_length(self, xs, ilens): + sort_ilens, sort_idx = ilens.sort(0, descending=True) + return xs[sort_idx], ilens[sort_idx], sort_idx + + def _revert_sort_by_length(self, xs, ilens, sort_idx): + _, revert_idx = sort_idx.sort(0) + return xs[revert_idx], ilens[revert_idx] + + +class HighwayNet(torch.nn.Module): + """Highway Network module. + + This is a module of Highway Network introduced in `Highway Networks`_. + + .. _`Highway Networks`: https://arxiv.org/abs/1505.00387 + + """ + + def __init__(self, idim): + """Initialize Highway Network module. + + Args: + idim (int): Dimension of the inputs. + + """ + super(HighwayNet, self).__init__() + self.idim = idim + self.projection = torch.nn.Sequential( + torch.nn.Linear(idim, idim), torch.nn.ReLU() + ) + self.gate = torch.nn.Sequential(torch.nn.Linear(idim, idim), torch.nn.Sigmoid()) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Batch of inputs (B, ..., idim). + + Returns: + Tensor: Batch of outputs, which are the same shape as inputs (B, ..., idim). + + """ + proj = self.projection(x) + gate = self.gate(x) + return proj * gate + x * (1.0 - gate) diff --git a/espnet/nets/pytorch_backend/tacotron2/decoder.py b/espnet/nets/pytorch_backend/tacotron2/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a5b9ba23b50fa0ef3b7ec9e753d2ae5e43eb2a --- /dev/null +++ b/espnet/nets/pytorch_backend/tacotron2/decoder.py @@ -0,0 +1,676 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Tacotron2 decoder related modules.""" + +import six + +import torch +import torch.nn.functional as F + +from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA + + +def decoder_init(m): + """Initialize decoder parameters.""" + if isinstance(m, torch.nn.Conv1d): + torch.nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("tanh")) + + +class ZoneOutCell(torch.nn.Module): + """ZoneOut Cell module. + + This is a module of zoneout described in + `Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`_. + This code is modified from `eladhoffer/seq2seq.pytorch`_. + + Examples: + >>> lstm = torch.nn.LSTMCell(16, 32) + >>> lstm = ZoneOutCell(lstm, 0.5) + + .. _`Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations`: + https://arxiv.org/abs/1606.01305 + + .. _`eladhoffer/seq2seq.pytorch`: + https://github.com/eladhoffer/seq2seq.pytorch + + """ + + def __init__(self, cell, zoneout_rate=0.1): + """Initialize zone out cell module. + + Args: + cell (torch.nn.Module): Pytorch recurrent cell module + e.g. `torch.nn.Module.LSTMCell`. + zoneout_rate (float, optional): Probability of zoneout from 0.0 to 1.0. + + """ + super(ZoneOutCell, self).__init__() + self.cell = cell + self.hidden_size = cell.hidden_size + self.zoneout_rate = zoneout_rate + if zoneout_rate > 1.0 or zoneout_rate < 0.0: + raise ValueError( + "zoneout probability must be in the range from 0.0 to 1.0." + ) + + def forward(self, inputs, hidden): + """Calculate forward propagation. + + Args: + inputs (Tensor): Batch of input tensor (B, input_size). + hidden (tuple): + - Tensor: Batch of initial hidden states (B, hidden_size). + - Tensor: Batch of initial cell states (B, hidden_size). + + Returns: + tuple: + - Tensor: Batch of next hidden states (B, hidden_size). + - Tensor: Batch of next cell states (B, hidden_size). + + """ + next_hidden = self.cell(inputs, hidden) + next_hidden = self._zoneout(hidden, next_hidden, self.zoneout_rate) + return next_hidden + + def _zoneout(self, h, next_h, prob): + # apply recursively + if isinstance(h, tuple): + num_h = len(h) + if not isinstance(prob, tuple): + prob = tuple([prob] * num_h) + return tuple( + [self._zoneout(h[i], next_h[i], prob[i]) for i in range(num_h)] + ) + + if self.training: + mask = h.new(*h.size()).bernoulli_(prob) + return mask * h + (1 - mask) * next_h + else: + return prob * h + (1 - prob) * next_h + + +class Prenet(torch.nn.Module): + """Prenet module for decoder of Spectrogram prediction network. + + This is a module of Prenet in the decoder of Spectrogram prediction network, + which described in `Natural TTS + Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. + The Prenet preforms nonlinear conversion + of inputs before input to auto-regressive lstm, + which helps to learn diagonal attentions. + + Note: + This module alway applies dropout even in evaluation. + See the detail in `Natural TTS Synthesis by + Conditioning WaveNet on Mel Spectrogram Predictions`_. + + .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: + https://arxiv.org/abs/1712.05884 + + """ + + def __init__(self, idim, n_layers=2, n_units=256, dropout_rate=0.5): + """Initialize prenet module. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + n_layers (int, optional): The number of prenet layers. + n_units (int, optional): The number of prenet units. + + """ + super(Prenet, self).__init__() + self.dropout_rate = dropout_rate + self.prenet = torch.nn.ModuleList() + for layer in six.moves.range(n_layers): + n_inputs = idim if layer == 0 else n_units + self.prenet += [ + torch.nn.Sequential(torch.nn.Linear(n_inputs, n_units), torch.nn.ReLU()) + ] + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Batch of input tensors (B, ..., idim). + + Returns: + Tensor: Batch of output tensors (B, ..., odim). + + """ + for i in six.moves.range(len(self.prenet)): + x = F.dropout(self.prenet[i](x), self.dropout_rate) + return x + + +class Postnet(torch.nn.Module): + """Postnet module for Spectrogram prediction network. + + This is a module of Postnet in Spectrogram prediction network, + which described in `Natural TTS Synthesis by + Conditioning WaveNet on Mel Spectrogram Predictions`_. + The Postnet predicts refines the predicted + Mel-filterbank of the decoder, + which helps to compensate the detail sturcture of spectrogram. + + .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: + https://arxiv.org/abs/1712.05884 + + """ + + def __init__( + self, + idim, + odim, + n_layers=5, + n_chans=512, + n_filts=5, + dropout_rate=0.5, + use_batch_norm=True, + ): + """Initialize postnet module. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + n_layers (int, optional): The number of layers. + n_filts (int, optional): The number of filter size. + n_units (int, optional): The number of filter channels. + use_batch_norm (bool, optional): Whether to use batch normalization.. + dropout_rate (float, optional): Dropout rate.. + + """ + super(Postnet, self).__init__() + self.postnet = torch.nn.ModuleList() + for layer in six.moves.range(n_layers - 1): + ichans = odim if layer == 0 else n_chans + ochans = odim if layer == n_layers - 1 else n_chans + if use_batch_norm: + self.postnet += [ + torch.nn.Sequential( + torch.nn.Conv1d( + ichans, + ochans, + n_filts, + stride=1, + padding=(n_filts - 1) // 2, + bias=False, + ), + torch.nn.BatchNorm1d(ochans), + torch.nn.Tanh(), + torch.nn.Dropout(dropout_rate), + ) + ] + else: + self.postnet += [ + torch.nn.Sequential( + torch.nn.Conv1d( + ichans, + ochans, + n_filts, + stride=1, + padding=(n_filts - 1) // 2, + bias=False, + ), + torch.nn.Tanh(), + torch.nn.Dropout(dropout_rate), + ) + ] + ichans = n_chans if n_layers != 1 else odim + if use_batch_norm: + self.postnet += [ + torch.nn.Sequential( + torch.nn.Conv1d( + ichans, + odim, + n_filts, + stride=1, + padding=(n_filts - 1) // 2, + bias=False, + ), + torch.nn.BatchNorm1d(odim), + torch.nn.Dropout(dropout_rate), + ) + ] + else: + self.postnet += [ + torch.nn.Sequential( + torch.nn.Conv1d( + ichans, + odim, + n_filts, + stride=1, + padding=(n_filts - 1) // 2, + bias=False, + ), + torch.nn.Dropout(dropout_rate), + ) + ] + + def forward(self, xs): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax). + + Returns: + Tensor: Batch of padded output tensor. (B, odim, Tmax). + + """ + for i in six.moves.range(len(self.postnet)): + xs = self.postnet[i](xs) + return xs + + +class Decoder(torch.nn.Module): + """Decoder module of Spectrogram prediction network. + + This is a module of decoder of Spectrogram prediction network in Tacotron2, + which described in `Natural TTS + Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_. + The decoder generates the sequence of + features from the sequence of the hidden states. + + .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: + https://arxiv.org/abs/1712.05884 + + """ + + def __init__( + self, + idim, + odim, + att, + dlayers=2, + dunits=1024, + prenet_layers=2, + prenet_units=256, + postnet_layers=5, + postnet_chans=512, + postnet_filts=5, + output_activation_fn=None, + cumulate_att_w=True, + use_batch_norm=True, + use_concate=True, + dropout_rate=0.5, + zoneout_rate=0.1, + reduction_factor=1, + ): + """Initialize Tacotron2 decoder module. + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + att (torch.nn.Module): Instance of attention class. + dlayers (int, optional): The number of decoder lstm layers. + dunits (int, optional): The number of decoder lstm units. + prenet_layers (int, optional): The number of prenet layers. + prenet_units (int, optional): The number of prenet units. + postnet_layers (int, optional): The number of postnet layers. + postnet_filts (int, optional): The number of postnet filter size. + postnet_chans (int, optional): The number of postnet filter channels. + output_activation_fn (torch.nn.Module, optional): + Activation function for outputs. + cumulate_att_w (bool, optional): + Whether to cumulate previous attention weight. + use_batch_norm (bool, optional): Whether to use batch normalization. + use_concate (bool, optional): Whether to concatenate encoder embedding + with decoder lstm outputs. + dropout_rate (float, optional): Dropout rate. + zoneout_rate (float, optional): Zoneout rate. + reduction_factor (int, optional): Reduction factor. + + """ + super(Decoder, self).__init__() + + # store the hyperparameters + self.idim = idim + self.odim = odim + self.att = att + self.output_activation_fn = output_activation_fn + self.cumulate_att_w = cumulate_att_w + self.use_concate = use_concate + self.reduction_factor = reduction_factor + + # check attention type + if isinstance(self.att, AttForwardTA): + self.use_att_extra_inputs = True + else: + self.use_att_extra_inputs = False + + # define lstm network + prenet_units = prenet_units if prenet_layers != 0 else odim + self.lstm = torch.nn.ModuleList() + for layer in six.moves.range(dlayers): + iunits = idim + prenet_units if layer == 0 else dunits + lstm = torch.nn.LSTMCell(iunits, dunits) + if zoneout_rate > 0.0: + lstm = ZoneOutCell(lstm, zoneout_rate) + self.lstm += [lstm] + + # define prenet + if prenet_layers > 0: + self.prenet = Prenet( + idim=odim, + n_layers=prenet_layers, + n_units=prenet_units, + dropout_rate=dropout_rate, + ) + else: + self.prenet = None + + # define postnet + if postnet_layers > 0: + self.postnet = Postnet( + idim=idim, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=use_batch_norm, + dropout_rate=dropout_rate, + ) + else: + self.postnet = None + + # define projection layers + iunits = idim + dunits if use_concate else dunits + self.feat_out = torch.nn.Linear(iunits, odim * reduction_factor, bias=False) + self.prob_out = torch.nn.Linear(iunits, reduction_factor) + + # initialize + self.apply(decoder_init) + + def _zero_state(self, hs): + init_hs = hs.new_zeros(hs.size(0), self.lstm[0].hidden_size) + return init_hs + + def forward(self, hs, hlens, ys): + """Calculate forward propagation. + + Args: + hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim). + hlens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): + Batch of the sequences of padded target features (B, Lmax, odim). + + Returns: + Tensor: Batch of output tensors after postnet (B, Lmax, odim). + Tensor: Batch of output tensors before postnet (B, Lmax, odim). + Tensor: Batch of logits of stop prediction (B, Lmax). + Tensor: Batch of attention weights (B, Lmax, Tmax). + + Note: + This computation is performed in teacher-forcing manner. + + """ + # thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim) + if self.reduction_factor > 1: + ys = ys[:, self.reduction_factor - 1 :: self.reduction_factor] + + # length list should be list of int + hlens = list(map(int, hlens)) + + # initialize hidden states of decoder + c_list = [self._zero_state(hs)] + z_list = [self._zero_state(hs)] + for _ in six.moves.range(1, len(self.lstm)): + c_list += [self._zero_state(hs)] + z_list += [self._zero_state(hs)] + prev_out = hs.new_zeros(hs.size(0), self.odim) + + # initialize attention + prev_att_w = None + self.att.reset() + + # loop for an output sequence + outs, logits, att_ws = [], [], [] + for y in ys.transpose(0, 1): + if self.use_att_extra_inputs: + att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out) + else: + att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) + prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out + xs = torch.cat([att_c, prenet_out], dim=1) + z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) + for i in six.moves.range(1, len(self.lstm)): + z_list[i], c_list[i] = self.lstm[i]( + z_list[i - 1], (z_list[i], c_list[i]) + ) + zcs = ( + torch.cat([z_list[-1], att_c], dim=1) + if self.use_concate + else z_list[-1] + ) + outs += [self.feat_out(zcs).view(hs.size(0), self.odim, -1)] + logits += [self.prob_out(zcs)] + att_ws += [att_w] + prev_out = y # teacher forcing + if self.cumulate_att_w and prev_att_w is not None: + prev_att_w = prev_att_w + att_w # Note: error when use += + else: + prev_att_w = att_w + + logits = torch.cat(logits, dim=1) # (B, Lmax) + before_outs = torch.cat(outs, dim=2) # (B, odim, Lmax) + att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax) + + if self.reduction_factor > 1: + before_outs = before_outs.view( + before_outs.size(0), self.odim, -1 + ) # (B, odim, Lmax) + + if self.postnet is not None: + after_outs = before_outs + self.postnet(before_outs) # (B, odim, Lmax) + else: + after_outs = before_outs + before_outs = before_outs.transpose(2, 1) # (B, Lmax, odim) + after_outs = after_outs.transpose(2, 1) # (B, Lmax, odim) + logits = logits + + # apply activation function for scaling + if self.output_activation_fn is not None: + before_outs = self.output_activation_fn(before_outs) + after_outs = self.output_activation_fn(after_outs) + + return after_outs, before_outs, logits, att_ws + + def inference( + self, + h, + threshold=0.5, + minlenratio=0.0, + maxlenratio=10.0, + use_att_constraint=False, + backward_window=None, + forward_window=None, + ): + """Generate the sequence of features given the sequences of characters. + + Args: + h (Tensor): Input sequence of encoder hidden states (T, C). + threshold (float, optional): Threshold to stop generation. + minlenratio (float, optional): Minimum length ratio. + If set to 1.0 and the length of input is 10, + the minimum length of outputs will be 10 * 1 = 10. + minlenratio (float, optional): Minimum length ratio. + If set to 10 and the length of input is 10, + the maximum length of outputs will be 10 * 10 = 100. + use_att_constraint (bool): + Whether to apply attention constraint introduced in `Deep Voice 3`_. + backward_window (int): Backward window size in attention constraint. + forward_window (int): Forward window size in attention constraint. + + Returns: + Tensor: Output sequence of features (L, odim). + Tensor: Output sequence of stop probabilities (L,). + Tensor: Attention weights (L, T). + + Note: + This computation is performed in auto-regressive manner. + + .. _`Deep Voice 3`: https://arxiv.org/abs/1710.07654 + + """ + # setup + assert len(h.size()) == 2 + hs = h.unsqueeze(0) + ilens = [h.size(0)] + maxlen = int(h.size(0) * maxlenratio) + minlen = int(h.size(0) * minlenratio) + + # initialize hidden states of decoder + c_list = [self._zero_state(hs)] + z_list = [self._zero_state(hs)] + for _ in six.moves.range(1, len(self.lstm)): + c_list += [self._zero_state(hs)] + z_list += [self._zero_state(hs)] + prev_out = hs.new_zeros(1, self.odim) + + # initialize attention + prev_att_w = None + self.att.reset() + + # setup for attention constraint + if use_att_constraint: + last_attended_idx = 0 + else: + last_attended_idx = None + + # loop for an output sequence + idx = 0 + outs, att_ws, probs = [], [], [] + while True: + # updated index + idx += self.reduction_factor + + # decoder calculation + if self.use_att_extra_inputs: + att_c, att_w = self.att( + hs, + ilens, + z_list[0], + prev_att_w, + prev_out, + last_attended_idx=last_attended_idx, + backward_window=backward_window, + forward_window=forward_window, + ) + else: + att_c, att_w = self.att( + hs, + ilens, + z_list[0], + prev_att_w, + last_attended_idx=last_attended_idx, + backward_window=backward_window, + forward_window=forward_window, + ) + + att_ws += [att_w] + prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out + xs = torch.cat([att_c, prenet_out], dim=1) + z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) + for i in six.moves.range(1, len(self.lstm)): + z_list[i], c_list[i] = self.lstm[i]( + z_list[i - 1], (z_list[i], c_list[i]) + ) + zcs = ( + torch.cat([z_list[-1], att_c], dim=1) + if self.use_concate + else z_list[-1] + ) + outs += [self.feat_out(zcs).view(1, self.odim, -1)] # [(1, odim, r), ...] + probs += [torch.sigmoid(self.prob_out(zcs))[0]] # [(r), ...] + if self.output_activation_fn is not None: + prev_out = self.output_activation_fn(outs[-1][:, :, -1]) # (1, odim) + else: + prev_out = outs[-1][:, :, -1] # (1, odim) + if self.cumulate_att_w and prev_att_w is not None: + prev_att_w = prev_att_w + att_w # Note: error when use += + else: + prev_att_w = att_w + if use_att_constraint: + last_attended_idx = int(att_w.argmax()) + + # check whether to finish generation + if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: + # check mininum length + if idx < minlen: + continue + outs = torch.cat(outs, dim=2) # (1, odim, L) + if self.postnet is not None: + outs = outs + self.postnet(outs) # (1, odim, L) + outs = outs.transpose(2, 1).squeeze(0) # (L, odim) + probs = torch.cat(probs, dim=0) + att_ws = torch.cat(att_ws, dim=0) + break + + if self.output_activation_fn is not None: + outs = self.output_activation_fn(outs) + + return outs, probs, att_ws + + def calculate_all_attentions(self, hs, hlens, ys): + """Calculate all of the attention weights. + + Args: + hs (Tensor): Batch of the sequences of padded hidden states (B, Tmax, idim). + hlens (LongTensor): Batch of lengths of each input batch (B,). + ys (Tensor): + Batch of the sequences of padded target features (B, Lmax, odim). + + Returns: + numpy.ndarray: Batch of attention weights (B, Lmax, Tmax). + + Note: + This computation is performed in teacher-forcing manner. + + """ + # thin out frames (B, Lmax, odim) -> (B, Lmax/r, odim) + if self.reduction_factor > 1: + ys = ys[:, self.reduction_factor - 1 :: self.reduction_factor] + + # length list should be list of int + hlens = list(map(int, hlens)) + + # initialize hidden states of decoder + c_list = [self._zero_state(hs)] + z_list = [self._zero_state(hs)] + for _ in six.moves.range(1, len(self.lstm)): + c_list += [self._zero_state(hs)] + z_list += [self._zero_state(hs)] + prev_out = hs.new_zeros(hs.size(0), self.odim) + + # initialize attention + prev_att_w = None + self.att.reset() + + # loop for an output sequence + att_ws = [] + for y in ys.transpose(0, 1): + if self.use_att_extra_inputs: + att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w, prev_out) + else: + att_c, att_w = self.att(hs, hlens, z_list[0], prev_att_w) + att_ws += [att_w] + prenet_out = self.prenet(prev_out) if self.prenet is not None else prev_out + xs = torch.cat([att_c, prenet_out], dim=1) + z_list[0], c_list[0] = self.lstm[0](xs, (z_list[0], c_list[0])) + for i in six.moves.range(1, len(self.lstm)): + z_list[i], c_list[i] = self.lstm[i]( + z_list[i - 1], (z_list[i], c_list[i]) + ) + prev_out = y # teacher forcing + if self.cumulate_att_w and prev_att_w is not None: + prev_att_w = prev_att_w + att_w # Note: error when use += + else: + prev_att_w = att_w + + att_ws = torch.stack(att_ws, dim=1) # (B, Lmax, Tmax) + + return att_ws diff --git a/espnet/nets/pytorch_backend/tacotron2/encoder.py b/espnet/nets/pytorch_backend/tacotron2/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fee4b1c555205ba7ef0176cc033743d9a360bafa --- /dev/null +++ b/espnet/nets/pytorch_backend/tacotron2/encoder.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Tacotron2 encoder related modules.""" + +import six + +import torch + +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + + +def encoder_init(m): + """Initialize encoder parameters.""" + if isinstance(m, torch.nn.Conv1d): + torch.nn.init.xavier_uniform_(m.weight, torch.nn.init.calculate_gain("relu")) + + +class Encoder(torch.nn.Module): + """Encoder module of Spectrogram prediction network. + + This is a module of encoder of Spectrogram prediction network in Tacotron2, + which described in `Natural TTS Synthesis by Conditioning WaveNet on Mel + Spectrogram Predictions`_. This is the encoder which converts either a sequence + of characters or acoustic features into the sequence of hidden states. + + .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: + https://arxiv.org/abs/1712.05884 + + """ + + def __init__( + self, + idim, + input_layer="embed", + embed_dim=512, + elayers=1, + eunits=512, + econv_layers=3, + econv_chans=512, + econv_filts=5, + use_batch_norm=True, + use_residual=False, + dropout_rate=0.5, + padding_idx=0, + ): + """Initialize Tacotron2 encoder module. + + Args: + idim (int) Dimension of the inputs. + input_layer (str): Input layer type. + embed_dim (int, optional) Dimension of character embedding. + elayers (int, optional) The number of encoder blstm layers. + eunits (int, optional) The number of encoder blstm units. + econv_layers (int, optional) The number of encoder conv layers. + econv_filts (int, optional) The number of encoder conv filter size. + econv_chans (int, optional) The number of encoder conv filter channels. + use_batch_norm (bool, optional) Whether to use batch normalization. + use_residual (bool, optional) Whether to use residual connection. + dropout_rate (float, optional) Dropout rate. + + """ + super(Encoder, self).__init__() + # store the hyperparameters + self.idim = idim + self.use_residual = use_residual + + # define network layer modules + if input_layer == "linear": + self.embed = torch.nn.Linear(idim, econv_chans) + elif input_layer == "embed": + self.embed = torch.nn.Embedding(idim, embed_dim, padding_idx=padding_idx) + else: + raise ValueError("unknown input_layer: " + input_layer) + + if econv_layers > 0: + self.convs = torch.nn.ModuleList() + for layer in six.moves.range(econv_layers): + ichans = ( + embed_dim if layer == 0 and input_layer == "embed" else econv_chans + ) + if use_batch_norm: + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv1d( + ichans, + econv_chans, + econv_filts, + stride=1, + padding=(econv_filts - 1) // 2, + bias=False, + ), + torch.nn.BatchNorm1d(econv_chans), + torch.nn.ReLU(), + torch.nn.Dropout(dropout_rate), + ) + ] + else: + self.convs += [ + torch.nn.Sequential( + torch.nn.Conv1d( + ichans, + econv_chans, + econv_filts, + stride=1, + padding=(econv_filts - 1) // 2, + bias=False, + ), + torch.nn.ReLU(), + torch.nn.Dropout(dropout_rate), + ) + ] + else: + self.convs = None + if elayers > 0: + iunits = econv_chans if econv_layers != 0 else embed_dim + self.blstm = torch.nn.LSTM( + iunits, eunits // 2, elayers, batch_first=True, bidirectional=True + ) + else: + self.blstm = None + + # initialize + self.apply(encoder_init) + + def forward(self, xs, ilens=None): + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of the padded sequence. Either character ids (B, Tmax) + or acoustic feature (B, Tmax, idim * encoder_reduction_factor). Padded + value should be 0. + ilens (LongTensor): Batch of lengths of each input batch (B,). + + Returns: + Tensor: Batch of the sequences of encoder states(B, Tmax, eunits). + LongTensor: Batch of lengths of each sequence (B,) + + """ + xs = self.embed(xs).transpose(1, 2) + if self.convs is not None: + for i in six.moves.range(len(self.convs)): + if self.use_residual: + xs += self.convs[i](xs) + else: + xs = self.convs[i](xs) + if self.blstm is None: + return xs.transpose(1, 2) + if not isinstance(ilens, torch.Tensor): + ilens = torch.tensor(ilens) + xs = pack_padded_sequence(xs.transpose(1, 2), ilens.cpu(), batch_first=True) + self.blstm.flatten_parameters() + xs, _ = self.blstm(xs) # (B, Tmax, C) + xs, hlens = pad_packed_sequence(xs, batch_first=True) + + return xs, hlens + + def inference(self, x): + """Inference. + + Args: + x (Tensor): The sequeunce of character ids (T,) + or acoustic feature (T, idim * encoder_reduction_factor). + + Returns: + Tensor: The sequences of encoder states(T, eunits). + + """ + xs = x.unsqueeze(0) + ilens = torch.tensor([x.size(0)]) + + return self.forward(xs, ilens)[0][0] diff --git a/espnet/nets/pytorch_backend/transducer/__init__.py b/espnet/nets/pytorch_backend/transducer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/transducer/arguments.py b/espnet/nets/pytorch_backend/transducer/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..96f9fda4a096c2ca4bfb61e4faded3960a7ee401 --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/arguments.py @@ -0,0 +1,321 @@ +"""Transducer model arguments.""" + +import ast +from distutils.util import strtobool + + +def add_encoder_general_arguments(group): + """Define general arguments for encoder.""" + group.add_argument( + "--etype", + default="blstmp", + type=str, + choices=[ + "custom", + "lstm", + "blstm", + "lstmp", + "blstmp", + "vgglstmp", + "vggblstmp", + "vgglstm", + "vggblstm", + "gru", + "bgru", + "grup", + "bgrup", + "vgggrup", + "vggbgrup", + "vgggru", + "vggbgru", + ], + help="Type of encoder network architecture", + ) + group.add_argument( + "--dropout-rate", + default=0.0, + type=float, + help="Dropout rate for the encoder", + ) + + return group + + +def add_rnn_encoder_arguments(group): + """Define arguments for RNN encoder.""" + group.add_argument( + "--elayers", + default=4, + type=int, + help="Number of encoder layers (for shared recognition part " + "in multi-speaker asr mode)", + ) + group.add_argument( + "--eunits", + "-u", + default=300, + type=int, + help="Number of encoder hidden units", + ) + group.add_argument( + "--eprojs", default=320, type=int, help="Number of encoder projection units" + ) + group.add_argument( + "--subsample", + default="1", + type=str, + help="Subsample input frames x_y_z means subsample every x frame " + "at 1st layer, every y frame at 2nd layer etc.", + ) + + return group + + +def add_custom_encoder_arguments(group): + """Define arguments for Custom encoder.""" + group.add_argument( + "--enc-block-arch", + type=eval, + action="append", + default=None, + help="Encoder architecture definition by blocks", + ) + group.add_argument( + "--enc-block-repeat", + default=0, + type=int, + help="Repeat N times the provided encoder blocks if N > 1", + ) + group.add_argument( + "--custom-enc-input-layer", + type=str, + default="conv2d", + choices=["conv2d", "vgg2l", "linear", "embed"], + help="Custom encoder input layer type", + ) + group.add_argument( + "--custom-enc-positional-encoding-type", + type=str, + default="abs_pos", + choices=["abs_pos", "scaled_abs_pos", "rel_pos"], + help="Custom encoder positional encoding layer type", + ) + group.add_argument( + "--custom-enc-self-attn-type", + type=str, + default="self_attn", + choices=["self_attn", "rel_self_attn"], + help="Custom encoder self-attention type", + ) + group.add_argument( + "--custom-enc-pw-activation-type", + type=str, + default="relu", + choices=["relu", "hardtanh", "selu", "swish"], + help="Custom encoder pointwise activation type", + ) + group.add_argument( + "--custom-enc-conv-mod-activation-type", + type=str, + default="swish", + choices=["relu", "hardtanh", "selu", "swish"], + help="Custom encoder convolutional module activation type", + ) + + return group + + +def add_decoder_general_arguments(group): + """Define general arguments for encoder.""" + group.add_argument( + "--dtype", + default="lstm", + type=str, + choices=["lstm", "gru", "custom"], + help="Type of decoder to use", + ) + group.add_argument( + "--dropout-rate-decoder", + default=0.0, + type=float, + help="Dropout rate for the decoder", + ) + group.add_argument( + "--dropout-rate-embed-decoder", + default=0.0, + type=float, + help="Dropout rate for the decoder embedding layer", + ) + + return group + + +def add_rnn_decoder_arguments(group): + """Define arguments for RNN decoder.""" + group.add_argument( + "--dec-embed-dim", + default=320, + type=int, + help="Number of decoder embeddings dimensions", + ) + group.add_argument( + "--dlayers", default=1, type=int, help="Number of decoder layers" + ) + group.add_argument( + "--dunits", default=320, type=int, help="Number of decoder hidden units" + ) + + return group + + +def add_custom_decoder_arguments(group): + """Define arguments for Custom decoder.""" + group.add_argument( + "--dec-block-arch", + type=eval, + action="append", + default=None, + help="Custom decoder blocks definition", + ) + group.add_argument( + "--dec-block-repeat", + default=1, + type=int, + help="Repeat N times the provided decoder blocks if N > 1", + ) + group.add_argument( + "--custom-dec-input-layer", + type=str, + default="embed", + choices=["linear", "embed"], + help="Custom decoder input layer type", + ) + group.add_argument( + "--custom-dec-pw-activation-type", + type=str, + default="relu", + choices=["relu", "hardtanh", "selu", "swish"], + help="Custom decoder pointwise activation type", + ) + + return group + + +def add_custom_training_arguments(group): + """Define arguments for training with Custom architecture.""" + group.add_argument( + "--transformer-warmup-steps", + default=25000, + type=int, + help="Optimizer warmup steps", + ) + group.add_argument( + "--transformer-lr", + default=10.0, + type=float, + help="Initial value of learning rate", + ) + + return group + + +def add_transducer_arguments(group): + """Define general arguments for transducer model.""" + group.add_argument( + "--trans-type", + default="warp-transducer", + type=str, + choices=["warp-transducer", "warp-rnnt"], + help="Type of transducer implementation to calculate loss.", + ) + group.add_argument( + "--transducer-weight", + default=1.0, + type=float, + help="Weight of transducer loss when auxiliary task is used.", + ) + group.add_argument( + "--joint-dim", + default=320, + type=int, + help="Number of dimensions in joint space", + ) + group.add_argument( + "--joint-activation-type", + type=str, + default="tanh", + choices=["relu", "tanh", "swish"], + help="Joint network activation type", + ) + group.add_argument( + "--score-norm", + type=strtobool, + nargs="?", + default=True, + help="Normalize transducer scores by length", + ) + + return group + + +def add_auxiliary_task_arguments(group): + """Add arguments for auxiliary task.""" + group.add_argument( + "--aux-task-type", + nargs="?", + default=None, + choices=["default", "symm_kl_div", "both"], + help="Type of auxiliary task.", + ) + group.add_argument( + "--aux-task-layer-list", + default=None, + type=ast.literal_eval, + help="List of layers to use for auxiliary task.", + ) + group.add_argument( + "--aux-task-weight", + default=0.3, + type=float, + help="Weight of auxiliary task loss.", + ) + group.add_argument( + "--aux-ctc", + type=strtobool, + nargs="?", + default=False, + help="Whether to use CTC as auxiliary task.", + ) + group.add_argument( + "--aux-ctc-weight", + default=1.0, + type=float, + help="Weight of auxiliary task loss", + ) + group.add_argument( + "--aux-ctc-dropout-rate", + default=0.0, + type=float, + help="Dropout rate for auxiliary CTC", + ) + group.add_argument( + "--aux-cross-entropy", + type=strtobool, + nargs="?", + default=False, + help="Whether to use CE as auxiliary task for the prediction network.", + ) + group.add_argument( + "--aux-cross-entropy-smoothing", + default=0.0, + type=float, + help="Smoothing rate for cross-entropy. If > 0, enables label smoothing loss.", + ) + group.add_argument( + "--aux-cross-entropy-weight", + default=0.5, + type=float, + help="Weight of auxiliary task loss", + ) + + return group diff --git a/espnet/nets/pytorch_backend/transducer/auxiliary_task.py b/espnet/nets/pytorch_backend/transducer/auxiliary_task.py new file mode 100644 index 0000000000000000000000000000000000000000..998bcf0cc29c121a0513aeedffb230a110d1e784 --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/auxiliary_task.py @@ -0,0 +1,114 @@ +"""Auxiliary task implementation for transducer models.""" + +from itertools import chain +from typing import List +from typing import Tuple +from typing import Union + +import torch +import torch.nn.functional as F + +from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface + + +class AuxiliaryTask(torch.nn.Module): + """Auxiliary task module.""" + + def __init__( + self, + decoder: Union[torch.nn.Module, TransducerDecoderInterface], + joint_network: torch.nn.Module, + rnnt_criterion: torch.nn.Module, + aux_task_type: str, + aux_task_weight: int, + encoder_out_dim: int, + joint_dim: int, + ): + """Auxiliary task initialization. + + Args: + decoder: Decoder module + joint_network: Joint network module + aux_task_type: Auxiliary task type + aux_task_weight: Auxiliary task weight + encoder_out: Encoder output dimension + joint_dim: Joint space dimension + + """ + super().__init__() + + self.rnnt_criterion = rnnt_criterion + + self.mlp_net = torch.nn.Sequential( + torch.nn.Linear(encoder_out_dim, joint_dim), + torch.nn.ReLU(), + torch.nn.Linear(joint_dim, joint_dim), + ) + + self.decoder = decoder + self.joint_network = joint_network + + self.aux_task_type = aux_task_type + self.aux_task_weight = aux_task_weight + + def forward( + self, + enc_out_aux: List, + dec_out: torch.Tensor, + main_joint: torch.Tensor, + target: torch.Tensor, + pred_len: torch.Tensor, + target_len: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward auxiliary task. + + Args: + enc_out_aux: List of encoder intermediate outputs + dec_out: Decoder outputs + main_joint: Joint output for main task + target: Target labels + pred_len: Prediction lengths + target_len: Target lengths + + Returns: + : (Weighted auxiliary transducer loss, Weighted auxiliary symmetric KL loss) + + """ + aux_trans = 0 + aux_symm_kl = 0 + + for p in chain(self.decoder.parameters(), self.joint_network.parameters()): + p.requires_grad = False + + for i, enc_aux in enumerate(enc_out_aux): + aux_mlp = self.mlp_net(enc_aux) + + aux_joint = self.joint_network( + aux_mlp.unsqueeze(2), + dec_out.unsqueeze(1), + is_aux=True, + ) + + if self.aux_task_type != "symm_kl_div": + aux_trans += self.rnnt_criterion( + aux_joint, + target, + pred_len, + target_len, + ) + + if self.aux_task_type != "default": + aux_symm_kl += F.kl_div( + F.log_softmax(main_joint, dim=-1), + F.softmax(aux_joint, dim=-1), + reduction="mean", + ) + F.kl_div( + F.log_softmax(aux_joint, dim=-1), + F.softmax(main_joint, dim=-1), + reduction="mean", + ) + + for p in chain(self.decoder.parameters(), self.joint_network.parameters()): + p.requires_grad = True + + return self.aux_task_weight * aux_trans, self.aux_task_weight * aux_symm_kl diff --git a/espnet/nets/pytorch_backend/transducer/blocks.py b/espnet/nets/pytorch_backend/transducer/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..875053ab69bcf6353c970b6585a0e369c4d946b3 --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/blocks.py @@ -0,0 +1,556 @@ +"""Set of methods to create custom architecture.""" + +from collections import Counter + +import torch + +from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule +from espnet.nets.pytorch_backend.conformer.encoder_layer import ( + EncoderLayer as ConformerEncoderLayer, # noqa: H301 +) + +from espnet.nets.pytorch_backend.nets_utils import get_activation + +from espnet.nets.pytorch_backend.transducer.causal_conv1d import CausalConv1d +from espnet.nets.pytorch_backend.transducer.transformer_decoder_layer import ( + DecoderLayer, # noqa: H301 +) +from espnet.nets.pytorch_backend.transducer.tdnn import TDNN +from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L + +from espnet.nets.pytorch_backend.transformer.attention import ( + MultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttention, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, # noqa: H301 + ScaledPositionalEncoding, # noqa: H301 + RelPositionalEncoding, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.repeat import MultiSequential +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling + + +def check_and_prepare(net_part, blocks_arch, input_layer): + """Check consecutive block shapes match and prepare input parameters. + + Args: + net_part (str): either 'encoder' or 'decoder' + blocks_arch (list): list of blocks for network part (type and parameters) + input_layer (str): input layer type + + Return: + input_layer (str): input layer type + input_layer_odim (int): output dim of input layer + input_dropout_rate (float): dropout rate of input layer + input_pos_dropout_rate (float): dropout rate of input layer positional enc. + out_dim (int): output dim of last block + + """ + input_dropout_rate = sorted( + Counter( + b["dropout-rate"] for b in blocks_arch if "dropout-rate" in b + ).most_common(), + key=lambda x: x[0], + reverse=True, + ) + + input_pos_dropout_rate = sorted( + Counter( + b["pos-dropout-rate"] for b in blocks_arch if "pos-dropout-rate" in b + ).most_common(), + key=lambda x: x[0], + reverse=True, + ) + + input_dropout_rate = input_dropout_rate[0][0] if input_dropout_rate else 0.0 + input_pos_dropout_rate = ( + input_pos_dropout_rate[0][0] if input_pos_dropout_rate else 0.0 + ) + + cmp_io = [] + has_transformer = False + has_conformer = False + for i in range(len(blocks_arch)): + if "type" in blocks_arch[i]: + block_type = blocks_arch[i]["type"] + else: + raise ValueError("type is not defined in the " + str(i + 1) + "th block.") + + if block_type == "transformer": + if not {"d_hidden", "d_ff", "heads"}.issubset(blocks_arch[i]): + raise ValueError( + "Block " + + str(i + 1) + + "in " + + net_part + + ": Transformer block format is: {'type: transformer', " + "'d_hidden': int, 'd_ff': int, 'heads': int, [...]}" + ) + + has_transformer = True + cmp_io.append((blocks_arch[i]["d_hidden"], blocks_arch[i]["d_hidden"])) + elif block_type == "conformer": + if net_part != "encoder": + raise ValueError( + "Block " + str(i + 1) + ": conformer type is only for encoder part." + ) + + if not { + "d_hidden", + "d_ff", + "heads", + "macaron_style", + "use_conv_mod", + }.issubset(blocks_arch[i]): + raise ValueError( + "Block " + + str(i + 1) + + " in " + + net_part + + ": Conformer block format is {'type: conformer', " + "'d_hidden': int, 'd_ff': int, 'heads': int, " + "'macaron_style': bool, 'use_conv_mod': bool, [...]}" + ) + + if ( + blocks_arch[i]["use_conv_mod"] is True + and "conv_mod_kernel" not in blocks_arch[i] + ): + raise ValueError( + "Block " + + str(i + 1) + + ": 'use_conv_mod' is True but 'use_conv_kernel' is not specified" + ) + + has_conformer = True + cmp_io.append((blocks_arch[i]["d_hidden"], blocks_arch[i]["d_hidden"])) + elif block_type == "causal-conv1d": + if not {"idim", "odim", "kernel_size"}.issubset(blocks_arch[i]): + raise ValueError( + "Block " + + str(i + 1) + + " in " + + net_part + + ": causal conv1d block format is: {'type: causal-conv1d', " + "'idim': int, 'odim': int, 'kernel_size': int}" + ) + + if i == 0: + input_layer = "c-embed" + + cmp_io.append((blocks_arch[i]["idim"], blocks_arch[i]["odim"])) + elif block_type == "tdnn": + if not {"idim", "odim", "ctx_size", "dilation", "stride"}.issubset( + blocks_arch[i] + ): + raise ValueError( + "Block " + + str(i + 1) + + " in " + + net_part + + ": TDNN block format is: {'type: tdnn', " + "'idim': int, 'odim': int, 'ctx_size': int, " + "'dilation': int, 'stride': int, [...]}" + ) + + cmp_io.append((blocks_arch[i]["idim"], blocks_arch[i]["odim"])) + else: + raise NotImplementedError( + "Wrong type for block " + + str(i + 1) + + " in " + + net_part + + ". Currently supported: " + "tdnn, causal-conv1d or transformer" + ) + + if has_transformer and has_conformer: + raise NotImplementedError( + net_part + ": transformer and conformer blocks " + "can't be defined in the same net part." + ) + + for i in range(1, len(cmp_io)): + if cmp_io[(i - 1)][1] != cmp_io[i][0]: + raise ValueError( + "Output/Input mismatch between blocks " + + str(i) + + " and " + + str(i + 1) + + " in " + + net_part + ) + + if blocks_arch[0]["type"] in ("tdnn", "causal-conv1d"): + input_layer_odim = blocks_arch[0]["idim"] + else: + input_layer_odim = blocks_arch[0]["d_hidden"] + + if blocks_arch[-1]["type"] in ("tdnn", "causal-conv1d"): + out_dim = blocks_arch[-1]["odim"] + else: + out_dim = blocks_arch[-1]["d_hidden"] + + return ( + input_layer, + input_layer_odim, + input_dropout_rate, + input_pos_dropout_rate, + out_dim, + ) + + +def get_pos_enc_and_att_class(net_part, pos_enc_type, self_attn_type): + """Get positional encoding and self attention module class. + + Args: + net_part (str): either 'encoder' or 'decoder' + pos_enc_type (str): positional encoding type + self_attn_type (str): self-attention type + + Return: + pos_enc_class (torch.nn.Module): positional encoding class + self_attn_class (torch.nn.Module): self-attention class + + """ + if pos_enc_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_type == "rel_pos": + if net_part == "encoder" and self_attn_type != "rel_self_attn": + raise ValueError("'rel_pos' is only compatible with 'rel_self_attn'") + pos_enc_class = RelPositionalEncoding + else: + raise NotImplementedError( + "pos_enc_type should be either 'abs_pos', 'scaled_abs_pos' or 'rel_pos'" + ) + + if self_attn_type == "rel_self_attn": + self_attn_class = RelPositionMultiHeadedAttention + else: + self_attn_class = MultiHeadedAttention + + return pos_enc_class, self_attn_class + + +def build_input_layer( + input_layer, + idim, + odim, + pos_enc_class, + dropout_rate_embed, + dropout_rate, + pos_dropout_rate, + padding_idx, +): + """Build input layer. + + Args: + input_layer (str): input layer type + idim (int): input dimension + odim (int): output dimension + pos_enc_class (class): positional encoding class + dropout_rate_embed (float): dropout rate for embedding layer + dropout_rate (float): dropout rate for input layer + pos_dropout_rate (float): dropout rate for positional encoding + padding_idx (int): padding index for embedding input layer (if specified) + + Returns: + (torch.nn.*): input layer module + subsampling_factor (int): subsampling factor + + """ + if pos_enc_class.__name__ == "RelPositionalEncoding": + pos_enc_class_subsampling = pos_enc_class(odim, pos_dropout_rate) + else: + pos_enc_class_subsampling = None + + if input_layer == "linear": + return ( + torch.nn.Sequential( + torch.nn.Linear(idim, odim), + torch.nn.LayerNorm(odim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(odim, pos_dropout_rate), + ), + 1, + ) + elif input_layer == "conv2d": + return Conv2dSubsampling(idim, odim, dropout_rate, pos_enc_class_subsampling), 4 + elif input_layer == "vgg2l": + return VGG2L(idim, odim, pos_enc_class_subsampling), 4 + elif input_layer == "embed": + return ( + torch.nn.Sequential( + torch.nn.Embedding(idim, odim, padding_idx=padding_idx), + pos_enc_class(odim, pos_dropout_rate), + ), + 1, + ) + elif input_layer == "c-embed": + return ( + torch.nn.Sequential( + torch.nn.Embedding(idim, odim, padding_idx=padding_idx), + torch.nn.Dropout(dropout_rate_embed), + ), + 1, + ) + else: + raise NotImplementedError("Support: linear, conv2d, vgg2l and embed") + + +def build_transformer_block(net_part, block_arch, pw_layer_type, pw_activation_type): + """Build function for transformer block. + + Args: + net_part (str): either 'encoder' or 'decoder' + block_arch (dict): transformer block parameters + pw_layer_type (str): positionwise layer type + pw_activation_type (str): positionwise activation type + + Returns: + (function): function to create transformer block + + """ + d_hidden = block_arch["d_hidden"] + d_ff = block_arch["d_ff"] + heads = block_arch["heads"] + + dropout_rate = block_arch["dropout-rate"] if "dropout-rate" in block_arch else 0.0 + pos_dropout_rate = ( + block_arch["pos-dropout-rate"] if "pos-dropout-rate" in block_arch else 0.0 + ) + att_dropout_rate = ( + block_arch["att-dropout-rate"] if "att-dropout-rate" in block_arch else 0.0 + ) + + if pw_layer_type == "linear": + pw_layer = PositionwiseFeedForward + pw_activation = get_activation(pw_activation_type) + pw_layer_args = (d_hidden, d_ff, pos_dropout_rate, pw_activation) + else: + raise NotImplementedError("Transformer block only supports linear yet.") + + if net_part == "encoder": + transformer_layer_class = EncoderLayer + elif net_part == "decoder": + transformer_layer_class = DecoderLayer + + return lambda: transformer_layer_class( + d_hidden, + MultiHeadedAttention(heads, d_hidden, att_dropout_rate), + pw_layer(*pw_layer_args), + dropout_rate, + ) + + +def build_conformer_block( + block_arch, + self_attn_class, + pw_layer_type, + pw_activation_type, + conv_mod_activation_type, +): + """Build function for conformer block. + + Args: + block_arch (dict): conformer block parameters + self_attn_type (str): self-attention module type + pw_layer_type (str): positionwise layer type + pw_activation_type (str): positionwise activation type + conv_mod_activation_type (str): convolutional module activation type + + Returns: + (function): function to create conformer block + + """ + d_hidden = block_arch["d_hidden"] + d_ff = block_arch["d_ff"] + heads = block_arch["heads"] + macaron_style = block_arch["macaron_style"] + use_conv_mod = block_arch["use_conv_mod"] + + dropout_rate = block_arch["dropout-rate"] if "dropout-rate" in block_arch else 0.0 + pos_dropout_rate = ( + block_arch["pos-dropout-rate"] if "pos-dropout-rate" in block_arch else 0.0 + ) + att_dropout_rate = ( + block_arch["att-dropout-rate"] if "att-dropout-rate" in block_arch else 0.0 + ) + + if pw_layer_type == "linear": + pw_layer = PositionwiseFeedForward + pw_activation = get_activation(pw_activation_type) + pw_layer_args = (d_hidden, d_ff, pos_dropout_rate, pw_activation) + else: + raise NotImplementedError("Conformer block only supports linear yet.") + + if use_conv_mod: + conv_layer = ConvolutionModule + conv_activation = get_activation(conv_mod_activation_type) + conv_layers_args = (d_hidden, block_arch["conv_mod_kernel"], conv_activation) + + return lambda: ConformerEncoderLayer( + d_hidden, + self_attn_class(heads, d_hidden, att_dropout_rate), + pw_layer(*pw_layer_args), + pw_layer(*pw_layer_args) if macaron_style else None, + conv_layer(*conv_layers_args) if use_conv_mod else None, + dropout_rate, + ) + + +def build_causal_conv1d_block(block_arch): + """Build function for causal conv1d block. + + Args: + block_arch (dict): causal conv1d block parameters + + Returns: + (function): function to create causal conv1d block + + """ + idim = block_arch["idim"] + odim = block_arch["odim"] + kernel_size = block_arch["kernel_size"] + + return lambda: CausalConv1d(idim, odim, kernel_size) + + +def build_tdnn_block(block_arch): + """Build function for tdnn block. + + Args: + block_arch (dict): tdnn block parameters + + Returns: + (function): function to create tdnn block + + """ + idim = block_arch["idim"] + odim = block_arch["odim"] + ctx_size = block_arch["ctx_size"] + dilation = block_arch["dilation"] + stride = block_arch["stride"] + + use_batch_norm = ( + block_arch["use-batch-norm"] if "use-batch-norm" in block_arch else False + ) + use_relu = block_arch["use-relu"] if "use-relu" in block_arch else False + + dropout_rate = block_arch["dropout-rate"] if "dropout-rate" in block_arch else 0.0 + + return lambda: TDNN( + idim, + odim, + ctx_size=ctx_size, + dilation=dilation, + stride=stride, + dropout_rate=dropout_rate, + batch_norm=use_batch_norm, + relu=use_relu, + ) + + +def build_blocks( + net_part, + idim, + input_layer, + blocks_arch, + repeat_block=0, + self_attn_type="self_attn", + positional_encoding_type="abs_pos", + positionwise_layer_type="linear", + positionwise_activation_type="relu", + conv_mod_activation_type="relu", + dropout_rate_embed=0.0, + padding_idx=-1, +): + """Build block for customizable architecture. + + Args: + net_part (str): either 'encoder' or 'decoder' + idim (int): dimension of inputs + input_layer (str): input layer type + blocks_arch (list): list of blocks for network part (type and parameters) + repeat_block (int): repeat provided blocks N times if N > 1 + positional_encoding_type (str): positional encoding layer type + positionwise_layer_type (str): linear + positionwise_activation_type (str): positionwise activation type + conv_mod_activation_type (str): convolutional module activation type + dropout_rate_embed (float): dropout rate for embedding + padding_idx (int): padding index for embedding input layer (if specified) + + Returns: + in_layer (torch.nn.*): input layer + all_blocks (MultiSequential): all blocks for network part + out_dim (int): dimension of last block output + conv_subsampling_factor (int): subsampling factor in frontend CNN + + """ + fn_modules = [] + + ( + input_layer, + input_layer_odim, + input_dropout_rate, + input_pos_dropout_rate, + out_dim, + ) = check_and_prepare(net_part, blocks_arch, input_layer) + + pos_enc_class, self_attn_class = get_pos_enc_and_att_class( + net_part, positional_encoding_type, self_attn_type + ) + + in_layer, conv_subsampling_factor = build_input_layer( + input_layer, + idim, + input_layer_odim, + pos_enc_class, + dropout_rate_embed, + input_dropout_rate, + input_pos_dropout_rate, + padding_idx, + ) + + for i in range(len(blocks_arch)): + block_type = blocks_arch[i]["type"] + + if block_type == "tdnn": + module = build_tdnn_block(blocks_arch[i]) + elif block_type == "transformer": + module = build_transformer_block( + net_part, + blocks_arch[i], + positionwise_layer_type, + positionwise_activation_type, + ) + elif block_type == "conformer": + module = build_conformer_block( + blocks_arch[i], + self_attn_class, + positionwise_layer_type, + positionwise_activation_type, + conv_mod_activation_type, + ) + elif block_type == "causal-conv1d": + module = build_causal_conv1d_block(blocks_arch[i]) + + fn_modules.append(module) + + if repeat_block > 1: + fn_modules = fn_modules * repeat_block + + return ( + in_layer, + MultiSequential(*[fn() for fn in fn_modules]), + out_dim, + conv_subsampling_factor, + ) diff --git a/espnet/nets/pytorch_backend/transducer/causal_conv1d.py b/espnet/nets/pytorch_backend/transducer/causal_conv1d.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8f331578f384ab93ad0cf8271a84c7cc0bfa80 --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/causal_conv1d.py @@ -0,0 +1,59 @@ +"""CausalConv1d module definition for custom decoder.""" + +import torch + + +class CausalConv1d(torch.nn.Module): + """CausalConv1d module for custom decoder. + + Args: + idim (int): dimension of inputs + odim (int): dimension of outputs + kernel_size (int): size of convolving kernel + stride (int): stride of the convolution + dilation (int): spacing between the kernel points + groups (int): number of blocked connections from ichannels to ochannels + bias (bool): whether to add a learnable bias to the output + + """ + + def __init__( + self, idim, odim, kernel_size, stride=1, dilation=1, groups=1, bias=True + ): + """Construct a CausalConv1d object.""" + super().__init__() + + self._pad = (kernel_size - 1) * dilation + + self.causal_conv1d = torch.nn.Conv1d( + idim, + odim, + kernel_size=kernel_size, + stride=stride, + padding=self._pad, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x, x_mask, cache=None): + """CausalConv1d forward for x. + + Args: + x (torch.Tensor): input torch (B, U, idim) + x_mask (torch.Tensor): (B, 1, U) + + Returns: + x (torch.Tensor): input torch (B, sub(U), attention_dim) + x_mask (torch.Tensor): (B, 1, sub(U)) + + """ + x = x.permute(0, 2, 1) + x = self.causal_conv1d(x) + + if self._pad != 0: + x = x[:, :, : -self._pad] + + x = x.permute(0, 2, 1) + + return x, x_mask diff --git a/espnet/nets/pytorch_backend/transducer/custom_decoder.py b/espnet/nets/pytorch_backend/transducer/custom_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ca37fe645ff71635f3ce78d6d3929cb005f58b37 --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/custom_decoder.py @@ -0,0 +1,277 @@ +"""Custom decoder definition for transducer models.""" + +import torch + +from espnet.nets.pytorch_backend.transducer.blocks import build_blocks +from espnet.nets.pytorch_backend.transducer.utils import check_batch_state +from espnet.nets.pytorch_backend.transducer.utils import check_state +from espnet.nets.pytorch_backend.transducer.utils import pad_sequence +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface + + +class CustomDecoder(TransducerDecoderInterface, torch.nn.Module): + """Custom decoder module for transducer models. + + Args: + odim (int): dimension of outputs + dec_arch (list): list of layer definitions + input_layer (str): input layer type + repeat_block (int): repeat provided blocks N times if N > 1 + positional_encoding_type (str): positional encoding type + positionwise_layer_type (str): linear + positionwise_activation_type (str): positionwise activation type + dropout_rate_embed (float): dropout rate for embedding layer (if specified) + blank (int): blank symbol ID + + """ + + def __init__( + self, + odim, + dec_arch, + input_layer="embed", + repeat_block=0, + joint_activation_type="tanh", + positional_encoding_type="abs_pos", + positionwise_layer_type="linear", + positionwise_activation_type="relu", + dropout_rate_embed=0.0, + blank=0, + ): + """Construct a CustomDecoder object.""" + torch.nn.Module.__init__(self) + + self.embed, self.decoders, ddim, _ = build_blocks( + "decoder", + odim, + input_layer, + dec_arch, + repeat_block=repeat_block, + positional_encoding_type=positional_encoding_type, + positionwise_layer_type=positionwise_layer_type, + positionwise_activation_type=positionwise_activation_type, + dropout_rate_embed=dropout_rate_embed, + padding_idx=blank, + ) + + self.after_norm = LayerNorm(ddim) + + self.dlayers = len(self.decoders) + self.dunits = ddim + self.odim = odim + + self.blank = blank + + def set_device(self, device): + """Set GPU device to use. + + Args: + device (torch.device): device id + + """ + self.device = device + + def init_state(self, batch_size=None, device=None, dtype=None): + """Initialize decoder states. + + Args: + None + + Returns: + state (list): batch of decoder decoder states [L x None] + + """ + state = [None] * self.dlayers + + return state + + def forward(self, tgt, tgt_mask, memory): + """Forward custom decoder. + + Args: + tgt (torch.Tensor): input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor + (batch, maxlen_out, #mels) in the other cases + tgt_mask (torch.Tensor): input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + memory (torch.Tensor): encoded memory, float32 (batch, maxlen_in, feat) + + Return: + tgt (torch.Tensor): decoder output (batch, maxlen_out, dim_dec) + tgt_mask (torch.Tensor): score mask before softmax (batch, maxlen_out) + + """ + tgt = self.embed(tgt) + + tgt, tgt_mask = self.decoders(tgt, tgt_mask) + tgt = self.after_norm(tgt) + + return tgt, tgt_mask + + def score(self, hyp, cache): + """Forward one step. + + Args: + hyp (dataclass): hypothesis + cache (dict): states cache + + Returns: + y (torch.Tensor): decoder outputs (1, dec_dim) + (list): decoder states + [L x (1, max_len, dec_dim)] + lm_tokens (torch.Tensor): token id for LM (1) + + """ + tgt = torch.tensor([hyp.yseq], device=self.device) + lm_tokens = tgt[:, -1] + + str_yseq = "".join(list(map(str, hyp.yseq))) + + if str_yseq in cache: + y, new_state = cache[str_yseq] + else: + tgt_mask = subsequent_mask(len(hyp.yseq)).unsqueeze_(0) + + state = check_state(hyp.dec_state, (tgt.size(1) - 1), self.blank) + + tgt = self.embed(tgt) + + new_state = [] + for s, decoder in zip(state, self.decoders): + tgt, tgt_mask = decoder(tgt, tgt_mask, cache=s) + new_state.append(tgt) + + y = self.after_norm(tgt[:, -1]) + + cache[str_yseq] = (y, new_state) + + return y[0], new_state, lm_tokens + + def batch_score(self, hyps, batch_states, cache, use_lm): + """Forward batch one step. + + Args: + hyps (list): batch of hypotheses + batch_states (list): decoder states + [L x (B, max_len, dec_dim)] + cache (dict): states cache + + Returns: + batch_y (torch.Tensor): decoder output (B, dec_dim) + batch_states (list): decoder states + [L x (B, max_len, dec_dim)] + lm_tokens (torch.Tensor): batch of token ids for LM (B) + + """ + final_batch = len(hyps) + + process = [] + done = [None for _ in range(final_batch)] + + for i, hyp in enumerate(hyps): + str_yseq = "".join(list(map(str, hyp.yseq))) + + if str_yseq in cache: + done[i] = cache[str_yseq] + else: + process.append((str_yseq, hyp.yseq, hyp.dec_state)) + + if process: + _tokens = pad_sequence([p[1] for p in process], self.blank) + batch_tokens = torch.LongTensor(_tokens, device=self.device) + + tgt_mask = ( + subsequent_mask(batch_tokens.size(-1)) + .unsqueeze_(0) + .expand(len(process), -1, -1) + ) + + dec_state = self.create_batch_states( + self.init_state(), + [p[2] for p in process], + _tokens, + ) + + tgt = self.embed(batch_tokens) + + next_state = [] + for s, decoder in zip(dec_state, self.decoders): + tgt, tgt_mask = decoder(tgt, tgt_mask, cache=s) + next_state.append(tgt) + + tgt = self.after_norm(tgt[:, -1]) + + j = 0 + for i in range(final_batch): + if done[i] is None: + new_state = self.select_state(next_state, j) + + done[i] = (tgt[j], new_state) + cache[process[j][0]] = (tgt[j], new_state) + + j += 1 + + self.create_batch_states( + batch_states, [d[1] for d in done], [[0] + h.yseq for h in hyps] + ) + batch_y = torch.stack([d[0] for d in done]) + + if use_lm: + lm_tokens = torch.LongTensor( + [hyp.yseq[-1] for hyp in hyps], device=self.device + ) + + return batch_y, batch_states, lm_tokens + + return batch_y, batch_states, None + + def select_state(self, batch_states, idx): + """Get decoder state from batch of states, for given id. + + Args: + batch_states (list): batch of decoder states + [L x (B, max_len, dec_dim)] + idx (int): index to extract state from batch of states + + Returns: + state_idx (list): decoder states for given id + [L x (1, max_len, dec_dim)] + + """ + if batch_states[0] is None: + return batch_states + + state_idx = [batch_states[layer][idx] for layer in range(self.dlayers)] + + return state_idx + + def create_batch_states(self, batch_states, l_states, check_list): + """Create batch of decoder states. + + Args: + batch_states (list): batch of decoder states + [L x (B, max_len, dec_dim)] + l_states (list): list of decoder states + [B x [L x (1, max_len, dec_dim)]] + check_list (list): list of sequences for max_len + + Returns: + batch_states (list): batch of decoder states + [L x (B, max_len, dec_dim)] + + """ + if l_states[0][0] is None: + return batch_states + + max_len = max(len(elem) for elem in check_list) - 1 + + for layer in range(self.dlayers): + batch_states[layer] = check_batch_state( + [s[layer] for s in l_states], max_len, self.blank + ) + + return batch_states diff --git a/espnet/nets/pytorch_backend/transducer/custom_encoder.py b/espnet/nets/pytorch_backend/transducer/custom_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..6a024bf12e624533c24f134ee44408459cd53c26 --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/custom_encoder.py @@ -0,0 +1,124 @@ +"""Cutom encoder definition for transducer models.""" + +import torch + +from espnet.nets.pytorch_backend.transducer.blocks import build_blocks +from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling + + +class CustomEncoder(torch.nn.Module): + """Custom encoder module for transducer models. + + Args: + idim (int): input dim + enc_arch (list): list of encoder blocks (type and parameters) + input_layer (str): input layer type + repeat_block (int): repeat provided block N times if N > 1 + self_attn_type (str): type of self-attention + positional_encoding_type (str): positional encoding type + positionwise_layer_type (str): linear + positionwise_activation_type (str): positionwise activation type + conv_mod_activation_type (str): convolutional module activation type + normalize_before (bool): whether to use layer_norm before the first block + aux_task_layer_list (list): list of layer ids for intermediate output + padding_idx (int): padding_idx for embedding input layer (if specified) + + """ + + def __init__( + self, + idim, + enc_arch, + input_layer="linear", + repeat_block=0, + self_attn_type="selfattn", + positional_encoding_type="abs_pos", + positionwise_layer_type="linear", + positionwise_activation_type="relu", + conv_mod_activation_type="relu", + normalize_before=True, + aux_task_layer_list=[], + padding_idx=-1, + ): + """Construct an CustomEncoder object.""" + super().__init__() + + ( + self.embed, + self.encoders, + self.enc_out, + self.conv_subsampling_factor, + ) = build_blocks( + "encoder", + idim, + input_layer, + enc_arch, + repeat_block=repeat_block, + self_attn_type=self_attn_type, + positional_encoding_type=positional_encoding_type, + positionwise_layer_type=positionwise_layer_type, + positionwise_activation_type=positionwise_activation_type, + conv_mod_activation_type=conv_mod_activation_type, + padding_idx=padding_idx, + ) + + self.normalize_before = normalize_before + + if self.normalize_before: + self.after_norm = LayerNorm(self.enc_out) + + self.n_blocks = len(enc_arch) * repeat_block + + self.aux_task_layer_list = aux_task_layer_list + + def forward(self, xs, masks): + """Encode input sequence. + + Args: + xs (torch.Tensor): input tensor + masks (torch.Tensor): input mask + + Returns: + xs (torch.Tensor or tuple): + position embedded output or + (position embedded output, auxiliary outputs) + mask (torch.Tensor): position embedded mask + + """ + if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + + if self.aux_task_layer_list: + aux_xs_list = [] + + for b in range(self.n_blocks): + xs, masks = self.encoders[b](xs, masks) + + if b in self.aux_task_layer_list: + if isinstance(xs, tuple): + aux_xs = xs[0] + else: + aux_xs = xs + + if self.normalize_before: + aux_xs_list.append(self.after_norm(aux_xs)) + else: + aux_xs_list.append(aux_xs) + else: + xs, masks = self.encoders(xs, masks) + + if isinstance(xs, tuple): + xs = xs[0] + + if self.normalize_before: + xs = self.after_norm(xs) + + if self.aux_task_layer_list: + return (xs, aux_xs_list), masks + + return xs, masks diff --git a/espnet/nets/pytorch_backend/transducer/error_calculator.py b/espnet/nets/pytorch_backend/transducer/error_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..5989f1ea2f4cfd306845bbd27a484c5bdcfffaeb --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/error_calculator.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +"""CER/WER monitoring for transducer models.""" + +import editdistance + +from espnet.nets.beam_search_transducer import BeamSearchTransducer + + +class ErrorCalculator(object): + """Calculate CER and WER for transducer models. + + Args: + decoder (torch.nn.Module|TransducerDecoderInterface): decoder module + joint_network (torch.nn.Module): joint network module + token_list (list): list of tokens + sym_space (str): space symbol + sym_blank (str): blank symbol + report_cer (boolean): compute CER option + report_wer (boolean): compute WER option + + """ + + def __init__( + self, + decoder, + joint_network, + token_list, + sym_space, + sym_blank, + report_cer=False, + report_wer=False, + ): + """Construct an ErrorCalculator object for transducer model.""" + super().__init__() + + self.beam_search = BeamSearchTransducer( + decoder=decoder, + joint_network=joint_network, + beam_size=1, + ) + + self.decoder = decoder + + self.token_list = token_list + self.space = sym_space + self.blank = sym_blank + + self.report_cer = report_cer + self.report_wer = report_wer + + def __call__(self, hs_pad, ys_pad): + """Calculate sentence-level WER/CER score for transducer models. + + Args: + hs_pad (torch.Tensor): batch of padded input sequence (batch, T, D) + ys_pad (torch.Tensor): reference (batch, seqlen) + + Returns: + (float): sentence-level CER score + (float): sentence-level WER score + + """ + cer, wer = None, None + + batchsize = int(hs_pad.size(0)) + batch_nbest = [] + + hs_pad = hs_pad.to(next(self.decoder.parameters()).device) + + for b in range(batchsize): + nbest_hyps = self.beam_search(hs_pad[b]) + batch_nbest.append(nbest_hyps[-1]) + + ys_hat = [nbest_hyp.yseq[1:] for nbest_hyp in batch_nbest] + + seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad.cpu()) + + if self.report_cer: + cer = self.calculate_cer(seqs_hat, seqs_true) + + if self.report_wer: + wer = self.calculate_wer(seqs_hat, seqs_true) + + return cer, wer + + def convert_to_char(self, ys_hat, ys_pad): + """Convert index to character. + + Args: + ys_hat (torch.Tensor): prediction (batch, seqlen) + ys_pad (torch.Tensor): reference (batch, seqlen) + + Returns: + (list): token list of prediction + (list): token list of reference + + """ + seqs_hat, seqs_true = [], [] + + for i, y_hat in enumerate(ys_hat): + y_true = ys_pad[i] + + seq_hat = [self.token_list[int(idx)] for idx in y_hat] + seq_true = [self.token_list[int(idx)] for idx in y_true if int(idx) != -1] + + seq_hat_text = "".join(seq_hat).replace(self.space, " ") + seq_hat_text = seq_hat_text.replace(self.blank, "") + seq_true_text = "".join(seq_true).replace(self.space, " ") + + seqs_hat.append(seq_hat_text) + seqs_true.append(seq_true_text) + + return seqs_hat, seqs_true + + def calculate_cer(self, seqs_hat, seqs_true): + """Calculate sentence-level CER score for transducer model. + + Args: + seqs_hat (torch.Tensor): prediction (batch, seqlen) + seqs_true (torch.Tensor): reference (batch, seqlen) + + Returns: + (float): average sentence-level CER score + + """ + char_eds, char_ref_lens = [], [] + + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + + hyp_chars = seq_hat_text.replace(" ", "") + ref_chars = seq_true_text.replace(" ", "") + + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + + return float(sum(char_eds)) / sum(char_ref_lens) + + def calculate_wer(self, seqs_hat, seqs_true): + """Calculate sentence-level WER score for transducer model. + + Args: + seqs_hat (torch.Tensor): prediction (batch, seqlen) + seqs_true (torch.Tensor): reference (batch, seqlen) + + Returns: + (float): average sentence-level WER score + + """ + word_eds, word_ref_lens = [], [] + + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + + return float(sum(word_eds)) / sum(word_ref_lens) diff --git a/espnet/nets/pytorch_backend/transducer/initializer.py b/espnet/nets/pytorch_backend/transducer/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..e218bb203e2a37985822cd4a524afc567029fcac --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/initializer.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +"""Parameter initialization for transducer model.""" + +import math + +from espnet.nets.pytorch_backend.initialization import set_forget_bias_to_one + + +def initializer(model, args): + """Initialize transducer model. + + Args: + model (torch.nn.Module): transducer instance + args (Namespace): argument Namespace containing options + + """ + for name, p in model.named_parameters(): + if any(x in name for x in ["enc.", "dec.", "joint_network"]): + # rnn based parts + joint network + if p.dim() == 1: + # bias + p.data.zero_() + elif p.dim() == 2: + # linear weight + n = p.size(1) + stdv = 1.0 / math.sqrt(n) + p.data.normal_(0, stdv) + elif p.dim() in (3, 4): + # conv weight + n = p.size(1) + for k in p.size()[2:]: + n *= k + stdv = 1.0 / math.sqrt(n) + p.data.normal_(0, stdv) + + if args.dtype != "custom": + model.dec.embed.weight.data.normal_(0, 1) + + for i in range(model.dec.dlayers): + set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_ih_l0")) + set_forget_bias_to_one(getattr(model.dec.decoder[i], "bias_hh_l0")) diff --git a/espnet/nets/pytorch_backend/transducer/joint_network.py b/espnet/nets/pytorch_backend/transducer/joint_network.py new file mode 100644 index 0000000000000000000000000000000000000000..d88f2ee7d559bcbf07536a3efc5f993930ee141e --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/joint_network.py @@ -0,0 +1,56 @@ +"""Transducer joint network implementation.""" + +import torch + +from espnet.nets.pytorch_backend.nets_utils import get_activation + + +class JointNetwork(torch.nn.Module): + """Transducer joint network module. + + Args: + joint_space_size: Dimension of joint space + joint_activation_type: Activation type for joint network + + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + decoder_output_size: int, + joint_space_size: int, + joint_activation_type: int, + ): + """Joint network initializer.""" + super().__init__() + + self.lin_enc = torch.nn.Linear(encoder_output_size, joint_space_size) + self.lin_dec = torch.nn.Linear( + decoder_output_size, joint_space_size, bias=False + ) + + self.lin_out = torch.nn.Linear(joint_space_size, vocab_size) + + self.joint_activation = get_activation(joint_activation_type) + + def forward( + self, h_enc: torch.Tensor, h_dec: torch.Tensor, is_aux: bool = False + ) -> torch.Tensor: + """Joint computation of z. + + Args: + h_enc: Batch of expanded hidden state (B, T, 1, D_enc) + h_dec: Batch of expanded hidden state (B, 1, U, D_dec) + + Returns: + z: Output (B, T, U, vocab_size) + + """ + if is_aux: + z = self.joint_activation(h_enc + self.lin_dec(h_dec)) + else: + z = self.joint_activation(self.lin_enc(h_enc) + self.lin_dec(h_dec)) + z = self.lin_out(z) + + return z diff --git a/espnet/nets/pytorch_backend/transducer/loss.py b/espnet/nets/pytorch_backend/transducer/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..543049cc0ec87859c8e88a746a3cb6ebfbd02eec --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/loss.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +"""Transducer loss module.""" + +import torch + + +class TransLoss(torch.nn.Module): + """Transducer loss module. + + Args: + trans_type (str): type of transducer implementation to calculate loss. + blank_id (int): blank symbol id + """ + + def __init__(self, trans_type, blank_id): + """Construct an TransLoss object.""" + super().__init__() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if trans_type == "warp-transducer": + from warprnnt_pytorch import RNNTLoss + + self.trans_loss = RNNTLoss(blank=blank_id) + elif trans_type == "warp-rnnt": + if device.type == "cuda": + try: + from warp_rnnt import rnnt_loss + + self.trans_loss = rnnt_loss + except ImportError: + raise ImportError( + "warp-rnnt is not installed. Please re-setup" + " espnet or use 'warp-transducer'" + ) + else: + raise ValueError("warp-rnnt is not supported in CPU mode") + + self.trans_type = trans_type + self.blank_id = blank_id + + def forward(self, pred_pad, target, pred_len, target_len): + """Compute path-aware regularization transducer loss. + + Args: + pred_pad (torch.Tensor): Batch of predicted sequences + (batch, maxlen_in, maxlen_out+1, odim) + target (torch.Tensor): Batch of target sequences (batch, maxlen_out) + pred_len (torch.Tensor): batch of lengths of predicted sequences (batch) + target_len (torch.tensor): batch of lengths of target sequences (batch) + + Returns: + loss (torch.Tensor): transducer loss + + """ + dtype = pred_pad.dtype + if dtype != torch.float32: + # warp-transducer and warp-rnnt only support float32 + pred_pad = pred_pad.to(dtype=torch.float32) + + if self.trans_type == "warp-rnnt": + log_probs = torch.log_softmax(pred_pad, dim=-1) + + loss = self.trans_loss( + log_probs, + target, + pred_len, + target_len, + reduction="mean", + blank=self.blank_id, + gather=True, + ) + else: + loss = self.trans_loss(pred_pad, target, pred_len, target_len) + loss = loss.to(dtype=dtype) + + return loss diff --git a/espnet/nets/pytorch_backend/transducer/rnn_decoder.py b/espnet/nets/pytorch_backend/transducer/rnn_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..55ee4b8708d83664ed7e7a7c7e4baa81584316ee --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/rnn_decoder.py @@ -0,0 +1,292 @@ +"""RNN decoder for transducer-based models.""" + +import torch + +from espnet.nets.transducer_decoder_interface import TransducerDecoderInterface + + +class DecoderRNNT(TransducerDecoderInterface, torch.nn.Module): + """RNN-T Decoder module. + + Args: + odim (int): dimension of outputs + dtype (str): gru or lstm + dlayers (int): # prediction layers + dunits (int): # prediction units + blank (int): blank symbol id + embed_dim (int): dimension of embeddings + dropout (float): dropout rate + dropout_embed (float): embedding dropout rate + + """ + + def __init__( + self, + odim, + dtype, + dlayers, + dunits, + blank, + embed_dim, + dropout=0.0, + dropout_embed=0.0, + ): + """Transducer initializer.""" + super().__init__() + + self.embed = torch.nn.Embedding(odim, embed_dim, padding_idx=blank) + self.dropout_embed = torch.nn.Dropout(p=dropout_embed) + + dec_net = torch.nn.LSTM if dtype == "lstm" else torch.nn.GRU + + self.decoder = torch.nn.ModuleList( + [dec_net(embed_dim, dunits, 1, batch_first=True)] + ) + self.dropout_dec = torch.nn.Dropout(p=dropout) + + for _ in range(1, dlayers): + self.decoder += [dec_net(dunits, dunits, 1, batch_first=True)] + + self.dlayers = dlayers + self.dunits = dunits + self.dtype = dtype + + self.odim = odim + + self.ignore_id = -1 + self.blank = blank + + self.multi_gpus = torch.cuda.device_count() > 1 + + def set_device(self, device): + """Set GPU device to use. + + Args: + device (torch.device): device id + + """ + self.device = device + + def set_data_type(self, data_type): + """Set GPU device to use. + + Args: + data_type (torch.dtype): Tensor data type + + """ + self.data_type = data_type + + def init_state(self, batch_size): + """Initialize decoder states. + + Args: + batch_size (int): Batch size + + Returns: + (tuple): batch of decoder states + ((L, B, dec_dim), (L, B, dec_dim)) + + """ + h_n = torch.zeros( + self.dlayers, + batch_size, + self.dunits, + device=self.device, + dtype=self.data_type, + ) + + if self.dtype == "lstm": + c_n = torch.zeros( + self.dlayers, + batch_size, + self.dunits, + device=self.device, + dtype=self.data_type, + ) + + return (h_n, c_n) + + return (h_n, None) + + def rnn_forward(self, y, state): + """RNN forward. + + Args: + y (torch.Tensor): batch of input features (B, emb_dim) + state (tuple): batch of decoder states + ((L, B, dec_dim), (L, B, dec_dim)) + + Returns: + y (torch.Tensor): batch of output features (B, dec_dim) + (tuple): batch of decoder states + (L, B, dec_dim), (L, B, dec_dim)) + + """ + h_prev, c_prev = state + h_next, c_next = self.init_state(y.size(0)) + + for layer in range(self.dlayers): + if self.dtype == "lstm": + y, ( + h_next[layer : layer + 1], + c_next[layer : layer + 1], + ) = self.decoder[layer]( + y, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]) + ) + else: + y, h_next[layer : layer + 1] = self.decoder[layer]( + y, hx=h_prev[layer : layer + 1] + ) + + y = self.dropout_dec(y) + + return y, (h_next, c_next) + + def forward(self, hs_pad, ys_in_pad): + """Forward function for transducer. + + Args: + hs_pad (torch.Tensor): + batch of padded hidden state sequences (B, Tmax, D) + ys_in_pad (torch.Tensor): + batch of padded character id sequence tensor (B, Lmax+1) + + Returns: + z (torch.Tensor): output (B, T, U, odim) + + """ + self.set_device(hs_pad.device) + self.set_data_type(hs_pad.dtype) + + state = self.init_state(hs_pad.size(0)) + eys = self.dropout_embed(self.embed(ys_in_pad)) + + h_dec, _ = self.rnn_forward(eys, state) + + return h_dec + + def score(self, hyp, cache): + """Forward one step. + + Args: + hyp (dataclass): hypothesis + cache (dict): states cache + + Returns: + y (torch.Tensor): decoder outputs (1, dec_dim) + state (tuple): decoder states + ((L, 1, dec_dim), (L, 1, dec_dim)), + (torch.Tensor): token id for LM (1,) + + """ + vy = torch.full((1, 1), hyp.yseq[-1], dtype=torch.long, device=self.device) + + str_yseq = "".join(list(map(str, hyp.yseq))) + + if str_yseq in cache: + y, state = cache[str_yseq] + else: + ey = self.embed(vy) + + y, state = self.rnn_forward(ey, hyp.dec_state) + cache[str_yseq] = (y, state) + + return y[0][0], state, vy[0] + + def batch_score(self, hyps, batch_states, cache, use_lm): + """Forward batch one step. + + Args: + hyps (list): batch of hypotheses + batch_states (tuple): batch of decoder states + ((L, B, dec_dim), (L, B, dec_dim)) + cache (dict): states cache + use_lm (bool): whether a LM is used for decoding + + Returns: + batch_y (torch.Tensor): decoder output (B, dec_dim) + batch_states (tuple): batch of decoder states + ((L, B, dec_dim), (L, B, dec_dim)) + lm_tokens (torch.Tensor): batch of token ids for LM (B) + + """ + final_batch = len(hyps) + + process = [] + done = [None] * final_batch + + for i, hyp in enumerate(hyps): + str_yseq = "".join(list(map(str, hyp.yseq))) + + if str_yseq in cache: + done[i] = cache[str_yseq] + else: + process.append((str_yseq, hyp.yseq[-1], hyp.dec_state)) + + if process: + tokens = torch.LongTensor([[p[1]] for p in process], device=self.device) + dec_state = self.create_batch_states( + self.init_state(tokens.size(0)), [p[2] for p in process] + ) + + ey = self.embed(tokens) + y, dec_state = self.rnn_forward(ey, dec_state) + + j = 0 + for i in range(final_batch): + if done[i] is None: + new_state = self.select_state(dec_state, j) + + done[i] = (y[j], new_state) + cache[process[j][0]] = (y[j], new_state) + + j += 1 + + batch_y = torch.cat([d[0] for d in done], dim=0) + batch_states = self.create_batch_states(batch_states, [d[1] for d in done]) + + if use_lm: + lm_tokens = torch.LongTensor([h.yseq[-1] for h in hyps], device=self.device) + + return batch_y, batch_states, lm_tokens + + return batch_y, batch_states, None + + def select_state(self, batch_states, idx): + """Get decoder state from batch of states, for given id. + + Args: + batch_states (tuple): batch of decoder states + ((L, B, dec_dim), (L, B, dec_dim)) + idx (int): index to extract state from batch of states + + Returns: + (tuple): decoder states for given id + ((L, 1, dec_dim), (L, 1, dec_dim)) + + """ + return ( + batch_states[0][:, idx : idx + 1, :], + batch_states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None, + ) + + def create_batch_states(self, batch_states, l_states, l_tokens=None): + """Create batch of decoder states. + + Args: + batch_states (tuple): batch of decoder states + ((L, B, dec_dim), (L, B, dec_dim)) + l_states (list): list of decoder states + [L x ((1, dec_dim), (1, dec_dim))] + + Returns: + batch_states (tuple): batch of decoder states + ((L, B, dec_dim), (L, B, dec_dim)) + + """ + return ( + torch.cat([s[0] for s in l_states], dim=1), + torch.cat([s[1] for s in l_states], dim=1) + if self.dtype == "lstm" + else None, + ) diff --git a/espnet/nets/pytorch_backend/transducer/rnn_encoder.py b/espnet/nets/pytorch_backend/transducer/rnn_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3751c89f305bea2940a0ddeed12eea8324fd62bf --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/rnn_encoder.py @@ -0,0 +1,535 @@ +"""RNN encoder implementation for transducer-based models. + +These classes are based on the ones in espnet.nets.pytorch_backend.rnn.encoders, +and modified to output intermediate layers representation based on a list of +layers given as input. These additional outputs are intended to be used with +auxiliary tasks. +It should be noted that, here, RNN class rely on a stack of 1-layer LSTM instead +of a multi-layer LSTM for that purpose. + +""" + +import argparse +import logging +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + +from espnet.nets.e2e_asr_common import get_vgg2l_odim +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.nets_utils import to_device + + +class RNNP(torch.nn.Module): + """RNN with projection layer module. + + Args: + idim: Dimension of inputs + elayers: Dimension of encoder layers + cdim: Number of units (results in cdim * 2 if bidirectional) + hdim: Number of projection units + subsample: List of subsampling number + dropout: Dropout rate + typ: RNN type + aux_task_layer_list: List of layer ids for intermediate output + + """ + + def __init__( + self, + idim: int, + elayers: int, + cdim: int, + hdim: int, + subsample: np.ndarray, + dropout: float, + typ: str = "blstm", + aux_task_layer_list: List = [], + ): + """Initialize RNNP module.""" + super(RNNP, self).__init__() + + bidir = typ[0] == "b" + for i in range(elayers): + if i == 0: + inputdim = idim + else: + inputdim = hdim + + RNN = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU + rnn = RNN( + inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True + ) + + setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn) + + if bidir: + setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim)) + else: + setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim)) + + self.elayers = elayers + self.cdim = cdim + self.subsample = subsample + self.typ = typ + self.bidir = bidir + self.dropout = dropout + + self.aux_task_layer_list = aux_task_layer_list + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_state: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor, List], torch.Tensor]: + """RNNP forward. + + Args: + xs_pad: Batch of padded input sequences (B, Tmax, idim) + ilens: Batch of lengths of input sequences (B) + prev_state: Batch of previous RNN states + + Returns: + : Batch of padded output sequences (B, Tmax, hdim) + or tuple w/ aux outputs ((B, Tmax, hdim), [L x (B, Tmax, hdim)]) + : Batch of lengths of output sequences (B) + : Batch of hidden state sequences (B, Tmax, hdim) + + """ + logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) + + aux_xs_list = [] + elayer_states = [] + for layer in range(self.elayers): + if not isinstance(ilens, torch.Tensor): + ilens = torch.tensor(ilens) + + xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True) + rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer)) + rnn.flatten_parameters() + + if prev_state is not None and rnn.bidirectional: + prev_state = reset_backward_rnn_state(prev_state) + + ys, states = rnn( + xs_pack, hx=None if prev_state is None else prev_state[layer] + ) + elayer_states.append(states) + + ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) + + sub = self.subsample[layer + 1] + if sub > 1: + ys_pad = ys_pad[:, ::sub] + ilens = torch.tensor([int(i + 1) // sub for i in ilens]) + + projection_layer = getattr(self, "bt%d" % layer) + projected = projection_layer(ys_pad.contiguous().view(-1, ys_pad.size(2))) + xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) + + if layer in self.aux_task_layer_list: + aux_xs_list.append(xs_pad) + + if layer < self.elayers - 1: + xs_pad = torch.tanh(F.dropout(xs_pad, p=self.dropout)) + + if aux_xs_list: + return (xs_pad, aux_xs_list), ilens, elayer_states + else: + return xs_pad, ilens, elayer_states + + +class RNN(torch.nn.Module): + """RNN module. + + Args: + idim: Dimension of inputs + elayers: Number of encoder layers + cdim: Number of rnn units (resulted in cdim * 2 if bidirectional) + hdim: Number of final projection units + dropout: Dropout rate + typ: The RNN type + + """ + + def __init__( + self, + idim: int, + elayers: int, + cdim: int, + hdim: int, + dropout: float, + typ: str = "blstm", + aux_task_layer_list: List = [], + ): + """Initialize RNN module.""" + super(RNN, self).__init__() + + bidir = typ[0] == "b" + + for i in range(elayers): + if i == 0: + inputdim = idim + else: + inputdim = cdim + + layer_type = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU + rnn = layer_type( + inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True + ) + + setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn) + + self.dropout = torch.nn.Dropout(p=dropout) + + self.elayers = elayers + self.cdim = cdim + self.hdim = hdim + self.typ = typ + self.bidir = bidir + + self.l_last = torch.nn.Linear(cdim, hdim) + + self.aux_task_layer_list = aux_task_layer_list + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_state: Optional[torch.Tensor] = None, + ) -> Union[Tuple[torch.Tensor, List], torch.Tensor]: + """RNN forward. + + Args: + xs_pad: Batch of padded input sequences (B, Tmax, idim) + ilens: Batch of lengths of input sequences (B) + prev_state: Batch of previous RNN states + + Returns: + : Batch of padded output sequences (B, Tmax, hdim) + or tuple w/ aux outputs ((B, Tmax, hdim), [L x (B, Tmax, hdim)]) + : Batch of lengths of output sequences (B) + : Batch of hidden state sequences (B, Tmax, hdim) + + """ + logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) + + aux_xs_list = [] + elayer_states = [] + for layer in range(self.elayers): + if not isinstance(ilens, torch.Tensor): + ilens = torch.tensor(ilens) + + xs_pack = pack_padded_sequence(xs_pad, ilens.cpu(), batch_first=True) + + rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer)) + rnn.flatten_parameters() + + if prev_state is not None and rnn.bidirectional: + prev_state = reset_backward_rnn_state(prev_state) + + xs, states = rnn( + xs_pack, hx=None if prev_state is None else prev_state[layer] + ) + elayer_states.append(states) + + xs_pad, ilens = pad_packed_sequence(xs, batch_first=True) + + if self.bidir: + xs_pad = xs_pad[:, :, : self.cdim] + xs_pad[:, :, self.cdim :] + + if layer in self.aux_task_layer_list: + aux_projected = torch.tanh( + self.l_last(xs_pad.contiguous().view(-1, xs_pad.size(2))) + ) + aux_xs_pad = aux_projected.view(xs_pad.size(0), xs_pad.size(1), -1) + + aux_xs_list.append(aux_xs_pad) + + if layer < self.elayers - 1: + xs_pad = self.dropout(xs_pad) + + projected = torch.tanh( + self.l_last(xs_pad.contiguous().view(-1, xs_pad.size(2))) + ) + xs_pad = projected.view(xs_pad.size(0), xs_pad.size(1), -1) + + if aux_xs_list: + return (xs_pad, aux_xs_list), ilens, elayer_states + else: + return xs_pad, ilens, elayer_states + + +def reset_backward_rnn_state( + states: Union[torch.Tensor, Tuple, List] +) -> Union[torch.Tensor, Tuple, List]: + """Set backward BRNN states to zeroes. + + Args: + states: RNN states + + Returns: + states: RNN states with backward set to zeroes + + """ + if isinstance(states, (list, tuple)): + for state in states: + state[1::2] = 0.0 + else: + states[1::2] = 0.0 + return states + + +class VGG2L(torch.nn.Module): + """VGG-like module. + + Args: + in_channel: number of input channels + + """ + + def __init__(self, in_channel: int = 1): + """Initialize VGG-like module.""" + super(VGG2L, self).__init__() + + # CNN layer (VGG motivated) + self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1) + self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1) + self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1) + self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1) + + self.in_channel = in_channel + + def forward(self, xs_pad: torch.Tensor, ilens: torch.Tensor, **kwargs): + """VGG2L forward. + + Args: + xs_pad: Batch of padded input sequences (B, Tmax, D) + ilens: Batch of lengths of input sequences (B) + + Returns: + : Batch of padded output sequences (B, Tmax // 4, 128 * D // 4) + : Batch of lengths of output sequences (B) + + """ + logging.debug(self.__class__.__name__ + " input lengths: " + str(ilens)) + + xs_pad = xs_pad.view( + xs_pad.size(0), + xs_pad.size(1), + self.in_channel, + xs_pad.size(2) // self.in_channel, + ).transpose(1, 2) + + xs_pad = F.relu(self.conv1_1(xs_pad)) + xs_pad = F.relu(self.conv1_2(xs_pad)) + xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True) + + xs_pad = F.relu(self.conv2_1(xs_pad)) + xs_pad = F.relu(self.conv2_2(xs_pad)) + xs_pad = F.max_pool2d(xs_pad, 2, stride=2, ceil_mode=True) + + if torch.is_tensor(ilens): + ilens = ilens.cpu().numpy() + else: + ilens = np.array(ilens, dtype=np.float32) + ilens = np.array(np.ceil(ilens / 2), dtype=np.int64) + ilens = np.array( + np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64 + ).tolist() + + xs_pad = xs_pad.transpose(1, 2) + xs_pad = xs_pad.contiguous().view( + xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3) + ) + + return xs_pad, ilens, None + + +class Encoder(torch.nn.Module): + """Encoder module. + + Args: + etype: Type of encoder network + idim: Number of dimensions of encoder network + elayers: Number of layers of encoder network + eunits: Number of RNN units of encoder network + eprojs: Number of projection units of encoder network + subsample: List of subsampling numbers + dropout: Dropout rate + in_channel: Number of input channels + + """ + + def __init__( + self, + etype: str, + idim: int, + elayers: int, + eunits: int, + eprojs: int, + subsample: np.ndarray, + dropout: float, + in_channel: int = 1, + aux_task_layer_list: List = [], + ): + """Initialize Encoder module.""" + super(Encoder, self).__init__() + + typ = etype.lstrip("vgg").rstrip("p") + if typ not in ["lstm", "gru", "blstm", "bgru"]: + logging.error("Error: need to specify an appropriate encoder architecture") + + if etype.startswith("vgg"): + if etype[-1] == "p": + self.enc = torch.nn.ModuleList( + [ + VGG2L(in_channel), + RNNP( + get_vgg2l_odim(idim, in_channel=in_channel), + elayers, + eunits, + eprojs, + subsample, + dropout, + typ=typ, + aux_task_layer_list=aux_task_layer_list, + ), + ] + ) + logging.info("Use CNN-VGG + " + typ.upper() + "P for encoder") + else: + self.enc = torch.nn.ModuleList( + [ + VGG2L(in_channel), + RNN( + get_vgg2l_odim(idim, in_channel=in_channel), + elayers, + eunits, + eprojs, + dropout, + typ=typ, + aux_task_layer_list=aux_task_layer_list, + ), + ] + ) + logging.info("Use CNN-VGG + " + typ.upper() + " for encoder") + self.conv_subsampling_factor = 4 + else: + if etype[-1] == "p": + self.enc = torch.nn.ModuleList( + [ + RNNP( + idim, + elayers, + eunits, + eprojs, + subsample, + dropout, + typ=typ, + aux_task_layer_list=aux_task_layer_list, + ) + ] + ) + logging.info(typ.upper() + " with every-layer projection for encoder") + else: + self.enc = torch.nn.ModuleList( + [ + RNN( + idim, + elayers, + eunits, + eprojs, + dropout, + typ=typ, + aux_task_layer_list=aux_task_layer_list, + ) + ] + ) + logging.info(typ.upper() + " without projection for encoder") + self.conv_subsampling_factor = 1 + + def forward(self, xs_pad, ilens, prev_states=None): + """Forward encoder. + + Args: + xs_pad: Batch of padded input sequences (B, Tmax, idim) + ilens: Batch of lengths of input sequences (B) + prev_state: Batch of previous encoder hidden states (B, ??) + + Returns: + : Batch of padded output sequences (B, Tmax, hdim) + or tuple w/ aux outputs ((B, Tmax, hdim), [L x (B, Tmax, hdim)]) + : Batch of lengths of output sequences (B) + : Batch of hidden state sequences (B, Tmax, hdim) + + """ + if prev_states is None: + prev_states = [None] * len(self.enc) + assert len(prev_states) == len(self.enc) + + current_states = [] + for module, prev_state in zip(self.enc, prev_states): + xs_pad, ilens, states = module( + xs_pad, + ilens, + prev_state=prev_state, + ) + current_states.append(states) + + if isinstance(xs_pad, tuple): + final_xs_pad, aux_xs_list = xs_pad[0], xs_pad[1] + + mask = to_device(final_xs_pad, make_pad_mask(ilens).unsqueeze(-1)) + + aux_xs_list = [layer.masked_fill(mask, 0.0) for layer in aux_xs_list] + + return ( + ( + final_xs_pad.masked_fill(mask, 0.0), + aux_xs_list, + ), + ilens, + current_states, + ) + else: + mask = to_device(xs_pad, make_pad_mask(ilens).unsqueeze(-1)) + + return xs_pad.masked_fill(mask, 0.0), ilens, current_states + + +def encoder_for( + args: argparse.Namespace, + idim: Union[int, List], + subsample: np.ndarray, + aux_task_layer_list: List = [], +) -> Union[torch.nn.Module, List[torch.nn.Module]]: + """Instantiate an encoder module given the program arguments. + + Args: + args: The model arguments + idim: Dimension of inputs or list of dimensions of inputs for each encoder + subsample: subsample factors or list of subsample factors for each encoder + + Returns: + : The encoder module or list of encoder modules + + """ + return Encoder( + args.etype, + idim, + args.elayers, + args.eunits, + args.eprojs, + subsample, + args.dropout_rate, + aux_task_layer_list=aux_task_layer_list, + ) diff --git a/espnet/nets/pytorch_backend/transducer/tdnn.py b/espnet/nets/pytorch_backend/transducer/tdnn.py new file mode 100644 index 0000000000000000000000000000000000000000..1041a81f85ea212f4de147aaae65341c31ea2a9d --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/tdnn.py @@ -0,0 +1,163 @@ +"""TDNN modules definition for transformer encoder.""" + +import logging +from typing import Tuple +from typing import Union + +import torch + + +class TDNN(torch.nn.Module): + """TDNN implementation with symmetric context. + + Args: + idim: Dimension of inputs + odim: Dimension of outputs + ctx_size: Size of context window + stride: Stride of the sliding blocks + dilation: Parameter to control the stride of + elements within the neighborhood + batch_norm: Whether to use batch normalization + relu: Whether to use non-linearity layer (ReLU) + + """ + + def __init__( + self, + idim: int, + odim: int, + ctx_size: int = 5, + dilation: int = 1, + stride: int = 1, + batch_norm: bool = False, + relu: bool = True, + dropout_rate: float = 0.0, + ): + """Construct a TDNN object.""" + super().__init__() + + self.idim = idim + self.odim = odim + + self.ctx_size = ctx_size + self.stride = stride + self.dilation = dilation + + self.batch_norm = batch_norm + self.relu = relu + + self.tdnn = torch.nn.Conv1d( + idim, odim, ctx_size, stride=stride, dilation=dilation + ) + + if self.relu: + self.relu_func = torch.nn.ReLU() + + if self.batch_norm: + self.bn = torch.nn.BatchNorm1d(odim) + + self.dropout = torch.nn.Dropout(p=dropout_rate) + + def forward( + self, + x_input: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + masks: torch.Tensor, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]: + """Forward TDNN. + + Args: + x_input: Input tensor (B, T, idim) or ((B, T, idim), (B, T, att_dim)) + or ((B, T, idim), (B, 2*T-1, att_dim)) + masks: Input mask (B, 1, T) + + Returns: + x_output: Output tensor (B, sub(T), odim) + or ((B, sub(T), odim), (B, sub(T), att_dim)) + mask: Output mask (B, 1, sub(T)) + + """ + if isinstance(x_input, tuple): + xs, pos_emb = x_input[0], x_input[1] + else: + xs, pos_emb = x_input, None + + # The bidirect_pos is used to distinguish legacy_rel_pos and rel_pos in + # Conformer model. Note the `legacy_rel_pos` will be deprecated in the future. + # Details can be found in https://github.com/espnet/espnet/pull/2816. + if pos_emb is not None and pos_emb.size(1) == 2 * xs.size(1) - 1: + logging.warning("Using bidirectional relative postitional encoding.") + bidirect_pos = True + else: + bidirect_pos = False + + xs = xs.transpose(1, 2) + xs = self.tdnn(xs) + + if self.relu: + xs = self.relu_func(xs) + + xs = self.dropout(xs) + + if self.batch_norm: + xs = self.bn(xs) + + xs = xs.transpose(1, 2) + + return self.create_outputs(xs, pos_emb, masks, bidirect_pos=bidirect_pos) + + def create_outputs( + self, + xs: torch.Tensor, + pos_emb: torch.Tensor, + masks: torch.Tensor, + bidirect_pos: bool = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]: + """Create outputs with subsampled version of pos_emb and masks. + + Args: + xs: Output tensor (B, sub(T), odim) + pos_emb: Input positional embedding tensor (B, T, att_dim) + or (B, 2*T-1, att_dim) + masks: Input mask (B, 1, T) + bidirect_pos: whether to use bidirectional positional embedding + + Returns: + xs: Output tensor (B, sub(T), odim) + pos_emb: Output positional embedding tensor (B, sub(T), att_dim) + or (B, 2*sub(T)-1, att_dim) + masks: Output mask (B, 1, sub(T)) + + """ + sub = (self.ctx_size - 1) * self.dilation + + if masks is not None: + if sub != 0: + masks = masks[:, :, :-sub] + + masks = masks[:, :, :: self.stride] + + if pos_emb is not None: + # If the bidirect_pos is true, the pos_emb will include both positive and + # negative embeddings. Refer to https://github.com/espnet/espnet/pull/2816. + if bidirect_pos: + pos_emb_positive = pos_emb[:, : pos_emb.size(1) // 2 + 1, :] + pos_emb_negative = pos_emb[:, pos_emb.size(1) // 2 :, :] + + if sub != 0: + pos_emb_positive = pos_emb_positive[:, :-sub, :] + pos_emb_negative = pos_emb_negative[:, :-sub, :] + + pos_emb_positive = pos_emb_positive[:, :: self.stride, :] + pos_emb_negative = pos_emb_negative[:, :: self.stride, :] + pos_emb = torch.cat( + [pos_emb_positive, pos_emb_negative[:, 1:, :]], dim=1 + ) + else: + if sub != 0: + pos_emb = pos_emb[:, :-sub, :] + + pos_emb = pos_emb[:, :: self.stride, :] + + return (xs, pos_emb), masks + + return xs, masks diff --git a/espnet/nets/pytorch_backend/transducer/transformer_decoder_layer.py b/espnet/nets/pytorch_backend/transducer/transformer_decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..e84070c65e9ff6c4bbe60270da6130a5a78ad709 --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/transformer_decoder_layer.py @@ -0,0 +1,75 @@ +"""Decoder layer definition for transformer-transducer models.""" + +import torch +from torch import nn + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + + +class DecoderLayer(nn.Module): + """Single decoder layer module for transformer-transducer models. + + Args: + size (int): input dim + self_attn (MultiHeadedAttention): self attention module + feed_forward (PositionwiseFeedForward): feed forward layer module + dropout_rate (float): dropout rate + normalize_before (bool): whether to use layer_norm before the first block + + """ + + def __init__(self, size, self_attn, feed_forward, dropout_rate): + """Construct an DecoderLayer object.""" + super().__init__() + + self.self_attn = self_attn + self.feed_forward = feed_forward + + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + + self.dropout = nn.Dropout(dropout_rate) + + self.size = size + + def forward(self, tgt, tgt_mask, cache=None): + """Compute decoded features. + + Args: + tgt (torch.Tensor): decoded previous target features (B, Lmax, idim) + tgt_mask (torch.Tensor): mask for tgt (B, Lmax) + cache (torch.Tensor): cached output (B, Lmax-1, idim) + + Returns: + tgt (torch.Tensor): decoder target features (B, Lmax, odim) + tgt_mask (torch.Tensor): mask for tgt (B, Lmax) + """ + residual = tgt + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + else: + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + + if tgt_mask is not None: + tgt_mask = tgt_mask[:, -1:, :] + + tgt = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_mask)) + + residual = tgt + tgt = self.norm2(tgt) + + tgt = residual + self.dropout(self.feed_forward(tgt)) + + if cache is not None: + tgt = torch.cat([cache, tgt], dim=1) + + return tgt, tgt_mask diff --git a/espnet/nets/pytorch_backend/transducer/utils.py b/espnet/nets/pytorch_backend/transducer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..077eb1896dbbc022e7f238c33adbf05b541cb370 --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/utils.py @@ -0,0 +1,340 @@ +"""Utility functions for transducer models.""" + +import os + +import numpy as np +import torch + +from espnet.nets.pytorch_backend.nets_utils import pad_list + + +def prepare_loss_inputs(ys_pad, hlens, blank_id=0, ignore_id=-1): + """Prepare tensors for transducer loss computation. + + Args: + ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) + hlens (torch.Tensor): batch of hidden sequence lengthts (B) + or batch of masks (B, 1, Tmax) + blank_id (int): index of blank label + ignore_id (int): index of initial padding + + Returns: + ys_in_pad (torch.Tensor): batch of padded target sequences + blank (B, Lmax + 1) + target (torch.Tensor): batch of padded target sequences (B, Lmax) + pred_len (torch.Tensor): batch of hidden sequence lengths (B) + target_len (torch.Tensor): batch of output sequence lengths (B) + + """ + device = ys_pad.device + + ys = [y[y != ignore_id] for y in ys_pad] + blank = ys[0].new([blank_id]) + + ys_in_pad = pad_list([torch.cat([blank, y], dim=0) for y in ys], blank_id) + ys_out_pad = pad_list([torch.cat([y, blank], dim=0) for y in ys], ignore_id) + + target = pad_list(ys, blank_id).type(torch.int32).to(device) + target_len = torch.IntTensor([y.size(0) for y in ys]).to(device) + + if torch.is_tensor(hlens): + if hlens.dim() > 1: + hs = [h[h != 0] for h in hlens] + hlens = list(map(int, [h.size(0) for h in hs])) + else: + hlens = list(map(int, hlens)) + + pred_len = torch.IntTensor(hlens).to(device) + + return ys_in_pad, ys_out_pad, target, pred_len, target_len + + +def valid_aux_task_layer_list(aux_layer_ids, enc_num_layers): + """Check whether input list of auxiliary layer ids is valid. + + Return the valid list sorted with duplicated removed. + + Args: + aux_layer_ids (list): Auxiliary layers ids + enc_num_layers (int): Number of encoder layers + + Returns: + valid (list): Validated list of layers for auxiliary task + + """ + if ( + not isinstance(aux_layer_ids, list) + or not aux_layer_ids + or not all(isinstance(layer, int) for layer in aux_layer_ids) + ): + raise ValueError("--aux-task-layer-list argument takes a list of layer ids.") + + sorted_list = sorted(aux_layer_ids, key=int, reverse=False) + valid = list(filter(lambda x: 0 <= x < enc_num_layers, sorted_list)) + + if sorted_list != valid: + raise ValueError( + "Provided list of layer ids for auxiliary task is incorrect. " + "IDs should be between [0, %d]" % (enc_num_layers - 1) + ) + + return valid + + +def is_prefix(x, pref): + """Check prefix. + + Args: + x (list): token id sequence + pref (list): token id sequence + + Returns: + (boolean): whether pref is a prefix of x. + + """ + if len(pref) >= len(x): + return False + + for i in range(len(pref)): + if pref[i] != x[i]: + return False + + return True + + +def substract(x, subset): + """Remove elements of subset if corresponding token id sequence exist in x. + + Args: + x (list): set of hypotheses + subset (list): subset of hypotheses + + Returns: + final (list): new set + + """ + final = [] + + for x_ in x: + if any(x_.yseq == sub.yseq for sub in subset): + continue + final.append(x_) + + return final + + +def select_lm_state(lm_states, idx, lm_layers, is_wordlm): + """Get LM state from batch for given id. + + Args: + lm_states (list or dict): batch of LM states + idx (int): index to extract state from batch state + lm_layers (int): number of LM layers + is_wordlm (bool): whether provided LM is a word-LM + + Returns: + idx_state (dict): LM state for given id + + """ + if is_wordlm: + idx_state = lm_states[idx] + else: + idx_state = {} + + idx_state["c"] = [lm_states["c"][layer][idx] for layer in range(lm_layers)] + idx_state["h"] = [lm_states["h"][layer][idx] for layer in range(lm_layers)] + + return idx_state + + +def create_lm_batch_state(lm_states_list, lm_layers, is_wordlm): + """Create batch of LM states. + + Args: + lm_states (list or dict): list of individual LM states + lm_layers (int): number of LM layers + is_wordlm (bool): whether provided LM is a word-LM + + Returns: + batch_states (list): batch of LM states + + """ + if is_wordlm: + batch_states = lm_states_list + else: + batch_states = {} + + batch_states["c"] = [ + torch.stack([state["c"][layer] for state in lm_states_list]) + for layer in range(lm_layers) + ] + batch_states["h"] = [ + torch.stack([state["h"][layer] for state in lm_states_list]) + for layer in range(lm_layers) + ] + + return batch_states + + +def init_lm_state(lm_model): + """Initialize LM state. + + Args: + lm_model (torch.nn.Module): LM module + + Returns: + lm_state (dict): initial LM state + + """ + lm_layers = len(lm_model.rnn) + lm_units_typ = lm_model.typ + lm_units = lm_model.n_units + + p = next(lm_model.parameters()) + + h = [ + torch.zeros(lm_units).to(device=p.device, dtype=p.dtype) + for _ in range(lm_layers) + ] + + lm_state = {"h": h} + + if lm_units_typ == "lstm": + lm_state["c"] = [ + torch.zeros(lm_units).to(device=p.device, dtype=p.dtype) + for _ in range(lm_layers) + ] + + return lm_state + + +def recombine_hyps(hyps): + """Recombine hypotheses with equivalent output sequence. + + Args: + hyps (list): list of hypotheses + + Returns: + final (list): list of recombined hypotheses + + """ + final = [] + + for hyp in hyps: + seq_final = [f.yseq for f in final if f.yseq] + + if hyp.yseq in seq_final: + seq_pos = seq_final.index(hyp.yseq) + + final[seq_pos].score = np.logaddexp(final[seq_pos].score, hyp.score) + else: + final.append(hyp) + + return hyps + + +def pad_sequence(seqlist, pad_token): + """Left pad list of token id sequences. + + Args: + seqlist (list): list of token id sequences + pad_token (int): padding token id + + Returns: + final (list): list of padded token id sequences + + """ + maxlen = max(len(x) for x in seqlist) + + final = [([pad_token] * (maxlen - len(x))) + x for x in seqlist] + + return final + + +def check_state(state, max_len, pad_token): + """Check state and left pad or trim if necessary. + + Args: + state (list): list of of L decoder states (in_len, dec_dim) + max_len (int): maximum length authorized + pad_token (int): padding token id + + Returns: + final (list): list of L padded decoder states (1, max_len, dec_dim) + + """ + if state is None or max_len < 1 or state[0].size(1) == max_len: + return state + + curr_len = state[0].size(1) + + if curr_len > max_len: + trim_val = int(state[0].size(1) - max_len) + + for i, s in enumerate(state): + state[i] = s[:, trim_val:, :] + else: + layers = len(state) + ddim = state[0].size(2) + + final_dims = (1, max_len, ddim) + final = [state[0].data.new(*final_dims).fill_(pad_token) for _ in range(layers)] + + for i, s in enumerate(state): + final[i][:, (max_len - s.size(1)) : max_len, :] = s + + return final + + return state + + +def check_batch_state(state, max_len, pad_token): + """Check batch of states and left pad or trim if necessary. + + Args: + state (list): list of of L decoder states (B, ?, dec_dim) + max_len (int): maximum length authorized + pad_token (int): padding token id + + Returns: + final (list): list of L decoder states (B, pred_len, dec_dim) + + """ + final_dims = (len(state), max_len, state[0].size(1)) + final = state[0].data.new(*final_dims).fill_(pad_token) + + for i, s in enumerate(state): + curr_len = s.size(0) + + if curr_len < max_len: + final[i, (max_len - curr_len) : max_len, :] = s + else: + final[i, :, :] = s[(curr_len - max_len) :, :] + + return final + + +def custom_torch_load(model_path, model, training=True): + """Load transducer model modules and parameters with training-only ones removed. + + Args: + model_path (str): Model path + model (torch.nn.Module): The model with pretrained modules + + """ + if "snapshot" in os.path.basename(model_path): + model_state_dict = torch.load( + model_path, map_location=lambda storage, loc: storage + )["model"] + else: + model_state_dict = torch.load( + model_path, map_location=lambda storage, loc: storage + ) + + if not training: + model_state_dict = { + k: v for k, v in model_state_dict.items() if not k.startswith("aux") + } + + model.load_state_dict(model_state_dict) + + del model_state_dict diff --git a/espnet/nets/pytorch_backend/transducer/vgg2l.py b/espnet/nets/pytorch_backend/transducer/vgg2l.py new file mode 100644 index 0000000000000000000000000000000000000000..18aeafb0f32c1feea7f38c28645ecac2d461b0e5 --- /dev/null +++ b/espnet/nets/pytorch_backend/transducer/vgg2l.py @@ -0,0 +1,89 @@ +"""VGG2L module definition for transformer encoder.""" + +from typing import Tuple +from typing import Union + +import torch + + +class VGG2L(torch.nn.Module): + """VGG2L module for custom encoder. + + Args: + idim: Dimension of inputs + odim: Dimension of outputs + pos_enc: Positional encoding class + + """ + + def __init__(self, idim: int, odim: int, pos_enc: torch.nn.Module = None): + """Construct a VGG2L object.""" + super().__init__() + + self.vgg2l = torch.nn.Sequential( + torch.nn.Conv2d(1, 64, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(64, 64, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((3, 2)), + torch.nn.Conv2d(64, 128, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(128, 128, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((2, 2)), + ) + + if pos_enc is not None: + self.output = torch.nn.Sequential( + torch.nn.Linear(128 * ((idim // 2) // 2), odim), pos_enc + ) + else: + self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim) + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor + ) -> Union[ + Tuple[torch.Tensor, torch.Tensor], + Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor], + ]: + """VGG2L forward for x. + + Args: + x: Input tensor (B, T, idim) + x_mask: Input mask (B, 1, T) + + Returns: + x: Output tensor (B, sub(T), odim) + or ((B, sub(T), odim), (B, sub(T), att_dim)) + x_mask: Output mask (B, 1, sub(T)) + + """ + x = x.unsqueeze(1) + x = self.vgg2l(x) + + b, c, t, f = x.size() + + x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + if x_mask is not None: + x_mask = self.create_new_mask(x_mask) + + return x, x_mask + + def create_new_mask(self, x_mask: torch.Tensor) -> torch.Tensor: + """Create a subsampled version of x_mask. + + Args: + x_mask: Input mask (B, 1, T) + + Returns: + x_mask: Output mask (B, 1, sub(T)) + + """ + x_t1 = x_mask.size(2) - (x_mask.size(2) % 3) + x_mask = x_mask[:, :, :x_t1][:, :, ::3] + + x_t2 = x_mask.size(2) - (x_mask.size(2) % 2) + x_mask = x_mask[:, :, :x_t2][:, :, ::2] + + return x_mask diff --git a/espnet/nets/pytorch_backend/transformer/__init__.py b/espnet/nets/pytorch_backend/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/pytorch_backend/transformer/add_sos_eos.py b/espnet/nets/pytorch_backend/transformer/add_sos_eos.py new file mode 100644 index 0000000000000000000000000000000000000000..1f763bc97dba0122f13794a6c5dc738b4ff0d825 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/add_sos_eos.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Unility funcitons for Transformer.""" + +import torch + + +def add_sos_eos(ys_pad, sos, eos, ignore_id): + """Add and labels. + + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :param int sos: index of + :param int eos: index of + :param int ignore_id: index of padding + :return: padded tensor (B, Lmax) + :rtype: torch.Tensor + :return: padded tensor (B, Lmax) + :rtype: torch.Tensor + """ + from espnet.nets.pytorch_backend.nets_utils import pad_list + + _sos = ys_pad.new([sos]) + _eos = ys_pad.new([eos]) + ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys + ys_in = [torch.cat([_sos, y], dim=0) for y in ys] + ys_out = [torch.cat([y, _eos], dim=0) for y in ys] + return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) diff --git a/espnet/nets/pytorch_backend/transformer/argument.py b/espnet/nets/pytorch_backend/transformer/argument.py new file mode 100644 index 0000000000000000000000000000000000000000..216a68d90c38f168fc34634427bb5e9ead044ad7 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/argument.py @@ -0,0 +1,159 @@ +# Copyright 2020 Hirofumi Inaguma +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Transformer common arguments.""" + + +from distutils.util import strtobool + + +def add_arguments_transformer_common(group): + """Add Transformer common arguments.""" + group.add_argument( + "--transformer-init", + type=str, + default="pytorch", + choices=[ + "pytorch", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + ], + help="how to initialize transformer parameters", + ) + group.add_argument( + "--transformer-input-layer", + type=str, + default="conv2d", + choices=["conv2d", "linear", "embed"], + help="transformer input layer type", + ) + group.add_argument( + "--transformer-attn-dropout-rate", + default=None, + type=float, + help="dropout in transformer attention. use --dropout-rate if None is set", + ) + group.add_argument( + "--transformer-lr", + default=10.0, + type=float, + help="Initial value of learning rate", + ) + group.add_argument( + "--transformer-warmup-steps", + default=25000, + type=int, + help="optimizer warmup steps", + ) + group.add_argument( + "--transformer-length-normalized-loss", + default=True, + type=strtobool, + help="normalize loss by length", + ) + group.add_argument( + "--transformer-encoder-selfattn-layer-type", + type=str, + default="selfattn", + choices=[ + "selfattn", + "rel_selfattn", + "lightconv", + "lightconv2d", + "dynamicconv", + "dynamicconv2d", + "light-dynamicconv2d", + ], + help="transformer encoder self-attention layer type", + ) + group.add_argument( + "--transformer-decoder-selfattn-layer-type", + type=str, + default="selfattn", + choices=[ + "selfattn", + "lightconv", + "lightconv2d", + "dynamicconv", + "dynamicconv2d", + "light-dynamicconv2d", + ], + help="transformer decoder self-attention layer type", + ) + # Lightweight/Dynamic convolution related parameters. + # See https://arxiv.org/abs/1912.11793v2 + # and https://arxiv.org/abs/1901.10430 for detail of the method. + # Configurations used in the first paper are in + # egs/{csj, librispeech}/asr1/conf/tuning/ld_conv/ + group.add_argument( + "--wshare", + default=4, + type=int, + help="Number of parameter shargin for lightweight convolution", + ) + group.add_argument( + "--ldconv-encoder-kernel-length", + default="21_23_25_27_29_31_33_35_37_39_41_43", + type=str, + help="kernel size for lightweight/dynamic convolution: " + 'Encoder side. For example, "21_23_25" means kernel length 21 for ' + "First layer, 23 for Second layer and so on.", + ) + group.add_argument( + "--ldconv-decoder-kernel-length", + default="11_13_15_17_19_21", + type=str, + help="kernel size for lightweight/dynamic convolution: " + 'Decoder side. For example, "21_23_25" means kernel length 21 for ' + "First layer, 23 for Second layer and so on.", + ) + group.add_argument( + "--ldconv-usebias", + type=strtobool, + default=False, + help="use bias term in lightweight/dynamic convolution", + ) + group.add_argument( + "--dropout-rate", + default=0.0, + type=float, + help="Dropout rate for the encoder", + ) + # Encoder + group.add_argument( + "--elayers", + default=4, + type=int, + help="Number of encoder layers (for shared recognition part " + "in multi-speaker asr mode)", + ) + group.add_argument( + "--eunits", + "-u", + default=300, + type=int, + help="Number of encoder hidden units", + ) + # Attention + group.add_argument( + "--adim", + default=320, + type=int, + help="Number of attention transformation dimensions", + ) + group.add_argument( + "--aheads", + default=4, + type=int, + help="Number of heads for multi head attention", + ) + # Decoder + group.add_argument( + "--dlayers", default=1, type=int, help="Number of decoder layers" + ) + group.add_argument( + "--dunits", default=320, type=int, help="Number of decoder hidden units" + ) + return group diff --git a/espnet/nets/pytorch_backend/transformer/attention.py b/espnet/nets/pytorch_backend/transformer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..8d8b68089ec7629b22346f538dab359ff7560acd --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/attention.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Multi-Head Attention layer definition.""" + +import math + +import numpy +import torch +from torch import nn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """Transform query, key and value. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + + Returns: + torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). + torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). + torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + + Args: + value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). + scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). + mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). + + Returns: + torch.Tensor: Transformed value (#batch, time1, d_model) + weighted by the attention score (#batch, time1, time2). + + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = float( + numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min + ) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, mask): + """Compute scaled dot product attention. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding (old version). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + Paper: https://arxiv.org/abs/1901.02860 + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + + """ + + def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, time2). + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k + ) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + Paper: https://arxiv.org/abs/1901.02860 + + Args: + n_head (int): The number of heads. + n_feat (int): The number of features. + dropout_rate (float): Dropout rate. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + + """ + + def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + self.zero_triu = zero_triu + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x): + """Compute relative positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: Output tensor. + + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x)[ + :, :, :, : x.size(-1) // 2 + 1 + ] # only keep the positions from 0 to time2 + + if self.zero_triu: + ones = torch.ones((x.size(2), x.size(3)), device=x.device) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + Args: + query (torch.Tensor): Query tensor (#batch, time1, size). + key (torch.Tensor): Key tensor (#batch, time2, size). + value (torch.Tensor): Value tensor (#batch, time2, size). + pos_emb (torch.Tensor): Positional embedding tensor + (#batch, 2*time1-1, size). + mask (torch.Tensor): Mask tensor (#batch, 1, time2) or + (#batch, time1, time2). + + Returns: + torch.Tensor: Output tensor (#batch, time1, d_model). + + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k + ) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) diff --git a/espnet/nets/pytorch_backend/transformer/contextual_block_encoder_layer.py b/espnet/nets/pytorch_backend/transformer/contextual_block_encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..d4dd676ab841b5fa8fd237248615952755813ced --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/contextual_block_encoder_layer.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Emiru Tsunoo +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder self-attention layer definition.""" + +import torch + +from torch import nn + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + + +class ContextualBlockEncoderLayer(nn.Module): + """Contexutal Block Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + total_layer_num (int): Total number of layers + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + dropout_rate, + total_layer_num, + normalize_before=True, + concat_after=False, + ): + """Construct an EncoderLayer object.""" + super(ContextualBlockEncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + self.total_layer_num = total_layer_num + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + + def forward(self, x, mask, past_ctx=None, next_ctx=None, layer_idx=0, cache=None): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + past_ctx (torch.Tensor): Previous contexutal vector + next_ctx (torch.Tensor): Next contexutal vector + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + cur_ctx (torch.Tensor): Current contexutal vector + next_ctx (torch.Tensor): Next contexutal vector + layer_idx (int): layer index number + + """ + nbatch = x.size(0) + nblock = x.size(1) + + if past_ctx is not None: + if next_ctx is None: + # store all context vectors in one tensor + next_ctx = past_ctx.new_zeros( + nbatch, nblock, self.total_layer_num, x.size(-1) + ) + else: + x[:, :, 0] = past_ctx[:, :, layer_idx] + + # reshape ( nbatch, nblock, block_size + 2, dim ) + # -> ( nbatch * nblock, block_size + 2, dim ) + x = x.view(-1, x.size(-2), x.size(-1)) + if mask is not None: + mask = mask.view(-1, mask.size(-2), mask.size(-1)) + + residual = x + if self.normalize_before: + x = self.norm1(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + layer_idx += 1 + # reshape ( nbatch * nblock, block_size + 2, dim ) + # -> ( nbatch, nblock, block_size + 2, dim ) + x = x.view(nbatch, -1, x.size(-2), x.size(-1)).squeeze(1) + if mask is not None: + mask = mask.view(nbatch, -1, mask.size(-2), mask.size(-1)).squeeze(1) + + if next_ctx is not None and layer_idx < self.total_layer_num: + next_ctx[:, 0, layer_idx, :] = x[:, 0, -1, :] + next_ctx[:, 1:, layer_idx, :] = x[:, 0:-1, -1, :] + + return x, mask, next_ctx, next_ctx, layer_idx diff --git a/espnet/nets/pytorch_backend/transformer/decoder.py b/espnet/nets/pytorch_backend/transformer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..08ddbbbb98db04b32cc7cceeb3fbca38fec85029 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/decoder.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Decoder definition.""" + +import logging + +from typing import Any +from typing import List +from typing import Tuple + +import torch + +from espnet.nets.pytorch_backend.nets_utils import rename_state_dict +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer +from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution +from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.lightconv import LightweightConvolution +from espnet.nets.pytorch_backend.transformer.lightconv2d import LightweightConvolution2D +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.scorer_interface import BatchScorerInterface + + +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563 + rename_state_dict(prefix + "output_norm.", prefix + "after_norm.", state_dict) + + +class Decoder(BatchScorerInterface, torch.nn.Module): + """Transfomer decoder module. + + Args: + odim (int): Output diminsion. + self_attention_layer_type (str): Self-attention layer type. + attention_dim (int): Dimention of attention. + attention_heads (int): The number of heads of multi head attention. + conv_wshare (int): The number of kernel of convolution. Only used in + self_attention_layer_type == "lightconv*" or "dynamiconv*". + conv_kernel_length (Union[int, str]): Kernel size str of convolution + (e.g. 71_71_71_71_71_71). Only used in self_attention_layer_type + == "lightconv*" or "dynamiconv*". + conv_usebias (bool): Whether to use bias in convolution. Only used in + self_attention_layer_type == "lightconv*" or "dynamiconv*". + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + self_attention_dropout_rate (float): Dropout rate in self-attention. + src_attention_dropout_rate (float): Dropout rate in source-attention. + input_layer (Union[str, torch.nn.Module]): Input layer type. + use_output_layer (bool): Whether to use output layer. + pos_enc_class (torch.nn.Module): Positional encoding module class. + `PositionalEncoding `or `ScaledPositionalEncoding` + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__( + self, + odim, + selfattention_layer_type="selfattn", + attention_dim=256, + attention_heads=4, + conv_wshare=4, + conv_kernel_length=11, + conv_usebias=False, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + self_attention_dropout_rate=0.0, + src_attention_dropout_rate=0.0, + input_layer="embed", + use_output_layer=True, + pos_enc_class=PositionalEncoding, + normalize_before=True, + concat_after=False, + ): + """Construct an Decoder object.""" + torch.nn.Module.__init__(self) + self._register_load_state_dict_pre_hook(_pre_hook) + if input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(odim, attention_dim), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(odim, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, pos_enc_class(attention_dim, positional_dropout_rate) + ) + else: + raise NotImplementedError("only `embed` or torch.nn.Module is supported.") + self.normalize_before = normalize_before + + # self-attention module definition + if selfattention_layer_type == "selfattn": + logging.info("decoder self-attention layer type = self-attention") + decoder_selfattn_layer = MultiHeadedAttention + decoder_selfattn_layer_args = [ + ( + attention_heads, + attention_dim, + self_attention_dropout_rate, + ) + ] * num_blocks + elif selfattention_layer_type == "lightconv": + logging.info("decoder self-attention layer type = lightweight convolution") + decoder_selfattn_layer = LightweightConvolution + decoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + self_attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + True, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + elif selfattention_layer_type == "lightconv2d": + logging.info( + "decoder self-attention layer " + "type = lightweight convolution 2-dimentional" + ) + decoder_selfattn_layer = LightweightConvolution2D + decoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + self_attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + True, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + elif selfattention_layer_type == "dynamicconv": + logging.info("decoder self-attention layer type = dynamic convolution") + decoder_selfattn_layer = DynamicConvolution + decoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + self_attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + True, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + elif selfattention_layer_type == "dynamicconv2d": + logging.info( + "decoder self-attention layer type = dynamic convolution 2-dimentional" + ) + decoder_selfattn_layer = DynamicConvolution2D + decoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + self_attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + True, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + decoder_selfattn_layer(*decoder_selfattn_layer_args[lnum]), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + self.selfattention_layer_type = selfattention_layer_type + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + if use_output_layer: + self.output_layer = torch.nn.Linear(attention_dim, odim) + else: + self.output_layer = None + + def forward(self, tgt, tgt_mask, memory, memory_mask): + """Forward decoder. + + Args: + tgt (torch.Tensor): Input token ids, int64 (#batch, maxlen_out) if + input_layer == "embed". In the other case, input tensor + (#batch, maxlen_out, odim). + tgt_mask (torch.Tensor): Input token mask (#batch, maxlen_out). + dtype=torch.uint8 in PyTorch 1.2- and dtype=torch.bool in PyTorch 1.2+ + (include 1.2). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, feat). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + dtype=torch.uint8 in PyTorch 1.2- and dtype=torch.bool in PyTorch 1.2+ + (include 1.2). + + Returns: + torch.Tensor: Decoded token score before softmax (#batch, maxlen_out, odim) + if use_output_layer is True. In the other case,final block outputs + (#batch, maxlen_out, attention_dim). + torch.Tensor: Score mask before softmax (#batch, maxlen_out). + + """ + x = self.embed(tgt) + x, tgt_mask, memory, memory_mask = self.decoders( + x, tgt_mask, memory, memory_mask + ) + if self.normalize_before: + x = self.after_norm(x) + if self.output_layer is not None: + x = self.output_layer(x) + return x, tgt_mask + + def forward_one_step(self, tgt, tgt_mask, memory, cache=None): + """Forward one step. + + Args: + tgt (torch.Tensor): Input token ids, int64 (#batch, maxlen_out). + tgt_mask (torch.Tensor): Input token mask (#batch, maxlen_out). + dtype=torch.uint8 in PyTorch 1.2- and dtype=torch.bool in PyTorch 1.2+ + (include 1.2). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, feat). + cache (List[torch.Tensor]): List of cached tensors. + Each tensor shape should be (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor (batch, maxlen_out, odim). + List[torch.Tensor]: List of cache tensors of each decoder layer. + + """ + x = self.embed(tgt) + if cache is None: + cache = [None] * len(self.decoders) + new_cache = [] + for c, decoder in zip(cache, self.decoders): + x, tgt_mask, memory, memory_mask = decoder( + x, tgt_mask, memory, None, cache=c + ) + new_cache.append(x) + + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.output_layer is not None: + y = torch.log_softmax(self.output_layer(y), dim=-1) + + return y, new_cache + + # beam search API (see ScorerInterface) + def score(self, ys, state, x): + """Score.""" + ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) + if self.selfattention_layer_type != "selfattn": + # TODO(karita): implement cache + logging.warning( + f"{self.selfattention_layer_type} does not support cached decoding." + ) + state = None + logp, state = self.forward_one_step( + ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state + ) + return logp.squeeze(0), state + + # batch beam search API (see BatchScorerInterface) + def batch_score( + self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor + ) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch (required). + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.decoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [ + torch.stack([states[b][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + + # batch decoding + ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) + logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + return logp, state_list diff --git a/espnet/nets/pytorch_backend/transformer/decoder_layer.py b/espnet/nets/pytorch_backend/transformer/decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..98e5e011dfb2c12bf487d3a82ae5cd579e09d602 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/decoder_layer.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Decoder self-attention layer definition.""" + +import torch +from torch import nn + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + + +class DecoderLayer(nn.Module): + """Single decoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + src_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` instance can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + + """ + + def __init__( + self, + size, + self_attn, + src_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an DecoderLayer object.""" + super(DecoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.norm3 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear1 = nn.Linear(size + size, size) + self.concat_linear2 = nn.Linear(size + size, size) + + def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): + """Compute decoded features. + + Args: + tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). + tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). + memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). + memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). + cache (List[torch.Tensor]): List of cached tensors. + Each tensor shape should be (#batch, maxlen_out - 1, size). + + Returns: + torch.Tensor: Output tensor(#batch, maxlen_out, size). + torch.Tensor: Mask for output tensor (#batch, maxlen_out). + torch.Tensor: Encoded memory (#batch, maxlen_in, size). + torch.Tensor: Encoded memory mask (#batch, maxlen_in). + + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + + if cache is None: + tgt_q = tgt + tgt_q_mask = tgt_mask + else: + # compute only the last frame query keeping dim: max_time_out -> 1 + assert cache.shape == ( + tgt.shape[0], + tgt.shape[1] - 1, + self.size, + ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" + tgt_q = tgt[:, -1:, :] + residual = residual[:, -1:, :] + tgt_q_mask = None + if tgt_mask is not None: + tgt_q_mask = tgt_mask[:, -1:, :] + + if self.concat_after: + tgt_concat = torch.cat( + (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 + ) + x = residual + self.concat_linear1(tgt_concat) + else: + x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + if self.concat_after: + x_concat = torch.cat( + (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 + ) + x = residual + self.concat_linear2(x_concat) + else: + x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) + if not self.normalize_before: + x = self.norm2(x) + + residual = x + if self.normalize_before: + x = self.norm3(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm3(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, tgt_mask, memory, memory_mask diff --git a/espnet/nets/pytorch_backend/transformer/dynamic_conv.py b/espnet/nets/pytorch_backend/transformer/dynamic_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..bfff34fbcb6bdaca0126b98029c231cbd08414df --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/dynamic_conv.py @@ -0,0 +1,125 @@ +"""Dynamic Convolution module.""" + +import numpy +import torch +from torch import nn +import torch.nn.functional as F + + +MIN_VALUE = float(numpy.finfo(numpy.float32).min) + + +class DynamicConvolution(nn.Module): + """Dynamic Convolution layer. + + This implementation is based on + https://github.com/pytorch/fairseq/tree/master/fairseq + + Args: + wshare (int): the number of kernel of convolution + n_feat (int): the number of features + dropout_rate (float): dropout_rate + kernel_size (int): kernel size (length) + use_kernel_mask (bool): Use causal mask or not for convolution kernel + use_bias (bool): Use bias term or not. + + """ + + def __init__( + self, + wshare, + n_feat, + dropout_rate, + kernel_size, + use_kernel_mask=False, + use_bias=False, + ): + """Construct Dynamic Convolution layer.""" + super(DynamicConvolution, self).__init__() + + assert n_feat % wshare == 0 + self.wshare = wshare + self.use_kernel_mask = use_kernel_mask + self.dropout_rate = dropout_rate + self.kernel_size = kernel_size + self.attn = None + + # linear -> GLU -- -> lightconv -> linear + # \ / + # Linear + self.linear1 = nn.Linear(n_feat, n_feat * 2) + self.linear2 = nn.Linear(n_feat, n_feat) + self.linear_weight = nn.Linear(n_feat, self.wshare * 1 * kernel_size) + nn.init.xavier_uniform(self.linear_weight.weight) + self.act = nn.GLU() + + # dynamic conv related + self.use_bias = use_bias + if self.use_bias: + self.bias = nn.Parameter(torch.Tensor(n_feat)) + + def forward(self, query, key, value, mask): + """Forward of 'Dynamic Convolution'. + + This function takes query, key and value but uses only quert. + This is just for compatibility with self-attention layer (attention.py) + + Args: + query (torch.Tensor): (batch, time1, d_model) input tensor + key (torch.Tensor): (batch, time2, d_model) NOT USED + value (torch.Tensor): (batch, time2, d_model) NOT USED + mask (torch.Tensor): (batch, time1, time2) mask + + Return: + x (torch.Tensor): (batch, time1, d_model) ouput + + """ + # linear -> GLU -- -> lightconv -> linear + # \ / + # Linear + x = query + B, T, C = x.size() + H = self.wshare + k = self.kernel_size + + # first liner layer + x = self.linear1(x) + + # GLU activation + x = self.act(x) + + # get kernel of convolution + weight = self.linear_weight(x) # B x T x kH + weight = F.dropout(weight, self.dropout_rate, training=self.training) + weight = weight.view(B, T, H, k).transpose(1, 2).contiguous() # B x H x T x k + weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype) + weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf")) + weight_new = weight_new.to(x.device) # B x H x T x T+k-1 + weight_new.as_strided( + (B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1) + ).copy_(weight) + weight_new = weight_new.narrow(-1, int((k - 1) / 2), T) # B x H x T x T(k) + if self.use_kernel_mask: + kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0) + weight_new = weight_new.masked_fill(kernel_mask == 0.0, float("-inf")) + weight_new = F.softmax(weight_new, dim=-1) + self.attn = weight_new + weight_new = weight_new.view(B * H, T, T) + + # convolution + x = x.transpose(1, 2).contiguous() # B x C x T + x = x.view(B * H, int(C / H), T).transpose(1, 2) + x = torch.bmm(weight_new, x) # BH x T x C/H + x = x.transpose(1, 2).contiguous().view(B, C, T) + + if self.use_bias: + x = x + self.bias.view(1, -1, 1) + x = x.transpose(1, 2) # B x T x C + + if mask is not None and not self.use_kernel_mask: + mask = mask.transpose(-1, -2) + x = x.masked_fill(mask == 0, 0.0) + + # second linear layer + x = self.linear2(x) + return x diff --git a/espnet/nets/pytorch_backend/transformer/dynamic_conv2d.py b/espnet/nets/pytorch_backend/transformer/dynamic_conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..dd49719b3d930e57bb9aa5faaf2d562b6d0c5f21 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/dynamic_conv2d.py @@ -0,0 +1,138 @@ +"""Dynamic 2-Dimentional Convolution module.""" + +import numpy +import torch +from torch import nn +import torch.nn.functional as F + + +MIN_VALUE = float(numpy.finfo(numpy.float32).min) + + +class DynamicConvolution2D(nn.Module): + """Dynamic 2-Dimentional Convolution layer. + + This implementation is based on + https://github.com/pytorch/fairseq/tree/master/fairseq + + Args: + wshare (int): the number of kernel of convolution + n_feat (int): the number of features + dropout_rate (float): dropout_rate + kernel_size (int): kernel size (length) + use_kernel_mask (bool): Use causal mask or not for convolution kernel + use_bias (bool): Use bias term or not. + + """ + + def __init__( + self, + wshare, + n_feat, + dropout_rate, + kernel_size, + use_kernel_mask=False, + use_bias=False, + ): + """Construct Dynamic 2-Dimentional Convolution layer.""" + super(DynamicConvolution2D, self).__init__() + + assert n_feat % wshare == 0 + self.wshare = wshare + self.use_kernel_mask = use_kernel_mask + self.dropout_rate = dropout_rate + self.kernel_size = kernel_size + self.padding_size = int(kernel_size / 2) + self.attn_t = None + self.attn_f = None + + # linear -> GLU -- -> lightconv -> linear + # \ / + # Linear + self.linear1 = nn.Linear(n_feat, n_feat * 2) + self.linear2 = nn.Linear(n_feat * 2, n_feat) + self.linear_weight = nn.Linear(n_feat, self.wshare * 1 * kernel_size) + nn.init.xavier_uniform(self.linear_weight.weight) + self.linear_weight_f = nn.Linear(n_feat, kernel_size) + nn.init.xavier_uniform(self.linear_weight_f.weight) + self.act = nn.GLU() + + # dynamic conv related + self.use_bias = use_bias + if self.use_bias: + self.bias = nn.Parameter(torch.Tensor(n_feat)) + + def forward(self, query, key, value, mask): + """Forward of 'Dynamic 2-Dimentional Convolution'. + + This function takes query, key and value but uses only query. + This is just for compatibility with self-attention layer (attention.py) + + Args: + query (torch.Tensor): (batch, time1, d_model) input tensor + key (torch.Tensor): (batch, time2, d_model) NOT USED + value (torch.Tensor): (batch, time2, d_model) NOT USED + mask (torch.Tensor): (batch, time1, time2) mask + + Return: + x (torch.Tensor): (batch, time1, d_model) ouput + + """ + # linear -> GLU -- -> lightconv -> linear + # \ / + # Linear + x = query + B, T, C = x.size() + H = self.wshare + k = self.kernel_size + + # first liner layer + x = self.linear1(x) + + # GLU activation + x = self.act(x) + + # convolution of frequency axis + weight_f = self.linear_weight_f(x).view(B * T, 1, k) # B x T x k + self.attn_f = weight_f.view(B, T, k).unsqueeze(1) + xf = F.conv1d( + x.view(1, B * T, C), weight_f, padding=self.padding_size, groups=B * T + ) + xf = xf.view(B, T, C) + + # get kernel of convolution + weight = self.linear_weight(x) # B x T x kH + weight = F.dropout(weight, self.dropout_rate, training=self.training) + weight = weight.view(B, T, H, k).transpose(1, 2).contiguous() # B x H x T x k + weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype) + weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf")) + weight_new = weight_new.to(x.device) # B x H x T x T+k-1 + weight_new.as_strided( + (B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1) + ).copy_(weight) + weight_new = weight_new.narrow(-1, int((k - 1) / 2), T) # B x H x T x T(k) + if self.use_kernel_mask: + kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0) + weight_new = weight_new.masked_fill(kernel_mask == 0.0, float("-inf")) + weight_new = F.softmax(weight_new, dim=-1) + self.attn_t = weight_new + weight_new = weight_new.view(B * H, T, T) + + # convolution + x = x.transpose(1, 2).contiguous() # B x C x T + x = x.view(B * H, int(C / H), T).transpose(1, 2) + x = torch.bmm(weight_new, x) + x = x.transpose(1, 2).contiguous().view(B, C, T) + + if self.use_bias: + x = x + self.bias.view(1, -1, 1) + x = x.transpose(1, 2) # B x T x C + x = torch.cat((x, xf), -1) # B x T x Cx2 + + if mask is not None and not self.use_kernel_mask: + mask = mask.transpose(-1, -2) + x = x.masked_fill(mask == 0, 0.0) + + # second linear layer + x = self.linear2(x) + return x diff --git a/espnet/nets/pytorch_backend/transformer/embedding.py b/espnet/nets/pytorch_backend/transformer/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..3a92e0f8cc9091744d3b9434fb9dfe8f8231c9df --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/embedding.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Positional Encoding Module.""" + +import math + +import torch + + +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward compatibility. + + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + reverse (bool): Whether to reverse the input position. Only for + the class LegacyRelPositionalEncoding. We remove it in the current + class RelPositionalEncoding. + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class ScaledPositionalEncoding(PositionalEncoding): + """Scaled positional encoding module. + + See Sec. 3.2 https://arxiv.org/abs/1809.08895 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class.""" + super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) + self.alpha = torch.nn.Parameter(torch.tensor(1.0)) + + def reset_parameters(self): + """Reset parameters.""" + self.alpha.data = torch.tensor(1.0) + + def forward(self, x): + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + + """ + self.extend_pe(x) + x = x + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(x) + + +class LegacyRelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module (old version). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class.""" + super().__init__( + d_model=d_model, + dropout_rate=dropout_rate, + max_len=max_len, + reverse=True, + ) + + def forward(self, x): + """Compute positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[:, : x.size(1)] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module (new implementation). + + Details can be found in https://github.com/espnet/espnet/pull/2816. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + selfattention_layer_type (str): Encoder attention layer type. + padding_idx (int): Padding idx for input_layer=embed. + + """ + + def __init__( + self, + idim, + attention_dim=256, + attention_heads=4, + conv_wshare=4, + conv_kernel_length="11", + conv_usebias=False, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + pos_enc_class=PositionalEncoding, + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + selfattention_layer_type="selfattn", + padding_idx=-1, + ): + """Construct an Encoder object.""" + super(Encoder, self).__init__() + self._register_load_state_dict_pre_hook(_pre_hook) + + self.conv_subsampling_factor = 1 + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(idim, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate) + self.conv_subsampling_factor = 4 + elif input_layer == "conv2d-scaled-pos-enc": + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + self.conv_subsampling_factor = 4 + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6(idim, attention_dim, dropout_rate) + self.conv_subsampling_factor = 6 + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8(idim, attention_dim, dropout_rate) + self.conv_subsampling_factor = 8 + elif input_layer == "vgg2l": + self.embed = VGG2L(idim, attention_dim) + self.conv_subsampling_factor = 4 + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + positionwise_layer, positionwise_layer_args = self.get_positionwise_layer( + positionwise_layer_type, + attention_dim, + linear_units, + dropout_rate, + positionwise_conv_kernel_size, + ) + if selfattention_layer_type in [ + "selfattn", + "rel_selfattn", + "legacy_rel_selfattn", + ]: + logging.info("encoder self-attention layer type = self-attention") + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = [ + ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + ] * num_blocks + elif selfattention_layer_type == "lightconv": + logging.info("encoder self-attention layer type = lightweight convolution") + encoder_selfattn_layer = LightweightConvolution + encoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + False, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + elif selfattention_layer_type == "lightconv2d": + logging.info( + "encoder self-attention layer " + "type = lightweight convolution 2-dimentional" + ) + encoder_selfattn_layer = LightweightConvolution2D + encoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + False, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + elif selfattention_layer_type == "dynamicconv": + logging.info("encoder self-attention layer type = dynamic convolution") + encoder_selfattn_layer = DynamicConvolution + encoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + False, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + elif selfattention_layer_type == "dynamicconv2d": + logging.info( + "encoder self-attention layer type = dynamic convolution 2-dimentional" + ) + encoder_selfattn_layer = DynamicConvolution2D + encoder_selfattn_layer_args = [ + ( + conv_wshare, + attention_dim, + attention_dropout_rate, + int(conv_kernel_length.split("_")[lnum]), + False, + conv_usebias, + ) + for lnum in range(num_blocks) + ] + else: + raise NotImplementedError(selfattention_layer_type) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args[lnum]), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def get_positionwise_layer( + self, + positionwise_layer_type="linear", + attention_dim=256, + linear_units=2048, + dropout_rate=0.1, + positionwise_conv_kernel_size=1, + ): + """Define positionwise layer.""" + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = (attention_dim, linear_units, dropout_rate) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + return positionwise_layer, positionwise_layer_args + + def forward(self, xs, masks): + """Encode input sequence. + + Args: + xs (torch.Tensor): Input tensor (#batch, time, idim). + masks (torch.Tensor): Mask tensor (#batch, time). + + Returns: + torch.Tensor: Output tensor (#batch, time, attention_dim). + torch.Tensor: Mask tensor (#batch, time). + + """ + if isinstance( + self.embed, + (Conv2dSubsampling, Conv2dSubsampling6, Conv2dSubsampling8, VGG2L), + ): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + xs, masks = self.encoders(xs, masks) + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks + + def forward_one_step(self, xs, masks, cache=None): + """Encode input frame. + + Args: + xs (torch.Tensor): Input tensor. + masks (torch.Tensor): Mask tensor. + cache (List[torch.Tensor]): List of cache tensors. + + Returns: + torch.Tensor: Output tensor. + torch.Tensor: Mask tensor. + List[torch.Tensor]: List of new cache tensors. + + """ + if isinstance(self.embed, Conv2dSubsampling): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + if cache is None: + cache = [None for _ in range(len(self.encoders))] + new_cache = [] + for c, e in zip(cache, self.encoders): + xs, masks = e(xs, masks, cache=c) + new_cache.append(xs) + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks, new_cache diff --git a/espnet/nets/pytorch_backend/transformer/encoder_layer.py b/espnet/nets/pytorch_backend/transformer/encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..5758fbc663f4146c25304e9ad65714d1a46594c6 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/encoder_layer.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder self-attention layer definition.""" + +import torch + +from torch import nn + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + + +class EncoderLayer(nn.Module): + """Encoder layer module. + + Args: + size (int): Input dimension. + self_attn (torch.nn.Module): Self-attention module instance. + `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance + can be used as the argument. + feed_forward (torch.nn.Module): Feed-forward module instance. + `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance + can be used as the argument. + dropout_rate (float): Dropout rate. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + + def forward(self, x, mask, cache=None): + """Compute encoded features. + + Args: + x_input (torch.Tensor): Input tensor (#batch, time, size). + mask (torch.Tensor): Mask tensor for the input (#batch, time). + cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). + + Returns: + torch.Tensor: Output tensor (#batch, time, size). + torch.Tensor: Mask tensor (#batch, time). + + """ + residual = x + if self.normalize_before: + x = self.norm1(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if self.concat_after: + x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(self.self_attn(x_q, x, x, mask)) + if not self.normalize_before: + x = self.norm1(x) + + residual = x + if self.normalize_before: + x = self.norm2(x) + x = residual + self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm2(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + return x, mask diff --git a/espnet/nets/pytorch_backend/transformer/encoder_mix.py b/espnet/nets/pytorch_backend/transformer/encoder_mix.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4bd94536849241f45367c29ecb3e178145cd39 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/encoder_mix.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder Mix definition.""" + +import torch + +from espnet.nets.pytorch_backend.transducer.vgg2l import VGG2L +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling + + +class EncoderMix(Encoder, torch.nn.Module): + """Transformer encoder module. + + :param int idim: input dim + :param int attention_dim: dimention of attention + :param int attention_heads: the number of heads of multi head attention + :param int linear_units: the number of units of position-wise feed forward + :param int num_blocks: the number of decoder blocks + :param float dropout_rate: dropout rate + :param float attention_dropout_rate: dropout rate in attention + :param float positional_dropout_rate: dropout rate after adding positional encoding + :param str or torch.nn.Module input_layer: input layer type + :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + :param bool normalize_before: whether to use layer_norm before the first block + :param bool concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + :param str positionwise_layer_type: linear of conv1d + :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + :param int padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + idim, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks_sd=4, + num_blocks_rec=8, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + pos_enc_class=PositionalEncoding, + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + padding_idx=-1, + num_spkrs=2, + ): + """Construct an Encoder object.""" + super(EncoderMix, self).__init__( + idim=idim, + selfattention_layer_type="selfattn", + attention_dim=attention_dim, + attention_heads=attention_heads, + linear_units=linear_units, + num_blocks=num_blocks_rec, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + attention_dropout_rate=attention_dropout_rate, + input_layer=input_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + concat_after=concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + padding_idx=padding_idx, + ) + positionwise_layer, positionwise_layer_args = self.get_positionwise_layer( + positionwise_layer_type, + attention_dim, + linear_units, + dropout_rate, + positionwise_conv_kernel_size, + ) + self.num_spkrs = num_spkrs + self.encoders_sd = torch.nn.ModuleList( + [ + repeat( + num_blocks_sd, + lambda lnum: EncoderLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, attention_dropout_rate + ), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + for i in range(num_spkrs) + ] + ) + + def forward(self, xs, masks): + """Encode input sequence. + + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + xs_sd, masks_sd = [None] * self.num_spkrs, [None] * self.num_spkrs + + for ns in range(self.num_spkrs): + xs_sd[ns], masks_sd[ns] = self.encoders_sd[ns](xs, masks) + xs_sd[ns], masks_sd[ns] = self.encoders(xs_sd[ns], masks_sd[ns]) # Enc_rec + if self.normalize_before: + xs_sd[ns] = self.after_norm(xs_sd[ns]) + return xs_sd, masks_sd + + def forward_one_step(self, xs, masks, cache=None): + """Encode input frame. + + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :param List[torch.Tensor] cache: cache tensors + :return: position embedded tensor, mask and new cache + :rtype Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + if isinstance(self.embed, Conv2dSubsampling): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + + new_cache_sd = [] + for ns in range(self.num_spkrs): + if cache is None: + cache = [ + None for _ in range(len(self.encoders_sd) + len(self.encoders_rec)) + ] + new_cache = [] + for c, e in zip(cache[: len(self.encoders_sd)], self.encoders_sd[ns]): + xs, masks = e(xs, masks, cache=c) + new_cache.append(xs) + for c, e in zip(cache[: len(self.encoders_sd) :], self.encoders_rec): + xs, masks = e(xs, masks, cache=c) + new_cache.append(xs) + new_cache_sd.append(new_cache) + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks, new_cache_sd diff --git a/espnet/nets/pytorch_backend/transformer/initializer.py b/espnet/nets/pytorch_backend/transformer/initializer.py new file mode 100644 index 0000000000000000000000000000000000000000..1bce5459c3de47630cdafbace242c0d467c5e733 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/initializer.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Parameter initialization.""" + +import torch + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + + +def initialize(model, init_type="pytorch"): + """Initialize Transformer module. + + :param torch.nn.Module model: transformer instance + :param str init_type: initialization type + """ + if init_type == "pytorch": + return + + # weight init + for p in model.parameters(): + if p.dim() > 1: + if init_type == "xavier_uniform": + torch.nn.init.xavier_uniform_(p.data) + elif init_type == "xavier_normal": + torch.nn.init.xavier_normal_(p.data) + elif init_type == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") + elif init_type == "kaiming_normal": + torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") + else: + raise ValueError("Unknown initialization: " + init_type) + # bias init + for p in model.parameters(): + if p.dim() == 1: + p.data.zero_() + + # reset some modules with default init + for m in model.modules(): + if isinstance(m, (torch.nn.Embedding, LayerNorm)): + m.reset_parameters() diff --git a/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py b/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..0d8b30338a778da9ba27870d51db24afd10d9b24 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Label smoothing module.""" + +import torch +from torch import nn + + +class LabelSmoothingLoss(nn.Module): + """Label-smoothing loss. + + :param int size: the number of class + :param int padding_idx: ignored class id + :param float smoothing: smoothing rate (0.0 means the conventional CE) + :param bool normalize_length: normalize loss by sequence length if True + :param torch.nn.Module criterion: loss function to be smoothed + """ + + def __init__( + self, + size, + padding_idx, + smoothing, + normalize_length=False, + criterion=nn.KLDivLoss(reduction="none"), + ): + """Construct an LabelSmoothingLoss object.""" + super(LabelSmoothingLoss, self).__init__() + self.criterion = criterion + self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.size = size + self.true_dist = None + self.normalize_length = normalize_length + + def forward(self, x, target): + """Compute loss between x and target. + + :param torch.Tensor x: prediction (batch, seqlen, class) + :param torch.Tensor target: + target signal masked with self.padding_id (batch, seqlen) + :return: scalar float value + :rtype torch.Tensor + """ + assert x.size(2) == self.size + batch_size = x.size(0) + x = x.view(-1, self.size) + target = target.view(-1) + with torch.no_grad(): + true_dist = x.clone() + true_dist.fill_(self.smoothing / (self.size - 1)) + ignore = target == self.padding_idx # (B,) + total = len(target) - ignore.sum().item() + target = target.masked_fill(ignore, 0) # avoid -1 index + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) + denom = total if self.normalize_length else batch_size + return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom diff --git a/espnet/nets/pytorch_backend/transformer/layer_norm.py b/espnet/nets/pytorch_backend/transformer/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..b47530ece7d32bc476e71b4e479e8af0ced708a5 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/layer_norm.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Layer normalization module.""" + +import torch + + +class LayerNorm(torch.nn.LayerNorm): + """Layer normalization module. + + Args: + nout (int): Output dim size. + dim (int): Dimension to be normalized. + + """ + + def __init__(self, nout, dim=-1): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=1e-12) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized tensor. + + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) diff --git a/espnet/nets/pytorch_backend/transformer/lightconv.py b/espnet/nets/pytorch_backend/transformer/lightconv.py new file mode 100644 index 0000000000000000000000000000000000000000..a940c6d9042563185e5a673f42137bcff7fa8d18 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/lightconv.py @@ -0,0 +1,112 @@ +"""Lightweight Convolution Module.""" + +import numpy +import torch +from torch import nn +import torch.nn.functional as F + + +MIN_VALUE = float(numpy.finfo(numpy.float32).min) + + +class LightweightConvolution(nn.Module): + """Lightweight Convolution layer. + + This implementation is based on + https://github.com/pytorch/fairseq/tree/master/fairseq + + Args: + wshare (int): the number of kernel of convolution + n_feat (int): the number of features + dropout_rate (float): dropout_rate + kernel_size (int): kernel size (length) + use_kernel_mask (bool): Use causal mask or not for convolution kernel + use_bias (bool): Use bias term or not. + + """ + + def __init__( + self, + wshare, + n_feat, + dropout_rate, + kernel_size, + use_kernel_mask=False, + use_bias=False, + ): + """Construct Lightweight Convolution layer.""" + super(LightweightConvolution, self).__init__() + + assert n_feat % wshare == 0 + self.wshare = wshare + self.use_kernel_mask = use_kernel_mask + self.dropout_rate = dropout_rate + self.kernel_size = kernel_size + self.padding_size = int(kernel_size / 2) + + # linear -> GLU -> lightconv -> linear + self.linear1 = nn.Linear(n_feat, n_feat * 2) + self.linear2 = nn.Linear(n_feat, n_feat) + self.act = nn.GLU() + + # lightconv related + self.weight = nn.Parameter( + torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1) + ) + self.use_bias = use_bias + if self.use_bias: + self.bias = nn.Parameter(torch.Tensor(n_feat)) + + # mask of kernel + kernel_mask0 = torch.zeros(self.wshare, int(kernel_size / 2)) + kernel_mask1 = torch.ones(self.wshare, int(kernel_size / 2 + 1)) + self.kernel_mask = torch.cat((kernel_mask1, kernel_mask0), dim=-1).unsqueeze(1) + + def forward(self, query, key, value, mask): + """Forward of 'Lightweight Convolution'. + + This function takes query, key and value but uses only query. + This is just for compatibility with self-attention layer (attention.py) + + Args: + query (torch.Tensor): (batch, time1, d_model) input tensor + key (torch.Tensor): (batch, time2, d_model) NOT USED + value (torch.Tensor): (batch, time2, d_model) NOT USED + mask (torch.Tensor): (batch, time1, time2) mask + + Return: + x (torch.Tensor): (batch, time1, d_model) ouput + + """ + # linear -> GLU -> lightconv -> linear + x = query + B, T, C = x.size() + H = self.wshare + + # first liner layer + x = self.linear1(x) + + # GLU activation + x = self.act(x) + + # lightconv + x = x.transpose(1, 2).contiguous().view(-1, H, T) # B x C x T + weight = F.dropout(self.weight, self.dropout_rate, training=self.training) + if self.use_kernel_mask: + self.kernel_mask = self.kernel_mask.to(x.device) + weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf")) + weight = F.softmax(weight, dim=-1) + x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view( + B, C, T + ) + if self.use_bias: + x = x + self.bias.view(1, -1, 1) + x = x.transpose(1, 2) # B x T x C + + if mask is not None and not self.use_kernel_mask: + mask = mask.transpose(-1, -2) + x = x.masked_fill(mask == 0, 0.0) + + # second linear layer + x = self.linear2(x) + return x diff --git a/espnet/nets/pytorch_backend/transformer/lightconv2d.py b/espnet/nets/pytorch_backend/transformer/lightconv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..5effb9729ef53e149a940640b519e41ad2d2a8f5 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/lightconv2d.py @@ -0,0 +1,124 @@ +"""Lightweight 2-Dimentional Convolution module.""" + +import numpy +import torch +from torch import nn +import torch.nn.functional as F + + +MIN_VALUE = float(numpy.finfo(numpy.float32).min) + + +class LightweightConvolution2D(nn.Module): + """Lightweight 2-Dimentional Convolution layer. + + This implementation is based on + https://github.com/pytorch/fairseq/tree/master/fairseq + + Args: + wshare (int): the number of kernel of convolution + n_feat (int): the number of features + dropout_rate (float): dropout_rate + kernel_size (int): kernel size (length) + use_kernel_mask (bool): Use causal mask or not for convolution kernel + use_bias (bool): Use bias term or not. + + """ + + def __init__( + self, + wshare, + n_feat, + dropout_rate, + kernel_size, + use_kernel_mask=False, + use_bias=False, + ): + """Construct Lightweight 2-Dimentional Convolution layer.""" + super(LightweightConvolution2D, self).__init__() + + assert n_feat % wshare == 0 + self.wshare = wshare + self.use_kernel_mask = use_kernel_mask + self.dropout_rate = dropout_rate + self.kernel_size = kernel_size + self.padding_size = int(kernel_size / 2) + + # linear -> GLU -> lightconv -> linear + self.linear1 = nn.Linear(n_feat, n_feat * 2) + self.linear2 = nn.Linear(n_feat * 2, n_feat) + self.act = nn.GLU() + + # lightconv related + self.weight = nn.Parameter( + torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1) + ) + self.weight_f = nn.Parameter(torch.Tensor(1, 1, kernel_size).uniform_(0, 1)) + self.use_bias = use_bias + if self.use_bias: + self.bias = nn.Parameter(torch.Tensor(n_feat)) + + # mask of kernel + kernel_mask0 = torch.zeros(self.wshare, int(kernel_size / 2)) + kernel_mask1 = torch.ones(self.wshare, int(kernel_size / 2 + 1)) + self.kernel_mask = torch.cat((kernel_mask1, kernel_mask0), dim=-1).unsqueeze(1) + + def forward(self, query, key, value, mask): + """Forward of 'Lightweight 2-Dimentional Convolution'. + + This function takes query, key and value but uses only query. + This is just for compatibility with self-attention layer (attention.py) + + Args: + query (torch.Tensor): (batch, time1, d_model) input tensor + key (torch.Tensor): (batch, time2, d_model) NOT USED + value (torch.Tensor): (batch, time2, d_model) NOT USED + mask (torch.Tensor): (batch, time1, time2) mask + + Return: + x (torch.Tensor): (batch, time1, d_model) ouput + + """ + # linear -> GLU -> lightconv -> linear + x = query + B, T, C = x.size() + H = self.wshare + + # first liner layer + x = self.linear1(x) + + # GLU activation + x = self.act(x) + + # convolution along frequency axis + weight_f = F.softmax(self.weight_f, dim=-1) + weight_f = F.dropout(weight_f, self.dropout_rate, training=self.training) + weight_new = torch.zeros( + B * T, 1, self.kernel_size, device=x.device, dtype=x.dtype + ).copy_(weight_f) + xf = F.conv1d( + x.view(1, B * T, C), weight_new, padding=self.padding_size, groups=B * T + ).view(B, T, C) + + # lightconv + x = x.transpose(1, 2).contiguous().view(-1, H, T) # B x C x T + weight = F.dropout(self.weight, self.dropout_rate, training=self.training) + if self.use_kernel_mask: + self.kernel_mask = self.kernel_mask.to(x.device) + weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf")) + weight = F.softmax(weight, dim=-1) + x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view( + B, C, T + ) + if self.use_bias: + x = x + self.bias.view(1, -1, 1) + x = x.transpose(1, 2) # B x T x C + x = torch.cat((x, xf), -1) # B x T x Cx2 + + if mask is not None and not self.use_kernel_mask: + mask = mask.transpose(-1, -2) + x = x.masked_fill(mask == 0, 0.0) + + # second linear layer + x = self.linear2(x) + return x diff --git a/espnet/nets/pytorch_backend/transformer/mask.py b/espnet/nets/pytorch_backend/transformer/mask.py new file mode 100644 index 0000000000000000000000000000000000000000..d9245d0917ca9c5179edb9f9b84ee9c65750e994 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/mask.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Mask module.""" + +from distutils.version import LooseVersion + +import torch + +is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0") +# LooseVersion('1.2.0') == LooseVersion(torch.__version__) can't include e.g. 1.2.0+aaa +is_torch_1_2 = ( + LooseVersion("1.3") > LooseVersion(torch.__version__) >= LooseVersion("1.2") +) +datatype = torch.bool if is_torch_1_2_plus else torch.uint8 + + +def subsequent_mask(size, device="cpu", dtype=datatype): + """Create mask for subsequent steps (size, size). + + :param int size: size of mask + :param str device: "cpu" or "cuda" or torch.Tensor.device + :param torch.dtype dtype: result dtype + :rtype: torch.Tensor + >>> subsequent_mask(3) + [[1, 0, 0], + [1, 1, 0], + [1, 1, 1]] + """ + if is_torch_1_2 and dtype == torch.bool: + # torch=1.2 doesn't support tril for bool tensor + ret = torch.ones(size, size, device=device, dtype=torch.uint8) + return torch.tril(ret, out=ret).type(dtype) + else: + ret = torch.ones(size, size, device=device, dtype=dtype) + return torch.tril(ret, out=ret) + + +def target_mask(ys_in_pad, ignore_id): + """Create mask for decoder self-attention. + + :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) + :param int ignore_id: index of padding + :param torch.dtype dtype: result dtype + :rtype: torch.Tensor (B, Lmax, Lmax) + """ + ys_mask = ys_in_pad != ignore_id + m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) + return ys_mask.unsqueeze(-2) & m diff --git a/espnet/nets/pytorch_backend/transformer/multi_layer_conv.py b/espnet/nets/pytorch_backend/transformer/multi_layer_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb0717b060d5815d44c83b711f8fc4659987f3a --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/multi_layer_conv.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer).""" + +import torch + + +class MultiLayeredConv1d(torch.nn.Module): + """Multi-layered conv1d for Transformer block. + + This is a module of multi-leyered conv1d designed + to replace positionwise feed-forward network + in Transforner block, which is introduced in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """Initialize MultiLayeredConv1d module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + """ + super(MultiLayeredConv1d, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Conv1d( + hidden_chans, + in_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (torch.Tensor): Batch of input tensors (B, T, in_chans). + + Returns: + torch.Tensor: Batch of output tensors (B, T, hidden_chans). + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) + + +class Conv1dLinear(torch.nn.Module): + """Conv1D + Linear for Transformer block. + + A variant of MultiLayeredConv1d, which replaces second conv-layer to linear. + + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """Initialize Conv1dLinear module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + """ + super(Conv1dLinear, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Linear(hidden_chans, in_chans) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (torch.Tensor): Batch of input tensors (B, T, in_chans). + + Returns: + torch.Tensor: Batch of output tensors (B, T, hidden_chans). + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x)) diff --git a/espnet/nets/pytorch_backend/transformer/optimizer.py b/espnet/nets/pytorch_backend/transformer/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff5ba740f091a9b95e163acca75960ba80cc7d0 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/optimizer.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Optimizer module.""" + +import torch + + +class NoamOpt(object): + """Optim wrapper that implements rate.""" + + def __init__(self, model_size, factor, warmup, optimizer): + """Construct an NoamOpt object.""" + self.optimizer = optimizer + self._step = 0 + self.warmup = warmup + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +def get_std_opt(model_params, d_model, warmup, factor): + """Get standard NoamOpt.""" + base = torch.optim.Adam(model_params, lr=0, betas=(0.9, 0.98), eps=1e-9) + return NoamOpt(d_model, factor, warmup, base) diff --git a/espnet/nets/pytorch_backend/transformer/plot.py b/espnet/nets/pytorch_backend/transformer/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..b44673fddc8b891c0fdd70a102983fc86c5081dd --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/plot.py @@ -0,0 +1,146 @@ +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging + +import matplotlib.pyplot as plt +import numpy +import os + +from espnet.asr import asr_utils + + +def _plot_and_save_attention(att_w, filename, xtokens=None, ytokens=None): + # dynamically import matplotlib due to not found error + from matplotlib.ticker import MaxNLocator + + d = os.path.dirname(filename) + if not os.path.exists(d): + os.makedirs(d) + w, h = plt.figaspect(1.0 / len(att_w)) + fig = plt.Figure(figsize=(w * 2, h * 2)) + axes = fig.subplots(1, len(att_w)) + if len(att_w) == 1: + axes = [axes] + for ax, aw in zip(axes, att_w): + # plt.subplot(1, len(att_w), h) + ax.imshow(aw.astype(numpy.float32), aspect="auto") + ax.set_xlabel("Input") + ax.set_ylabel("Output") + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.yaxis.set_major_locator(MaxNLocator(integer=True)) + # Labels for major ticks + if xtokens is not None: + ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, len(xtokens))) + ax.set_xticks(numpy.linspace(0, len(xtokens) - 1, 1), minor=True) + ax.set_xticklabels(xtokens + [""], rotation=40) + if ytokens is not None: + ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, len(ytokens))) + ax.set_yticks(numpy.linspace(0, len(ytokens) - 1, 1), minor=True) + ax.set_yticklabels(ytokens + [""]) + fig.tight_layout() + return fig + + +def savefig(plot, filename): + plot.savefig(filename) + plt.clf() + + +def plot_multi_head_attention( + data, + uttid_list, + attn_dict, + outdir, + suffix="png", + savefn=savefig, + ikey="input", + iaxis=0, + okey="output", + oaxis=0, + subsampling_factor=4, +): + """Plot multi head attentions. + + :param dict data: utts info from json file + :param List uttid_list: utterance IDs + :param dict[str, torch.Tensor] attn_dict: multi head attention dict. + values should be torch.Tensor (head, input_length, output_length) + :param str outdir: dir to save fig + :param str suffix: filename suffix including image type (e.g., png) + :param savefn: function to save + :param str ikey: key to access input + :param int iaxis: dimension to access input + :param str okey: key to access output + :param int oaxis: dimension to access output + :param subsampling_factor: subsampling factor in encoder + + """ + for name, att_ws in attn_dict.items(): + for idx, att_w in enumerate(att_ws): + data_i = data[uttid_list[idx]] + filename = "%s/%s.%s.%s" % (outdir, uttid_list[idx], name, suffix) + dec_len = int(data_i[okey][oaxis]["shape"][0]) + 1 # +1 for + enc_len = int(data_i[ikey][iaxis]["shape"][0]) + is_mt = "token" in data_i[ikey][iaxis].keys() + # for ASR/ST + if not is_mt: + enc_len //= subsampling_factor + xtokens, ytokens = None, None + if "encoder" in name: + att_w = att_w[:, :enc_len, :enc_len] + # for MT + if is_mt: + xtokens = data_i[ikey][iaxis]["token"].split() + ytokens = xtokens[:] + elif "decoder" in name: + if "self" in name: + # self-attention + att_w = att_w[:, :dec_len, :dec_len] + if "token" in data_i[okey][oaxis].keys(): + ytokens = data_i[okey][oaxis]["token"].split() + [""] + xtokens = [""] + data_i[okey][oaxis]["token"].split() + else: + # cross-attention + att_w = att_w[:, :dec_len, :enc_len] + if "token" in data_i[okey][oaxis].keys(): + ytokens = data_i[okey][oaxis]["token"].split() + [""] + # for MT + if is_mt: + xtokens = data_i[ikey][iaxis]["token"].split() + else: + logging.warning("unknown name for shaping attention") + fig = _plot_and_save_attention(att_w, filename, xtokens, ytokens) + savefn(fig, filename) + + +class PlotAttentionReport(asr_utils.PlotAttentionReport): + def plotfn(self, *args, **kwargs): + kwargs["ikey"] = self.ikey + kwargs["iaxis"] = self.iaxis + kwargs["okey"] = self.okey + kwargs["oaxis"] = self.oaxis + kwargs["subsampling_factor"] = self.factor + plot_multi_head_attention(*args, **kwargs) + + def __call__(self, trainer): + attn_dict, uttid_list = self.get_attention_weights() + suffix = "ep.{.updater.epoch}.png".format(trainer) + self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, suffix, savefig) + + def get_attention_weights(self): + return_batch, uttid_list = self.transform(self.data, return_uttid=True) + batch = self.converter([return_batch], self.device) + if isinstance(batch, tuple): + att_ws = self.att_vis_fn(*batch) + elif isinstance(batch, dict): + att_ws = self.att_vis_fn(**batch) + return att_ws, uttid_list + + def log_attentions(self, logger, step): + def log_fig(plot, filename): + logger.add_figure(os.path.basename(filename), plot, step) + plt.clf() + + attn_dict, uttid_list = self.get_attention_weights() + self.plotfn(self.data_dict, uttid_list, attn_dict, self.outdir, "", log_fig) diff --git a/espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py b/espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py new file mode 100644 index 0000000000000000000000000000000000000000..5a66445e9557c9ea5f4ad382a7532f9d4204ff54 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + Args: + idim (int): Input dimenstion. + hidden_units (int): The number of hidden units. + dropout_rate (float): Dropout rate. + + """ + + def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): + """Construct an PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.w_2 = torch.nn.Linear(hidden_units, idim) + self.dropout = torch.nn.Dropout(dropout_rate) + self.activation = activation + + def forward(self, x): + """Forward funciton.""" + return self.w_2(self.dropout(self.activation(self.w_1(x)))) diff --git a/espnet/nets/pytorch_backend/transformer/repeat.py b/espnet/nets/pytorch_backend/transformer/repeat.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d2676a8020bbb4cb44e84a199baece2c9e763b --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/repeat.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Repeat the same layer definition.""" + +import torch + + +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential.""" + + def forward(self, *args): + """Repeat.""" + for m in self: + args = m(*args) + return args + + +def repeat(N, fn): + """Repeat module N times. + + Args: + N (int): Number of repeat time. + fn (Callable): Function to generate module. + + Returns: + MultiSequential: Repeated model instance. + + """ + return MultiSequential(*[fn(n) for n in range(N)]) diff --git a/espnet/nets/pytorch_backend/transformer/subsampling.py b/espnet/nets/pytorch_backend/transformer/subsampling.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5a736d3aa0075ff44965096184bfa4971973e9 --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/subsampling.py @@ -0,0 +1,196 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Subsampling layer definition.""" + +import torch + +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding + + +class TooShortUttError(Exception): + """Raised when the utt is too short for subsampling. + + Args: + message (str): Message for error catch + actual_size (int): the short size that cannot pass the subsampling + limit (int): the limit size for subsampling + + """ + + def __init__(self, message, actual_size, limit): + """Construct a TooShortUttError for error handler.""" + super().__init__(message) + self.actual_size = actual_size + self.limit = limit + + +def check_short_utt(ins, size): + """Check if the utterance is too short for subsampling.""" + if isinstance(ins, Conv2dSubsampling) and size < 7: + return True, 7 + if isinstance(ins, Conv2dSubsampling6) and size < 11: + return True, 11 + if isinstance(ins, Conv2dSubsampling8) and size < 15: + return True, 15 + return False, -1 + + +class Conv2dSubsampling(torch.nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-2:2] + + def __getitem__(self, key): + """Get item. + + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + + +class Conv2dSubsampling6(torch.nn.Module): + """Convolutional 2D subsampling (to 1/6 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling6 object.""" + super(Conv2dSubsampling6, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 6. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 6. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-4:3] + + +class Conv2dSubsampling8(torch.nn.Module): + """Convolutional 2D subsampling (to 1/8 length). + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + pos_enc (torch.nn.Module): Custom position encoding layer. + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling8 object.""" + super(Conv2dSubsampling8, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 8. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 8. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] diff --git a/espnet/nets/pytorch_backend/transformer/subsampling_without_posenc.py b/espnet/nets/pytorch_backend/transformer/subsampling_without_posenc.py new file mode 100644 index 0000000000000000000000000000000000000000..239d3f1ade7f03435e44bf2414f7ab59cb055e6f --- /dev/null +++ b/espnet/nets/pytorch_backend/transformer/subsampling_without_posenc.py @@ -0,0 +1,61 @@ +# Copyright 2020 Emiru Tsunoo +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Subsampling layer definition.""" + +import math +import torch + + +class Conv2dSubsamplingWOPosEnc(torch.nn.Module): + """Convolutional 2D subsampling. + + Args: + idim (int): Input dimension. + odim (int): Output dimension. + dropout_rate (float): Dropout rate. + kernels (list): kernel sizes + strides (list): stride sizes + + """ + + def __init__(self, idim, odim, dropout_rate, kernels, strides): + """Construct an Conv2dSubsamplingWOPosEnc object.""" + assert len(kernels) == len(strides) + super().__init__() + conv = [] + olen = idim + for i, (k, s) in enumerate(zip(kernels, strides)): + conv += [ + torch.nn.Conv2d(1 if i == 0 else odim, odim, k, s), + torch.nn.ReLU(), + ] + olen = math.floor((olen - k) / s + 1) + self.conv = torch.nn.Sequential(*conv) + self.out = torch.nn.Linear(odim * olen, odim) + self.strides = strides + self.kernels = kernels + + def forward(self, x, x_mask): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, time, idim). + x_mask (torch.Tensor): Input mask (#batch, 1, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + torch.Tensor: Subsampled mask (#batch, 1, time'), + where time' = time // 4. + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + for k, s in zip(self.kernels, self.strides): + x_mask = x_mask[:, :, : -k + 1 : s] + return x, x_mask diff --git a/espnet/nets/pytorch_backend/wavenet.py b/espnet/nets/pytorch_backend/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..a14870e5b5af2874fca6c55f37b806a616bc30e0 --- /dev/null +++ b/espnet/nets/pytorch_backend/wavenet.py @@ -0,0 +1,448 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi (Nagoya University) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""This code is based on https://github.com/kan-bayashi/PytorchWaveNetVocoder.""" + +import logging +import sys +import time + +import numpy as np +import torch +import torch.nn.functional as F + +from torch import nn + + +def encode_mu_law(x, mu=256): + """Perform mu-law encoding. + + Args: + x (ndarray): Audio signal with the range from -1 to 1. + mu (int): Quantized level. + + Returns: + ndarray: Quantized audio signal with the range from 0 to mu - 1. + + """ + mu = mu - 1 + fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu) + return np.floor((fx + 1) / 2 * mu + 0.5).astype(np.int64) + + +def decode_mu_law(y, mu=256): + """Perform mu-law decoding. + + Args: + x (ndarray): Quantized audio signal with the range from 0 to mu - 1. + mu (int): Quantized level. + + Returns: + ndarray: Audio signal with the range from -1 to 1. + + """ + mu = mu - 1 + fx = (y - 0.5) / mu * 2 - 1 + x = np.sign(fx) / mu * ((1 + mu) ** np.abs(fx) - 1) + return x + + +def initialize(m): + """Initilize conv layers with xavier. + + Args: + m (torch.nn.Module): Torch module. + + """ + if isinstance(m, nn.Conv1d): + nn.init.xavier_uniform_(m.weight) + nn.init.constant_(m.bias, 0.0) + + if isinstance(m, nn.ConvTranspose2d): + nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0.0) + + +class OneHot(nn.Module): + """Convert to one-hot vector. + + Args: + depth (int): Dimension of one-hot vector. + + """ + + def __init__(self, depth): + super(OneHot, self).__init__() + self.depth = depth + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (LongTensor): long tensor variable with the shape (B, T) + + Returns: + Tensor: float tensor variable with the shape (B, depth, T) + + """ + x = x % self.depth + x = torch.unsqueeze(x, 2) + x_onehot = x.new_zeros(x.size(0), x.size(1), self.depth).float() + + return x_onehot.scatter_(2, x, 1) + + +class CausalConv1d(nn.Module): + """1D dilated causal convolution.""" + + def __init__(self, in_channels, out_channels, kernel_size, dilation=1, bias=True): + super(CausalConv1d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.dilation = dilation + self.padding = padding = (kernel_size - 1) * dilation + self.conv = nn.Conv1d( + in_channels, + out_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor with the shape (B, in_channels, T). + + Returns: + Tensor: Tensor with the shape (B, out_channels, T) + + """ + x = self.conv(x) + if self.padding != 0: + x = x[:, :, : -self.padding] + return x + + +class UpSampling(nn.Module): + """Upsampling layer with deconvolution. + + Args: + upsampling_factor (int): Upsampling factor. + + """ + + def __init__(self, upsampling_factor, bias=True): + super(UpSampling, self).__init__() + self.upsampling_factor = upsampling_factor + self.bias = bias + self.conv = nn.ConvTranspose2d( + 1, + 1, + kernel_size=(1, self.upsampling_factor), + stride=(1, self.upsampling_factor), + bias=self.bias, + ) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor with the shape (B, C, T) + + Returns: + Tensor: Tensor with the shape (B, C, T') where T' = T * upsampling_factor. + + """ + x = x.unsqueeze(1) # B x 1 x C x T + x = self.conv(x) # B x 1 x C x T' + return x.squeeze(1) + + +class WaveNet(nn.Module): + """Conditional wavenet. + + Args: + n_quantize (int): Number of quantization. + n_aux (int): Number of aux feature dimension. + n_resch (int): Number of filter channels for residual block. + n_skipch (int): Number of filter channels for skip connection. + dilation_depth (int): Number of dilation depth + (e.g. if set 10, max dilation = 2^(10-1)). + dilation_repeat (int): Number of dilation repeat. + kernel_size (int): Filter size of dilated causal convolution. + upsampling_factor (int): Upsampling factor. + + """ + + def __init__( + self, + n_quantize=256, + n_aux=28, + n_resch=512, + n_skipch=256, + dilation_depth=10, + dilation_repeat=3, + kernel_size=2, + upsampling_factor=0, + ): + super(WaveNet, self).__init__() + self.n_aux = n_aux + self.n_quantize = n_quantize + self.n_resch = n_resch + self.n_skipch = n_skipch + self.kernel_size = kernel_size + self.dilation_depth = dilation_depth + self.dilation_repeat = dilation_repeat + self.upsampling_factor = upsampling_factor + + self.dilations = [ + 2 ** i for i in range(self.dilation_depth) + ] * self.dilation_repeat + self.receptive_field = (self.kernel_size - 1) * sum(self.dilations) + 1 + + # for preprocessing + self.onehot = OneHot(self.n_quantize) + self.causal = CausalConv1d(self.n_quantize, self.n_resch, self.kernel_size) + if self.upsampling_factor > 0: + self.upsampling = UpSampling(self.upsampling_factor) + + # for residual blocks + self.dil_sigmoid = nn.ModuleList() + self.dil_tanh = nn.ModuleList() + self.aux_1x1_sigmoid = nn.ModuleList() + self.aux_1x1_tanh = nn.ModuleList() + self.skip_1x1 = nn.ModuleList() + self.res_1x1 = nn.ModuleList() + for d in self.dilations: + self.dil_sigmoid += [ + CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d) + ] + self.dil_tanh += [ + CausalConv1d(self.n_resch, self.n_resch, self.kernel_size, d) + ] + self.aux_1x1_sigmoid += [nn.Conv1d(self.n_aux, self.n_resch, 1)] + self.aux_1x1_tanh += [nn.Conv1d(self.n_aux, self.n_resch, 1)] + self.skip_1x1 += [nn.Conv1d(self.n_resch, self.n_skipch, 1)] + self.res_1x1 += [nn.Conv1d(self.n_resch, self.n_resch, 1)] + + # for postprocessing + self.conv_post_1 = nn.Conv1d(self.n_skipch, self.n_skipch, 1) + self.conv_post_2 = nn.Conv1d(self.n_skipch, self.n_quantize, 1) + + def forward(self, x, h): + """Calculate forward propagation. + + Args: + x (LongTensor): Quantized input waveform tensor with the shape (B, T). + h (Tensor): Auxiliary feature tensor with the shape (B, n_aux, T). + + Returns: + Tensor: Logits with the shape (B, T, n_quantize). + + """ + # preprocess + output = self._preprocess(x) + if self.upsampling_factor > 0: + h = self.upsampling(h) + + # residual block + skip_connections = [] + for i in range(len(self.dilations)): + output, skip = self._residual_forward( + output, + h, + self.dil_sigmoid[i], + self.dil_tanh[i], + self.aux_1x1_sigmoid[i], + self.aux_1x1_tanh[i], + self.skip_1x1[i], + self.res_1x1[i], + ) + skip_connections.append(skip) + + # skip-connection part + output = sum(skip_connections) + output = self._postprocess(output) + + return output + + def generate(self, x, h, n_samples, interval=None, mode="sampling"): + """Generate a waveform with fast genration algorithm. + + This generation based on `Fast WaveNet Generation Algorithm`_. + + Args: + x (LongTensor): Initial waveform tensor with the shape (T,). + h (Tensor): Auxiliary feature tensor with the shape (n_samples + T, n_aux). + n_samples (int): Number of samples to be generated. + interval (int, optional): Log interval. + mode (str, optional): "sampling" or "argmax". + + Return: + ndarray: Generated quantized waveform (n_samples). + + .. _`Fast WaveNet Generation Algorithm`: https://arxiv.org/abs/1611.09482 + + """ + # reshape inputs + assert len(x.shape) == 1 + assert len(h.shape) == 2 and h.shape[1] == self.n_aux + x = x.unsqueeze(0) + h = h.transpose(0, 1).unsqueeze(0) + + # perform upsampling + if self.upsampling_factor > 0: + h = self.upsampling(h) + + # padding for shortage + if n_samples > h.shape[2]: + h = F.pad(h, (0, n_samples - h.shape[2]), "replicate") + + # padding if the length less than + n_pad = self.receptive_field - x.size(1) + if n_pad > 0: + x = F.pad(x, (n_pad, 0), "constant", self.n_quantize // 2) + h = F.pad(h, (n_pad, 0), "replicate") + + # prepare buffer + output = self._preprocess(x) + h_ = h[:, :, : x.size(1)] + output_buffer = [] + buffer_size = [] + for i, d in enumerate(self.dilations): + output, _ = self._residual_forward( + output, + h_, + self.dil_sigmoid[i], + self.dil_tanh[i], + self.aux_1x1_sigmoid[i], + self.aux_1x1_tanh[i], + self.skip_1x1[i], + self.res_1x1[i], + ) + if d == 2 ** (self.dilation_depth - 1): + buffer_size.append(self.kernel_size - 1) + else: + buffer_size.append(d * 2 * (self.kernel_size - 1)) + output_buffer.append(output[:, :, -buffer_size[i] - 1 : -1]) + + # generate + samples = x[0] + start_time = time.time() + for i in range(n_samples): + output = samples[-self.kernel_size * 2 + 1 :].unsqueeze(0) + output = self._preprocess(output) + h_ = h[:, :, samples.size(0) - 1].contiguous().view(1, self.n_aux, 1) + output_buffer_next = [] + skip_connections = [] + for j, d in enumerate(self.dilations): + output, skip = self._generate_residual_forward( + output, + h_, + self.dil_sigmoid[j], + self.dil_tanh[j], + self.aux_1x1_sigmoid[j], + self.aux_1x1_tanh[j], + self.skip_1x1[j], + self.res_1x1[j], + ) + output = torch.cat([output_buffer[j], output], dim=2) + output_buffer_next.append(output[:, :, -buffer_size[j] :]) + skip_connections.append(skip) + + # update buffer + output_buffer = output_buffer_next + + # get predicted sample + output = sum(skip_connections) + output = self._postprocess(output)[0] + if mode == "sampling": + posterior = F.softmax(output[-1], dim=0) + dist = torch.distributions.Categorical(posterior) + sample = dist.sample().unsqueeze(0) + elif mode == "argmax": + sample = output.argmax(-1) + else: + logging.error("mode should be sampling or argmax") + sys.exit(1) + samples = torch.cat([samples, sample], dim=0) + + # show progress + if interval is not None and (i + 1) % interval == 0: + elapsed_time_per_sample = (time.time() - start_time) / interval + logging.info( + "%d/%d estimated time = %.3f sec (%.3f sec / sample)" + % ( + i + 1, + n_samples, + (n_samples - i - 1) * elapsed_time_per_sample, + elapsed_time_per_sample, + ) + ) + start_time = time.time() + + return samples[-n_samples:].cpu().numpy() + + def _preprocess(self, x): + x = self.onehot(x).transpose(1, 2) + output = self.causal(x) + return output + + def _postprocess(self, x): + output = F.relu(x) + output = self.conv_post_1(output) + output = F.relu(output) # B x C x T + output = self.conv_post_2(output).transpose(1, 2) # B x T x C + return output + + def _residual_forward( + self, + x, + h, + dil_sigmoid, + dil_tanh, + aux_1x1_sigmoid, + aux_1x1_tanh, + skip_1x1, + res_1x1, + ): + output_sigmoid = dil_sigmoid(x) + output_tanh = dil_tanh(x) + aux_output_sigmoid = aux_1x1_sigmoid(h) + aux_output_tanh = aux_1x1_tanh(h) + output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * torch.tanh( + output_tanh + aux_output_tanh + ) + skip = skip_1x1(output) + output = res_1x1(output) + output = output + x + return output, skip + + def _generate_residual_forward( + self, + x, + h, + dil_sigmoid, + dil_tanh, + aux_1x1_sigmoid, + aux_1x1_tanh, + skip_1x1, + res_1x1, + ): + output_sigmoid = dil_sigmoid(x)[:, :, -1:] + output_tanh = dil_tanh(x)[:, :, -1:] + aux_output_sigmoid = aux_1x1_sigmoid(h) + aux_output_tanh = aux_1x1_tanh(h) + output = torch.sigmoid(output_sigmoid + aux_output_sigmoid) * torch.tanh( + output_tanh + aux_output_tanh + ) + skip = skip_1x1(output) + output = res_1x1(output) + output = output + x[:, :, -1:] # B x C x 1 + return output, skip diff --git a/espnet/nets/scorer_interface.py b/espnet/nets/scorer_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..c4865d296b6fc8f5f8c969b0ba4f0f1cdd93bc67 --- /dev/null +++ b/espnet/nets/scorer_interface.py @@ -0,0 +1,188 @@ +"""Scorer interface module.""" + +from typing import Any +from typing import List +from typing import Tuple + +import torch +import warnings + + +class ScorerInterface: + """Scorer interface for beam search. + + The scorer performs scoring of the all tokens in vocabulary. + + Examples: + * Search heuristics + * :class:`espnet.nets.scorers.length_bonus.LengthBonus` + * Decoder networks of the sequence-to-sequence models + * :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder` + * :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder` + * Neural language models + * :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM` + * :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM` + * :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM` + + """ + + def init_state(self, x: torch.Tensor) -> Any: + """Get an initial state for decoding (optional). + + Args: + x (torch.Tensor): The encoded feature tensor + + Returns: initial state + + """ + return None + + def select_state(self, state: Any, i: int, new_id: int = None) -> Any: + """Select state with relative ids in the main beam search. + + Args: + state: Decoder state for prefix tokens + i (int): Index to select a state in the main beam search + new_id (int): New label index to select a state if necessary + + Returns: + state: pruned state + + """ + return None if state is None else state[i] + + def score( + self, y: torch.Tensor, state: Any, x: torch.Tensor + ) -> Tuple[torch.Tensor, Any]: + """Score new token (required). + + Args: + y (torch.Tensor): 1D torch.int64 prefix tokens. + state: Scorer state for prefix tokens + x (torch.Tensor): The encoder feature that generates ys. + + Returns: + tuple[torch.Tensor, Any]: Tuple of + scores for next token that has a shape of `(n_vocab)` + and next state for ys + + """ + raise NotImplementedError + + def final_score(self, state: Any) -> float: + """Score eos (optional). + + Args: + state: Scorer state for prefix tokens + + Returns: + float: final score + + """ + return 0.0 + + +class BatchScorerInterface(ScorerInterface): + """Batch scorer interface.""" + + def batch_init_state(self, x: torch.Tensor) -> Any: + """Get an initial state for decoding (optional). + + Args: + x (torch.Tensor): The encoded feature tensor + + Returns: initial state + + """ + return self.init_state(x) + + def batch_score( + self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor + ) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch (required). + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + warnings.warn( + "{} batch score is implemented through for loop not parallelized".format( + self.__class__.__name__ + ) + ) + scores = list() + outstates = list() + for i, (y, state, x) in enumerate(zip(ys, states, xs)): + score, outstate = self.score(y, state, x) + outstates.append(outstate) + scores.append(score) + scores = torch.cat(scores, 0).view(ys.shape[0], -1) + return scores, outstates + + +class PartialScorerInterface(ScorerInterface): + """Partial scorer interface for beam search. + + The partial scorer performs scoring when non-partial scorer finished scoring, + and recieves pre-pruned next tokens to score because it is too heavy to score + all the tokens. + + Examples: + * Prefix search for connectionist-temporal-classification models + * :class:`espnet.nets.scorers.ctc.CTCPrefixScorer` + + """ + + def score_partial( + self, y: torch.Tensor, next_tokens: torch.Tensor, state: Any, x: torch.Tensor + ) -> Tuple[torch.Tensor, Any]: + """Score new token (required). + + Args: + y (torch.Tensor): 1D prefix token + next_tokens (torch.Tensor): torch.int64 next token to score + state: decoder state for prefix tokens + x (torch.Tensor): The encoder feature that generates ys + + Returns: + tuple[torch.Tensor, Any]: + Tuple of a score tensor for y that has a shape `(len(next_tokens),)` + and next state for ys + + """ + raise NotImplementedError + + +class BatchPartialScorerInterface(BatchScorerInterface, PartialScorerInterface): + """Batch partial scorer interface for beam search.""" + + def batch_score_partial( + self, + ys: torch.Tensor, + next_tokens: torch.Tensor, + states: List[Any], + xs: torch.Tensor, + ) -> Tuple[torch.Tensor, Any]: + """Score new token (required). + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, Any]: + Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)` + and next states for ys + """ + raise NotImplementedError diff --git a/espnet/nets/scorers/__init__.py b/espnet/nets/scorers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/nets/scorers/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/nets/scorers/ctc.py b/espnet/nets/scorers/ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..9305e40a428c2b1e820cfdc7c81927e219c42895 --- /dev/null +++ b/espnet/nets/scorers/ctc.py @@ -0,0 +1,158 @@ +"""ScorerInterface implementation for CTC.""" + +import numpy as np +import torch + +from espnet.nets.ctc_prefix_score import CTCPrefixScore +from espnet.nets.ctc_prefix_score import CTCPrefixScoreTH +from espnet.nets.scorer_interface import BatchPartialScorerInterface + + +class CTCPrefixScorer(BatchPartialScorerInterface): + """Decoder interface wrapper for CTCPrefixScore.""" + + def __init__(self, ctc: torch.nn.Module, eos: int): + """Initialize class. + + Args: + ctc (torch.nn.Module): The CTC implementaiton. + For example, :class:`espnet.nets.pytorch_backend.ctc.CTC` + eos (int): The end-of-sequence id. + + """ + self.ctc = ctc + self.eos = eos + self.impl = None + + def init_state(self, x: torch.Tensor): + """Get an initial state for decoding. + + Args: + x (torch.Tensor): The encoded feature tensor + + Returns: initial state + + """ + logp = self.ctc.log_softmax(x.unsqueeze(0)).detach().squeeze(0).cpu().numpy() + # TODO(karita): use CTCPrefixScoreTH + self.impl = CTCPrefixScore(logp, 0, self.eos, np) + return 0, self.impl.initial_state() + + def select_state(self, state, i, new_id=None): + """Select state with relative ids in the main beam search. + + Args: + state: Decoder state for prefix tokens + i (int): Index to select a state in the main beam search + new_id (int): New label id to select a state if necessary + + Returns: + state: pruned state + + """ + if type(state) == tuple: + if len(state) == 2: # for CTCPrefixScore + sc, st = state + return sc[i], st[i] + else: # for CTCPrefixScoreTH (need new_id > 0) + r, log_psi, f_min, f_max, scoring_idmap = state + s = log_psi[i, new_id].expand(log_psi.size(1)) + if scoring_idmap is not None: + return r[:, :, i, scoring_idmap[i, new_id]], s, f_min, f_max + else: + return r[:, :, i, new_id], s, f_min, f_max + return None if state is None else state[i] + + def score_partial(self, y, ids, state, x): + """Score new token. + + Args: + y (torch.Tensor): 1D prefix token + next_tokens (torch.Tensor): torch.int64 next token to score + state: decoder state for prefix tokens + x (torch.Tensor): 2D encoder feature that generates ys + + Returns: + tuple[torch.Tensor, Any]: + Tuple of a score tensor for y that has a shape `(len(next_tokens),)` + and next state for ys + + """ + prev_score, state = state + presub_score, new_st = self.impl(y.cpu(), ids.cpu(), state) + tscore = torch.as_tensor( + presub_score - prev_score, device=x.device, dtype=x.dtype + ) + return tscore, (presub_score, new_st) + + def batch_init_state(self, x: torch.Tensor): + """Get an initial state for decoding. + + Args: + x (torch.Tensor): The encoded feature tensor + + Returns: initial state + + """ + logp = self.ctc.log_softmax(x.unsqueeze(0)) # assuming batch_size = 1 + xlen = torch.tensor([logp.size(1)]) + self.impl = CTCPrefixScoreTH(logp, xlen, 0, self.eos) + return None + + def batch_score_partial(self, y, ids, state, x): + """Score new token. + + Args: + y (torch.Tensor): 1D prefix token + ids (torch.Tensor): torch.int64 next token to score + state: decoder state for prefix tokens + x (torch.Tensor): 2D encoder feature that generates ys + + Returns: + tuple[torch.Tensor, Any]: + Tuple of a score tensor for y that has a shape `(len(next_tokens),)` + and next state for ys + + """ + batch_state = ( + ( + torch.stack([s[0] for s in state], dim=2), + torch.stack([s[1] for s in state]), + state[0][2], + state[0][3], + ) + if state[0] is not None + else None + ) + return self.impl(y, batch_state, ids) + + def extend_prob(self, x: torch.Tensor): + """Extend probs for decoding. + + This extention is for streaming decoding + as in Eq (14) in https://arxiv.org/abs/2006.14941 + + Args: + x (torch.Tensor): The encoded feature tensor + + """ + logp = self.ctc.log_softmax(x.unsqueeze(0)) + self.impl.extend_prob(logp) + + def extend_state(self, state): + """Extend state for decoding. + + This extention is for streaming decoding + as in Eq (14) in https://arxiv.org/abs/2006.14941 + + Args: + state: The states of hyps + + Returns: exteded state + + """ + new_state = [] + for s in state: + new_state.append(self.impl.extend_state(s)) + + return new_state diff --git a/espnet/nets/scorers/length_bonus.py b/espnet/nets/scorers/length_bonus.py new file mode 100644 index 0000000000000000000000000000000000000000..fe32a616211591308c8e7ade144e856230d211d4 --- /dev/null +++ b/espnet/nets/scorers/length_bonus.py @@ -0,0 +1,61 @@ +"""Length bonus module.""" +from typing import Any +from typing import List +from typing import Tuple + +import torch + +from espnet.nets.scorer_interface import BatchScorerInterface + + +class LengthBonus(BatchScorerInterface): + """Length bonus in beam search.""" + + def __init__(self, n_vocab: int): + """Initialize class. + + Args: + n_vocab (int): The number of tokens in vocabulary for beam search + + """ + self.n = n_vocab + + def score(self, y, state, x): + """Score new token. + + Args: + y (torch.Tensor): 1D torch.int64 prefix tokens. + state: Scorer state for prefix tokens + x (torch.Tensor): 2D encoder feature that generates ys. + + Returns: + tuple[torch.Tensor, Any]: Tuple of + torch.float32 scores for next token (n_vocab) + and None + + """ + return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None + + def batch_score( + self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor + ) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch. + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + return ( + torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand( + ys.shape[0], self.n + ), + None, + ) diff --git a/espnet/nets/scorers/ngram.py b/espnet/nets/scorers/ngram.py new file mode 100644 index 0000000000000000000000000000000000000000..701bbbdc30401d28fb5176649486867b0b797499 --- /dev/null +++ b/espnet/nets/scorers/ngram.py @@ -0,0 +1,102 @@ +"""Ngram lm implement.""" + +from abc import ABC + +import kenlm +import torch + +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.nets.scorer_interface import PartialScorerInterface + + +class Ngrambase(ABC): + """Ngram base implemented throught ScorerInterface.""" + + def __init__(self, ngram_model, token_list): + """Initialize Ngrambase. + + Args: + ngram_model: ngram model path + token_list: token list from dict or model.json + + """ + self.chardict = [x if x != "" else "" for x in token_list] + self.charlen = len(self.chardict) + self.lm = kenlm.LanguageModel(ngram_model) + self.tmpkenlmstate = kenlm.State() + + def init_state(self, x): + """Initialize tmp state.""" + state = kenlm.State() + self.lm.NullContextWrite(state) + return state + + def score_partial_(self, y, next_token, state, x): + """Score interface for both full and partial scorer. + + Args: + y: previous char + next_token: next token need to be score + state: previous state + x: encoded feature + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + out_state = kenlm.State() + ys = self.chardict[y[-1]] if y.shape[0] > 1 else "" + self.lm.BaseScore(state, ys, out_state) + scores = torch.empty_like(next_token, dtype=x.dtype, device=y.device) + for i, j in enumerate(next_token): + scores[i] = self.lm.BaseScore( + out_state, self.chardict[j], self.tmpkenlmstate + ) + return scores, out_state + + +class NgramFullScorer(Ngrambase, BatchScorerInterface): + """Fullscorer for ngram.""" + + def score(self, y, state, x): + """Score interface for both full and partial scorer. + + Args: + y: previous char + state: previous state + x: encoded feature + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + return self.score_partial_(y, torch.tensor(range(self.charlen)), state, x) + + +class NgramPartScorer(Ngrambase, PartialScorerInterface): + """Partialscorer for ngram.""" + + def score_partial(self, y, next_token, state, x): + """Score interface for both full and partial scorer. + + Args: + y: previous char + next_token: next token need to be score + state: previous state + x: encoded feature + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + return self.score_partial_(y, next_token, state, x) + + def select_state(self, state, i): + """Empty select state for scorer interface.""" + return state diff --git a/espnet/nets/st_interface.py b/espnet/nets/st_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..68a656f52e6d8779da961dd55a3287f1d4400ca1 --- /dev/null +++ b/espnet/nets/st_interface.py @@ -0,0 +1,67 @@ +"""ST Interface module.""" + +from espnet.nets.asr_interface import ASRInterface +from espnet.utils.dynamic_import import dynamic_import + + +class STInterface(ASRInterface): + """ST Interface for ESPnet model implementation. + + NOTE: This class is inherited from ASRInterface to enable joint translation + and recognition when performing multi-task learning with the ASR task. + + """ + + def translate(self, x, trans_args, char_list=None, rnnlm=None, ensemble_models=[]): + """Recognize x for evaluation. + + :param ndarray x: input acouctic feature (B, T, D) or (T, D) + :param namespace trans_args: argment namespace contraining options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("translate method is not implemented") + + def translate_batch(self, x, trans_args, char_list=None, rnnlm=None): + """Beam search implementation for batch. + + :param torch.Tensor x: encoder hidden state sequences (B, Tmax, Henc) + :param namespace trans_args: argument namespace containing options + :param list char_list: list of characters + :param torch.nn.Module rnnlm: language model module + :return: N-best decoding results + :rtype: list + """ + raise NotImplementedError("Batch decoding is not supported yet.") + + +predefined_st = { + "pytorch": { + "rnn": "espnet.nets.pytorch_backend.e2e_st:E2E", + "transformer": "espnet.nets.pytorch_backend.e2e_st_transformer:E2E", + }, + # "chainer": { + # "rnn": "espnet.nets.chainer_backend.e2e_st:E2E", + # "transformer": "espnet.nets.chainer_backend.e2e_st_transformer:E2E", + # } +} + + +def dynamic_import_st(module, backend): + """Import ST models dynamically. + + Args: + module (str): module_name:class_name or alias in `predefined_st` + backend (str): NN backend. e.g., pytorch, chainer + + Returns: + type: ST class + + """ + model_class = dynamic_import(module, predefined_st.get(backend, dict())) + assert issubclass( + model_class, STInterface + ), f"{module} does not implement STInterface" + return model_class diff --git a/espnet/nets/transducer_decoder_interface.py b/espnet/nets/transducer_decoder_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..111771305f8b119ce89c7cd658127737be811615 --- /dev/null +++ b/espnet/nets/transducer_decoder_interface.py @@ -0,0 +1,153 @@ +"""Transducer decoder interface module.""" + +from dataclasses import dataclass +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + + +@dataclass +class Hypothesis: + """Default hypothesis definition for beam search.""" + + score: float + yseq: List[int] + dec_state: Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], List[torch.Tensor], torch.Tensor + ] + lm_state: Union[Dict[str, Any], List[Any]] = None + + +@dataclass +class NSCHypothesis(Hypothesis): + """Extended hypothesis definition for NSC beam search.""" + + y: List[torch.Tensor] = None + lm_scores: torch.Tensor = None + + +class TransducerDecoderInterface: + """Decoder interface for transducer models.""" + + def init_state( + self, + batch_size: int, + device: torch.device, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] + ]: + """Initialize decoder states. + + Args: + batch_size: Batch size for initial state + device: Device for initial state + + Returns: + state: Initialized state + + """ + raise NotImplementedError("init_state method is not implemented") + + def score( + self, + hyp: Union[Hypothesis, NSCHypothesis], + cache: Dict[str, Any], + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + torch.Tensor, + List[Optional[torch.Tensor]], + ]: + """Forward one hypothesis. + + Args: + hyp: Hypothesis. + cache: Pairs of (y, state) for each token sequence (key) + + Returns: + y: Decoder outputs + new_state: New decoder state + lm_tokens: Token id for LM + + """ + raise NotImplementedError("score method is not implemented") + + def batch_score( + self, + hyps: Union[List[Hypothesis], List[NSCHypothesis]], + batch_states: Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] + ], + cache: Dict[str, Any], + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + torch.Tensor, + List[Optional[torch.Tensor]], + ]: + """Forward batch of hypotheses. + + Args: + hyps: Batch of hypotheses + batch_states: Batch of decoder states + cache: pairs of (y, state) for each token sequence (key) + + Returns: + batch_y: Decoder outputs + batch_states: Batch of decoder states + lm_tokens: Batch of token ids for LM + + """ + raise NotImplementedError("batch_score method is not implemented") + + def select_state( + self, + batch_states: Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] + ], + idx: int, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] + ]: + """Get decoder state from batch for given id. + + Args: + batch_states: Batch of decoder states + idx: Index to extract state from batch + + Returns: + state_idx: Decoder state for given id + + """ + raise NotImplementedError("select_state method is not implemented") + + def create_batch_states( + self, + batch_states: Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] + ], + l_states: List[ + Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + List[Optional[torch.Tensor]], + ] + ], + l_tokens: List[List[int]], + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], List[Optional[torch.Tensor]] + ]: + """Create batch of decoder states. + + Args: + batch_states: Batch of decoder states + l_states: List of decoder states + l_tokens: List of token sequences for input batch + + Returns: + batch_states: Batch of decoder states + + """ + raise NotImplementedError("create_batch_states method is not implemented") diff --git a/espnet/nets/tts_interface.py b/espnet/nets/tts_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..587d72792373ea4e9143a6443bac1f156d00fb90 --- /dev/null +++ b/espnet/nets/tts_interface.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- + +# Copyright 2018 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""TTS Interface realted modules.""" + +from espnet.asr.asr_utils import torch_load + + +try: + import chainer +except ImportError: + Reporter = None +else: + + class Reporter(chainer.Chain): + """Reporter module.""" + + def report(self, dicts): + """Report values from a given dict.""" + for d in dicts: + chainer.reporter.report(d, self) + + +class TTSInterface(object): + """TTS Interface for ESPnet model implementation.""" + + @staticmethod + def add_arguments(parser): + """Add model specific argments to parser.""" + return parser + + def __init__(self): + """Initilize TTS module.""" + self.reporter = Reporter() + + def forward(self, *args, **kwargs): + """Calculate TTS forward propagation. + + Returns: + Tensor: Loss value. + + """ + raise NotImplementedError("forward method is not implemented") + + def inference(self, *args, **kwargs): + """Generate the sequence of features given the sequences of characters. + + Returns: + Tensor: The sequence of generated features (L, odim). + Tensor: The sequence of stop probabilities (L,). + Tensor: The sequence of attention weights (L, T). + + """ + raise NotImplementedError("inference method is not implemented") + + def calculate_all_attentions(self, *args, **kwargs): + """Calculate TTS attention weights. + + Args: + Tensor: Batch of attention weights (B, Lmax, Tmax). + + """ + raise NotImplementedError("calculate_all_attentions method is not implemented") + + def load_pretrained_model(self, model_path): + """Load pretrained model parameters.""" + torch_load(model_path, self) + + @property + def attention_plot_class(self): + """Plot attention weights.""" + from espnet.asr.asr_utils import PlotAttentionReport + + return PlotAttentionReport + + @property + def base_plot_keys(self): + """Return base key names to plot during training. + + The keys should match what `chainer.reporter` reports. + if you add the key `loss`, + the reporter will report `main/loss` and `validation/main/loss` values. + also `loss.png` will be created as a figure visulizing `main/loss` + and `validation/main/loss` values. + + Returns: + list[str]: Base keys to plot during training. + + """ + return ["loss"] diff --git a/espnet/optimizer/__init__.py b/espnet/optimizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/optimizer/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/optimizer/chainer.py b/espnet/optimizer/chainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb6f4b3fab082926c66a81f8aa4cf19e7a6b849 --- /dev/null +++ b/espnet/optimizer/chainer.py @@ -0,0 +1,98 @@ +"""Chainer optimizer builders.""" +import argparse + +import chainer +from chainer.optimizer_hooks import WeightDecay + +from espnet.optimizer.factory import OptimizerFactoryInterface +from espnet.optimizer.parser import adadelta +from espnet.optimizer.parser import adam +from espnet.optimizer.parser import sgd + + +class AdamFactory(OptimizerFactoryInterface): + """Adam factory.""" + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Register args.""" + return adam(parser) + + @staticmethod + def from_args(target, args: argparse.Namespace): + """Initialize optimizer from argparse Namespace. + + Args: + target: for pytorch `model.parameters()`, + for chainer `model` + args (argparse.Namespace): parsed command-line args + + """ + opt = chainer.optimizers.Adam( + alpha=args.lr, + beta1=args.beta1, + beta2=args.beta2, + ) + opt.setup(target) + opt.add_hook(WeightDecay(args.weight_decay)) + return opt + + +class SGDFactory(OptimizerFactoryInterface): + """SGD factory.""" + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Register args.""" + return sgd(parser) + + @staticmethod + def from_args(target, args: argparse.Namespace): + """Initialize optimizer from argparse Namespace. + + Args: + target: for pytorch `model.parameters()`, + for chainer `model` + args (argparse.Namespace): parsed command-line args + + """ + opt = chainer.optimizers.SGD( + lr=args.lr, + ) + opt.setup(target) + opt.add_hook(WeightDecay(args.weight_decay)) + return opt + + +class AdadeltaFactory(OptimizerFactoryInterface): + """Adadelta factory.""" + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Register args.""" + return adadelta(parser) + + @staticmethod + def from_args(target, args: argparse.Namespace): + """Initialize optimizer from argparse Namespace. + + Args: + target: for pytorch `model.parameters()`, + for chainer `model` + args (argparse.Namespace): parsed command-line args + + """ + opt = chainer.optimizers.AdaDelta( + rho=args.rho, + eps=args.eps, + ) + opt.setup(target) + opt.add_hook(WeightDecay(args.weight_decay)) + return opt + + +OPTIMIZER_FACTORY_DICT = { + "adam": AdamFactory, + "sgd": SGDFactory, + "adadelta": AdadeltaFactory, +} diff --git a/espnet/optimizer/factory.py b/espnet/optimizer/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..37e19b0692e37f4d8c06b6055b64b0311540f783 --- /dev/null +++ b/espnet/optimizer/factory.py @@ -0,0 +1,69 @@ +"""Import optimizer class dynamically.""" +import argparse + +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.fill_missing_args import fill_missing_args + + +class OptimizerFactoryInterface: + """Optimizer adaptor.""" + + @staticmethod + def from_args(target, args: argparse.Namespace): + """Initialize optimizer from argparse Namespace. + + Args: + target: for pytorch `model.parameters()`, + for chainer `model` + args (argparse.Namespace): parsed command-line args + + """ + raise NotImplementedError() + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Register args.""" + return parser + + @classmethod + def build(cls, target, **kwargs): + """Initialize optimizer with python-level args. + + Args: + target: for pytorch `model.parameters()`, + for chainer `model` + + Returns: + new Optimizer + + """ + args = argparse.Namespace(**kwargs) + args = fill_missing_args(args, cls.add_arguments) + return cls.from_args(target, args) + + +def dynamic_import_optimizer(name: str, backend: str) -> OptimizerFactoryInterface: + """Import optimizer class dynamically. + + Args: + name (str): alias name or dynamic import syntax `module:class` + backend (str): backend name e.g., chainer or pytorch + + Returns: + OptimizerFactoryInterface or FunctionalOptimizerAdaptor + + """ + if backend == "pytorch": + from espnet.optimizer.pytorch import OPTIMIZER_FACTORY_DICT + + return OPTIMIZER_FACTORY_DICT[name] + elif backend == "chainer": + from espnet.optimizer.chainer import OPTIMIZER_FACTORY_DICT + + return OPTIMIZER_FACTORY_DICT[name] + else: + raise NotImplementedError(f"unsupported backend: {backend}") + + factory_class = dynamic_import(name) + assert issubclass(factory_class, OptimizerFactoryInterface) + return factory_class diff --git a/espnet/optimizer/parser.py b/espnet/optimizer/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..3347359ba35612f6c89c1e7cad0b871a7d15dccb --- /dev/null +++ b/espnet/optimizer/parser.py @@ -0,0 +1,25 @@ +"""Common optimizer default config for multiple backends.""" + + +def sgd(parser): + """Add arguments.""" + parser.add_argument("--lr", type=float, default=1.0, help="Learning rate") + parser.add_argument("--weight-decay", type=float, default=0.0, help="Weight decay") + return parser + + +def adam(parser): + """Add arguments.""" + parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") + parser.add_argument("--beta1", type=float, default=0.9, help="Beta1") + parser.add_argument("--beta2", type=float, default=0.999, help="Beta2") + parser.add_argument("--weight-decay", type=float, default=0.0, help="Weight decay") + return parser + + +def adadelta(parser): + """Add arguments.""" + parser.add_argument("--rho", type=float, default=0.95, help="Rho") + parser.add_argument("--eps", type=float, default=1e-8, help="Eps") + parser.add_argument("--weight-decay", type=float, default=0.0, help="Weight decay") + return parser diff --git a/espnet/optimizer/pytorch.py b/espnet/optimizer/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..7914e36b999b50de79e3ed666dfcfa60ef8265a1 --- /dev/null +++ b/espnet/optimizer/pytorch.py @@ -0,0 +1,93 @@ +"""PyTorch optimizer builders.""" +import argparse + +import torch + +from espnet.optimizer.factory import OptimizerFactoryInterface +from espnet.optimizer.parser import adadelta +from espnet.optimizer.parser import adam +from espnet.optimizer.parser import sgd + + +class AdamFactory(OptimizerFactoryInterface): + """Adam factory.""" + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Register args.""" + return adam(parser) + + @staticmethod + def from_args(target, args: argparse.Namespace): + """Initialize optimizer from argparse Namespace. + + Args: + target: for pytorch `model.parameters()`, + for chainer `model` + args (argparse.Namespace): parsed command-line args + + """ + return torch.optim.Adam( + target, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(args.beta1, args.beta2), + ) + + +class SGDFactory(OptimizerFactoryInterface): + """SGD factory.""" + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Register args.""" + return sgd(parser) + + @staticmethod + def from_args(target, args: argparse.Namespace): + """Initialize optimizer from argparse Namespace. + + Args: + target: for pytorch `model.parameters()`, + for chainer `model` + args (argparse.Namespace): parsed command-line args + + """ + return torch.optim.SGD( + target, + lr=args.lr, + weight_decay=args.weight_decay, + ) + + +class AdadeltaFactory(OptimizerFactoryInterface): + """Adadelta factory.""" + + @staticmethod + def add_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Register args.""" + return adadelta(parser) + + @staticmethod + def from_args(target, args: argparse.Namespace): + """Initialize optimizer from argparse Namespace. + + Args: + target: for pytorch `model.parameters()`, + for chainer `model` + args (argparse.Namespace): parsed command-line args + + """ + return torch.optim.Adadelta( + target, + rho=args.rho, + eps=args.eps, + weight_decay=args.weight_decay, + ) + + +OPTIMIZER_FACTORY_DICT = { + "adam": AdamFactory, + "sgd": SGDFactory, + "adadelta": AdadeltaFactory, +} diff --git a/espnet/scheduler/__init__.py b/espnet/scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/scheduler/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/scheduler/chainer.py b/espnet/scheduler/chainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3ec422c41b64ab9a3f1e7e5ed5af87402bef938b --- /dev/null +++ b/espnet/scheduler/chainer.py @@ -0,0 +1,25 @@ +"""Chainer optimizer schdulers.""" + +from typing import List + +from chainer.optimizer import Optimizer + +from espnet.scheduler.scheduler import SchedulerInterface + + +class ChainerScheduler: + """Chainer optimizer scheduler.""" + + def __init__(self, schedulers: List[SchedulerInterface], optimizer: Optimizer): + """Initialize class.""" + self.schedulers = schedulers + self.optimizer = optimizer + self.init_values = dict() + for s in self.schedulers: + self.init_values[s.key] = getattr(self.optimizer, s.key) + + def step(self, n_iter: int): + """Update optimizer by scheduling.""" + for s in self.schedulers: + new_val = self.init_values[s.key] * s.scale(n_iter) + setattr(self.optimizer, s.key, new_val) diff --git a/espnet/scheduler/pytorch.py b/espnet/scheduler/pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..0e944b15d4bcf16fcac6443e46ffc01038dc281e --- /dev/null +++ b/espnet/scheduler/pytorch.py @@ -0,0 +1,25 @@ +"""PyTorch optimizer schdulers.""" + +from typing import List + +from torch.optim import Optimizer + +from espnet.scheduler.scheduler import SchedulerInterface + + +class PyTorchScheduler: + """PyTorch optimizer scheduler.""" + + def __init__(self, schedulers: List[SchedulerInterface], optimizer: Optimizer): + """Initialize class.""" + self.schedulers = schedulers + self.optimizer = optimizer + for s in self.schedulers: + for group in optimizer.param_groups: + group.setdefault("initial_" + s.key, group[s.key]) + + def step(self, n_iter: int): + """Update optimizer by scheduling.""" + for s in self.schedulers: + for group in self.optimizer.param_groups: + group[s.key] = group["initial_" + s.key] * s.scale(n_iter) diff --git a/espnet/scheduler/scheduler.py b/espnet/scheduler/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..8d81368884c9e7e6d4ed237e427ac89608e5e73f --- /dev/null +++ b/espnet/scheduler/scheduler.py @@ -0,0 +1,180 @@ +"""Schedulers.""" + +import argparse + +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.fill_missing_args import fill_missing_args + + +class _PrefixParser: + def __init__(self, parser, prefix): + self.parser = parser + self.prefix = prefix + + def add_argument(self, name, **kwargs): + assert name.startswith("--") + self.parser.add_argument(self.prefix + name[2:], **kwargs) + + +class SchedulerInterface: + """Scheduler interface.""" + + alias = "" + + def __init__(self, key: str, args: argparse.Namespace): + """Initialize class.""" + self.key = key + prefix = key + "_" + self.alias + "_" + for k, v in vars(args).items(): + if k.startswith(prefix): + setattr(self, k[len(prefix) :], v) + + def get_arg(self, name): + """Get argument without prefix.""" + return getattr(self.args, f"{self.key}_{self.alias}_{name}") + + @classmethod + def add_arguments(cls, key: str, parser: argparse.ArgumentParser): + """Add arguments for CLI.""" + group = parser.add_argument_group(f"{cls.alias} scheduler") + cls._add_arguments(_PrefixParser(parser=group, prefix=f"--{key}-{cls.alias}-")) + return parser + + @staticmethod + def _add_arguments(parser: _PrefixParser): + pass + + @classmethod + def build(cls, key: str, **kwargs): + """Initialize this class with python-level args. + + Args: + key (str): key of hyper parameter + + Returns: + LMinterface: A new instance of LMInterface. + + """ + + def add(parser): + return cls.add_arguments(key, parser) + + kwargs = {f"{key}_{cls.alias}_" + k: v for k, v in kwargs.items()} + args = argparse.Namespace(**kwargs) + args = fill_missing_args(args, add) + return cls(key, args) + + def scale(self, n_iter: int) -> float: + """Scale at `n_iter`. + + Args: + n_iter (int): number of current iterations. + + Returns: + float: current scale of learning rate. + + """ + raise NotImplementedError() + + +SCHEDULER_DICT = {} + + +def register_scheduler(cls): + """Register scheduler.""" + SCHEDULER_DICT[cls.alias] = cls.__module__ + ":" + cls.__name__ + return cls + + +def dynamic_import_scheduler(module): + """Import Scheduler class dynamically. + + Args: + module (str): module_name:class_name or alias in `SCHEDULER_DICT` + + Returns: + type: Scheduler class + + """ + model_class = dynamic_import(module, SCHEDULER_DICT) + assert issubclass( + model_class, SchedulerInterface + ), f"{module} does not implement SchedulerInterface" + return model_class + + +@register_scheduler +class NoScheduler(SchedulerInterface): + """Scheduler which does nothing.""" + + alias = "none" + + def scale(self, n_iter): + """Scale of lr.""" + return 1.0 + + +@register_scheduler +class NoamScheduler(SchedulerInterface): + """Warmup + InverseSqrt decay scheduler. + + Args: + noam_warmup (int): number of warmup iterations. + + """ + + alias = "noam" + + @staticmethod + def _add_arguments(parser: _PrefixParser): + """Add scheduler args.""" + parser.add_argument( + "--warmup", type=int, default=1000, help="Number of warmup iterations." + ) + + def __init__(self, key, args): + """Initialize class.""" + super().__init__(key, args) + self.normalize = 1 / (self.warmup * self.warmup ** -1.5) + + def scale(self, step): + """Scale of lr.""" + step += 1 # because step starts from 0 + return self.normalize * min(step ** -0.5, step * self.warmup ** -1.5) + + +@register_scheduler +class CyclicCosineScheduler(SchedulerInterface): + """Cyclic cosine annealing. + + Args: + cosine_warmup (int): number of warmup iterations. + cosine_total (int): number of total annealing iterations. + + Notes: + Proposed in https://openreview.net/pdf?id=BJYwwY9ll + (and https://arxiv.org/pdf/1608.03983.pdf). + Used in the GPT2 config of Megatron-LM https://github.com/NVIDIA/Megatron-LM + + """ + + alias = "cosine" + + @staticmethod + def _add_arguments(parser: _PrefixParser): + """Add scheduler args.""" + parser.add_argument( + "--warmup", type=int, default=1000, help="Number of warmup iterations." + ) + parser.add_argument( + "--total", + type=int, + default=100000, + help="Number of total annealing iterations.", + ) + + def scale(self, n_iter): + """Scale of lr.""" + import math + + return 0.5 * (math.cos(math.pi * (n_iter - self.warmup) / self.total) + 1) diff --git a/espnet/st/__init__.py b/espnet/st/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/st/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/st/pytorch_backend/__init__.py b/espnet/st/pytorch_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/st/pytorch_backend/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/st/pytorch_backend/st.py b/espnet/st/pytorch_backend/st.py new file mode 100644 index 0000000000000000000000000000000000000000..d6824280c3998d59f522713df845732228f3bedd --- /dev/null +++ b/espnet/st/pytorch_backend/st.py @@ -0,0 +1,687 @@ +# Copyright 2019 Kyoto University (Hirofumi Inaguma) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Training/decoding definition for the speech translation task.""" + +import json +import logging +import os +import sys + +from chainer import training +from chainer.training import extensions +import numpy as np +from tensorboardX import SummaryWriter +import torch + +from espnet.asr.asr_utils import adadelta_eps_decay +from espnet.asr.asr_utils import adam_lr_decay +from espnet.asr.asr_utils import add_results_to_json +from espnet.asr.asr_utils import CompareValueTrigger +from espnet.asr.asr_utils import restore_snapshot +from espnet.asr.asr_utils import snapshot_object +from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import torch_resume +from espnet.asr.asr_utils import torch_snapshot +from espnet.asr.pytorch_backend.asr_init import load_trained_model +from espnet.asr.pytorch_backend.asr_init import load_trained_modules + +from espnet.nets.pytorch_backend.e2e_asr import pad_list +from espnet.nets.st_interface import STInterface +from espnet.utils.dataset import ChainerDataLoader +from espnet.utils.dataset import TransformDataset +from espnet.utils.deterministic_utils import set_deterministic_pytorch +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.io_utils import LoadInputsAndTargets +from espnet.utils.training.batchfy import make_batchset +from espnet.utils.training.iterators import ShufflingEnabler +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from espnet.utils.training.train_utils import check_early_stop +from espnet.utils.training.train_utils import set_early_stop + +from espnet.asr.pytorch_backend.asr import CustomConverter as ASRCustomConverter +from espnet.asr.pytorch_backend.asr import CustomEvaluator +from espnet.asr.pytorch_backend.asr import CustomUpdater + +import matplotlib + +matplotlib.use("Agg") + +if sys.version_info[0] == 2: + from itertools import izip_longest as zip_longest +else: + from itertools import zip_longest as zip_longest + + +class CustomConverter(ASRCustomConverter): + """Custom batch converter for Pytorch. + + Args: + subsampling_factor (int): The subsampling factor. + dtype (torch.dtype): Data type to convert. + use_source_text (bool): use source transcription. + + """ + + def __init__( + self, subsampling_factor=1, dtype=torch.float32, use_source_text=False + ): + """Construct a CustomConverter object.""" + super().__init__(subsampling_factor=subsampling_factor, dtype=dtype) + self.use_source_text = use_source_text + + def __call__(self, batch, device=torch.device("cpu")): + """Transform a batch and send it to a device. + + Args: + batch (list): The batch to transform. + device (torch.device): The device to send to. + + Returns: + tuple(torch.Tensor, torch.Tensor, torch.Tensor) + + """ + # batch should be located in list + assert len(batch) == 1 + xs, ys, ys_src = batch[0] + + # get batch of lengths of input sequences + ilens = np.array([x.shape[0] for x in xs]) + ilens = torch.from_numpy(ilens).to(device) + + xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to( + device, dtype=self.dtype + ) + + ys_pad = pad_list( + [torch.from_numpy(np.array(y, dtype=np.int64)) for y in ys], + self.ignore_id, + ).to(device) + + if self.use_source_text: + ys_pad_src = pad_list( + [torch.from_numpy(np.array(y, dtype=np.int64)) for y in ys_src], + self.ignore_id, + ).to(device) + else: + ys_pad_src = None + + return xs_pad, ilens, ys_pad, ys_pad_src + + +def train(args): + """Train with the given args. + + Args: + args (namespace): The program arguments. + + """ + set_deterministic_pytorch(args) + + # check cuda availability + if not torch.cuda.is_available(): + logging.warning("cuda is not available") + + # get input and output dimension info + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + utts = list(valid_json.keys()) + idim = int(valid_json[utts[0]]["input"][0]["shape"][-1]) + odim = int(valid_json[utts[0]]["output"][0]["shape"][-1]) + logging.info("#input dims : " + str(idim)) + logging.info("#output dims: " + str(odim)) + + # Initialize with pre-trained ASR encoder and MT decoder + if args.enc_init is not None or args.dec_init is not None: + model = load_trained_modules(idim, odim, args, interface=STInterface) + else: + model_class = dynamic_import(args.model_module) + model = model_class(idim, odim, args) + assert isinstance(model, STInterface) + total_subsampling_factor = model.get_total_subsampling_factor() + + # write model config + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + model_conf = args.outdir + "/model.json" + with open(model_conf, "wb") as f: + logging.info("writing a model config file to " + model_conf) + f.write( + json.dumps( + (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) + for key in sorted(vars(args).keys()): + logging.info("ARGS: " + key + ": " + str(vars(args)[key])) + + reporter = model.reporter + + # check the use of multi-gpu + if args.ngpu > 1: + if args.batch_size != 0: + logging.warning( + "batch size is automatically increased (%d -> %d)" + % (args.batch_size, args.batch_size * args.ngpu) + ) + args.batch_size *= args.ngpu + + # set torch device + device = torch.device("cuda" if args.ngpu > 0 else "cpu") + if args.train_dtype in ("float16", "float32", "float64"): + dtype = getattr(torch, args.train_dtype) + else: + dtype = torch.float32 + model = model.to(device=device, dtype=dtype) + + logging.warning( + "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(p.numel() for p in model.parameters() if p.requires_grad) + * 100.0 + / sum(p.numel() for p in model.parameters()), + ) + ) + + # Setup an optimizer + if args.opt == "adadelta": + optimizer = torch.optim.Adadelta( + model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay + ) + elif args.opt == "adam": + optimizer = torch.optim.Adam( + model.parameters(), lr=args.lr, weight_decay=args.weight_decay + ) + elif args.opt == "noam": + from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + + optimizer = get_std_opt( + model.parameters(), + args.adim, + args.transformer_warmup_steps, + args.transformer_lr, + ) + else: + raise NotImplementedError("unknown optimizer: " + args.opt) + + # setup apex.amp + if args.train_dtype in ("O0", "O1", "O2", "O3"): + try: + from apex import amp + except ImportError as e: + logging.error( + f"You need to install apex for --train-dtype {args.train_dtype}. " + "See https://github.com/NVIDIA/apex#linux" + ) + raise e + if args.opt == "noam": + model, optimizer.optimizer = amp.initialize( + model, optimizer.optimizer, opt_level=args.train_dtype + ) + else: + model, optimizer = amp.initialize( + model, optimizer, opt_level=args.train_dtype + ) + use_apex = True + else: + use_apex = False + + # FIXME: TOO DIRTY HACK + setattr(optimizer, "target", reporter) + setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) + + # Setup a converter + converter = CustomConverter( + subsampling_factor=model.subsample[0], + dtype=dtype, + use_source_text=args.asr_weight > 0 or args.mt_weight > 0, + ) + + # read json data + with open(args.train_json, "rb") as f: + train_json = json.load(f)["utts"] + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + + use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 + # make minibatch list (variable length) + train = make_batchset( + train_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + shortest_first=use_sortagrad, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + iaxis=0, + oaxis=0, + ) + valid = make_batchset( + valid_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + iaxis=0, + oaxis=0, + ) + + load_tr = LoadInputsAndTargets( + mode="asr", + load_output=True, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": True}, # Switch the mode of preprocessing + ) + load_cv = LoadInputsAndTargets( + mode="asr", + load_output=True, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": False}, # Switch the mode of preprocessing + ) + # hack to make batchsize argument as 1 + # actual bathsize is included in a list + # default collate function converts numpy array to pytorch tensor + # we used an empty collate function instead which returns list + train_iter = ChainerDataLoader( + dataset=TransformDataset(train, lambda data: converter([load_tr(data)])), + batch_size=1, + num_workers=args.n_iter_processes, + shuffle=not use_sortagrad, + collate_fn=lambda x: x[0], + ) + valid_iter = ChainerDataLoader( + dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])), + batch_size=1, + shuffle=False, + collate_fn=lambda x: x[0], + num_workers=args.n_iter_processes, + ) + + # Set up a trainer + updater = CustomUpdater( + model, + args.grad_clip, + {"main": train_iter}, + optimizer, + device, + args.ngpu, + args.grad_noise, + args.accum_grad, + use_apex=use_apex, + ) + trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) + + if use_sortagrad: + trainer.extend( + ShufflingEnabler([train_iter]), + trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), + ) + + # Resume from a snapshot + if args.resume: + logging.info("resumed from %s" % args.resume) + torch_resume(args.resume, trainer) + + # Evaluate the model with the test dataset for each epoch + if args.save_interval_iters > 0: + trainer.extend( + CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu), + trigger=(args.save_interval_iters, "iteration"), + ) + else: + trainer.extend( + CustomEvaluator(model, {"main": valid_iter}, reporter, device, args.ngpu) + ) + + # Save attention weight at each epoch + if args.num_save_attention > 0: + data = sorted( + list(valid_json.items())[: args.num_save_attention], + key=lambda x: int(x[1]["input"][0]["shape"][1]), + reverse=True, + ) + if hasattr(model, "module"): + att_vis_fn = model.module.calculate_all_attentions + plot_class = model.module.attention_plot_class + else: + att_vis_fn = model.calculate_all_attentions + plot_class = model.attention_plot_class + att_reporter = plot_class( + att_vis_fn, + data, + args.outdir + "/att_ws", + converter=converter, + transform=load_cv, + device=device, + subsampling_factor=total_subsampling_factor, + ) + trainer.extend(att_reporter, trigger=(1, "epoch")) + else: + att_reporter = None + + # Save CTC prob at each epoch + if (args.asr_weight > 0 and args.mtlalpha > 0) and args.num_save_ctc > 0: + # NOTE: sort it by output lengths + data = sorted( + list(valid_json.items())[: args.num_save_ctc], + key=lambda x: int(x[1]["output"][0]["shape"][0]), + reverse=True, + ) + if hasattr(model, "module"): + ctc_vis_fn = model.module.calculate_all_ctc_probs + plot_class = model.module.ctc_plot_class + else: + ctc_vis_fn = model.calculate_all_ctc_probs + plot_class = model.ctc_plot_class + ctc_reporter = plot_class( + ctc_vis_fn, + data, + args.outdir + "/ctc_prob", + converter=converter, + transform=load_cv, + device=device, + subsampling_factor=total_subsampling_factor, + ) + trainer.extend(ctc_reporter, trigger=(1, "epoch")) + else: + ctc_reporter = None + + # Make a plot for training and validation values + trainer.extend( + extensions.PlotReport( + [ + "main/loss", + "validation/main/loss", + "main/loss_asr", + "validation/main/loss_asr", + "main/loss_mt", + "validation/main/loss_mt", + "main/loss_st", + "validation/main/loss_st", + ], + "epoch", + file_name="loss.png", + ) + ) + trainer.extend( + extensions.PlotReport( + [ + "main/acc", + "validation/main/acc", + "main/acc_asr", + "validation/main/acc_asr", + "main/acc_mt", + "validation/main/acc_mt", + ], + "epoch", + file_name="acc.png", + ) + ) + trainer.extend( + extensions.PlotReport( + ["main/bleu", "validation/main/bleu"], "epoch", file_name="bleu.png" + ) + ) + + # Save best models + trainer.extend( + snapshot_object(model, "model.loss.best"), + trigger=training.triggers.MinValueTrigger("validation/main/loss"), + ) + trainer.extend( + snapshot_object(model, "model.acc.best"), + trigger=training.triggers.MaxValueTrigger("validation/main/acc"), + ) + + # save snapshot which contains model and optimizer states + if args.save_interval_iters > 0: + trainer.extend( + torch_snapshot(filename="snapshot.iter.{.updater.iteration}"), + trigger=(args.save_interval_iters, "iteration"), + ) + else: + trainer.extend(torch_snapshot(), trigger=(1, "epoch")) + + # epsilon decay in the optimizer + if args.opt == "adadelta": + if args.criterion == "acc": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.acc.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + elif args.criterion == "loss": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.loss.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + trainer.extend( + adadelta_eps_decay(args.eps_decay), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + elif args.opt == "adam": + if args.criterion == "acc": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.acc.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + trainer.extend( + adam_lr_decay(args.lr_decay), + trigger=CompareValueTrigger( + "validation/main/acc", + lambda best_value, current_value: best_value > current_value, + ), + ) + elif args.criterion == "loss": + trainer.extend( + restore_snapshot( + model, args.outdir + "/model.loss.best", load_fn=torch_load + ), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + trainer.extend( + adam_lr_decay(args.lr_decay), + trigger=CompareValueTrigger( + "validation/main/loss", + lambda best_value, current_value: best_value < current_value, + ), + ) + + # Write a log of evaluation statistics for each epoch + trainer.extend( + extensions.LogReport(trigger=(args.report_interval_iters, "iteration")) + ) + report_keys = [ + "epoch", + "iteration", + "main/loss", + "main/loss_st", + "main/loss_asr", + "validation/main/loss", + "validation/main/loss_st", + "validation/main/loss_asr", + "main/acc", + "validation/main/acc", + ] + if args.asr_weight > 0: + report_keys.append("main/acc_asr") + report_keys.append("validation/main/acc_asr") + report_keys += ["elapsed_time"] + if args.opt == "adadelta": + trainer.extend( + extensions.observe_value( + "eps", + lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ + "eps" + ], + ), + trigger=(args.report_interval_iters, "iteration"), + ) + report_keys.append("eps") + elif args.opt in ["adam", "noam"]: + trainer.extend( + extensions.observe_value( + "lr", + lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][ + "lr" + ], + ), + trigger=(args.report_interval_iters, "iteration"), + ) + report_keys.append("lr") + if args.asr_weight > 0: + if args.mtlalpha > 0: + report_keys.append("main/cer_ctc") + report_keys.append("validation/main/cer_ctc") + if args.mtlalpha < 1: + if args.report_cer: + report_keys.append("validation/main/cer") + if args.report_wer: + report_keys.append("validation/main/wer") + if args.report_bleu: + report_keys.append("main/bleu") + report_keys.append("validation/main/bleu") + trainer.extend( + extensions.PrintReport(report_keys), + trigger=(args.report_interval_iters, "iteration"), + ) + + trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) + set_early_stop(trainer, args) + + if args.tensorboard_dir is not None and args.tensorboard_dir != "": + trainer.extend( + TensorboardLogger( + SummaryWriter(args.tensorboard_dir), + att_reporter=att_reporter, + ctc_reporter=ctc_reporter, + ), + trigger=(args.report_interval_iters, "iteration"), + ) + # Run the training + trainer.run() + check_early_stop(trainer, args.epochs) + + +def trans(args): + """Decode with the given args. + + Args: + args (namespace): The program arguments. + + """ + set_deterministic_pytorch(args) + model, train_args = load_trained_model(args.model) + assert isinstance(model, STInterface) + model.trans_args = args + + # gpu + if args.ngpu == 1: + gpu_id = list(range(args.ngpu)) + logging.info("gpu id: " + str(gpu_id)) + model.cuda() + + # read json data + with open(args.trans_json, "rb") as f: + js = json.load(f)["utts"] + new_js = {} + + load_inputs_and_targets = LoadInputsAndTargets( + mode="asr", + load_output=False, + sort_in_input_length=False, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf, + preprocess_args={"train": False}, + ) + + if args.batchsize == 0: + with torch.no_grad(): + for idx, name in enumerate(js.keys(), 1): + logging.info("(%d/%d) decoding " + name, idx, len(js.keys())) + batch = [(name, js[name])] + feat = load_inputs_and_targets(batch)[0][0] + nbest_hyps = model.translate( + feat, + args, + train_args.char_list, + ) + new_js[name] = add_results_to_json( + js[name], nbest_hyps, train_args.char_list + ) + + else: + + def grouper(n, iterable, fillvalue=None): + kargs = [iter(iterable)] * n + return zip_longest(*kargs, fillvalue=fillvalue) + + # sort data if batchsize > 1 + keys = list(js.keys()) + if args.batchsize > 1: + feat_lens = [js[key]["input"][0]["shape"][0] for key in keys] + sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i]) + keys = [keys[i] for i in sorted_index] + + with torch.no_grad(): + for names in grouper(args.batchsize, keys, None): + names = [name for name in names if name] + batch = [(name, js[name]) for name in names] + feats = load_inputs_and_targets(batch)[0] + nbest_hyps = model.translate_batch( + feats, + args, + train_args.char_list, + ) + + for i, nbest_hyp in enumerate(nbest_hyps): + name = names[i] + new_js[name] = add_results_to_json( + js[name], nbest_hyp, train_args.char_list + ) + + with open(args.result_label, "wb") as f: + f.write( + json.dumps( + {"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) diff --git a/espnet/transform/__init__.py b/espnet/transform/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f78ea5dbc9ae296270ed1cf2688313d52d6480b3 --- /dev/null +++ b/espnet/transform/__init__.py @@ -0,0 +1 @@ +"""Initialize main package.""" diff --git a/espnet/transform/add_deltas.py b/espnet/transform/add_deltas.py new file mode 100644 index 0000000000000000000000000000000000000000..93f941c5f04ffb84776c2fcafd59229b6b5e8fd4 --- /dev/null +++ b/espnet/transform/add_deltas.py @@ -0,0 +1,34 @@ +import numpy as np + + +def delta(feat, window): + assert window > 0 + delta_feat = np.zeros_like(feat) + for i in range(1, window + 1): + delta_feat[:-i] += i * feat[i:] + delta_feat[i:] += -i * feat[:-i] + delta_feat[-i:] += i * feat[-1] + delta_feat[:i] += -i * feat[0] + delta_feat /= 2 * sum(i ** 2 for i in range(1, window + 1)) + return delta_feat + + +def add_deltas(x, window=2, order=2): + feats = [x] + for _ in range(order): + feats.append(delta(feats[-1], window)) + return np.concatenate(feats, axis=1) + + +class AddDeltas(object): + def __init__(self, window=2, order=2): + self.window = window + self.order = order + + def __repr__(self): + return "{name}(window={window}, order={order}".format( + name=self.__class__.__name__, window=self.window, order=self.order + ) + + def __call__(self, x): + return add_deltas(x, window=self.window, order=self.order) diff --git a/espnet/transform/channel_selector.py b/espnet/transform/channel_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..9f303bd507787997244f1c33a590e366bd0300fd --- /dev/null +++ b/espnet/transform/channel_selector.py @@ -0,0 +1,45 @@ +import numpy + + +class ChannelSelector(object): + """Select 1ch from multi-channel signal """ + + def __init__(self, train_channel="random", eval_channel=0, axis=1): + self.train_channel = train_channel + self.eval_channel = eval_channel + self.axis = axis + + def __repr__(self): + return ( + "{name}(train_channel={train_channel}, " + "eval_channel={eval_channel}, axis={axis})".format( + name=self.__class__.__name__, + train_channel=self.train_channel, + eval_channel=self.eval_channel, + axis=self.axis, + ) + ) + + def __call__(self, x, train=True): + # Assuming x: [Time, Channel] by default + + if x.ndim <= self.axis: + # If the dimension is insufficient, then unsqueeze + # (e.g [Time] -> [Time, 1]) + ind = tuple( + slice(None) if i < x.ndim else None for i in range(self.axis + 1) + ) + x = x[ind] + + if train: + channel = self.train_channel + else: + channel = self.eval_channel + + if channel == "random": + ch = numpy.random.randint(0, x.shape[self.axis]) + else: + ch = channel + + ind = tuple(slice(None) if i != self.axis else ch for i in range(x.ndim)) + return x[ind] diff --git a/espnet/transform/cmvn.py b/espnet/transform/cmvn.py new file mode 100644 index 0000000000000000000000000000000000000000..085b243841d7a0ab18359782f3dee0bacffce3a5 --- /dev/null +++ b/espnet/transform/cmvn.py @@ -0,0 +1,144 @@ +import io + +import h5py +import kaldiio +import numpy as np + + +class CMVN(object): + def __init__( + self, + stats, + norm_means=True, + norm_vars=False, + filetype="mat", + utt2spk=None, + spk2utt=None, + reverse=False, + std_floor=1.0e-20, + ): + self.stats_file = stats + self.norm_means = norm_means + self.norm_vars = norm_vars + self.reverse = reverse + + if isinstance(stats, dict): + stats_dict = dict(stats) + else: + # Use for global CMVN + if filetype == "mat": + stats_dict = {None: kaldiio.load_mat(stats)} + # Use for global CMVN + elif filetype == "npy": + stats_dict = {None: np.load(stats)} + # Use for speaker CMVN + elif filetype == "ark": + self.accept_uttid = True + stats_dict = dict(kaldiio.load_ark(stats)) + # Use for speaker CMVN + elif filetype == "hdf5": + self.accept_uttid = True + stats_dict = h5py.File(stats) + else: + raise ValueError("Not supporting filetype={}".format(filetype)) + + if utt2spk is not None: + self.utt2spk = {} + with io.open(utt2spk, "r", encoding="utf-8") as f: + for line in f: + utt, spk = line.rstrip().split(None, 1) + self.utt2spk[utt] = spk + elif spk2utt is not None: + self.utt2spk = {} + with io.open(spk2utt, "r", encoding="utf-8") as f: + for line in f: + spk, utts = line.rstrip().split(None, 1) + for utt in utts.split(): + self.utt2spk[utt] = spk + else: + self.utt2spk = None + + # Kaldi makes a matrix for CMVN which has a shape of (2, feat_dim + 1), + # and the first vector contains the sum of feats and the second is + # the sum of squares. The last value of the first, i.e. stats[0,-1], + # is the number of samples for this statistics. + self.bias = {} + self.scale = {} + for spk, stats in stats_dict.items(): + assert len(stats) == 2, stats.shape + + count = stats[0, -1] + + # If the feature has two or more dimensions + if not (np.isscalar(count) or isinstance(count, (int, float))): + # The first is only used + count = count.flatten()[0] + + mean = stats[0, :-1] / count + # V(x) = E(x^2) - (E(x))^2 + var = stats[1, :-1] / count - mean * mean + std = np.maximum(np.sqrt(var), std_floor) + self.bias[spk] = -mean + self.scale[spk] = 1 / std + + def __repr__(self): + return ( + "{name}(stats_file={stats_file}, " + "norm_means={norm_means}, norm_vars={norm_vars}, " + "reverse={reverse})".format( + name=self.__class__.__name__, + stats_file=self.stats_file, + norm_means=self.norm_means, + norm_vars=self.norm_vars, + reverse=self.reverse, + ) + ) + + def __call__(self, x, uttid=None): + if self.utt2spk is not None: + spk = self.utt2spk[uttid] + else: + spk = uttid + + if not self.reverse: + if self.norm_means: + x = np.add(x, self.bias[spk]) + if self.norm_vars: + x = np.multiply(x, self.scale[spk]) + + else: + if self.norm_vars: + x = np.divide(x, self.scale[spk]) + if self.norm_means: + x = np.subtract(x, self.bias[spk]) + + return x + + +class UtteranceCMVN(object): + def __init__(self, norm_means=True, norm_vars=False, std_floor=1.0e-20): + self.norm_means = norm_means + self.norm_vars = norm_vars + self.std_floor = std_floor + + def __repr__(self): + return "{name}(norm_means={norm_means}, norm_vars={norm_vars})".format( + name=self.__class__.__name__, + norm_means=self.norm_means, + norm_vars=self.norm_vars, + ) + + def __call__(self, x, uttid=None): + # x: [Time, Dim] + square_sums = (x ** 2).sum(axis=0) + mean = x.mean(axis=0) + + if self.norm_means: + x = np.subtract(x, mean) + + if self.norm_vars: + var = square_sums / x.shape[0] - mean ** 2 + std = np.maximum(np.sqrt(var), self.std_floor) + x = np.divide(x, std) + + return x diff --git a/espnet/transform/functional.py b/espnet/transform/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..6226cddaef80494d11d44568fd222f79e7e6f328 --- /dev/null +++ b/espnet/transform/functional.py @@ -0,0 +1,71 @@ +import inspect + +from espnet.transform.transform_interface import TransformInterface +from espnet.utils.check_kwargs import check_kwargs + + +class FuncTrans(TransformInterface): + """Functional Transformation + + WARNING: + Builtin or C/C++ functions may not work properly + because this class heavily depends on the `inspect` module. + + Usage: + + >>> def foo_bar(x, a=1, b=2): + ... '''Foo bar + ... :param x: input + ... :param int a: default 1 + ... :param int b: default 2 + ... ''' + ... return x + a - b + + + >>> class FooBar(FuncTrans): + ... _func = foo_bar + ... __doc__ = foo_bar.__doc__ + """ + + _func = None + + def __init__(self, **kwargs): + self.kwargs = kwargs + check_kwargs(self.func, kwargs) + + def __call__(self, x): + return self.func(x, **self.kwargs) + + @classmethod + def add_arguments(cls, parser): + fname = cls._func.__name__.replace("_", "-") + group = parser.add_argument_group(fname + " transformation setting") + for k, v in cls.default_params().items(): + # TODO(karita): get help and choices from docstring? + attr = k.replace("_", "-") + group.add_argument(f"--{fname}-{attr}", default=v, type=type(v)) + return parser + + @property + def func(self): + return type(self)._func + + @classmethod + def default_params(cls): + try: + d = dict(inspect.signature(cls._func).parameters) + except ValueError: + d = dict() + return { + k: v.default for k, v in d.items() if v.default != inspect.Parameter.empty + } + + def __repr__(self): + params = self.default_params() + params.update(**self.kwargs) + ret = self.__class__.__name__ + "(" + if len(params) == 0: + return ret + ")" + for k, v in params.items(): + ret += "{}={}, ".format(k, v) + return ret[:-2] + ")" diff --git a/espnet/transform/perturb.py b/espnet/transform/perturb.py new file mode 100644 index 0000000000000000000000000000000000000000..a05b72794c2904e08d37298861fbd4e1871b6405 --- /dev/null +++ b/espnet/transform/perturb.py @@ -0,0 +1,344 @@ +import librosa +import numpy +import scipy +import soundfile + +from espnet.utils.io_utils import SoundHDF5File + + +class SpeedPerturbation(object): + """SpeedPerturbation + + The speed perturbation in kaldi uses sox-speed instead of sox-tempo, + and sox-speed just to resample the input, + i.e pitch and tempo are changed both. + + "Why use speed option instead of tempo -s in SoX for speed perturbation" + https://groups.google.com/forum/#!topic/kaldi-help/8OOG7eE4sZ8 + + Warning: + This function is very slow because of resampling. + I recommmend to apply speed-perturb outside the training using sox. + + """ + + def __init__( + self, + lower=0.9, + upper=1.1, + utt2ratio=None, + keep_length=True, + res_type="kaiser_best", + seed=None, + ): + self.res_type = res_type + self.keep_length = keep_length + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + self.utt2ratio = {} + # Use the scheduled ratio for each utterances + self.utt2ratio_file = utt2ratio + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + self.utt2ratio = None + # The ratio is given on runtime randomly + self.lower = lower + self.upper = upper + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, " "keep_length={}, res_type={})".format( + self.__class__.__name__, + self.lower, + self.upper, + self.keep_length, + self.res_type, + ) + else: + return "{}({}, res_type={})".format( + self.__class__.__name__, self.utt2ratio_file, self.res_type + ) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + # Note1: resample requires the sampling-rate of input and output, + # but actually only the ratio is used. + y = librosa.resample(x, ratio, 1, res_type=self.res_type) + + if self.keep_length: + diff = abs(len(x) - len(y)) + if len(y) > len(x): + # Truncate noise + y = y[diff // 2 : -((diff + 1) // 2)] + elif len(y) < len(x): + # Assume the time-axis is the first: (Time, Channel) + pad_width = [(diff // 2, (diff + 1) // 2)] + [ + (0, 0) for _ in range(y.ndim - 1) + ] + y = numpy.pad( + y, pad_width=pad_width, constant_values=0, mode="constant" + ) + return y + + +class BandpassPerturbation(object): + """BandpassPerturbation + + Randomly dropout along the frequency axis. + + The original idea comes from the following: + "randomly-selected frequency band was cut off under the constraint of + leaving at least 1,000 Hz band within the range of less than 4,000Hz." + (The Hitachi/JHU CHiME-5 system: Advances in speech recognition for + everyday home environments using multiple microphone arrays; + http://spandh.dcs.shef.ac.uk/chime_workshop/papers/CHiME_2018_paper_kanda.pdf) + + """ + + def __init__(self, lower=0.0, upper=0.75, seed=None, axes=(-1,)): + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + # x_stft: (Time, Channel, Freq) + self.axes = axes + + def __repr__(self): + return "{}(lower={}, upper={})".format( + self.__class__.__name__, self.lower, self.upper + ) + + def __call__(self, x_stft, uttid=None, train=True): + if not train: + return x_stft + + if x_stft.ndim == 1: + raise RuntimeError( + "Input in time-freq domain: " "(Time, Channel, Freq) or (Time, Freq)" + ) + + ratio = self.state.uniform(self.lower, self.upper) + axes = [i if i >= 0 else x_stft.ndim - i for i in self.axes] + shape = [s if i in axes else 1 for i, s in enumerate(x_stft.shape)] + + mask = self.state.randn(*shape) > ratio + x_stft *= mask + return x_stft + + +class VolumePerturbation(object): + def __init__(self, lower=-1.6, upper=1.6, utt2ratio=None, dbunit=True, seed=None): + self.dbunit = dbunit + self.utt2ratio_file = utt2ratio + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + # Use the scheduled ratio for each utterances + self.utt2ratio = {} + self.lower = None + self.upper = None + self.accept_uttid = True + + with open(utt2ratio, "r") as f: + for line in f: + utt, ratio = line.rstrip().split(None, 1) + ratio = float(ratio) + self.utt2ratio[utt] = ratio + else: + # The ratio is given on runtime randomly + self.utt2ratio = None + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, dbunit={})".format( + self.__class__.__name__, self.lower, self.upper, self.dbunit + ) + else: + return '{}("{}", dbunit={})'.format( + self.__class__.__name__, self.utt2ratio_file, self.dbunit + ) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + + if self.accept_uttid: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + if self.dbunit: + ratio = 10 ** (ratio / 20) + return x * ratio + + +class NoiseInjection(object): + """Add isotropic noise""" + + def __init__( + self, + utt2noise=None, + lower=-20, + upper=-5, + utt2ratio=None, + filetype="list", + dbunit=True, + seed=None, + ): + self.utt2noise_file = utt2noise + self.utt2ratio_file = utt2ratio + self.filetype = filetype + self.dbunit = dbunit + self.lower = lower + self.upper = upper + self.state = numpy.random.RandomState(seed) + + if utt2ratio is not None: + # Use the scheduled ratio for each utterances + self.utt2ratio = {} + with open(utt2noise, "r") as f: + for line in f: + utt, snr = line.rstrip().split(None, 1) + snr = float(snr) + self.utt2ratio[utt] = snr + else: + # The ratio is given on runtime randomly + self.utt2ratio = None + + if utt2noise is not None: + self.utt2noise = {} + if filetype == "list": + with open(utt2noise, "r") as f: + for line in f: + utt, filename = line.rstrip().split(None, 1) + signal, rate = soundfile.read(filename, dtype="int16") + # Load all files in memory + self.utt2noise[utt] = (signal, rate) + + elif filetype == "sound.hdf5": + self.utt2noise = SoundHDF5File(utt2noise, "r") + else: + raise ValueError(filetype) + else: + self.utt2noise = None + + if utt2noise is not None and utt2ratio is not None: + if set(self.utt2ratio) != set(self.utt2noise): + raise RuntimeError( + "The uttids mismatch between {} and {}".format(utt2ratio, utt2noise) + ) + + def __repr__(self): + if self.utt2ratio is None: + return "{}(lower={}, upper={}, dbunit={})".format( + self.__class__.__name__, self.lower, self.upper, self.dbunit + ) + else: + return '{}("{}", dbunit={})'.format( + self.__class__.__name__, self.utt2ratio_file, self.dbunit + ) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + x = x.astype(numpy.float32) + + # 1. Get ratio of noise to signal in sound pressure level + if uttid is not None and self.utt2ratio is not None: + ratio = self.utt2ratio[uttid] + else: + ratio = self.state.uniform(self.lower, self.upper) + + if self.dbunit: + ratio = 10 ** (ratio / 20) + scale = ratio * numpy.sqrt((x ** 2).mean()) + + # 2. Get noise + if self.utt2noise is not None: + # Get noise from the external source + if uttid is not None: + noise, rate = self.utt2noise[uttid] + else: + # Randomly select the noise source + noise = self.state.choice(list(self.utt2noise.values())) + # Normalize the level + noise /= numpy.sqrt((noise ** 2).mean()) + + # Adjust the noise length + diff = abs(len(x) - len(noise)) + offset = self.state.randint(0, diff) + if len(noise) > len(x): + # Truncate noise + noise = noise[offset : -(diff - offset)] + else: + noise = numpy.pad(noise, pad_width=[offset, diff - offset], mode="wrap") + + else: + # Generate white noise + noise = self.state.normal(0, 1, x.shape) + + # 3. Add noise to signal + return x + noise * scale + + +class RIRConvolve(object): + def __init__(self, utt2rir, filetype="list"): + self.utt2rir_file = utt2rir + self.filetype = filetype + + self.utt2rir = {} + if filetype == "list": + with open(utt2rir, "r") as f: + for line in f: + utt, filename = line.rstrip().split(None, 1) + signal, rate = soundfile.read(filename, dtype="int16") + self.utt2rir[utt] = (signal, rate) + + elif filetype == "sound.hdf5": + self.utt2rir = SoundHDF5File(utt2rir, "r") + else: + raise NotImplementedError(filetype) + + def __repr__(self): + return '{}("{}")'.format(self.__class__.__name__, self.utt2rir_file) + + def __call__(self, x, uttid=None, train=True): + if not train: + return x + + x = x.astype(numpy.float32) + + if x.ndim != 1: + # Must be single channel + raise RuntimeError( + "Input x must be one dimensional array, but got {}".format(x.shape) + ) + + rir, rate = self.utt2rir[uttid] + if rir.ndim == 2: + # FIXME(kamo): Use chainer.convolution_1d? + # return [Time, Channel] + return numpy.stack( + [scipy.convolve(x, r, mode="same") for r in rir], axis=-1 + ) + else: + return scipy.convolve(x, rir, mode="same") diff --git a/espnet/transform/spec_augment.py b/espnet/transform/spec_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..789bf187a2d5264a585e9153cf62c89c7ace6a7f --- /dev/null +++ b/espnet/transform/spec_augment.py @@ -0,0 +1,202 @@ +"""Spec Augment module for preprocessing i.e., data augmentation""" + +import random + +import numpy +from PIL import Image +from PIL.Image import BICUBIC + +from espnet.transform.functional import FuncTrans + + +def time_warp(x, max_time_warp=80, inplace=False, mode="PIL"): + """time warp for spec augment + + move random center frame by the random width ~ uniform(-window, window) + :param numpy.ndarray x: spectrogram (time, freq) + :param int max_time_warp: maximum time frames to warp + :param bool inplace: overwrite x with the result + :param str mode: "PIL" (default, fast, not differentiable) or "sparse_image_warp" + (slow, differentiable) + :returns numpy.ndarray: time warped spectrogram (time, freq) + """ + window = max_time_warp + if mode == "PIL": + t = x.shape[0] + if t - window <= window: + return x + # NOTE: randrange(a, b) emits a, a + 1, ..., b - 1 + center = random.randrange(window, t - window) + warped = random.randrange(center - window, center + window) + 1 # 1 ... t - 1 + + left = Image.fromarray(x[:center]).resize((x.shape[1], warped), BICUBIC) + right = Image.fromarray(x[center:]).resize((x.shape[1], t - warped), BICUBIC) + if inplace: + x[:warped] = left + x[warped:] = right + return x + return numpy.concatenate((left, right), 0) + elif mode == "sparse_image_warp": + import torch + + from espnet.utils import spec_augment + + # TODO(karita): make this differentiable again + return spec_augment.time_warp(torch.from_numpy(x), window).numpy() + else: + raise NotImplementedError( + "unknown resize mode: " + + mode + + ", choose one from (PIL, sparse_image_warp)." + ) + + +class TimeWarp(FuncTrans): + _func = time_warp + __doc__ = time_warp.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def freq_mask(x, F=30, n_mask=2, replace_with_zero=True, inplace=False): + """freq mask for spec agument + + :param numpy.ndarray x: (time, freq) + :param int n_mask: the number of masks + :param bool inplace: overwrite + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + if inplace: + cloned = x + else: + cloned = x.copy() + + num_mel_channels = cloned.shape[1] + fs = numpy.random.randint(0, F, size=(n_mask, 2)) + + for f, mask_end in fs: + f_zero = random.randrange(0, num_mel_channels - f) + mask_end += f_zero + + # avoids randrange error if values are equal and range is empty + if f_zero == f_zero + f: + continue + + if replace_with_zero: + cloned[:, f_zero:mask_end] = 0 + else: + cloned[:, f_zero:mask_end] = cloned.mean() + return cloned + + +class FreqMask(FuncTrans): + _func = freq_mask + __doc__ = freq_mask.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def time_mask(spec, T=40, n_mask=2, replace_with_zero=True, inplace=False): + """freq mask for spec agument + + :param numpy.ndarray spec: (time, freq) + :param int n_mask: the number of masks + :param bool inplace: overwrite + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + if inplace: + cloned = spec + else: + cloned = spec.copy() + len_spectro = cloned.shape[0] + ts = numpy.random.randint(0, T, size=(n_mask, 2)) + for t, mask_end in ts: + # avoid randint range error + if len_spectro - t <= 0: + continue + t_zero = random.randrange(0, len_spectro - t) + + # avoids randrange error if values are equal and range is empty + if t_zero == t_zero + t: + continue + + mask_end += t_zero + if replace_with_zero: + cloned[t_zero:mask_end] = 0 + else: + cloned[t_zero:mask_end] = cloned.mean() + return cloned + + +class TimeMask(FuncTrans): + _func = time_mask + __doc__ = time_mask.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) + + +def spec_augment( + x, + resize_mode="PIL", + max_time_warp=80, + max_freq_width=27, + n_freq_mask=2, + max_time_width=100, + n_time_mask=2, + inplace=True, + replace_with_zero=True, +): + """spec agument + + apply random time warping and time/freq masking + default setting is based on LD (Librispeech double) in Table 2 + https://arxiv.org/pdf/1904.08779.pdf + + :param numpy.ndarray x: (time, freq) + :param str resize_mode: "PIL" (fast, nondifferentiable) or "sparse_image_warp" + (slow, differentiable) + :param int max_time_warp: maximum frames to warp the center frame in spectrogram (W) + :param int freq_mask_width: maximum width of the random freq mask (F) + :param int n_freq_mask: the number of the random freq mask (m_F) + :param int time_mask_width: maximum width of the random time mask (T) + :param int n_time_mask: the number of the random time mask (m_T) + :param bool inplace: overwrite intermediate array + :param bool replace_with_zero: pad zero on mask if true else use mean + """ + assert isinstance(x, numpy.ndarray) + assert x.ndim == 2 + x = time_warp(x, max_time_warp, inplace=inplace, mode=resize_mode) + x = freq_mask( + x, + max_freq_width, + n_freq_mask, + inplace=inplace, + replace_with_zero=replace_with_zero, + ) + x = time_mask( + x, + max_time_width, + n_time_mask, + inplace=inplace, + replace_with_zero=replace_with_zero, + ) + return x + + +class SpecAugment(FuncTrans): + _func = spec_augment + __doc__ = spec_augment.__doc__ + + def __call__(self, x, train): + if not train: + return x + return super().__call__(x) diff --git a/espnet/transform/spectrogram.py b/espnet/transform/spectrogram.py new file mode 100644 index 0000000000000000000000000000000000000000..518a00efea4b5e83b37104a3010643048a06863e --- /dev/null +++ b/espnet/transform/spectrogram.py @@ -0,0 +1,307 @@ +import librosa +import numpy as np + + +def stft( + x, n_fft, n_shift, win_length=None, window="hann", center=True, pad_mode="reflect" +): + # x: [Time, Channel] + if x.ndim == 1: + single_channel = True + # x: [Time] -> [Time, Channel] + x = x[:, None] + else: + single_channel = False + x = x.astype(np.float32) + + # FIXME(kamo): librosa.stft can't use multi-channel? + # x: [Time, Channel, Freq] + x = np.stack( + [ + librosa.stft( + x[:, ch], + n_fft=n_fft, + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, + pad_mode=pad_mode, + ).T + for ch in range(x.shape[1]) + ], + axis=1, + ) + + if single_channel: + # x: [Time, Channel, Freq] -> [Time, Freq] + x = x[:, 0] + return x + + +def istft(x, n_shift, win_length=None, window="hann", center=True): + # x: [Time, Channel, Freq] + if x.ndim == 2: + single_channel = True + # x: [Time, Freq] -> [Time, Channel, Freq] + x = x[:, None, :] + else: + single_channel = False + + # x: [Time, Channel] + x = np.stack( + [ + librosa.istft( + x[:, ch].T, # [Time, Freq] -> [Freq, Time] + hop_length=n_shift, + win_length=win_length, + window=window, + center=center, + ) + for ch in range(x.shape[1]) + ], + axis=1, + ) + + if single_channel: + # x: [Time, Channel] -> [Time] + x = x[:, 0] + return x + + +def stft2logmelspectrogram(x_stft, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): + # x_stft: (Time, Channel, Freq) or (Time, Freq) + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + + # spc: (Time, Channel, Freq) or (Time, Freq) + spc = np.abs(x_stft) + # mel_basis: (Mel_freq, Freq) + mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax) + # lmspc: (Time, Channel, Mel_freq) or (Time, Mel_freq) + lmspc = np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) + + return lmspc + + +def spectrogram(x, n_fft, n_shift, win_length=None, window="hann"): + # x: (Time, Channel) -> spc: (Time, Channel, Freq) + spc = np.abs(stft(x, n_fft, n_shift, win_length, window=window)) + return spc + + +def logmelspectrogram( + x, + fs, + n_mels, + n_fft, + n_shift, + win_length=None, + window="hann", + fmin=None, + fmax=None, + eps=1e-10, + pad_mode="reflect", +): + # stft: (Time, Channel, Freq) or (Time, Freq) + x_stft = stft( + x, + n_fft=n_fft, + n_shift=n_shift, + win_length=win_length, + window=window, + pad_mode=pad_mode, + ) + + return stft2logmelspectrogram( + x_stft, fs=fs, n_mels=n_mels, n_fft=n_fft, fmin=fmin, fmax=fmax, eps=eps + ) + + +class Spectrogram(object): + def __init__(self, n_fft, n_shift, win_length=None, window="hann"): + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + + def __repr__(self): + return ( + "{name}(n_fft={n_fft}, n_shift={n_shift}, " + "win_length={win_length}, window={window})".format( + name=self.__class__.__name__, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + ) + ) + + def __call__(self, x): + return spectrogram( + x, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + ) + + +class LogMelSpectrogram(object): + def __init__( + self, + fs, + n_mels, + n_fft, + n_shift, + win_length=None, + window="hann", + fmin=None, + fmax=None, + eps=1e-10, + ): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.fmin = fmin + self.fmax = fmax + self.eps = eps + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "n_shift={n_shift}, win_length={win_length}, window={window}, " + "fmin={fmin}, fmax={fmax}, eps={eps}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, + ) + ) + + def __call__(self, x): + return logmelspectrogram( + x, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + ) + + +class Stft2LogMelSpectrogram(object): + def __init__(self, fs, n_mels, n_fft, fmin=None, fmax=None, eps=1e-10): + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.fmin = fmin + self.fmax = fmax + self.eps = eps + + def __repr__(self): + return ( + "{name}(fs={fs}, n_mels={n_mels}, n_fft={n_fft}, " + "fmin={fmin}, fmax={fmax}, eps={eps}))".format( + name=self.__class__.__name__, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + fmin=self.fmin, + fmax=self.fmax, + eps=self.eps, + ) + ) + + def __call__(self, x): + return stft2logmelspectrogram( + x, + fs=self.fs, + n_mels=self.n_mels, + n_fft=self.n_fft, + fmin=self.fmin, + fmax=self.fmax, + ) + + +class Stft(object): + def __init__( + self, + n_fft, + n_shift, + win_length=None, + window="hann", + center=True, + pad_mode="reflect", + ): + self.n_fft = n_fft + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.center = center + self.pad_mode = pad_mode + + def __repr__(self): + return ( + "{name}(n_fft={n_fft}, n_shift={n_shift}, " + "win_length={win_length}, window={window}," + "center={center}, pad_mode={pad_mode})".format( + name=self.__class__.__name__, + n_fft=self.n_fft, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode, + ) + ) + + def __call__(self, x): + return stft( + x, + self.n_fft, + self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode=self.pad_mode, + ) + + +class IStft(object): + def __init__(self, n_shift, win_length=None, window="hann", center=True): + self.n_shift = n_shift + self.win_length = win_length + self.window = window + self.center = center + + def __repr__(self): + return ( + "{name}(n_shift={n_shift}, " + "win_length={win_length}, window={window}," + "center={center})".format( + name=self.__class__.__name__, + n_shift=self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + ) + ) + + def __call__(self, x): + return istft( + x, + self.n_shift, + win_length=self.win_length, + window=self.window, + center=self.center, + ) diff --git a/espnet/transform/transform_interface.py b/espnet/transform/transform_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..8a6aba45b0fbd4896e05a39b8b4aee774f86dd69 --- /dev/null +++ b/espnet/transform/transform_interface.py @@ -0,0 +1,20 @@ +# TODO(karita): add this to all the transform impl. +class TransformInterface: + """Transform Interface""" + + def __call__(self, x): + raise NotImplementedError("__call__ method is not implemented") + + @classmethod + def add_arguments(cls, parser): + return parser + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class Identity(TransformInterface): + """Identity Function""" + + def __call__(self, x): + return x diff --git a/espnet/transform/transformation.py b/espnet/transform/transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..49418f560394069996cd4d5d18c112d9bfa2a7aa --- /dev/null +++ b/espnet/transform/transformation.py @@ -0,0 +1,158 @@ +from collections import OrderedDict +import copy +import io +import logging +import sys + +import yaml + +from espnet.utils.dynamic_import import dynamic_import + + +PY2 = sys.version_info[0] == 2 + +if PY2: + from collections import Sequence + from funcsigs import signature +else: + # The ABCs from 'collections' will stop working in 3.8 + from collections.abc import Sequence + from inspect import signature + + +# TODO(karita): inherit TransformInterface +# TODO(karita): register cmd arguments in asr_train.py +import_alias = dict( + identity="espnet.transform.transform_interface:Identity", + time_warp="espnet.transform.spec_augment:TimeWarp", + time_mask="espnet.transform.spec_augment:TimeMask", + freq_mask="espnet.transform.spec_augment:FreqMask", + spec_augment="espnet.transform.spec_augment:SpecAugment", + speed_perturbation="espnet.transform.perturb:SpeedPerturbation", + volume_perturbation="espnet.transform.perturb:VolumePerturbation", + noise_injection="espnet.transform.perturb:NoiseInjection", + bandpass_perturbation="espnet.transform.perturb:BandpassPerturbation", + rir_convolve="espnet.transform.perturb:RIRConvolve", + delta="espnet.transform.add_deltas:AddDeltas", + cmvn="espnet.transform.cmvn:CMVN", + utterance_cmvn="espnet.transform.cmvn:UtteranceCMVN", + fbank="espnet.transform.spectrogram:LogMelSpectrogram", + spectrogram="espnet.transform.spectrogram:Spectrogram", + stft="espnet.transform.spectrogram:Stft", + istft="espnet.transform.spectrogram:IStft", + stft2fbank="espnet.transform.spectrogram:Stft2LogMelSpectrogram", + wpe="espnet.transform.wpe:WPE", + channel_selector="espnet.transform.channel_selector:ChannelSelector", +) + + +class Transformation(object): + """Apply some functions to the mini-batch + + Examples: + >>> kwargs = {"process": [{"type": "fbank", + ... "n_mels": 80, + ... "fs": 16000}, + ... {"type": "cmvn", + ... "stats": "data/train/cmvn.ark", + ... "norm_vars": True}, + ... {"type": "delta", "window": 2, "order": 2}]} + >>> transform = Transformation(kwargs) + >>> bs = 10 + >>> xs = [np.random.randn(100, 80).astype(np.float32) + ... for _ in range(bs)] + >>> xs = transform(xs) + """ + + def __init__(self, conffile=None): + if conffile is not None: + if isinstance(conffile, dict): + self.conf = copy.deepcopy(conffile) + else: + with io.open(conffile, encoding="utf-8") as f: + self.conf = yaml.safe_load(f) + assert isinstance(self.conf, dict), type(self.conf) + else: + self.conf = {"mode": "sequential", "process": []} + + self.functions = OrderedDict() + if self.conf.get("mode", "sequential") == "sequential": + for idx, process in enumerate(self.conf["process"]): + assert isinstance(process, dict), type(process) + opts = dict(process) + process_type = opts.pop("type") + class_obj = dynamic_import(process_type, import_alias) + # TODO(karita): assert issubclass(class_obj, TransformInterface) + try: + self.functions[idx] = class_obj(**opts) + except TypeError: + try: + signa = signature(class_obj) + except ValueError: + # Some function, e.g. built-in function, are failed + pass + else: + logging.error( + "Expected signature: {}({})".format( + class_obj.__name__, signa + ) + ) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"]) + ) + + def __repr__(self): + rep = "\n" + "\n".join( + " {}: {}".format(k, v) for k, v in self.functions.items() + ) + return "{}({})".format(self.__class__.__name__, rep) + + def __call__(self, xs, uttid_list=None, **kwargs): + """Return new mini-batch + + :param Union[Sequence[np.ndarray], np.ndarray] xs: + :param Union[Sequence[str], str] uttid_list: + :return: batch: + :rtype: List[np.ndarray] + """ + if not isinstance(xs, Sequence): + is_batch = False + xs = [xs] + else: + is_batch = True + + if isinstance(uttid_list, str): + uttid_list = [uttid_list for _ in range(len(xs))] + + if self.conf.get("mode", "sequential") == "sequential": + for idx in range(len(self.conf["process"])): + func = self.functions[idx] + # TODO(karita): use TrainingTrans and UttTrans to check __call__ args + # Derive only the args which the func has + try: + param = signature(func).parameters + except ValueError: + # Some function, e.g. built-in function, are failed + param = {} + _kwargs = {k: v for k, v in kwargs.items() if k in param} + try: + if uttid_list is not None and "uttid" in param: + xs = [func(x, u, **_kwargs) for x, u in zip(xs, uttid_list)] + else: + xs = [func(x, **_kwargs) for x in xs] + except Exception: + logging.fatal( + "Catch a exception from {}th func: {}".format(idx, func) + ) + raise + else: + raise NotImplementedError( + "Not supporting mode={}".format(self.conf["mode"]) + ) + + if is_batch: + return xs + else: + return xs[0] diff --git a/espnet/transform/wpe.py b/espnet/transform/wpe.py new file mode 100644 index 0000000000000000000000000000000000000000..8aed97e6bf503f5ef00117a94b7829d8ce3aa8ec --- /dev/null +++ b/espnet/transform/wpe.py @@ -0,0 +1,45 @@ +from nara_wpe.wpe import wpe + + +class WPE(object): + def __init__( + self, taps=10, delay=3, iterations=3, psd_context=0, statistics_mode="full" + ): + self.taps = taps + self.delay = delay + self.iterations = iterations + self.psd_context = psd_context + self.statistics_mode = statistics_mode + + def __repr__(self): + return ( + "{name}(taps={taps}, delay={delay}" + "iterations={iterations}, psd_context={psd_context}, " + "statistics_mode={statistics_mode})".format( + name=self.__class__.__name__, + taps=self.taps, + delay=self.delay, + iterations=self.iterations, + psd_context=self.psd_context, + statistics_mode=self.statistics_mode, + ) + ) + + def __call__(self, xs): + """Return enhanced + + :param np.ndarray xs: (Time, Channel, Frequency) + :return: enhanced_xs + :rtype: np.ndarray + + """ + # nara_wpe.wpe: (F, C, T) + xs = wpe( + xs.transpose((2, 1, 0)), + taps=self.taps, + delay=self.delay, + iterations=self.iterations, + psd_context=self.psd_context, + statistics_mode=self.statistics_mode, + ) + return xs.transpose(2, 1, 0) diff --git a/espnet/tts/__init__.py b/espnet/tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/tts/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/tts/pytorch_backend/__init__.py b/espnet/tts/pytorch_backend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/tts/pytorch_backend/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/tts/pytorch_backend/tts.py b/espnet/tts/pytorch_backend/tts.py new file mode 100644 index 0000000000000000000000000000000000000000..8e1e7bf1030c34e68f77426c292e8201e7cad6ed --- /dev/null +++ b/espnet/tts/pytorch_backend/tts.py @@ -0,0 +1,748 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2018 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""E2E-TTS training / decoding functions.""" + +import copy +import json +import logging +import math +import os +import time + +import chainer +import kaldiio +import numpy as np +import torch + +from chainer import training +from chainer.training import extensions + +from espnet.asr.asr_utils import get_model_conf +from espnet.asr.asr_utils import snapshot_object +from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import torch_resume +from espnet.asr.asr_utils import torch_snapshot +from espnet.asr.pytorch_backend.asr_init import load_trained_modules +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.tts_interface import TTSInterface +from espnet.utils.dataset import ChainerDataLoader +from espnet.utils.dataset import TransformDataset +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.io_utils import LoadInputsAndTargets +from espnet.utils.training.batchfy import make_batchset +from espnet.utils.training.evaluator import BaseEvaluator + +from espnet.utils.deterministic_utils import set_deterministic_pytorch +from espnet.utils.training.train_utils import check_early_stop +from espnet.utils.training.train_utils import set_early_stop + +from espnet.utils.training.iterators import ShufflingEnabler + +import matplotlib + +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from tensorboardX import SummaryWriter + +matplotlib.use("Agg") + + +class CustomEvaluator(BaseEvaluator): + """Custom evaluator.""" + + def __init__(self, model, iterator, target, device): + """Initilize module. + + Args: + model (torch.nn.Module): Pytorch model instance. + iterator (chainer.dataset.Iterator): Iterator for validation. + target (chainer.Chain): Dummy chain instance. + device (torch.device): The device to be used in evaluation. + + """ + super(CustomEvaluator, self).__init__(iterator, target) + self.model = model + self.device = device + + # The core part of the update routine can be customized by overriding. + def evaluate(self): + """Evaluate over validation iterator.""" + iterator = self._iterators["main"] + + if self.eval_hook: + self.eval_hook(self) + + if hasattr(iterator, "reset"): + iterator.reset() + it = iterator + else: + it = copy.copy(iterator) + + summary = chainer.reporter.DictSummary() + + self.model.eval() + with torch.no_grad(): + for batch in it: + if isinstance(batch, tuple): + x = tuple(arr.to(self.device) for arr in batch) + else: + x = batch + for key in x.keys(): + x[key] = x[key].to(self.device) + observation = {} + with chainer.reporter.report_scope(observation): + # convert to torch tensor + if isinstance(x, tuple): + self.model(*x) + else: + self.model(**x) + summary.add(observation) + self.model.train() + + return summary.compute_mean() + + +class CustomUpdater(training.StandardUpdater): + """Custom updater.""" + + def __init__(self, model, grad_clip, iterator, optimizer, device, accum_grad=1): + """Initilize module. + + Args: + model (torch.nn.Module) model: Pytorch model instance. + grad_clip (float) grad_clip : The gradient clipping value. + iterator (chainer.dataset.Iterator): Iterator for training. + optimizer (torch.optim.Optimizer) : Pytorch optimizer instance. + device (torch.device): The device to be used in training. + + """ + super(CustomUpdater, self).__init__(iterator, optimizer) + self.model = model + self.grad_clip = grad_clip + self.device = device + self.clip_grad_norm = torch.nn.utils.clip_grad_norm_ + self.accum_grad = accum_grad + self.forward_count = 0 + + # The core part of the update routine can be customized by overriding. + def update_core(self): + """Update model one step.""" + # When we pass one iterator and optimizer to StandardUpdater.__init__, + # they are automatically named 'main'. + train_iter = self.get_iterator("main") + optimizer = self.get_optimizer("main") + + # Get the next batch (a list of json files) + batch = train_iter.next() + if isinstance(batch, tuple): + x = tuple(arr.to(self.device) for arr in batch) + else: + x = batch + for key in x.keys(): + x[key] = x[key].to(self.device) + + # compute loss and gradient + if isinstance(x, tuple): + loss = self.model(*x).mean() / self.accum_grad + else: + loss = self.model(**x).mean() / self.accum_grad + loss.backward() + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + + # compute the gradient norm to check if it is normal or not + grad_norm = self.clip_grad_norm(self.model.parameters(), self.grad_clip) + logging.debug("grad norm={}".format(grad_norm)) + if math.isnan(grad_norm): + logging.warning("grad norm is nan. Do not update model.") + else: + optimizer.step() + optimizer.zero_grad() + + def update(self): + """Run update function.""" + self.update_core() + if self.forward_count == 0: + self.iteration += 1 + + +class CustomConverter(object): + """Custom converter.""" + + def __init__(self): + """Initilize module.""" + # NOTE: keep as class for future development + pass + + def __call__(self, batch, device=torch.device("cpu")): + """Convert a given batch. + + Args: + batch (list): List of ndarrays. + device (torch.device): The device to be send. + + Returns: + dict: Dict of converted tensors. + + Examples: + >>> batch = [([np.arange(5), np.arange(3)], + [np.random.randn(8, 2), np.random.randn(4, 2)], + None, None)] + >>> conveter = CustomConverter() + >>> conveter(batch, torch.device("cpu")) + {'xs': tensor([[0, 1, 2, 3, 4], + [0, 1, 2, 0, 0]]), + 'ilens': tensor([5, 3]), + 'ys': tensor([[[-0.4197, -1.1157], + [-1.5837, -0.4299], + [-2.0491, 0.9215], + [-2.4326, 0.8891], + [ 1.2323, 1.7388], + [-0.3228, 0.6656], + [-0.6025, 1.3693], + [-1.0778, 1.3447]], + [[ 0.1768, -0.3119], + [ 0.4386, 2.5354], + [-1.2181, -0.5918], + [-0.6858, -0.8843], + [ 0.0000, 0.0000], + [ 0.0000, 0.0000], + [ 0.0000, 0.0000], + [ 0.0000, 0.0000]]]), + 'labels': tensor([[0., 0., 0., 0., 0., 0., 0., 1.], + [0., 0., 0., 1., 1., 1., 1., 1.]]), + 'olens': tensor([8, 4])} + + """ + # batch should be located in list + assert len(batch) == 1 + xs, ys, spembs, extras = batch[0] + + # get list of lengths (must be tensor for DataParallel) + ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).long().to(device) + olens = torch.from_numpy(np.array([y.shape[0] for y in ys])).long().to(device) + + # perform padding and conversion to tensor + xs = pad_list([torch.from_numpy(x).long() for x in xs], 0).to(device) + ys = pad_list([torch.from_numpy(y).float() for y in ys], 0).to(device) + + # make labels for stop prediction + labels = ys.new_zeros(ys.size(0), ys.size(1)) + for i, l in enumerate(olens): + labels[i, l - 1 :] = 1.0 + + # prepare dict + new_batch = { + "xs": xs, + "ilens": ilens, + "ys": ys, + "labels": labels, + "olens": olens, + } + + # load speaker embedding + if spembs is not None: + spembs = torch.from_numpy(np.array(spembs)).float() + new_batch["spembs"] = spembs.to(device) + + # load second target + if extras is not None: + extras = pad_list([torch.from_numpy(extra).float() for extra in extras], 0) + new_batch["extras"] = extras.to(device) + + return new_batch + + +def train(args): + """Train E2E-TTS model.""" + set_deterministic_pytorch(args) + + # check cuda availability + if not torch.cuda.is_available(): + logging.warning("cuda is not available") + + # get input and output dimension info + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + utts = list(valid_json.keys()) + + # reverse input and output dimension + idim = int(valid_json[utts[0]]["output"][0]["shape"][1]) + odim = int(valid_json[utts[0]]["input"][0]["shape"][1]) + logging.info("#input dims : " + str(idim)) + logging.info("#output dims: " + str(odim)) + + # get extra input and output dimenstion + if args.use_speaker_embedding: + args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0]) + else: + args.spk_embed_dim = None + if args.use_second_target: + args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1]) + else: + args.spc_dim = None + + # write model config + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + model_conf = args.outdir + "/model.json" + with open(model_conf, "wb") as f: + logging.info("writing a model config file to" + model_conf) + f.write( + json.dumps( + (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) + for key in sorted(vars(args).keys()): + logging.info("ARGS: " + key + ": " + str(vars(args)[key])) + + # specify model architecture + if args.enc_init is not None or args.dec_init is not None: + model = load_trained_modules(idim, odim, args, TTSInterface) + else: + model_class = dynamic_import(args.model_module) + model = model_class(idim, odim, args) + assert isinstance(model, TTSInterface) + logging.info(model) + reporter = model.reporter + + # check the use of multi-gpu + if args.ngpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) + if args.batch_size != 0: + logging.warning( + "batch size is automatically increased (%d -> %d)" + % (args.batch_size, args.batch_size * args.ngpu) + ) + args.batch_size *= args.ngpu + + # set torch device + device = torch.device("cuda" if args.ngpu > 0 else "cpu") + model = model.to(device) + + # freeze modules, if specified + if args.freeze_mods: + if hasattr(model, "module"): + freeze_mods = ["module." + x for x in args.freeze_mods] + else: + freeze_mods = args.freeze_mods + + for mod, param in model.named_parameters(): + if any(mod.startswith(key) for key in freeze_mods): + logging.info(f"{mod} is frozen not to be updated.") + param.requires_grad = False + + model_params = filter(lambda x: x.requires_grad, model.parameters()) + else: + model_params = model.parameters() + + logging.warning( + "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(p.numel() for p in model.parameters() if p.requires_grad) + * 100.0 + / sum(p.numel() for p in model.parameters()), + ) + ) + + # Setup an optimizer + if args.opt == "adam": + optimizer = torch.optim.Adam( + model_params, args.lr, eps=args.eps, weight_decay=args.weight_decay + ) + elif args.opt == "noam": + from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + + optimizer = get_std_opt( + model_params, args.adim, args.transformer_warmup_steps, args.transformer_lr + ) + else: + raise NotImplementedError("unknown optimizer: " + args.opt) + + # FIXME: TOO DIRTY HACK + setattr(optimizer, "target", reporter) + setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) + + # read json data + with open(args.train_json, "rb") as f: + train_json = json.load(f)["utts"] + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + + use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 + if use_sortagrad: + args.batch_sort_key = "input" + # make minibatch list (variable length) + train_batchset = make_batchset( + train_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + batch_sort_key=args.batch_sort_key, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + shortest_first=use_sortagrad, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + swap_io=True, + iaxis=0, + oaxis=0, + ) + valid_batchset = make_batchset( + valid_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + batch_sort_key=args.batch_sort_key, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + swap_io=True, + iaxis=0, + oaxis=0, + ) + + load_tr = LoadInputsAndTargets( + mode="tts", + use_speaker_embedding=args.use_speaker_embedding, + use_second_target=args.use_second_target, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": True}, # Switch the mode of preprocessing + keep_all_data_on_mem=args.keep_all_data_on_mem, + ) + + load_cv = LoadInputsAndTargets( + mode="tts", + use_speaker_embedding=args.use_speaker_embedding, + use_second_target=args.use_second_target, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": False}, # Switch the mode of preprocessing + keep_all_data_on_mem=args.keep_all_data_on_mem, + ) + + converter = CustomConverter() + # hack to make batchsize argument as 1 + # actual bathsize is included in a list + train_iter = { + "main": ChainerDataLoader( + dataset=TransformDataset( + train_batchset, lambda data: converter([load_tr(data)]) + ), + batch_size=1, + num_workers=args.num_iter_processes, + shuffle=not use_sortagrad, + collate_fn=lambda x: x[0], + ) + } + valid_iter = { + "main": ChainerDataLoader( + dataset=TransformDataset( + valid_batchset, lambda data: converter([load_cv(data)]) + ), + batch_size=1, + shuffle=False, + collate_fn=lambda x: x[0], + num_workers=args.num_iter_processes, + ) + } + + # Set up a trainer + updater = CustomUpdater( + model, args.grad_clip, train_iter, optimizer, device, args.accum_grad + ) + trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) + + # Resume from a snapshot + if args.resume: + logging.info("resumed from %s" % args.resume) + torch_resume(args.resume, trainer) + + # set intervals + eval_interval = (args.eval_interval_epochs, "epoch") + save_interval = (args.save_interval_epochs, "epoch") + report_interval = (args.report_interval_iters, "iteration") + + # Evaluate the model with the test dataset for each epoch + trainer.extend( + CustomEvaluator(model, valid_iter, reporter, device), trigger=eval_interval + ) + + # Save snapshot for each epoch + trainer.extend(torch_snapshot(), trigger=save_interval) + + # Save best models + trainer.extend( + snapshot_object(model, "model.loss.best"), + trigger=training.triggers.MinValueTrigger( + "validation/main/loss", trigger=eval_interval + ), + ) + + # Save attention figure for each epoch + if args.num_save_attention > 0: + data = sorted( + list(valid_json.items())[: args.num_save_attention], + key=lambda x: int(x[1]["output"][0]["shape"][0]), + ) + if hasattr(model, "module"): + att_vis_fn = model.module.calculate_all_attentions + plot_class = model.module.attention_plot_class + reduction_factor = model.module.reduction_factor + else: + att_vis_fn = model.calculate_all_attentions + plot_class = model.attention_plot_class + reduction_factor = model.reduction_factor + if reduction_factor > 1: + # fix the length to crop attention weight plot correctly + data = copy.deepcopy(data) + for idx in range(len(data)): + ilen = data[idx][1]["input"][0]["shape"][0] + data[idx][1]["input"][0]["shape"][0] = ilen // reduction_factor + att_reporter = plot_class( + att_vis_fn, + data, + args.outdir + "/att_ws", + converter=converter, + transform=load_cv, + device=device, + reverse=True, + ) + trainer.extend(att_reporter, trigger=eval_interval) + else: + att_reporter = None + + # Make a plot for training and validation values + if hasattr(model, "module"): + base_plot_keys = model.module.base_plot_keys + else: + base_plot_keys = model.base_plot_keys + plot_keys = [] + for key in base_plot_keys: + plot_key = ["main/" + key, "validation/main/" + key] + trainer.extend( + extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"), + trigger=eval_interval, + ) + plot_keys += plot_key + trainer.extend( + extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"), + trigger=eval_interval, + ) + + # Write a log of evaluation statistics for each epoch + trainer.extend(extensions.LogReport(trigger=report_interval)) + report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys + trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) + trainer.extend(extensions.ProgressBar(), trigger=report_interval) + + set_early_stop(trainer, args) + if args.tensorboard_dir is not None and args.tensorboard_dir != "": + writer = SummaryWriter(args.tensorboard_dir) + trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) + + if use_sortagrad: + trainer.extend( + ShufflingEnabler([train_iter]), + trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), + ) + + # Run the training + trainer.run() + check_early_stop(trainer, args.epochs) + + +@torch.no_grad() +def decode(args): + """Decode with E2E-TTS model.""" + set_deterministic_pytorch(args) + # read training config + idim, odim, train_args = get_model_conf(args.model, args.model_conf) + + # show arguments + for key in sorted(vars(args).keys()): + logging.info("args: " + key + ": " + str(vars(args)[key])) + + # define model + model_class = dynamic_import(train_args.model_module) + model = model_class(idim, odim, train_args) + assert isinstance(model, TTSInterface) + logging.info(model) + + # load trained model parameters + logging.info("reading model parameters from " + args.model) + torch_load(args.model, model) + model.eval() + + # set torch device + device = torch.device("cuda" if args.ngpu > 0 else "cpu") + model = model.to(device) + + # read json data + with open(args.json, "rb") as f: + js = json.load(f)["utts"] + + # check directory + outdir = os.path.dirname(args.out) + if len(outdir) != 0 and not os.path.exists(outdir): + os.makedirs(outdir) + + load_inputs_and_targets = LoadInputsAndTargets( + mode="tts", + load_input=False, + sort_in_input_length=False, + use_speaker_embedding=train_args.use_speaker_embedding, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf, + preprocess_args={"train": False}, # Switch the mode of preprocessing + ) + + # define function for plot prob and att_ws + def _plot_and_save(array, figname, figsize=(6, 4), dpi=150): + import matplotlib.pyplot as plt + + shape = array.shape + if len(shape) == 1: + # for eos probability + plt.figure(figsize=figsize, dpi=dpi) + plt.plot(array) + plt.xlabel("Frame") + plt.ylabel("Probability") + plt.ylim([0, 1]) + elif len(shape) == 2: + # for tacotron 2 attention weights, whose shape is (out_length, in_length) + plt.figure(figsize=figsize, dpi=dpi) + plt.imshow(array, aspect="auto") + plt.xlabel("Input") + plt.ylabel("Output") + elif len(shape) == 4: + # for transformer attention weights, + # whose shape is (#leyers, #heads, out_length, in_length) + plt.figure(figsize=(figsize[0] * shape[0], figsize[1] * shape[1]), dpi=dpi) + for idx1, xs in enumerate(array): + for idx2, x in enumerate(xs, 1): + plt.subplot(shape[0], shape[1], idx1 * shape[1] + idx2) + plt.imshow(x, aspect="auto") + plt.xlabel("Input") + plt.ylabel("Output") + else: + raise NotImplementedError("Support only from 1D to 4D array.") + plt.tight_layout() + if not os.path.exists(os.path.dirname(figname)): + # NOTE: exist_ok = True is needed for parallel process decoding + os.makedirs(os.path.dirname(figname), exist_ok=True) + plt.savefig(figname) + plt.close() + + # define function to calculate focus rate + # (see section 3.3 in https://arxiv.org/abs/1905.09263) + def _calculate_focus_rete(att_ws): + if att_ws is None: + # fastspeech case -> None + return 1.0 + elif len(att_ws.shape) == 2: + # tacotron 2 case -> (L, T) + return float(att_ws.max(dim=-1)[0].mean()) + elif len(att_ws.shape) == 4: + # transformer case -> (#layers, #heads, L, T) + return float(att_ws.max(dim=-1)[0].mean(dim=-1).max()) + else: + raise ValueError("att_ws should be 2 or 4 dimensional tensor.") + + # define function to convert attention to duration + def _convert_att_to_duration(att_ws): + if len(att_ws.shape) == 2: + # tacotron 2 case -> (L, T) + pass + elif len(att_ws.shape) == 4: + # transformer case -> (#layers, #heads, L, T) + # get the most diagonal head according to focus rate + att_ws = torch.cat( + [att_w for att_w in att_ws], dim=0 + ) # (#heads * #layers, L, T) + diagonal_scores = att_ws.max(dim=-1)[0].mean(dim=-1) # (#heads * #layers,) + diagonal_head_idx = diagonal_scores.argmax() + att_ws = att_ws[diagonal_head_idx] # (L, T) + else: + raise ValueError("att_ws should be 2 or 4 dimensional tensor.") + # calculate duration from 2d attention weight + durations = torch.stack( + [att_ws.argmax(-1).eq(i).sum() for i in range(att_ws.shape[1])] + ) + return durations.view(-1, 1).float() + + # define writer instances + feat_writer = kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format(o=args.out)) + if args.save_durations: + dur_writer = kaldiio.WriteHelper( + "ark,scp:{o}.ark,{o}.scp".format(o=args.out.replace("feats", "durations")) + ) + if args.save_focus_rates: + fr_writer = kaldiio.WriteHelper( + "ark,scp:{o}.ark,{o}.scp".format(o=args.out.replace("feats", "focus_rates")) + ) + + # start decoding + for idx, utt_id in enumerate(js.keys()): + # setup inputs + batch = [(utt_id, js[utt_id])] + data = load_inputs_and_targets(batch) + x = torch.LongTensor(data[0][0]).to(device) + spemb = None + if train_args.use_speaker_embedding: + spemb = torch.FloatTensor(data[1][0]).to(device) + + # decode and write + start_time = time.time() + outs, probs, att_ws = model.inference(x, args, spemb=spemb) + logging.info( + "inference speed = %.1f frames / sec." + % (int(outs.size(0)) / (time.time() - start_time)) + ) + if outs.size(0) == x.size(0) * args.maxlenratio: + logging.warning("output length reaches maximum length (%s)." % utt_id) + focus_rate = _calculate_focus_rete(att_ws) + logging.info( + "(%d/%d) %s (size: %d->%d, focus rate: %.3f)" + % (idx + 1, len(js.keys()), utt_id, x.size(0), outs.size(0), focus_rate) + ) + feat_writer[utt_id] = outs.cpu().numpy() + if args.save_durations: + ds = _convert_att_to_duration(att_ws) + dur_writer[utt_id] = ds.cpu().numpy() + if args.save_focus_rates: + fr_writer[utt_id] = np.array(focus_rate).reshape(1, 1) + + # plot and save prob and att_ws + if probs is not None: + _plot_and_save( + probs.cpu().numpy(), + os.path.dirname(args.out) + "/probs/%s_prob.png" % utt_id, + ) + if att_ws is not None: + _plot_and_save( + att_ws.cpu().numpy(), + os.path.dirname(args.out) + "/att_ws/%s_att_ws.png" % utt_id, + ) + + # close file object + feat_writer.close() + if args.save_durations: + dur_writer.close() + if args.save_focus_rates: + fr_writer.close() diff --git a/espnet/utils/__init__.py b/espnet/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/utils/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/utils/check_kwargs.py b/espnet/utils/check_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..593bfa2485318dde1e1f4731012e0c15f80042cb --- /dev/null +++ b/espnet/utils/check_kwargs.py @@ -0,0 +1,20 @@ +import inspect + + +def check_kwargs(func, kwargs, name=None): + """check kwargs are valid for func + + If kwargs are invalid, raise TypeError as same as python default + :param function func: function to be validated + :param dict kwargs: keyword arguments for func + :param str name: name used in TypeError (default is func name) + """ + try: + params = inspect.signature(func).parameters + except ValueError: + return + if name is None: + name = func.__name__ + for k in kwargs.keys(): + if k not in params: + raise TypeError(f"{name}() got an unexpected keyword argument '{k}'") diff --git a/espnet/utils/cli_readers.py b/espnet/utils/cli_readers.py new file mode 100644 index 0000000000000000000000000000000000000000..5df116ab6cf3bce9ce9449a3b7b77da0e8856bf6 --- /dev/null +++ b/espnet/utils/cli_readers.py @@ -0,0 +1,234 @@ +import io +import logging +import sys + +import h5py +import kaldiio +import soundfile + +from espnet.utils.io_utils import SoundHDF5File + + +def file_reader_helper( + rspecifier: str, + filetype: str = "mat", + return_shape: bool = False, + segments: str = None, +): + """Read uttid and array in kaldi style + + This function might be a bit confusing as "ark" is used + for HDF5 to imitate "kaldi-rspecifier". + + Args: + rspecifier: Give as "ark:feats.ark" or "scp:feats.scp" + filetype: "mat" is kaldi-martix, "hdf5": HDF5 + return_shape: Return the shape of the matrix, + instead of the matrix. This can reduce IO cost for HDF5. + Returns: + Generator[Tuple[str, np.ndarray], None, None]: + + Examples: + Read from kaldi-matrix ark file: + + >>> for u, array in file_reader_helper('ark:feats.ark', 'mat'): + ... array + + Read from HDF5 file: + + >>> for u, array in file_reader_helper('ark:feats.h5', 'hdf5'): + ... array + + """ + if filetype == "mat": + return KaldiReader(rspecifier, return_shape=return_shape, segments=segments) + elif filetype == "hdf5": + return HDF5Reader(rspecifier, return_shape=return_shape) + elif filetype == "sound.hdf5": + return SoundHDF5Reader(rspecifier, return_shape=return_shape) + elif filetype == "sound": + return SoundReader(rspecifier, return_shape=return_shape) + else: + raise NotImplementedError(f"filetype={filetype}") + + +class KaldiReader: + def __init__(self, rspecifier, return_shape=False, segments=None): + self.rspecifier = rspecifier + self.return_shape = return_shape + self.segments = segments + + def __iter__(self): + with kaldiio.ReadHelper(self.rspecifier, segments=self.segments) as reader: + for key, array in reader: + if self.return_shape: + array = array.shape + yield key, array + + +class HDF5Reader: + def __init__(self, rspecifier, return_shape=False): + if ":" not in rspecifier: + raise ValueError( + 'Give "rspecifier" such as "ark:some.ark: {}"'.format(self.rspecifier) + ) + self.rspecifier = rspecifier + self.ark_or_scp, self.filepath = self.rspecifier.split(":", 1) + if self.ark_or_scp not in ["ark", "scp"]: + raise ValueError(f"Must be scp or ark: {self.ark_or_scp}") + + self.return_shape = return_shape + + def __iter__(self): + if self.ark_or_scp == "scp": + hdf5_dict = {} + with open(self.filepath, "r", encoding="utf-8") as f: + for line in f: + key, value = line.rstrip().split(None, 1) + + if ":" not in value: + raise RuntimeError( + "scp file for hdf5 should be like: " + '"uttid filepath.h5:key": {}({})'.format( + line, self.filepath + ) + ) + path, h5_key = value.split(":", 1) + + hdf5_file = hdf5_dict.get(path) + if hdf5_file is None: + try: + hdf5_file = h5py.File(path, "r") + except Exception: + logging.error("Error when loading {}".format(path)) + raise + hdf5_dict[path] = hdf5_file + + try: + data = hdf5_file[h5_key] + except Exception: + logging.error( + "Error when loading {} with key={}".format(path, h5_key) + ) + raise + + if self.return_shape: + yield key, data.shape + else: + yield key, data[()] + + # Closing all files + for k in hdf5_dict: + try: + hdf5_dict[k].close() + except Exception: + pass + + else: + if self.filepath == "-": + # Required h5py>=2.9 + filepath = io.BytesIO(sys.stdin.buffer.read()) + else: + filepath = self.filepath + with h5py.File(filepath, "r") as f: + for key in f: + if self.return_shape: + yield key, f[key].shape + else: + yield key, f[key][()] + + +class SoundHDF5Reader: + def __init__(self, rspecifier, return_shape=False): + if ":" not in rspecifier: + raise ValueError( + 'Give "rspecifier" such as "ark:some.ark: {}"'.format(rspecifier) + ) + self.ark_or_scp, self.filepath = rspecifier.split(":", 1) + if self.ark_or_scp not in ["ark", "scp"]: + raise ValueError(f"Must be scp or ark: {self.ark_or_scp}") + self.return_shape = return_shape + + def __iter__(self): + if self.ark_or_scp == "scp": + hdf5_dict = {} + with open(self.filepath, "r", encoding="utf-8") as f: + for line in f: + key, value = line.rstrip().split(None, 1) + + if ":" not in value: + raise RuntimeError( + "scp file for hdf5 should be like: " + '"uttid filepath.h5:key": {}({})'.format( + line, self.filepath + ) + ) + path, h5_key = value.split(":", 1) + + hdf5_file = hdf5_dict.get(path) + if hdf5_file is None: + try: + hdf5_file = SoundHDF5File(path, "r") + except Exception: + logging.error("Error when loading {}".format(path)) + raise + hdf5_dict[path] = hdf5_file + + try: + data = hdf5_file[h5_key] + except Exception: + logging.error( + "Error when loading {} with key={}".format(path, h5_key) + ) + raise + + # Change Tuple[ndarray, int] -> Tuple[int, ndarray] + # (soundfile style -> scipy style) + array, rate = data + if self.return_shape: + array = array.shape + yield key, (rate, array) + + # Closing all files + for k in hdf5_dict: + try: + hdf5_dict[k].close() + except Exception: + pass + + else: + if self.filepath == "-": + # Required h5py>=2.9 + filepath = io.BytesIO(sys.stdin.buffer.read()) + else: + filepath = self.filepath + for key, (a, r) in SoundHDF5File(filepath, "r").items(): + if self.return_shape: + a = a.shape + yield key, (r, a) + + +class SoundReader: + def __init__(self, rspecifier, return_shape=False): + if ":" not in rspecifier: + raise ValueError( + 'Give "rspecifier" such as "scp:some.scp: {}"'.format(rspecifier) + ) + self.ark_or_scp, self.filepath = rspecifier.split(":", 1) + if self.ark_or_scp != "scp": + raise ValueError( + 'Only supporting "scp" for sound file: {}'.format(self.ark_or_scp) + ) + self.return_shape = return_shape + + def __iter__(self): + with open(self.filepath, "r", encoding="utf-8") as f: + for line in f: + key, sound_file_path = line.rstrip().split(None, 1) + # Assume PCM16 + array, rate = soundfile.read(sound_file_path, dtype="int16") + # Change Tuple[ndarray, int] -> Tuple[int, ndarray] + # (soundfile style -> scipy style) + if self.return_shape: + array = array.shape + yield key, (rate, array) diff --git a/espnet/utils/cli_utils.py b/espnet/utils/cli_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a4cd15b72f832d9118aa7a7377a13de16c329b --- /dev/null +++ b/espnet/utils/cli_utils.py @@ -0,0 +1,65 @@ +from collections.abc import Sequence +from distutils.util import strtobool as dist_strtobool +import sys + +import numpy + + +def strtobool(x): + # distutils.util.strtobool returns integer, but it's confusing, + return bool(dist_strtobool(x)) + + +def get_commandline_args(): + extra_chars = [ + " ", + ";", + "&", + "(", + ")", + "|", + "^", + "<", + ">", + "?", + "*", + "[", + "]", + "$", + "`", + '"', + "\\", + "!", + "{", + "}", + ] + + # Escape the extra characters for shell + argv = [ + arg.replace("'", "'\\''") + if all(char not in arg for char in extra_chars) + else "'" + arg.replace("'", "'\\''") + "'" + for arg in sys.argv + ] + + return sys.executable + " " + " ".join(argv) + + +def is_scipy_wav_style(value): + # If Tuple[int, numpy.ndarray] or not + return ( + isinstance(value, Sequence) + and len(value) == 2 + and isinstance(value[0], int) + and isinstance(value[1], numpy.ndarray) + ) + + +def assert_scipy_wav_style(value): + assert is_scipy_wav_style( + value + ), "Must be Tuple[int, numpy.ndarray], but got {}".format( + type(value) + if not isinstance(value, Sequence) + else "{}[{}]".format(type(value), ", ".join(str(type(v)) for v in value)) + ) diff --git a/espnet/utils/cli_writers.py b/espnet/utils/cli_writers.py new file mode 100644 index 0000000000000000000000000000000000000000..42c350a33dbb09e0f52576cda6da9b2fff58c189 --- /dev/null +++ b/espnet/utils/cli_writers.py @@ -0,0 +1,282 @@ +from pathlib import Path +from typing import Dict + +import h5py +import kaldiio +import numpy +import soundfile + +from espnet.utils.cli_utils import assert_scipy_wav_style +from espnet.utils.io_utils import SoundHDF5File + + +def file_writer_helper( + wspecifier: str, + filetype: str = "mat", + write_num_frames: str = None, + compress: bool = False, + compression_method: int = 2, + pcm_format: str = "wav", +): + """Write matrices in kaldi style + + Args: + wspecifier: e.g. ark,scp:out.ark,out.scp + filetype: "mat" is kaldi-martix, "hdf5": HDF5 + write_num_frames: e.g. 'ark,t:num_frames.txt' + compress: Compress or not + compression_method: Specify compression level + + Write in kaldi-matrix-ark with "kaldi-scp" file: + + >>> with file_writer_helper('ark,scp:out.ark,out.scp') as f: + >>> f['uttid'] = array + + This "scp" has the following format: + + uttidA out.ark:1234 + uttidB out.ark:2222 + + where, 1234 and 2222 points the strating byte address of the matrix. + (For detail, see official documentation of Kaldi) + + Write in HDF5 with "scp" file: + + >>> with file_writer_helper('ark,scp:out.h5,out.scp', 'hdf5') as f: + >>> f['uttid'] = array + + This "scp" file is created as: + + uttidA out.h5:uttidA + uttidB out.h5:uttidB + + HDF5 can be, unlike "kaldi-ark", accessed to any keys, + so originally "scp" is not required for random-reading. + Nevertheless we create "scp" for HDF5 because it is useful + for some use-case. e.g. Concatenation, Splitting. + + """ + if filetype == "mat": + return KaldiWriter( + wspecifier, + write_num_frames=write_num_frames, + compress=compress, + compression_method=compression_method, + ) + elif filetype == "hdf5": + return HDF5Writer( + wspecifier, write_num_frames=write_num_frames, compress=compress + ) + elif filetype == "sound.hdf5": + return SoundHDF5Writer( + wspecifier, write_num_frames=write_num_frames, pcm_format=pcm_format + ) + elif filetype == "sound": + return SoundWriter( + wspecifier, write_num_frames=write_num_frames, pcm_format=pcm_format + ) + else: + raise NotImplementedError(f"filetype={filetype}") + + +class BaseWriter: + def __setitem__(self, key, value): + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + try: + self.writer.close() + except Exception: + pass + + if self.writer_scp is not None: + try: + self.writer_scp.close() + except Exception: + pass + + if self.writer_nframe is not None: + try: + self.writer_nframe.close() + except Exception: + pass + + +def get_num_frames_writer(write_num_frames: str): + """get_num_frames_writer + + Examples: + >>> get_num_frames_writer('ark,t:num_frames.txt') + """ + if write_num_frames is not None: + if ":" not in write_num_frames: + raise ValueError( + 'Must include ":", write_num_frames={}'.format(write_num_frames) + ) + + nframes_type, nframes_file = write_num_frames.split(":", 1) + if nframes_type != "ark,t": + raise ValueError( + "Only supporting text mode. " + "e.g. --write-num-frames=ark,t:foo.txt :" + "{}".format(nframes_type) + ) + + return open(nframes_file, "w", encoding="utf-8") + + +class KaldiWriter(BaseWriter): + def __init__( + self, wspecifier, write_num_frames=None, compress=False, compression_method=2 + ): + if compress: + self.writer = kaldiio.WriteHelper( + wspecifier, compression_method=compression_method + ) + else: + self.writer = kaldiio.WriteHelper(wspecifier) + self.writer_scp = None + if write_num_frames is not None: + self.writer_nframe = get_num_frames_writer(write_num_frames) + else: + self.writer_nframe = None + + def __setitem__(self, key, value): + self.writer[key] = value + if self.writer_nframe is not None: + self.writer_nframe.write(f"{key} {len(value)}\n") + + +def parse_wspecifier(wspecifier: str) -> Dict[str, str]: + """Parse wspecifier to dict + + Examples: + >>> parse_wspecifier('ark,scp:out.ark,out.scp') + {'ark': 'out.ark', 'scp': 'out.scp'} + + """ + ark_scp, filepath = wspecifier.split(":", 1) + if ark_scp not in ["ark", "scp,ark", "ark,scp"]: + raise ValueError("{} is not allowed: {}".format(ark_scp, wspecifier)) + ark_scps = ark_scp.split(",") + filepaths = filepath.split(",") + if len(ark_scps) != len(filepaths): + raise ValueError("Mismatch: {} and {}".format(ark_scp, filepath)) + spec_dict = dict(zip(ark_scps, filepaths)) + return spec_dict + + +class HDF5Writer(BaseWriter): + """HDF5Writer + + Examples: + >>> with HDF5Writer('ark:out.h5', compress=True) as f: + ... f['key'] = array + """ + + def __init__(self, wspecifier, write_num_frames=None, compress=False): + spec_dict = parse_wspecifier(wspecifier) + self.filename = spec_dict["ark"] + + if compress: + self.kwargs = {"compression": "gzip"} + else: + self.kwargs = {} + self.writer = h5py.File(spec_dict["ark"], "w") + if "scp" in spec_dict: + self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8") + else: + self.writer_scp = None + if write_num_frames is not None: + self.writer_nframe = get_num_frames_writer(write_num_frames) + else: + self.writer_nframe = None + + def __setitem__(self, key, value): + self.writer.create_dataset(key, data=value, **self.kwargs) + + if self.writer_scp is not None: + self.writer_scp.write(f"{key} {self.filename}:{key}\n") + if self.writer_nframe is not None: + self.writer_nframe.write(f"{key} {len(value)}\n") + + +class SoundHDF5Writer(BaseWriter): + """SoundHDF5Writer + + Examples: + >>> fs = 16000 + >>> with SoundHDF5Writer('ark:out.h5') as f: + ... f['key'] = fs, array + """ + + def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"): + self.pcm_format = pcm_format + spec_dict = parse_wspecifier(wspecifier) + self.filename = spec_dict["ark"] + self.writer = SoundHDF5File(spec_dict["ark"], "w", format=self.pcm_format) + if "scp" in spec_dict: + self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8") + else: + self.writer_scp = None + if write_num_frames is not None: + self.writer_nframe = get_num_frames_writer(write_num_frames) + else: + self.writer_nframe = None + + def __setitem__(self, key, value): + assert_scipy_wav_style(value) + # Change Tuple[int, ndarray] -> Tuple[ndarray, int] + # (scipy style -> soundfile style) + value = (value[1], value[0]) + self.writer.create_dataset(key, data=value) + + if self.writer_scp is not None: + self.writer_scp.write(f"{key} {self.filename}:{key}\n") + if self.writer_nframe is not None: + self.writer_nframe.write(f"{key} {len(value[0])}\n") + + +class SoundWriter(BaseWriter): + """SoundWriter + + Examples: + >>> fs = 16000 + >>> with SoundWriter('ark,scp:outdir,out.scp') as f: + ... f['key'] = fs, array + """ + + def __init__(self, wspecifier, write_num_frames=None, pcm_format="wav"): + self.pcm_format = pcm_format + spec_dict = parse_wspecifier(wspecifier) + # e.g. ark,scp:dirname,wav.scp + # -> The wave files are found in dirname/*.wav + self.dirname = spec_dict["ark"] + Path(self.dirname).mkdir(parents=True, exist_ok=True) + self.writer = None + + if "scp" in spec_dict: + self.writer_scp = open(spec_dict["scp"], "w", encoding="utf-8") + else: + self.writer_scp = None + if write_num_frames is not None: + self.writer_nframe = get_num_frames_writer(write_num_frames) + else: + self.writer_nframe = None + + def __setitem__(self, key, value): + assert_scipy_wav_style(value) + rate, signal = value + wavfile = Path(self.dirname) / (key + "." + self.pcm_format) + soundfile.write(wavfile, signal.astype(numpy.int16), rate) + + if self.writer_scp is not None: + self.writer_scp.write(f"{key} {wavfile}\n") + if self.writer_nframe is not None: + self.writer_nframe.write(f"{key} {len(signal)}\n") diff --git a/espnet/utils/dataset.py b/espnet/utils/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ac37ae6f7c7d0100ef54eeb938fbe8c86b4f1065 --- /dev/null +++ b/espnet/utils/dataset.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""pytorch dataset and dataloader implementation for chainer training.""" + +import torch +import torch.utils.data + + +class TransformDataset(torch.utils.data.Dataset): + """Transform Dataset for pytorch backend. + + Args: + data: list object from make_batchset + transfrom: transform function + + """ + + def __init__(self, data, transform): + """Init function.""" + super(TransformDataset).__init__() + self.data = data + self.transform = transform + + def __len__(self): + """Len function.""" + return len(self.data) + + def __getitem__(self, idx): + """[] operator.""" + return self.transform(self.data[idx]) + + +class ChainerDataLoader(object): + """Pytorch dataloader in chainer style. + + Args: + all args for torch.utils.data.dataloader.Dataloader + + """ + + def __init__(self, **kwargs): + """Init function.""" + self.loader = torch.utils.data.dataloader.DataLoader(**kwargs) + self.len = len(kwargs["dataset"]) + self.current_position = 0 + self.epoch = 0 + self.iter = None + self.kwargs = kwargs + + def next(self): + """Implement next function.""" + if self.iter is None: + self.iter = iter(self.loader) + try: + ret = next(self.iter) + except StopIteration: + self.iter = None + return self.next() + self.current_position += 1 + if self.current_position == self.len: + self.epoch = self.epoch + 1 + self.current_position = 0 + return ret + + def __iter__(self): + """Implement iter function.""" + for batch in self.loader: + yield batch + + @property + def epoch_detail(self): + """Epoch_detail required by chainer.""" + return self.epoch + self.current_position / self.len + + def serialize(self, serializer): + """Serialize and deserialize function.""" + epoch = serializer("epoch", self.epoch) + current_position = serializer("current_position", self.current_position) + self.epoch = epoch + self.current_position = current_position + + def start_shuffle(self): + """Shuffle function for sortagrad.""" + self.kwargs["shuffle"] = True + self.loader = torch.utils.data.dataloader.DataLoader(**self.kwargs) + + def finalize(self): + """Implement finalize function.""" + del self.loader diff --git a/espnet/utils/deterministic_utils.py b/espnet/utils/deterministic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..95de435220cd1c844826311354e1da038a07db4e --- /dev/null +++ b/espnet/utils/deterministic_utils.py @@ -0,0 +1,55 @@ +import logging +import os + +import chainer +import torch + + +def set_deterministic_pytorch(args): + """Ensures pytorch produces deterministic results depending on the program arguments + + :param Namespace args: The program arguments + """ + # seed setting + torch.manual_seed(args.seed) + + # debug mode setting + # 0 would be fastest, but 1 seems to be reasonable + # considering reproducibility + # remove type check + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = ( + False # https://github.com/pytorch/pytorch/issues/6351 + ) + if args.debugmode < 2: + chainer.config.type_check = False + logging.info("torch type check is disabled") + # use deterministic computation or not + if args.debugmode < 1: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + logging.info("torch cudnn deterministic is disabled") + + +def set_deterministic_chainer(args): + """Ensures chainer produces deterministic results depending on the program arguments + + :param Namespace args: The program arguments + """ + # seed setting (chainer seed may not need it) + os.environ["CHAINER_SEED"] = str(args.seed) + logging.info("chainer seed = " + os.environ["CHAINER_SEED"]) + + # debug mode setting + # 0 would be fastest, but 1 seems to be reasonable + # considering reproducibility + # remove type check + if args.debugmode < 2: + chainer.config.type_check = False + logging.info("chainer type check is disabled") + # use deterministic computation or not + if args.debugmode < 1: + chainer.config.cudnn_deterministic = False + logging.info("chainer cudnn deterministic is disabled") + else: + chainer.config.cudnn_deterministic = True diff --git a/espnet/utils/dynamic_import.py b/espnet/utils/dynamic_import.py new file mode 100644 index 0000000000000000000000000000000000000000..db885d0069bfb8f59dcf03f5477c13706574b217 --- /dev/null +++ b/espnet/utils/dynamic_import.py @@ -0,0 +1,23 @@ +import importlib + + +def dynamic_import(import_path, alias=dict()): + """dynamic import module and class + + :param str import_path: syntax 'module_name:class_name' + e.g., 'espnet.transform.add_deltas:AddDeltas' + :param dict alias: shortcut for registered class + :return: imported class + """ + if import_path not in alias and ":" not in import_path: + raise ValueError( + "import_path should be one of {} or " + 'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : ' + "{}".format(set(alias), import_path) + ) + if ":" not in import_path: + import_path = alias[import_path] + + module_name, objname = import_path.split(":") + m = importlib.import_module(module_name) + return getattr(m, objname) diff --git a/espnet/utils/fill_missing_args.py b/espnet/utils/fill_missing_args.py new file mode 100644 index 0000000000000000000000000000000000000000..a0fd117529569976780436c0d79e7ce158cd44e9 --- /dev/null +++ b/espnet/utils/fill_missing_args.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +# Copyright 2018 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import argparse +import logging + + +def fill_missing_args(args, add_arguments): + """Fill missing arguments in args. + + Args: + args (Namespace or None): Namesapce containing hyperparameters. + add_arguments (function): Function to add arguments. + + Returns: + Namespace: Arguments whose missing ones are filled with default value. + + Examples: + >>> from argparse import Namespace + >>> from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2 + >>> args = Namespace() + >>> fill_missing_args(args, Tacotron2.add_arguments_fn) + Namespace(aconv_chans=32, aconv_filts=15, adim=512, atype='location', ...) + + """ + # check argument type + assert isinstance(args, argparse.Namespace) or args is None + assert callable(add_arguments) + + # get default arguments + default_args, _ = add_arguments(argparse.ArgumentParser()).parse_known_args() + + # convert to dict + args = {} if args is None else vars(args) + default_args = vars(default_args) + + for key, value in default_args.items(): + if key not in args: + logging.info( + 'attribute "%s" does not exist. use default %s.' % (key, str(value)) + ) + args[key] = value + + return argparse.Namespace(**args) diff --git a/espnet/utils/io_utils.py b/espnet/utils/io_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56277c2a5b8c611c40a0f366fdeda67f1c2a5492 --- /dev/null +++ b/espnet/utils/io_utils.py @@ -0,0 +1,641 @@ +from collections import OrderedDict +import io +import logging +import os + +import h5py +import kaldiio +import numpy as np +import soundfile + +from espnet.transform.transformation import Transformation + + +class LoadInputsAndTargets(object): + """Create a mini-batch from a list of dicts + + >>> batch = [('utt1', + ... dict(input=[dict(feat='some.ark:123', + ... filetype='mat', + ... name='input1', + ... shape=[100, 80])], + ... output=[dict(tokenid='1 2 3 4', + ... name='target1', + ... shape=[4, 31])]])) + >>> l = LoadInputsAndTargets() + >>> feat, target = l(batch) + + :param: str mode: Specify the task mode, "asr" or "tts" + :param: str preprocess_conf: The path of a json file for pre-processing + :param: bool load_input: If False, not to load the input data + :param: bool load_output: If False, not to load the output data + :param: bool sort_in_input_length: Sort the mini-batch in descending order + of the input length + :param: bool use_speaker_embedding: Used for tts mode only + :param: bool use_second_target: Used for tts mode only + :param: dict preprocess_args: Set some optional arguments for preprocessing + :param: Optional[dict] preprocess_args: Used for tts mode only + """ + + def __init__( + self, + mode="asr", + preprocess_conf=None, + load_input=True, + load_output=True, + sort_in_input_length=True, + use_speaker_embedding=False, + use_second_target=False, + preprocess_args=None, + keep_all_data_on_mem=False, + ): + self._loaders = {} + if mode not in ["asr", "tts", "mt", "vc"]: + raise ValueError("Only asr or tts are allowed: mode={}".format(mode)) + if preprocess_conf is not None: + self.preprocessing = Transformation(preprocess_conf) + logging.warning( + "[Experimental feature] Some preprocessing will be done " + "for the mini-batch creation using {}".format(self.preprocessing) + ) + else: + # If conf doesn't exist, this function don't touch anything. + self.preprocessing = None + + if use_second_target and use_speaker_embedding and mode == "tts": + raise ValueError( + 'Choose one of "use_second_target" and ' '"use_speaker_embedding "' + ) + if ( + (use_second_target or use_speaker_embedding) + and mode != "tts" + and mode != "vc" + ): + logging.warning( + '"use_second_target" and "use_speaker_embedding" is ' + "used only for tts or vc mode" + ) + + self.mode = mode + self.load_output = load_output + self.load_input = load_input + self.sort_in_input_length = sort_in_input_length + self.use_speaker_embedding = use_speaker_embedding + self.use_second_target = use_second_target + if preprocess_args is None: + self.preprocess_args = {} + else: + assert isinstance(preprocess_args, dict), type(preprocess_args) + self.preprocess_args = dict(preprocess_args) + + self.keep_all_data_on_mem = keep_all_data_on_mem + + def __call__(self, batch, return_uttid=False): + """Function to load inputs and targets from list of dicts + + :param List[Tuple[str, dict]] batch: list of dict which is subset of + loaded data.json + :param bool return_uttid: return utterance ID information for visualization + :return: list of input token id sequences [(L_1), (L_2), ..., (L_B)] + :return: list of input feature sequences + [(T_1, D), (T_2, D), ..., (T_B, D)] + :rtype: list of float ndarray + :return: list of target token id sequences [(L_1), (L_2), ..., (L_B)] + :rtype: list of int ndarray + + """ + x_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]] + y_feats_dict = OrderedDict() # OrderedDict[str, List[np.ndarray]] + uttid_list = [] # List[str] + + for uttid, info in batch: + uttid_list.append(uttid) + + if self.load_input: + # Note(kamo): This for-loop is for multiple inputs + for idx, inp in enumerate(info["input"]): + # {"input": + # [{"feat": "some/path.h5:F01_050C0101_PED_REAL", + # "filetype": "hdf5", + # "name": "input1", ...}], ...} + x = self._get_from_loader( + filepath=inp["feat"], filetype=inp.get("filetype", "mat") + ) + x_feats_dict.setdefault(inp["name"], []).append(x) + # FIXME(kamo): Dirty way to load only speaker_embedding + elif self.mode == "tts" and self.use_speaker_embedding: + for idx, inp in enumerate(info["input"]): + if idx != 1 and len(info["input"]) > 1: + x = None + else: + x = self._get_from_loader( + filepath=inp["feat"], filetype=inp.get("filetype", "mat") + ) + x_feats_dict.setdefault(inp["name"], []).append(x) + + if self.load_output: + if self.mode == "mt": + x = np.fromiter( + map(int, info["output"][1]["tokenid"].split()), dtype=np.int64 + ) + x_feats_dict.setdefault(info["output"][1]["name"], []).append(x) + + for idx, inp in enumerate(info["output"]): + if "tokenid" in inp: + # ======= Legacy format for output ======= + # {"output": [{"tokenid": "1 2 3 4"}]) + x = np.fromiter( + map(int, inp["tokenid"].split()), dtype=np.int64 + ) + else: + # ======= New format ======= + # {"input": + # [{"feat": "some/path.h5:F01_050C0101_PED_REAL", + # "filetype": "hdf5", + # "name": "target1", ...}], ...} + x = self._get_from_loader( + filepath=inp["feat"], filetype=inp.get("filetype", "mat") + ) + + y_feats_dict.setdefault(inp["name"], []).append(x) + + if self.mode == "asr": + return_batch, uttid_list = self._create_batch_asr( + x_feats_dict, y_feats_dict, uttid_list + ) + elif self.mode == "tts": + _, info = batch[0] + eos = int(info["output"][0]["shape"][1]) - 1 + return_batch, uttid_list = self._create_batch_tts( + x_feats_dict, y_feats_dict, uttid_list, eos + ) + elif self.mode == "mt": + return_batch, uttid_list = self._create_batch_mt( + x_feats_dict, y_feats_dict, uttid_list + ) + elif self.mode == "vc": + return_batch, uttid_list = self._create_batch_vc( + x_feats_dict, y_feats_dict, uttid_list + ) + else: + raise NotImplementedError(self.mode) + + if self.preprocessing is not None: + # Apply pre-processing all input features + for x_name in return_batch.keys(): + if x_name.startswith("input"): + return_batch[x_name] = self.preprocessing( + return_batch[x_name], uttid_list, **self.preprocess_args + ) + + if return_uttid: + return tuple(return_batch.values()), uttid_list + + # Doesn't return the names now. + return tuple(return_batch.values()) + + def _create_batch_asr(self, x_feats_dict, y_feats_dict, uttid_list): + """Create a OrderedDict for the mini-batch + + :param OrderedDict x_feats_dict: + e.g. {"input1": [ndarray, ndarray, ...], + "input2": [ndarray, ndarray, ...]} + :param OrderedDict y_feats_dict: + e.g. {"target1": [ndarray, ndarray, ...], + "target2": [ndarray, ndarray, ...]} + :param: List[str] uttid_list: + Give uttid_list to sort in the same order as the mini-batch + :return: batch, uttid_list + :rtype: Tuple[OrderedDict, List[str]] + """ + # handle single-input and multi-input (paralell) asr mode + xs = list(x_feats_dict.values()) + + if self.load_output: + ys = list(y_feats_dict.values()) + assert len(xs[0]) == len(ys[0]), (len(xs[0]), len(ys[0])) + + # get index of non-zero length samples + nonzero_idx = list(filter(lambda i: len(ys[0][i]) > 0, range(len(ys[0])))) + for n in range(1, len(y_feats_dict)): + nonzero_idx = filter(lambda i: len(ys[n][i]) > 0, nonzero_idx) + else: + # Note(kamo): Be careful not to make nonzero_idx to a generator + nonzero_idx = list(range(len(xs[0]))) + + if self.sort_in_input_length: + # sort in input lengths based on the first input + nonzero_sorted_idx = sorted(nonzero_idx, key=lambda i: -len(xs[0][i])) + else: + nonzero_sorted_idx = nonzero_idx + + if len(nonzero_sorted_idx) != len(xs[0]): + logging.warning( + "Target sequences include empty tokenid (batch {} -> {}).".format( + len(xs[0]), len(nonzero_sorted_idx) + ) + ) + + # remove zero-length samples + xs = [[x[i] for i in nonzero_sorted_idx] for x in xs] + uttid_list = [uttid_list[i] for i in nonzero_sorted_idx] + + x_names = list(x_feats_dict.keys()) + if self.load_output: + ys = [[y[i] for i in nonzero_sorted_idx] for y in ys] + y_names = list(y_feats_dict.keys()) + + # Keeping x_name and y_name, e.g. input1, for future extension + return_batch = OrderedDict( + [ + *[(x_name, x) for x_name, x in zip(x_names, xs)], + *[(y_name, y) for y_name, y in zip(y_names, ys)], + ] + ) + else: + return_batch = OrderedDict([(x_name, x) for x_name, x in zip(x_names, xs)]) + return return_batch, uttid_list + + def _create_batch_mt(self, x_feats_dict, y_feats_dict, uttid_list): + """Create a OrderedDict for the mini-batch + + :param OrderedDict x_feats_dict: + :param OrderedDict y_feats_dict: + :return: batch, uttid_list + :rtype: Tuple[OrderedDict, List[str]] + """ + # Create a list from the first item + xs = list(x_feats_dict.values())[0] + + if self.load_output: + ys = list(y_feats_dict.values())[0] + assert len(xs) == len(ys), (len(xs), len(ys)) + + # get index of non-zero length samples + nonzero_idx = filter(lambda i: len(ys[i]) > 0, range(len(ys))) + else: + nonzero_idx = range(len(xs)) + + if self.sort_in_input_length: + # sort in input lengths + nonzero_sorted_idx = sorted(nonzero_idx, key=lambda i: -len(xs[i])) + else: + nonzero_sorted_idx = nonzero_idx + + if len(nonzero_sorted_idx) != len(xs): + logging.warning( + "Target sequences include empty tokenid (batch {} -> {}).".format( + len(xs), len(nonzero_sorted_idx) + ) + ) + + # remove zero-length samples + xs = [xs[i] for i in nonzero_sorted_idx] + uttid_list = [uttid_list[i] for i in nonzero_sorted_idx] + + x_name = list(x_feats_dict.keys())[0] + if self.load_output: + ys = [ys[i] for i in nonzero_sorted_idx] + y_name = list(y_feats_dict.keys())[0] + + return_batch = OrderedDict([(x_name, xs), (y_name, ys)]) + else: + return_batch = OrderedDict([(x_name, xs)]) + return return_batch, uttid_list + + def _create_batch_tts(self, x_feats_dict, y_feats_dict, uttid_list, eos): + """Create a OrderedDict for the mini-batch + + :param OrderedDict x_feats_dict: + e.g. {"input1": [ndarray, ndarray, ...], + "input2": [ndarray, ndarray, ...]} + :param OrderedDict y_feats_dict: + e.g. {"target1": [ndarray, ndarray, ...], + "target2": [ndarray, ndarray, ...]} + :param: List[str] uttid_list: + :param int eos: + :return: batch, uttid_list + :rtype: Tuple[OrderedDict, List[str]] + """ + # Use the output values as the input feats for tts mode + xs = list(y_feats_dict.values())[0] + # get index of non-zero length samples + nonzero_idx = list(filter(lambda i: len(xs[i]) > 0, range(len(xs)))) + # sort in input lengths + if self.sort_in_input_length: + # sort in input lengths + nonzero_sorted_idx = sorted(nonzero_idx, key=lambda i: -len(xs[i])) + else: + nonzero_sorted_idx = nonzero_idx + # remove zero-length samples + xs = [xs[i] for i in nonzero_sorted_idx] + uttid_list = [uttid_list[i] for i in nonzero_sorted_idx] + # Added eos into input sequence + xs = [np.append(x, eos) for x in xs] + + if self.load_input: + ys = list(x_feats_dict.values())[0] + assert len(xs) == len(ys), (len(xs), len(ys)) + ys = [ys[i] for i in nonzero_sorted_idx] + + spembs = None + spcs = None + spembs_name = "spembs_none" + spcs_name = "spcs_none" + + if self.use_second_target: + spcs = list(x_feats_dict.values())[1] + spcs = [spcs[i] for i in nonzero_sorted_idx] + spcs_name = list(x_feats_dict.keys())[1] + + if self.use_speaker_embedding: + spembs = list(x_feats_dict.values())[1] + spembs = [spembs[i] for i in nonzero_sorted_idx] + spembs_name = list(x_feats_dict.keys())[1] + + x_name = list(y_feats_dict.keys())[0] + y_name = list(x_feats_dict.keys())[0] + + return_batch = OrderedDict( + [(x_name, xs), (y_name, ys), (spembs_name, spembs), (spcs_name, spcs)] + ) + elif self.use_speaker_embedding: + if len(x_feats_dict) == 0: + raise IndexError("No speaker embedding is provided") + elif len(x_feats_dict) == 1: + spembs_idx = 0 + else: + spembs_idx = 1 + + spembs = list(x_feats_dict.values())[spembs_idx] + spembs = [spembs[i] for i in nonzero_sorted_idx] + + x_name = list(y_feats_dict.keys())[0] + spembs_name = list(x_feats_dict.keys())[spembs_idx] + + return_batch = OrderedDict([(x_name, xs), (spembs_name, spembs)]) + else: + x_name = list(y_feats_dict.keys())[0] + + return_batch = OrderedDict([(x_name, xs)]) + return return_batch, uttid_list + + def _create_batch_vc(self, x_feats_dict, y_feats_dict, uttid_list): + """Create a OrderedDict for the mini-batch + + :param OrderedDict x_feats_dict: + e.g. {"input1": [ndarray, ndarray, ...], + "input2": [ndarray, ndarray, ...]} + :param OrderedDict y_feats_dict: + e.g. {"target1": [ndarray, ndarray, ...], + "target2": [ndarray, ndarray, ...]} + :param: List[str] uttid_list: + :return: batch, uttid_list + :rtype: Tuple[OrderedDict, List[str]] + """ + # Create a list from the first item + xs = list(x_feats_dict.values())[0] + + # get index of non-zero length samples + nonzero_idx = list(filter(lambda i: len(xs[i]) > 0, range(len(xs)))) + + # sort in input lengths + if self.sort_in_input_length: + # sort in input lengths + nonzero_sorted_idx = sorted(nonzero_idx, key=lambda i: -len(xs[i])) + else: + nonzero_sorted_idx = nonzero_idx + + # remove zero-length samples + xs = [xs[i] for i in nonzero_sorted_idx] + uttid_list = [uttid_list[i] for i in nonzero_sorted_idx] + + if self.load_output: + ys = list(y_feats_dict.values())[0] + assert len(xs) == len(ys), (len(xs), len(ys)) + ys = [ys[i] for i in nonzero_sorted_idx] + + spembs = None + spcs = None + spembs_name = "spembs_none" + spcs_name = "spcs_none" + + if self.use_second_target: + raise ValueError("Currently second target not supported.") + spcs = list(x_feats_dict.values())[1] + spcs = [spcs[i] for i in nonzero_sorted_idx] + spcs_name = list(x_feats_dict.keys())[1] + + if self.use_speaker_embedding: + spembs = list(x_feats_dict.values())[1] + spembs = [spembs[i] for i in nonzero_sorted_idx] + spembs_name = list(x_feats_dict.keys())[1] + + x_name = list(x_feats_dict.keys())[0] + y_name = list(y_feats_dict.keys())[0] + + return_batch = OrderedDict( + [(x_name, xs), (y_name, ys), (spembs_name, spembs), (spcs_name, spcs)] + ) + elif self.use_speaker_embedding: + if len(x_feats_dict) == 0: + raise IndexError("No speaker embedding is provided") + elif len(x_feats_dict) == 1: + spembs_idx = 0 + else: + spembs_idx = 1 + + spembs = list(x_feats_dict.values())[spembs_idx] + spembs = [spembs[i] for i in nonzero_sorted_idx] + + x_name = list(x_feats_dict.keys())[0] + spembs_name = list(x_feats_dict.keys())[spembs_idx] + + return_batch = OrderedDict([(x_name, xs), (spembs_name, spembs)]) + else: + x_name = list(x_feats_dict.keys())[0] + + return_batch = OrderedDict([(x_name, xs)]) + return return_batch, uttid_list + + def _get_from_loader(self, filepath, filetype): + """Return ndarray + + In order to make the fds to be opened only at the first referring, + the loader are stored in self._loaders + + >>> ndarray = loader.get_from_loader( + ... 'some/path.h5:F01_050C0101_PED_REAL', filetype='hdf5') + + :param: str filepath: + :param: str filetype: + :return: + :rtype: np.ndarray + """ + if filetype == "hdf5": + # e.g. + # {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL", + # "filetype": "hdf5", + # -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL" + filepath, key = filepath.split(":", 1) + + loader = self._loaders.get(filepath) + if loader is None: + # To avoid disk access, create loader only for the first time + loader = h5py.File(filepath, "r") + self._loaders[filepath] = loader + return loader[key][()] + elif filetype == "sound.hdf5": + # e.g. + # {"input": [{"feat": "some/path.h5:F01_050C0101_PED_REAL", + # "filetype": "sound.hdf5", + # -> filepath = "some/path.h5", key = "F01_050C0101_PED_REAL" + filepath, key = filepath.split(":", 1) + + loader = self._loaders.get(filepath) + if loader is None: + # To avoid disk access, create loader only for the first time + loader = SoundHDF5File(filepath, "r", dtype="int16") + self._loaders[filepath] = loader + array, rate = loader[key] + return array + elif filetype == "sound": + # e.g. + # {"input": [{"feat": "some/path.wav", + # "filetype": "sound"}, + # Assume PCM16 + if not self.keep_all_data_on_mem: + array, _ = soundfile.read(filepath, dtype="int16") + return array + if filepath not in self._loaders: + array, _ = soundfile.read(filepath, dtype="int16") + self._loaders[filepath] = array + return self._loaders[filepath] + elif filetype == "npz": + # e.g. + # {"input": [{"feat": "some/path.npz:F01_050C0101_PED_REAL", + # "filetype": "npz", + filepath, key = filepath.split(":", 1) + + loader = self._loaders.get(filepath) + if loader is None: + # To avoid disk access, create loader only for the first time + loader = np.load(filepath) + self._loaders[filepath] = loader + return loader[key] + elif filetype == "npy": + # e.g. + # {"input": [{"feat": "some/path.npy", + # "filetype": "npy"}, + if not self.keep_all_data_on_mem: + return np.load(filepath) + if filepath not in self._loaders: + self._loaders[filepath] = np.load(filepath) + return self._loaders[filepath] + elif filetype in ["mat", "vec"]: + # e.g. + # {"input": [{"feat": "some/path.ark:123", + # "filetype": "mat"}]}, + # In this case, "123" indicates the starting points of the matrix + # load_mat can load both matrix and vector + if not self.keep_all_data_on_mem: + return kaldiio.load_mat(filepath) + if filepath not in self._loaders: + self._loaders[filepath] = kaldiio.load_mat(filepath) + return self._loaders[filepath] + elif filetype == "scp": + # e.g. + # {"input": [{"feat": "some/path.scp:F01_050C0101_PED_REAL", + # "filetype": "scp", + filepath, key = filepath.split(":", 1) + loader = self._loaders.get(filepath) + if loader is None: + # To avoid disk access, create loader only for the first time + loader = kaldiio.load_scp(filepath) + self._loaders[filepath] = loader + return loader[key] + else: + raise NotImplementedError("Not supported: loader_type={}".format(filetype)) + + +class SoundHDF5File(object): + """Collecting sound files to a HDF5 file + + >>> f = SoundHDF5File('a.flac.h5', mode='a') + >>> array = np.random.randint(0, 100, 100, dtype=np.int16) + >>> f['id'] = (array, 16000) + >>> array, rate = f['id'] + + + :param: str filepath: + :param: str mode: + :param: str format: The type used when saving wav. flac, nist, htk, etc. + :param: str dtype: + + """ + + def __init__(self, filepath, mode="r+", format=None, dtype="int16", **kwargs): + self.filepath = filepath + self.mode = mode + self.dtype = dtype + + self.file = h5py.File(filepath, mode, **kwargs) + if format is None: + # filepath = a.flac.h5 -> format = flac + second_ext = os.path.splitext(os.path.splitext(filepath)[0])[1] + format = second_ext[1:] + if format.upper() not in soundfile.available_formats(): + # If not found, flac is selected + format = "flac" + + # This format affects only saving + self.format = format + + def __repr__(self): + return ''.format( + self.filepath, self.mode, self.format, self.dtype + ) + + def create_dataset(self, name, shape=None, data=None, **kwds): + f = io.BytesIO() + array, rate = data + soundfile.write(f, array, rate, format=self.format) + self.file.create_dataset(name, shape=shape, data=np.void(f.getvalue()), **kwds) + + def __setitem__(self, name, data): + self.create_dataset(name, data=data) + + def __getitem__(self, key): + data = self.file[key][()] + f = io.BytesIO(data.tobytes()) + array, rate = soundfile.read(f, dtype=self.dtype) + return array, rate + + def keys(self): + return self.file.keys() + + def values(self): + for k in self.file: + yield self[k] + + def items(self): + for k in self.file: + yield k, self[k] + + def __iter__(self): + return iter(self.file) + + def __contains__(self, item): + return item in self.file + + def __len__(self, item): + return len(self.file) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def close(self): + self.file.close() diff --git a/espnet/utils/spec_augment.py b/espnet/utils/spec_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c6bb2b856120a72cb5c229b985d43f0858276a --- /dev/null +++ b/espnet/utils/spec_augment.py @@ -0,0 +1,500 @@ +# -*- coding: utf-8 -*- + +""" +This implementation is modified from https://github.com/zcaceres/spec_augment + +MIT License + +Copyright (c) 2019 Zach Caceres + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETjjHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import random + +import torch + + +def specaug( + spec, W=5, F=30, T=40, num_freq_masks=2, num_time_masks=2, replace_with_zero=False +): + """SpecAugment + + Reference: + SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition + (https://arxiv.org/pdf/1904.08779.pdf) + + This implementation modified from https://github.com/zcaceres/spec_augment + + :param torch.Tensor spec: input tensor with the shape (T, dim) + :param int W: time warp parameter + :param int F: maximum width of each freq mask + :param int T: maximum width of each time mask + :param int num_freq_masks: number of frequency masks + :param int num_time_masks: number of time masks + :param bool replace_with_zero: if True, masked parts will be filled with 0, + if False, filled with mean + """ + return time_mask( + freq_mask( + time_warp(spec, W=W), + F=F, + num_masks=num_freq_masks, + replace_with_zero=replace_with_zero, + ), + T=T, + num_masks=num_time_masks, + replace_with_zero=replace_with_zero, + ) + + +def time_warp(spec, W=5): + """Time warping + + :param torch.Tensor spec: input tensor with shape (T, dim) + :param int W: time warp parameter + """ + spec = spec.unsqueeze(0) + spec_len = spec.shape[1] + num_rows = spec.shape[2] + device = spec.device + + y = num_rows // 2 + horizontal_line_at_ctr = spec[0, :, y] + assert len(horizontal_line_at_ctr) == spec_len + + point_to_warp = horizontal_line_at_ctr[random.randrange(W, spec_len - W)] + assert isinstance(point_to_warp, torch.Tensor) + + # Uniform distribution from (0,W) with chance to be up to W negative + dist_to_warp = random.randrange(-W, W) + src_pts, dest_pts = ( + torch.tensor([[[point_to_warp, y]]], device=device), + torch.tensor([[[point_to_warp + dist_to_warp, y]]], device=device), + ) + warped_spectro, dense_flows = sparse_image_warp(spec, src_pts, dest_pts) + return warped_spectro.squeeze(3).squeeze(0) + + +def freq_mask(spec, F=30, num_masks=1, replace_with_zero=False): + """Frequency masking + + :param torch.Tensor spec: input tensor with shape (T, dim) + :param int F: maximum width of each mask + :param int num_masks: number of masks + :param bool replace_with_zero: if True, masked parts will be filled with 0, + if False, filled with mean + """ + cloned = spec.unsqueeze(0).clone() + num_mel_channels = cloned.shape[2] + + for i in range(0, num_masks): + f = random.randrange(0, F) + f_zero = random.randrange(0, num_mel_channels - f) + + # avoids randrange error if values are equal and range is empty + if f_zero == f_zero + f: + return cloned.squeeze(0) + + mask_end = random.randrange(f_zero, f_zero + f) + if replace_with_zero: + cloned[0][:, f_zero:mask_end] = 0 + else: + cloned[0][:, f_zero:mask_end] = cloned.mean() + return cloned.squeeze(0) + + +def time_mask(spec, T=40, num_masks=1, replace_with_zero=False): + """Time masking + + :param torch.Tensor spec: input tensor with shape (T, dim) + :param int T: maximum width of each mask + :param int num_masks: number of masks + :param bool replace_with_zero: if True, masked parts will be filled with 0, + if False, filled with mean + """ + cloned = spec.unsqueeze(0).clone() + len_spectro = cloned.shape[1] + + for i in range(0, num_masks): + t = random.randrange(0, T) + t_zero = random.randrange(0, len_spectro - t) + + # avoids randrange error if values are equal and range is empty + if t_zero == t_zero + t: + return cloned.squeeze(0) + + mask_end = random.randrange(t_zero, t_zero + t) + if replace_with_zero: + cloned[0][t_zero:mask_end, :] = 0 + else: + cloned[0][t_zero:mask_end, :] = cloned.mean() + return cloned.squeeze(0) + + +def sparse_image_warp( + img_tensor, + source_control_point_locations, + dest_control_point_locations, + interpolation_order=2, + regularization_weight=0.0, + num_boundaries_points=0, +): + device = img_tensor.device + control_point_flows = dest_control_point_locations - source_control_point_locations + + batch_size, image_height, image_width = img_tensor.shape + flattened_grid_locations = get_flat_grid_locations( + image_height, image_width, device + ) + + flattened_flows = interpolate_spline( + dest_control_point_locations, + control_point_flows, + flattened_grid_locations, + interpolation_order, + regularization_weight, + ) + + dense_flows = create_dense_flows( + flattened_flows, batch_size, image_height, image_width + ) + + warped_image = dense_image_warp(img_tensor, dense_flows) + + return warped_image, dense_flows + + +def get_grid_locations(image_height, image_width, device): + y_range = torch.linspace(0, image_height - 1, image_height, device=device) + x_range = torch.linspace(0, image_width - 1, image_width, device=device) + y_grid, x_grid = torch.meshgrid(y_range, x_range) + return torch.stack((y_grid, x_grid), -1) + + +def flatten_grid_locations(grid_locations, image_height, image_width): + return torch.reshape(grid_locations, [image_height * image_width, 2]) + + +def get_flat_grid_locations(image_height, image_width, device): + y_range = torch.linspace(0, image_height - 1, image_height, device=device) + x_range = torch.linspace(0, image_width - 1, image_width, device=device) + y_grid, x_grid = torch.meshgrid(y_range, x_range) + return torch.stack((y_grid, x_grid), -1).reshape([image_height * image_width, 2]) + + +def create_dense_flows(flattened_flows, batch_size, image_height, image_width): + # possibly .view + return torch.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) + + +def interpolate_spline( + train_points, + train_values, + query_points, + order, + regularization_weight=0.0, +): + # First, fit the spline to the observed data. + w, v = solve_interpolation(train_points, train_values, order, regularization_weight) + # Then, evaluate the spline at the query locations. + query_values = apply_interpolation(query_points, train_points, w, v, order) + + return query_values + + +def solve_interpolation(train_points, train_values, order, regularization_weight): + device = train_points.device + b, n, d = train_points.shape + k = train_values.shape[-1] + + c = train_points + f = train_values.float() + + matrix_a = phi(cross_squared_distance_matrix(c, c), order).unsqueeze(0) # [b, n, n] + + # Append ones to the feature values for the bias term in the linear model. + ones = torch.ones(1, dtype=train_points.dtype, device=device).view([-1, 1, 1]) + matrix_b = torch.cat((c, ones), 2).float() # [b, n, d + 1] + + # [b, n + d + 1, n] + left_block = torch.cat((matrix_a, torch.transpose(matrix_b, 2, 1)), 1) + + num_b_cols = matrix_b.shape[2] # d + 1 + + # In Tensorflow, zeros are used here. Pytorch solve fails with zeros + # for some reason we don't understand. + # So instead we use very tiny randn values (variance of one, zero mean) + # on one side of our multiplication. + lhs_zeros = torch.randn((b, num_b_cols, num_b_cols), device=device) / 1e10 + right_block = torch.cat((matrix_b, lhs_zeros), 1) # [b, n + d + 1, d + 1] + lhs = torch.cat((left_block, right_block), 2) # [b, n + d + 1, n + d + 1] + + rhs_zeros = torch.zeros( + (b, d + 1, k), dtype=train_points.dtype, device=device + ).float() + rhs = torch.cat((f, rhs_zeros), 1) # [b, n + d + 1, k] + + # Then, solve the linear system and unpack the results. + X, LU = torch.gesv(rhs, lhs) + w = X[:, :n, :] + v = X[:, n:, :] + + return w, v + + +def cross_squared_distance_matrix(x, y): + """Pairwise squared distance between two (batch) matrices' rows (2nd dim). + + Computes the pairwise distances between rows of x and rows of y + Args: + x: [batch_size, n, d] float `Tensor` + y: [batch_size, m, d] float `Tensor` + Returns: + squared_dists: [batch_size, n, m] float `Tensor`, where + squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2 + """ + x_norm_squared = torch.sum(torch.mul(x, x)) + y_norm_squared = torch.sum(torch.mul(y, y)) + + x_y_transpose = torch.matmul(x.squeeze(0), y.squeeze(0).transpose(0, 1)) + + # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj + squared_dists = x_norm_squared - 2 * x_y_transpose + y_norm_squared + + return squared_dists.float() + + +def phi(r, order): + """Coordinate-wise nonlinearity used to define the order of the interpolation. + + See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition. + Args: + r: input op + order: interpolation order + Returns: + phi_k evaluated coordinate-wise on r, for k = r + """ + EPSILON = torch.tensor(1e-10, device=r.device) + # using EPSILON prevents log(0), sqrt0), etc. + # sqrt(0) is well-defined, but its gradient is not + if order == 1: + r = torch.max(r, EPSILON) + r = torch.sqrt(r) + return r + elif order == 2: + return 0.5 * r * torch.log(torch.max(r, EPSILON)) + elif order == 4: + return 0.5 * torch.square(r) * torch.log(torch.max(r, EPSILON)) + elif order % 2 == 0: + r = torch.max(r, EPSILON) + return 0.5 * torch.pow(r, 0.5 * order) * torch.log(r) + else: + r = torch.max(r, EPSILON) + return torch.pow(r, 0.5 * order) + + +def apply_interpolation(query_points, train_points, w, v, order): + """Apply polyharmonic interpolation model to data. + + Notes: + Given coefficients w and v for the interpolation model, we evaluate + interpolated function values at query_points. + + Args: + query_points: `[b, m, d]` x values to evaluate the interpolation at + train_points: `[b, n, d]` x values that act as the interpolation centers + ( the c variables in the wikipedia article) + w: `[b, n, k]` weights on each interpolation center + v: `[b, d, k]` weights on each input dimension + order: order of the interpolation + + Returns: + Polyharmonic interpolation evaluated at points defined in query_points. + """ + query_points = query_points.unsqueeze(0) + # First, compute the contribution from the rbf term. + pairwise_dists = cross_squared_distance_matrix( + query_points.float(), train_points.float() + ) + phi_pairwise_dists = phi(pairwise_dists, order) + + rbf_term = torch.matmul(phi_pairwise_dists, w) + + # Then, compute the contribution from the linear term. + # Pad query_points with ones, for the bias term in the linear model. + ones = torch.ones_like(query_points[..., :1]) + query_points_pad = torch.cat((query_points, ones), 2).float() + linear_term = torch.matmul(query_points_pad, v) + + return rbf_term + linear_term + + +def dense_image_warp(image, flow): + """Image warping using per-pixel flow vectors. + + Apply a non-linear warp to the image, where the warp is specified by a dense + flow field of offset vectors that define the correspondences of pixel values + in the output image back to locations in the source image. Specifically, the + pixel value at output[b, j, i, c] is + images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. + The locations specified by this formula do not necessarily map to an int + index. Therefore, the pixel value is obtained by bilinear + interpolation of the 4 nearest pixels around + (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside + of the image, we use the nearest pixel values at the image boundary. + Args: + image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. + flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. + name: A name for the operation (optional). + Note that image and flow can be of type tf.half, tf.float32, or tf.float64, + and do not necessarily have to be the same type. + Returns: + A 4-D float `Tensor` with shape`[batch, height, width, channels]` + and same type as input image. + Raises: + ValueError: if height < 2 or width < 2 or the inputs have the wrong number + of dimensions. + """ + image = image.unsqueeze(3) # add a single channel dimension to image tensor + batch_size, height, width, channels = image.shape + device = image.device + + # The flow is defined on the image grid. Turn the flow into a list of query + # points in the grid space. + grid_x, grid_y = torch.meshgrid( + torch.arange(width, device=device), torch.arange(height, device=device) + ) + + stacked_grid = torch.stack((grid_y, grid_x), dim=2).float() + + batched_grid = stacked_grid.unsqueeze(-1).permute(3, 1, 0, 2) + + query_points_on_grid = batched_grid - flow + query_points_flattened = torch.reshape( + query_points_on_grid, [batch_size, height * width, 2] + ) + # Compute values at the query points, then reshape the result back to the + # image grid. + interpolated = interpolate_bilinear(image, query_points_flattened) + interpolated = torch.reshape(interpolated, [batch_size, height, width, channels]) + return interpolated + + +def interpolate_bilinear( + grid, query_points, name="interpolate_bilinear", indexing="ij" +): + """Similar to Matlab's interp2 function. + + Notes: + Finds values for query points on a grid using bilinear interpolation. + + Args: + grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. + query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. + name: a name for the operation (optional). + indexing: whether the query points are specified as row and column (ij), + or Cartesian coordinates (xy). + + Returns: + values: a 3-D `Tensor` with shape `[batch, N, channels]` + + Raises: + ValueError: if the indexing mode is invalid, or if the shape of the inputs + invalid. + """ + if indexing != "ij" and indexing != "xy": + raise ValueError("Indexing mode must be 'ij' or 'xy'") + + shape = grid.shape + if len(shape) != 4: + msg = "Grid must be 4 dimensional. Received size: " + raise ValueError(msg + str(grid.shape)) + + batch_size, height, width, channels = grid.shape + + shape = [batch_size, height, width, channels] + query_type = query_points.dtype + grid_type = grid.dtype + grid_device = grid.device + + num_queries = query_points.shape[1] + + alphas = [] + floors = [] + ceils = [] + index_order = [0, 1] if indexing == "ij" else [1, 0] + unstacked_query_points = query_points.unbind(2) + + for dim in index_order: + queries = unstacked_query_points[dim] + + size_in_indexing_dimension = shape[dim + 1] + + # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 + # is still a valid index into the grid. + max_floor = torch.tensor( + size_in_indexing_dimension - 2, dtype=query_type, device=grid_device + ) + min_floor = torch.tensor(0.0, dtype=query_type, device=grid_device) + maxx = torch.max(min_floor, torch.floor(queries)) + floor = torch.min(maxx, max_floor) + int_floor = floor.long() + floors.append(int_floor) + ceil = int_floor + 1 + ceils.append(ceil) + + # alpha has the same type as the grid, as we will directly use alpha + # when taking linear combinations of pixel values from the image. + + alpha = torch.tensor((queries - floor), dtype=grid_type, device=grid_device) + min_alpha = torch.tensor(0.0, dtype=grid_type, device=grid_device) + max_alpha = torch.tensor(1.0, dtype=grid_type, device=grid_device) + alpha = torch.min(torch.max(min_alpha, alpha), max_alpha) + + # Expand alpha to [b, n, 1] so we can use broadcasting + # (since the alpha values don't depend on the channel). + alpha = torch.unsqueeze(alpha, 2) + alphas.append(alpha) + + flattened_grid = torch.reshape(grid, [batch_size * height * width, channels]) + batch_offsets = torch.reshape( + torch.arange(batch_size, device=grid_device) * height * width, [batch_size, 1] + ) + + # This wraps array_ops.gather. We reshape the image data such that the + # batch, y, and x coordinates are pulled into the first dimension. + # Then we gather. Finally, we reshape the output back. It's possible this + # code would be made simpler by using array_ops.gather_nd. + def gather(y_coords, x_coords, name): + linear_coordinates = batch_offsets + y_coords * width + x_coords + gathered_values = torch.gather(flattened_grid.t(), 1, linear_coordinates) + return torch.reshape(gathered_values, [batch_size, num_queries, channels]) + + # grab the pixel values in the 4 corners around each query point + top_left = gather(floors[0], floors[1], "top_left") + top_right = gather(floors[0], ceils[1], "top_right") + bottom_left = gather(ceils[0], floors[1], "bottom_left") + bottom_right = gather(ceils[0], ceils[1], "bottom_right") + + interp_top = alphas[1] * (top_right - top_left) + top_left + interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left + interp = alphas[0] * (interp_bottom - interp_top) + interp_top + + return interp diff --git a/espnet/utils/training/__init__.py b/espnet/utils/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f177368e62a5578b8706300e101f831a3972ac --- /dev/null +++ b/espnet/utils/training/__init__.py @@ -0,0 +1 @@ +"""Initialize sub package.""" diff --git a/espnet/utils/training/batchfy.py b/espnet/utils/training/batchfy.py new file mode 100644 index 0000000000000000000000000000000000000000..5deccb2d4b64f06b98c880917e8248d82972eb7f --- /dev/null +++ b/espnet/utils/training/batchfy.py @@ -0,0 +1,505 @@ +import itertools +import logging + +import numpy as np + + +def batchfy_by_seq( + sorted_data, + batch_size, + max_length_in, + max_length_out, + min_batch_size=1, + shortest_first=False, + ikey="input", + iaxis=0, + okey="output", + oaxis=0, +): + """Make batch set from json dictionary + + :param Dict[str, Dict[str, Any]] sorted_data: dictionary loaded from data.json + :param int batch_size: batch size + :param int max_length_in: maximum length of input to decide adaptive batch size + :param int max_length_out: maximum length of output to decide adaptive batch size + :param int min_batch_size: mininum batch size (for multi-gpu) + :param bool shortest_first: Sort from batch with shortest samples + to longest if true, otherwise reverse + :param str ikey: key to access input + (for ASR ikey="input", for TTS, MT ikey="output".) + :param int iaxis: dimension to access input + (for ASR, TTS iaxis=0, for MT iaxis="1".) + :param str okey: key to access output + (for ASR, MT okey="output". for TTS okey="input".) + :param int oaxis: dimension to access output + (for ASR, TTS, MT oaxis=0, reserved for future research, -1 means all axis.) + :return: List[List[Tuple[str, dict]]] list of batches + """ + if batch_size <= 0: + raise ValueError(f"Invalid batch_size={batch_size}") + + # check #utts is more than min_batch_size + if len(sorted_data) < min_batch_size: + raise ValueError( + f"#utts({len(sorted_data)}) is less than min_batch_size({min_batch_size})." + ) + + # make list of minibatches + minibatches = [] + start = 0 + while True: + _, info = sorted_data[start] + ilen = int(info[ikey][iaxis]["shape"][0]) + olen = ( + int(info[okey][oaxis]["shape"][0]) + if oaxis >= 0 + else max(map(lambda x: int(x["shape"][0]), info[okey])) + ) + factor = max(int(ilen / max_length_in), int(olen / max_length_out)) + # change batchsize depending on the input and output length + # if ilen = 1000 and max_length_in = 800 + # then b = batchsize / 2 + # and max(min_batches, .) avoids batchsize = 0 + bs = max(min_batch_size, int(batch_size / (1 + factor))) + end = min(len(sorted_data), start + bs) + minibatch = sorted_data[start:end] + if shortest_first: + minibatch.reverse() + + # check each batch is more than minimum batchsize + if len(minibatch) < min_batch_size: + mod = min_batch_size - len(minibatch) % min_batch_size + additional_minibatch = [ + sorted_data[i] for i in np.random.randint(0, start, mod) + ] + if shortest_first: + additional_minibatch.reverse() + minibatch.extend(additional_minibatch) + minibatches.append(minibatch) + + if end == len(sorted_data): + break + start = end + + # batch: List[List[Tuple[str, dict]]] + return minibatches + + +def batchfy_by_bin( + sorted_data, + batch_bins, + num_batches=0, + min_batch_size=1, + shortest_first=False, + ikey="input", + okey="output", +): + """Make variably sized batch set, which maximizes + + the number of bins up to `batch_bins`. + + :param Dict[str, Dict[str, Any]] sorted_data: dictionary loaded from data.json + :param int batch_bins: Maximum frames of a batch + :param int num_batches: # number of batches to use (for debug) + :param int min_batch_size: minimum batch size (for multi-gpu) + :param int test: Return only every `test` batches + :param bool shortest_first: Sort from batch with shortest samples + to longest if true, otherwise reverse + + :param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".) + :param str okey: key to access output (for ASR okey="output". for TTS okey="input".) + + :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches + """ + if batch_bins <= 0: + raise ValueError(f"invalid batch_bins={batch_bins}") + length = len(sorted_data) + idim = int(sorted_data[0][1][ikey][0]["shape"][1]) + odim = int(sorted_data[0][1][okey][0]["shape"][1]) + logging.info("# utts: " + str(len(sorted_data))) + minibatches = [] + start = 0 + n = 0 + while True: + # Dynamic batch size depending on size of samples + b = 0 + next_size = 0 + max_olen = 0 + while next_size < batch_bins and (start + b) < length: + ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) * idim + olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) * odim + if olen > max_olen: + max_olen = olen + next_size = (max_olen + ilen) * (b + 1) + if next_size <= batch_bins: + b += 1 + elif next_size == 0: + raise ValueError( + f"Can't fit one sample in batch_bins ({batch_bins}): " + f"Please increase the value" + ) + end = min(length, start + max(min_batch_size, b)) + batch = sorted_data[start:end] + if shortest_first: + batch.reverse() + minibatches.append(batch) + # Check for min_batch_size and fixes the batches if needed + i = -1 + while len(minibatches[i]) < min_batch_size: + missing = min_batch_size - len(minibatches[i]) + if -i == len(minibatches): + minibatches[i + 1].extend(minibatches[i]) + minibatches = minibatches[1:] + break + else: + minibatches[i].extend(minibatches[i - 1][:missing]) + minibatches[i - 1] = minibatches[i - 1][missing:] + i -= 1 + if end == length: + break + start = end + n += 1 + if num_batches > 0: + minibatches = minibatches[:num_batches] + lengths = [len(x) for x in minibatches] + logging.info( + str(len(minibatches)) + + " batches containing from " + + str(min(lengths)) + + " to " + + str(max(lengths)) + + " samples " + + "(avg " + + str(int(np.mean(lengths))) + + " samples)." + ) + return minibatches + + +def batchfy_by_frame( + sorted_data, + max_frames_in, + max_frames_out, + max_frames_inout, + num_batches=0, + min_batch_size=1, + shortest_first=False, + ikey="input", + okey="output", +): + """Make variable batch set, which maximizes the number of frames to max_batch_frame. + + :param Dict[str, Dict[str, Any]] sorteddata: dictionary loaded from data.json + :param int max_frames_in: Maximum input frames of a batch + :param int max_frames_out: Maximum output frames of a batch + :param int max_frames_inout: Maximum input+output frames of a batch + :param int num_batches: # number of batches to use (for debug) + :param int min_batch_size: minimum batch size (for multi-gpu) + :param int test: Return only every `test` batches + :param bool shortest_first: Sort from batch with shortest samples + to longest if true, otherwise reverse + + :param str ikey: key to access input (for ASR ikey="input", for TTS ikey="output".) + :param str okey: key to access output (for ASR okey="output". for TTS okey="input".) + + :return: List[Tuple[str, Dict[str, List[Dict[str, Any]]]] list of batches + """ + if max_frames_in <= 0 and max_frames_out <= 0 and max_frames_inout <= 0: + raise ValueError( + "At least, one of `--batch-frames-in`, `--batch-frames-out` or " + "`--batch-frames-inout` should be > 0" + ) + length = len(sorted_data) + minibatches = [] + start = 0 + end = 0 + while end != length: + # Dynamic batch size depending on size of samples + b = 0 + max_olen = 0 + max_ilen = 0 + while (start + b) < length: + ilen = int(sorted_data[start + b][1][ikey][0]["shape"][0]) + if ilen > max_frames_in and max_frames_in != 0: + raise ValueError( + f"Can't fit one sample in --batch-frames-in ({max_frames_in}): " + f"Please increase the value" + ) + olen = int(sorted_data[start + b][1][okey][0]["shape"][0]) + if olen > max_frames_out and max_frames_out != 0: + raise ValueError( + f"Can't fit one sample in --batch-frames-out ({max_frames_out}): " + f"Please increase the value" + ) + if ilen + olen > max_frames_inout and max_frames_inout != 0: + raise ValueError( + f"Can't fit one sample in --batch-frames-out ({max_frames_inout}): " + f"Please increase the value" + ) + max_olen = max(max_olen, olen) + max_ilen = max(max_ilen, ilen) + in_ok = max_ilen * (b + 1) <= max_frames_in or max_frames_in == 0 + out_ok = max_olen * (b + 1) <= max_frames_out or max_frames_out == 0 + inout_ok = (max_ilen + max_olen) * ( + b + 1 + ) <= max_frames_inout or max_frames_inout == 0 + if in_ok and out_ok and inout_ok: + # add more seq in the minibatch + b += 1 + else: + # no more seq in the minibatch + break + end = min(length, start + b) + batch = sorted_data[start:end] + if shortest_first: + batch.reverse() + minibatches.append(batch) + # Check for min_batch_size and fixes the batches if needed + i = -1 + while len(minibatches[i]) < min_batch_size: + missing = min_batch_size - len(minibatches[i]) + if -i == len(minibatches): + minibatches[i + 1].extend(minibatches[i]) + minibatches = minibatches[1:] + break + else: + minibatches[i].extend(minibatches[i - 1][:missing]) + minibatches[i - 1] = minibatches[i - 1][missing:] + i -= 1 + start = end + if num_batches > 0: + minibatches = minibatches[:num_batches] + lengths = [len(x) for x in minibatches] + logging.info( + str(len(minibatches)) + + " batches containing from " + + str(min(lengths)) + + " to " + + str(max(lengths)) + + " samples" + + "(avg " + + str(int(np.mean(lengths))) + + " samples)." + ) + + return minibatches + + +def batchfy_shuffle(data, batch_size, min_batch_size, num_batches, shortest_first): + import random + + logging.info("use shuffled batch.") + sorted_data = random.sample(data.items(), len(data.items())) + logging.info("# utts: " + str(len(sorted_data))) + # make list of minibatches + minibatches = [] + start = 0 + while True: + end = min(len(sorted_data), start + batch_size) + # check each batch is more than minimum batchsize + minibatch = sorted_data[start:end] + if shortest_first: + minibatch.reverse() + if len(minibatch) < min_batch_size: + mod = min_batch_size - len(minibatch) % min_batch_size + additional_minibatch = [ + sorted_data[i] for i in np.random.randint(0, start, mod) + ] + if shortest_first: + additional_minibatch.reverse() + minibatch.extend(additional_minibatch) + minibatches.append(minibatch) + if end == len(sorted_data): + break + start = end + + # for debugging + if num_batches > 0: + minibatches = minibatches[:num_batches] + logging.info("# minibatches: " + str(len(minibatches))) + return minibatches + + +BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"] +BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"] + + +def make_batchset( + data, + batch_size=0, + max_length_in=float("inf"), + max_length_out=float("inf"), + num_batches=0, + min_batch_size=1, + shortest_first=False, + batch_sort_key="input", + swap_io=False, + mt=False, + count="auto", + batch_bins=0, + batch_frames_in=0, + batch_frames_out=0, + batch_frames_inout=0, + iaxis=0, + oaxis=0, +): + """Make batch set from json dictionary + + if utts have "category" value, + + >>> data = {'utt1': {'category': 'A', 'input': ...}, + ... 'utt2': {'category': 'B', 'input': ...}, + ... 'utt3': {'category': 'B', 'input': ...}, + ... 'utt4': {'category': 'A', 'input': ...}} + >>> make_batchset(data, batchsize=2, ...) + [[('utt1', ...), ('utt4', ...)], [('utt2', ...), ('utt3': ...)]] + + Note that if any utts doesn't have "category", + perform as same as batchfy_by_{count} + + :param Dict[str, Dict[str, Any]] data: dictionary loaded from data.json + :param int batch_size: maximum number of sequences in a minibatch. + :param int batch_bins: maximum number of bins (frames x dim) in a minibatch. + :param int batch_frames_in: maximum number of input frames in a minibatch. + :param int batch_frames_out: maximum number of output frames in a minibatch. + :param int batch_frames_out: maximum number of input+output frames in a minibatch. + :param str count: strategy to count maximum size of batch. + For choices, see espnet.asr.batchfy.BATCH_COUNT_CHOICES + + :param int max_length_in: maximum length of input to decide adaptive batch size + :param int max_length_out: maximum length of output to decide adaptive batch size + :param int num_batches: # number of batches to use (for debug) + :param int min_batch_size: minimum batch size (for multi-gpu) + :param bool shortest_first: Sort from batch with shortest samples + to longest if true, otherwise reverse + :param str batch_sort_key: how to sort data before creating minibatches + ["input", "output", "shuffle"] + :param bool swap_io: if True, use "input" as output and "output" + as input in `data` dict + :param bool mt: if True, use 0-axis of "output" as output and 1-axis of "output" + as input in `data` dict + :param int iaxis: dimension to access input + (for ASR, TTS iaxis=0, for MT iaxis="1".) + :param int oaxis: dimension to access output (for ASR, TTS, MT oaxis=0, + reserved for future research, -1 means all axis.) + :return: List[List[Tuple[str, dict]]] list of batches + """ + + # check args + if count not in BATCH_COUNT_CHOICES: + raise ValueError( + f"arg 'count' ({count}) should be one of {BATCH_COUNT_CHOICES}" + ) + if batch_sort_key not in BATCH_SORT_KEY_CHOICES: + raise ValueError( + f"arg 'batch_sort_key' ({batch_sort_key}) should be " + f"one of {BATCH_SORT_KEY_CHOICES}" + ) + + # TODO(karita): remove this by creating converter from ASR to TTS json format + batch_sort_axis = 0 + if swap_io: + # for TTS + ikey = "output" + okey = "input" + if batch_sort_key == "input": + batch_sort_key = "output" + elif batch_sort_key == "output": + batch_sort_key = "input" + elif mt: + # for MT + ikey = "output" + okey = "output" + batch_sort_key = "output" + batch_sort_axis = 1 + assert iaxis == 1 + assert oaxis == 0 + # NOTE: input is json['output'][1] and output is json['output'][0] + else: + ikey = "input" + okey = "output" + + if count == "auto": + if batch_size != 0: + count = "seq" + elif batch_bins != 0: + count = "bin" + elif batch_frames_in != 0 or batch_frames_out != 0 or batch_frames_inout != 0: + count = "frame" + else: + raise ValueError( + f"cannot detect `count` manually set one of {BATCH_COUNT_CHOICES}" + ) + logging.info(f"count is auto detected as {count}") + + if count != "seq" and batch_sort_key == "shuffle": + raise ValueError("batch_sort_key=shuffle is only available if batch_count=seq") + + category2data = {} # Dict[str, dict] + for k, v in data.items(): + category2data.setdefault(v.get("category"), {})[k] = v + + batches_list = [] # List[List[List[Tuple[str, dict]]]] + for d in category2data.values(): + if batch_sort_key == "shuffle": + batches = batchfy_shuffle( + d, batch_size, min_batch_size, num_batches, shortest_first + ) + batches_list.append(batches) + continue + + # sort it by input lengths (long to short) + sorted_data = sorted( + d.items(), + key=lambda data: int(data[1][batch_sort_key][batch_sort_axis]["shape"][0]), + reverse=not shortest_first, + ) + logging.info("# utts: " + str(len(sorted_data))) + if count == "seq": + batches = batchfy_by_seq( + sorted_data, + batch_size=batch_size, + max_length_in=max_length_in, + max_length_out=max_length_out, + min_batch_size=min_batch_size, + shortest_first=shortest_first, + ikey=ikey, + iaxis=iaxis, + okey=okey, + oaxis=oaxis, + ) + if count == "bin": + batches = batchfy_by_bin( + sorted_data, + batch_bins=batch_bins, + min_batch_size=min_batch_size, + shortest_first=shortest_first, + ikey=ikey, + okey=okey, + ) + if count == "frame": + batches = batchfy_by_frame( + sorted_data, + max_frames_in=batch_frames_in, + max_frames_out=batch_frames_out, + max_frames_inout=batch_frames_inout, + min_batch_size=min_batch_size, + shortest_first=shortest_first, + ikey=ikey, + okey=okey, + ) + batches_list.append(batches) + + if len(batches_list) == 1: + batches = batches_list[0] + else: + # Concat list. This way is faster than "sum(batch_list, [])" + batches = list(itertools.chain(*batches_list)) + + # for debugging + if num_batches > 0: + batches = batches[:num_batches] + logging.info("# minibatches: " + str(len(batches))) + + # batch: List[List[Tuple[str, dict]]] + return batches diff --git a/espnet/utils/training/evaluator.py b/espnet/utils/training/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..28586dd2f00b7d36be1a89011e6fa99132decb7a --- /dev/null +++ b/espnet/utils/training/evaluator.py @@ -0,0 +1,18 @@ +from chainer.training.extensions import Evaluator + +from espnet.utils.training.tensorboard_logger import TensorboardLogger + + +class BaseEvaluator(Evaluator): + """Base Evaluator in ESPnet""" + + def __call__(self, trainer=None): + ret = super().__call__(trainer) + try: + if trainer is not None: + # force tensorboard to report evaluation log + tb_logger = trainer.get_extension(TensorboardLogger.default_name) + tb_logger(trainer) + except ValueError: + pass + return ret diff --git a/espnet/utils/training/iterators.py b/espnet/utils/training/iterators.py new file mode 100644 index 0000000000000000000000000000000000000000..1cabb1f1fa8523fcb35c381fd257924b6af01862 --- /dev/null +++ b/espnet/utils/training/iterators.py @@ -0,0 +1,101 @@ +import chainer +from chainer.iterators import MultiprocessIterator +from chainer.iterators import SerialIterator +from chainer.iterators import ShuffleOrderSampler +from chainer.training.extension import Extension + +import numpy as np + + +class ShufflingEnabler(Extension): + """An extension enabling shuffling on an Iterator""" + + def __init__(self, iterators): + """Inits the ShufflingEnabler + + :param list[Iterator] iterators: The iterators to enable shuffling on + """ + self.set = False + self.iterators = iterators + + def __call__(self, trainer): + """Calls the enabler on the given iterator + + :param trainer: The iterator + """ + if not self.set: + for iterator in self.iterators: + iterator.start_shuffle() + self.set = True + + +class ToggleableShufflingSerialIterator(SerialIterator): + """A SerialIterator having its shuffling property activated during training""" + + def __init__(self, dataset, batch_size, repeat=True, shuffle=True): + """Init the Iterator + + :param torch.nn.Tensor dataset: The dataset to take batches from + :param int batch_size: The batch size + :param bool repeat: Whether to repeat data (allow multiple epochs) + :param bool shuffle: Whether to shuffle the batches + """ + super(ToggleableShufflingSerialIterator, self).__init__( + dataset, batch_size, repeat, shuffle + ) + + def start_shuffle(self): + """Starts shuffling (or reshuffles) the batches""" + self._shuffle = True + if int(chainer._version.__version__[0]) <= 4: + self._order = np.random.permutation(len(self.dataset)) + else: + self.order_sampler = ShuffleOrderSampler() + self._order = self.order_sampler(np.arange(len(self.dataset)), 0) + + +class ToggleableShufflingMultiprocessIterator(MultiprocessIterator): + """A MultiprocessIterator having its shuffling property activated during training""" + + def __init__( + self, + dataset, + batch_size, + repeat=True, + shuffle=True, + n_processes=None, + n_prefetch=1, + shared_mem=None, + maxtasksperchild=20, + ): + """Init the iterator + + :param torch.nn.Tensor dataset: The dataset to take batches from + :param int batch_size: The batch size + :param bool repeat: Whether to repeat batches or not (enables multiple epochs) + :param bool shuffle: Whether to shuffle the order of the batches + :param int n_processes: How many processes to use + :param int n_prefetch: The number of prefetch to use + :param int shared_mem: How many memory to share between processes + :param int maxtasksperchild: Maximum number of tasks per child + """ + super(ToggleableShufflingMultiprocessIterator, self).__init__( + dataset=dataset, + batch_size=batch_size, + repeat=repeat, + shuffle=shuffle, + n_processes=n_processes, + n_prefetch=n_prefetch, + shared_mem=shared_mem, + maxtasksperchild=maxtasksperchild, + ) + + def start_shuffle(self): + """Starts shuffling (or reshuffles) the batches""" + self.shuffle = True + if int(chainer._version.__version__[0]) <= 4: + self._order = np.random.permutation(len(self.dataset)) + else: + self.order_sampler = ShuffleOrderSampler() + self._order = self.order_sampler(np.arange(len(self.dataset)), 0) + self._set_prefetch_state() diff --git a/espnet/utils/training/tensorboard_logger.py b/espnet/utils/training/tensorboard_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..2db49755037cd0a203ddeae478bed9dc13111788 --- /dev/null +++ b/espnet/utils/training/tensorboard_logger.py @@ -0,0 +1,51 @@ +from chainer.training.extension import Extension + + +class TensorboardLogger(Extension): + """A tensorboard logger extension""" + + default_name = "espnet_tensorboard_logger" + + def __init__( + self, logger, att_reporter=None, ctc_reporter=None, entries=None, epoch=0 + ): + """Init the extension + + :param SummaryWriter logger: The logger to use + :param PlotAttentionReporter att_reporter: The (optional) PlotAttentionReporter + :param entries: The entries to watch + :param int epoch: The starting epoch + """ + self._entries = entries + self._att_reporter = att_reporter + self._ctc_reporter = ctc_reporter + self._logger = logger + self._epoch = epoch + + def __call__(self, trainer): + """Updates the events file with the new values + + :param trainer: The trainer + """ + observation = trainer.observation + for k, v in observation.items(): + if (self._entries is not None) and (k not in self._entries): + continue + if k is not None and v is not None: + if "cupy" in str(type(v)): + v = v.get() + if "cupy" in str(type(k)): + k = k.get() + self._logger.add_scalar(k, v, trainer.updater.iteration) + if ( + self._att_reporter is not None + and trainer.updater.get_iterator("main").epoch > self._epoch + ): + self._epoch = trainer.updater.get_iterator("main").epoch + self._att_reporter.log_attentions(self._logger, trainer.updater.iteration) + if ( + self._ctc_reporter is not None + and trainer.updater.get_iterator("main").epoch > self._epoch + ): + self._epoch = trainer.updater.get_iterator("main").epoch + self._ctc_reporter.log_ctc_probs(self._logger, trainer.updater.iteration) diff --git a/espnet/utils/training/train_utils.py b/espnet/utils/training/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..38f7cd4feb6b9e06f55f90090e6df4e0be4f0155 --- /dev/null +++ b/espnet/utils/training/train_utils.py @@ -0,0 +1,37 @@ +import chainer +import logging + + +def check_early_stop(trainer, epochs): + """Checks an early stopping trigger and warns the user if it's the case + + :param trainer: The trainer used for training + :param epochs: The maximum number of epochs + """ + end_epoch = trainer.updater.get_iterator("main").epoch + if end_epoch < (epochs - 1): + logging.warning( + "Hit early stop at epoch " + + str(end_epoch) + + "\nYou can change the patience or set it to 0 to run all epochs" + ) + + +def set_early_stop(trainer, args, is_lm=False): + """Sets the early stop trigger given the program arguments + + :param trainer: The trainer used for training + :param args: The program arguments + :param is_lm: If the trainer is for a LM (epoch instead of epochs) + """ + patience = args.patience + criterion = args.early_stop_criterion + epochs = args.epoch if is_lm else args.epochs + mode = "max" if "acc" in criterion else "min" + if patience > 0: + trainer.stop_trigger = chainer.training.triggers.EarlyStoppingTrigger( + monitor=criterion, + mode=mode, + patients=patience, + max_trigger=(epochs, "epoch"), + ) diff --git a/espnet/vc/pytorch_backend/vc.py b/espnet/vc/pytorch_backend/vc.py new file mode 100644 index 0000000000000000000000000000000000000000..ec35e20c3f5108494d65944e10ef244e2fdeb298 --- /dev/null +++ b/espnet/vc/pytorch_backend/vc.py @@ -0,0 +1,742 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Nagoya University (Wen-Chin Huang) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""E2E VC training / decoding functions.""" + +import copy +import json +import logging +import math +import os +import time + +import chainer +import kaldiio +import numpy as np +import torch + +from chainer import training +from chainer.training import extensions + +from espnet.asr.asr_utils import get_model_conf +from espnet.asr.asr_utils import snapshot_object +from espnet.asr.asr_utils import torch_load +from espnet.asr.asr_utils import torch_resume +from espnet.asr.asr_utils import torch_snapshot +from espnet.asr.pytorch_backend.asr_init import load_trained_modules +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet.nets.tts_interface import TTSInterface +from espnet.utils.dataset import ChainerDataLoader +from espnet.utils.dataset import TransformDataset +from espnet.utils.dynamic_import import dynamic_import +from espnet.utils.io_utils import LoadInputsAndTargets +from espnet.utils.training.batchfy import make_batchset +from espnet.utils.training.evaluator import BaseEvaluator + +from espnet.utils.deterministic_utils import set_deterministic_pytorch +from espnet.utils.training.train_utils import check_early_stop +from espnet.utils.training.train_utils import set_early_stop + +from espnet.utils.training.iterators import ShufflingEnabler + +import matplotlib + +from espnet.utils.training.tensorboard_logger import TensorboardLogger +from tensorboardX import SummaryWriter + +matplotlib.use("Agg") + + +class CustomEvaluator(BaseEvaluator): + """Custom evaluator.""" + + def __init__(self, model, iterator, target, device): + """Initilize module. + + Args: + model (torch.nn.Module): Pytorch model instance. + iterator (chainer.dataset.Iterator): Iterator for validation. + target (chainer.Chain): Dummy chain instance. + device (torch.device): The device to be used in evaluation. + + """ + super(CustomEvaluator, self).__init__(iterator, target) + self.model = model + self.device = device + + # The core part of the update routine can be customized by overriding. + def evaluate(self): + """Evaluate over validation iterator.""" + iterator = self._iterators["main"] + + if self.eval_hook: + self.eval_hook(self) + + if hasattr(iterator, "reset"): + iterator.reset() + it = iterator + else: + it = copy.copy(iterator) + + summary = chainer.reporter.DictSummary() + + self.model.eval() + with torch.no_grad(): + for batch in it: + if isinstance(batch, tuple): + x = tuple(arr.to(self.device) for arr in batch) + else: + x = batch + for key in x.keys(): + x[key] = x[key].to(self.device) + observation = {} + with chainer.reporter.report_scope(observation): + # convert to torch tensor + if isinstance(x, tuple): + self.model(*x) + else: + self.model(**x) + summary.add(observation) + self.model.train() + + return summary.compute_mean() + + +class CustomUpdater(training.StandardUpdater): + """Custom updater.""" + + def __init__(self, model, grad_clip, iterator, optimizer, device, accum_grad=1): + """Initilize module. + + Args: + model (torch.nn.Module) model: Pytorch model instance. + grad_clip (float) grad_clip : The gradient clipping value. + iterator (chainer.dataset.Iterator): Iterator for training. + optimizer (torch.optim.Optimizer) : Pytorch optimizer instance. + device (torch.device): The device to be used in training. + + """ + super(CustomUpdater, self).__init__(iterator, optimizer) + self.model = model + self.grad_clip = grad_clip + self.device = device + self.clip_grad_norm = torch.nn.utils.clip_grad_norm_ + self.accum_grad = accum_grad + self.forward_count = 0 + + # The core part of the update routine can be customized by overriding. + def update_core(self): + """Update model one step.""" + # When we pass one iterator and optimizer to StandardUpdater.__init__, + # they are automatically named 'main'. + train_iter = self.get_iterator("main") + optimizer = self.get_optimizer("main") + + # Get the next batch (a list of json files) + batch = train_iter.next() + if isinstance(batch, tuple): + x = tuple(arr.to(self.device) for arr in batch) + else: + x = batch + for key in x.keys(): + x[key] = x[key].to(self.device) + + # compute loss and gradient + if isinstance(x, tuple): + loss = self.model(*x).mean() / self.accum_grad + else: + loss = self.model(**x).mean() / self.accum_grad + loss.backward() + + # update parameters + self.forward_count += 1 + if self.forward_count != self.accum_grad: + return + self.forward_count = 0 + + # compute the gradient norm to check if it is normal or not + grad_norm = self.clip_grad_norm(self.model.parameters(), self.grad_clip) + logging.debug("grad norm={}".format(grad_norm)) + if math.isnan(grad_norm): + logging.warning("grad norm is nan. Do not update model.") + else: + optimizer.step() + optimizer.zero_grad() + + def update(self): + """Run update function.""" + self.update_core() + if self.forward_count == 0: + self.iteration += 1 + + +class CustomConverter(object): + """Custom converter.""" + + def __init__(self): + """Initilize module.""" + # NOTE: keep as class for future development + pass + + def __call__(self, batch, device=torch.device("cpu")): + """Convert a given batch. + + Args: + batch (list): List of ndarrays. + device (torch.device): The device to be send. + + Returns: + dict: Dict of converted tensors. + + Examples: + >>> batch = [([np.arange(5), np.arange(3)], + [np.random.randn(8, 2), np.random.randn(4, 2)], + None, None)] + >>> conveter = CustomConverter() + >>> conveter(batch, torch.device("cpu")) + {'xs': tensor([[0, 1, 2, 3, 4], + [0, 1, 2, 0, 0]]), + 'ilens': tensor([5, 3]), + 'ys': tensor([[[-0.4197, -1.1157], + [-1.5837, -0.4299], + [-2.0491, 0.9215], + [-2.4326, 0.8891], + [ 1.2323, 1.7388], + [-0.3228, 0.6656], + [-0.6025, 1.3693], + [-1.0778, 1.3447]], + [[ 0.1768, -0.3119], + [ 0.4386, 2.5354], + [-1.2181, -0.5918], + [-0.6858, -0.8843], + [ 0.0000, 0.0000], + [ 0.0000, 0.0000], + [ 0.0000, 0.0000], + [ 0.0000, 0.0000]]]), + 'labels': tensor([[0., 0., 0., 0., 0., 0., 0., 1.], + [0., 0., 0., 1., 1., 1., 1., 1.]]), + 'olens': tensor([8, 4])} + + """ + # batch should be located in list + assert len(batch) == 1 + xs, ys, spembs, extras = batch[0] + + # get list of lengths (must be tensor for DataParallel) + ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).long().to(device) + olens = torch.from_numpy(np.array([y.shape[0] for y in ys])).long().to(device) + + # perform padding and conversion to tensor + xs = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(device) + ys = pad_list([torch.from_numpy(y).float() for y in ys], 0).to(device) + + # make labels for stop prediction + labels = ys.new_zeros(ys.size(0), ys.size(1)) + for i, l in enumerate(olens): + labels[i, l - 1 :] = 1.0 + + # prepare dict + new_batch = { + "xs": xs, + "ilens": ilens, + "ys": ys, + "labels": labels, + "olens": olens, + } + + # load speaker embedding + if spembs is not None: + spembs = torch.from_numpy(np.array(spembs)).float() + new_batch["spembs"] = spembs.to(device) + + # load second target + if extras is not None: + extras = pad_list([torch.from_numpy(extra).float() for extra in extras], 0) + new_batch["extras"] = extras.to(device) + + return new_batch + + +def train(args): + """Train E2E VC model.""" + set_deterministic_pytorch(args) + + # check cuda availability + if not torch.cuda.is_available(): + logging.warning("cuda is not available") + + # get input and output dimension info + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + utts = list(valid_json.keys()) + + # In TTS, this is reversed, but not in VC. See `espnet.utils.training.batchfy` + idim = int(valid_json[utts[0]]["input"][0]["shape"][1]) + odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) + logging.info("#input dims : " + str(idim)) + logging.info("#output dims: " + str(odim)) + + # get extra input and output dimenstion + if args.use_speaker_embedding: + args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0]) + else: + args.spk_embed_dim = None + if args.use_second_target: + args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1]) + else: + args.spc_dim = None + + # write model config + if not os.path.exists(args.outdir): + os.makedirs(args.outdir) + model_conf = args.outdir + "/model.json" + with open(model_conf, "wb") as f: + logging.info("writing a model config file to" + model_conf) + f.write( + json.dumps( + (idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True + ).encode("utf_8") + ) + for key in sorted(vars(args).keys()): + logging.info("ARGS: " + key + ": " + str(vars(args)[key])) + + # specify model architecture + if args.enc_init is not None or args.dec_init is not None: + model = load_trained_modules(idim, odim, args, TTSInterface) + else: + model_class = dynamic_import(args.model_module) + model = model_class(idim, odim, args) + assert isinstance(model, TTSInterface) + logging.info(model) + reporter = model.reporter + + # freeze modules, if specified + if args.freeze_mods: + for mod, param in model.named_parameters(): + if any(mod.startswith(key) for key in args.freeze_mods): + logging.info("freezing %s" % mod) + param.requires_grad = False + + for mod, param in model.named_parameters(): + if not param.requires_grad: + logging.info("Frozen module %s" % mod) + + # check the use of multi-gpu + if args.ngpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) + if args.batch_size != 0: + logging.warning( + "batch size is automatically increased (%d -> %d)" + % (args.batch_size, args.batch_size * args.ngpu) + ) + args.batch_size *= args.ngpu + + # set torch device + device = torch.device("cuda" if args.ngpu > 0 else "cpu") + model = model.to(device) + + logging.warning( + "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( + sum(p.numel() for p in model.parameters()), + sum(p.numel() for p in model.parameters() if p.requires_grad), + sum(p.numel() for p in model.parameters() if p.requires_grad) + * 100.0 + / sum(p.numel() for p in model.parameters()), + ) + ) + + # Setup an optimizer + if args.opt == "adam": + optimizer = torch.optim.Adam( + model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay + ) + elif args.opt == "noam": + from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt + + optimizer = get_std_opt( + model, args.adim, args.transformer_warmup_steps, args.transformer_lr + ) + elif args.opt == "lamb": + from pytorch_lamb import Lamb + + optimizer = Lamb( + model.parameters(), lr=args.lr, weight_decay=0.01, betas=(0.9, 0.999) + ) + else: + raise NotImplementedError("unknown optimizer: " + args.opt) + + # FIXME: TOO DIRTY HACK + setattr(optimizer, "target", reporter) + setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) + + # read json data + with open(args.train_json, "rb") as f: + train_json = json.load(f)["utts"] + with open(args.valid_json, "rb") as f: + valid_json = json.load(f)["utts"] + + use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 + if use_sortagrad: + args.batch_sort_key = "input" + # make minibatch list (variable length) + train_batchset = make_batchset( + train_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + batch_sort_key=args.batch_sort_key, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + shortest_first=use_sortagrad, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + swap_io=False, + iaxis=0, + oaxis=0, + ) + valid_batchset = make_batchset( + valid_json, + args.batch_size, + args.maxlen_in, + args.maxlen_out, + args.minibatches, + batch_sort_key=args.batch_sort_key, + min_batch_size=args.ngpu if args.ngpu > 1 else 1, + count=args.batch_count, + batch_bins=args.batch_bins, + batch_frames_in=args.batch_frames_in, + batch_frames_out=args.batch_frames_out, + batch_frames_inout=args.batch_frames_inout, + swap_io=False, + iaxis=0, + oaxis=0, + ) + + load_tr = LoadInputsAndTargets( + mode="vc", + use_speaker_embedding=args.use_speaker_embedding, + use_second_target=args.use_second_target, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": True}, # Switch the mode of preprocessing + keep_all_data_on_mem=args.keep_all_data_on_mem, + ) + + load_cv = LoadInputsAndTargets( + mode="vc", + use_speaker_embedding=args.use_speaker_embedding, + use_second_target=args.use_second_target, + preprocess_conf=args.preprocess_conf, + preprocess_args={"train": False}, # Switch the mode of preprocessing + keep_all_data_on_mem=args.keep_all_data_on_mem, + ) + + converter = CustomConverter() + # hack to make batchsize argument as 1 + # actual bathsize is included in a list + train_iter = { + "main": ChainerDataLoader( + dataset=TransformDataset( + train_batchset, lambda data: converter([load_tr(data)]) + ), + batch_size=1, + num_workers=args.num_iter_processes, + shuffle=not use_sortagrad, + collate_fn=lambda x: x[0], + ) + } + valid_iter = { + "main": ChainerDataLoader( + dataset=TransformDataset( + valid_batchset, lambda data: converter([load_cv(data)]) + ), + batch_size=1, + shuffle=False, + collate_fn=lambda x: x[0], + num_workers=args.num_iter_processes, + ) + } + + # Set up a trainer + updater = CustomUpdater( + model, args.grad_clip, train_iter, optimizer, device, args.accum_grad + ) + trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) + + # Resume from a snapshot + if args.resume: + logging.info("resumed from %s" % args.resume) + torch_resume(args.resume, trainer) + + # set intervals + eval_interval = (args.eval_interval_epochs, "epoch") + save_interval = (args.save_interval_epochs, "epoch") + report_interval = (args.report_interval_iters, "iteration") + + # Evaluate the model with the test dataset for each epoch + trainer.extend( + CustomEvaluator(model, valid_iter, reporter, device), trigger=eval_interval + ) + + # Save snapshot for each epoch + trainer.extend(torch_snapshot(), trigger=save_interval) + + # Save best models + trainer.extend( + snapshot_object(model, "model.loss.best"), + trigger=training.triggers.MinValueTrigger( + "validation/main/loss", trigger=eval_interval + ), + ) + + # Save attention figure for each epoch + if args.num_save_attention > 0: + data = sorted( + list(valid_json.items())[: args.num_save_attention], + key=lambda x: int(x[1]["input"][0]["shape"][1]), + reverse=True, + ) + if hasattr(model, "module"): + att_vis_fn = model.module.calculate_all_attentions + plot_class = model.module.attention_plot_class + else: + att_vis_fn = model.calculate_all_attentions + plot_class = model.attention_plot_class + att_reporter = plot_class( + att_vis_fn, + data, + args.outdir + "/att_ws", + converter=converter, + transform=load_cv, + device=device, + reverse=True, + ) + trainer.extend(att_reporter, trigger=eval_interval) + else: + att_reporter = None + + # Make a plot for training and validation values + if hasattr(model, "module"): + base_plot_keys = model.module.base_plot_keys + else: + base_plot_keys = model.base_plot_keys + plot_keys = [] + for key in base_plot_keys: + plot_key = ["main/" + key, "validation/main/" + key] + trainer.extend( + extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"), + trigger=eval_interval, + ) + plot_keys += plot_key + trainer.extend( + extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"), + trigger=eval_interval, + ) + + # Write a log of evaluation statistics for each epoch + trainer.extend(extensions.LogReport(trigger=report_interval)) + report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys + trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) + trainer.extend(extensions.ProgressBar(), trigger=report_interval) + + set_early_stop(trainer, args) + if args.tensorboard_dir is not None and args.tensorboard_dir != "": + writer = SummaryWriter(args.tensorboard_dir) + trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) + + if use_sortagrad: + trainer.extend( + ShufflingEnabler([train_iter]), + trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), + ) + + # Run the training + trainer.run() + check_early_stop(trainer, args.epochs) + + +@torch.no_grad() +def decode(args): + """Decode with E2E VC model.""" + set_deterministic_pytorch(args) + # read training config + idim, odim, train_args = get_model_conf(args.model, args.model_conf) + + # show arguments + for key in sorted(vars(args).keys()): + logging.info("args: " + key + ": " + str(vars(args)[key])) + + # define model + model_class = dynamic_import(train_args.model_module) + model = model_class(idim, odim, train_args) + assert isinstance(model, TTSInterface) + logging.info(model) + + # load trained model parameters + logging.info("reading model parameters from " + args.model) + torch_load(args.model, model) + model.eval() + + # set torch device + device = torch.device("cuda" if args.ngpu > 0 else "cpu") + model = model.to(device) + + # read json data + with open(args.json, "rb") as f: + js = json.load(f)["utts"] + + # check directory + outdir = os.path.dirname(args.out) + if len(outdir) != 0 and not os.path.exists(outdir): + os.makedirs(outdir) + + load_inputs_and_targets = LoadInputsAndTargets( + mode="vc", + load_output=False, + sort_in_input_length=False, + use_speaker_embedding=train_args.use_speaker_embedding, + preprocess_conf=train_args.preprocess_conf + if args.preprocess_conf is None + else args.preprocess_conf, + preprocess_args={"train": False}, # Switch the mode of preprocessing + ) + + # define function for plot prob and att_ws + def _plot_and_save(array, figname, figsize=(6, 4), dpi=150): + import matplotlib.pyplot as plt + + shape = array.shape + if len(shape) == 1: + # for eos probability + plt.figure(figsize=figsize, dpi=dpi) + plt.plot(array) + plt.xlabel("Frame") + plt.ylabel("Probability") + plt.ylim([0, 1]) + elif len(shape) == 2: + # for tacotron 2 attention weights, whose shape is (out_length, in_length) + plt.figure(figsize=figsize, dpi=dpi) + plt.imshow(array, aspect="auto") + plt.xlabel("Input") + plt.ylabel("Output") + elif len(shape) == 4: + # for transformer attention weights, + # whose shape is (#leyers, #heads, out_length, in_length) + plt.figure(figsize=(figsize[0] * shape[0], figsize[1] * shape[1]), dpi=dpi) + for idx1, xs in enumerate(array): + for idx2, x in enumerate(xs, 1): + plt.subplot(shape[0], shape[1], idx1 * shape[1] + idx2) + plt.imshow(x, aspect="auto") + plt.xlabel("Input") + plt.ylabel("Output") + else: + raise NotImplementedError("Support only from 1D to 4D array.") + plt.tight_layout() + if not os.path.exists(os.path.dirname(figname)): + # NOTE: exist_ok = True is needed for parallel process decoding + os.makedirs(os.path.dirname(figname), exist_ok=True) + plt.savefig(figname) + plt.close() + + # define function to calculate focus rate + # (see section 3.3 in https://arxiv.org/abs/1905.09263) + def _calculate_focus_rete(att_ws): + if att_ws is None: + # fastspeech case -> None + return 1.0 + elif len(att_ws.shape) == 2: + # tacotron 2 case -> (L, T) + return float(att_ws.max(dim=-1)[0].mean()) + elif len(att_ws.shape) == 4: + # transformer case -> (#layers, #heads, L, T) + return float(att_ws.max(dim=-1)[0].mean(dim=-1).max()) + else: + raise ValueError("att_ws should be 2 or 4 dimensional tensor.") + + # define function to convert attention to duration + def _convert_att_to_duration(att_ws): + if len(att_ws.shape) == 2: + # tacotron 2 case -> (L, T) + pass + elif len(att_ws.shape) == 4: + # transformer case -> (#layers, #heads, L, T) + # get the most diagonal head according to focus rate + att_ws = torch.cat( + [att_w for att_w in att_ws], dim=0 + ) # (#heads * #layers, L, T) + diagonal_scores = att_ws.max(dim=-1)[0].mean(dim=-1) # (#heads * #layers,) + diagonal_head_idx = diagonal_scores.argmax() + att_ws = att_ws[diagonal_head_idx] # (L, T) + else: + raise ValueError("att_ws should be 2 or 4 dimensional tensor.") + # calculate duration from 2d attention weight + durations = torch.stack( + [att_ws.argmax(-1).eq(i).sum() for i in range(att_ws.shape[1])] + ) + return durations.view(-1, 1).float() + + # define writer instances + feat_writer = kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format(o=args.out)) + if args.save_durations: + dur_writer = kaldiio.WriteHelper( + "ark,scp:{o}.ark,{o}.scp".format(o=args.out.replace("feats", "durations")) + ) + if args.save_focus_rates: + fr_writer = kaldiio.WriteHelper( + "ark,scp:{o}.ark,{o}.scp".format(o=args.out.replace("feats", "focus_rates")) + ) + + # start decoding + for idx, utt_id in enumerate(js.keys()): + # setup inputs + batch = [(utt_id, js[utt_id])] + data = load_inputs_and_targets(batch) + x = torch.FloatTensor(data[0][0]).to(device) + spemb = None + if train_args.use_speaker_embedding: + spemb = torch.FloatTensor(data[1][0]).to(device) + + # decode and write + start_time = time.time() + outs, probs, att_ws = model.inference(x, args, spemb=spemb) + logging.info( + "inference speed = %.1f frames / sec." + % (int(outs.size(0)) / (time.time() - start_time)) + ) + if outs.size(0) == x.size(0) * args.maxlenratio: + logging.warning("output length reaches maximum length (%s)." % utt_id) + focus_rate = _calculate_focus_rete(att_ws) + logging.info( + "(%d/%d) %s (size: %d->%d, focus rate: %.3f)" + % (idx + 1, len(js.keys()), utt_id, x.size(0), outs.size(0), focus_rate) + ) + feat_writer[utt_id] = outs.cpu().numpy() + if args.save_durations: + ds = _convert_att_to_duration(att_ws) + dur_writer[utt_id] = ds.cpu().numpy() + if args.save_focus_rates: + fr_writer[utt_id] = np.array(focus_rate).reshape(1, 1) + + # plot and save prob and att_ws + if probs is not None: + _plot_and_save( + probs.cpu().numpy(), + os.path.dirname(args.out) + "/probs/%s_prob.png" % utt_id, + ) + if att_ws is not None: + _plot_and_save( + att_ws.cpu().numpy(), + os.path.dirname(args.out) + "/att_ws/%s_att_ws.png" % utt_id, + ) + + # close file object + feat_writer.close() + if args.save_durations: + dur_writer.close() + if args.save_focus_rates: + fr_writer.close() diff --git a/espnet/version.txt b/espnet/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..7e310bae19960a3c44b9f9095f1f95b1e4c49ad9 --- /dev/null +++ b/espnet/version.txt @@ -0,0 +1 @@ +0.9.9 diff --git a/espnet2/__init__.py b/espnet2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c92d2b7329f4e1164a2932fc149606492b87219 --- /dev/null +++ b/espnet2/__init__.py @@ -0,0 +1,3 @@ +"""Initialize espnet2 package.""" + +from espnet import __version__ # NOQA diff --git a/espnet2/asr/__init__.py b/espnet2/asr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/asr/ctc.py b/espnet2/asr/ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..c652cff6c7fbf64cca02886fedc3824ede43d204 --- /dev/null +++ b/espnet2/asr/ctc.py @@ -0,0 +1,165 @@ +import logging + +import torch +import torch.nn.functional as F +from typeguard import check_argument_types + + +class CTC(torch.nn.Module): + """CTC module. + + Args: + odim: dimension of outputs + encoder_output_sizse: number of encoder projection units + dropout_rate: dropout rate (0.0 ~ 1.0) + ctc_type: builtin or warpctc + reduce: reduce the CTC loss into a scalar + """ + + def __init__( + self, + odim: int, + encoder_output_sizse: int, + dropout_rate: float = 0.0, + ctc_type: str = "builtin", + reduce: bool = True, + ignore_nan_grad: bool = False, + ): + assert check_argument_types() + super().__init__() + eprojs = encoder_output_sizse + self.dropout_rate = dropout_rate + self.ctc_lo = torch.nn.Linear(eprojs, odim) + self.ctc_type = ctc_type + self.ignore_nan_grad = ignore_nan_grad + + if self.ctc_type == "builtin": + self.ctc_loss = torch.nn.CTCLoss(reduction="none") + elif self.ctc_type == "warpctc": + import warpctc_pytorch as warp_ctc + + if ignore_nan_grad: + raise NotImplementedError( + "ignore_nan_grad option is not supported for warp_ctc" + ) + self.ctc_loss = warp_ctc.CTCLoss(size_average=True, reduce=reduce) + else: + raise ValueError( + f'ctc_type must be "builtin" or "warpctc": {self.ctc_type}' + ) + + self.reduce = reduce + + def loss_fn(self, th_pred, th_target, th_ilen, th_olen) -> torch.Tensor: + if self.ctc_type == "builtin": + th_pred = th_pred.log_softmax(2) + loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) + + if loss.requires_grad and self.ignore_nan_grad: + # ctc_grad: (L, B, O) + ctc_grad = loss.grad_fn(torch.ones_like(loss)) + ctc_grad = ctc_grad.sum([0, 2]) + indices = torch.isfinite(ctc_grad) + size = indices.long().sum() + if size == 0: + # Return as is + logging.warning( + "All samples in this mini-batch got nan grad." + " Returning nan value instead of CTC loss" + ) + elif size != th_pred.size(1): + logging.warning( + f"{th_pred.size(1) - size}/{th_pred.size(1)}" + " samples got nan grad." + " These were ignored for CTC loss." + ) + + # Create mask for target + target_mask = torch.full( + [th_target.size(0)], + 1, + dtype=torch.bool, + device=th_target.device, + ) + s = 0 + for ind, le in enumerate(th_olen): + if not indices[ind]: + target_mask[s : s + le] = 0 + s += le + + # Calc loss again using maksed data + loss = self.ctc_loss( + th_pred[:, indices, :], + th_target[target_mask], + th_ilen[indices], + th_olen[indices], + ) + else: + size = th_pred.size(1) + + if self.reduce: + # Batch-size average + loss = loss.sum() / size + else: + loss = loss / size + return loss + + elif self.ctc_type == "warpctc": + # warpctc only supports float32 + th_pred = th_pred.to(dtype=torch.float32) + + th_target = th_target.cpu().int() + th_ilen = th_ilen.cpu().int() + th_olen = th_olen.cpu().int() + loss = self.ctc_loss(th_pred, th_target, th_ilen, th_olen) + if self.reduce: + # NOTE: sum() is needed to keep consistency since warpctc + # return as tensor w/ shape (1,) + # but builtin return as tensor w/o shape (scalar). + loss = loss.sum() + return loss + else: + raise NotImplementedError + + def forward(self, hs_pad, hlens, ys_pad, ys_lens): + """Calculate CTC loss. + + Args: + hs_pad: batch of padded hidden state sequences (B, Tmax, D) + hlens: batch of lengths of hidden state sequences (B) + ys_pad: batch of padded character id sequence tensor (B, Lmax) + ys_lens: batch of lengths of character sequence (B) + """ + # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab) + ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) + # ys_hat: (B, L, D) -> (L, B, D) + ys_hat = ys_hat.transpose(0, 1) + + # (B, L) -> (BxL,) + ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)]) + + loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to( + device=hs_pad.device, dtype=hs_pad.dtype + ) + + return loss + + def log_softmax(self, hs_pad): + """log_softmax of frame activations + + Args: + Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + Returns: + torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) + """ + return F.log_softmax(self.ctc_lo(hs_pad), dim=2) + + def argmax(self, hs_pad): + """argmax of frame activations + + Args: + torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) + Returns: + torch.Tensor: argmax applied 2d tensor (B, Tmax) + """ + return torch.argmax(self.ctc_lo(hs_pad), dim=2) diff --git a/espnet2/asr/decoder/__init__.py b/espnet2/asr/decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/asr/decoder/abs_decoder.py b/espnet2/asr/decoder/abs_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4ad18d5e36865e15b8889857fb8e463702eec42c --- /dev/null +++ b/espnet2/asr/decoder/abs_decoder.py @@ -0,0 +1,19 @@ +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + +from espnet.nets.scorer_interface import ScorerInterface + + +class AbsDecoder(torch.nn.Module, ScorerInterface, ABC): + @abstractmethod + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/espnet2/asr/decoder/rnn_decoder.py b/espnet2/asr/decoder/rnn_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fc938225f3571e531849418bb075f23adfdea7a1 --- /dev/null +++ b/espnet2/asr/decoder/rnn_decoder.py @@ -0,0 +1,334 @@ +import random + +import numpy as np +import torch +import torch.nn.functional as F +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet.nets.pytorch_backend.rnn.attentions import initial_att +from espnet2.asr.decoder.abs_decoder import AbsDecoder +from espnet2.utils.get_default_kwargs import get_default_kwargs + + +def build_attention_list( + eprojs: int, + dunits: int, + atype: str = "location", + num_att: int = 1, + num_encs: int = 1, + aheads: int = 4, + adim: int = 320, + awin: int = 5, + aconv_chans: int = 10, + aconv_filts: int = 100, + han_mode: bool = False, + han_type=None, + han_heads: int = 4, + han_dim: int = 320, + han_conv_chans: int = -1, + han_conv_filts: int = 100, + han_win: int = 5, +): + + att_list = torch.nn.ModuleList() + if num_encs == 1: + for i in range(num_att): + att = initial_att( + atype, + eprojs, + dunits, + aheads, + adim, + awin, + aconv_chans, + aconv_filts, + ) + att_list.append(att) + elif num_encs > 1: # no multi-speaker mode + if han_mode: + att = initial_att( + han_type, + eprojs, + dunits, + han_heads, + han_dim, + han_win, + han_conv_chans, + han_conv_filts, + han_mode=True, + ) + return att + else: + att_list = torch.nn.ModuleList() + for idx in range(num_encs): + att = initial_att( + atype[idx], + eprojs, + dunits, + aheads[idx], + adim[idx], + awin[idx], + aconv_chans[idx], + aconv_filts[idx], + ) + att_list.append(att) + else: + raise ValueError( + "Number of encoders needs to be more than one. {}".format(num_encs) + ) + return att_list + + +class RNNDecoder(AbsDecoder): + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + rnn_type: str = "lstm", + num_layers: int = 1, + hidden_size: int = 320, + sampling_probability: float = 0.0, + dropout: float = 0.0, + context_residual: bool = False, + replace_sos: bool = False, + num_encs: int = 1, + att_conf: dict = get_default_kwargs(build_attention_list), + ): + # FIXME(kamo): The parts of num_spk should be refactored more more more + assert check_argument_types() + if rnn_type not in {"lstm", "gru"}: + raise ValueError(f"Not supported: rnn_type={rnn_type}") + + super().__init__() + eprojs = encoder_output_size + self.dtype = rnn_type + self.dunits = hidden_size + self.dlayers = num_layers + self.context_residual = context_residual + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.odim = vocab_size + self.sampling_probability = sampling_probability + self.dropout = dropout + self.num_encs = num_encs + + # for multilingual translation + self.replace_sos = replace_sos + + self.embed = torch.nn.Embedding(vocab_size, hidden_size) + self.dropout_emb = torch.nn.Dropout(p=dropout) + + self.decoder = torch.nn.ModuleList() + self.dropout_dec = torch.nn.ModuleList() + self.decoder += [ + torch.nn.LSTMCell(hidden_size + eprojs, hidden_size) + if self.dtype == "lstm" + else torch.nn.GRUCell(hidden_size + eprojs, hidden_size) + ] + self.dropout_dec += [torch.nn.Dropout(p=dropout)] + for _ in range(1, self.dlayers): + self.decoder += [ + torch.nn.LSTMCell(hidden_size, hidden_size) + if self.dtype == "lstm" + else torch.nn.GRUCell(hidden_size, hidden_size) + ] + self.dropout_dec += [torch.nn.Dropout(p=dropout)] + # NOTE: dropout is applied only for the vertical connections + # see https://arxiv.org/pdf/1409.2329.pdf + + if context_residual: + self.output = torch.nn.Linear(hidden_size + eprojs, vocab_size) + else: + self.output = torch.nn.Linear(hidden_size, vocab_size) + + self.att_list = build_attention_list( + eprojs=eprojs, dunits=hidden_size, **att_conf + ) + + def zero_state(self, hs_pad): + return hs_pad.new_zeros(hs_pad.size(0), self.dunits) + + def rnn_forward(self, ey, z_list, c_list, z_prev, c_prev): + if self.dtype == "lstm": + z_list[0], c_list[0] = self.decoder[0](ey, (z_prev[0], c_prev[0])) + for i in range(1, self.dlayers): + z_list[i], c_list[i] = self.decoder[i]( + self.dropout_dec[i - 1](z_list[i - 1]), + (z_prev[i], c_prev[i]), + ) + else: + z_list[0] = self.decoder[0](ey, z_prev[0]) + for i in range(1, self.dlayers): + z_list[i] = self.decoder[i]( + self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i] + ) + return z_list, c_list + + def forward(self, hs_pad, hlens, ys_in_pad, ys_in_lens, strm_idx=0): + # to support mutiple encoder asr mode, in single encoder mode, + # convert torch.Tensor to List of torch.Tensor + if self.num_encs == 1: + hs_pad = [hs_pad] + hlens = [hlens] + + # attention index for the attention module + # in SPA (speaker parallel attention), + # att_idx is used to select attention module. In other cases, it is 0. + att_idx = min(strm_idx, len(self.att_list) - 1) + + # hlens should be list of list of integer + hlens = [list(map(int, hlens[idx])) for idx in range(self.num_encs)] + + # get dim, length info + olength = ys_in_pad.size(1) + + # initialization + c_list = [self.zero_state(hs_pad[0])] + z_list = [self.zero_state(hs_pad[0])] + for _ in range(1, self.dlayers): + c_list.append(self.zero_state(hs_pad[0])) + z_list.append(self.zero_state(hs_pad[0])) + z_all = [] + if self.num_encs == 1: + att_w = None + self.att_list[att_idx].reset() # reset pre-computation of h + else: + att_w_list = [None] * (self.num_encs + 1) # atts + han + att_c_list = [None] * self.num_encs # atts + for idx in range(self.num_encs + 1): + # reset pre-computation of h in atts and han + self.att_list[idx].reset() + + # pre-computation of embedding + eys = self.dropout_emb(self.embed(ys_in_pad)) # utt x olen x zdim + + # loop for an output sequence + for i in range(olength): + if self.num_encs == 1: + att_c, att_w = self.att_list[att_idx]( + hs_pad[0], hlens[0], self.dropout_dec[0](z_list[0]), att_w + ) + else: + for idx in range(self.num_encs): + att_c_list[idx], att_w_list[idx] = self.att_list[idx]( + hs_pad[idx], + hlens[idx], + self.dropout_dec[0](z_list[0]), + att_w_list[idx], + ) + hs_pad_han = torch.stack(att_c_list, dim=1) + hlens_han = [self.num_encs] * len(ys_in_pad) + att_c, att_w_list[self.num_encs] = self.att_list[self.num_encs]( + hs_pad_han, + hlens_han, + self.dropout_dec[0](z_list[0]), + att_w_list[self.num_encs], + ) + if i > 0 and random.random() < self.sampling_probability: + z_out = self.output(z_all[-1]) + z_out = np.argmax(z_out.detach().cpu(), axis=1) + z_out = self.dropout_emb(self.embed(to_device(self, z_out))) + ey = torch.cat((z_out, att_c), dim=1) # utt x (zdim + hdim) + else: + # utt x (zdim + hdim) + ey = torch.cat((eys[:, i, :], att_c), dim=1) + z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_list, c_list) + if self.context_residual: + z_all.append( + torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) + ) # utt x (zdim + hdim) + else: + z_all.append(self.dropout_dec[-1](z_list[-1])) # utt x (zdim) + + z_all = torch.stack(z_all, dim=1) + z_all = self.output(z_all) + z_all.masked_fill_( + make_pad_mask(ys_in_lens, z_all, 1), + 0, + ) + return z_all, ys_in_lens + + def init_state(self, x): + # to support mutiple encoder asr mode, in single encoder mode, + # convert torch.Tensor to List of torch.Tensor + if self.num_encs == 1: + x = [x] + + c_list = [self.zero_state(x[0].unsqueeze(0))] + z_list = [self.zero_state(x[0].unsqueeze(0))] + for _ in range(1, self.dlayers): + c_list.append(self.zero_state(x[0].unsqueeze(0))) + z_list.append(self.zero_state(x[0].unsqueeze(0))) + # TODO(karita): support strm_index for `asr_mix` + strm_index = 0 + att_idx = min(strm_index, len(self.att_list) - 1) + if self.num_encs == 1: + a = None + self.att_list[att_idx].reset() # reset pre-computation of h + else: + a = [None] * (self.num_encs + 1) # atts + han + for idx in range(self.num_encs + 1): + # reset pre-computation of h in atts and han + self.att_list[idx].reset() + return dict( + c_prev=c_list[:], + z_prev=z_list[:], + a_prev=a, + workspace=(att_idx, z_list, c_list), + ) + + def score(self, yseq, state, x): + # to support mutiple encoder asr mode, in single encoder mode, + # convert torch.Tensor to List of torch.Tensor + if self.num_encs == 1: + x = [x] + + att_idx, z_list, c_list = state["workspace"] + vy = yseq[-1].unsqueeze(0) + ey = self.dropout_emb(self.embed(vy)) # utt list (1) x zdim + if self.num_encs == 1: + att_c, att_w = self.att_list[att_idx]( + x[0].unsqueeze(0), + [x[0].size(0)], + self.dropout_dec[0](state["z_prev"][0]), + state["a_prev"], + ) + else: + att_w = [None] * (self.num_encs + 1) # atts + han + att_c_list = [None] * self.num_encs # atts + for idx in range(self.num_encs): + att_c_list[idx], att_w[idx] = self.att_list[idx]( + x[idx].unsqueeze(0), + [x[idx].size(0)], + self.dropout_dec[0](state["z_prev"][0]), + state["a_prev"][idx], + ) + h_han = torch.stack(att_c_list, dim=1) + att_c, att_w[self.num_encs] = self.att_list[self.num_encs]( + h_han, + [self.num_encs], + self.dropout_dec[0](state["z_prev"][0]), + state["a_prev"][self.num_encs], + ) + ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim) + z_list, c_list = self.rnn_forward( + ey, z_list, c_list, state["z_prev"], state["c_prev"] + ) + if self.context_residual: + logits = self.output( + torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1) + ) + else: + logits = self.output(self.dropout_dec[-1](z_list[-1])) + logp = F.log_softmax(logits, dim=1).squeeze(0) + return ( + logp, + dict( + c_prev=c_list[:], + z_prev=z_list[:], + a_prev=att_w, + workspace=(att_idx, z_list, c_list), + ), + ) diff --git a/espnet2/asr/decoder/transformer_decoder.py b/espnet2/asr/decoder/transformer_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..eb612773c9692515824a19e7d8a1e7a9d36ededf --- /dev/null +++ b/espnet2/asr/decoder/transformer_decoder.py @@ -0,0 +1,521 @@ +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Decoder definition.""" +from typing import Any +from typing import List +from typing import Sequence +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.decoder_layer import DecoderLayer +from espnet.nets.pytorch_backend.transformer.dynamic_conv import DynamicConvolution +from espnet.nets.pytorch_backend.transformer.dynamic_conv2d import DynamicConvolution2D +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.lightconv import LightweightConvolution +from espnet.nets.pytorch_backend.transformer.lightconv2d import LightweightConvolution2D +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet2.asr.decoder.abs_decoder import AbsDecoder + + +class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface): + """Base class of Transfomer decoder module. + + Args: + vocab_size: output dim + encoder_output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the number of units of position-wise feed forward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + self_attention_dropout_rate: dropout rate for attention + input_layer: input layer type + use_output_layer: whether to use output layer + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: whether to use layer_norm before the first block + concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) + """ + + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + input_layer: str = "embed", + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + ): + assert check_argument_types() + super().__init__() + attention_dim = encoder_output_size + + if input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(vocab_size, attention_dim), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(vocab_size, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + else: + raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") + + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + if use_output_layer: + self.output_layer = torch.nn.Linear(attention_dim, vocab_size) + else: + self.output_layer = None + + # Must set by the inheritance + self.decoders = None + + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward decoder. + + Args: + hs_pad: encoded memory, float32 (batch, maxlen_in, feat) + hlens: (batch) + ys_in_pad: + input token ids, int64 (batch, maxlen_out) + if input_layer == "embed" + input tensor (batch, maxlen_out, #mels) in the other cases + ys_in_lens: (batch) + Returns: + (tuple): tuple containing: + + x: decoded token score before softmax (batch, maxlen_out, token) + if use_output_layer is True, + olens: (batch, ) + """ + tgt = ys_in_pad + # tgt_mask: (B, 1, L) + tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) + # m: (1, L, L) + m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) + # tgt_mask: (B, L, L) + tgt_mask = tgt_mask & m + + memory = hs_pad + memory_mask = (~make_pad_mask(hlens))[:, None, :].to(memory.device) + + x = self.embed(tgt) + x, tgt_mask, memory, memory_mask = self.decoders( + x, tgt_mask, memory, memory_mask + ) + if self.normalize_before: + x = self.after_norm(x) + if self.output_layer is not None: + x = self.output_layer(x) + + olens = tgt_mask.sum(1) + return x, olens + + def forward_one_step( + self, + tgt: torch.Tensor, + tgt_mask: torch.Tensor, + memory: torch.Tensor, + cache: List[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Forward one step. + + Args: + tgt: input token ids, int64 (batch, maxlen_out) + tgt_mask: input token mask, (batch, maxlen_out) + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (include 1.2) + memory: encoded memory, float32 (batch, maxlen_in, feat) + cache: cached output list of (batch, max_time_out-1, size) + Returns: + y, cache: NN output value and cache per `self.decoders`. + y.shape` is (batch, maxlen_out, token) + """ + x = self.embed(tgt) + if cache is None: + cache = [None] * len(self.decoders) + new_cache = [] + for c, decoder in zip(cache, self.decoders): + x, tgt_mask, memory, memory_mask = decoder( + x, tgt_mask, memory, None, cache=c + ) + new_cache.append(x) + + if self.normalize_before: + y = self.after_norm(x[:, -1]) + else: + y = x[:, -1] + if self.output_layer is not None: + y = torch.log_softmax(self.output_layer(y), dim=-1) + + return y, new_cache + + def score(self, ys, state, x): + """Score.""" + ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0) + logp, state = self.forward_one_step( + ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state + ) + return logp.squeeze(0), state + + def batch_score( + self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor + ) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch. + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.decoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [ + torch.stack([states[b][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + + # batch decoding + ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0) + logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + return logp, state_list + + +class TransformerDecoder(BaseTransformerDecoder): + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + ): + assert check_argument_types() + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + MultiHeadedAttention( + attention_heads, attention_dim, self_attention_dropout_rate + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + +class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder): + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + conv_wshare: int = 4, + conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), + conv_usebias: int = False, + ): + assert check_argument_types() + if len(conv_kernel_length) != num_blocks: + raise ValueError( + "conv_kernel_length must have equal number of values to num_blocks: " + f"{len(conv_kernel_length)} != {num_blocks}" + ) + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + LightweightConvolution( + wshare=conv_wshare, + n_feat=attention_dim, + dropout_rate=self_attention_dropout_rate, + kernel_size=conv_kernel_length[lnum], + use_kernel_mask=True, + use_bias=conv_usebias, + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + +class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder): + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + conv_wshare: int = 4, + conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), + conv_usebias: int = False, + ): + assert check_argument_types() + if len(conv_kernel_length) != num_blocks: + raise ValueError( + "conv_kernel_length must have equal number of values to num_blocks: " + f"{len(conv_kernel_length)} != {num_blocks}" + ) + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + + attention_dim = encoder_output_size + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + LightweightConvolution2D( + wshare=conv_wshare, + n_feat=attention_dim, + dropout_rate=self_attention_dropout_rate, + kernel_size=conv_kernel_length[lnum], + use_kernel_mask=True, + use_bias=conv_usebias, + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + +class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder): + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + conv_wshare: int = 4, + conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), + conv_usebias: int = False, + ): + assert check_argument_types() + if len(conv_kernel_length) != num_blocks: + raise ValueError( + "conv_kernel_length must have equal number of values to num_blocks: " + f"{len(conv_kernel_length)} != {num_blocks}" + ) + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + attention_dim = encoder_output_size + + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + DynamicConvolution( + wshare=conv_wshare, + n_feat=attention_dim, + dropout_rate=self_attention_dropout_rate, + kernel_size=conv_kernel_length[lnum], + use_kernel_mask=True, + use_bias=conv_usebias, + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) + + +class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder): + def __init__( + self, + vocab_size: int, + encoder_output_size: int, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, + input_layer: str = "embed", + use_output_layer: bool = True, + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + conv_wshare: int = 4, + conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11), + conv_usebias: int = False, + ): + assert check_argument_types() + if len(conv_kernel_length) != num_blocks: + raise ValueError( + "conv_kernel_length must have equal number of values to num_blocks: " + f"{len(conv_kernel_length)} != {num_blocks}" + ) + super().__init__( + vocab_size=vocab_size, + encoder_output_size=encoder_output_size, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + input_layer=input_layer, + use_output_layer=use_output_layer, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + ) + attention_dim = encoder_output_size + + self.decoders = repeat( + num_blocks, + lambda lnum: DecoderLayer( + attention_dim, + DynamicConvolution2D( + wshare=conv_wshare, + n_feat=attention_dim, + dropout_rate=self_attention_dropout_rate, + kernel_size=conv_kernel_length[lnum], + use_kernel_mask=True, + use_bias=conv_usebias, + ), + MultiHeadedAttention( + attention_heads, attention_dim, src_attention_dropout_rate + ), + PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), + dropout_rate, + normalize_before, + concat_after, + ), + ) diff --git a/espnet2/asr/encoder/__init__.py b/espnet2/asr/encoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/asr/encoder/abs_encoder.py b/espnet2/asr/encoder/abs_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1fb7c97c35b5939a01ce9df3db1faf45a19d9938 --- /dev/null +++ b/espnet2/asr/encoder/abs_encoder.py @@ -0,0 +1,21 @@ +from abc import ABC +from abc import abstractmethod +from typing import Optional +from typing import Tuple + +import torch + + +class AbsEncoder(torch.nn.Module, ABC): + @abstractmethod + def output_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + raise NotImplementedError diff --git a/espnet2/asr/encoder/conformer_encoder.py b/espnet2/asr/encoder/conformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9608301196763ac89dc04dffc7c3ca2cae4a5d --- /dev/null +++ b/espnet2/asr/encoder/conformer_encoder.py @@ -0,0 +1,304 @@ +# Copyright 2020 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Conformer encoder definition.""" + +from typing import Optional +from typing import Tuple + +import logging +import torch + +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule +from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer +from espnet.nets.pytorch_backend.nets_utils import get_activation +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.transformer.attention import ( + MultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttention, # noqa: H301 + LegacyRelPositionMultiHeadedAttention, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, # noqa: H301 + ScaledPositionalEncoding, # noqa: H301 + RelPositionalEncoding, # noqa: H301 + LegacyRelPositionalEncoding, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling import check_short_utt +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling6 +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling8 +from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError +from espnet2.asr.encoder.abs_encoder import AbsEncoder + + +class ConformerEncoder(AbsEncoder): + """Conformer encoder module. + + Args: + input_size (int): Input dimension. + output_size (int): Dimention of attention. + attention_heads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + num_blocks (int): The number of decoder blocks. + dropout_rate (float): Dropout rate. + attention_dropout_rate (float): Dropout rate in attention. + positional_dropout_rate (float): Dropout rate after adding positional encoding. + input_layer (Union[str, torch.nn.Module]): Input layer type. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + If True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + If False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. + rel_pos_type (str): Whether to use the latest relative positional encoding or + the legacy one. The legacy relative positional encoding will be deprecated + in the future. More Details can be found in + https://github.com/espnet/espnet/pull/2816. + encoder_pos_enc_layer_type (str): Encoder positional encoding layer type. + encoder_attn_layer_type (str): Encoder attention layer type. + activation_type (str): Encoder activation function type. + macaron_style (bool): Whether to use macaron style for positionwise layer. + use_cnn_module (bool): Whether to use convolution module. + zero_triu (bool): Whether to zero the upper triangular part of attention matrix. + cnn_module_kernel (int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: str = "conv2d", + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 3, + macaron_style: bool = False, + rel_pos_type: str = "legacy", + pos_enc_layer_type: str = "rel_pos", + selfattention_layer_type: str = "rel_selfattn", + activation_type: str = "swish", + use_cnn_module: bool = True, + zero_triu: bool = False, + cnn_module_kernel: int = 31, + padding_idx: int = -1, + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if rel_pos_type == "legacy": + if pos_enc_layer_type == "rel_pos": + pos_enc_layer_type = "legacy_rel_pos" + if selfattention_layer_type == "rel_selfattn": + selfattention_layer_type = "legacy_rel_selfattn" + elif rel_pos_type == "latest": + assert selfattention_layer_type != "legacy_rel_selfattn" + assert pos_enc_layer_type != "legacy_rel_pos" + else: + raise ValueError("unknown rel_pos_type: " + rel_pos_type) + + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + elif pos_enc_layer_type == "legacy_rel_pos": + assert selfattention_layer_type == "legacy_rel_selfattn" + pos_enc_class = LegacyRelPositionalEncoding + logging.warning( + "Using legacy_rel_pos and it will be deprecated in the future." + ) + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8( + input_size, + output_size, + dropout_rate, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(output_size, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + if selfattention_layer_type == "selfattn": + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + elif selfattention_layer_type == "legacy_rel_selfattn": + assert pos_enc_layer_type == "legacy_rel_pos" + encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + ) + logging.warning( + "Using legacy_rel_selfattn and it will be deprecated in the future." + ) + elif selfattention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + attention_dropout_rate, + zero_triu, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + + convolution_layer = ConvolutionModule + convolution_layer_args = (output_size, cnn_module_kernel, activation) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Calculate forward propagation. + + Args: + xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). + ilens (torch.Tensor): Input length (#batch). + prev_states (torch.Tensor): Not to be used now. + + Returns: + torch.Tensor: Output tensor (#batch, L, output_size). + torch.Tensor: Output length (#batch). + torch.Tensor: Not to be used now. + + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if ( + isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8) + ): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + xs_pad, masks = self.encoders(xs_pad, masks) + if isinstance(xs_pad, tuple): + xs_pad = xs_pad[0] + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + return xs_pad, olens, None diff --git a/espnet2/asr/encoder/contextual_block_transformer_encoder.py b/espnet2/asr/encoder/contextual_block_transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f5765d45651449c9cff885ff05beb54186e385bf --- /dev/null +++ b/espnet2/asr/encoder/contextual_block_transformer_encoder.py @@ -0,0 +1,327 @@ +# Copyright 2020 Emiru Tsunoo +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder definition.""" +from typing import Optional +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.contextual_block_encoder_layer import ( + ContextualBlockEncoderLayer, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling_without_posenc import ( + Conv2dSubsamplingWOPosEnc, # noqa: H301 +) +from espnet2.asr.encoder.abs_encoder import AbsEncoder +import math + + +class ContextualBlockTransformerEncoder(AbsEncoder): + """Contextual Block Transformer encoder module. + + Details in Tsunoo et al. "Transformer ASR with contextual block processing" + (https://arxiv.org/abs/1910.07204) + + Args: + input_size: input dim + output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the number of units of position-wise feed forward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + attention_dropout_rate: dropout rate in attention + positional_dropout_rate: dropout rate after adding positional encoding + input_layer: input layer type + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: whether to use layer_norm before the first block + concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) + positionwise_layer_type: linear of conv1d + positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + padding_idx: padding_idx for input_layer=embed + block_size: block size for contextual block processing + hop_Size: hop size for block processing + look_ahead: look-ahead size for block_processing + init_average: whether to use average as initial context (otherwise max values) + ctx_pos_enc: whether to use positional encoding to the context vectors + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + block_size: int = 40, + hop_size: int = 16, + look_ahead: int = 16, + init_average: bool = True, + ctx_pos_enc: bool = True, + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + self.pos_enc = pos_enc_class(output_size, positional_dropout_rate) + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsamplingWOPosEnc( + input_size, output_size, dropout_rate, kernels=[3, 3], strides=[2, 2] + ) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsamplingWOPosEnc( + input_size, output_size, dropout_rate, kernels=[3, 5], strides=[2, 3] + ) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsamplingWOPosEnc( + input_size, + output_size, + dropout_rate, + kernels=[3, 3, 3], + strides=[2, 2, 2], + ) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + ) + elif input_layer is None: + self.embed = None + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + self.encoders = repeat( + num_blocks, + lambda lnum: ContextualBlockEncoderLayer( + output_size, + MultiHeadedAttention( + attention_heads, output_size, attention_dropout_rate + ), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + num_blocks, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + # for block processing + self.block_size = block_size + self.hop_size = hop_size + self.look_ahead = look_ahead + self.init_average = init_average + self.ctx_pos_enc = ctx_pos_enc + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if isinstance(self.embed, Conv2dSubsamplingWOPosEnc): + xs_pad, masks = self.embed(xs_pad, masks) + elif self.embed is not None: + xs_pad = self.embed(xs_pad) + + # create empty output container + total_frame_num = xs_pad.size(1) + ys_pad = xs_pad.new_zeros(xs_pad.size()) + + past_size = self.block_size - self.hop_size - self.look_ahead + + # block_size could be 0 meaning infinite + # apply usual encoder for short sequence + if self.block_size == 0 or total_frame_num <= self.block_size: + xs_pad, masks, _, _, _ = self.encoders( + self.pos_enc(xs_pad), masks, None, None + ) + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + return xs_pad, olens, None + + # start block processing + cur_hop = 0 + block_num = math.ceil( + float(total_frame_num - past_size - self.look_ahead) / float(self.hop_size) + ) + bsize = xs_pad.size(0) + addin = xs_pad.new_zeros( + bsize, block_num, xs_pad.size(-1) + ) # additional context embedding vecctors + + # first step + if self.init_average: # initialize with average value + addin[:, 0, :] = xs_pad.narrow(1, cur_hop, self.block_size).mean(1) + else: # initialize with max value + addin[:, 0, :] = xs_pad.narrow(1, cur_hop, self.block_size).max(1) + cur_hop += self.hop_size + # following steps + while cur_hop + self.block_size < total_frame_num: + if self.init_average: # initialize with average value + addin[:, cur_hop // self.hop_size, :] = xs_pad.narrow( + 1, cur_hop, self.block_size + ).mean(1) + else: # initialize with max value + addin[:, cur_hop // self.hop_size, :] = xs_pad.narrow( + 1, cur_hop, self.block_size + ).max(1) + cur_hop += self.hop_size + # last step + if cur_hop < total_frame_num and cur_hop // self.hop_size < block_num: + if self.init_average: # initialize with average value + addin[:, cur_hop // self.hop_size, :] = xs_pad.narrow( + 1, cur_hop, total_frame_num - cur_hop + ).mean(1) + else: # initialize with max value + addin[:, cur_hop // self.hop_size, :] = xs_pad.narrow( + 1, cur_hop, total_frame_num - cur_hop + ).max(1) + + if self.ctx_pos_enc: + addin = self.pos_enc(addin) + + xs_pad = self.pos_enc(xs_pad) + + # set up masks + mask_online = xs_pad.new_zeros( + xs_pad.size(0), block_num, self.block_size + 2, self.block_size + 2 + ) + mask_online.narrow(2, 1, self.block_size + 1).narrow( + 3, 0, self.block_size + 1 + ).fill_(1) + + xs_chunk = xs_pad.new_zeros( + bsize, block_num, self.block_size + 2, xs_pad.size(-1) + ) + + # fill the input + # first step + left_idx = 0 + block_idx = 0 + xs_chunk[:, block_idx, 1 : self.block_size + 1] = xs_pad.narrow( + -2, left_idx, self.block_size + ) + left_idx += self.hop_size + block_idx += 1 + # following steps + while left_idx + self.block_size < total_frame_num and block_idx < block_num: + xs_chunk[:, block_idx, 1 : self.block_size + 1] = xs_pad.narrow( + -2, left_idx, self.block_size + ) + left_idx += self.hop_size + block_idx += 1 + # last steps + last_size = total_frame_num - left_idx + xs_chunk[:, block_idx, 1 : last_size + 1] = xs_pad.narrow( + -2, left_idx, last_size + ) + + # fill the initial context vector + xs_chunk[:, 0, 0] = addin[:, 0] + xs_chunk[:, 1:, 0] = addin[:, 0 : block_num - 1] + xs_chunk[:, :, self.block_size + 1] = addin + + # forward + ys_chunk, mask_online, _, _, _ = self.encoders(xs_chunk, mask_online, xs_chunk) + + # copy output + # first step + offset = self.block_size - self.look_ahead - self.hop_size + 1 + left_idx = 0 + block_idx = 0 + cur_hop = self.block_size - self.look_ahead + ys_pad[:, left_idx:cur_hop] = ys_chunk[:, block_idx, 1 : cur_hop + 1] + left_idx += self.hop_size + block_idx += 1 + # following steps + while left_idx + self.block_size < total_frame_num and block_idx < block_num: + ys_pad[:, cur_hop : cur_hop + self.hop_size] = ys_chunk[ + :, block_idx, offset : offset + self.hop_size + ] + cur_hop += self.hop_size + left_idx += self.hop_size + block_idx += 1 + ys_pad[:, cur_hop:total_frame_num] = ys_chunk[ + :, block_idx, offset : last_size + 1, : + ] + + if self.normalize_before: + ys_pad = self.after_norm(ys_pad) + + olens = masks.squeeze(1).sum(1) + return ys_pad, olens, None diff --git a/espnet2/asr/encoder/rnn_encoder.py b/espnet2/asr/encoder/rnn_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fd57ebfd2d8a123e0e042dfa3f4f5df93e2f0815 --- /dev/null +++ b/espnet2/asr/encoder/rnn_encoder.py @@ -0,0 +1,115 @@ +from typing import Optional +from typing import Sequence +from typing import Tuple + +import numpy as np +import torch +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.rnn.encoders import RNN +from espnet.nets.pytorch_backend.rnn.encoders import RNNP +from espnet2.asr.encoder.abs_encoder import AbsEncoder + + +class RNNEncoder(AbsEncoder): + """RNNEncoder class. + + Args: + input_size: The number of expected features in the input + output_size: The number of output features + hidden_size: The number of hidden features + bidirectional: If ``True`` becomes a bidirectional LSTM + use_projection: Use projection layer or not + num_layers: Number of recurrent layers + dropout: dropout probability + + """ + + def __init__( + self, + input_size: int, + rnn_type: str = "lstm", + bidirectional: bool = True, + use_projection: bool = True, + num_layers: int = 4, + hidden_size: int = 320, + output_size: int = 320, + dropout: float = 0.0, + subsample: Optional[Sequence[int]] = (2, 2, 1, 1), + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + self.rnn_type = rnn_type + self.bidirectional = bidirectional + self.use_projection = use_projection + + if rnn_type not in {"lstm", "gru"}: + raise ValueError(f"Not supported rnn_type={rnn_type}") + + if subsample is None: + subsample = np.ones(num_layers + 1, dtype=np.int) + else: + subsample = subsample[:num_layers] + # Append 1 at the beginning because the second or later is used + subsample = np.pad( + np.array(subsample, dtype=np.int), + [1, num_layers - len(subsample)], + mode="constant", + constant_values=1, + ) + + rnn_type = ("b" if bidirectional else "") + rnn_type + if use_projection: + self.enc = torch.nn.ModuleList( + [ + RNNP( + input_size, + num_layers, + hidden_size, + output_size, + subsample, + dropout, + typ=rnn_type, + ) + ] + ) + + else: + self.enc = torch.nn.ModuleList( + [ + RNN( + input_size, + num_layers, + hidden_size, + output_size, + dropout, + typ=rnn_type, + ) + ] + ) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if prev_states is None: + prev_states = [None] * len(self.enc) + assert len(prev_states) == len(self.enc) + + current_states = [] + for module, prev_state in zip(self.enc, prev_states): + xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state) + current_states.append(states) + + if self.use_projection: + xs_pad.masked_fill_(make_pad_mask(ilens, xs_pad, 1), 0.0) + else: + xs_pad = xs_pad.masked_fill(make_pad_mask(ilens, xs_pad, 1), 0.0) + return xs_pad, ilens, current_states diff --git a/espnet2/asr/encoder/transformer_encoder.py b/espnet2/asr/encoder/transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..472339e4c143498f653c2d9d04cba2d5280c0c53 --- /dev/null +++ b/espnet2/asr/encoder/transformer_encoder.py @@ -0,0 +1,185 @@ +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder definition.""" +from typing import Optional +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder_layer import EncoderLayer +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling import check_short_utt +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling6 +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling8 +from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError +from espnet2.asr.encoder.abs_encoder import AbsEncoder + + +class TransformerEncoder(AbsEncoder): + """Transformer encoder module. + + Args: + input_size: input dim + output_size: dimension of attention + attention_heads: the number of heads of multi head attention + linear_units: the number of units of position-wise feed forward + num_blocks: the number of decoder blocks + dropout_rate: dropout rate + attention_dropout_rate: dropout rate in attention + positional_dropout_rate: dropout rate after adding positional encoding + input_layer: input layer type + pos_enc_class: PositionalEncoding or ScaledPositionalEncoding + normalize_before: whether to use layer_norm before the first block + concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. + i.e. x -> x + att(x) + positionwise_layer_type: linear of conv1d + positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.0, + input_layer: Optional[str] = "conv2d", + pos_enc_class=PositionalEncoding, + normalize_before: bool = True, + concat_after: bool = False, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + padding_idx: int = -1, + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(input_size, output_size), + torch.nn.LayerNorm(output_size), + torch.nn.Dropout(dropout_rate), + torch.nn.ReLU(), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate) + elif input_layer == "conv2d6": + self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate) + elif input_layer == "conv2d8": + self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), + pos_enc_class(output_size, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(output_size, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + output_size, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + output_size, + MultiHeadedAttention( + attention_heads, output_size, attention_dropout_rate + ), + positionwise_layer(*positionwise_layer_args), + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Embed positions in tensor. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if ( + isinstance(self.embed, Conv2dSubsampling) + or isinstance(self.embed, Conv2dSubsampling6) + or isinstance(self.embed, Conv2dSubsampling8) + ): + short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) + if short_status: + raise TooShortUttError( + f"has {xs_pad.size(1)} frames and is too short for subsampling " + + f"(it needs more than {limit_size} frames), return empty results", + xs_pad.size(1), + limit_size, + ) + xs_pad, masks = self.embed(xs_pad, masks) + else: + xs_pad = self.embed(xs_pad) + xs_pad, masks = self.encoders(xs_pad, masks) + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1) + return xs_pad, olens, None diff --git a/espnet2/asr/encoder/vgg_rnn_encoder.py b/espnet2/asr/encoder/vgg_rnn_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..8c36c8cf4f2bef3cc1db7f9c0cba3ea9ad902024 --- /dev/null +++ b/espnet2/asr/encoder/vgg_rnn_encoder.py @@ -0,0 +1,106 @@ +from typing import Tuple + +import numpy as np +import torch +from typeguard import check_argument_types + +from espnet.nets.e2e_asr_common import get_vgg2l_odim +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.rnn.encoders import RNN +from espnet.nets.pytorch_backend.rnn.encoders import RNNP +from espnet.nets.pytorch_backend.rnn.encoders import VGG2L +from espnet2.asr.encoder.abs_encoder import AbsEncoder + + +class VGGRNNEncoder(AbsEncoder): + """VGGRNNEncoder class. + + Args: + input_size: The number of expected features in the input + bidirectional: If ``True`` becomes a bidirectional LSTM + use_projection: Use projection layer or not + num_layers: Number of recurrent layers + hidden_size: The number of hidden features + output_size: The number of output features + dropout: dropout probability + + """ + + def __init__( + self, + input_size: int, + rnn_type: str = "lstm", + bidirectional: bool = True, + use_projection: bool = True, + num_layers: int = 4, + hidden_size: int = 320, + output_size: int = 320, + dropout: float = 0.0, + in_channel: int = 1, + ): + assert check_argument_types() + super().__init__() + self._output_size = output_size + self.rnn_type = rnn_type + self.bidirectional = bidirectional + self.use_projection = use_projection + if rnn_type not in {"lstm", "gru"}: + raise ValueError(f"Not supported rnn_type={rnn_type}") + + # Subsample is not used for VGGRNN + subsample = np.ones(num_layers + 1, dtype=np.int) + rnn_type = ("b" if bidirectional else "") + rnn_type + if use_projection: + self.enc = torch.nn.ModuleList( + [ + VGG2L(in_channel), + RNNP( + get_vgg2l_odim(input_size, in_channel=in_channel), + num_layers, + hidden_size, + output_size, + subsample, + dropout, + typ=rnn_type, + ), + ] + ) + + else: + self.enc = torch.nn.ModuleList( + [ + VGG2L(in_channel), + RNN( + get_vgg2l_odim(input_size, in_channel=in_channel), + num_layers, + hidden_size, + output_size, + dropout, + typ=rnn_type, + ), + ] + ) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if prev_states is None: + prev_states = [None] * len(self.enc) + assert len(prev_states) == len(self.enc) + + current_states = [] + for module, prev_state in zip(self.enc, prev_states): + xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state) + current_states.append(states) + + if self.use_projection: + xs_pad.masked_fill_(make_pad_mask(ilens, xs_pad, 1), 0.0) + else: + xs_pad = xs_pad.masked_fill(make_pad_mask(ilens, xs_pad, 1), 0.0) + return xs_pad, ilens, current_states diff --git a/espnet2/asr/encoder/wav2vec2_encoder.py b/espnet2/asr/encoder/wav2vec2_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a9e6d6e8932885db7d85052a5c2dce552fa8e4 --- /dev/null +++ b/espnet2/asr/encoder/wav2vec2_encoder.py @@ -0,0 +1,165 @@ +# Copyright 2021 Xuankai Chang +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder definition.""" +import contextlib +import copy +from filelock import FileLock +import logging +import os +from typing import Optional +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet2.asr.encoder.abs_encoder import AbsEncoder + + +class FairSeqWav2Vec2Encoder(AbsEncoder): + """FairSeq Wav2Vec2 encoder module. + + Args: + input_size: input dim + output_size: dimension of attention + w2v_url: url to Wav2Vec2.0 pretrained model + w2v_dir_path: directory to download the Wav2Vec2.0 pretrained model. + normalize_before: whether to use layer_norm before the first block + finetune_last_n_layers: last n layers to be finetuned in Wav2Vec2.0 + 0 means to finetune every layer if freeze_w2v=False. + """ + + def __init__( + self, + input_size: int, + w2v_url: str, + w2v_dir_path: str = "./", + output_size: int = 256, + normalize_before: bool = False, + freeze_finetune_updates: int = 0, + ): + assert check_argument_types() + super().__init__() + + if w2v_url != "": + try: + import fairseq + from fairseq.models.wav2vec.wav2vec2 import Wav2Vec2Model + except Exception as e: + print("Error: FairSeq is not properly installed.") + print( + "Please install FairSeq: cd ${MAIN_ROOT}/tools && make fairseq.done" + ) + raise e + + self.w2v_model_path = download_w2v(w2v_url, w2v_dir_path) + + self._output_size = output_size + + models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [self.w2v_model_path], + arg_overrides={"data": w2v_dir_path}, + ) + model = models[0] + + if not isinstance(model, Wav2Vec2Model): + try: + model = model.w2v_encoder.w2v_model + except Exception as e: + print( + "Error: pretrained models should be within: " + "'Wav2Vec2Model, Wav2VecCTC' classes, etc." + ) + raise e + + self.encoders = model + + self.pretrained_params = copy.deepcopy(model.state_dict()) + + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = LayerNorm(output_size) + + if model.cfg.encoder_embed_dim != output_size: + # TODO(xkc09): try LSTM + self.output_layer = torch.nn.Sequential( + torch.nn.Linear(model.cfg.encoder_embed_dim, output_size), + ) + else: + self.output_layer = None + + self.freeze_finetune_updates = freeze_finetune_updates + self.register_buffer("num_updates", torch.LongTensor([0])) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Forward FairSeqWav2Vec2 Encoder. + + Args: + xs_pad: input tensor (B, L, D) + ilens: input length (B) + prev_states: Not to be used now. + Returns: + position embedded tensor and mask + """ + masks = make_pad_mask(ilens).to(xs_pad.device) + + ft = self.freeze_finetune_updates <= self.num_updates + if self.num_updates <= self.freeze_finetune_updates: + self.num_updates += 1 + elif ft and self.num_updates == self.freeze_finetune_updates + 1: + self.num_updates += 1 + logging.info("Start fine-tuning wav2vec parameters!") + + with torch.no_grad() if not ft else contextlib.nullcontext(): + enc_outputs = self.encoders( + xs_pad, + masks, + features_only=True, + ) + + xs_pad = enc_outputs["x"] # (B,T,C), + masks = enc_outputs["padding_mask"] # (B, T) + + olens = (~masks).sum(dim=1) + + if self.output_layer is not None: + xs_pad = self.output_layer(xs_pad) + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + + return xs_pad, olens, None + + def reload_pretrained_parameters(self): + self.encoders.load_state_dict(self.pretrained_params) + logging.info("Pretrained Wav2Vec model parameters reloaded!") + + +def download_w2v(model_url, dir_path): + os.makedirs(dir_path, exist_ok=True) + + model_name = model_url.split("/")[-1] + model_path = os.path.join(dir_path, model_name) + + dict_url = "https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt" + dict_path = os.path.join(dir_path, dict_url.split("/")[-1]) + + with FileLock(model_path + ".lock"): + if not os.path.exists(model_path): + torch.hub.download_url_to_file(model_url, model_path) + torch.hub.download_url_to_file(dict_url, dict_path) + logging.info(f"Wav2Vec model downloaded {model_path}") + else: + logging.info(f"Wav2Vec model {model_path} already exists.") + + return model_path diff --git a/espnet2/asr/espnet_model.py b/espnet2/asr/espnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c5763d93a80af3965553a890c108db800ebd9069 --- /dev/null +++ b/espnet2/asr/espnet_model.py @@ -0,0 +1,297 @@ +from contextlib import contextmanager +from distutils.version import LooseVersion +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +from typeguard import check_argument_types + +from espnet.nets.e2e_asr_common import ErrorCalculator +from espnet.nets.pytorch_backend.nets_utils import th_accuracy +from espnet.nets.pytorch_backend.transformer.add_sos_eos import add_sos_eos +from espnet.nets.pytorch_backend.transformer.label_smoothing_loss import ( + LabelSmoothingLoss, # noqa: H301 +) +from espnet2.asr.ctc import CTC +from espnet2.asr.decoder.abs_decoder import AbsDecoder +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder +from espnet2.asr.specaug.abs_specaug import AbsSpecAug +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.train.abs_espnet_model import AbsESPnetModel + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class ESPnetASRModel(AbsESPnetModel): + """CTC-attention hybrid Encoder-Decoder model""" + + def __init__( + self, + vocab_size: int, + token_list: Union[Tuple[str, ...], List[str]], + frontend: Optional[AbsFrontend], + specaug: Optional[AbsSpecAug], + normalize: Optional[AbsNormalize], + preencoder: Optional[AbsPreEncoder], + encoder: AbsEncoder, + decoder: AbsDecoder, + ctc: CTC, + rnnt_decoder: None, + ctc_weight: float = 0.5, + ignore_id: int = -1, + lsm_weight: float = 0.0, + length_normalized_loss: bool = False, + report_cer: bool = True, + report_wer: bool = True, + sym_space: str = "", + sym_blank: str = "", + ): + assert check_argument_types() + assert 0.0 <= ctc_weight <= 1.0, ctc_weight + assert rnnt_decoder is None, "Not implemented" + + super().__init__() + # note that eos is the same as sos (equivalent ID) + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + self.vocab_size = vocab_size + self.ignore_id = ignore_id + self.ctc_weight = ctc_weight + self.token_list = token_list.copy() + + self.frontend = frontend + self.specaug = specaug + self.normalize = normalize + self.preencoder = preencoder + self.encoder = encoder + self.decoder = decoder + if ctc_weight == 0.0: + self.ctc = None + else: + self.ctc = ctc + self.rnnt_decoder = rnnt_decoder + self.criterion_att = LabelSmoothingLoss( + size=vocab_size, + padding_idx=ignore_id, + smoothing=lsm_weight, + normalize_length=length_normalized_loss, + ) + + if report_cer or report_wer: + self.error_calculator = ErrorCalculator( + token_list, sym_space, sym_blank, report_cer, report_wer + ) + else: + self.error_calculator = None + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + text: (Batch, Length) + text_lengths: (Batch,) + """ + assert text_lengths.dim() == 1, text_lengths.shape + # Check that batch_size is unified + assert ( + speech.shape[0] + == speech_lengths.shape[0] + == text.shape[0] + == text_lengths.shape[0] + ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) + batch_size = speech.shape[0] + + # for data-parallel + text = text[:, : text_lengths.max()] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # 2a. Attention-decoder branch + if self.ctc_weight == 1.0: + loss_att, acc_att, cer_att, wer_att = None, None, None, None + else: + loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 2b. CTC branch + if self.ctc_weight == 0.0: + loss_ctc, cer_ctc = None, None + else: + loss_ctc, cer_ctc = self._calc_ctc_loss( + encoder_out, encoder_out_lens, text, text_lengths + ) + + # 2c. RNN-T branch + if self.rnnt_decoder is not None: + _ = self._calc_rnnt_loss(encoder_out, encoder_out_lens, text, text_lengths) + + if self.ctc_weight == 0.0: + loss = loss_att + elif self.ctc_weight == 1.0: + loss = loss_ctc + else: + loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + + stats = dict( + loss=loss.detach(), + loss_att=loss_att.detach() if loss_att is not None else None, + loss_ctc=loss_ctc.detach() if loss_ctc is not None else None, + acc=acc_att, + cer=cer_att, + wer=wer_att, + cer_ctc=cer_ctc, + ) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + text: torch.Tensor, + text_lengths: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder. Note that this method is used by asr_inference.py + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch, ) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Data augmentation + if self.specaug is not None and self.training: + feats, feats_lengths = self.specaug(feats, feats_lengths) + + # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # Pre-encoder, e.g. used for raw input data + if self.preencoder is not None: + feats, feats_lengths = self.preencoder(feats, feats_lengths) + + # 4. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim2) + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + return encoder_out, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def _calc_att_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) + ys_in_lens = ys_pad_lens + 1 + + # 1. Forward decoder + decoder_out, _ = self.decoder( + encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens + ) + + # 2. Compute attention loss + loss_att = self.criterion_att(decoder_out, ys_out_pad) + acc_att = th_accuracy( + decoder_out.view(-1, self.vocab_size), + ys_out_pad, + ignore_label=self.ignore_id, + ) + + # Compute cer/wer using attention-decoder + if self.training or self.error_calculator is None: + cer_att, wer_att = None, None + else: + ys_hat = decoder_out.argmax(dim=-1) + cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) + + return loss_att, acc_att, cer_att, wer_att + + def _calc_ctc_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + # Calc CTC loss + loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) + + # Calc CER using CTC + cer_ctc = None + if not self.training and self.error_calculator is not None: + ys_hat = self.ctc.argmax(encoder_out).data + cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) + return loss_ctc, cer_ctc + + def _calc_rnnt_loss( + self, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_pad: torch.Tensor, + ys_pad_lens: torch.Tensor, + ): + raise NotImplementedError diff --git a/espnet2/asr/frontend/__init__.py b/espnet2/asr/frontend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/asr/frontend/abs_frontend.py b/espnet2/asr/frontend/abs_frontend.py new file mode 100644 index 0000000000000000000000000000000000000000..538236fe9443a9d054164eb07ce8084bb1d2f5bf --- /dev/null +++ b/espnet2/asr/frontend/abs_frontend.py @@ -0,0 +1,17 @@ +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + + +class AbsFrontend(torch.nn.Module, ABC): + @abstractmethod + def output_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/espnet2/asr/frontend/default.py b/espnet2/asr/frontend/default.py new file mode 100644 index 0000000000000000000000000000000000000000..6c4a5da7a91d4076884963b32bc9ca8b8ff69713 --- /dev/null +++ b/espnet2/asr/frontend/default.py @@ -0,0 +1,131 @@ +import copy +from typing import Optional +from typing import Tuple +from typing import Union + +import humanfriendly +import numpy as np +import torch +from torch_complex.tensor import ComplexTensor +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.frontends.frontend import Frontend +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.layers.log_mel import LogMel +from espnet2.layers.stft import Stft +from espnet2.utils.get_default_kwargs import get_default_kwargs + + +class DefaultFrontend(AbsFrontend): + """Conventional frontend structure for ASR. + + Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN + """ + + def __init__( + self, + fs: Union[int, str] = 16000, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + n_mels: int = 80, + fmin: int = None, + fmax: int = None, + htk: bool = False, + frontend_conf: Optional[dict] = get_default_kwargs(Frontend), + apply_stft: bool = True, + ): + assert check_argument_types() + super().__init__() + if isinstance(fs, str): + fs = humanfriendly.parse_size(fs) + + # Deepcopy (In general, dict shouldn't be used as default arg) + frontend_conf = copy.deepcopy(frontend_conf) + + if apply_stft: + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + center=center, + window=window, + normalized=normalized, + onesided=onesided, + ) + else: + self.stft = None + self.apply_stft = apply_stft + + if frontend_conf is not None: + self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) + else: + self.frontend = None + + self.logmel = LogMel( + fs=fs, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + htk=htk, + ) + self.n_mels = n_mels + + def output_size(self) -> int: + return self.n_mels + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Domain-conversion: e.g. Stft: time -> time-freq + if self.stft is not None: + input_stft, feats_lens = self._compute_stft(input, input_lengths) + else: + input_stft = ComplexTensor(input[..., 0], input[..., 1]) + feats_lens = input_lengths + # 2. [Option] Speech enhancement + if self.frontend is not None: + assert isinstance(input_stft, ComplexTensor), type(input_stft) + # input_stft: (Batch, Length, [Channel], Freq) + input_stft, _, mask = self.frontend(input_stft, feats_lens) + + # 3. [Multi channel case]: Select a channel + if input_stft.dim() == 4: + # h: (B, T, C, F) -> h: (B, T, F) + if self.training: + # Select 1ch randomly + ch = np.random.randint(input_stft.size(2)) + input_stft = input_stft[:, :, ch, :] + else: + # Use the first channel + input_stft = input_stft[:, :, 0, :] + + # 4. STFT -> Power spectrum + # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) + input_power = input_stft.real ** 2 + input_stft.imag ** 2 + + # 5. Feature transform e.g. Stft -> Log-Mel-Fbank + # input_power: (Batch, [Channel,] Length, Freq) + # -> input_feats: (Batch, Length, Dim) + input_feats, _ = self.logmel(input_power, feats_lens) + + return input_feats, feats_lens + + def _compute_stft( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> torch.Tensor: + input_stft, feats_lens = self.stft(input, input_lengths) + + assert input_stft.dim() >= 4, input_stft.shape + # "2" refers to the real/imag parts of Complex + assert input_stft.shape[-1] == 2, input_stft.shape + + # Change torch.Tensor to ComplexTensor + # input_stft: (..., F, 2) -> (..., F) + input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1]) + return input_stft, feats_lens diff --git a/espnet2/asr/frontend/windowing.py b/espnet2/asr/frontend/windowing.py new file mode 100644 index 0000000000000000000000000000000000000000..55600ca30d874bde73acb997ef6ae91c6d38af4f --- /dev/null +++ b/espnet2/asr/frontend/windowing.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +# 2020, Technische Universität München; Ludwig Kürzinger +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Sliding Window for raw audio input data.""" + +from espnet2.asr.frontend.abs_frontend import AbsFrontend +import torch +from typeguard import check_argument_types +from typing import Tuple + + +class SlidingWindow(AbsFrontend): + """Sliding Window. + + Provides a sliding window over a batched continuous raw audio tensor. + Optionally, provides padding (Currently not implemented). + Combine this module with a pre-encoder compatible with raw audio data, + for example Sinc convolutions. + + Known issues: + Output length is calculated incorrectly if audio shorter than win_length. + WARNING: trailing values are discarded - padding not implemented yet. + There is currently no additional window function applied to input values. + """ + + def __init__( + self, + win_length: int = 400, + hop_length: int = 160, + channels: int = 1, + padding: int = None, + fs=None, + ): + """Initialize. + + Args: + win_length: Length of frame. + hop_length: Relative starting point of next frame. + channels: Number of input channels. + padding: Padding (placeholder, currently not implemented). + fs: Sampling rate (placeholder for compatibility, not used). + """ + assert check_argument_types() + super().__init__() + self.fs = fs + self.win_length = win_length + self.hop_length = hop_length + self.channels = channels + self.padding = padding + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply a sliding window on the input. + + Args: + input: Input (B, T, C*D) or (B, T*C*D), with D=C=1. + input_lengths: Input lengths within batch. + + Returns: + Tensor: Output with dimensions (B, T, C, D), with D=win_length. + Tensor: Output lengths within batch. + """ + input_size = input.size() + B = input_size[0] + T = input_size[1] + C = self.channels + D = self.win_length + # (B, T, C) --> (T, B, C) + continuous = input.view(B, T, C).permute(1, 0, 2) + windowed = continuous.unfold(0, D, self.hop_length) + # (T, B, C, D) --> (B, T, C, D) + output = windowed.permute(1, 0, 2, 3).contiguous() + # After unfold(), windowed lengths change: + output_lengths = (input_lengths - self.win_length) // self.hop_length + 1 + return output, output_lengths + + def output_size(self) -> int: + """Return output length of feature dimension D, i.e. the window length.""" + return self.win_length diff --git a/espnet2/asr/preencoder/__init__.py b/espnet2/asr/preencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/asr/preencoder/abs_preencoder.py b/espnet2/asr/preencoder/abs_preencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3ecdc6b91f00ffc5680e6d0bf3eeb86d58bf74a5 --- /dev/null +++ b/espnet2/asr/preencoder/abs_preencoder.py @@ -0,0 +1,17 @@ +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + + +class AbsPreEncoder(torch.nn.Module, ABC): + @abstractmethod + def output_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/espnet2/asr/preencoder/sinc.py b/espnet2/asr/preencoder/sinc.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9dfa6e4c094b6f8cf37491895561a7ed358f53 --- /dev/null +++ b/espnet2/asr/preencoder/sinc.py @@ -0,0 +1,282 @@ +#!/usr/bin/env python3 +# 2020, Technische Universität München; Ludwig Kürzinger +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Sinc convolutions for raw audio input.""" + +from collections import OrderedDict +from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder +from espnet2.layers.sinc_conv import LogCompression +from espnet2.layers.sinc_conv import SincConv +import humanfriendly +import torch +from typeguard import check_argument_types +from typing import Optional +from typing import Tuple +from typing import Union + + +class LightweightSincConvs(AbsPreEncoder): + """Lightweight Sinc Convolutions. + + Instead of using precomputed features, end-to-end speech recognition + can also be done directly from raw audio using sinc convolutions, as + described in "Lightweight End-to-End Speech Recognition from Raw Audio + Data Using Sinc-Convolutions" by Kürzinger et al. + https://arxiv.org/abs/2010.07597 + + To use Sinc convolutions in your model instead of the default f-bank + frontend, set this module as your pre-encoder with `preencoder: sinc` + and use the input of the sliding window frontend with + `frontend: sliding_window` in your yaml configuration file. + So that the process flow is: + + Frontend (SlidingWindow) -> SpecAug -> Normalization -> + Pre-encoder (LightweightSincConvs) -> Encoder -> Decoder + + Note that this method also performs data augmentation in time domain + (vs. in spectral domain in the default frontend). + Use `plot_sinc_filters.py` to visualize the learned Sinc filters. + """ + + def __init__( + self, + fs: Union[int, str, float] = 16000, + in_channels: int = 1, + out_channels: int = 256, + activation_type: str = "leakyrelu", + dropout_type: str = "dropout", + windowing_type: str = "hamming", + scale_type: str = "mel", + ): + """Initialize the module. + + Args: + fs: Sample rate. + in_channels: Number of input channels. + out_channels: Number of output channels (for each input channel). + activation_type: Choice of activation function. + dropout_type: Choice of dropout function. + windowing_type: Choice of windowing function. + scale_type: Choice of filter-bank initialization scale. + """ + assert check_argument_types() + super().__init__() + if isinstance(fs, str): + fs = humanfriendly.parse_size(fs) + self.fs = fs + self.in_channels = in_channels + self.out_channels = out_channels + self.activation_type = activation_type + self.dropout_type = dropout_type + self.windowing_type = windowing_type + self.scale_type = scale_type + + self.choices_dropout = { + "dropout": torch.nn.Dropout, + "spatial": SpatialDropout, + "dropout2d": torch.nn.Dropout2d, + } + if dropout_type not in self.choices_dropout: + raise NotImplementedError( + f"Dropout type has to be one of " + f"{list(self.choices_dropout.keys())}", + ) + + self.choices_activation = { + "leakyrelu": torch.nn.LeakyReLU, + "relu": torch.nn.ReLU, + } + if activation_type not in self.choices_activation: + raise NotImplementedError( + f"Activation type has to be one of " + f"{list(self.choices_activation.keys())}", + ) + + # initialization + self._create_sinc_convs() + # Sinc filters require custom initialization + self.espnet_initialization_fn() + + def _create_sinc_convs(self): + blocks = OrderedDict() + + # SincConvBlock + out_channels = 128 + self.filters = SincConv( + self.in_channels, + out_channels, + kernel_size=101, + stride=1, + fs=self.fs, + window_func=self.windowing_type, + scale_type=self.scale_type, + ) + block = OrderedDict( + [ + ("Filters", self.filters), + ("LogCompression", LogCompression()), + ("BatchNorm", torch.nn.BatchNorm1d(out_channels, affine=True)), + ("AvgPool", torch.nn.AvgPool1d(2)), + ] + ) + blocks["SincConvBlock"] = torch.nn.Sequential(block) + in_channels = out_channels + + # First convolutional block, connects the sinc output to the front-end "body" + out_channels = 128 + blocks["DConvBlock1"] = self.gen_lsc_block( + in_channels, + out_channels, + depthwise_kernel_size=25, + depthwise_stride=2, + pointwise_groups=0, + avgpool=True, + dropout_probability=0.1, + ) + in_channels = out_channels + + # Second convolutional block, multiple convolutional layers + out_channels = self.out_channels + for layer in [2, 3, 4]: + blocks[f"DConvBlock{layer}"] = self.gen_lsc_block( + in_channels, out_channels, depthwise_kernel_size=9, depthwise_stride=1 + ) + in_channels = out_channels + + # Third Convolutional block, acts as coupling to encoder + out_channels = self.out_channels + blocks["DConvBlock5"] = self.gen_lsc_block( + in_channels, + out_channels, + depthwise_kernel_size=7, + depthwise_stride=1, + pointwise_groups=0, + ) + + self.blocks = torch.nn.Sequential(blocks) + + def gen_lsc_block( + self, + in_channels: int, + out_channels: int, + depthwise_kernel_size: int = 9, + depthwise_stride: int = 1, + depthwise_groups=None, + pointwise_groups=0, + dropout_probability: float = 0.15, + avgpool=False, + ): + """Generate a convolutional block for Lightweight Sinc convolutions. + + Each block consists of either a depthwise or a depthwise-separable + convolutions together with dropout, (batch-)normalization layer, and + an optional average-pooling layer. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + depthwise_kernel_size: Kernel size of the depthwise convolution. + depthwise_stride: Stride of the depthwise convolution. + depthwise_groups: Number of groups of the depthwise convolution. + pointwise_groups: Number of groups of the pointwise convolution. + dropout_probability: Dropout probability in the block. + avgpool: If True, an AvgPool layer is inserted. + + Returns: + torch.nn.Sequential: Neural network building block. + """ + block = OrderedDict() + if not depthwise_groups: + # GCD(in_channels, out_channels) to prevent size mismatches + depthwise_groups, r = in_channels, out_channels + while r != 0: + depthwise_groups, r = depthwise_groups, depthwise_groups % r + block["depthwise"] = torch.nn.Conv1d( + in_channels, + out_channels, + depthwise_kernel_size, + depthwise_stride, + groups=depthwise_groups, + ) + if pointwise_groups: + block["pointwise"] = torch.nn.Conv1d( + out_channels, out_channels, 1, 1, groups=pointwise_groups + ) + block["activation"] = self.choices_activation[self.activation_type]() + block["batchnorm"] = torch.nn.BatchNorm1d(out_channels, affine=True) + if avgpool: + block["avgpool"] = torch.nn.AvgPool1d(2) + block["dropout"] = self.choices_dropout[self.dropout_type](dropout_probability) + return torch.nn.Sequential(block) + + def espnet_initialization_fn(self): + """Initialize sinc filters with filterbank values.""" + self.filters.init_filters() + for block in self.blocks: + for layer in block: + if type(layer) == torch.nn.BatchNorm1d and layer.affine: + layer.weight.data[:] = 1.0 + layer.bias.data[:] = 0.0 + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply Lightweight Sinc Convolutions. + + The input shall be formatted as (B, T, C_in, D_in) + with B as batch size, T as time dimension, C_in as channels, + and D_in as feature dimension. + + The output will then be (B, T, C_out*D_out) + with C_out and D_out as output dimensions. + + The current module structure only handles D_in=400, so that D_out=1. + Remark for the multichannel case: C_out is the number of out_channels + given at initialization multiplied with C_in. + """ + # Transform input data: + # (B, T, C_in, D_in) -> (B*T, C_in, D_in) + B, T, C_in, D_in = input.size() + input_frames = input.view(B * T, C_in, D_in) + output_frames = self.blocks.forward(input_frames) + + # ---TRANSFORM: (B*T, C_out, D_out) -> (B, T, C_out*D_out) + _, C_out, D_out = output_frames.size() + output_frames = output_frames.view(B, T, C_out * D_out) + return output_frames, input_lengths # no state in this layer + + def output_size(self) -> int: + """Get the output size.""" + return self.out_channels * self.in_channels + + +class SpatialDropout(torch.nn.Module): + """Spatial dropout module. + + Apply dropout to full channels on tensors of input (B, C, D) + """ + + def __init__( + self, + dropout_probability: float = 0.15, + shape: Optional[Union[tuple, list]] = None, + ): + """Initialize. + + Args: + dropout_probability: Dropout probability. + shape (tuple, list): Shape of input tensors. + """ + assert check_argument_types() + super().__init__() + if shape is None: + shape = (0, 2, 1) + self.dropout = torch.nn.Dropout2d(dropout_probability) + self.shape = (shape,) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward of spatial dropout module.""" + y = x.permute(*self.shape) + y = self.dropout(y) + return y.permute(*self.shape) diff --git a/espnet2/asr/specaug/__init__.py b/espnet2/asr/specaug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/asr/specaug/abs_specaug.py b/espnet2/asr/specaug/abs_specaug.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbac418fb631ae29ab6bb01c2c82b47b0016d22 --- /dev/null +++ b/espnet2/asr/specaug/abs_specaug.py @@ -0,0 +1,18 @@ +from typing import Optional +from typing import Tuple + +import torch + + +class AbsSpecAug(torch.nn.Module): + """Abstract class for the augmentation of spectrogram + + The process-flow: + + Frontend -> SpecAug -> Normalization -> Encoder -> Decoder + """ + + def forward( + self, x: torch.Tensor, x_lengths: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + raise NotImplementedError diff --git a/espnet2/asr/specaug/specaug.py b/espnet2/asr/specaug/specaug.py new file mode 100644 index 0000000000000000000000000000000000000000..6cfeb1ce00e875f789dab7f411d6a1d0c947d2f3 --- /dev/null +++ b/espnet2/asr/specaug/specaug.py @@ -0,0 +1,84 @@ +from distutils.version import LooseVersion +from typing import Sequence +from typing import Union + +import torch + +from espnet2.asr.specaug.abs_specaug import AbsSpecAug +from espnet2.layers.mask_along_axis import MaskAlongAxis +from espnet2.layers.time_warp import TimeWarp + + +if LooseVersion(torch.__version__) >= LooseVersion("1.1"): + DEFAULT_TIME_WARP_MODE = "bicubic" +else: + # pytorch1.0 doesn't implement bicubic + DEFAULT_TIME_WARP_MODE = "bilinear" + + +class SpecAug(AbsSpecAug): + """Implementation of SpecAug. + + Reference: + Daniel S. Park et al. + "SpecAugment: A Simple Data + Augmentation Method for Automatic Speech Recognition" + + .. warning:: + When using cuda mode, time_warp doesn't have reproducibility + due to `torch.nn.functional.interpolate`. + + """ + + def __init__( + self, + apply_time_warp: bool = True, + time_warp_window: int = 5, + time_warp_mode: str = DEFAULT_TIME_WARP_MODE, + apply_freq_mask: bool = True, + freq_mask_width_range: Union[int, Sequence[int]] = (0, 20), + num_freq_mask: int = 2, + apply_time_mask: bool = True, + time_mask_width_range: Union[int, Sequence[int]] = (0, 100), + num_time_mask: int = 2, + ): + if not apply_time_warp and not apply_time_mask and not apply_freq_mask: + raise ValueError( + "Either one of time_warp, time_mask, or freq_mask should be applied", + ) + super().__init__() + self.apply_time_warp = apply_time_warp + self.apply_freq_mask = apply_freq_mask + self.apply_time_mask = apply_time_mask + + if apply_time_warp: + self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode) + else: + self.time_warp = None + + if apply_freq_mask: + self.freq_mask = MaskAlongAxis( + dim="freq", + mask_width_range=freq_mask_width_range, + num_mask=num_freq_mask, + ) + else: + self.freq_mask = None + + if apply_time_mask: + self.time_mask = MaskAlongAxis( + dim="time", + mask_width_range=time_mask_width_range, + num_mask=num_time_mask, + ) + else: + self.time_mask = None + + def forward(self, x, x_lengths=None): + if self.time_warp is not None: + x, x_lengths = self.time_warp(x, x_lengths) + if self.freq_mask is not None: + x, x_lengths = self.freq_mask(x, x_lengths) + if self.time_mask is not None: + x, x_lengths = self.time_mask(x, x_lengths) + return x, x_lengths diff --git a/espnet2/bin/__init__.py b/espnet2/bin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/bin/aggregate_stats_dirs.py b/espnet2/bin/aggregate_stats_dirs.py new file mode 100644 index 0000000000000000000000000000000000000000..b79e67c399d172a915f26055486ba85e127b30e4 --- /dev/null +++ b/espnet2/bin/aggregate_stats_dirs.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +import argparse +import logging +from pathlib import Path +import sys +from typing import Iterable +from typing import Union + +import numpy as np + +from espnet.utils.cli_utils import get_commandline_args + + +def aggregate_stats_dirs( + input_dir: Iterable[Union[str, Path]], + output_dir: Union[str, Path], + log_level: str, + skip_sum_stats: bool, +): + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) (levelname)s: %(message)s", + ) + + input_dirs = [Path(p) for p in input_dir] + output_dir = Path(output_dir) + + for mode in ["train", "valid"]: + with (input_dirs[0] / mode / "batch_keys").open("r", encoding="utf-8") as f: + batch_keys = [line.strip() for line in f if line.strip() != ""] + with (input_dirs[0] / mode / "stats_keys").open("r", encoding="utf-8") as f: + stats_keys = [line.strip() for line in f if line.strip() != ""] + (output_dir / mode).mkdir(parents=True, exist_ok=True) + + for key in batch_keys: + with (output_dir / mode / f"{key}_shape").open( + "w", encoding="utf-8" + ) as fout: + for idir in input_dirs: + with (idir / mode / f"{key}_shape").open( + "r", encoding="utf-8" + ) as fin: + # Read to the last in order to sort keys + # because the order can be changed if num_workers>=1 + lines = fin.readlines() + lines = sorted(lines, key=lambda x: x.split()[0]) + for line in lines: + fout.write(line) + + for key in stats_keys: + if not skip_sum_stats: + sum_stats = None + for idir in input_dirs: + stats = np.load(idir / mode / f"{key}_stats.npz") + if sum_stats is None: + sum_stats = dict(**stats) + else: + for k in stats: + sum_stats[k] += stats[k] + + np.savez(output_dir / mode / f"{key}_stats.npz", **sum_stats) + + # if --write_collected_feats=true + p = Path(mode) / "collect_feats" / f"{key}.scp" + scp = input_dirs[0] / p + if scp.exists(): + (output_dir / p).parent.mkdir(parents=True, exist_ok=True) + with (output_dir / p).open("w", encoding="utf-8") as fout: + for idir in input_dirs: + with (idir / p).open("r", encoding="utf-8") as fin: + for line in fin: + fout.write(line) + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Aggregate statistics directories into one directory", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + parser.add_argument( + "--skip_sum_stats", + default=False, + action="store_true", + help="Skip computing the sum of statistics.", + ) + + parser.add_argument("--input_dir", action="append", help="Input directories") + parser.add_argument("--output_dir", required=True, help="Output directory") + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + aggregate_stats_dirs(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/asr_inference.py b/espnet2/bin/asr_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7e45732bac6d32126011b9924f997851389e11 --- /dev/null +++ b/espnet2/bin/asr_inference.py @@ -0,0 +1,463 @@ +#!/usr/bin/env python3 +import argparse +import logging +from pathlib import Path +import sys +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type +from typing import List + +from espnet.nets.batch_beam_search import BatchBeamSearch +from espnet.nets.batch_beam_search_online_sim import BatchBeamSearchOnlineSim +from espnet.nets.beam_search import BeamSearch +from espnet.nets.beam_search import Hypothesis +from espnet.nets.pytorch_backend.transformer.subsampling import TooShortUttError +from espnet.nets.scorer_interface import BatchScorerInterface +from espnet.nets.scorers.ctc import CTCPrefixScorer +from espnet.nets.scorers.length_bonus import LengthBonus +from espnet.utils.cli_utils import get_commandline_args +from espnet2.fileio.datadir_writer import DatadirWriter +from espnet2.tasks.asr import ASRTask +from espnet2.tasks.lm import LMTask +from espnet2.text.build_tokenizer import build_tokenizer +from espnet2.text.token_id_converter import TokenIDConverter +from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.set_all_random_seed import set_all_random_seed +from espnet2.utils import config_argparse +from espnet2.utils.types import str2bool +from espnet2.utils.types import str2triple_str +from espnet2.utils.types import str_or_none + + +class Speech2Text: + """Speech2Text class + + Examples: + >>> import soundfile + >>> speech2text = Speech2Text("asr_config.yml", "asr.pth") + >>> audio, rate = soundfile.read("speech.wav") + >>> speech2text(audio) + [(text, token, token_int, hypothesis object), ...] + + """ + + def __init__( + self, + asr_train_config: Union[Path, str], + asr_model_file: Union[Path, str] = None, + lm_train_config: Union[Path, str] = None, + lm_file: Union[Path, str] = None, + token_type: str = None, + bpemodel: str = None, + device: str = "cpu", + maxlenratio: float = 0.0, + minlenratio: float = 0.0, + batch_size: int = 1, + dtype: str = "float32", + beam_size: int = 20, + ctc_weight: float = 0.5, + lm_weight: float = 1.0, + penalty: float = 0.0, + nbest: int = 1, + streaming: bool = False, + ): + assert check_argument_types() + + # 1. Build ASR model + scorers = {} + asr_model, asr_train_args = ASRTask.build_model_from_file( + asr_train_config, asr_model_file, device + ) + asr_model.to(dtype=getattr(torch, dtype)).eval() + + decoder = asr_model.decoder + ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) + token_list = asr_model.token_list + scorers.update( + decoder=decoder, + ctc=ctc, + length_bonus=LengthBonus(len(token_list)), + ) + + # 2. Build Language model + if lm_train_config is not None: + lm, lm_train_args = LMTask.build_model_from_file( + lm_train_config, lm_file, device + ) + scorers["lm"] = lm.lm + + # 3. Build BeamSearch object + weights = dict( + decoder=1.0 - ctc_weight, + ctc=ctc_weight, + lm=lm_weight, + length_bonus=penalty, + ) + beam_search = BeamSearch( + beam_size=beam_size, + weights=weights, + scorers=scorers, + sos=asr_model.sos, + eos=asr_model.eos, + vocab_size=len(token_list), + token_list=token_list, + pre_beam_score_key=None if ctc_weight == 1.0 else "full", + ) + # TODO(karita): make all scorers batchfied + if batch_size == 1: + non_batch = [ + k + for k, v in beam_search.full_scorers.items() + if not isinstance(v, BatchScorerInterface) + ] + if len(non_batch) == 0: + if streaming: + beam_search.__class__ = BatchBeamSearchOnlineSim + beam_search.set_streaming_config(asr_train_config) + logging.info("BatchBeamSearchOnlineSim implementation is selected.") + else: + beam_search.__class__ = BatchBeamSearch + logging.info("BatchBeamSearch implementation is selected.") + else: + logging.warning( + f"As non-batch scorers {non_batch} are found, " + f"fall back to non-batch implementation." + ) + beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() + for scorer in scorers.values(): + if isinstance(scorer, torch.nn.Module): + scorer.to(device=device, dtype=getattr(torch, dtype)).eval() + logging.info(f"Beam_search: {beam_search}") + logging.info(f"Decoding device={device}, dtype={dtype}") + + # 4. [Optional] Build Text converter: e.g. bpe-sym -> Text + if token_type is None: + token_type = asr_train_args.token_type + if bpemodel is None: + bpemodel = asr_train_args.bpemodel + + if token_type is None: + tokenizer = None + elif token_type == "bpe": + if bpemodel is not None: + tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) + else: + tokenizer = None + else: + tokenizer = build_tokenizer(token_type=token_type) + converter = TokenIDConverter(token_list=token_list) + logging.info(f"Text tokenizer: {tokenizer}") + + self.asr_model = asr_model + self.asr_train_args = asr_train_args + self.converter = converter + self.tokenizer = tokenizer + self.beam_search = beam_search + self.maxlenratio = maxlenratio + self.minlenratio = minlenratio + self.device = device + self.dtype = dtype + self.nbest = nbest + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray] + ) -> List[Tuple[Optional[str], List[str], List[int], Hypothesis]]: + """Inference + + Args: + data: Input speech data + Returns: + text, token, token_int, hyp + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.tensor(speech) + + # data: (Nsamples,) -> (1, Nsamples) + speech = speech.unsqueeze(0).to(getattr(torch, self.dtype)) + # lenghts: (1,) + lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) + batch = {"speech": speech, "speech_lengths": lengths} + + # a. To device + batch = to_device(batch, device=self.device) + + # b. Forward Encoder + enc, _ = self.asr_model.encode(**batch) + assert len(enc) == 1, len(enc) + + # c. Passed the encoder result and the beam search + nbest_hyps = self.beam_search( + x=enc[0], maxlenratio=self.maxlenratio, minlenratio=self.minlenratio + ) + nbest_hyps = nbest_hyps[: self.nbest] + + results = [] + for hyp in nbest_hyps: + assert isinstance(hyp, Hypothesis), type(hyp) + + # remove sos/eos and get results + token_int = hyp.yseq[1:-1].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x != 0, token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + + if self.tokenizer is not None: + text = self.tokenizer.tokens2text(token) + else: + text = None + results.append((text, token, token_int, hyp)) + + assert check_return_type(results) + return results + + +def inference( + output_dir: str, + maxlenratio: float, + minlenratio: float, + batch_size: int, + dtype: str, + beam_size: int, + ngpu: int, + seed: int, + ctc_weight: float, + lm_weight: float, + penalty: float, + nbest: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + key_file: Optional[str], + asr_train_config: str, + asr_model_file: str, + lm_train_config: Optional[str], + lm_file: Optional[str], + word_lm_train_config: Optional[str], + word_lm_file: Optional[str], + token_type: Optional[str], + bpemodel: Optional[str], + allow_variable_data_keys: bool, + streaming: bool, +): + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if word_lm_train_config is not None: + raise NotImplementedError("Word LM is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1: + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build speech2text + speech2text = Speech2Text( + asr_train_config=asr_train_config, + asr_model_file=asr_model_file, + lm_train_config=lm_train_config, + lm_file=lm_file, + token_type=token_type, + bpemodel=bpemodel, + device=device, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + dtype=dtype, + beam_size=beam_size, + ctc_weight=ctc_weight, + lm_weight=lm_weight, + penalty=penalty, + nbest=nbest, + streaming=streaming, + ) + + # 3. Build data-iterator + loader = ASRTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), + collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + # 7 .Start for-loop + # FIXME(kamo): The output format should be discussed about + with DatadirWriter(output_dir) as writer: + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + + # N-best list of (text, token, token_int, hyp_object) + try: + results = speech2text(**batch) + except TooShortUttError as e: + logging.warning(f"Utterance {keys} {e}") + hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) + results = [[" ", [""], [2], hyp]] * nbest + + # Only supporting batch_size==1 + key = keys[0] + for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): + # Create a directory: outdir/{n}best_recog + ibest_writer = writer[f"{n}best_recog"] + + # Write the result to each file + ibest_writer["token"][key] = " ".join(token) + ibest_writer["token_int"][key] = " ".join(map(str, token_int)) + ibest_writer["score"][key] = str(hyp.score) + + if text is not None: + ibest_writer["text"][key] = text + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="ASR Decoding", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=True, + action="append", + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument("--asr_train_config", type=str, required=True) + group.add_argument("--asr_model_file", type=str, required=True) + group.add_argument("--lm_train_config", type=str) + group.add_argument("--lm_file", type=str) + group.add_argument("--word_lm_train_config", type=str) + group.add_argument("--word_lm_file", type=str) + + group = parser.add_argument_group("Beam-search related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") + group.add_argument("--beam_size", type=int, default=20, help="Beam size") + group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") + group.add_argument( + "--maxlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain max output length. " + "If maxlenratio=0.0 (default), it uses a end-detect " + "function " + "to automatically find maximum hypothesis lengths", + ) + group.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Input length ratio to obtain min output length", + ) + group.add_argument( + "--ctc_weight", + type=float, + default=0.5, + help="CTC weight in joint decoding", + ) + group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") + group.add_argument("--streaming", type=str2bool, default=False) + + group = parser.add_argument_group("Text converter related") + group.add_argument( + "--token_type", + type=str_or_none, + default=None, + choices=["char", "bpe", None], + help="The token type for ASR model. " + "If not given, refers from the training args", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model path of sentencepiece. " + "If not given, refers from the training args", + ) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + inference(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/asr_train.py b/espnet2/bin/asr_train.py new file mode 100644 index 0000000000000000000000000000000000000000..53243b60dd72be4f86d53eb8db5668113f65e274 --- /dev/null +++ b/espnet2/bin/asr_train.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +from espnet2.tasks.asr import ASRTask + + +def get_parser(): + parser = ASRTask.get_parser() + return parser + + +def main(cmd=None): + r"""ASR training. + + Example: + + % python asr_train.py asr --print_config --optim adadelta \ + > conf/train_asr.yaml + % python asr_train.py --config conf/train_asr.yaml + """ + ASRTask.main(cmd=cmd) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/diar_inference.py b/espnet2/bin/diar_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..46d23b247579dc5ee9b122bacc95f6424e27a560 --- /dev/null +++ b/espnet2/bin/diar_inference.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python3 + +import argparse +import logging +from pathlib import Path +import sys +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from tqdm import trange +from typeguard import check_argument_types + +from espnet.utils.cli_utils import get_commandline_args +from espnet2.fileio.npy_scp import NpyScpWriter +from espnet2.tasks.diar import DiarizationTask +from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.set_all_random_seed import set_all_random_seed +from espnet2.utils import config_argparse +from espnet2.utils.types import humanfriendly_parse_size_or_none +from espnet2.utils.types import str2bool +from espnet2.utils.types import str2triple_str +from espnet2.utils.types import str_or_none + + +class DiarizeSpeech: + """DiarizeSpeech class + + Examples: + >>> import soundfile + >>> diarization = DiarizeSpeech("diar_config.yaml", "diar.pth") + >>> audio, rate = soundfile.read("speech.wav") + >>> diarization(audio) + [(spk_id, start, end), (spk_id2, start2, end2)] + + """ + + def __init__( + self, + diar_train_config: Union[Path, str], + diar_model_file: Union[Path, str] = None, + segment_size: Optional[float] = None, + normalize_segment_scale: bool = False, + show_progressbar: bool = False, + device: str = "cpu", + dtype: str = "float32", + ): + assert check_argument_types() + + # 1. Build Diar model + diar_model, diar_train_args = DiarizationTask.build_model_from_file( + diar_train_config, diar_model_file, device + ) + diar_model.to(dtype=getattr(torch, dtype)).eval() + + self.device = device + self.dtype = dtype + self.diar_train_args = diar_train_args + self.diar_model = diar_model + + # only used when processing long speech, i.e. + # segment_size is not None and hop_size is not None + self.segment_size = segment_size + self.normalize_segment_scale = normalize_segment_scale + self.show_progressbar = show_progressbar + + self.num_spk = diar_model.num_spk + + self.segmenting = segment_size is not None + if self.segmenting: + logging.info("Perform segment-wise speaker diarization") + logging.info("Segment length = {} sec".format(segment_size)) + else: + logging.info("Perform direct speaker diarization on the input") + + @torch.no_grad() + def __call__( + self, speech: Union[torch.Tensor, np.ndarray], fs: int = 8000 + ) -> List[torch.Tensor]: + """Inference + + Args: + speech: Input speech data (Batch, Nsamples [, Channels]) + fs: sample rate + Returns: + [speaker_info1, speaker_info2, ...] + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech, np.ndarray): + speech = torch.as_tensor(speech) + + assert speech.dim() > 1, speech.size() + batch_size = speech.size(0) + speech = speech.to(getattr(torch, self.dtype)) + # lenghts: (B,) + lengths = speech.new_full( + [batch_size], dtype=torch.long, fill_value=speech.size(1) + ) + + # a. To device + speech = to_device(speech, device=self.device) + lengths = to_device(lengths, device=self.device) + + if self.segmenting and lengths[0] > self.segment_size * fs: + # Segment-wise speaker diarization + num_segments = int(np.ceil(speech.size(1) / (self.segment_size * fs))) + t = T = int(self.segment_size * fs) + pad_shape = speech[:, :T].shape + diarized_wavs = [] + range_ = trange if self.show_progressbar else range + for i in range_(num_segments): + st = int(i * self.segment_size * fs) + en = st + T + if en >= lengths[0]: + # en - st < T (last segment) + en = lengths[0] + speech_seg = speech.new_zeros(pad_shape) + t = en - st + speech_seg[:, :t] = speech[:, st:en] + else: + t = T + speech_seg = speech[:, st:en] # B x T [x C] + + lengths_seg = speech.new_full( + [batch_size], dtype=torch.long, fill_value=T + ) + # b. Diarization Forward + encoder_out, encoder_out_lens = self.diar_model.encode( + speech_seg, lengths_seg + ) + spk_prediction = self.diar_model.decoder(encoder_out, encoder_out_lens) + + # List[torch.Tensor(B, T, num_spks)] + diarized_wavs.append(spk_prediction) + + spk_prediction = torch.cat(diarized_wavs, dim=1) + else: + # b. Diarization Forward + encoder_out, encoder_out_lens = self.diar_model.encode(speech, lengths) + spk_prediction = self.diar_model.decoder(encoder_out, encoder_out_lens) + + assert spk_prediction.size(2) == self.num_spk, ( + spk_prediction.size(2), + self.num_spk, + ) + assert spk_prediction.size(0) == batch_size, ( + spk_prediction.size(0), + batch_size, + ) + spk_prediction = spk_prediction.cpu().numpy() + spk_prediction = 1 / (1 + np.exp(-spk_prediction)) + + return spk_prediction + + +def inference( + output_dir: str, + batch_size: int, + dtype: str, + fs: int, + ngpu: int, + seed: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + key_file: Optional[str], + diar_train_config: str, + diar_model_file: str, + allow_variable_data_keys: bool, + segment_size: Optional[float], + show_progressbar: bool, +): + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1: + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build separate_speech + diarize_speech = DiarizeSpeech( + diar_train_config=diar_train_config, + diar_model_file=diar_model_file, + segment_size=segment_size, + show_progressbar=show_progressbar, + device=device, + dtype=dtype, + ) + + # 3. Build data-iterator + loader = DiarizationTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=DiarizationTask.build_preprocess_fn( + diarize_speech.diar_train_args, False + ), + collate_fn=DiarizationTask.build_collate_fn( + diarize_speech.diar_train_args, False + ), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + # 4. Start for-loop + writer = NpyScpWriter(f"{output_dir}/predictions", f"{output_dir}/diarize.scp") + + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} + + spk_predictions = diarize_speech(**batch) + for b in range(batch_size): + writer[keys[b]] = spk_predictions[b] + + writer.close() + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Speaker Diarization inference", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--fs", + type=humanfriendly_parse_size_or_none, + default=8000, + help="Sampling rate", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=True, + action="append", + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument("--diar_train_config", type=str, required=True) + group.add_argument("--diar_model_file", type=str, required=True) + + group = parser.add_argument_group("Data loading related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + group = parser.add_argument_group("Diarize speech related") + group.add_argument( + "--segment_size", + type=float, + default=None, + help="Segment length in seconds for segment-wise speaker diarization", + ) + group.add_argument( + "--show_progressbar", + type=str2bool, + default=False, + help="Whether to show a progress bar when performing segment-wise speaker " + "diarization", + ) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + inference(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/diar_train.py b/espnet2/bin/diar_train.py new file mode 100644 index 0000000000000000000000000000000000000000..dca8d57d4e69ebd165d7200b7ba10952c5cda8e3 --- /dev/null +++ b/espnet2/bin/diar_train.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 + +from espnet2.tasks.diar import DiarizationTask + + +def get_parser(): + parser = DiarizationTask.get_parser() + return parser + + +def main(cmd=None): + r"""Speaker diarization training. + + Example: + % python diar_train.py diar --print_config --optim adadelta \ + > conf/train_diar.yaml + % python diar_train.py --config conf/train_diar.yaml + """ + DiarizationTask.main(cmd=cmd) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/enh_inference.py b/espnet2/bin/enh_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..54958d842fa90c04948e595cdc254404dd1e7b13 --- /dev/null +++ b/espnet2/bin/enh_inference.py @@ -0,0 +1,481 @@ +#!/usr/bin/env python3 +import argparse +import logging +from pathlib import Path +import sys +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import humanfriendly +import numpy as np +import torch +from tqdm import trange +from typeguard import check_argument_types + +from espnet.utils.cli_utils import get_commandline_args +from espnet2.fileio.sound_scp import SoundScpWriter +from espnet2.tasks.enh import EnhancementTask +from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.set_all_random_seed import set_all_random_seed +from espnet2.utils import config_argparse +from espnet2.utils.types import str2bool +from espnet2.utils.types import str2triple_str +from espnet2.utils.types import str_or_none + + +EPS = torch.finfo(torch.get_default_dtype()).eps + + +class SeparateSpeech: + """SeparateSpeech class + + Examples: + >>> import soundfile + >>> separate_speech = SeparateSpeech("enh_config.yml", "enh.pth") + >>> audio, rate = soundfile.read("speech.wav") + >>> separate_speech(audio) + [separated_audio1, separated_audio2, ...] + + """ + + def __init__( + self, + enh_train_config: Union[Path, str], + enh_model_file: Union[Path, str] = None, + segment_size: Optional[float] = None, + hop_size: Optional[float] = None, + normalize_segment_scale: bool = False, + show_progressbar: bool = False, + ref_channel: Optional[int] = None, + normalize_output_wav: bool = False, + device: str = "cpu", + dtype: str = "float32", + ): + assert check_argument_types() + + # 1. Build Enh model + enh_model, enh_train_args = EnhancementTask.build_model_from_file( + enh_train_config, enh_model_file, device + ) + enh_model.to(dtype=getattr(torch, dtype)).eval() + + self.device = device + self.dtype = dtype + self.enh_train_args = enh_train_args + self.enh_model = enh_model + + # only used when processing long speech, i.e. + # segment_size is not None and hop_size is not None + self.segment_size = segment_size + self.hop_size = hop_size + self.normalize_segment_scale = normalize_segment_scale + self.normalize_output_wav = normalize_output_wav + self.show_progressbar = show_progressbar + + self.num_spk = enh_model.num_spk + task = "enhancement" if self.num_spk == 1 else "separation" + + # reference channel for processing multi-channel speech + if ref_channel is not None: + logging.info( + "Overwrite enh_model.separator.ref_channel with {}".format(ref_channel) + ) + enh_model.separator.ref_channel = ref_channel + self.ref_channel = ref_channel + else: + self.ref_channel = enh_model.ref_channel + + self.segmenting = segment_size is not None and hop_size is not None + if self.segmenting: + logging.info("Perform segment-wise speech %s" % task) + logging.info( + "Segment length = {} sec, hop length = {} sec".format( + segment_size, hop_size + ) + ) + else: + logging.info("Perform direct speech %s on the input" % task) + + @torch.no_grad() + def __call__( + self, speech_mix: Union[torch.Tensor, np.ndarray], fs: int = 8000 + ) -> List[torch.Tensor]: + """Inference + + Args: + speech_mix: Input speech data (Batch, Nsamples [, Channels]) + fs: sample rate + Returns: + [separated_audio1, separated_audio2, ...] + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech_mix, np.ndarray): + speech_mix = torch.as_tensor(speech_mix) + + assert speech_mix.dim() > 1, speech_mix.size() + batch_size = speech_mix.size(0) + speech_mix = speech_mix.to(getattr(torch, self.dtype)) + # lenghts: (B,) + lengths = speech_mix.new_full( + [batch_size], dtype=torch.long, fill_value=speech_mix.size(1) + ) + + # a. To device + speech_mix = to_device(speech_mix, device=self.device) + lengths = to_device(lengths, device=self.device) + + if self.segmenting and lengths[0] > self.segment_size * fs: + # Segment-wise speech enhancement/separation + overlap_length = int(np.round(fs * (self.segment_size - self.hop_size))) + num_segments = int( + np.ceil((speech_mix.size(1) - overlap_length) / (self.hop_size * fs)) + ) + t = T = int(self.segment_size * fs) + pad_shape = speech_mix[:, :T].shape + enh_waves = [] + range_ = trange if self.show_progressbar else range + for i in range_(num_segments): + st = int(i * self.hop_size * fs) + en = st + T + if en >= lengths[0]: + # en - st < T (last segment) + en = lengths[0] + speech_seg = speech_mix.new_zeros(pad_shape) + t = en - st + speech_seg[:, :t] = speech_mix[:, st:en] + else: + t = T + speech_seg = speech_mix[:, st:en] # B x T [x C] + + lengths_seg = speech_mix.new_full( + [batch_size], dtype=torch.long, fill_value=T + ) + # b. Enhancement/Separation Forward + feats, f_lens = self.enh_model.encoder(speech_seg, lengths_seg) + feats, _, _ = self.enh_model.separator(feats, f_lens) + processed_wav = [ + self.enh_model.decoder(f, lengths_seg)[0] for f in feats + ] + if speech_seg.dim() > 2: + # multi-channel speech + speech_seg_ = speech_seg[:, self.ref_channel] + else: + speech_seg_ = speech_seg + + if self.normalize_segment_scale: + # normalize the energy of each separated stream + # to match the input energy + processed_wav = [ + self.normalize_scale(w, speech_seg_) for w in processed_wav + ] + # List[torch.Tensor(num_spk, B, T)] + enh_waves.append(torch.stack(processed_wav, dim=0)) + + # c. Stitch the enhanced segments together + waves = enh_waves[0] + for i in range(1, num_segments): + # permutation between separated streams in last and current segments + perm = self.cal_permumation( + waves[:, :, -overlap_length:], + enh_waves[i][:, :, :overlap_length], + criterion="si_snr", + ) + # repermute separated streams in current segment + for batch in range(batch_size): + enh_waves[i][:, batch] = enh_waves[i][perm[batch], batch] + + if i == num_segments - 1: + enh_waves[i][:, :, t:] = 0 + enh_waves_res_i = enh_waves[i][:, :, overlap_length:t] + else: + enh_waves_res_i = enh_waves[i][:, :, overlap_length:] + + # overlap-and-add (average over the overlapped part) + waves[:, :, -overlap_length:] = ( + waves[:, :, -overlap_length:] + enh_waves[i][:, :, :overlap_length] + ) / 2 + # concatenate the residual parts of the later segment + waves = torch.cat([waves, enh_waves_res_i], dim=2) + # ensure the stitched length is same as input + assert waves.size(2) == speech_mix.size(1), (waves.shape, speech_mix.shape) + waves = torch.unbind(waves, dim=0) + else: + # b. Enhancement/Separation Forward + feats, f_lens = self.enh_model.encoder(speech_mix, lengths) + feats, _, _ = self.enh_model.separator(feats, f_lens) + waves = [self.enh_model.decoder(f, lengths)[0] for f in feats] + + assert len(waves) == self.num_spk, len(waves) == self.num_spk + assert len(waves[0]) == batch_size, (len(waves[0]), batch_size) + if self.normalize_output_wav: + waves = [ + (w / abs(w).max(dim=1, keepdim=True)[0] * 0.9).cpu().numpy() + for w in waves + ] # list[(batch, sample)] + else: + waves = [w.cpu().numpy() for w in waves] + + return waves + + @staticmethod + @torch.no_grad() + def normalize_scale(enh_wav, ref_ch_wav): + """Normalize the energy of enh_wav to match that of ref_ch_wav. + + Args: + enh_wav (torch.Tensor): (B, Nsamples) + ref_ch_wav (torch.Tensor): (B, Nsamples) + Returns: + enh_wav (torch.Tensor): (B, Nsamples) + """ + ref_energy = torch.sqrt(torch.mean(ref_ch_wav.pow(2), dim=1)) + enh_energy = torch.sqrt(torch.mean(enh_wav.pow(2), dim=1)) + return enh_wav * (ref_energy / enh_energy)[:, None] + + @torch.no_grad() + def cal_permumation(self, ref_wavs, enh_wavs, criterion="si_snr"): + """Calculate the permutation between seaprated streams in two adjacent segments. + + Args: + ref_wavs (List[torch.Tensor]): [(Batch, Nsamples)] + enh_wavs (List[torch.Tensor]): [(Batch, Nsamples)] + criterion (str): one of ("si_snr", "mse", "corr) + Returns: + perm (torch.Tensor): permutation for enh_wavs (Batch, num_spk) + """ + loss_func = { + "si_snr": self.enh_model.si_snr_loss, + "mse": lambda enh, ref: torch.mean((enh - ref).pow(2), dim=1), + "corr": lambda enh, ref: ( + (enh * ref).sum(dim=1) + / (enh.pow(2).sum(dim=1) * ref.pow(2).sum(dim=1) + EPS) + ).clamp(min=EPS, max=1 - EPS), + }[criterion] + + _, perm = self.enh_model._permutation_loss(ref_wavs, enh_wavs, loss_func) + return perm + + +def humanfriendly_or_none(value: str): + if value in ("none", "None", "NONE"): + return None + return humanfriendly.parse_size(value) + + +def inference( + output_dir: str, + batch_size: int, + dtype: str, + fs: int, + ngpu: int, + seed: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + key_file: Optional[str], + enh_train_config: str, + enh_model_file: str, + allow_variable_data_keys: bool, + segment_size: Optional[float], + hop_size: Optional[float], + normalize_segment_scale: bool, + show_progressbar: bool, + ref_channel: Optional[int], + normalize_output_wav: bool, +): + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1: + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build separate_speech + separate_speech = SeparateSpeech( + enh_train_config=enh_train_config, + enh_model_file=enh_model_file, + segment_size=segment_size, + hop_size=hop_size, + normalize_segment_scale=normalize_segment_scale, + show_progressbar=show_progressbar, + ref_channel=ref_channel, + normalize_output_wav=normalize_output_wav, + device=device, + dtype=dtype, + ) + + # 3. Build data-iterator + loader = EnhancementTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=EnhancementTask.build_preprocess_fn( + separate_speech.enh_train_args, False + ), + collate_fn=EnhancementTask.build_collate_fn( + separate_speech.enh_train_args, False + ), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + # 4. Start for-loop + writers = [] + for i in range(separate_speech.num_spk): + writers.append( + SoundScpWriter(f"{output_dir}/wavs/{i + 1}", f"{output_dir}/spk{i + 1}.scp") + ) + + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} + + waves = separate_speech(**batch) + for (spk, w) in enumerate(waves): + for b in range(batch_size): + writers[spk][keys[b]] = fs, w[b] + + for writer in writers: + writer.close() + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Frontend inference", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--fs", type=humanfriendly_or_none, default=8000, help="Sampling rate" + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=True, + action="append", + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("Output data related") + group.add_argument( + "--normalize_output_wav", + type=str2bool, + default=False, + help="Whether to normalize the predicted wav to [-1~1]", + ) + + group = parser.add_argument_group("The model configuration related") + group.add_argument("--enh_train_config", type=str, required=True) + group.add_argument("--enh_model_file", type=str, required=True) + + group = parser.add_argument_group("Data loading related") + group.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + group = parser.add_argument_group("SeparateSpeech related") + group.add_argument( + "--segment_size", + type=float, + default=None, + help="Segment length in seconds for segment-wise speech enhancement/separation", + ) + group.add_argument( + "--hop_size", + type=float, + default=None, + help="Hop length in seconds for segment-wise speech enhancement/separation", + ) + group.add_argument( + "--normalize_segment_scale", + type=str2bool, + default=False, + help="Whether to normalize the energy of the separated streams in each segment", + ) + group.add_argument( + "--show_progressbar", + type=str2bool, + default=False, + help="Whether to show a progress bar when performing segment-wise speech " + "enhancement/separation", + ) + group.add_argument( + "--ref_channel", + type=int, + default=None, + help="If not None, this will overwrite the ref_channel defined in the " + "separator module (for multi-channel speech processing)", + ) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + inference(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/enh_scoring.py b/espnet2/bin/enh_scoring.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a42fdb07c5bc162749422ca969fa9029ebac4 --- /dev/null +++ b/espnet2/bin/enh_scoring.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +import argparse +import logging +import sys +from typing import List +from typing import Union + +from mir_eval.separation import bss_eval_sources +import numpy as np +from pystoi import stoi +import torch +from typeguard import check_argument_types + +from espnet.utils.cli_utils import get_commandline_args +from espnet2.enh.espnet_model import ESPnetEnhancementModel +from espnet2.fileio.datadir_writer import DatadirWriter +from espnet2.fileio.sound_scp import SoundScpReader +from espnet2.utils import config_argparse + + +def scoring( + output_dir: str, + dtype: str, + log_level: Union[int, str], + key_file: str, + ref_scp: List[str], + inf_scp: List[str], + ref_channel: int, +): + assert check_argument_types() + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + assert len(ref_scp) == len(inf_scp), ref_scp + num_spk = len(ref_scp) + + keys = [ + line.rstrip().split(maxsplit=1)[0] for line in open(key_file, encoding="utf-8") + ] + + ref_readers = [SoundScpReader(f, dtype=dtype, normalize=True) for f in ref_scp] + inf_readers = [SoundScpReader(f, dtype=dtype, normalize=True) for f in inf_scp] + + # get sample rate + sample_rate, _ = ref_readers[0][keys[0]] + + # check keys + for inf_reader, ref_reader in zip(inf_readers, ref_readers): + assert inf_reader.keys() == ref_reader.keys() + + with DatadirWriter(output_dir) as writer: + for key in keys: + ref_audios = [ref_reader[key][1] for ref_reader in ref_readers] + inf_audios = [inf_reader[key][1] for inf_reader in inf_readers] + ref = np.array(ref_audios) + inf = np.array(inf_audios) + if ref.ndim > inf.ndim: + # multi-channel reference and single-channel output + ref = ref[..., ref_channel] + assert ref.shape == inf.shape, (ref.shape, inf.shape) + elif ref.ndim < inf.ndim: + # single-channel reference and multi-channel output + raise ValueError( + "Reference must be multi-channel when the \ + network output is multi-channel." + ) + elif ref.ndim == inf.ndim == 3: + # multi-channel reference and output + ref = ref[..., ref_channel] + inf = inf[..., ref_channel] + + sdr, sir, sar, perm = bss_eval_sources(ref, inf, compute_permutation=True) + + for i in range(num_spk): + stoi_score = stoi(ref[i], inf[int(perm[i])], fs_sig=sample_rate) + si_snr_score = -float( + ESPnetEnhancementModel.si_snr_loss( + torch.from_numpy(ref[i][None, ...]), + torch.from_numpy(inf[int(perm[i])][None, ...]), + ) + ) + writer[f"STOI_spk{i + 1}"][key] = str(stoi_score) + writer[f"SI_SNR_spk{i + 1}"][key] = str(si_snr_score) + writer[f"SDR_spk{i + 1}"][key] = str(sdr[i]) + writer[f"SAR_spk{i + 1}"][key] = str(sar[i]) + writer[f"SIR_spk{i + 1}"][key] = str(sir[i]) + # save permutation assigned script file + writer[f"wav_spk{i + 1}"][key] = inf_readers[perm[i]].data[key] + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Frontend inference", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--ref_scp", + type=str, + required=True, + action="append", + ) + group.add_argument( + "--inf_scp", + type=str, + required=True, + action="append", + ) + group.add_argument("--key_file", type=str) + group.add_argument("--ref_channel", type=int, default=0) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + scoring(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/enh_train.py b/espnet2/bin/enh_train.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4708eb87eb939596be0dd8cfbfaa1c832cf0b3 --- /dev/null +++ b/espnet2/bin/enh_train.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +from espnet2.tasks.enh import EnhancementTask + + +def get_parser(): + parser = EnhancementTask.get_parser() + return parser + + +def main(cmd=None): + r"""Enhancemnet frontend training. + + Example: + + % python enh_train.py asr --print_config --optim adadelta \ + > conf/train_enh.yaml + % python enh_train.py --config conf/train_enh.yaml + """ + EnhancementTask.main(cmd=cmd) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/launch.py b/espnet2/bin/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c86f9b7dab514fadc76034f4026664b5af7f9e --- /dev/null +++ b/espnet2/bin/launch.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python3 +import argparse +import logging +import os +from pathlib import Path +import shlex +import shutil +import subprocess +import sys +import uuid + +from espnet.utils.cli_utils import get_commandline_args +from espnet2.utils.types import str2bool +from espnet2.utils.types import str_or_none + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Launch distributed process with appropriate options. ", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--cmd", + help="The path of cmd script of Kaldi: run.pl. queue.pl, or slurm.pl", + default="utils/run.pl", + ) + parser.add_argument( + "--log", + help="The path of log file used by cmd", + default="run.log", + ) + parser.add_argument( + "--max_num_log_files", + help="The maximum number of log-files to be kept", + default=1000, + ) + parser.add_argument( + "--ngpu", type=int, default=1, help="The number of GPUs per node" + ) + egroup = parser.add_mutually_exclusive_group() + egroup.add_argument("--num_nodes", type=int, default=1, help="The number of nodes") + egroup.add_argument( + "--host", + type=str, + default=None, + help="Directly specify the host names. The job are submitted via SSH. " + "Multiple host names can be specified by splitting by comma. e.g. host1,host2" + " You can also the device id after the host name with ':'. e.g. " + "host1:0:2:3,host2:0:2. If the device ids are specified in this way, " + "the value of --ngpu is ignored.", + ) + parser.add_argument( + "--envfile", + type=str_or_none, + default="path.sh", + help="Source the shell script before executing command. " + "This option is used when --host is specified.", + ) + + parser.add_argument( + "--multiprocessing_distributed", + type=str2bool, + default=True, + help="Distributed method is used when single-node mode.", + ) + parser.add_argument( + "--master_port", + type=int, + default=None, + help="Specify the port number of master" + "Master is a host machine has RANK0 process.", + ) + parser.add_argument( + "--master_addr", + type=str, + default=None, + help="Specify the address s of master. " + "Master is a host machine has RANK0 process.", + ) + parser.add_argument( + "--init_file_prefix", + type=str, + default=".dist_init_", + help="The file name prefix for init_file, which is used for " + "'Shared-file system initialization'. " + "This option is used when --port is not specified", + ) + parser.add_argument("args", type=str, nargs="+") + return parser + + +def main(cmd=None): + logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + logging.basicConfig(level=logging.INFO, format=logfmt) + logging.info(get_commandline_args()) + + parser = get_parser() + args = parser.parse_args(cmd) + args.cmd = shlex.split(args.cmd) + + if args.host is None and shutil.which(args.cmd[0]) is None: + raise RuntimeError( + f"The first args of --cmd should be a script path. e.g. utils/run.pl: " + f"{args.cmd[0]}" + ) + + # Specify init_method: + # See: https://pytorch.org/docs/stable/distributed.html#initialization + if args.host is None and args.num_nodes <= 1: + # Automatically set init_method if num_node=1 + init_method = None + else: + if args.master_port is None: + # Try "shared-file system initialization" if master_port is not specified + # Give random name to avoid reusing previous file + init_file = args.init_file_prefix + str(uuid.uuid4()) + init_file = Path(init_file).absolute() + Path(init_file).parent.mkdir(exist_ok=True, parents=True) + init_method = ["--dist_init_method", f"file://{init_file}"] + else: + init_method = ["--dist_master_port", str(args.master_port)] + + # This can be omitted if slurm mode + if args.master_addr is not None: + init_method += ["--dist_master_addr", args.master_addr] + elif args.host is not None: + init_method += [ + "--dist_master_addr", + args.host.split(",")[0].split(":")[0], + ] + + # Log-rotation + for i in range(args.max_num_log_files - 1, -1, -1): + if i == 0: + p = Path(args.log) + pn = p.parent / (p.stem + ".1" + p.suffix) + else: + _p = Path(args.log) + p = _p.parent / (_p.stem + f".{i}" + _p.suffix) + pn = _p.parent / (_p.stem + f".{i + 1}" + _p.suffix) + + if p.exists(): + if i == args.max_num_log_files - 1: + p.unlink() + else: + shutil.move(p, pn) + + processes = [] + # Submit command via SSH + if args.host is not None: + hosts = [] + ids_list = [] + # e.g. args.host = "host1:0:2,host2:0:1" + for host in args.host.split(","): + # e.g host = "host1:0:2" + sps = host.split(":") + host = sps[0] + if len(sps) > 1: + ids = [int(x) for x in sps[1:]] + else: + ids = list(range(args.ngpu)) + hosts.append(host) + ids_list.append(ids) + + world_size = sum(max(len(x), 1) for x in ids_list) + logging.info(f"{len(hosts)}nodes with world_size={world_size} via SSH") + + if args.envfile is not None: + env = f"source {args.envfile}" + else: + env = "" + + if args.log != "-": + Path(args.log).parent.mkdir(parents=True, exist_ok=True) + f = Path(args.log).open("w", encoding="utf-8") + else: + # Output to stdout/stderr + f = None + + rank = 0 + for host, ids in zip(hosts, ids_list): + ngpu = 1 if len(ids) > 0 else 0 + ids = ids if len(ids) > 0 else ["none"] + + for local_rank in ids: + cmd = ( + args.args + + [ + "--ngpu", + str(ngpu), + "--multiprocessing_distributed", + "false", + "--local_rank", + str(local_rank), + "--dist_rank", + str(rank), + "--dist_world_size", + str(world_size), + ] + + init_method + ) + if ngpu == 0: + # Gloo supports both GPU and CPU mode. + # See: https://pytorch.org/docs/stable/distributed.html + cmd += ["--dist_backend", "gloo"] + + heredoc = f"""<< EOF +set -euo pipefail +cd {os.getcwd()} +{env} +{" ".join([c if len(c) != 0 else "''" for c in cmd])} +EOF +""" + + # FIXME(kamo): The process will be alive + # even if this program is stopped because we don't set -t here, + # i.e. not assigning pty, + # and the program is not killed when SSH connection is closed. + process = subprocess.Popen( + ["ssh", host, "bash", heredoc], + stdout=f, + stderr=f, + ) + + processes.append(process) + + rank += 1 + + # If Single node + elif args.num_nodes <= 1: + if args.ngpu > 1: + if args.multiprocessing_distributed: + # NOTE: + # If multiprocessing_distributed=true, + # -> Distributed mode, which is multi-process and Multi-GPUs. + # and TCP initializetion is used if single-node case: + # e.g. init_method="tcp://localhost:20000" + logging.info(f"single-node with {args.ngpu}gpu on distributed mode") + else: + # NOTE: + # If multiprocessing_distributed=false + # -> "DataParallel" mode, which is single-process + # and Multi-GPUs with threading. + # See: + # https://discuss.pytorch.org/t/why-torch-nn-parallel-distributeddataparallel-runs-faster-than-torch-nn-dataparallel-on-single-machine-with-multi-gpu/32977/2 + logging.info(f"single-node with {args.ngpu}gpu using DataParallel") + + # Using cmd as it is simply + cmd = ( + args.cmd + # arguments for ${cmd} + + ["--gpu", str(args.ngpu), args.log] + # arguments for *_train.py + + args.args + + [ + "--ngpu", + str(args.ngpu), + "--multiprocessing_distributed", + str(args.multiprocessing_distributed), + ] + ) + process = subprocess.Popen(cmd) + processes.append(process) + + elif Path(args.cmd[0]).name == "run.pl": + raise RuntimeError("run.pl doesn't support submitting to the other nodes.") + + elif Path(args.cmd[0]).name == "ssh.pl": + raise RuntimeError("Use --host option instead of ssh.pl") + + # If Slurm + elif Path(args.cmd[0]).name == "slurm.pl": + logging.info(f"{args.num_nodes}nodes and {args.ngpu}gpu-per-node using srun") + cmd = ( + args.cmd + # arguments for ${cmd} + + [ + "--gpu", + str(args.ngpu), + "--num_threads", + str(max(args.ngpu, 1)), + "--num_nodes", + str(args.num_nodes), + args.log, + "srun", + # Inherit all enviroment variable from parent process + "--export=ALL", + ] + # arguments for *_train.py + + args.args + + [ + "--ngpu", + str(args.ngpu), + "--multiprocessing_distributed", + "true", + "--dist_launcher", + "slurm", + ] + + init_method + ) + if args.ngpu == 0: + # Gloo supports both GPU and CPU mode. + # See: https://pytorch.org/docs/stable/distributed.html + cmd += ["--dist_backend", "gloo"] + process = subprocess.Popen(cmd) + processes.append(process) + + else: + # This pattern can also works with Slurm. + + logging.info(f"{args.num_nodes}nodes and {args.ngpu}gpu-per-node using mpirun") + cmd = ( + args.cmd + # arguments for ${cmd} + + [ + "--gpu", + str(args.ngpu), + "--num_threads", + str(max(args.ngpu, 1)), + # Make sure scheduler setting, i.e. conf/queue.conf + # so that --num_nodes requires 1process-per-node + "--num_nodes", + str(args.num_nodes), + args.log, + "mpirun", + # -np option can be omitted with Torque/PBS + "-np", + str(args.num_nodes), + ] + # arguments for *_train.py + + args.args + + [ + "--ngpu", + str(args.ngpu), + "--multiprocessing_distributed", + "true", + "--dist_launcher", + "mpi", + ] + + init_method + ) + if args.ngpu == 0: + # Gloo supports both GPU and CPU mode. + # See: https://pytorch.org/docs/stable/distributed.html + cmd += ["--dist_backend", "gloo"] + process = subprocess.Popen(cmd) + processes.append(process) + + logging.info(f"log file: {args.log}") + + failed = False + while any(p.returncode is None for p in processes): + for process in processes: + # If any process is failed, try to kill the other processes too + if failed and process.returncode is not None: + process.kill() + else: + try: + process.wait(0.5) + except subprocess.TimeoutExpired: + pass + + if process.returncode is not None and process.returncode != 0: + failed = True + + for process in processes: + if process.returncode != 0: + print( + subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd), + file=sys.stderr, + ) + p = Path(args.log) + if p.exists(): + with p.open() as f: + lines = list(f) + raise RuntimeError( + f"\n################### The last 1000 lines of {args.log} " + f"###################\n" + "".join(lines[-1000:]) + ) + else: + raise RuntimeError + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/lm_calc_perplexity.py b/espnet2/bin/lm_calc_perplexity.py new file mode 100644 index 0000000000000000000000000000000000000000..97ba229afe365ff46a09459e13f78ce8a93dd2cf --- /dev/null +++ b/espnet2/bin/lm_calc_perplexity.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python3 +import argparse +import logging +from pathlib import Path +import sys +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from torch.nn.parallel import data_parallel +from typeguard import check_argument_types + +from espnet.utils.cli_utils import get_commandline_args +from espnet2.fileio.datadir_writer import DatadirWriter +from espnet2.tasks.lm import LMTask +from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.forward_adaptor import ForwardAdaptor +from espnet2.torch_utils.set_all_random_seed import set_all_random_seed +from espnet2.utils import config_argparse +from espnet2.utils.types import float_or_none +from espnet2.utils.types import str2bool +from espnet2.utils.types import str2triple_str +from espnet2.utils.types import str_or_none + + +def calc_perplexity( + output_dir: str, + batch_size: int, + dtype: str, + ngpu: int, + seed: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + key_file: Optional[str], + train_config: Optional[str], + model_file: Optional[str], + log_base: Optional[float], + allow_variable_data_keys: bool, +): + assert check_argument_types() + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1: + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build LM + model, train_args = LMTask.build_model_from_file(train_config, model_file, device) + # Wrape model to make model.nll() data-parallel + wrapped_model = ForwardAdaptor(model, "nll") + wrapped_model.to(dtype=getattr(torch, dtype)).eval() + logging.info(f"Model:\n{model}") + + # 3. Build data-iterator + loader = LMTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=LMTask.build_preprocess_fn(train_args, False), + collate_fn=LMTask.build_collate_fn(train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + # 4. Start for-loop + with DatadirWriter(output_dir) as writer: + total_nll = 0.0 + total_ntokens = 0 + for keys, batch in loader: + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert len(keys) == _bs, f"{len(keys)} != {_bs}" + + with torch.no_grad(): + batch = to_device(batch, device) + if ngpu <= 1: + # NOTE(kamo): data_parallel also should work with ngpu=1, + # but for debuggability it's better to keep this block. + nll, lengths = wrapped_model(**batch) + else: + nll, lengths = data_parallel( + wrapped_model, (), range(ngpu), module_kwargs=batch + ) + + assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths)) + # nll: (B, L) -> (B,) + nll = nll.detach().cpu().numpy().sum(1) + # lengths: (B,) + lengths = lengths.detach().cpu().numpy() + total_nll += nll.sum() + total_ntokens += lengths.sum() + + for key, _nll, ntoken in zip(keys, nll, lengths): + if log_base is None: + utt_ppl = np.exp(_nll / ntoken) + else: + utt_ppl = log_base ** (_nll / ntoken / np.log(log_base)) + + # Write PPL of each utts for debugging or analysis + writer["utt2ppl"][key] = str(utt_ppl) + writer["utt2ntokens"][key] = str(ntoken) + + if log_base is None: + ppl = np.exp(total_nll / total_ntokens) + else: + ppl = log_base ** (total_nll / total_ntokens / np.log(log_base)) + + with (Path(output_dir) / "ppl").open("w", encoding="utf-8") as f: + f.write(f"{ppl}\n") + with (Path(output_dir) / "base").open("w", encoding="utf-8") as f: + if log_base is None: + _log_base = np.e + else: + _log_base = log_base + f.write(f"{_log_base}\n") + logging.info(f"PPL={ppl}") + + +def get_parser(): + parser = config_argparse.ArgumentParser( + description="Calc perplexity", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use '_' instead of '-' as separator. + # '-' is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--output_dir", type=str, required=True) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument("--seed", type=int, default=0, help="Random seed") + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + parser.add_argument( + "--log_base", + type=float_or_none, + default=None, + help="The base of logarithm for Perplexity. " + "If None, napier's constant is used.", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=True, + action="append", + ) + group.add_argument("--key_file", type=str_or_none) + group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) + + group = parser.add_argument_group("The model configuration related") + group.add_argument("--train_config", type=str) + group.add_argument("--model_file", type=str) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + calc_perplexity(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/lm_train.py b/espnet2/bin/lm_train.py new file mode 100644 index 0000000000000000000000000000000000000000..f60e9f3b89162f347a44e95265f3ba4c9d615fb1 --- /dev/null +++ b/espnet2/bin/lm_train.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +from espnet2.tasks.lm import LMTask + + +def get_parser(): + parser = LMTask.get_parser() + return parser + + +def main(cmd=None): + """LM training. + + Example: + + % python lm_train.py asr --print_config --optim adadelta + % python lm_train.py --config conf/train_asr.yaml + """ + LMTask.main(cmd=cmd) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/pack.py b/espnet2/bin/pack.py new file mode 100644 index 0000000000000000000000000000000000000000..b152ba6ee76bcd99dcfd491f9d116f79acdb2354 --- /dev/null +++ b/espnet2/bin/pack.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 +import argparse +from typing import Type + +from espnet2.main_funcs.pack_funcs import pack + + +class PackedContents: + files = [] + yaml_files = [] + + +class ASRPackedContents(PackedContents): + # These names must be consistent with the argument of inference functions + files = ["asr_model_file", "lm_file"] + yaml_files = ["asr_train_config", "lm_train_config"] + + +class TTSPackedContents(PackedContents): + files = ["model_file"] + yaml_files = ["train_config"] + + +class EnhPackedContents(PackedContents): + files = ["model_file"] + yaml_files = ["train_config"] + + +def add_arguments(parser: argparse.ArgumentParser, contents: Type[PackedContents]): + parser.add_argument("--outpath", type=str, required=True) + for key in contents.yaml_files: + parser.add_argument(f"--{key}", type=str, default=None) + for key in contents.files: + parser.add_argument(f"--{key}", type=str, default=None) + parser.add_argument("--option", type=str, action="append", default=[]) + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Pack input files to archive format") + subparsers = parser.add_subparsers() + + # Create subparser for ASR + for name, contents in [ + ("asr", ASRPackedContents), + ("tts", TTSPackedContents), + ("enh", EnhPackedContents), + ]: + parser_asr = subparsers.add_parser( + name, + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + add_arguments(parser_asr, contents) + parser_asr.set_defaults(contents=contents) + return parser + + +def main(cmd=None): + parser = get_parser() + args = parser.parse_args(cmd) + if not hasattr(args, "contents"): + parser.print_help() + parser.exit(2) + + yaml_files = { + y: getattr(args, y) + for y in args.contents.yaml_files + if getattr(args, y) is not None + } + files = { + y: getattr(args, y) for y in args.contents.files if getattr(args, y) is not None + } + pack( + yaml_files=yaml_files, + files=files, + option=args.option, + outpath=args.outpath, + ) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/split_scps.py b/espnet2/bin/split_scps.py new file mode 100644 index 0000000000000000000000000000000000000000..557c70bac2c1741bf16595c9258ffd3ee7c21bd8 --- /dev/null +++ b/espnet2/bin/split_scps.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +import argparse +from collections import Counter +from itertools import zip_longest +import logging +from pathlib import Path +import sys +from typing import List +from typing import Optional + +from espnet.utils.cli_utils import get_commandline_args + + +def split_scps( + scps: List[str], + num_splits: int, + names: Optional[List[str]], + output_dir: str, + log_level: str, +): + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + if num_splits < 2: + raise RuntimeError(f"{num_splits} must be more than 1") + + if names is None: + names = [Path(s).name for s in scps] + if len(set(names)) != len(names): + raise RuntimeError(f"names are duplicated: {names}") + + for name in names: + (Path(output_dir) / name).mkdir(parents=True, exist_ok=True) + + scp_files = [open(s, "r", encoding="utf-8") for s in scps] + # Remove existing files + for n in range(num_splits): + for name in names: + if (Path(output_dir) / name / f"split.{n}").exists(): + (Path(output_dir) / name / f"split.{n}").unlink() + + counter = Counter() + linenum = -1 + for linenum, lines in enumerate(zip_longest(*scp_files)): + if any(line is None for line in lines): + raise RuntimeError("Number of lines are mismatched") + + prev_key = None + for line in lines: + key = line.rstrip().split(maxsplit=1)[0] + if prev_key is not None and prev_key != key: + raise RuntimeError("Not sorted or not having same keys") + + # Select a piece from split texts alternatively + num = linenum % num_splits + counter[num] += 1 + # Write lines respectively + for line, name in zip(lines, names): + # To reduce the number of opened file descriptors, open now + with (Path(output_dir) / name / f"split.{num}").open( + "a", encoding="utf-8" + ) as f: + f.write(line) + + if linenum + 1 < num_splits: + raise RuntimeError( + f"The number of lines is less than num_splits: {linenum + 1} < {num_splits}" + ) + + for name in names: + with (Path(output_dir) / name / "num_splits").open("w", encoding="utf-8") as f: + f.write(str(num_splits)) + logging.info(f"N lines of split text: {set(counter.values())}") + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Split scp files", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument("--scps", required=True, help="Input texts", nargs="+") + parser.add_argument("--names", help="Output names for each files", nargs="+") + parser.add_argument("--num_splits", help="Split number", type=int) + parser.add_argument("--output_dir", required=True, help="Output directory") + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + split_scps(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/tokenize_text.py b/espnet2/bin/tokenize_text.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ff08cb60be95e63880534d5be98eceeb4ee45b --- /dev/null +++ b/espnet2/bin/tokenize_text.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +import argparse +from collections import Counter +import logging +from pathlib import Path +import sys +from typing import List +from typing import Optional + +from typeguard import check_argument_types + +from espnet.utils.cli_utils import get_commandline_args +from espnet2.text.build_tokenizer import build_tokenizer +from espnet2.text.cleaner import TextCleaner +from espnet2.utils.types import str2bool +from espnet2.utils.types import str_or_none + + +def field2slice(field: Optional[str]) -> slice: + """Convert field string to slice + + Note that field string accepts 1-based integer. + + Examples: + >>> field2slice("1-") + slice(0, None, None) + >>> field2slice("1-3") + slice(0, 3, None) + >>> field2slice("-3") + slice(None, 3, None) + + """ + field = field.strip() + try: + if "-" in field: + # e.g. "2-" or "2-5" or "-7" + s1, s2 = field.split("-", maxsplit=1) + if s1.strip() == "": + s1 = None + else: + s1 = int(s1) + if s1 == 0: + raise ValueError("1-based string") + if s2.strip() == "": + s2 = None + else: + s2 = int(s2) + else: + # e.g. "2" + s1 = int(field) + s2 = s1 + 1 + if s1 == 0: + raise ValueError("must be 1 or more value") + except ValueError: + raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}") + + # -1 because of 1-based integer following "cut" command + # e.g "1-3" -> slice(0, 3) + slic = slice(s1 - 1, s2) + return slic + + +def tokenize( + input: str, + output: str, + field: Optional[str], + delimiter: Optional[str], + token_type: str, + space_symbol: str, + non_linguistic_symbols: Optional[str], + bpemodel: Optional[str], + log_level: str, + write_vocabulary: bool, + vocabulary_size: int, + remove_non_linguistic_symbols: bool, + cutoff: int, + add_symbol: List[str], + cleaner: Optional[str], + g2p: Optional[str], +): + assert check_argument_types() + + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + if input == "-": + fin = sys.stdin + else: + fin = Path(input).open("r", encoding="utf-8") + if output == "-": + fout = sys.stdout + else: + p = Path(output) + p.parent.mkdir(parents=True, exist_ok=True) + fout = p.open("w", encoding="utf-8") + + cleaner = TextCleaner(cleaner) + tokenizer = build_tokenizer( + token_type=token_type, + bpemodel=bpemodel, + delimiter=delimiter, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + remove_non_linguistic_symbols=remove_non_linguistic_symbols, + g2p_type=g2p, + ) + + counter = Counter() + if field is not None: + field = field2slice(field) + + for line in fin: + line = line.rstrip() + if field is not None: + # e.g. field="2-" + # uttidA hello world!! -> hello world!! + tokens = line.split(delimiter) + tokens = tokens[field] + if delimiter is None: + line = " ".join(tokens) + else: + line = delimiter.join(tokens) + + line = cleaner(line) + tokens = tokenizer.text2tokens(line) + if not write_vocabulary: + fout.write(" ".join(tokens) + "\n") + else: + for t in tokens: + counter[t] += 1 + + if not write_vocabulary: + return + + # ======= write_vocabulary mode from here ======= + # Sort by the number of occurrences in descending order + # and filter lower frequency words than cutoff value + words_and_counts = list( + filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1])) + ) + # Restrict the vocabulary size + if vocabulary_size > 0: + if vocabulary_size < len(add_symbol): + raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}") + words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)] + + # Parse the values of --add_symbol + for symbol_and_id in add_symbol: + # e.g symbol=":0" + try: + symbol, idx = symbol_and_id.split(":") + idx = int(idx) + except ValueError: + raise RuntimeError(f"Format error: e.g. ':0': {symbol_and_id}") + symbol = symbol.strip() + + # e.g. idx=0 -> append as the first symbol + # e.g. idx=-1 -> append as the last symbol + if idx < 0: + idx = len(words_and_counts) + 1 + idx + words_and_counts.insert(idx, (symbol, None)) + + # Write words + for w, c in words_and_counts: + fout.write(w + "\n") + + # Logging + total_count = sum(counter.values()) + invocab_count = sum(c for w, c in words_and_counts if c is not None) + logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %") + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Tokenize texts", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument( + "--input", "-i", required=True, help="Input text. - indicates sys.stdin" + ) + parser.add_argument( + "--output", "-o", required=True, help="Output text. - indicates sys.stdout" + ) + parser.add_argument( + "--field", + "-f", + help="The target columns of the input text as 1-based integer. e.g 2-", + ) + parser.add_argument( + "--token_type", + "-t", + default="char", + choices=["char", "bpe", "word", "phn"], + help="Token type", + ) + parser.add_argument("--delimiter", "-d", default=None, help="The delimiter") + parser.add_argument("--space_symbol", default="", help="The space symbol") + parser.add_argument("--bpemodel", default=None, help="The bpemodel file path") + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + help="non_linguistic_symbols file path", + ) + parser.add_argument( + "--remove_non_linguistic_symbols", + type=str2bool, + default=False, + help="Remove non-language-symbols from tokens", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Apply text cleaning", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=[ + None, + "g2p_en", + "g2p_en_no_space", + "pyopenjtalk", + "pyopenjtalk_kana", + "pyopenjtalk_accent", + "pyopenjtalk_accent_with_pause", + "pypinyin_g2p", + "pypinyin_g2p_phone", + "espeak_ng_arabic", + ], + default=None, + help="Specify g2p method if --token_type=phn", + ) + + group = parser.add_argument_group("write_vocabulary mode related") + group.add_argument( + "--write_vocabulary", + type=str2bool, + default=False, + help="Write tokens list instead of tokenized text per line", + ) + group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size") + group.add_argument( + "--cutoff", + default=0, + type=int, + help="cut-off frequency used for write-vocabulary mode", + ) + group.add_argument( + "--add_symbol", + type=str, + default=[], + action="append", + help="Append symbol e.g. --add_symbol ':0' --add_symbol ':1'", + ) + + return parser + + +def main(cmd=None): + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + tokenize(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/tts_inference.py b/espnet2/bin/tts_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e592fe3f18cd1d697524bb1b6bc7c1e9ed65f2 --- /dev/null +++ b/espnet2/bin/tts_inference.py @@ -0,0 +1,579 @@ +#!/usr/bin/env python3 + +"""TTS mode decoding.""" + +import argparse +import logging +from pathlib import Path +import shutil +import sys +import time +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from collections import defaultdict +import json + +import matplotlib +import numpy as np +import soundfile as sf +import torch +from typeguard import check_argument_types + +from espnet.utils.cli_utils import get_commandline_args +from espnet2.fileio.npy_scp import NpyScpWriter +from espnet2.tasks.tts import TTSTask +from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.set_all_random_seed import set_all_random_seed +from espnet2.tts.duration_calculator import DurationCalculator +from espnet2.tts.fastspeech import FastSpeech +from espnet2.tts.fastspeech2 import FastSpeech2 +from espnet2.tts.fastespeech import FastESpeech +from espnet2.tts.tacotron2 import Tacotron2 +from espnet2.tts.transformer import Transformer +from espnet2.utils import config_argparse +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet2.utils.griffin_lim import Spectrogram2Waveform +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import str2bool +from espnet2.utils.types import str2triple_str +from espnet2.utils.types import str_or_none + + +class Text2Speech: + """Speech2Text class + + Examples: + >>> import soundfile + >>> text2speech = Text2Speech("config.yml", "model.pth") + >>> wav = text2speech("Hello World")[0] + >>> soundfile.write("out.wav", wav.numpy(), text2speech.fs, "PCM_16") + + """ + + def __init__( + self, + train_config: Optional[Union[Path, str]], + model_file: Optional[Union[Path, str]] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 10.0, + use_teacher_forcing: bool = False, + use_att_constraint: bool = False, + backward_window: int = 1, + forward_window: int = 3, + speed_control_alpha: float = 1.0, + vocoder_conf: dict = None, + dtype: str = "float32", + device: str = "cpu", + ): + assert check_argument_types() + + model, train_args = TTSTask.build_model_from_file( + train_config, model_file, device + ) + model.to(dtype=getattr(torch, dtype)).eval() + self.device = device + self.dtype = dtype + self.train_args = train_args + self.model = model + self.tts = model.tts + self.normalize = model.normalize + self.feats_extract = model.feats_extract + self.duration_calculator = DurationCalculator() + self.preprocess_fn = TTSTask.build_preprocess_fn(train_args, False) + self.use_teacher_forcing = use_teacher_forcing + + logging.info(f"Normalization:\n{self.normalize}") + logging.info(f"TTS:\n{self.tts}") + + decode_config = {} + if isinstance(self.tts, (Tacotron2, Transformer)): + decode_config.update( + { + "threshold": threshold, + "maxlenratio": maxlenratio, + "minlenratio": minlenratio, + } + ) + if isinstance(self.tts, Tacotron2): + decode_config.update( + { + "use_att_constraint": use_att_constraint, + "forward_window": forward_window, + "backward_window": backward_window, + } + ) + if isinstance(self.tts, (FastSpeech, FastSpeech2, FastESpeech)): + decode_config.update({"alpha": speed_control_alpha}) + decode_config.update({"use_teacher_forcing": use_teacher_forcing}) + + self.decode_config = decode_config + + if vocoder_conf is None: + vocoder_conf = {} + if self.feats_extract is not None: + vocoder_conf.update(self.feats_extract.get_parameters()) + if ( + "n_fft" in vocoder_conf + and "n_shift" in vocoder_conf + and "fs" in vocoder_conf + ): + self.spc2wav = Spectrogram2Waveform(**vocoder_conf) + logging.info(f"Vocoder: {self.spc2wav}") + else: + self.spc2wav = None + logging.info("Vocoder is not used because vocoder_conf is not sufficient") + + @torch.no_grad() + def __call__( + self, + text: Union[str, torch.Tensor, np.ndarray], + speech: Union[torch.Tensor, np.ndarray] = None, + durations: Union[torch.Tensor, np.ndarray] = None, + ref_embs: torch.Tensor = None, + spembs: Union[torch.Tensor, np.ndarray] = None, # new addition + fg_inds: torch.Tensor = None, + ): + assert check_argument_types() + + if self.use_speech and speech is None: + raise RuntimeError("missing required argument: 'speech'") + + if isinstance(text, str): + # str -> np.ndarray + text = self.preprocess_fn("", {"text": text})["text"] + batch = {"text": text, "ref_embs": ref_embs, "ar_prior_inference": True, "fg_inds": fg_inds} # TC marker + if speech is not None: + batch["speech"] = speech + if durations is not None: + batch["durations"] = durations + if spembs is not None: + batch["spembs"] = spembs + + batch = to_device(batch, self.device) + outs, outs_denorm, probs, att_ws, ref_embs, ar_prior_loss = self.model.inference( + **batch, **self.decode_config + ) + + if att_ws is not None: + duration, focus_rate = self.duration_calculator(att_ws) + else: + duration, focus_rate = None, None + + if self.spc2wav is not None: + wav = torch.tensor(self.spc2wav(outs_denorm.cpu().numpy())) + else: + wav = None + + return wav, outs, outs_denorm, probs, att_ws, duration, focus_rate, ref_embs + + @property + def fs(self) -> Optional[int]: + if self.spc2wav is not None: + return self.spc2wav.fs + else: + return None + + @property + def use_speech(self) -> bool: + """Check whether to require speech in inference. + + Returns: + bool: True if speech is required else False. + + """ + # TC marker, oorspr false -> set false for test_ref_embs, but true if testing wo duration + return self.use_teacher_forcing or getattr(self.tts, "use_gst", False) + + +def inference( + output_dir: str, + batch_size: int, + dtype: str, + ngpu: int, + seed: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + key_file: Optional[str], + train_config: Optional[str], + model_file: Optional[str], + ref_embs: Optional[str], + threshold: float, + minlenratio: float, + maxlenratio: float, + use_teacher_forcing: bool, + use_att_constraint: bool, + backward_window: int, + forward_window: int, + speed_control_alpha: float, + allow_variable_data_keys: bool, + vocoder_conf: dict, +): + """Perform TTS model decoding.""" + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch decoding is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU decoding is supported") + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if len(ref_embs) > 0: + ref_emb_in = torch.load(ref_embs).squeeze(0) + else: + ref_emb_in = None + + if ngpu >= 1: + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build model + text2speech = Text2Speech( + train_config=train_config, + model_file=model_file, + threshold=threshold, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + use_teacher_forcing=use_teacher_forcing, + use_att_constraint=use_att_constraint, + backward_window=backward_window, + forward_window=forward_window, + speed_control_alpha=speed_control_alpha, + vocoder_conf=vocoder_conf, + dtype=dtype, + device=device, + ) + + # 3. Build data-iterator + if not text2speech.use_speech: + data_path_and_name_and_type = list( + filter(lambda x: x[1] != "speech", data_path_and_name_and_type) + ) + loader = TTSTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=TTSTask.build_preprocess_fn(text2speech.train_args, False), + collate_fn=TTSTask.build_collate_fn(text2speech.train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + # 6. Start for-loop + output_dir = Path(output_dir) + (output_dir / "norm").mkdir(parents=True, exist_ok=True) + (output_dir / "denorm").mkdir(parents=True, exist_ok=True) + (output_dir / "speech_shape").mkdir(parents=True, exist_ok=True) + (output_dir / "wav").mkdir(parents=True, exist_ok=True) + (output_dir / "att_ws").mkdir(parents=True, exist_ok=True) + (output_dir / "probs").mkdir(parents=True, exist_ok=True) + (output_dir / "durations").mkdir(parents=True, exist_ok=True) + (output_dir / "focus_rates").mkdir(parents=True, exist_ok=True) + + # Lazy load to avoid the backend error + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.ticker import MaxNLocator + + with NpyScpWriter( + output_dir / "norm", + output_dir / "norm/feats.scp", + ) as norm_writer, NpyScpWriter( + output_dir / "denorm", output_dir / "denorm/feats.scp" + ) as denorm_writer, open( + output_dir / "speech_shape/speech_shape", "w" + ) as shape_writer, open( + output_dir / "durations/durations", "w" + ) as duration_writer, open( + output_dir / "focus_rates/focus_rates", "w" + ) as focus_rate_writer, open( + output_dir / "ref_embs", "w" + ) as ref_embs_writer: + ref_embs_list = [] + ref_embs_dict = defaultdict(list) + for idx, (keys, batch) in enumerate(loader, 1): + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert _bs == 1, _bs + + # Change to single sequence and remove *_length + # because inference() requires 1-seq, not mini-batch. + batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + + start_time = time.perf_counter() + wav, outs, outs_denorm, probs, att_ws, duration, focus_rate, \ + ref_embs = text2speech(ref_embs=ref_emb_in, **batch) + + key = keys[0] + insize = next(iter(batch.values())).size(0) + 1 + logging.info( + "inference speed = {:.1f} frames / sec.".format( + int(outs.size(0)) / (time.perf_counter() - start_time) + ) + ) + logging.info(f"{key} (size:{insize}->{outs.size(0)})") + if outs.size(0) == insize * maxlenratio: + logging.warning(f"output length reaches maximum length ({key}).") + + norm_writer[key] = outs.cpu().numpy() + shape_writer.write(f"{key} " + ",".join(map(str, outs.shape)) + "\n") + + denorm_writer[key] = outs_denorm.cpu().numpy() + + if duration is not None: + # Save duration and fucus rates + duration_writer.write( + f"{key} " + " ".join(map(str, duration.cpu().numpy())) + "\n" + ) + focus_rate_writer.write(f"{key} {float(focus_rate):.5f}\n") + + # Plot attention weight + att_ws = att_ws.cpu().numpy() + + if att_ws.ndim == 2: + att_ws = att_ws[None][None] + elif att_ws.ndim != 4: + raise RuntimeError(f"Must be 2 or 4 dimension: {att_ws.ndim}") + + w, h = plt.figaspect(att_ws.shape[0] / att_ws.shape[1]) + fig = plt.Figure( + figsize=( + w * 1.3 * min(att_ws.shape[0], 2.5), + h * 1.3 * min(att_ws.shape[1], 2.5), + ) + ) + fig.suptitle(f"{key}") + axes = fig.subplots(att_ws.shape[0], att_ws.shape[1]) + if len(att_ws) == 1: + axes = [[axes]] + for ax, att_w in zip(axes, att_ws): + for ax_, att_w_ in zip(ax, att_w): + ax_.imshow(att_w_.astype(np.float32), aspect="auto") + ax_.set_xlabel("Input") + ax_.set_ylabel("Output") + ax_.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax_.yaxis.set_major_locator(MaxNLocator(integer=True)) + + fig.set_tight_layout({"rect": [0, 0.03, 1, 0.95]}) + fig.savefig(output_dir / f"att_ws/{key}.png") + fig.clf() + + if probs is not None: + # Plot stop token prediction + probs = probs.cpu().numpy() + + fig = plt.Figure() + ax = fig.add_subplot(1, 1, 1) + ax.plot(probs) + ax.set_title(f"{key}") + ax.set_xlabel("Output") + ax.set_ylabel("Stop probability") + ax.set_ylim(0, 1) + ax.grid(which="both") + + fig.set_tight_layout(True) + fig.savefig(output_dir / f"probs/{key}.png") + fig.clf() + + # TODO(kamo): Write scp + if wav is not None: + sf.write( + f"{output_dir}/wav/{key}.wav", wav.numpy(), text2speech.fs, "PCM_16" + ) + + if ref_embs is not None: + ref_emb_key = -1 + for index, ref_emb in enumerate(ref_embs_list): + if torch.equal(ref_emb, ref_embs): + ref_emb_key = index + if ref_emb_key == -1: + ref_emb_key = len(ref_embs_list) + ref_embs_list.append(ref_embs) + ref_embs_dict[ref_emb_key].append(key) + + ref_embs_writer.write(json.dumps(ref_embs_dict)) + for index, ref_emb in enumerate(ref_embs_list): + filename = "ref_embs_" + str(index) + ".pt" + torch.save(ref_emb, output_dir / filename) + + # remove duration related files if attention is not provided + if att_ws is None: + shutil.rmtree(output_dir / "att_ws") + shutil.rmtree(output_dir / "durations") + shutil.rmtree(output_dir / "focus_rates") + if probs is None: + shutil.rmtree(output_dir / "probs") + + +def get_parser(): + """Get argument parser.""" + parser = config_argparse.ArgumentParser( + description="TTS Decode", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use "_" instead of "-" as separator. + # "-" is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="The path of output directory", + ) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed", + ) + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=True, + action="append", + ) + group.add_argument( + "--key_file", + type=str_or_none, + ) + group.add_argument( + "--allow_variable_data_keys", + type=str2bool, + default=False, + ) + group.add_argument( + "--ref_embs", + type=str, + default=False, + ) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--train_config", + type=str, + help="Training configuration file.", + ) + group.add_argument( + "--model_file", + type=str, + help="Model parameter file.", + ) + + group = parser.add_argument_group("Decoding related") + group.add_argument( + "--maxlenratio", + type=float, + default=10.0, + help="Maximum length ratio in decoding", + ) + group.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Minimum length ratio in decoding", + ) + group.add_argument( + "--threshold", + type=float, + default=0.5, + help="Threshold value in decoding", + ) + group.add_argument( + "--use_att_constraint", + type=str2bool, + default=False, + help="Whether to use attention constraint", + ) + group.add_argument( + "--backward_window", + type=int, + default=1, + help="Backward window value in attention constraint", + ) + group.add_argument( + "--forward_window", + type=int, + default=3, + help="Forward window value in attention constraint", + ) + group.add_argument( + "--use_teacher_forcing", + type=str2bool, + default=False, + help="Whether to use teacher forcing", + ) + parser.add_argument( + "--speed_control_alpha", + type=float, + default=1.0, + help="Alpha in FastSpeech to change the speed of generated speech", + ) + + group = parser.add_argument_group("Grriffin-Lim related") + group.add_argument( + "--vocoder_conf", + action=NestedDictAction, + default=get_default_kwargs(Spectrogram2Waveform), + help="The configuration for Grriffin-Lim", + ) + return parser + + +def main(cmd=None): + """Run TTS model decoding.""" + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + inference(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/tts_prior_train.py b/espnet2/bin/tts_prior_train.py new file mode 100644 index 0000000000000000000000000000000000000000..24dea4416abfb9256debb9032f66ba35ee901a08 --- /dev/null +++ b/espnet2/bin/tts_prior_train.py @@ -0,0 +1,463 @@ +#!/usr/bin/env python3 + +"""TTS model AR prior training.""" + +import argparse +import logging +from pathlib import Path +import sys +import time +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from typeguard import check_argument_types + +from espnet.utils.cli_utils import get_commandline_args +from espnet2.tasks.tts import TTSTask +from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.set_all_random_seed import set_all_random_seed +from espnet2.tts.duration_calculator import DurationCalculator +from espnet2.tts.fastspeech import FastSpeech +from espnet2.tts.fastspeech2 import FastSpeech2 +from espnet2.tts.fastespeech import FastESpeech +from espnet2.tts.tacotron2 import Tacotron2 +from espnet2.tts.transformer import Transformer +from espnet2.utils import config_argparse +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet2.utils.griffin_lim import Spectrogram2Waveform +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import str2bool +from espnet2.utils.types import str2triple_str +from espnet2.utils.types import str_or_none + +from espnet2.tts.prosody_encoder import ARPrior + +import torch.optim as optim + + +class Text2Speech: + """Speech2Text class + """ + + def __init__( + self, + train_config: Optional[Union[Path, str]], + model_file: Optional[Union[Path, str]] = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 10.0, + use_teacher_forcing: bool = False, + use_att_constraint: bool = False, + backward_window: int = 1, + forward_window: int = 3, + speed_control_alpha: float = 1.0, + vocoder_conf: dict = None, + dtype: str = "float32", + device: str = "cpu", + ): + assert check_argument_types() + + model, train_args = TTSTask.build_model_from_file( + train_config, model_file, device + ) + model.to(dtype=getattr(torch, dtype)).eval() + self.device = device + self.dtype = dtype + self.train_args = train_args + self.model = model + self.tts = model.tts + self.normalize = model.normalize + self.feats_extract = model.feats_extract + self.duration_calculator = DurationCalculator() + self.preprocess_fn = TTSTask.build_preprocess_fn(train_args, False) + self.use_teacher_forcing = use_teacher_forcing + + logging.info(f"Normalization:\n{self.normalize}") + logging.info(f"TTS:\n{self.tts}") + + decode_config = {} + if isinstance(self.tts, (Tacotron2, Transformer)): + decode_config.update( + { + "threshold": threshold, + "maxlenratio": maxlenratio, + "minlenratio": minlenratio, + } + ) + if isinstance(self.tts, Tacotron2): + decode_config.update( + { + "use_att_constraint": use_att_constraint, + "forward_window": forward_window, + "backward_window": backward_window, + } + ) + if isinstance(self.tts, (FastSpeech, FastSpeech2, FastESpeech)): + decode_config.update({"alpha": speed_control_alpha}) + decode_config.update({"use_teacher_forcing": use_teacher_forcing}) + + self.decode_config = decode_config + + if vocoder_conf is None: + vocoder_conf = {} + if self.feats_extract is not None: + vocoder_conf.update(self.feats_extract.get_parameters()) + if ( + "n_fft" in vocoder_conf + and "n_shift" in vocoder_conf + and "fs" in vocoder_conf + ): + self.spc2wav = Spectrogram2Waveform(**vocoder_conf) + logging.info(f"Vocoder: {self.spc2wav}") + else: + self.spc2wav = None + logging.info("Vocoder is not used because vocoder_conf is not sufficient") + + def __call__( + self, + text: Union[str, torch.Tensor, np.ndarray], + speech: Union[torch.Tensor, np.ndarray] = None, + durations: Union[torch.Tensor, np.ndarray] = None, + ref_embs: torch.Tensor = None, + ): + assert check_argument_types() + + if self.use_speech and speech is None: + raise RuntimeError("missing required argument: 'speech'") + + if isinstance(text, str): + # str -> np.ndarray + text = self.preprocess_fn("", {"text": text})["text"] + batch = {"text": text, "ref_embs": ref_embs} + if speech is not None: + batch["speech"] = speech + if durations is not None: + batch["durations"] = durations + + batch = to_device(batch, self.device) + outs, outs_denorm, probs, att_ws, ref_embs, ar_prior_loss = \ + self.model.inference(**batch, **self.decode_config, train_ar_prior=True) + + return ar_prior_loss + + @property + def fs(self) -> Optional[int]: + if self.spc2wav is not None: + return self.spc2wav.fs + else: + return None + + @property + def use_speech(self) -> bool: + """Check whether to require speech in inference. + + Returns: + bool: True if speech is required else False. + + """ + # TC marker, oorspr false + return self.use_teacher_forcing or getattr(self.tts, "use_gst", True) + + +def train_prior( + output_dir: str, + batch_size: int, + dtype: str, + ngpu: int, + seed: int, + num_workers: int, + log_level: Union[int, str], + data_path_and_name_and_type: Sequence[Tuple[str, str, str]], + key_file: Optional[str], + train_config: Optional[str], + model_file: Optional[str], + threshold: float, + minlenratio: float, + maxlenratio: float, + use_teacher_forcing: bool, + use_att_constraint: bool, + backward_window: int, + forward_window: int, + speed_control_alpha: float, + allow_variable_data_keys: bool, + vocoder_conf: dict, +): + """Perform AR prior training.""" + assert check_argument_types() + if batch_size > 1: + raise NotImplementedError("batch AR prior training is not implemented") + if ngpu > 1: + raise NotImplementedError("only single GPU AR prior training is supported") + logging.basicConfig( + level=log_level, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + + if ngpu >= 1: + device = "cuda" + else: + device = "cpu" + + # 1. Set random-seed + set_all_random_seed(seed) + + # 2. Build model + text2speech = Text2Speech( + train_config=train_config, + model_file=model_file, + threshold=threshold, + maxlenratio=maxlenratio, + minlenratio=minlenratio, + use_teacher_forcing=use_teacher_forcing, + use_att_constraint=use_att_constraint, + backward_window=backward_window, + forward_window=forward_window, + speed_control_alpha=speed_control_alpha, + vocoder_conf=vocoder_conf, + dtype=dtype, + device=device, + ) + + # 3. Build data-iterator + if not text2speech.use_speech: + data_path_and_name_and_type = list( + filter(lambda x: x[1] != "speech", data_path_and_name_and_type) + ) + loader = TTSTask.build_streaming_iterator( + data_path_and_name_and_type, + dtype=dtype, + batch_size=batch_size, + key_file=key_file, + num_workers=num_workers, + preprocess_fn=TTSTask.build_preprocess_fn(text2speech.train_args, False), + collate_fn=TTSTask.build_collate_fn(text2speech.train_args, False), + allow_variable_data_keys=allow_variable_data_keys, + inference=True, + ) + + num_epochs = 500 + + # Freeze model + for param in text2speech.model.parameters(): + param.requires_grad = False + + text2speech.model.tts.prosody_encoder.ar_prior = ARPrior( + num_embeddings=32, + embedding_dim=384, + lstm_num_layers=1, + lstm_bidirectional=False, + ) + + text2speech.model.tts = text2speech.model.tts.to(device) + + optimizer = optim.SGD(text2speech.model.tts.parameters(), lr=0.001, momentum=0.9) + + since = time.time() + + for epoch in range(num_epochs): + print('Epoch {}/{}'.format(epoch, num_epochs - 1)) + print('-' * 10) + + # Each epoch has a training and validation phase + for phase in ['train']: # 'val' + if phase == 'train': + text2speech.model.tts.train() # Set model to training mode + else: + text2speech.model.tts.eval() # Set model to evaluate mode + + for idx, (keys, batch) in enumerate(loader, 1): + assert isinstance(batch, dict), type(batch) + assert all(isinstance(s, str) for s in keys), keys + _bs = len(next(iter(batch.values()))) + assert _bs == 1, _bs + + # Change to single sequence and remove *_length + # because inference() requires 1-seq, not mini-batch. + batch = { + k: v[0] for k, v in batch.items() if not k.endswith("_lengths") + } + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + # track history if only in train + with torch.set_grad_enabled(phase == 'train'): + loss = text2speech(**batch) + + # backward + optimize only if in training phase + if phase == 'train': + loss.backward() + optimizer.step() + + print('Loss: {:.4f}'.format(loss)) + + if epoch % 10 == 0: + torch.save(text2speech.model.state_dict(), "exp/tts_train_raw_phn_none/with_prior_" + str(epoch) + ".pth") + + time_elapsed = time.time() - since + print('Training complete in {:.0f}m {:.0f}s'.format( + time_elapsed // 60, time_elapsed % 60)) + + torch.save(text2speech.model.state_dict(), "exp/tts_train_raw_phn_none/with_prior.pth") + + +def get_parser(): + """Get argument parser.""" + parser = config_argparse.ArgumentParser( + description="TTS Decode", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Note(kamo): Use "_" instead of "-" as separator. + # "-" is confusing if written in yaml. + parser.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="The path of output directory", + ) + parser.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed", + ) + parser.add_argument( + "--dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type", + ) + parser.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="The batch size for inference", + ) + + group = parser.add_argument_group("Input data related") + group.add_argument( + "--data_path_and_name_and_type", + type=str2triple_str, + required=True, + action="append", + ) + group.add_argument( + "--key_file", + type=str_or_none, + ) + group.add_argument( + "--allow_variable_data_keys", + type=str2bool, + default=False, + ) + + group = parser.add_argument_group("The model configuration related") + group.add_argument( + "--train_config", + type=str, + help="Training configuration file.", + ) + group.add_argument( + "--model_file", + type=str, + help="Model parameter file.", + ) + + group = parser.add_argument_group("Decoding related") + group.add_argument( + "--maxlenratio", + type=float, + default=10.0, + help="Maximum length ratio in decoding", + ) + group.add_argument( + "--minlenratio", + type=float, + default=0.0, + help="Minimum length ratio in decoding", + ) + group.add_argument( + "--threshold", + type=float, + default=0.5, + help="Threshold value in decoding", + ) + group.add_argument( + "--use_att_constraint", + type=str2bool, + default=False, + help="Whether to use attention constraint", + ) + group.add_argument( + "--backward_window", + type=int, + default=1, + help="Backward window value in attention constraint", + ) + group.add_argument( + "--forward_window", + type=int, + default=3, + help="Forward window value in attention constraint", + ) + group.add_argument( + "--use_teacher_forcing", + type=str2bool, + default=False, + help="Whether to use teacher forcing", + ) + parser.add_argument( + "--speed_control_alpha", + type=float, + default=1.0, + help="Alpha in FastSpeech to change the speed of generated speech", + ) + + group = parser.add_argument_group("Grriffin-Lim related") + group.add_argument( + "--vocoder_conf", + action=NestedDictAction, + default=get_default_kwargs(Spectrogram2Waveform), + help="The configuration for Grriffin-Lim", + ) + return parser + + +def main(cmd=None): + """Run TTS model decoding.""" + print(get_commandline_args(), file=sys.stderr) + parser = get_parser() + args = parser.parse_args(cmd) + kwargs = vars(args) + kwargs.pop("config", None) + train_prior(**kwargs) + + +if __name__ == "__main__": + main() diff --git a/espnet2/bin/tts_train.py b/espnet2/bin/tts_train.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf487b8f1ddabe514477aab6894733fb8672a66 --- /dev/null +++ b/espnet2/bin/tts_train.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +from espnet2.tasks.tts import TTSTask + + +def get_parser(): + parser = TTSTask.get_parser() + return parser + + +def main(cmd=None): + """TTS training + + Example: + + % python tts_train.py asr --print_config --optim adadelta + % python tts_train.py --config conf/train_asr.yaml + """ + TTSTask.main(cmd=cmd) + + +if __name__ == "__main__": + main() diff --git a/espnet2/diar/__init__.py b/espnet2/diar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/diar/abs_diar.py b/espnet2/diar/abs_diar.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb2f2b2cc25e8870bae5a6b644962129d63b790 --- /dev/null +++ b/espnet2/diar/abs_diar.py @@ -0,0 +1,26 @@ +from abc import ABC +from abc import abstractmethod +from collections import OrderedDict +from typing import Tuple + +import torch + + +class AbsDiarization(torch.nn.Module, ABC): + # @abstractmethod + # def output_size(self) -> int: + # raise NotImplementedError + + @abstractmethod + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, OrderedDict]: + raise NotImplementedError + + @abstractmethod + def forward_rawwav( + self, input: torch.Tensor, ilens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, OrderedDict]: + raise NotImplementedError diff --git a/espnet2/diar/decoder/__init__.py b/espnet2/diar/decoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/diar/decoder/abs_decoder.py b/espnet2/diar/decoder/abs_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..bd9a167414437fe0c55159908a6a4f54d8a2cb7d --- /dev/null +++ b/espnet2/diar/decoder/abs_decoder.py @@ -0,0 +1,20 @@ +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + + +class AbsDecoder(torch.nn.Module, ABC): + @abstractmethod + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + @property + @abstractmethod + def num_spk(self): + raise NotImplementedError diff --git a/espnet2/diar/decoder/linear_decoder.py b/espnet2/diar/decoder/linear_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..61a3dea0aa73921f99dccdf83fd6a35f6d5a1585 --- /dev/null +++ b/espnet2/diar/decoder/linear_decoder.py @@ -0,0 +1,32 @@ +import torch + +from espnet2.diar.decoder.abs_decoder import AbsDecoder + + +class LinearDecoder(AbsDecoder): + """Linear decoder for speaker diarization """ + + def __init__( + self, + encoder_output_size: int, + num_spk: int = 2, + ): + super().__init__() + self._num_spk = num_spk + self.linear_decoder = torch.nn.Linear(encoder_output_size, num_spk) + + def forward(self, input: torch.Tensor, ilens: torch.Tensor): + """Forward. + + Args: + input (torch.Tensor): hidden_space [Batch, T, F] + ilens (torch.Tensor): input lengths [Batch] + """ + + output = self.linear_decoder(input) + + return output + + @property + def num_spk(self): + return self._num_spk diff --git a/espnet2/diar/espnet_model.py b/espnet2/diar/espnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cf923c8b7dca931c3e42124da1240538ae050403 --- /dev/null +++ b/espnet2/diar/espnet_model.py @@ -0,0 +1,281 @@ +# Copyright 2021 Jiatong Shi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +from contextlib import contextmanager +from distutils.version import LooseVersion +from itertools import permutations +from typing import Dict +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import to_device +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.diar.decoder.abs_decoder import AbsDecoder +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.train.abs_espnet_model import AbsESPnetModel + + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class ESPnetDiarizationModel(AbsESPnetModel): + """Speaker Diarization model""" + + def __init__( + self, + frontend: Optional[AbsFrontend], + normalize: Optional[AbsNormalize], + label_aggregator: torch.nn.Module, + encoder: AbsEncoder, + decoder: AbsDecoder, + loss_type: str = "pit", # only support pit loss for now + ): + assert check_argument_types() + + super().__init__() + + self.encoder = encoder + self.decoder = decoder + self.num_spk = decoder.num_spk + self.normalize = normalize + self.frontend = frontend + self.label_aggregator = label_aggregator + self.loss_type = loss_type + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor = None, + spk_labels: torch.Tensor = None, + spk_labels_lengths: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech: (Batch, samples) + speech_lengths: (Batch,) default None for chunk interator, + because the chunk-iterator does not + have the speech_lengths returned. + see in + espnet2/iterators/chunk_iter_factory.py + spk_labels: (Batch, ) + """ + assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape) + batch_size = speech.shape[0] + + # 1. Encoder + encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) + + # 2. Decoder (baiscally a predction layer after encoder_out) + pred = self.decoder(encoder_out, encoder_out_lens) + + # 3. Aggregate time-domain labels + spk_labels, spk_labels_lengths = self.label_aggregator( + spk_labels, spk_labels_lengths + ) + + if self.loss_type == "pit": + loss, perm_idx, perm_list, label_perm = self.pit_loss( + pred, spk_labels, encoder_out_lens + ) + + ( + correct, + num_frames, + speech_scored, + speech_miss, + speech_falarm, + speaker_scored, + speaker_miss, + speaker_falarm, + speaker_error, + ) = self.calc_diarization_error(pred, label_perm, encoder_out_lens) + + if speech_scored > 0 and num_frames > 0: + sad_mr, sad_fr, mi, fa, cf, acc, der = ( + speech_miss / speech_scored, + speech_falarm / speech_scored, + speaker_miss / speaker_scored, + speaker_falarm / speaker_scored, + speaker_error / speaker_scored, + correct / num_frames, + (speaker_miss + speaker_falarm + speaker_error) / speaker_scored, + ) + else: + sad_mr, sad_fr, mi, fa, cf, acc, der = 0, 0, 0, 0, 0, 0, 0 + stats = dict( + loss=loss.detach(), + sad_mr=sad_mr, + sad_fr=sad_fr, + mi=mi, + fa=fa, + cf=cf, + acc=acc, + der=der, + ) + else: + raise NotImplementedError + + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def collect_feats( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + spk_labels: torch.Tensor = None, + spk_labels_lengths: torch.Tensor = None, + ) -> Dict[str, torch.Tensor]: + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + return {"feats": feats, "feats_lengths": feats_lengths} + + def encode( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Frontend + Encoder + + Args: + speech: (Batch, Length, ...) + speech_lengths: (Batch,) + """ + with autocast(False): + # 1. Extract feats + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + + # 2. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + + # 3. Forward encoder + # feats: (Batch, Length, Dim) + # -> encoder_out: (Batch, Length2, Dim) + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) + + assert encoder_out.size(0) == speech.size(0), ( + encoder_out.size(), + speech.size(0), + ) + assert encoder_out.size(1) <= encoder_out_lens.max(), ( + encoder_out.size(), + encoder_out_lens.max(), + ) + + return encoder_out, encoder_out_lens + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = speech.shape[0] + speech_lengths = ( + speech_lengths + if speech_lengths is not None + else torch.ones(batch_size).int() * speech.shape[1] + ) + + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def pit_loss_single_permute(self, pred, label, length): + bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") + mask = self.create_length_mask(length, label.size(1), label.size(2)) + loss = bce_loss(pred, label) + loss = loss * mask + loss = torch.sum(torch.mean(loss, dim=2), dim=1) + loss = torch.unsqueeze(loss, dim=1) + return loss + + def pit_loss(self, pred, label, lengths): + # Note (jiatong): Credit to https://github.com/hitachi-speech/EEND + num_output = label.size(2) + permute_list = [np.array(p) for p in permutations(range(num_output))] + loss_list = [] + for p in permute_list: + label_perm = label[:, :, p] + loss_perm = self.pit_loss_single_permute(pred, label_perm, lengths) + loss_list.append(loss_perm) + loss = torch.cat(loss_list, dim=1) + min_loss, min_idx = torch.min(loss, dim=1) + loss = torch.sum(min_loss) / torch.sum(lengths.float()) + batch_size = len(min_idx) + label_list = [] + for i in range(batch_size): + label_list.append(label[i, :, permute_list[min_idx[i]]].data.cpu().numpy()) + label_permute = torch.from_numpy(np.array(label_list)).float() + return loss, min_idx, permute_list, label_permute + + def create_length_mask(self, length, max_len, num_output): + batch_size = len(length) + mask = torch.zeros(batch_size, max_len, num_output) + for i in range(batch_size): + mask[i, : length[i], :] = 1 + mask = to_device(self, mask) + return mask + + @staticmethod + def calc_diarization_error(pred, label, length): + # Note (jiatong): Credit to https://github.com/hitachi-speech/EEND + + (batch_size, max_len, num_output) = label.size() + # mask the padding part + mask = np.zeros((batch_size, max_len, num_output)) + for i in range(batch_size): + mask[i, : length[i], :] = 1 + + # pred and label have the shape (batch_size, max_len, num_output) + label_np = label.data.cpu().numpy().astype(int) + pred_np = (pred.data.cpu().numpy() > 0).astype(int) + label_np = label_np * mask + pred_np = pred_np * mask + length = length.data.cpu().numpy() + + # compute speech activity detection error + n_ref = np.sum(label_np, axis=2) + n_sys = np.sum(pred_np, axis=2) + speech_scored = float(np.sum(n_ref > 0)) + speech_miss = float(np.sum(np.logical_and(n_ref > 0, n_sys == 0))) + speech_falarm = float(np.sum(np.logical_and(n_ref == 0, n_sys > 0))) + + # compute speaker diarization error + speaker_scored = float(np.sum(n_ref)) + speaker_miss = float(np.sum(np.maximum(n_ref - n_sys, 0))) + speaker_falarm = float(np.sum(np.maximum(n_sys - n_ref, 0))) + n_map = np.sum(np.logical_and(label_np == 1, pred_np == 1), axis=2) + speaker_error = float(np.sum(np.minimum(n_ref, n_sys) - n_map)) + correct = float(1.0 * np.sum((label_np == pred_np) * mask) / num_output) + num_frames = np.sum(length) + return ( + correct, + num_frames, + speech_scored, + speech_miss, + speech_falarm, + speaker_scored, + speaker_miss, + speaker_falarm, + speaker_error, + ) diff --git a/espnet2/diar/label_processor.py b/espnet2/diar/label_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..3c981d997447ee6f91f4b9e3b5f6172304ed9559 --- /dev/null +++ b/espnet2/diar/label_processor.py @@ -0,0 +1,29 @@ +import torch + +from espnet2.layers.label_aggregation import LabelAggregate + + +class LabelProcessor(torch.nn.Module): + """Label aggregator for speaker diarization """ + + def __init__( + self, win_length: int = 512, hop_length: int = 128, center: bool = True + ): + super().__init__() + self.label_aggregator = LabelAggregate(win_length, hop_length, center) + + def forward(self, input: torch.Tensor, ilens: torch.Tensor): + """Forward. + + Args: + input: (Batch, Nsamples, Label_dim) + ilens: (Batch) + Returns: + output: (Batch, Frames, Label_dim) + olens: (Batch) + + """ + + output, olens = self.label_aggregator(input, ilens) + + return output, olens diff --git a/espnet2/enh/__init__.py b/espnet2/enh/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/enh/abs_enh.py b/espnet2/enh/abs_enh.py new file mode 100644 index 0000000000000000000000000000000000000000..c28745e26d1089e47e8ec9d28c4bb627cc70ba64 --- /dev/null +++ b/espnet2/enh/abs_enh.py @@ -0,0 +1,26 @@ +from abc import ABC +from abc import abstractmethod +from collections import OrderedDict +from typing import Tuple + +import torch + + +class AbsEnhancement(torch.nn.Module, ABC): + # @abstractmethod + # def output_size(self) -> int: + # raise NotImplementedError + + @abstractmethod + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, OrderedDict]: + raise NotImplementedError + + @abstractmethod + def forward_rawwav( + self, input: torch.Tensor, ilens: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, OrderedDict]: + raise NotImplementedError diff --git a/espnet2/enh/decoder/abs_decoder.py b/espnet2/enh/decoder/abs_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab8cb6a557f19d2a19b37de7ba02851513fd0c4 --- /dev/null +++ b/espnet2/enh/decoder/abs_decoder.py @@ -0,0 +1,15 @@ +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + + +class AbsDecoder(torch.nn.Module, ABC): + @abstractmethod + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/espnet2/enh/decoder/conv_decoder.py b/espnet2/enh/decoder/conv_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..aad83a243e8f66baf243d4fe41464324d769a6db --- /dev/null +++ b/espnet2/enh/decoder/conv_decoder.py @@ -0,0 +1,32 @@ +import torch + +from espnet2.enh.decoder.abs_decoder import AbsDecoder + + +class ConvDecoder(AbsDecoder): + """Transposed Convolutional decoder for speech enhancement and separation """ + + def __init__( + self, + channel: int, + kernel_size: int, + stride: int, + ): + super().__init__() + self.convtrans1d = torch.nn.ConvTranspose1d( + channel, 1, kernel_size, bias=False, stride=stride + ) + + def forward(self, input: torch.Tensor, ilens: torch.Tensor): + """Forward. + + Args: + input (torch.Tensor): spectrum [Batch, T, F] + ilens (torch.Tensor): input lengths [Batch] + """ + input = input.transpose(1, 2) + batch_size = input.shape[0] + wav = self.convtrans1d(input, output_size=(batch_size, 1, ilens.max())) + wav = wav.squeeze(1) + + return wav, ilens diff --git a/espnet2/enh/decoder/null_decoder.py b/espnet2/enh/decoder/null_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d62776a9e7e2d1d384e59f7a4d4f37bf263e66bd --- /dev/null +++ b/espnet2/enh/decoder/null_decoder.py @@ -0,0 +1,19 @@ +import torch + +from espnet2.enh.decoder.abs_decoder import AbsDecoder + + +class NullDecoder(AbsDecoder): + """Null decoder, return the same args.""" + + def __init__(self): + super().__init__() + + def forward(self, input: torch.Tensor, ilens: torch.Tensor): + """Forward. The input should be the waveform already. + + Args: + input (torch.Tensor): wav [Batch, sample] + ilens (torch.Tensor): input lengths [Batch] + """ + return input, ilens diff --git a/espnet2/enh/decoder/stft_decoder.py b/espnet2/enh/decoder/stft_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..329e8b771fe1dac219bc751a19f0d415bc22947a --- /dev/null +++ b/espnet2/enh/decoder/stft_decoder.py @@ -0,0 +1,44 @@ +import torch +from torch_complex.tensor import ComplexTensor + +from espnet2.enh.decoder.abs_decoder import AbsDecoder +from espnet2.layers.stft import Stft + + +class STFTDecoder(AbsDecoder): + """STFT decoder for speech enhancement and separation """ + + def __init__( + self, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window="hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + ): + super().__init__() + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + + def forward(self, input: ComplexTensor, ilens: torch.Tensor): + """Forward. + + Args: + input (ComplexTensor): spectrum [Batch, T, F] + ilens (torch.Tensor): input lengths [Batch] + """ + if not isinstance(input, ComplexTensor): + raise TypeError("Only support ComplexTensor for stft decoder") + + wav, wav_lens = self.stft.inverse(input, ilens) + + return wav, wav_lens diff --git a/espnet2/enh/encoder/abs_encoder.py b/espnet2/enh/encoder/abs_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1afb68213b670e1ad5cb7135ade64603e80b0b --- /dev/null +++ b/espnet2/enh/encoder/abs_encoder.py @@ -0,0 +1,20 @@ +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + + +class AbsEncoder(torch.nn.Module, ABC): + @abstractmethod + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + @property + @abstractmethod + def output_dim(self) -> int: + raise NotImplementedError diff --git a/espnet2/enh/encoder/conv_encoder.py b/espnet2/enh/encoder/conv_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a51b11daf0e9e758479fd7e54572a4b7ad5114 --- /dev/null +++ b/espnet2/enh/encoder/conv_encoder.py @@ -0,0 +1,47 @@ +import torch + +from espnet2.enh.encoder.abs_encoder import AbsEncoder + + +class ConvEncoder(AbsEncoder): + """Convolutional encoder for speech enhancement and separation """ + + def __init__( + self, + channel: int, + kernel_size: int, + stride: int, + ): + super().__init__() + self.conv1d = torch.nn.Conv1d( + 1, channel, kernel_size=kernel_size, stride=stride, bias=False + ) + self.stride = stride + self.kernel_size = kernel_size + + self._output_dim = channel + + @property + def output_dim(self) -> int: + return self._output_dim + + def forward(self, input: torch.Tensor, ilens: torch.Tensor): + """Forward. + + Args: + input (torch.Tensor): mixed speech [Batch, sample] + ilens (torch.Tensor): input lengths [Batch] + Returns: + feature (torch.Tensor): mixed feature after encoder [Batch, flens, channel] + """ + assert input.dim() == 2, "Currently only support single channle input" + + input = torch.unsqueeze(input, 1) + + feature = self.conv1d(input) + feature = torch.nn.functional.relu(feature) + feature = feature.transpose(1, 2) + + flens = (ilens - self.kernel_size) // self.stride + 1 + + return feature, flens diff --git a/espnet2/enh/encoder/null_encoder.py b/espnet2/enh/encoder/null_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..4914c352c1c89f12875e221242025ef6b2368ad9 --- /dev/null +++ b/espnet2/enh/encoder/null_encoder.py @@ -0,0 +1,23 @@ +import torch + +from espnet2.enh.encoder.abs_encoder import AbsEncoder + + +class NullEncoder(AbsEncoder): + """Null encoder. """ + + def __init__(self): + super().__init__() + + @property + def output_dim(self) -> int: + return 1 + + def forward(self, input: torch.Tensor, ilens: torch.Tensor): + """Forward. + + Args: + input (torch.Tensor): mixed speech [Batch, sample] + ilens (torch.Tensor): input lengths [Batch] + """ + return input, ilens diff --git a/espnet2/enh/encoder/stft_encoder.py b/espnet2/enh/encoder/stft_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a81f07b2257b395b3a6990d42dd0ec9d0fb6c710 --- /dev/null +++ b/espnet2/enh/encoder/stft_encoder.py @@ -0,0 +1,51 @@ +import torch +from torch_complex.tensor import ComplexTensor + +from espnet2.enh.encoder.abs_encoder import AbsEncoder +from espnet2.layers.stft import Stft + + +class STFTEncoder(AbsEncoder): + """STFT encoder for speech enhancement and separation """ + + def __init__( + self, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window="hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + ): + super().__init__() + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + + self._output_dim = n_fft // 2 + 1 if onesided else n_fft + + @property + def output_dim(self) -> int: + return self._output_dim + + def forward(self, input: torch.Tensor, ilens: torch.Tensor): + """Forward. + + Args: + input (torch.Tensor): mixed speech [Batch, sample] + ilens (torch.Tensor): input lengths [Batch] + Returns: + stft spectrum (torch.ComplexTensor): (Batch, Frames, Freq) + or (Batch, Frames, Channels, Freq) + """ + spectrum, flens = self.stft(input, ilens) + spectrum = ComplexTensor(spectrum[..., 0], spectrum[..., 1]) + + return spectrum, flens diff --git a/espnet2/enh/espnet_model.py b/espnet2/enh/espnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4daad4e033ce1b502e9ffef790284c532d454b --- /dev/null +++ b/espnet2/enh/espnet_model.py @@ -0,0 +1,653 @@ +from distutils.version import LooseVersion +from functools import reduce +from itertools import permutations +from typing import Dict +from typing import Optional +from typing import Tuple + +import torch +from torch_complex.tensor import ComplexTensor +from typeguard import check_argument_types + +from espnet2.enh.decoder.abs_decoder import AbsDecoder +from espnet2.enh.encoder.abs_encoder import AbsEncoder +from espnet2.enh.encoder.conv_encoder import ConvEncoder +from espnet2.enh.separator.abs_separator import AbsSeparator +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.train.abs_espnet_model import AbsESPnetModel + + +is_torch_1_3_plus = LooseVersion(torch.__version__) >= LooseVersion("1.3.0") +ALL_LOSS_TYPES = ( + # mse_loss(predicted_mask, target_label) + "mask_mse", + # mse_loss(enhanced_magnitude_spectrum, target_magnitude_spectrum) + "magnitude", + # mse_loss(enhanced_complex_spectrum, target_complex_spectrum) + "spectrum", + # log_mse_loss(enhanced_complex_spectrum, target_complex_spectrum) + "spectrum_log", + # si_snr(enhanced_waveform, target_waveform) + "si_snr", +) +EPS = torch.finfo(torch.get_default_dtype()).eps + + +class ESPnetEnhancementModel(AbsESPnetModel): + """Speech enhancement or separation Frontend model""" + + def __init__( + self, + encoder: AbsEncoder, + separator: AbsSeparator, + decoder: AbsDecoder, + stft_consistency: bool = False, + loss_type: str = "mask_mse", + mask_type: Optional[str] = None, + ): + assert check_argument_types() + + super().__init__() + + self.encoder = encoder + self.separator = separator + self.decoder = decoder + self.num_spk = separator.num_spk + self.num_noise_type = getattr(self.separator, "num_noise_type", 1) + + if loss_type != "si_snr" and isinstance(encoder, ConvEncoder): + raise TypeError(f"{loss_type} is not supported with {type(ConvEncoder)}") + + # get mask type for TF-domain models (only used when loss_type="mask_*") + self.mask_type = mask_type.upper() if mask_type else None + # get loss type for model training + self.loss_type = loss_type + # whether to compute the TF-domain loss while enforcing STFT consistency + self.stft_consistency = stft_consistency + + if stft_consistency and loss_type in ["mask_mse", "si_snr"]: + raise ValueError( + f"stft_consistency will not work when '{loss_type}' loss is used" + ) + + assert self.loss_type in ALL_LOSS_TYPES, self.loss_type + # for multi-channel signal + self.ref_channel = getattr(self.separator, "ref_channel", -1) + + @staticmethod + def _create_mask_label(mix_spec, ref_spec, mask_type="IAM"): + """Create mask label. + + Args: + mix_spec: ComplexTensor(B, T, F) + ref_spec: List[ComplexTensor(B, T, F), ...] + mask_type: str + Returns: + labels: List[Tensor(B, T, F), ...] or List[ComplexTensor(B, T, F), ...] + """ + + # Must be upper case + assert mask_type in [ + "IBM", + "IRM", + "IAM", + "PSM", + "NPSM", + "PSM^2", + ], f"mask type {mask_type} not supported" + mask_label = [] + for r in ref_spec: + mask = None + if mask_type == "IBM": + flags = [abs(r) >= abs(n) for n in ref_spec] + mask = reduce(lambda x, y: x * y, flags) + mask = mask.int() + elif mask_type == "IRM": + # TODO(Wangyou): need to fix this, + # as noise referecens are provided separately + mask = abs(r) / (sum(([abs(n) for n in ref_spec])) + EPS) + elif mask_type == "IAM": + mask = abs(r) / (abs(mix_spec) + EPS) + mask = mask.clamp(min=0, max=1) + elif mask_type == "PSM" or mask_type == "NPSM": + phase_r = r / (abs(r) + EPS) + phase_mix = mix_spec / (abs(mix_spec) + EPS) + # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b) + cos_theta = ( + phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag + ) + mask = (abs(r) / (abs(mix_spec) + EPS)) * cos_theta + mask = ( + mask.clamp(min=0, max=1) + if mask_type == "NPSM" + else mask.clamp(min=-1, max=1) + ) + elif mask_type == "PSM^2": + # This is for training beamforming masks + phase_r = r / (abs(r) + EPS) + phase_mix = mix_spec / (abs(mix_spec) + EPS) + # cos(a - b) = cos(a)*cos(b) + sin(a)*sin(b) + cos_theta = ( + phase_r.real * phase_mix.real + phase_r.imag * phase_mix.imag + ) + mask = (abs(r).pow(2) / (abs(mix_spec).pow(2) + EPS)) * cos_theta + mask = mask.clamp(min=-1, max=1) + assert mask is not None, f"mask type {mask_type} not supported" + mask_label.append(mask) + return mask_label + + def forward( + self, + speech_mix: torch.Tensor, + speech_mix_lengths: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech_mix: (Batch, samples) or (Batch, samples, channels) + speech_ref: (Batch, num_speaker, samples) + or (Batch, num_speaker, samples, channels) + speech_mix_lengths: (Batch,), default None for chunk interator, + because the chunk-iterator does not have the + speech_lengths returned. see in + espnet2/iterators/chunk_iter_factory.py + """ + # clean speech signal of each speaker + speech_ref = [ + kwargs["speech_ref{}".format(spk + 1)] for spk in range(self.num_spk) + ] + # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) + speech_ref = torch.stack(speech_ref, dim=1) + + if "noise_ref1" in kwargs: + # noise signal (optional, required when using + # frontend models with beamformering) + noise_ref = [ + kwargs["noise_ref{}".format(n + 1)] for n in range(self.num_noise_type) + ] + # (Batch, num_noise_type, samples) or + # (Batch, num_noise_type, samples, channels) + noise_ref = torch.stack(noise_ref, dim=1) + else: + noise_ref = None + + # dereverberated (noisy) signal + # (optional, only used for frontend models with WPE) + if "dereverb_ref1" in kwargs: + # noise signal (optional, required when using + # frontend models with beamformering) + dereverb_speech_ref = [ + kwargs["dereverb_ref{}".format(n + 1)] + for n in range(self.num_spk) + if "dereverb_ref{}".format(n + 1) in kwargs + ] + assert len(dereverb_speech_ref) in (1, self.num_spk), len( + dereverb_speech_ref + ) + # (Batch, N, samples) or (Batch, N, samples, channels) + dereverb_speech_ref = torch.stack(dereverb_speech_ref, dim=1) + else: + dereverb_speech_ref = None + + batch_size = speech_mix.shape[0] + speech_lengths = ( + speech_mix_lengths + if speech_mix_lengths is not None + else torch.ones(batch_size).int().fill_(speech_mix.shape[1]) + ) + assert speech_lengths.dim() == 1, speech_lengths.shape + # Check that batch_size is unified + assert speech_mix.shape[0] == speech_ref.shape[0] == speech_lengths.shape[0], ( + speech_mix.shape, + speech_ref.shape, + speech_lengths.shape, + ) + + # for data-parallel + speech_ref = speech_ref[:, :, : speech_lengths.max()] + speech_mix = speech_mix[:, : speech_lengths.max()] + + loss, speech_pre, others, out_lengths, perm = self._compute_loss( + speech_mix, + speech_lengths, + speech_ref, + dereverb_speech_ref=dereverb_speech_ref, + noise_ref=noise_ref, + ) + + # add stats for logging + if self.loss_type != "si_snr": + if self.training: + si_snr = None + else: + speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in speech_pre] + speech_ref = torch.unbind(speech_ref, dim=1) + if speech_ref[0].dim() == 3: + # For si_snr loss, only select one channel as the reference + speech_ref = [sr[..., self.ref_channel] for sr in speech_ref] + # compute si-snr loss + si_snr_loss, perm = self._permutation_loss( + speech_ref, speech_pre, self.si_snr_loss, perm=perm + ) + si_snr = -si_snr_loss.detach() + + stats = dict( + si_snr=si_snr, + loss=loss.detach(), + ) + else: + stats = dict(si_snr=-loss.detach(), loss=loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def _compute_loss( + self, + speech_mix, + speech_lengths, + speech_ref, + dereverb_speech_ref=None, + noise_ref=None, + cal_loss=True, + ): + """Compute loss according to self.loss_type. + + Args: + speech_mix: (Batch, samples) or (Batch, samples, channels) + speech_lengths: (Batch,), default None for chunk interator, + because the chunk-iterator does not have the + speech_lengths returned. see in + espnet2/iterators/chunk_iter_factory.py + speech_ref: (Batch, num_speaker, samples) + or (Batch, num_speaker, samples, channels) + dereverb_speech_ref: (Batch, N, samples) + or (Batch, num_speaker, samples, channels) + noise_ref: (Batch, num_noise_type, samples) + or (Batch, num_speaker, samples, channels) + cal_loss: whether to calculate enh loss, defualt is True + + Returns: + loss: (torch.Tensor) speech enhancement loss + speech_pre: (List[torch.Tensor] or List[ComplexTensor]) + enhanced speech or spectrum(s) + others: (OrderedDict) estimated masks or None + output_lengths: (Batch,) + perm: () best permutation + """ + feature_mix, flens = self.encoder(speech_mix, speech_lengths) + feature_pre, flens, others = self.separator(feature_mix, flens) + + if self.loss_type != "si_snr": + spectrum_mix = feature_mix + spectrum_pre = feature_pre + # predict separated speech and masks + if self.stft_consistency: + # pseudo STFT -> time-domain -> STFT (compute loss) + tmp_t_domain = [ + self.decoder(sp, speech_lengths)[0] for sp in spectrum_pre + ] + spectrum_pre = [ + self.encoder(sp, speech_lengths)[0] for sp in tmp_t_domain + ] + pass + + if spectrum_pre is not None and not isinstance( + spectrum_pre[0], ComplexTensor + ): + spectrum_pre = [ + ComplexTensor(*torch.unbind(sp, dim=-1)) for sp in spectrum_pre + ] + + if not cal_loss: + loss, perm = None, None + return loss, spectrum_pre, others, flens, perm + + # prepare reference speech and reference spectrum + speech_ref = torch.unbind(speech_ref, dim=1) + # List[ComplexTensor(Batch, T, F)] or List[ComplexTensor(Batch, T, C, F)] + spectrum_ref = [self.encoder(sr, speech_lengths)[0] for sr in speech_ref] + + # compute TF masking loss + if self.loss_type == "magnitude": + # compute loss on magnitude spectrum + assert spectrum_pre is not None + magnitude_pre = [abs(ps + 1e-15) for ps in spectrum_pre] + if spectrum_ref[0].dim() > magnitude_pre[0].dim(): + # only select one channel as the reference + magnitude_ref = [ + abs(sr[..., self.ref_channel, :]) for sr in spectrum_ref + ] + else: + magnitude_ref = [abs(sr) for sr in spectrum_ref] + + tf_loss, perm = self._permutation_loss( + magnitude_ref, magnitude_pre, self.tf_mse_loss + ) + elif self.loss_type.startswith("spectrum"): + # compute loss on complex spectrum + if self.loss_type == "spectrum": + loss_func = self.tf_mse_loss + elif self.loss_type == "spectrum_log": + loss_func = self.tf_log_mse_loss + else: + raise ValueError("Unsupported loss type: %s" % self.loss_type) + + assert spectrum_pre is not None + if spectrum_ref[0].dim() > spectrum_pre[0].dim(): + # only select one channel as the reference + spectrum_ref = [sr[..., self.ref_channel, :] for sr in spectrum_ref] + + tf_loss, perm = self._permutation_loss( + spectrum_ref, spectrum_pre, loss_func + ) + elif self.loss_type.startswith("mask"): + if self.loss_type == "mask_mse": + loss_func = self.tf_mse_loss + else: + raise ValueError("Unsupported loss type: %s" % self.loss_type) + + assert others is not None + mask_pre_ = [ + others["mask_spk{}".format(spk + 1)] for spk in range(self.num_spk) + ] + + # prepare ideal masks + mask_ref = self._create_mask_label( + spectrum_mix, spectrum_ref, mask_type=self.mask_type + ) + + # compute TF masking loss + tf_loss, perm = self._permutation_loss(mask_ref, mask_pre_, loss_func) + + if "mask_dereverb1" in others: + if dereverb_speech_ref is None: + raise ValueError( + "No dereverberated reference for training!\n" + 'Please specify "--use_dereverb_ref true" in run.sh' + ) + + mask_wpe_pre = [ + others["mask_dereverb{}".format(spk + 1)] + for spk in range(self.num_spk) + if "mask_dereverb{}".format(spk + 1) in others + ] + assert len(mask_wpe_pre) == dereverb_speech_ref.size(1), ( + len(mask_wpe_pre), + dereverb_speech_ref.size(1), + ) + dereverb_speech_ref = torch.unbind(dereverb_speech_ref, dim=1) + dereverb_spectrum_ref = [ + self.encoder(dr, speech_lengths)[0] + for dr in dereverb_speech_ref + ] + dereverb_mask_ref = self._create_mask_label( + spectrum_mix, dereverb_spectrum_ref, mask_type=self.mask_type + ) + + tf_dereverb_loss, perm_d = self._permutation_loss( + dereverb_mask_ref, mask_wpe_pre, loss_func + ) + tf_loss = tf_loss + tf_dereverb_loss + + if "mask_noise1" in others: + if noise_ref is None: + raise ValueError( + "No noise reference for training!\n" + 'Please specify "--use_noise_ref true" in run.sh' + ) + + noise_ref = torch.unbind(noise_ref, dim=1) + noise_spectrum_ref = [ + self.encoder(nr, speech_lengths)[0] for nr in noise_ref + ] + noise_mask_ref = self._create_mask_label( + spectrum_mix, noise_spectrum_ref, mask_type=self.mask_type + ) + + mask_noise_pre = [ + others["mask_noise{}".format(n + 1)] + for n in range(self.num_noise_type) + ] + tf_noise_loss, perm_n = self._permutation_loss( + noise_mask_ref, mask_noise_pre, loss_func + ) + tf_loss = tf_loss + tf_noise_loss + else: + raise ValueError("Unsupported loss type: %s" % self.loss_type) + + loss = tf_loss + return loss, spectrum_pre, others, flens, perm + + else: + speech_pre = [self.decoder(ps, speech_lengths)[0] for ps in feature_pre] + if not cal_loss: + loss, perm = None, None + return loss, speech_pre, None, speech_lengths, perm + + # speech_pre: list[(batch, sample)] + assert speech_pre[0].dim() == 2, speech_pre[0].dim() + + if speech_ref.dim() == 4: + # For si_snr loss of multi-channel input, + # only select one channel as the reference + speech_ref = speech_ref[..., self.ref_channel] + speech_ref = torch.unbind(speech_ref, dim=1) + + # compute si-snr loss + si_snr_loss, perm = self._permutation_loss( + speech_ref, speech_pre, self.si_snr_loss_zeromean + ) + loss = si_snr_loss + + return loss, speech_pre, None, speech_lengths, perm + + @staticmethod + def tf_mse_loss(ref, inf): + """time-frequency MSE loss. + + Args: + ref: (Batch, T, F) or (Batch, T, C, F) + inf: (Batch, T, F) or (Batch, T, C, F) + Returns: + loss: (Batch,) + """ + assert ref.shape == inf.shape, (ref.shape, inf.shape) + if not is_torch_1_3_plus: + # in case of binary masks + ref = ref.type(inf.dtype) + diff = ref - inf + if isinstance(diff, ComplexTensor): + mseloss = diff.real ** 2 + diff.imag ** 2 + else: + mseloss = diff ** 2 + if ref.dim() == 3: + mseloss = mseloss.mean(dim=[1, 2]) + elif ref.dim() == 4: + mseloss = mseloss.mean(dim=[1, 2, 3]) + else: + raise ValueError( + "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) + ) + + return mseloss + + @staticmethod + def tf_log_mse_loss(ref, inf): + """time-frequency log-MSE loss. + + Args: + ref: (Batch, T, F) or (Batch, T, C, F) + inf: (Batch, T, F) or (Batch, T, C, F) + Returns: + loss: (Batch,) + """ + assert ref.shape == inf.shape, (ref.shape, inf.shape) + if not is_torch_1_3_plus: + # in case of binary masks + ref = ref.type(inf.dtype) + diff = ref - inf + if isinstance(diff, ComplexTensor): + log_mse_loss = diff.real ** 2 + diff.imag ** 2 + else: + log_mse_loss = diff ** 2 + if ref.dim() == 3: + log_mse_loss = torch.log10(log_mse_loss.sum(dim=[1, 2])) * 10 + elif ref.dim() == 4: + log_mse_loss = torch.log10(log_mse_loss.sum(dim=[1, 2, 3])) * 10 + else: + raise ValueError( + "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) + ) + + return log_mse_loss + + @staticmethod + def tf_l1_loss(ref, inf): + """time-frequency L1 loss. + + Args: + ref: (Batch, T, F) or (Batch, T, C, F) + inf: (Batch, T, F) or (Batch, T, C, F) + Returns: + loss: (Batch,) + """ + assert ref.shape == inf.shape, (ref.shape, inf.shape) + if not is_torch_1_3_plus: + # in case of binary masks + ref = ref.type(inf.dtype) + if isinstance(inf, ComplexTensor): + l1loss = abs(ref - inf + EPS) + else: + l1loss = abs(ref - inf) + if ref.dim() == 3: + l1loss = l1loss.mean(dim=[1, 2]) + elif ref.dim() == 4: + l1loss = l1loss.mean(dim=[1, 2, 3]) + else: + raise ValueError( + "Invalid input shape: ref={}, inf={}".format(ref.shape, inf.shape) + ) + return l1loss + + @staticmethod + def si_snr_loss(ref, inf): + """SI-SNR loss + + Args: + ref: (Batch, samples) + inf: (Batch, samples) + Returns: + loss: (Batch,) + """ + ref = ref / torch.norm(ref, p=2, dim=1, keepdim=True) + inf = inf / torch.norm(inf, p=2, dim=1, keepdim=True) + + s_target = (ref * inf).sum(dim=1, keepdims=True) * ref + e_noise = inf - s_target + + si_snr = 20 * ( + torch.log10(torch.norm(s_target, p=2, dim=1).clamp(min=EPS)) + - torch.log10(torch.norm(e_noise, p=2, dim=1).clamp(min=EPS)) + ) + return -si_snr + + @staticmethod + def si_snr_loss_zeromean(ref, inf): + """SI-SNR loss with zero-mean in pre-processing. + + Args: + ref: (Batch, samples) + inf: (Batch, samples) + Returns: + loss: (Batch,) + """ + assert ref.size() == inf.size() + B, T = ref.size() + # mask padding position along T + + # Step 1. Zero-mean norm + mean_target = torch.sum(ref, dim=1, keepdim=True) / T + mean_estimate = torch.sum(inf, dim=1, keepdim=True) / T + zero_mean_target = ref - mean_target + zero_mean_estimate = inf - mean_estimate + + # Step 2. SI-SNR with order + # reshape to use broadcast + s_target = zero_mean_target # [B, T] + s_estimate = zero_mean_estimate # [B, T] + # s_target = s / ||s||^2 + pair_wise_dot = torch.sum(s_estimate * s_target, dim=1, keepdim=True) # [B, 1] + s_target_energy = torch.sum(s_target ** 2, dim=1, keepdim=True) + EPS # [B, 1] + pair_wise_proj = pair_wise_dot * s_target / s_target_energy # [B, T] + # e_noise = s' - s_target + e_noise = s_estimate - pair_wise_proj # [B, T] + + # SI-SNR = 10 * log_10(||s_target||^2 / ||e_noise||^2) + pair_wise_si_snr = torch.sum(pair_wise_proj ** 2, dim=1) / ( + torch.sum(e_noise ** 2, dim=1) + EPS + ) + # print('pair_si_snr',pair_wise_si_snr[0,:]) + pair_wise_si_snr = 10 * torch.log10(pair_wise_si_snr + EPS) # [B] + # print(pair_wise_si_snr) + + return -1 * pair_wise_si_snr + + @staticmethod + def _permutation_loss(ref, inf, criterion, perm=None): + """The basic permutation loss function. + + Args: + ref (List[torch.Tensor]): [(batch, ...), ...] x n_spk + inf (List[torch.Tensor]): [(batch, ...), ...] + criterion (function): Loss function + perm (torch.Tensor): specified permutation (batch, num_spk) + Returns: + loss (torch.Tensor): minimum loss with the best permutation (batch) + perm (torch.Tensor): permutation for inf (batch, num_spk) + e.g. tensor([[1, 0, 2], [0, 1, 2]]) + """ + assert len(ref) == len(inf), (len(ref), len(inf)) + num_spk = len(ref) + + def pair_loss(permutation): + return sum( + [criterion(ref[s], inf[t]) for s, t in enumerate(permutation)] + ) / len(permutation) + + if perm is None: + device = ref[0].device + all_permutations = list(permutations(range(num_spk))) + losses = torch.stack([pair_loss(p) for p in all_permutations], dim=1) + loss, perm = torch.min(losses, dim=1) + perm = torch.index_select( + torch.tensor(all_permutations, device=device, dtype=torch.long), + 0, + perm, + ) + else: + loss = torch.tensor( + [ + torch.tensor( + [ + criterion( + ref[s][batch].unsqueeze(0), inf[t][batch].unsqueeze(0) + ) + for s, t in enumerate(p) + ] + ).mean() + for batch, p in enumerate(perm) + ] + ) + + return loss.mean(), perm + + def collect_feats( + self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor, **kwargs + ) -> Dict[str, torch.Tensor]: + # for data-parallel + speech_mix = speech_mix[:, : speech_mix_lengths.max()] + + feats, feats_lengths = speech_mix, speech_mix_lengths + return {"feats": feats, "feats_lengths": feats_lengths} diff --git a/espnet2/enh/layers/__init__.py b/espnet2/enh/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/enh/layers/beamformer.py b/espnet2/enh/layers/beamformer.py new file mode 100644 index 0000000000000000000000000000000000000000..39f610a2d57c237e9b28167c16b91138584b948a --- /dev/null +++ b/espnet2/enh/layers/beamformer.py @@ -0,0 +1,645 @@ +from distutils.version import LooseVersion +from typing import List +from typing import Optional +from typing import Union + +import numpy as np +import torch +from torch_complex import functional as FC +from torch_complex.tensor import ComplexTensor + +is_torch_1_1_plus = LooseVersion(torch.__version__) >= LooseVersion("1.1.0") +EPS = torch.finfo(torch.double).eps + + +def complex_norm(c: ComplexTensor) -> torch.Tensor: + return torch.sqrt((c.real ** 2 + c.imag ** 2).sum(dim=-1, keepdim=True) + EPS) + + +def get_rtf( + psd_speech: ComplexTensor, + psd_noise: ComplexTensor, + reference_vector: Union[int, torch.Tensor, None] = None, + iterations: int = 3, + use_torch_solver: bool = True, +) -> ComplexTensor: + """Calculate the relative transfer function (RTF) using the power method. + + Algorithm: + 1) rtf = reference_vector + 2) for i in range(iterations): + rtf = (psd_noise^-1 @ psd_speech) @ rtf + rtf = rtf / ||rtf||_2 # this normalization can be skipped + 3) rtf = psd_noise @ rtf + 4) rtf = rtf / rtf[..., ref_channel, :] + Note: 4) Normalization at the reference channel is not performed here. + + Args: + psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) + psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) + reference_vector (torch.Tensor or int): (..., C) or scalar + iterations (int): number of iterations in power method + use_torch_solver (bool): Whether to use `solve` instead of `inverse` + Returns: + rtf (ComplexTensor): (..., F, C, 1) + """ + if use_torch_solver and is_torch_1_1_plus: + # torch.solve is required, which is only available after pytorch 1.1.0+ + phi = FC.solve(psd_speech, psd_noise)[0] + else: + phi = FC.matmul(psd_noise.inverse2(), psd_speech) + rtf = ( + phi[..., reference_vector, None] + if isinstance(reference_vector, int) + else FC.matmul(phi, reference_vector[..., None, :, None]) + ) + for _ in range(iterations - 2): + rtf = FC.matmul(phi, rtf) + # rtf = rtf / complex_norm(rtf) + rtf = FC.matmul(psd_speech, rtf) + return rtf + + +def get_mvdr_vector( + psd_s: ComplexTensor, + psd_n: ComplexTensor, + reference_vector: torch.Tensor, + use_torch_solver: bool = True, + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + eps: float = 1e-8, +) -> ComplexTensor: + """Return the MVDR (Minimum Variance Distortionless Response) vector: + + h = (Npsd^-1 @ Spsd) / (Tr(Npsd^-1 @ Spsd)) @ u + + Reference: + On optimal frequency-domain multichannel linear filtering + for noise reduction; M. Souden et al., 2010; + https://ieeexplore.ieee.org/document/5089420 + + Args: + psd_s (ComplexTensor): speech covariance matrix (..., F, C, C) + psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C) + reference_vector (torch.Tensor): (..., C) + use_torch_solver (bool): Whether to use `solve` instead of `inverse` + diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n + diag_eps (float): + eps (float): + Returns: + beamform_vector (ComplexTensor): (..., F, C) + """ # noqa: D400 + if diagonal_loading: + psd_n = tik_reg(psd_n, reg=diag_eps, eps=eps) + + if use_torch_solver and is_torch_1_1_plus: + # torch.solve is required, which is only available after pytorch 1.1.0+ + numerator = FC.solve(psd_s, psd_n)[0] + else: + numerator = FC.matmul(psd_n.inverse2(), psd_s) + # ws: (..., C, C) / (...,) -> (..., C, C) + ws = numerator / (FC.trace(numerator)[..., None, None] + eps) + # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) + beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) + return beamform_vector + + +def get_mvdr_vector_with_rtf( + psd_n: ComplexTensor, + psd_speech: ComplexTensor, + psd_noise: ComplexTensor, + iterations: int = 3, + reference_vector: Union[int, torch.Tensor, None] = None, + normalize_ref_channel: Optional[int] = None, + use_torch_solver: bool = True, + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + eps: float = 1e-8, +) -> ComplexTensor: + """Return the MVDR (Minimum Variance Distortionless Response) vector + calculated with RTF: + + h = (Npsd^-1 @ rtf) / (rtf^H @ Npsd^-1 @ rtf) + + Reference: + On optimal frequency-domain multichannel linear filtering + for noise reduction; M. Souden et al., 2010; + https://ieeexplore.ieee.org/document/5089420 + + Args: + psd_n (ComplexTensor): observation/noise covariance matrix (..., F, C, C) + psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) + psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) + iterations (int): number of iterations in power method + reference_vector (torch.Tensor or int): (..., C) or scalar + normalize_ref_channel (int): reference channel for normalizing the RTF + use_torch_solver (bool): Whether to use `solve` instead of `inverse` + diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n + diag_eps (float): + eps (float): + Returns: + beamform_vector (ComplexTensor): (..., F, C) + """ # noqa: H405, D205, D400 + if diagonal_loading: + psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) + + # (B, F, C, 1) + rtf = get_rtf( + psd_speech, + psd_noise, + reference_vector, + iterations=iterations, + use_torch_solver=use_torch_solver, + ) + + # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) + if use_torch_solver and is_torch_1_1_plus: + # torch.solve is required, which is only available after pytorch 1.1.0+ + numerator = FC.solve(rtf, psd_n)[0].squeeze(-1) + else: + numerator = FC.matmul(psd_n.inverse2(), rtf).squeeze(-1) + denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) + if normalize_ref_channel is not None: + scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() + beamforming_vector = numerator * scale / (denominator.real.unsqueeze(-1) + eps) + else: + beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) + return beamforming_vector + + +def signal_framing( + signal: Union[torch.Tensor, ComplexTensor], + frame_length: int, + frame_step: int, + bdelay: int, + do_padding: bool = False, + pad_value: int = 0, + indices: List = None, +) -> Union[torch.Tensor, ComplexTensor]: + """Expand `signal` into several frames, with each frame of length `frame_length`. + + Args: + signal : (..., T) + frame_length: length of each segment + frame_step: step for selecting frames + bdelay: delay for WPD + do_padding: whether or not to pad the input signal at the beginning + of the time dimension + pad_value: value to fill in the padding + + Returns: + torch.Tensor: + if do_padding: (..., T, frame_length) + else: (..., T - bdelay - frame_length + 2, frame_length) + """ + frame_length2 = frame_length - 1 + # pad to the right at the last dimension of `signal` (time dimension) + if do_padding: + # (..., T) --> (..., T + bdelay + frame_length - 2) + signal = FC.pad(signal, (bdelay + frame_length2 - 1, 0), "constant", pad_value) + do_padding = False + + if indices is None: + # [[ 0, 1, ..., frame_length2 - 1, frame_length2 - 1 + bdelay ], + # [ 1, 2, ..., frame_length2, frame_length2 + bdelay ], + # [ 2, 3, ..., frame_length2 + 1, frame_length2 + 1 + bdelay ], + # ... + # [ T-bdelay-frame_length2, ..., T-1-bdelay, T-1 ]] + indices = [ + [*range(i, i + frame_length2), i + frame_length2 + bdelay - 1] + for i in range(0, signal.shape[-1] - frame_length2 - bdelay + 1, frame_step) + ] + + if isinstance(signal, ComplexTensor): + real = signal_framing( + signal.real, + frame_length, + frame_step, + bdelay, + do_padding, + pad_value, + indices, + ) + imag = signal_framing( + signal.imag, + frame_length, + frame_step, + bdelay, + do_padding, + pad_value, + indices, + ) + return ComplexTensor(real, imag) + else: + # (..., T - bdelay - frame_length + 2, frame_length) + signal = signal[..., indices] + # signal[..., :-1] = -signal[..., :-1] + return signal + + +def get_covariances( + Y: ComplexTensor, + inverse_power: torch.Tensor, + bdelay: int, + btaps: int, + get_vector: bool = False, +) -> ComplexTensor: + """Calculates the power normalized spatio-temporal covariance + matrix of the framed signal. + + Args: + Y : Complext STFT signal with shape (B, F, C, T) + inverse_power : Weighting factor with shape (B, F, T) + + Returns: + Correlation matrix: (B, F, (btaps+1) * C, (btaps+1) * C) + Correlation vector: (B, F, btaps + 1, C, C) + """ # noqa: H405, D205, D400, D401 + assert inverse_power.dim() == 3, inverse_power.dim() + assert inverse_power.size(0) == Y.size(0), (inverse_power.size(0), Y.size(0)) + + Bs, Fdim, C, T = Y.shape + + # (B, F, C, T - bdelay - btaps + 1, btaps + 1) + Psi = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=False)[ + ..., : T - bdelay - btaps + 1, : + ] + # Reverse along btaps-axis: + # [tau, tau-bdelay, tau-bdelay-1, ..., tau-bdelay-frame_length+1] + Psi = FC.reverse(Psi, dim=-1) + Psi_norm = Psi * inverse_power[..., None, bdelay + btaps - 1 :, None] + + # let T' = T - bdelay - btaps + 1 + # (B, F, C, T', btaps + 1) x (B, F, C, T', btaps + 1) + # -> (B, F, btaps + 1, C, btaps + 1, C) + covariance_matrix = FC.einsum("bfdtk,bfetl->bfkdle", (Psi, Psi_norm.conj())) + + # (B, F, btaps + 1, C, btaps + 1, C) + # -> (B, F, (btaps + 1) * C, (btaps + 1) * C) + covariance_matrix = covariance_matrix.view( + Bs, Fdim, (btaps + 1) * C, (btaps + 1) * C + ) + + if get_vector: + # (B, F, C, T', btaps + 1) x (B, F, C, T') + # --> (B, F, btaps +1, C, C) + covariance_vector = FC.einsum( + "bfdtk,bfet->bfked", (Psi_norm, Y[..., bdelay + btaps - 1 :].conj()) + ) + return covariance_matrix, covariance_vector + else: + return covariance_matrix + + +def get_WPD_filter( + Phi: ComplexTensor, + Rf: ComplexTensor, + reference_vector: torch.Tensor, + use_torch_solver: bool = True, + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + eps: float = 1e-8, +) -> ComplexTensor: + """Return the WPD vector. + + WPD is the Weighted Power minimization Distortionless response + convolutional beamformer. As follows: + + h = (Rf^-1 @ Phi_{xx}) / tr[(Rf^-1) @ Phi_{xx}] @ u + + Reference: + T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer + for Simultaneous Denoising and Dereverberation," in IEEE Signal + Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: + 10.1109/LSP.2019.2911179. + https://ieeexplore.ieee.org/document/8691481 + + Args: + Phi (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) + is the PSD of zero-padded speech [x^T(t,f) 0 ... 0]^T. + Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) + is the power normalized spatio-temporal covariance matrix. + reference_vector (torch.Tensor): (B, (btaps+1) * C) + is the reference_vector. + use_torch_solver (bool): Whether to use `solve` instead of `inverse` + diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n + diag_eps (float): + eps (float): + + Returns: + filter_matrix (ComplexTensor): (B, F, (btaps + 1) * C) + """ + if diagonal_loading: + Rf = tik_reg(Rf, reg=diag_eps, eps=eps) + + # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) + if use_torch_solver and is_torch_1_1_plus: + # torch.solve is required, which is only available after pytorch 1.1.0+ + numerator = FC.solve(Phi, Rf)[0] + else: + numerator = FC.matmul(Rf.inverse2(), Phi) + # ws: (..., C, C) / (...,) -> (..., C, C) + ws = numerator / (FC.trace(numerator)[..., None, None] + eps) + # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) + beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) + # (B, F, (btaps + 1) * C) + return beamform_vector + + +def get_WPD_filter_v2( + Phi: ComplexTensor, + Rf: ComplexTensor, + reference_vector: torch.Tensor, + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + eps: float = 1e-8, +) -> ComplexTensor: + """Return the WPD vector (v2). + + This implementaion is more efficient than `get_WPD_filter` as + it skips unnecessary computation with zeros. + + Args: + Phi (ComplexTensor): (B, F, C, C) + is speech PSD. + Rf (ComplexTensor): (B, F, (btaps+1) * C, (btaps+1) * C) + is the power normalized spatio-temporal covariance matrix. + reference_vector (torch.Tensor): (B, C) + is the reference_vector. + diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n + diag_eps (float): + eps (float): + + Returns: + filter_matrix (ComplexTensor): (B, F, (btaps+1) * C) + """ + C = reference_vector.shape[-1] + if diagonal_loading: + Rf = tik_reg(Rf, reg=diag_eps, eps=eps) + inv_Rf = Rf.inverse2() + # (B, F, (btaps+1) * C, C) + inv_Rf_pruned = inv_Rf[..., :C] + # numerator: (..., C_1, C_2) x (..., C_2, C_3) -> (..., C_1, C_3) + numerator = FC.matmul(inv_Rf_pruned, Phi) + # ws: (..., (btaps+1) * C, C) / (...,) -> (..., (btaps+1) * C, C) + ws = numerator / (FC.trace(numerator[..., :C, :])[..., None, None] + eps) + # h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1) + beamform_vector = FC.einsum("...fec,...c->...fe", [ws, reference_vector]) + # (B, F, (btaps+1) * C) + return beamform_vector + + +def get_WPD_filter_with_rtf( + psd_observed_bar: ComplexTensor, + psd_speech: ComplexTensor, + psd_noise: ComplexTensor, + iterations: int = 3, + reference_vector: Union[int, torch.Tensor, None] = None, + normalize_ref_channel: Optional[int] = None, + use_torch_solver: bool = True, + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + eps: float = 1e-15, +) -> ComplexTensor: + """Return the WPD vector calculated with RTF. + + WPD is the Weighted Power minimization Distortionless response + convolutional beamformer. As follows: + + h = (Rf^-1 @ vbar) / (vbar^H @ R^-1 @ vbar) + + Reference: + T. Nakatani and K. Kinoshita, "A Unified Convolutional Beamformer + for Simultaneous Denoising and Dereverberation," in IEEE Signal + Processing Letters, vol. 26, no. 6, pp. 903-907, June 2019, doi: + 10.1109/LSP.2019.2911179. + https://ieeexplore.ieee.org/document/8691481 + + Args: + psd_observed_bar (ComplexTensor): stacked observation covariance matrix + psd_speech (ComplexTensor): speech covariance matrix (..., F, C, C) + psd_noise (ComplexTensor): noise covariance matrix (..., F, C, C) + iterations (int): number of iterations in power method + reference_vector (torch.Tensor or int): (..., C) or scalar + normalize_ref_channel (int): reference channel for normalizing the RTF + use_torch_solver (bool): Whether to use `solve` instead of `inverse` + diagonal_loading (bool): Whether to add a tiny term to the diagonal of psd_n + diag_eps (float): + eps (float): + Returns: + beamform_vector (ComplexTensor)r: (..., F, C) + """ + C = psd_noise.size(-1) + if diagonal_loading: + psd_noise = tik_reg(psd_noise, reg=diag_eps, eps=eps) + + # (B, F, C, 1) + rtf = get_rtf( + psd_speech, + psd_noise, + reference_vector, + iterations=iterations, + use_torch_solver=use_torch_solver, + ) + + # (B, F, (K+1)*C, 1) + rtf = FC.pad(rtf, (0, 0, 0, psd_observed_bar.shape[-1] - C), "constant", 0) + # numerator: (..., C_1, C_2) x (..., C_2, 1) -> (..., C_1) + if use_torch_solver and is_torch_1_1_plus: + # torch.solve is required, which is only available after pytorch 1.1.0+ + numerator = FC.solve(rtf, psd_observed_bar)[0].squeeze(-1) + else: + numerator = FC.matmul(psd_observed_bar.inverse2(), rtf).squeeze(-1) + denominator = FC.einsum("...d,...d->...", [rtf.squeeze(-1).conj(), numerator]) + if normalize_ref_channel is not None: + scale = rtf.squeeze(-1)[..., normalize_ref_channel, None].conj() + beamforming_vector = numerator * scale / (denominator.real.unsqueeze(-1) + eps) + else: + beamforming_vector = numerator / (denominator.real.unsqueeze(-1) + eps) + return beamforming_vector + + +def perform_WPD_filtering( + filter_matrix: ComplexTensor, Y: ComplexTensor, bdelay: int, btaps: int +) -> ComplexTensor: + """Perform WPD filtering. + + Args: + filter_matrix: Filter matrix (B, F, (btaps + 1) * C) + Y : Complex STFT signal with shape (B, F, C, T) + + Returns: + enhanced (ComplexTensor): (B, F, T) + """ + # (B, F, C, T) --> (B, F, C, T, btaps + 1) + Ytilde = signal_framing(Y, btaps + 1, 1, bdelay, do_padding=True, pad_value=0) + Ytilde = FC.reverse(Ytilde, dim=-1) + + Bs, Fdim, C, T = Y.shape + # --> (B, F, T, btaps + 1, C) --> (B, F, T, (btaps + 1) * C) + Ytilde = Ytilde.permute(0, 1, 3, 4, 2).contiguous().view(Bs, Fdim, T, -1) + # (B, F, T, 1) + enhanced = FC.einsum("...tc,...c->...t", [Ytilde, filter_matrix.conj()]) + return enhanced + + +def tik_reg(mat: ComplexTensor, reg: float = 1e-8, eps: float = 1e-8) -> ComplexTensor: + """Perform Tikhonov regularization (only modifying real part). + + Args: + mat (ComplexTensor): input matrix (..., C, C) + reg (float): regularization factor + eps (float) + Returns: + ret (ComplexTensor): regularized matrix (..., C, C) + """ + # Add eps + C = mat.size(-1) + eye = torch.eye(C, dtype=mat.dtype, device=mat.device) + shape = [1 for _ in range(mat.dim() - 2)] + [C, C] + eye = eye.view(*shape).repeat(*mat.shape[:-2], 1, 1) + with torch.no_grad(): + epsilon = FC.trace(mat).real[..., None, None] * reg + # in case that correlation_matrix is all-zero + epsilon = epsilon + eps + mat = mat + epsilon * eye + return mat + + +############################################## +# Below are for Multi-Frame MVDR beamforming # +############################################## +# modified from https://gitlab.uni-oldenburg.de/hura4843/deep-mfmvdr/-/blob/master/deep_mfmvdr (# noqa: E501) +def get_adjacent(spec: ComplexTensor, filter_length: int = 5) -> ComplexTensor: + """Zero-pad and unfold stft, i.e., + + add zeros to the beginning so that, using the multi-frame signal model, + there will be as many output frames as input frames. + + Args: + spec (ComplexTensor): input spectrum (B, F, T) + filter_length (int): length for frame extension + Returns: + ret (ComplexTensor): output spectrum (B, F, T, filter_length) + """ # noqa: D400 + return ( + FC.pad(spec, pad=[filter_length - 1, 0]) + .unfold(dim=-1, size=filter_length, step=1) + .contiguous() + ) + + +def get_adjacent_th(spec: torch.Tensor, filter_length: int = 5) -> torch.Tensor: + """Zero-pad and unfold stft, i.e., + + add zeros to the beginning so that, using the multi-frame signal model, + there will be as many output frames as input frames. + + Args: + spec (torch.Tensor): input spectrum (B, F, T, 2) + filter_length (int): length for frame extension + Returns: + ret (torch.Tensor): output spectrum (B, F, T, filter_length, 2) + """ # noqa: D400 + return ( + torch.nn.functional.pad(spec, pad=[0, 0, filter_length - 1, 0]) + .unfold(dimension=-2, size=filter_length, step=1) + .transpose(-2, -1) + .contiguous() + ) + + +def vector_to_Hermitian(vec): + """Construct a Hermitian matrix from a vector of N**2 independent + real-valued elements. + + Args: + vec (torch.Tensor): (..., N ** 2) + Returns: + mat (ComplexTensor): (..., N, N) + """ # noqa: H405, D205, D400 + N = int(np.sqrt(vec.shape[-1])) + mat = torch.zeros(size=vec.shape[:-1] + (N, N, 2), device=vec.device) + + # real component + triu = np.triu_indices(N, 0) + triu2 = np.triu_indices(N, 1) # above main diagonal + tril = (triu2[1], triu2[0]) # below main diagonal; for symmetry + mat[(...,) + triu + (np.zeros(triu[0].shape[0]),)] = vec[..., : triu[0].shape[0]] + start = triu[0].shape[0] + mat[(...,) + tril + (np.zeros(tril[0].shape[0]),)] = mat[ + (...,) + triu2 + (np.zeros(triu2[0].shape[0]),) + ] + + # imaginary component + mat[(...,) + triu2 + (np.ones(triu2[0].shape[0]),)] = vec[ + ..., start : start + triu2[0].shape[0] + ] + mat[(...,) + tril + (np.ones(tril[0].shape[0]),)] = -mat[ + (...,) + triu2 + (np.ones(triu2[0].shape[0]),) + ] + + return ComplexTensor(mat[..., 0], mat[..., 1]) + + +def get_mfmvdr_vector(gammax, Phi, use_torch_solver: bool = True, eps: float = EPS): + """Compute conventional MFMPDR/MFMVDR filter. + + Args: + gammax (ComplexTensor): (..., L, N) + Phi (ComplexTensor): (..., L, N, N) + use_torch_solver (bool): Whether to use `solve` instead of `inverse` + eps (float) + Returns: + beamforming_vector (ComplexTensor): (..., L, N) + """ + # (..., L, N) + if use_torch_solver and is_torch_1_1_plus: + # torch.solve is required, which is only available after pytorch 1.1.0+ + numerator = FC.solve(gammax.unsqueeze(-1), Phi)[0].squeeze(-1) + else: + numerator = FC.matmul(Phi.inverse2(), gammax.unsqueeze(-1)).squeeze(-1) + denominator = FC.einsum("...d,...d->...", [gammax.conj(), numerator]) + return numerator / (denominator.real.unsqueeze(-1) + eps) + + +def filter_minimum_gain_like( + G_min, w, y, alpha=None, k: float = 10.0, eps: float = EPS +): + """Approximate a minimum gain operation. + + speech_estimate = alpha w^H y + (1 - alpha) G_min Y, + where alpha = 1 / (1 + exp(-2 k x)), x = w^H y - G_min Y + + Args: + G_min (float): minimum gain + w (ComplexTensor): filter coefficients (..., L, N) + y (ComplexTensor): buffered and stacked input (..., L, N) + alpha: mixing factor + k (float): scaling in tanh-like function + esp (float) + Returns: + output (ComplexTensor): minimum gain-filtered output + alpha (float): optional + """ + # (..., L) + filtered_input = FC.einsum("...d,...d->...", [w.conj(), y]) + # (..., L) + Y = y[..., -1] + return minimum_gain_like(G_min, Y, filtered_input, alpha, k, eps) + + +def minimum_gain_like( + G_min, Y, filtered_input, alpha=None, k: float = 10.0, eps: float = EPS +): + if alpha is None: + diff = (filtered_input + eps).abs() - (G_min * Y + eps).abs() + alpha = 1.0 / (1.0 + torch.exp(-2 * k * diff)) + return_alpha = True + else: + return_alpha = False + output = alpha * filtered_input + (1 - alpha) * G_min * Y + if return_alpha: + return output, alpha + else: + return output diff --git a/espnet2/enh/layers/dnn_beamformer.py b/espnet2/enh/layers/dnn_beamformer.py new file mode 100644 index 0000000000000000000000000000000000000000..94d802dd22fe19770b40d97c8e26d358abca97be --- /dev/null +++ b/espnet2/enh/layers/dnn_beamformer.py @@ -0,0 +1,516 @@ +from distutils.version import LooseVersion +from typing import List +from typing import Tuple +from typing import Union + +import logging +import torch +from torch.nn import functional as F +from torch_complex import functional as FC +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.frontends.beamformer import apply_beamforming_vector +from espnet.nets.pytorch_backend.frontends.beamformer import ( + get_power_spectral_density_matrix, # noqa: H301 +) +from espnet2.enh.layers.beamformer import get_covariances +from espnet2.enh.layers.beamformer import get_mvdr_vector +from espnet2.enh.layers.beamformer import get_mvdr_vector_with_rtf +from espnet2.enh.layers.beamformer import get_WPD_filter_v2 +from espnet2.enh.layers.beamformer import get_WPD_filter_with_rtf +from espnet2.enh.layers.beamformer import perform_WPD_filtering +from espnet2.enh.layers.mask_estimator import MaskEstimator + +is_torch_1_2_plus = LooseVersion(torch.__version__) >= LooseVersion("1.2.0") +is_torch_1_3_plus = LooseVersion(torch.__version__) >= LooseVersion("1.3.0") + + +BEAMFORMER_TYPES = ( + # Minimum Variance Distortionless Response beamformer + "mvdr", # RTF-based formula + "mvdr_souden", # Souden's solution + # Minimum Power Distortionless Response beamformer + "mpdr", # RTF-based formula + "mpdr_souden", # Souden's solution + # weighted MPDR beamformer + "wmpdr", # RTF-based formula + "wmpdr_souden", # Souden's solution + # Weighted Power minimization Distortionless response beamformer + "wpd", # RTF-based formula + "wpd_souden", # Souden's solution +) + + +class DNN_Beamformer(torch.nn.Module): + """DNN mask based Beamformer. + + Citation: + Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017; + http://proceedings.mlr.press/v70/ochiai17a/ochiai17a.pdf + + """ + + def __init__( + self, + bidim, + btype: str = "blstmp", + blayers: int = 3, + bunits: int = 300, + bprojs: int = 320, + num_spk: int = 1, + use_noise_mask: bool = True, + nonlinear: str = "sigmoid", + dropout_rate: float = 0.0, + badim: int = 320, + ref_channel: int = -1, + beamformer_type: str = "mvdr_souden", + rtf_iterations: int = 2, + eps: float = 1e-6, + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + mask_flooring: bool = False, + flooring_thres: float = 1e-6, + use_torch_solver: bool = True, + # only for WPD beamformer + btaps: int = 5, + bdelay: int = 3, + ): + super().__init__() + bnmask = num_spk + 1 if use_noise_mask else num_spk + self.mask = MaskEstimator( + btype, + bidim, + blayers, + bunits, + bprojs, + dropout_rate, + nmask=bnmask, + nonlinear=nonlinear, + ) + self.ref = AttentionReference(bidim, badim) if ref_channel < 0 else None + self.ref_channel = ref_channel + + self.use_noise_mask = use_noise_mask + assert num_spk >= 1, num_spk + self.num_spk = num_spk + self.nmask = bnmask + + if beamformer_type not in BEAMFORMER_TYPES: + raise ValueError("Not supporting beamformer_type=%s" % beamformer_type) + if ( + beamformer_type == "mvdr_souden" or not beamformer_type.endswith("_souden") + ) and not use_noise_mask: + if num_spk == 1: + logging.warning( + "Initializing %s beamformer without noise mask " + "estimator (single-speaker case)" % beamformer_type.upper() + ) + logging.warning( + "(1 - speech_mask) will be used for estimating noise " + "PSD in %s beamformer!" % beamformer_type.upper() + ) + else: + logging.warning( + "Initializing %s beamformer without noise mask " + "estimator (multi-speaker case)" % beamformer_type.upper() + ) + logging.warning( + "Interference speech masks will be used for estimating " + "noise PSD in %s beamformer!" % beamformer_type.upper() + ) + + self.beamformer_type = beamformer_type + if not beamformer_type.endswith("_souden"): + assert rtf_iterations >= 2, rtf_iterations + # number of iterations in power method for estimating the RTF + self.rtf_iterations = rtf_iterations + + assert btaps >= 0 and bdelay >= 0, (btaps, bdelay) + self.btaps = btaps + self.bdelay = bdelay if self.btaps > 0 else 1 + self.eps = eps + self.diagonal_loading = diagonal_loading + self.diag_eps = diag_eps + self.mask_flooring = mask_flooring + self.flooring_thres = flooring_thres + self.use_torch_solver = use_torch_solver + + def forward( + self, + data: ComplexTensor, + ilens: torch.LongTensor, + powers: Union[List[torch.Tensor], None] = None, + ) -> Tuple[ComplexTensor, torch.LongTensor, torch.Tensor]: + """DNN_Beamformer forward function. + + Notation: + B: Batch + C: Channel + T: Time or Sequence length + F: Freq + + Args: + data (ComplexTensor): (B, T, C, F) + ilens (torch.Tensor): (B,) + powers (List[torch.Tensor] or None): used for wMPDR or WPD (B, F, T) + Returns: + enhanced (ComplexTensor): (B, T, F) + ilens (torch.Tensor): (B,) + masks (torch.Tensor): (B, T, C, F) + """ + + def apply_beamforming(data, ilens, psd_n, psd_speech, psd_distortion=None): + """Beamforming with the provided statistics. + + Args: + data (ComplexTensor): (B, F, C, T) + ilens (torch.Tensor): (B,) + psd_n (ComplexTensor): + Noise covariance matrix for MVDR (B, F, C, C) + Observation covariance matrix for MPDR/wMPDR (B, F, C, C) + Stacked observation covariance for WPD (B,F,(btaps+1)*C,(btaps+1)*C) + psd_speech (ComplexTensor): Speech covariance matrix (B, F, C, C) + psd_distortion (ComplexTensor): Noise covariance matrix (B, F, C, C) + Return: + enhanced (ComplexTensor): (B, F, T) + ws (ComplexTensor): (B, F) or (B, F, (btaps+1)*C) + """ + # u: (B, C) + if self.ref_channel < 0: + u, _ = self.ref(psd_speech.to(dtype=data.dtype), ilens) + u = u.double() + else: + if self.beamformer_type.endswith("_souden"): + # (optional) Create onehot vector for fixed reference microphone + u = torch.zeros( + *(data.size()[:-3] + (data.size(-2),)), + device=data.device, + dtype=torch.double + ) + u[..., self.ref_channel].fill_(1) + else: + # for simplifying computation in RTF-based beamforming + u = self.ref_channel + + if self.beamformer_type in ("mvdr", "mpdr", "wmpdr"): + ws = get_mvdr_vector_with_rtf( + psd_n.double(), + psd_speech.double(), + psd_distortion.double(), + iterations=self.rtf_iterations, + reference_vector=u, + normalize_ref_channel=self.ref_channel, + use_torch_solver=self.use_torch_solver, + diagonal_loading=self.diagonal_loading, + diag_eps=self.diag_eps, + ) + enhanced = apply_beamforming_vector(ws, data.double()) + elif self.beamformer_type in ("mpdr_souden", "mvdr_souden", "wmpdr_souden"): + ws = get_mvdr_vector( + psd_speech.double(), + psd_n.double(), + u, + use_torch_solver=self.use_torch_solver, + diagonal_loading=self.diagonal_loading, + diag_eps=self.diag_eps, + ) + enhanced = apply_beamforming_vector(ws, data.double()) + elif self.beamformer_type == "wpd": + ws = get_WPD_filter_with_rtf( + psd_n.double(), + psd_speech.double(), + psd_distortion.double(), + iterations=self.rtf_iterations, + reference_vector=u, + normalize_ref_channel=self.ref_channel, + use_torch_solver=self.use_torch_solver, + diagonal_loading=self.diagonal_loading, + diag_eps=self.diag_eps, + ) + enhanced = perform_WPD_filtering( + ws, data.double(), self.bdelay, self.btaps + ) + elif self.beamformer_type == "wpd_souden": + ws = get_WPD_filter_v2( + psd_speech.double(), + psd_n.double(), + u, + diagonal_loading=self.diagonal_loading, + diag_eps=self.diag_eps, + ) + enhanced = perform_WPD_filtering( + ws, data.double(), self.bdelay, self.btaps + ) + else: + raise ValueError( + "Not supporting beamformer_type={}".format(self.beamformer_type) + ) + + return enhanced.to(dtype=data.dtype), ws.to(dtype=data.dtype) + + # data (B, T, C, F) -> (B, F, C, T) + data = data.permute(0, 3, 2, 1) + data_d = data.double() + + # mask: [(B, F, C, T)] + masks, _ = self.mask(data, ilens) + assert self.nmask == len(masks), len(masks) + # floor masks to increase numerical stability + if self.mask_flooring: + masks = [torch.clamp(m, min=self.flooring_thres) for m in masks] + + if self.num_spk == 1: # single-speaker case + if self.use_noise_mask: + # (mask_speech, mask_noise) + mask_speech, mask_noise = masks + else: + # (mask_speech,) + mask_speech = masks[0] + mask_noise = 1 - mask_speech + + if self.beamformer_type.startswith( + "wmpdr" + ) or self.beamformer_type.startswith("wpd"): + if powers is None: + power_input = data_d.real ** 2 + data_d.imag ** 2 + # Averaging along the channel axis: (..., C, T) -> (..., T) + powers = (power_input * mask_speech.double()).mean(dim=-2) + else: + assert len(powers) == 1, len(powers) + powers = powers[0] + inverse_power = 1 / torch.clamp(powers, min=self.eps) + + psd_speech = get_power_spectral_density_matrix(data_d, mask_speech.double()) + if mask_noise is not None and ( + self.beamformer_type == "mvdr_souden" + or not self.beamformer_type.endswith("_souden") + ): + # MVDR or other RTF-based formulas + psd_noise = get_power_spectral_density_matrix( + data_d, mask_noise.double() + ) + if self.beamformer_type == "mvdr": + enhanced, ws = apply_beamforming( + data, ilens, psd_noise, psd_speech, psd_distortion=psd_noise + ) + elif self.beamformer_type == "mvdr_souden": + enhanced, ws = apply_beamforming(data, ilens, psd_noise, psd_speech) + elif self.beamformer_type == "mpdr": + psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) + enhanced, ws = apply_beamforming( + data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise + ) + elif self.beamformer_type == "mpdr_souden": + psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) + enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) + elif self.beamformer_type == "wmpdr": + psd_observed = FC.einsum( + "...ct,...et->...ce", + [data_d * inverse_power[..., None, :], data_d.conj()], + ) + enhanced, ws = apply_beamforming( + data, ilens, psd_observed, psd_speech, psd_distortion=psd_noise + ) + elif self.beamformer_type == "wmpdr_souden": + psd_observed = FC.einsum( + "...ct,...et->...ce", + [data_d * inverse_power[..., None, :], data_d.conj()], + ) + enhanced, ws = apply_beamforming(data, ilens, psd_observed, psd_speech) + elif self.beamformer_type == "wpd": + psd_observed_bar = get_covariances( + data_d, inverse_power, self.bdelay, self.btaps, get_vector=False + ) + enhanced, ws = apply_beamforming( + data, ilens, psd_observed_bar, psd_speech, psd_distortion=psd_noise + ) + elif self.beamformer_type == "wpd_souden": + psd_observed_bar = get_covariances( + data_d, inverse_power, self.bdelay, self.btaps, get_vector=False + ) + enhanced, ws = apply_beamforming( + data, ilens, psd_observed_bar, psd_speech + ) + else: + raise ValueError( + "Not supporting beamformer_type={}".format(self.beamformer_type) + ) + + # (..., F, T) -> (..., T, F) + enhanced = enhanced.transpose(-1, -2) + else: # multi-speaker case + if self.use_noise_mask: + # (mask_speech1, ..., mask_noise) + mask_speech = list(masks[:-1]) + mask_noise = masks[-1] + else: + # (mask_speech1, ..., mask_speechX) + mask_speech = list(masks) + mask_noise = None + + if self.beamformer_type.startswith( + "wmpdr" + ) or self.beamformer_type.startswith("wpd"): + if powers is None: + power_input = data_d.real ** 2 + data_d.imag ** 2 + # Averaging along the channel axis: (..., C, T) -> (..., T) + powers = [ + (power_input * m.double()).mean(dim=-2) for m in mask_speech + ] + else: + assert len(powers) == self.num_spk, len(powers) + inverse_power = [1 / torch.clamp(p, min=self.eps) for p in powers] + + psd_speeches = [ + get_power_spectral_density_matrix(data_d, mask.double()) + for mask in mask_speech + ] + if mask_noise is not None and ( + self.beamformer_type == "mvdr_souden" + or not self.beamformer_type.endswith("_souden") + ): + # MVDR or other RTF-based formulas + psd_noise = get_power_spectral_density_matrix( + data_d, mask_noise.double() + ) + if self.beamformer_type in ("mpdr", "mpdr_souden"): + psd_observed = FC.einsum("...ct,...et->...ce", [data_d, data_d.conj()]) + elif self.beamformer_type in ("wmpdr", "wmpdr_souden"): + psd_observed = [ + FC.einsum( + "...ct,...et->...ce", + [data_d * inv_p[..., None, :], data_d.conj()], + ) + for inv_p in inverse_power + ] + elif self.beamformer_type in ("wpd", "wpd_souden"): + psd_observed_bar = [ + get_covariances( + data_d, inv_p, self.bdelay, self.btaps, get_vector=False + ) + for inv_p in inverse_power + ] + + enhanced, ws = [], [] + for i in range(self.num_spk): + psd_speech = psd_speeches.pop(i) + if ( + self.beamformer_type == "mvdr_souden" + or not self.beamformer_type.endswith("_souden") + ): + psd_noise_i = ( + psd_noise + sum(psd_speeches) + if mask_noise is not None + else sum(psd_speeches) + ) + # treat all other speakers' psd_speech as noises + if self.beamformer_type == "mvdr": + enh, w = apply_beamforming( + data, ilens, psd_noise_i, psd_speech, psd_distortion=psd_noise_i + ) + elif self.beamformer_type == "mvdr_souden": + enh, w = apply_beamforming(data, ilens, psd_noise_i, psd_speech) + elif self.beamformer_type == "mpdr": + enh, w = apply_beamforming( + data, + ilens, + psd_observed, + psd_speech, + psd_distortion=psd_noise_i, + ) + elif self.beamformer_type == "mpdr_souden": + enh, w = apply_beamforming(data, ilens, psd_observed, psd_speech) + elif self.beamformer_type == "wmpdr": + enh, w = apply_beamforming( + data, + ilens, + psd_observed[i], + psd_speech, + psd_distortion=psd_noise_i, + ) + elif self.beamformer_type == "wmpdr_souden": + enh, w = apply_beamforming(data, ilens, psd_observed[i], psd_speech) + elif self.beamformer_type == "wpd": + enh, w = apply_beamforming( + data, + ilens, + psd_observed_bar[i], + psd_speech, + psd_distortion=psd_noise_i, + ) + elif self.beamformer_type == "wpd_souden": + enh, w = apply_beamforming( + data, ilens, psd_observed_bar[i], psd_speech + ) + else: + raise ValueError( + "Not supporting beamformer_type={}".format(self.beamformer_type) + ) + psd_speeches.insert(i, psd_speech) + + # (..., F, T) -> (..., T, F) + enh = enh.transpose(-1, -2) + enhanced.append(enh) + ws.append(w) + + # (..., F, C, T) -> (..., T, C, F) + masks = [m.transpose(-1, -3) for m in masks] + return enhanced, ilens, masks + + def predict_mask( + self, data: ComplexTensor, ilens: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: + """Predict masks for beamforming. + + Args: + data (ComplexTensor): (B, T, C, F), double precision + ilens (torch.Tensor): (B,) + Returns: + masks (torch.Tensor): (B, T, C, F) + ilens (torch.Tensor): (B,) + """ + masks, _ = self.mask(data.permute(0, 3, 2, 1).float(), ilens) + # (B, F, C, T) -> (B, T, C, F) + masks = [m.transpose(-1, -3) for m in masks] + return masks, ilens + + +class AttentionReference(torch.nn.Module): + def __init__(self, bidim, att_dim): + super().__init__() + self.mlp_psd = torch.nn.Linear(bidim, att_dim) + self.gvec = torch.nn.Linear(att_dim, 1) + + def forward( + self, psd_in: ComplexTensor, ilens: torch.LongTensor, scaling: float = 2.0 + ) -> Tuple[torch.Tensor, torch.LongTensor]: + """Attention-based reference forward function. + + Args: + psd_in (ComplexTensor): (B, F, C, C) + ilens (torch.Tensor): (B,) + scaling (float): + Returns: + u (torch.Tensor): (B, C) + ilens (torch.Tensor): (B,) + """ + B, _, C = psd_in.size()[:3] + assert psd_in.size(2) == psd_in.size(3), psd_in.size() + # psd_in: (B, F, C, C) + datatype = torch.bool if is_torch_1_3_plus else torch.uint8 + datatype2 = torch.bool if is_torch_1_2_plus else torch.uint8 + psd = psd_in.masked_fill( + torch.eye(C, dtype=datatype, device=psd_in.device).type(datatype2), 0 + ) + # psd: (B, F, C, C) -> (B, C, F) + psd = (psd.sum(dim=-1) / (C - 1)).transpose(-1, -2) + + # Calculate amplitude + psd_feat = (psd.real ** 2 + psd.imag ** 2) ** 0.5 + + # (B, C, F) -> (B, C, F2) + mlp_psd = self.mlp_psd(psd_feat) + # (B, C, F2) -> (B, C, 1) -> (B, C) + e = self.gvec(torch.tanh(mlp_psd)).squeeze(-1) + u = F.softmax(scaling * e, dim=-1) + return u, ilens diff --git a/espnet2/enh/layers/dnn_wpe.py b/espnet2/enh/layers/dnn_wpe.py new file mode 100644 index 0000000000000000000000000000000000000000..c48affd6f010b8abbd4e4ae815384e1a9aa1ec22 --- /dev/null +++ b/espnet2/enh/layers/dnn_wpe.py @@ -0,0 +1,159 @@ +from typing import Tuple + +from pytorch_wpe import wpe_one_iteration +import torch +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet2.enh.layers.mask_estimator import MaskEstimator + + +class DNN_WPE(torch.nn.Module): + def __init__( + self, + wtype: str = "blstmp", + widim: int = 257, + wlayers: int = 3, + wunits: int = 300, + wprojs: int = 320, + dropout_rate: float = 0.0, + taps: int = 5, + delay: int = 3, + use_dnn_mask: bool = True, + nmask: int = 1, + nonlinear: str = "sigmoid", + iterations: int = 1, + normalization: bool = False, + eps: float = 1e-6, + diagonal_loading: bool = True, + diag_eps: float = 1e-7, + mask_flooring: bool = False, + flooring_thres: float = 1e-6, + use_torch_solver: bool = True, + ): + super().__init__() + self.iterations = iterations + self.taps = taps + self.delay = delay + self.eps = eps + + self.normalization = normalization + self.use_dnn_mask = use_dnn_mask + + self.inverse_power = True + self.diagonal_loading = diagonal_loading + self.diag_eps = diag_eps + self.mask_flooring = mask_flooring + self.flooring_thres = flooring_thres + self.use_torch_solver = use_torch_solver + + if self.use_dnn_mask: + self.nmask = nmask + self.mask_est = MaskEstimator( + wtype, + widim, + wlayers, + wunits, + wprojs, + dropout_rate, + nmask=nmask, + nonlinear=nonlinear, + ) + else: + self.nmask = 1 + + def forward( + self, data: ComplexTensor, ilens: torch.LongTensor + ) -> Tuple[ComplexTensor, torch.LongTensor, ComplexTensor]: + """DNN_WPE forward function. + + Notation: + B: Batch + C: Channel + T: Time or Sequence length + F: Freq or Some dimension of the feature vector + + Args: + data: (B, T, C, F) + ilens: (B,) + Returns: + enhanced (torch.Tensor or List[torch.Tensor]): (B, T, C, F) + ilens: (B,) + masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F) + power (List[torch.Tensor]): (B, F, T) + """ + # (B, T, C, F) -> (B, F, C, T) + data = data.permute(0, 3, 2, 1) + enhanced = [data for i in range(self.nmask)] + masks = None + power = None + + for i in range(self.iterations): + # Calculate power: (..., C, T) + power = [enh.real ** 2 + enh.imag ** 2 for enh in enhanced] + if i == 0 and self.use_dnn_mask: + # mask: (B, F, C, T) + masks, _ = self.mask_est(data, ilens) + # floor masks to increase numerical stability + if self.mask_flooring: + masks = [m.clamp(min=self.flooring_thres) for m in masks] + if self.normalization: + # Normalize along T + masks = [m / m.sum(dim=-1, keepdim=True) for m in masks] + # (..., C, T) * (..., C, T) -> (..., C, T) + power = [p * masks[i] for i, p in enumerate(power)] + + # Averaging along the channel axis: (..., C, T) -> (..., T) + power = [p.mean(dim=-2).clamp(min=self.eps) for p in power] + + # enhanced: (..., C, T) -> (..., C, T) + # NOTE(kamo): Calculate in double precision + enhanced = [ + wpe_one_iteration( + data.contiguous().double(), + p.double(), + taps=self.taps, + delay=self.delay, + inverse_power=self.inverse_power, + ) + for p in power + ] + enhanced = [ + enh.to(dtype=data.dtype).masked_fill(make_pad_mask(ilens, enh.real), 0) + for enh in enhanced + ] + + # (B, F, C, T) -> (B, T, C, F) + enhanced = [enh.permute(0, 3, 2, 1) for enh in enhanced] + if masks is not None: + masks = ( + [m.transpose(-1, -3) for m in masks] + if self.nmask > 1 + else masks[0].transpose(-1, -3) + ) + if self.nmask == 1: + enhanced = enhanced[0] + + return enhanced, ilens, masks, power + + def predict_mask( + self, data: ComplexTensor, ilens: torch.LongTensor + ) -> Tuple[torch.Tensor, torch.LongTensor]: + """Predict mask for WPE dereverberation. + + Args: + data (ComplexTensor): (B, T, C, F), double precision + ilens (torch.Tensor): (B,) + Returns: + masks (torch.Tensor or List[torch.Tensor]): (B, T, C, F) + ilens (torch.Tensor): (B,) + """ + if self.use_dnn_mask: + masks, ilens = self.mask_est(data.permute(0, 3, 2, 1).float(), ilens) + # (B, F, C, T) -> (B, T, C, F) + masks = [m.transpose(-1, -3) for m in masks] + if self.nmask == 1: + masks = masks[0] + else: + masks = None + return masks, ilens diff --git a/espnet2/enh/layers/dprnn.py b/espnet2/enh/layers/dprnn.py new file mode 100644 index 0000000000000000000000000000000000000000..827c754ac8673685c56fe0ddb67116adde7a1bcb --- /dev/null +++ b/espnet2/enh/layers/dprnn.py @@ -0,0 +1,241 @@ +# The implementation of DPRNN in +# Luo. et al. "Dual-path rnn: efficient long sequence modeling +# for time-domain single-channel speech separation." +# +# The code is based on: +# https://github.com/yluo42/TAC/blob/master/utility/models.py +# + + +import torch +from torch.autograd import Variable +import torch.nn as nn + + +EPS = torch.finfo(torch.get_default_dtype()).eps + + +class SingleRNN(nn.Module): + """Container module for a single RNN layer. + + args: + rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. + input_size: int, dimension of the input feature. The input should have shape + (batch, seq_len, input_size). + hidden_size: int, dimension of the hidden state. + dropout: float, dropout ratio. Default is 0. + bidirectional: bool, whether the RNN layers are bidirectional. Default is False. + """ + + def __init__( + self, rnn_type, input_size, hidden_size, dropout=0, bidirectional=False + ): + super().__init__() + + rnn_type = rnn_type.upper() + + assert rnn_type in [ + "RNN", + "LSTM", + "GRU", + ], f"Only support 'RNN', 'LSTM' and 'GRU', current type: {rnn_type}" + + self.rnn_type = rnn_type + self.input_size = input_size + self.hidden_size = hidden_size + self.num_direction = int(bidirectional) + 1 + + self.rnn = getattr(nn, rnn_type)( + input_size, + hidden_size, + 1, + batch_first=True, + bidirectional=bidirectional, + ) + + self.dropout = nn.Dropout(p=dropout) + + # linear projection layer + self.proj = nn.Linear(hidden_size * self.num_direction, input_size) + + def forward(self, input): + # input shape: batch, seq, dim + # input = input.to(device) + output = input + rnn_output, _ = self.rnn(output) + rnn_output = self.dropout(rnn_output) + rnn_output = self.proj( + rnn_output.contiguous().view(-1, rnn_output.shape[2]) + ).view(output.shape) + return rnn_output + + +# dual-path RNN +class DPRNN(nn.Module): + """Deep dual-path RNN. + + args: + rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. + input_size: int, dimension of the input feature. The input should have shape + (batch, seq_len, input_size). + hidden_size: int, dimension of the hidden state. + output_size: int, dimension of the output size. + dropout: float, dropout ratio. Default is 0. + num_layers: int, number of stacked RNN layers. Default is 1. + bidirectional: bool, whether the RNN layers are bidirectional. Default is True. + """ + + def __init__( + self, + rnn_type, + input_size, + hidden_size, + output_size, + dropout=0, + num_layers=1, + bidirectional=True, + ): + super().__init__() + + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + + # dual-path RNN + self.row_rnn = nn.ModuleList([]) + self.col_rnn = nn.ModuleList([]) + self.row_norm = nn.ModuleList([]) + self.col_norm = nn.ModuleList([]) + for i in range(num_layers): + self.row_rnn.append( + SingleRNN( + rnn_type, input_size, hidden_size, dropout, bidirectional=True + ) + ) # intra-segment RNN is always noncausal + self.col_rnn.append( + SingleRNN( + rnn_type, + input_size, + hidden_size, + dropout, + bidirectional=bidirectional, + ) + ) + self.row_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) + # default is to use noncausal LayerNorm for inter-chunk RNN. + # For causal setting change it to causal normalization accordingly. + self.col_norm.append(nn.GroupNorm(1, input_size, eps=1e-8)) + + # output layer + self.output = nn.Sequential(nn.PReLU(), nn.Conv2d(input_size, output_size, 1)) + + def forward(self, input): + # input shape: batch, N, dim1, dim2 + # apply RNN on dim1 first and then dim2 + # output shape: B, output_size, dim1, dim2 + # input = input.to(device) + batch_size, _, dim1, dim2 = input.shape + output = input + for i in range(len(self.row_rnn)): + row_input = ( + output.permute(0, 3, 2, 1) + .contiguous() + .view(batch_size * dim2, dim1, -1) + ) # B*dim2, dim1, N + row_output = self.row_rnn[i](row_input) # B*dim2, dim1, H + row_output = ( + row_output.view(batch_size, dim2, dim1, -1) + .permute(0, 3, 2, 1) + .contiguous() + ) # B, N, dim1, dim2 + row_output = self.row_norm[i](row_output) + output = output + row_output + + col_input = ( + output.permute(0, 2, 3, 1) + .contiguous() + .view(batch_size * dim1, dim2, -1) + ) # B*dim1, dim2, N + col_output = self.col_rnn[i](col_input) # B*dim1, dim2, H + col_output = ( + col_output.view(batch_size, dim1, dim2, -1) + .permute(0, 3, 1, 2) + .contiguous() + ) # B, N, dim1, dim2 + col_output = self.col_norm[i](col_output) + output = output + col_output + + output = self.output(output) # B, output_size, dim1, dim2 + + return output + + +def _pad_segment(input, segment_size): + # input is the features: (B, N, T) + batch_size, dim, seq_len = input.shape + segment_stride = segment_size // 2 + + rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size + if rest > 0: + pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type()) + input = torch.cat([input, pad], 2) + + pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type()) + input = torch.cat([pad_aux, input, pad_aux], 2) + + return input, rest + + +def split_feature(input, segment_size): + # split the feature into chunks of segment size + # input is the features: (B, N, T) + + input, rest = _pad_segment(input, segment_size) + batch_size, dim, seq_len = input.shape + segment_stride = segment_size // 2 + + segments1 = ( + input[:, :, :-segment_stride] + .contiguous() + .view(batch_size, dim, -1, segment_size) + ) + segments2 = ( + input[:, :, segment_stride:] + .contiguous() + .view(batch_size, dim, -1, segment_size) + ) + segments = ( + torch.cat([segments1, segments2], 3) + .view(batch_size, dim, -1, segment_size) + .transpose(2, 3) + ) + + return segments.contiguous(), rest + + +def merge_feature(input, rest): + # merge the splitted features into full utterance + # input is the features: (B, N, L, K) + + batch_size, dim, segment_size, _ = input.shape + segment_stride = segment_size // 2 + input = ( + input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2) + ) # B, N, K, L + + input1 = ( + input[:, :, :, :segment_size] + .contiguous() + .view(batch_size, dim, -1)[:, :, segment_stride:] + ) + input2 = ( + input[:, :, :, segment_size:] + .contiguous() + .view(batch_size, dim, -1)[:, :, :-segment_stride] + ) + + output = input1 + input2 + if rest > 0: + output = output[:, :, :-rest] + + return output.contiguous() # B, N, T diff --git a/espnet2/enh/layers/mask_estimator.py b/espnet2/enh/layers/mask_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..9e309b4d9a8af76ca18ed1b094c74eb2862a4496 --- /dev/null +++ b/espnet2/enh/layers/mask_estimator.py @@ -0,0 +1,91 @@ +from typing import Tuple + +import numpy as np +import torch +from torch.nn import functional as F +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.rnn.encoders import RNN +from espnet.nets.pytorch_backend.rnn.encoders import RNNP + + +class MaskEstimator(torch.nn.Module): + def __init__( + self, type, idim, layers, units, projs, dropout, nmask=1, nonlinear="sigmoid" + ): + super().__init__() + subsample = np.ones(layers + 1, dtype=np.int) + + typ = type.lstrip("vgg").rstrip("p") + if type[-1] == "p": + self.brnn = RNNP(idim, layers, units, projs, subsample, dropout, typ=typ) + else: + self.brnn = RNN(idim, layers, units, projs, dropout, typ=typ) + + self.type = type + self.nmask = nmask + self.linears = torch.nn.ModuleList( + [torch.nn.Linear(projs, idim) for _ in range(nmask)] + ) + + if nonlinear not in ("sigmoid", "relu", "tanh", "crelu"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.nonlinear = nonlinear + + def forward( + self, xs: ComplexTensor, ilens: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, ...], torch.LongTensor]: + """Mask estimator forward function. + + Args: + xs: (B, F, C, T) + ilens: (B,) + Returns: + hs (torch.Tensor): The hidden vector (B, F, C, T) + masks: A tuple of the masks. (B, F, C, T) + ilens: (B,) + """ + assert xs.size(0) == ilens.size(0), (xs.size(0), ilens.size(0)) + _, _, C, input_length = xs.size() + # (B, F, C, T) -> (B, C, T, F) + xs = xs.permute(0, 2, 3, 1) + + # Calculate amplitude: (B, C, T, F) -> (B, C, T, F) + xs = (xs.real ** 2 + xs.imag ** 2) ** 0.5 + # xs: (B, C, T, F) -> xs: (B * C, T, F) + xs = xs.contiguous().view(-1, xs.size(-2), xs.size(-1)) + # ilens: (B,) -> ilens_: (B * C) + ilens_ = ilens[:, None].expand(-1, C).contiguous().view(-1) + + # xs: (B * C, T, F) -> xs: (B * C, T, D) + xs, _, _ = self.brnn(xs, ilens_) + # xs: (B * C, T, D) -> xs: (B, C, T, D) + xs = xs.view(-1, C, xs.size(-2), xs.size(-1)) + + masks = [] + for linear in self.linears: + # xs: (B, C, T, D) -> mask:(B, C, T, F) + mask = linear(xs) + + if self.nonlinear == "sigmoid": + mask = torch.sigmoid(mask) + elif self.nonlinear == "relu": + mask = torch.relu(mask) + elif self.nonlinear == "tanh": + mask = torch.tanh(mask) + elif self.nonlinear == "crelu": + mask = torch.clamp(mask, min=0, max=1) + # Zero padding + mask.masked_fill(make_pad_mask(ilens, mask, length_dim=2), 0) + + # (B, C, T, F) -> (B, F, C, T) + mask = mask.permute(0, 3, 1, 2) + + # Take cares of multi gpu cases: If input_length > max(ilens) + if mask.size(-1) < input_length: + mask = F.pad(mask, [0, input_length - mask.size(-1)], value=0) + masks.append(mask) + + return tuple(masks), ilens diff --git a/espnet2/enh/layers/tcn.py b/espnet2/enh/layers/tcn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c694cfd3dbd4e22ba4f6264af422ed7d371ced --- /dev/null +++ b/espnet2/enh/layers/tcn.py @@ -0,0 +1,289 @@ +# Implementation of the TCN proposed in +# Luo. et al. "Conv-tasnet: Surpassing ideal time–frequency +# magnitude masking for speech separation." +# +# The code is based on: +# https://github.com/kaituoxu/Conv-TasNet/blob/master/src/conv_tasnet.py +# + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +EPS = torch.finfo(torch.get_default_dtype()).eps + + +class TemporalConvNet(nn.Module): + def __init__( + self, N, B, H, P, X, R, C, norm_type="gLN", causal=False, mask_nonlinear="relu" + ): + """Basic Module of tasnet. + + Args: + N: Number of filters in autoencoder + B: Number of channels in bottleneck 1 * 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + C: Number of speakers + norm_type: BN, gLN, cLN + causal: causal or non-causal + mask_nonlinear: use which non-linear function to generate mask + """ + super().__init__() + # Hyper-parameter + self.C = C + self.mask_nonlinear = mask_nonlinear + # Components + # [M, N, K] -> [M, N, K] + layer_norm = ChannelwiseLayerNorm(N) + # [M, N, K] -> [M, B, K] + bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) + # [M, B, K] -> [M, B, K] + repeats = [] + for r in range(R): + blocks = [] + for x in range(X): + dilation = 2 ** x + padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 + blocks += [ + TemporalBlock( + B, + H, + P, + stride=1, + padding=padding, + dilation=dilation, + norm_type=norm_type, + causal=causal, + ) + ] + repeats += [nn.Sequential(*blocks)] + temporal_conv_net = nn.Sequential(*repeats) + # [M, B, K] -> [M, C*N, K] + mask_conv1x1 = nn.Conv1d(B, C * N, 1, bias=False) + # Put together + self.network = nn.Sequential( + layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1 + ) + + def forward(self, mixture_w): + """Keep this API same with TasNet. + + Args: + mixture_w: [M, N, K], M is batch size + + Returns: + est_mask: [M, C, N, K] + """ + M, N, K = mixture_w.size() + score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K] + score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] + if self.mask_nonlinear == "softmax": + est_mask = F.softmax(score, dim=1) + elif self.mask_nonlinear == "relu": + est_mask = F.relu(score) + elif self.mask_nonlinear == "sigmoid": + est_mask = F.sigmoid(score) + elif self.mask_nonlinear == "tanh": + est_mask = F.tanh(score) + else: + raise ValueError("Unsupported mask non-linear function") + return est_mask + + +class TemporalBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False, + ): + super().__init__() + # [M, B, K] -> [M, H, K] + conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) + prelu = nn.PReLU() + norm = chose_norm(norm_type, out_channels) + # [M, H, K] -> [M, B, K] + dsconv = DepthwiseSeparableConv( + out_channels, + in_channels, + kernel_size, + stride, + padding, + dilation, + norm_type, + causal, + ) + # Put together + self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) + + def forward(self, x): + """Forward. + + Args: + x: [M, B, K] + + Returns: + [M, B, K] + """ + residual = x + out = self.net(x) + # TODO(Jing): when P = 3 here works fine, but when P = 2 maybe need to pad? + return out + residual # look like w/o F.relu is better than w/ F.relu + # return F.relu(out + residual) + + +class DepthwiseSeparableConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False, + ): + super().__init__() + # Use `groups` option to implement depthwise convolution + # [M, H, K] -> [M, H, K] + depthwise_conv = nn.Conv1d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False, + ) + if causal: + chomp = Chomp1d(padding) + prelu = nn.PReLU() + norm = chose_norm(norm_type, in_channels) + # [M, H, K] -> [M, B, K] + pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) + # Put together + if causal: + self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) + else: + self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) + + def forward(self, x): + """Forward. + + Args: + x: [M, H, K] + + Returns: + result: [M, B, K] + """ + return self.net(x) + + +class Chomp1d(nn.Module): + """To ensure the output length is the same as the input.""" + + def __init__(self, chomp_size): + super().__init__() + self.chomp_size = chomp_size + + def forward(self, x): + """Forward. + + Args: + x: [M, H, Kpad] + + Returns: + [M, H, K] + """ + return x[:, :, : -self.chomp_size].contiguous() + + +def check_nonlinear(nolinear_type): + if nolinear_type not in ["softmax", "relu"]: + raise ValueError("Unsupported nonlinear type") + + +def chose_norm(norm_type, channel_size): + """The input of normalization will be (M, C, K), where M is batch size. + + C is channel size and K is sequence length. + """ + if norm_type == "gLN": + return GlobalLayerNorm(channel_size) + elif norm_type == "cLN": + return ChannelwiseLayerNorm(channel_size) + elif norm_type == "BN": + # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics + # along M and K, so this BN usage is right. + return nn.BatchNorm1d(channel_size) + else: + raise ValueError("Unsupported normalization type") + + +class ChannelwiseLayerNorm(nn.Module): + """Channel-wise Layer Normalization (cLN).""" + + def __init__(self, channel_size): + super().__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """Forward. + + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + + Returns: + cLN_y: [M, N, K] + """ + mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] + var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] + cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return cLN_y + + +class GlobalLayerNorm(nn.Module): + """Global Layer Normalization (gLN).""" + + def __init__(self, channel_size): + super().__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + def forward(self, y): + """Forward. + + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + + Returns: + gLN_y: [M, N, K] + """ + mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) # [M, 1, 1] + var = ( + (torch.pow(y - mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) + ) + gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + return gLN_y diff --git a/espnet2/enh/separator/abs_separator.py b/espnet2/enh/separator/abs_separator.py new file mode 100644 index 0000000000000000000000000000000000000000..8b9de6260269cc3b42d8572b28ee04b788bfd442 --- /dev/null +++ b/espnet2/enh/separator/abs_separator.py @@ -0,0 +1,22 @@ +from abc import ABC +from abc import abstractmethod +from collections import OrderedDict +from typing import Tuple + +import torch + + +class AbsSeparator(torch.nn.Module, ABC): + @abstractmethod + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor, + ) -> Tuple[Tuple[torch.Tensor], torch.Tensor, OrderedDict]: + + raise NotImplementedError + + @property + @abstractmethod + def num_spk(self): + raise NotImplementedError diff --git a/espnet2/enh/separator/asteroid_models.py b/espnet2/enh/separator/asteroid_models.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec5264797be2275b33247c977000282d5fa6b2d --- /dev/null +++ b/espnet2/enh/separator/asteroid_models.py @@ -0,0 +1,156 @@ +from collections import OrderedDict +from typing import Tuple +import warnings + +import torch + +from espnet2.enh.separator.abs_separator import AbsSeparator + + +class AsteroidModel_Converter(AbsSeparator): + def __init__( + self, + encoder_output_dim: int, + model_name: str, + num_spk: int, + pretrained_path: str = "", + loss_type: str = "si_snr", + **model_related_kwargs, + ): + """The class to convert the models from asteroid to AbsSeprator. + + Args: + encoder_output_dim: input feature dimension, deafult=1 after the NullEncoder + num_spk: number of speakers + loss_type: loss type of enhancement + model_name: Asteroid model names, e.g. ConvTasNet, DPTNet. Refers to + https://github.com/asteroid-team/asteroid/ + blob/master/asteroid/models/__init__.py + pretrained_path: the name of pretrained model from Asteroid in HF hub. + Refers to: https://github.com/asteroid-team/asteroid/ + blob/master/docs/source/readmes/pretrained_models.md and + https://huggingface.co/models?filter=asteroid + model_related_kwargs: more args towards each specific asteroid model. + """ + super(AsteroidModel_Converter, self).__init__() + + assert ( + encoder_output_dim == 1 + ), encoder_output_dim # The input should in raw-wave domain. + + # Please make sure the installation of Asteroid. + # https://github.com/asteroid-team/asteroid + from asteroid import models + + model_related_kwargs = { + k: None if v == "None" else v for k, v in model_related_kwargs.items() + } + # print('args:',model_related_kwargs) + + if pretrained_path: + model = getattr(models, model_name).from_pretrained(pretrained_path) + print("model_kwargs:", model_related_kwargs) + if model_related_kwargs: + warnings.warn( + "Pratrained model should get no args with %s" % model_related_kwargs + ) + + else: + model_name = getattr(models, model_name) + model = model_name(**model_related_kwargs) + + self.model = model + self._num_spk = num_spk + + self.loss_type = loss_type + if loss_type != "si_snr": + raise ValueError("Unsupported loss type: %s" % loss_type) + + def forward(self, input: torch.Tensor, ilens: torch.Tensor = None): + """Whole forward of asteroid models. + + Args: + input (torch.Tensor): Raw Waveforms [B, T] + ilens (torch.Tensor): input lengths [B] + + Returns: + estimated Waveforms(List[Union(torch.Tensor]): [(B, T), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. masks: OrderedDict[ + 'mask_spk1': torch.Tensor(Batch, T), + 'mask_spk2': torch.Tensor(Batch, T), + ... + 'mask_spkn': torch.Tensor(Batch, T), + ] + """ + + if hasattr(self.model, "forward_wav"): + est_source = self.model.forward_wav(input) # B,nspk,T or nspk,T + else: + est_source = self.model(input) # B,nspk,T or nspk,T + + if input.dim() == 1: + assert est_source.size(0) == self.num_spk, est_source.size(0) + else: + assert est_source.size(1) == self.num_spk, est_source.size(1) + + est_source = [es for es in est_source.transpose(0, 1)] # List(M,T) + masks = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(self.num_spk)], est_source) + ) + return est_source, ilens, masks + + def forward_rawwav( + self, input: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Output with waveforms. """ + return self.forward(input, ilens) + + @property + def num_spk(self): + return self._num_spk + + +if __name__ == "__main__": + mixture = torch.randn(3, 16000) + print("mixture shape", mixture.shape) + + net = AsteroidModel_Converter( + model_name="ConvTasNet", + encoder_output_dim=1, + num_spk=2, + loss_type="si_snr", + pretrained_path="mpariente/ConvTasNet_WHAM!_sepclean", + ) + print("model", net) + output, *__ = net(mixture) + output, *__ = net.forward_rawwav(mixture, 111) + print("output spk1 shape", output[0].shape) + + net = AsteroidModel_Converter( + encoder_output_dim=1, + num_spk=2, + model_name="ConvTasNet", + n_src=2, + loss_type="si_snr", + out_chan=None, + n_blocks=2, + n_repeats=2, + bn_chan=128, + hid_chan=512, + skip_chan=128, + conv_kernel_size=3, + norm_type="gLN", + mask_act="sigmoid", + in_chan=None, + fb_name="free", + kernel_size=16, + n_filters=512, + stride=8, + encoder_activation=None, + sample_rate=8000, + ) + print("\n\nmodel", net) + output, *__ = net(mixture) + print("output spk1 shape", output[0].shape) + print("Finished", output[0].shape) diff --git a/espnet2/enh/separator/conformer_separator.py b/espnet2/enh/separator/conformer_separator.py new file mode 100644 index 0000000000000000000000000000000000000000..26fd6a248fe1dac2b1016e8416e4f3a21addeac0 --- /dev/null +++ b/espnet2/enh/separator/conformer_separator.py @@ -0,0 +1,162 @@ +from collections import OrderedDict +from typing import List +from typing import Tuple +from typing import Union + +import torch +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.conformer.encoder import ( + Encoder as ConformerEncoder, # noqa: H301 +) +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet2.enh.separator.abs_separator import AbsSeparator + + +class ConformerSeparator(AbsSeparator): + def __init__( + self, + input_dim: int, + num_spk: int = 2, + adim: int = 384, + aheads: int = 4, + layers: int = 6, + linear_units: int = 1536, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + normalize_before: bool = False, + concat_after: bool = False, + dropout_rate: float = 0.1, + input_layer: str = "linear", + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.1, + nonlinear: str = "relu", + conformer_pos_enc_layer_type: str = "rel_pos", + conformer_self_attn_layer_type: str = "rel_selfattn", + conformer_activation_type: str = "swish", + use_macaron_style_in_conformer: bool = True, + use_cnn_in_conformer: bool = True, + conformer_enc_kernel_size: int = 7, + padding_idx: int = -1, + ): + """Conformer separator. + + Args: + input_dim: input feature dimension + num_spk: number of speakers + adim (int): Dimention of attention. + aheads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + layers (int): The number of transformer blocks. + dropout_rate (float): Dropout rate. + input_layer (Union[str, torch.nn.Module]): Input layer type. + attention_dropout_rate (float): Dropout rate in attention. + positional_dropout_rate (float): Dropout rate after adding + positional encoding. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + conformer_pos_enc_layer_type(str): Encoder positional encoding layer type. + conformer_self_attn_layer_type (str): Encoder attention layer type. + conformer_activation_type(str): Encoder activation function type. + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of + positionwise conv1d layer. + use_macaron_style_in_conformer (bool): Whether to use macaron style for + positionwise layer. + use_cnn_in_conformer (bool): Whether to use convolution module. + conformer_enc_kernel_size(int): Kernerl size of convolution module. + padding_idx (int): Padding idx for input_layer=embed. + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + """ + super().__init__() + + self._num_spk = num_spk + + self.conformer = ConformerEncoder( + idim=input_dim, + attention_dim=adim, + attention_heads=aheads, + linear_units=linear_units, + num_blocks=layers, + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + attention_dropout_rate=attention_dropout_rate, + input_layer=input_layer, + normalize_before=normalize_before, + concat_after=concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_enc_kernel_size, + padding_idx=padding_idx, + ) + + self.linear = torch.nn.ModuleList( + [torch.nn.Linear(adim, input_dim) for _ in range(self.num_spk)] + ) + + if nonlinear not in ("sigmoid", "relu", "tanh"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.nonlinear = { + "sigmoid": torch.nn.Sigmoid(), + "relu": torch.nn.ReLU(), + "tanh": torch.nn.Tanh(), + }[nonlinear] + + def forward( + self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] + ilens (torch.Tensor): input lengths [Batch] + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. masks: OrderedDict[ + 'mask_spk1': torch.Tensor(Batch, Frames, Freq), + 'mask_spk2': torch.Tensor(Batch, Frames, Freq), + ... + 'mask_spkn': torch.Tensor(Batch, Frames, Freq), + ] + """ + + # if complex spectrum, + if isinstance(input, ComplexTensor): + feature = abs(input) + else: + feature = input + + # prepare pad_mask for transformer + pad_mask = make_non_pad_mask(ilens).unsqueeze(1).to(feature.device) + + x, ilens = self.conformer(feature, pad_mask) + + masks = [] + for linear in self.linear: + y = linear(x) + y = self.nonlinear(y) + masks.append(y) + + masked = [input * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + + return masked, ilens, others + + @property + def num_spk(self): + return self._num_spk diff --git a/espnet2/enh/separator/dprnn_separator.py b/espnet2/enh/separator/dprnn_separator.py new file mode 100644 index 0000000000000000000000000000000000000000..449fb3b79bc311a8dfc179f565a6ade87bfeed54 --- /dev/null +++ b/espnet2/enh/separator/dprnn_separator.py @@ -0,0 +1,118 @@ +from collections import OrderedDict +from typing import List +from typing import Tuple +from typing import Union + +import torch +from torch_complex.tensor import ComplexTensor + +from espnet2.enh.layers.dprnn import DPRNN +from espnet2.enh.layers.dprnn import merge_feature +from espnet2.enh.layers.dprnn import split_feature +from espnet2.enh.separator.abs_separator import AbsSeparator + + +class DPRNNSeparator(AbsSeparator): + def __init__( + self, + input_dim: int, + rnn_type: str = "lstm", + bidirectional: bool = True, + num_spk: int = 2, + nonlinear: str = "relu", + layer: int = 3, + unit: int = 512, + segment_size: int = 20, + dropout: float = 0.0, + ): + """Dual-Path RNN (DPRNN) Separator + + Args: + input_dim: input feature dimension + rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'. + bidirectional: bool, whether the inter-chunk RNN layers are bidirectional. + num_spk: number of speakers + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + layer: int, number of stacked RNN layers. Default is 3. + unit: int, dimension of the hidden state. + segment_size: dual-path segment size + dropout: float, dropout ratio. Default is 0. + """ + super().__init__() + + self._num_spk = num_spk + + self.segment_size = segment_size + + self.dprnn = DPRNN( + rnn_type=rnn_type, + input_size=input_dim, + hidden_size=unit, + output_size=input_dim * num_spk, + dropout=dropout, + num_layers=layer, + bidirectional=bidirectional, + ) + + if nonlinear not in ("sigmoid", "relu", "tanh"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.nonlinear = { + "sigmoid": torch.nn.Sigmoid(), + "relu": torch.nn.ReLU(), + "tanh": torch.nn.Tanh(), + }[nonlinear] + + def forward( + self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] + ilens (torch.Tensor): input lengths [Batch] + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. masks: OrderedDict[ + 'mask_spk1': torch.Tensor(Batch, Frames, Freq), + 'mask_spk2': torch.Tensor(Batch, Frames, Freq), + ... + 'mask_spkn': torch.Tensor(Batch, Frames, Freq), + ] + """ + + # if complex spectrum, + if isinstance(input, ComplexTensor): + feature = abs(input) + else: + feature = input + + B, T, N = feature.shape + + feature = feature.transpose(1, 2) # B, N, T + segmented, rest = split_feature( + feature, segment_size=self.segment_size + ) # B, N, L, K + + processed = self.dprnn(segmented) # B, N*num_spk, L, K + + processed = merge_feature(processed, rest) # B, N*num_spk, T + + processed = processed.transpose(1, 2) # B, T, N*num_spk + processed = processed.view(B, T, N, self.num_spk) + masks = self.nonlinear(processed).unbind(dim=3) + + masked = [input * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + + return masked, ilens, others + + @property + def num_spk(self): + return self._num_spk diff --git a/espnet2/enh/separator/neural_beamformer.py b/espnet2/enh/separator/neural_beamformer.py new file mode 100644 index 0000000000000000000000000000000000000000..007072b16f78cc27423728f6149e6b404000c474 --- /dev/null +++ b/espnet2/enh/separator/neural_beamformer.py @@ -0,0 +1,258 @@ +from collections import OrderedDict +from typing import List +from typing import Tuple + +import torch +from torch_complex.tensor import ComplexTensor + +from espnet2.enh.layers.dnn_beamformer import DNN_Beamformer +from espnet2.enh.layers.dnn_wpe import DNN_WPE +from espnet2.enh.separator.abs_separator import AbsSeparator + + +class NeuralBeamformer(AbsSeparator): + def __init__( + self, + input_dim: int, + num_spk: int = 1, + loss_type: str = "mask_mse", + # Dereverberation options + use_wpe: bool = False, + wnet_type: str = "blstmp", + wlayers: int = 3, + wunits: int = 300, + wprojs: int = 320, + wdropout_rate: float = 0.0, + taps: int = 5, + delay: int = 3, + use_dnn_mask_for_wpe: bool = True, + wnonlinear: str = "crelu", + multi_source_wpe: bool = True, + wnormalization: bool = False, + # Beamformer options + use_beamformer: bool = True, + bnet_type: str = "blstmp", + blayers: int = 3, + bunits: int = 300, + bprojs: int = 320, + badim: int = 320, + ref_channel: int = -1, + use_noise_mask: bool = True, + bnonlinear: str = "sigmoid", + beamformer_type: str = "mvdr_souden", + rtf_iterations: int = 2, + bdropout_rate: float = 0.0, + shared_power: bool = True, + # For numerical stability + diagonal_loading: bool = True, + diag_eps_wpe: float = 1e-7, + diag_eps_bf: float = 1e-7, + mask_flooring: bool = False, + flooring_thres_wpe: float = 1e-6, + flooring_thres_bf: float = 1e-6, + use_torch_solver: bool = True, + ): + super().__init__() + + self._num_spk = num_spk + self.loss_type = loss_type + if loss_type not in ("mask_mse", "spectrum", "spectrum_log", "magnitude"): + raise ValueError("Unsupported loss type: %s" % loss_type) + + self.use_beamformer = use_beamformer + self.use_wpe = use_wpe + + if self.use_wpe: + if use_dnn_mask_for_wpe: + # Use DNN for power estimation + iterations = 1 + else: + # Performing as conventional WPE, without DNN Estimator + iterations = 2 + + self.wpe = DNN_WPE( + wtype=wnet_type, + widim=input_dim, + wlayers=wlayers, + wunits=wunits, + wprojs=wprojs, + dropout_rate=wdropout_rate, + taps=taps, + delay=delay, + use_dnn_mask=use_dnn_mask_for_wpe, + nmask=1 if multi_source_wpe else num_spk, + nonlinear=wnonlinear, + iterations=iterations, + normalization=wnormalization, + diagonal_loading=diagonal_loading, + diag_eps=diag_eps_wpe, + mask_flooring=mask_flooring, + flooring_thres=flooring_thres_wpe, + use_torch_solver=use_torch_solver, + ) + else: + self.wpe = None + + self.ref_channel = ref_channel + if self.use_beamformer: + self.beamformer = DNN_Beamformer( + bidim=input_dim, + btype=bnet_type, + blayers=blayers, + bunits=bunits, + bprojs=bprojs, + num_spk=num_spk, + use_noise_mask=use_noise_mask, + nonlinear=bnonlinear, + dropout_rate=bdropout_rate, + badim=badim, + ref_channel=ref_channel, + beamformer_type=beamformer_type, + rtf_iterations=rtf_iterations, + btaps=taps, + bdelay=delay, + diagonal_loading=diagonal_loading, + diag_eps=diag_eps_bf, + mask_flooring=mask_flooring, + flooring_thres=flooring_thres_bf, + use_torch_solver=use_torch_solver, + ) + else: + self.beamformer = None + + # share speech powers between WPE and beamforming (wMPDR/WPD) + self.shared_power = shared_power and use_wpe + + def forward( + self, input: ComplexTensor, ilens: torch.Tensor + ) -> Tuple[List[ComplexTensor], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (ComplexTensor): mixed speech [Batch, Frames, Channel, Freq] + ilens (torch.Tensor): input lengths [Batch] + + Returns: + enhanced speech (single-channel): List[ComplexTensor] + output lengths + other predcited data: OrderedDict[ + 'dereverb1': ComplexTensor(Batch, Frames, Channel, Freq), + 'mask_dereverb1': torch.Tensor(Batch, Frames, Channel, Freq), + 'mask_noise1': torch.Tensor(Batch, Frames, Channel, Freq), + 'mask_spk1': torch.Tensor(Batch, Frames, Channel, Freq), + 'mask_spk2': torch.Tensor(Batch, Frames, Channel, Freq), + ... + 'mask_spkn': torch.Tensor(Batch, Frames, Channel, Freq), + ] + """ + # Shape of input spectrum must be (B, T, F) or (B, T, C, F) + assert input.dim() in (3, 4), input.dim() + enhanced = input + others = OrderedDict() + + if ( + self.training + and self.loss_type is not None + and self.loss_type.startswith("mask") + ): + # Only estimating masks during training for saving memory + if self.use_wpe: + if input.dim() == 3: + mask_w, ilens = self.wpe.predict_mask(input.unsqueeze(-2), ilens) + mask_w = mask_w.squeeze(-2) + elif input.dim() == 4: + mask_w, ilens = self.wpe.predict_mask(input, ilens) + + if mask_w is not None: + if isinstance(enhanced, list): + # single-source WPE + for spk in range(self.num_spk): + others["mask_dereverb{}".format(spk + 1)] = mask_w[spk] + else: + # multi-source WPE + others["mask_dereverb1"] = mask_w + + if self.use_beamformer and input.dim() == 4: + others_b, ilens = self.beamformer.predict_mask(input, ilens) + for spk in range(self.num_spk): + others["mask_spk{}".format(spk + 1)] = others_b[spk] + if len(others_b) > self.num_spk: + others["mask_noise1"] = others_b[self.num_spk] + + return None, ilens, others + + else: + powers = None + # Performing both mask estimation and enhancement + if input.dim() == 3: + # single-channel input (B, T, F) + if self.use_wpe: + enhanced, ilens, mask_w, powers = self.wpe( + input.unsqueeze(-2), ilens + ) + if isinstance(enhanced, list): + # single-source WPE + enhanced = [enh.squeeze(-2) for enh in enhanced] + if mask_w is not None: + for spk in range(self.num_spk): + key = "dereverb{}".format(spk + 1) + others[key] = enhanced[spk] + others["mask_" + key] = mask_w[spk].squeeze(-2) + else: + # multi-source WPE + enhanced = enhanced.squeeze(-2) + if mask_w is not None: + others["dereverb1"] = enhanced + others["mask_dereverb1"] = mask_w.squeeze(-2) + else: + # multi-channel input (B, T, C, F) + # 1. WPE + if self.use_wpe: + enhanced, ilens, mask_w, powers = self.wpe(input, ilens) + if mask_w is not None: + if isinstance(enhanced, list): + # single-source WPE + for spk in range(self.num_spk): + key = "dereverb{}".format(spk + 1) + others[key] = enhanced[spk] + others["mask_" + key] = mask_w[spk] + else: + # multi-source WPE + others["dereverb1"] = enhanced + others["mask_dereverb1"] = mask_w.squeeze(-2) + + # 2. Beamformer + if self.use_beamformer: + if ( + not self.beamformer.beamformer_type.startswith("wmpdr") + or not self.beamformer.beamformer_type.startswith("wpd") + or not self.shared_power + or (self.wpe.nmask == 1 and self.num_spk > 1) + ): + powers = None + + # enhanced: (B, T, C, F) -> (B, T, F) + if isinstance(enhanced, list): + # outputs of single-source WPE + raise NotImplementedError( + "Single-source WPE is not supported with beamformer " + "in multi-speaker cases." + ) + else: + # output of multi-source WPE + enhanced, ilens, others_b = self.beamformer( + enhanced, ilens, powers=powers + ) + for spk in range(self.num_spk): + others["mask_spk{}".format(spk + 1)] = others_b[spk] + if len(others_b) > self.num_spk: + others["mask_noise1"] = others_b[self.num_spk] + + if not isinstance(enhanced, list): + enhanced = [enhanced] + + return enhanced, ilens, others + + @property + def num_spk(self): + return self._num_spk diff --git a/espnet2/enh/separator/rnn_separator.py b/espnet2/enh/separator/rnn_separator.py new file mode 100644 index 0000000000000000000000000000000000000000..6be889479d5a020aa8a5838aec45d12eaa3fab6e --- /dev/null +++ b/espnet2/enh/separator/rnn_separator.py @@ -0,0 +1,108 @@ +from collections import OrderedDict +from typing import List +from typing import Tuple +from typing import Union + +import torch +from torch_complex.tensor import ComplexTensor + +from espnet.nets.pytorch_backend.rnn.encoders import RNN +from espnet2.enh.separator.abs_separator import AbsSeparator + + +class RNNSeparator(AbsSeparator): + def __init__( + self, + input_dim: int, + rnn_type: str = "blstm", + num_spk: int = 2, + nonlinear: str = "sigmoid", + layer: int = 3, + unit: int = 512, + dropout: float = 0.0, + ): + """RNN Separator + + Args: + input_dim: input feature dimension + rnn_type: string, select from 'blstm', 'lstm' etc. + bidirectional: bool, whether the inter-chunk RNN layers are bidirectional. + num_spk: number of speakers + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + layer: int, number of stacked RNN layers. Default is 3. + unit: int, dimension of the hidden state. + dropout: float, dropout ratio. Default is 0. + """ + super().__init__() + + self._num_spk = num_spk + + self.rnn = RNN( + idim=input_dim, + elayers=layer, + cdim=unit, + hdim=unit, + dropout=dropout, + typ=rnn_type, + ) + + self.linear = torch.nn.ModuleList( + [torch.nn.Linear(unit, input_dim) for _ in range(self.num_spk)] + ) + + if nonlinear not in ("sigmoid", "relu", "tanh"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.nonlinear = { + "sigmoid": torch.nn.Sigmoid(), + "relu": torch.nn.ReLU(), + "tanh": torch.nn.Tanh(), + }[nonlinear] + + def forward( + self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] + ilens (torch.Tensor): input lengths [Batch] + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. masks: OrderedDict[ + 'mask_spk1': torch.Tensor(Batch, Frames, Freq), + 'mask_spk2': torch.Tensor(Batch, Frames, Freq), + ... + 'mask_spkn': torch.Tensor(Batch, Frames, Freq), + ] + """ + + # if complex spectrum, + if isinstance(input, ComplexTensor): + feature = abs(input) + else: + feature = input + + x, ilens, _ = self.rnn(feature, ilens) + + masks = [] + + for linear in self.linear: + y = linear(x) + y = self.nonlinear(y) + masks.append(y) + + masked = [input * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + + return masked, ilens, others + + @property + def num_spk(self): + return self._num_spk diff --git a/espnet2/enh/separator/tcn_separator.py b/espnet2/enh/separator/tcn_separator.py new file mode 100644 index 0000000000000000000000000000000000000000..a59adb5453e1f68c70c35bcc5d11387f55802171 --- /dev/null +++ b/espnet2/enh/separator/tcn_separator.py @@ -0,0 +1,104 @@ +from collections import OrderedDict +from typing import List +from typing import Tuple +from typing import Union + +import torch +from torch_complex.tensor import ComplexTensor + +from espnet2.enh.layers.tcn import TemporalConvNet +from espnet2.enh.separator.abs_separator import AbsSeparator + + +class TCNSeparator(AbsSeparator): + def __init__( + self, + input_dim: int, + num_spk: int = 2, + layer: int = 8, + stack: int = 3, + bottleneck_dim: int = 128, + hidden_dim: int = 512, + kernel: int = 3, + causal: bool = False, + norm_type: str = "gLN", + nonlinear: str = "relu", + ): + """Temporal Convolution Separator + + Args: + input_dim: input feature dimension + num_spk: number of speakers + layer: int, number of layers in each stack. + stack: int, number of stacks + bottleneck_dim: bottleneck dimension + hidden_dim: number of convolution channel + kernel: int, kernel size. + causal: bool, defalut False. + norm_type: str, choose from 'BN', 'gLN', 'cLN' + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + """ + super().__init__() + + self._num_spk = num_spk + + if nonlinear not in ("sigmoid", "relu", "tanh"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.tcn = TemporalConvNet( + N=input_dim, + B=bottleneck_dim, + H=hidden_dim, + P=kernel, + X=layer, + R=stack, + C=num_spk, + norm_type=norm_type, + causal=causal, + mask_nonlinear=nonlinear, + ) + + def forward( + self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] + ilens (torch.Tensor): input lengths [Batch] + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. masks: OrderedDict[ + 'mask_spk1': torch.Tensor(Batch, Frames, Freq), + 'mask_spk2': torch.Tensor(Batch, Frames, Freq), + ... + 'mask_spkn': torch.Tensor(Batch, Frames, Freq), + ] + """ + # if complex spectrum + if isinstance(input, ComplexTensor): + feature = abs(input) + else: + feature = input + B, L, N = feature.shape + + feature = feature.transpose(1, 2) # B, N, L + + masks = self.tcn(feature) # B, num_spk, N, L + masks = masks.transpose(2, 3) # B, num_spk, L, N + masks = masks.unbind(dim=1) # List[B, L, N] + + masked = [input * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + + return masked, ilens, others + + @property + def num_spk(self): + return self._num_spk diff --git a/espnet2/enh/separator/transformer_separator.py b/espnet2/enh/separator/transformer_separator.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca66d6b4025311e4223000eb8bc2c343d5a2500 --- /dev/null +++ b/espnet2/enh/separator/transformer_separator.py @@ -0,0 +1,149 @@ +from collections import OrderedDict +from typing import List +from typing import Tuple +from typing import Union + +import torch +from torch_complex.tensor import ComplexTensor + + +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, # noqa: H301 + ScaledPositionalEncoding, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.encoder import ( + Encoder as TransformerEncoder, # noqa: H301 +) +from espnet2.enh.separator.abs_separator import AbsSeparator + + +class TransformerSeparator(AbsSeparator): + def __init__( + self, + input_dim: int, + num_spk: int = 2, + adim: int = 384, + aheads: int = 4, + layers: int = 6, + linear_units: int = 1536, + positionwise_layer_type: str = "linear", + positionwise_conv_kernel_size: int = 1, + normalize_before: bool = False, + concat_after: bool = False, + dropout_rate: float = 0.1, + positional_dropout_rate: float = 0.1, + attention_dropout_rate: float = 0.1, + use_scaled_pos_enc: bool = True, + nonlinear: str = "relu", + ): + """Transformer separator. + + Args: + input_dim: input feature dimension + num_spk: number of speakers + adim (int): Dimention of attention. + aheads (int): The number of heads of multi head attention. + linear_units (int): The number of units of position-wise feed forward. + layers (int): The number of transformer blocks. + dropout_rate (float): Dropout rate. + attention_dropout_rate (float): Dropout rate in attention. + positional_dropout_rate (float): Dropout rate after adding + positional encoding. + normalize_before (bool): Whether to use layer_norm before the first block. + concat_after (bool): Whether to concat attention layer's input and output. + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". + positionwise_conv_kernel_size (int): Kernel size of + positionwise conv1d layer. + use_scaled_pos_enc (bool) : use scaled positional encoding or not + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + """ + super().__init__() + + self._num_spk = num_spk + + pos_enc_class = ( + ScaledPositionalEncoding if use_scaled_pos_enc else PositionalEncoding + ) + self.transformer = TransformerEncoder( + idim=input_dim, + attention_dim=adim, + attention_heads=aheads, + linear_units=linear_units, + num_blocks=layers, + input_layer="linear", + dropout_rate=dropout_rate, + positional_dropout_rate=positional_dropout_rate, + attention_dropout_rate=attention_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=normalize_before, + concat_after=concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + ) + + self.linear = torch.nn.ModuleList( + [torch.nn.Linear(adim, input_dim) for _ in range(self.num_spk)] + ) + + if nonlinear not in ("sigmoid", "relu", "tanh"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.nonlinear = { + "sigmoid": torch.nn.Sigmoid(), + "relu": torch.nn.ReLU(), + "tanh": torch.nn.Tanh(), + }[nonlinear] + + def forward( + self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] + ilens (torch.Tensor): input lengths [Batch] + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. masks: OrderedDict[ + 'mask_spk1': torch.Tensor(Batch, Frames, Freq), + 'mask_spk2': torch.Tensor(Batch, Frames, Freq), + ... + 'mask_spkn': torch.Tensor(Batch, Frames, Freq), + ] + """ + + # if complex spectrum, + if isinstance(input, ComplexTensor): + feature = abs(input) + else: + feature = input + + # prepare pad_mask for transformer + pad_mask = make_non_pad_mask(ilens).unsqueeze(1).to(feature.device) + + x, ilens = self.transformer(feature, pad_mask) + + masks = [] + for linear in self.linear: + y = linear(x) + y = self.nonlinear(y) + masks.append(y) + + masked = [input * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + + return masked, ilens, others + + @property + def num_spk(self): + return self._num_spk diff --git a/espnet2/fileio/__init__.py b/espnet2/fileio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/fileio/datadir_writer.py b/espnet2/fileio/datadir_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..bafdf984f19b704df458eff7181849a8463dda28 --- /dev/null +++ b/espnet2/fileio/datadir_writer.py @@ -0,0 +1,77 @@ +from pathlib import Path +from typing import Union +import warnings + +from typeguard import check_argument_types +from typeguard import check_return_type + + +class DatadirWriter: + """Writer class to create kaldi like data directory. + + Examples: + >>> with DatadirWriter("output") as writer: + ... # output/sub.txt is created here + ... subwriter = writer["sub.txt"] + ... # Write "uttidA some/where/a.wav" + ... subwriter["uttidA"] = "some/where/a.wav" + ... subwriter["uttidB"] = "some/where/b.wav" + + """ + + def __init__(self, p: Union[Path, str]): + assert check_argument_types() + self.path = Path(p) + self.chilidren = {} + self.fd = None + self.has_children = False + self.keys = set() + + def __enter__(self): + return self + + def __getitem__(self, key: str) -> "DatadirWriter": + assert check_argument_types() + if self.fd is not None: + raise RuntimeError("This writer points out a file") + + if key not in self.chilidren: + w = DatadirWriter((self.path / key)) + self.chilidren[key] = w + self.has_children = True + + retval = self.chilidren[key] + assert check_return_type(retval) + return retval + + def __setitem__(self, key: str, value: str): + assert check_argument_types() + if self.has_children: + raise RuntimeError("This writer points out a directory") + if key in self.keys: + warnings.warn(f"Duplicated: {key}") + + if self.fd is None: + self.path.parent.mkdir(parents=True, exist_ok=True) + self.fd = self.path.open("w", encoding="utf-8") + + self.keys.add(key) + self.fd.write(f"{key} {value}\n") + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + if self.has_children: + prev_child = None + for child in self.chilidren.values(): + child.close() + if prev_child is not None and prev_child.keys != child.keys: + warnings.warn( + f"Ids are mismatching between " + f"{prev_child.path} and {child.path}" + ) + prev_child = child + + elif self.fd is not None: + self.fd.close() diff --git a/espnet2/fileio/npy_scp.py b/espnet2/fileio/npy_scp.py new file mode 100644 index 0000000000000000000000000000000000000000..7cf61c468e6c118152c5bd0ff4bd5cbab96c861c --- /dev/null +++ b/espnet2/fileio/npy_scp.py @@ -0,0 +1,97 @@ +import collections.abc +from pathlib import Path +from typing import Union + +import numpy as np +from typeguard import check_argument_types + +from espnet2.fileio.read_text import read_2column_text + + +class NpyScpWriter: + """Writer class for a scp file of numpy file. + + Examples: + key1 /some/path/a.npy + key2 /some/path/b.npy + key3 /some/path/c.npy + key4 /some/path/d.npy + ... + + >>> writer = NpyScpWriter('./data/', './data/feat.scp') + >>> writer['aa'] = numpy_array + >>> writer['bb'] = numpy_array + + """ + + def __init__(self, outdir: Union[Path, str], scpfile: Union[Path, str]): + assert check_argument_types() + self.dir = Path(outdir) + self.dir.mkdir(parents=True, exist_ok=True) + scpfile = Path(scpfile) + scpfile.parent.mkdir(parents=True, exist_ok=True) + self.fscp = scpfile.open("w", encoding="utf-8") + + self.data = {} + + def get_path(self, key): + return self.data[key] + + def __setitem__(self, key, value): + assert isinstance(value, np.ndarray), type(value) + p = self.dir / f"{key}.npy" + p.parent.mkdir(parents=True, exist_ok=True) + np.save(str(p), value) + self.fscp.write(f"{key} {p}\n") + + # Store the file path + self.data[key] = str(p) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + self.fscp.close() + + +class NpyScpReader(collections.abc.Mapping): + """Reader class for a scp file of numpy file. + + Examples: + key1 /some/path/a.npy + key2 /some/path/b.npy + key3 /some/path/c.npy + key4 /some/path/d.npy + ... + + >>> reader = NpyScpReader('npy.scp') + >>> array = reader['key1'] + + """ + + def __init__(self, fname: Union[Path, str]): + assert check_argument_types() + self.fname = Path(fname) + self.data = read_2column_text(fname) + + def get_path(self, key): + return self.data[key] + + def __getitem__(self, key) -> np.ndarray: + p = self.data[key] + return np.load(p) + + def __contains__(self, item): + return item + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + def keys(self): + return self.data.keys() diff --git a/espnet2/fileio/rand_gen_dataset.py b/espnet2/fileio/rand_gen_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..bb92336a6feecc5733ad90dc530706c3a79dd251 --- /dev/null +++ b/espnet2/fileio/rand_gen_dataset.py @@ -0,0 +1,86 @@ +import collections +from pathlib import Path +from typing import Union + +import numpy as np +from typeguard import check_argument_types + +from espnet2.fileio.read_text import load_num_sequence_text + + +class FloatRandomGenerateDataset(collections.abc.Mapping): + """Generate float array from shape.txt. + + Examples: + shape.txt + uttA 123,83 + uttB 34,83 + >>> dataset = FloatRandomGenerateDataset("shape.txt") + >>> array = dataset["uttA"] + >>> assert array.shape == (123, 83) + >>> array = dataset["uttB"] + >>> assert array.shape == (34, 83) + + """ + + def __init__( + self, + shape_file: Union[Path, str], + dtype: Union[str, np.dtype] = "float32", + loader_type: str = "csv_int", + ): + assert check_argument_types() + shape_file = Path(shape_file) + self.utt2shape = load_num_sequence_text(shape_file, loader_type) + self.dtype = np.dtype(dtype) + + def __iter__(self): + return iter(self.utt2shape) + + def __len__(self): + return len(self.utt2shape) + + def __getitem__(self, item) -> np.ndarray: + shape = self.utt2shape[item] + return np.random.randn(*shape).astype(self.dtype) + + +class IntRandomGenerateDataset(collections.abc.Mapping): + """Generate float array from shape.txt + + Examples: + shape.txt + uttA 123,83 + uttB 34,83 + >>> dataset = IntRandomGenerateDataset("shape.txt", low=0, high=10) + >>> array = dataset["uttA"] + >>> assert array.shape == (123, 83) + >>> array = dataset["uttB"] + >>> assert array.shape == (34, 83) + + """ + + def __init__( + self, + shape_file: Union[Path, str], + low: int, + high: int = None, + dtype: Union[str, np.dtype] = "int64", + loader_type: str = "csv_int", + ): + assert check_argument_types() + shape_file = Path(shape_file) + self.utt2shape = load_num_sequence_text(shape_file, loader_type) + self.dtype = np.dtype(dtype) + self.low = low + self.high = high + + def __iter__(self): + return iter(self.utt2shape) + + def __len__(self): + return len(self.utt2shape) + + def __getitem__(self, item) -> np.ndarray: + shape = self.utt2shape[item] + return np.random.randint(self.low, self.high, size=shape, dtype=self.dtype) diff --git a/espnet2/fileio/read_text.py b/espnet2/fileio/read_text.py new file mode 100644 index 0000000000000000000000000000000000000000..e26e7a1c58295a3711d88c9a3cfa14c97a49a4fd --- /dev/null +++ b/espnet2/fileio/read_text.py @@ -0,0 +1,81 @@ +import logging +from pathlib import Path +from typing import Dict +from typing import List +from typing import Union + +from typeguard import check_argument_types + + +def read_2column_text(path: Union[Path, str]) -> Dict[str, str]: + """Read a text file having 2 column as dict object. + + Examples: + wav.scp: + key1 /some/path/a.wav + key2 /some/path/b.wav + + >>> read_2column_text('wav.scp') + {'key1': '/some/path/a.wav', 'key2': '/some/path/b.wav'} + + """ + assert check_argument_types() + + data = {} + with Path(path).open("r", encoding="utf-8") as f: + for linenum, line in enumerate(f, 1): + sps = line.rstrip().split(maxsplit=1) + if len(sps) == 1: + k, v = sps[0], "" + else: + k, v = sps + if k in data: + raise RuntimeError(f"{k} is duplicated ({path}:{linenum})") + data[k] = v + return data + + +def load_num_sequence_text( + path: Union[Path, str], loader_type: str = "csv_int" +) -> Dict[str, List[Union[float, int]]]: + """Read a text file indicating sequences of number + + Examples: + key1 1 2 3 + key2 34 5 6 + + >>> d = load_num_sequence_text('text') + >>> np.testing.assert_array_equal(d["key1"], np.array([1, 2, 3])) + """ + assert check_argument_types() + if loader_type == "text_int": + delimiter = " " + dtype = int + elif loader_type == "text_float": + delimiter = " " + dtype = float + elif loader_type == "csv_int": + delimiter = "," + dtype = int + elif loader_type == "csv_float": + delimiter = "," + dtype = float + else: + raise ValueError(f"Not supported loader_type={loader_type}") + + # path looks like: + # utta 1,0 + # uttb 3,4,5 + # -> return {'utta': np.ndarray([1, 0]), + # 'uttb': np.ndarray([3, 4, 5])} + d = read_2column_text(path) + + # Using for-loop instead of dict-comprehension for debuggability + retval = {} + for k, v in d.items(): + try: + retval[k] = [dtype(i) for i in v.split(delimiter)] + except TypeError: + logging.error(f'Error happened with path="{path}", id="{k}", value="{v}"') + raise + return retval diff --git a/espnet2/fileio/rttm.py b/espnet2/fileio/rttm.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8a343f3dc7f15878e5debcc10c1fb432981d8c --- /dev/null +++ b/espnet2/fileio/rttm.py @@ -0,0 +1,98 @@ +import collections.abc +from pathlib import Path +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union + +import numpy as np +import re +from typeguard import check_argument_types + + +def load_rttm_text(path: Union[Path, str]) -> Dict[str, List[Tuple[str, float, float]]]: + """Read a RTTM file + + Note: only support speaker information now + """ + + assert check_argument_types() + data = {} + with Path(path).open("r", encoding="utf-8") as f: + for linenum, line in enumerate(f, 1): + sps = re.split(" +", line.rstrip()) + + # RTTM format must have exactly 9 fields + assert len(sps) == 9, "{} does not have exactly 9 fields".format(path) + label_type, utt_id, channel, start, end, _, _, spk_id, _ = sps + + # Only support speaker label now + assert label_type in ["SPEAKER", "END"] + + spk_list, spk_event, max_duration = data.get(utt_id, ([], [], 0)) + if label_type == "END": + data[utt_id] = (spk_list, spk_event, int(end)) + continue + if spk_id not in spk_list: + spk_list.append(spk_id) + + data[utt_id] = ( + spk_list, + spk_event + [(spk_id, int(float(start)), int(float(end)))], + max_duration, + ) + + return data + + +class RttmReader(collections.abc.Mapping): + """Reader class for 'rttm.scp'. + + Examples: + SPEAKER file1 1 0 1023 spk1 + SPEAKER file1 2 4000 3023 spk2 + SPEAKER file1 3 500 4023 spk1 + END file1 4023 + + This is an extend version of standard RTTM format for espnet. + The difference including: + 1. Use sample number instead of absolute time + 2. has a END label to represent the duration of a recording + 3. replace duration (5th field) with end time + (For standard RTTM, + see https://catalog.ldc.upenn.edu/docs/LDC2004T12/RTTM-format-v13.pdf) + ... + + >>> reader = RttmReader('rttm') + >>> spk_label = reader["file1"] + + """ + + def __init__( + self, + fname: str, + ): + assert check_argument_types() + super().__init__() + + self.fname = fname + self.data = load_rttm_text(path=fname) + + def __getitem__(self, key): + spk_list, spk_event, max_duration = self.data[key] + spk_label = np.zeros((max_duration, len(spk_list))) + for spk_id, start, end in spk_event: + spk_label[start : end + 1, spk_list.index(spk_id)] = 1 + return spk_label + + def __contains__(self, item): + return item + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + def keys(self): + return self.data.keys() diff --git a/espnet2/fileio/sound_scp.py b/espnet2/fileio/sound_scp.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd1c90e4e49e6634ad0e72c656bbfa55f0b9f5a --- /dev/null +++ b/espnet2/fileio/sound_scp.py @@ -0,0 +1,131 @@ +import collections.abc +from pathlib import Path +from typing import Union + +import numpy as np +import soundfile +from typeguard import check_argument_types + +from espnet2.fileio.read_text import read_2column_text + + +class SoundScpReader(collections.abc.Mapping): + """Reader class for 'wav.scp'. + + Examples: + key1 /some/path/a.wav + key2 /some/path/b.wav + key3 /some/path/c.wav + key4 /some/path/d.wav + ... + + >>> reader = SoundScpReader('wav.scp') + >>> rate, array = reader['key1'] + + """ + + def __init__( + self, + fname, + dtype=np.int16, + always_2d: bool = False, + normalize: bool = False, + ): + assert check_argument_types() + self.fname = fname + self.dtype = dtype + self.always_2d = always_2d + self.normalize = normalize + self.data = read_2column_text(fname) + + def __getitem__(self, key): + wav = self.data[key] + if self.normalize: + # soundfile.read normalizes data to [-1,1] if dtype is not given + array, rate = soundfile.read(wav, always_2d=self.always_2d) + else: + array, rate = soundfile.read( + wav, dtype=self.dtype, always_2d=self.always_2d + ) + + return rate, array + + def get_path(self, key): + return self.data[key] + + def __contains__(self, item): + return item + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + def keys(self): + return self.data.keys() + + +class SoundScpWriter: + """Writer class for 'wav.scp' + + Examples: + key1 /some/path/a.wav + key2 /some/path/b.wav + key3 /some/path/c.wav + key4 /some/path/d.wav + ... + + >>> writer = SoundScpWriter('./data/', './data/feat.scp') + >>> writer['aa'] = 16000, numpy_array + >>> writer['bb'] = 16000, numpy_array + + """ + + def __init__( + self, + outdir: Union[Path, str], + scpfile: Union[Path, str], + format="wav", + dtype=None, + ): + assert check_argument_types() + self.dir = Path(outdir) + self.dir.mkdir(parents=True, exist_ok=True) + scpfile = Path(scpfile) + scpfile.parent.mkdir(parents=True, exist_ok=True) + self.fscp = scpfile.open("w", encoding="utf-8") + self.format = format + self.dtype = dtype + + self.data = {} + + def __setitem__(self, key: str, value): + rate, signal = value + assert isinstance(rate, int), type(rate) + assert isinstance(signal, np.ndarray), type(signal) + if signal.ndim not in (1, 2): + raise RuntimeError(f"Input signal must be 1 or 2 dimension: {signal.ndim}") + if signal.ndim == 1: + signal = signal[:, None] + + wav = self.dir / f"{key}.{self.format}" + wav.parent.mkdir(parents=True, exist_ok=True) + soundfile.write(str(wav), signal, rate) + + self.fscp.write(f"{key} {wav}\n") + + # Store the file path + self.data[key] = str(wav) + + def get_path(self, key): + return self.data[key] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def close(self): + self.fscp.close() diff --git a/espnet2/iterators/__init__.py b/espnet2/iterators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/iterators/abs_iter_factory.py b/espnet2/iterators/abs_iter_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..36e4dd2c52133fdabf02c94b268a80f6bb7dfd9a --- /dev/null +++ b/espnet2/iterators/abs_iter_factory.py @@ -0,0 +1,9 @@ +from abc import ABC +from abc import abstractmethod +from typing import Iterator + + +class AbsIterFactory(ABC): + @abstractmethod + def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator: + raise NotImplementedError diff --git a/espnet2/iterators/chunk_iter_factory.py b/espnet2/iterators/chunk_iter_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..f43703f9a8238c737fcca79be9c9d6f8c8cd8cab --- /dev/null +++ b/espnet2/iterators/chunk_iter_factory.py @@ -0,0 +1,215 @@ +import logging +from typing import Any +from typing import Dict +from typing import Iterator +from typing import List +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from typeguard import check_argument_types + +from espnet2.iterators.abs_iter_factory import AbsIterFactory +from espnet2.iterators.sequence_iter_factory import SequenceIterFactory +from espnet2.samplers.abs_sampler import AbsSampler + + +class ChunkIterFactory(AbsIterFactory): + """Creates chunks from a sequence + + Examples: + >>> batches = [["id1"], ["id2"], ...] + >>> batch_size = 128 + >>> chunk_length = 1000 + >>> iter_factory = ChunkIterFactory(dataset, batches, batch_size, chunk_length) + >>> it = iter_factory.build_iter(epoch) + >>> for ids, batch in it: + ... ... + + - The number of mini-batches are varied in each epochs and + we can't get the number in advance + because IterFactory doesn't be given to the length information. + - Since the first reason, "num_iters_per_epoch" can't be implemented + for this iterator. Instead of it, "num_samples_per_epoch" is implemented. + + """ + + def __init__( + self, + dataset, + batch_size: int, + batches: Union[AbsSampler, Sequence[Sequence[Any]]], + chunk_length: Union[int, str], + chunk_shift_ratio: float = 0.5, + num_cache_chunks: int = 1024, + num_samples_per_epoch: int = None, + seed: int = 0, + shuffle: bool = False, + num_workers: int = 0, + collate_fn=None, + pin_memory: bool = False, + ): + assert check_argument_types() + assert all(len(x) == 1 for x in batches), "batch-size must be 1" + + self.per_sample_iter_factory = SequenceIterFactory( + dataset=dataset, + batches=batches, + num_iters_per_epoch=num_samples_per_epoch, + seed=seed, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=collate_fn, + pin_memory=pin_memory, + ) + + self.num_cache_chunks = max(num_cache_chunks, batch_size) + if isinstance(chunk_length, str): + if len(chunk_length) == 0: + raise ValueError("e.g. 5,8 or 3-5: but got empty string") + + self.chunk_lengths = [] + for x in chunk_length.split(","): + try: + sps = list(map(int, x.split("-"))) + except ValueError: + raise ValueError(f"e.g. 5,8 or 3-5: but got {chunk_length}") + + if len(sps) > 2: + raise ValueError(f"e.g. 5,8 or 3-5: but got {chunk_length}") + elif len(sps) == 2: + # Append all numbers between the range into the candidates + self.chunk_lengths += list(range(sps[0], sps[1] + 1)) + else: + self.chunk_lengths += [sps[0]] + else: + # Single candidates: Fixed chunk length + self.chunk_lengths = [chunk_length] + + self.chunk_shift_ratio = chunk_shift_ratio + self.batch_size = batch_size + self.seed = seed + self.shuffle = shuffle + + def build_iter( + self, + epoch: int, + shuffle: bool = None, + ) -> Iterator[Tuple[List[str], Dict[str, torch.Tensor]]]: + per_sample_loader = self.per_sample_iter_factory.build_iter(epoch, shuffle) + + if shuffle is None: + shuffle = self.shuffle + state = np.random.RandomState(epoch + self.seed) + + # NOTE(kamo): + # This iterator supports multiple chunk lengths and + # keep chunks for each lenghts here until collecting specified numbers + cache_chunks_dict = {} + cache_id_list_dict = {} + for ids, batch in per_sample_loader: + # Must be per-sample-loader + assert len(ids) == 1, f"Must be per-sample-loader: {len(ids)}" + assert all(len(x) == 1 for x in batch.values()) + + # Get keys of sequence data + sequence_keys = [] + for key in batch: + if key + "_lengths" in batch: + sequence_keys.append(key) + # Remove lengths data and get the first sample + batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} + id_ = ids[0] + + for key in sequence_keys: + if len(batch[key]) != len(batch[sequence_keys[0]]): + raise RuntimeError( + f"All sequences must has same length: " + f"{len(batch[key])} != {len(batch[sequence_keys[0]])}" + ) + + L = len(batch[sequence_keys[0]]) + # Select chunk length + chunk_lengths = [lg for lg in self.chunk_lengths if lg < L] + if len(chunk_lengths) == 0: + logging.warning( + f"The length of '{id_}' is {L}, but it is shorter than " + f"any candidates of chunk-length: {self.chunk_lengths}" + ) + continue + + W = int(state.choice(chunk_lengths, 1)) + cache_id_list = cache_id_list_dict.setdefault(W, []) + cache_chunks = cache_chunks_dict.setdefault(W, {}) + + # Shift width to the next chunk + S = int(W * self.chunk_shift_ratio) + # Number of chunks + N = (L - W) // S + 1 + if shuffle: + Z = state.randint(0, (L - W) % S + 1) + else: + Z = 0 + + # Split a sequence into chunks. + # Note that the marginal frames divided by chunk length are discarded + for k, v in batch.items(): + if k not in cache_chunks: + cache_chunks[k] = [] + if k in sequence_keys: + # Shift chunks with overlapped length for data augmentation + cache_chunks[k] += [v[Z + i * S : Z + i * S + W] for i in range(N)] + else: + # If not sequence, use whole data instead of chunk + cache_chunks[k] += [v for _ in range(N)] + cache_id_list += [id_ for _ in range(N)] + + if len(cache_id_list) > self.num_cache_chunks: + cache_id_list, cache_chunks = yield from self._generate_mini_batches( + cache_id_list, + cache_chunks, + shuffle, + state, + ) + + cache_id_list_dict[W] = cache_id_list + cache_chunks_dict[W] = cache_chunks + + else: + for W in cache_id_list_dict: + cache_id_list = cache_id_list_dict.setdefault(W, []) + cache_chunks = cache_chunks_dict.setdefault(W, {}) + + yield from self._generate_mini_batches( + cache_id_list, + cache_chunks, + shuffle, + state, + ) + + def _generate_mini_batches( + self, + id_list: List[str], + batches: Dict[str, List[torch.Tensor]], + shuffle: bool, + state: np.random.RandomState, + ): + if shuffle: + indices = np.arange(0, len(id_list)) + state.shuffle(indices) + batches = {k: [v[i] for i in indices] for k, v in batches.items()} + id_list = [id_list[i] for i in indices] + + bs = self.batch_size + while len(id_list) >= bs: + # Make mini-batch and yield + yield ( + id_list[:bs], + {k: torch.stack(v[:bs], 0) for k, v in batches.items()}, + ) + id_list = id_list[bs:] + batches = {k: v[bs:] for k, v in batches.items()} + + return id_list, batches diff --git a/espnet2/iterators/multiple_iter_factory.py b/espnet2/iterators/multiple_iter_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..28e3d2dcb610b22f24bca1f7bfa73e216ac905a0 --- /dev/null +++ b/espnet2/iterators/multiple_iter_factory.py @@ -0,0 +1,37 @@ +import logging +from typing import Callable +from typing import Collection +from typing import Iterator + +import numpy as np +from typeguard import check_argument_types + +from espnet2.iterators.abs_iter_factory import AbsIterFactory + + +class MultipleIterFactory(AbsIterFactory): + def __init__( + self, + build_funcs: Collection[Callable[[], AbsIterFactory]], + seed: int = 0, + shuffle: bool = False, + ): + assert check_argument_types() + self.build_funcs = list(build_funcs) + self.seed = seed + self.shuffle = shuffle + + def build_iter(self, epoch: int, shuffle: bool = None) -> Iterator: + if shuffle is None: + shuffle = self.shuffle + + build_funcs = list(self.build_funcs) + + if shuffle: + np.random.RandomState(epoch + self.seed).shuffle(build_funcs) + + for i, build_func in enumerate(build_funcs): + logging.info(f"Building {i}th iter-factory...") + iter_factory = build_func() + assert isinstance(iter_factory, AbsIterFactory), type(iter_factory) + yield from iter_factory.build_iter(epoch, shuffle) diff --git a/espnet2/iterators/sequence_iter_factory.py b/espnet2/iterators/sequence_iter_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..48f61f8c7dfa57530c889bd718a5f44c8c6e1060 --- /dev/null +++ b/espnet2/iterators/sequence_iter_factory.py @@ -0,0 +1,143 @@ +from typing import Any +from typing import Sequence +from typing import Union + +import numpy as np +from torch.utils.data import DataLoader +from typeguard import check_argument_types + +from espnet2.iterators.abs_iter_factory import AbsIterFactory +from espnet2.samplers.abs_sampler import AbsSampler + + +class RawSampler(AbsSampler): + def __init__(self, batches): + self.batches = batches + + def __len__(self): + return len(self.batches) + + def __iter__(self): + return iter(self.batches) + + def generate(self, seed): + return list(self.batches) + + +class SequenceIterFactory(AbsIterFactory): + """Build iterator for each epoch. + + This class simply creates pytorch DataLoader except for the following points: + - The random seed is decided according to the number of epochs. This feature + guarantees reproducibility when resuming from middle of training process. + - Enable to restrict the number of samples for one epoch. This features + controls the interval number between training and evaluation. + + """ + + def __init__( + self, + dataset, + batches: Union[AbsSampler, Sequence[Sequence[Any]]], + num_iters_per_epoch: int = None, + seed: int = 0, + shuffle: bool = False, + num_workers: int = 0, + collate_fn=None, + pin_memory: bool = False, + ): + assert check_argument_types() + + if not isinstance(batches, AbsSampler): + self.sampler = RawSampler(batches) + else: + self.sampler = batches + + self.dataset = dataset + self.num_iters_per_epoch = num_iters_per_epoch + self.shuffle = shuffle + self.seed = seed + self.num_workers = num_workers + self.collate_fn = collate_fn + # https://discuss.pytorch.org/t/what-is-the-disadvantage-of-using-pin-memory/1702 + self.pin_memory = pin_memory + + def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader: + if shuffle is None: + shuffle = self.shuffle + + if self.num_iters_per_epoch is not None: + N = len(self.sampler) + # If corpus size is larger than the num_per_epoch + if self.num_iters_per_epoch < N: + N = len(self.sampler) + real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N) + + if offset >= self.num_iters_per_epoch: + current_batches = self.sampler.generate(real_epoch + self.seed) + if shuffle: + np.random.RandomState(real_epoch + self.seed).shuffle( + current_batches + ) + batches = current_batches[ + offset - self.num_iters_per_epoch : offset + ] + else: + prev_batches = self.sampler.generate(real_epoch - 1 + self.seed) + current_batches = self.sampler.generate(real_epoch + self.seed) + if shuffle: + np.random.RandomState(real_epoch - 1 + self.seed).shuffle( + prev_batches + ) + np.random.RandomState(real_epoch + self.seed).shuffle( + current_batches + ) + batches = ( + prev_batches[offset - self.num_iters_per_epoch :] + + current_batches[:offset] + ) + + # If corpus size is less than the num_per_epoch + else: + _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N) + _remain = self.num_iters_per_epoch + batches = [] + current_batches = self.sampler.generate(_epoch + self.seed) + if shuffle: + np.random.RandomState(_epoch + self.seed).shuffle(current_batches) + while _remain > 0: + + _batches = current_batches[_cursor : _cursor + _remain] + batches += _batches + if _cursor + _remain >= N: + _epoch += 1 + _cursor = 0 + current_batches = self.sampler.generate(_epoch + self.seed) + if shuffle: + np.random.RandomState(_epoch + self.seed).shuffle( + current_batches + ) + else: + _cursor = _cursor + _remain + _remain -= len(_batches) + + assert len(batches) == self.num_iters_per_epoch + + else: + batches = self.sampler.generate(epoch + self.seed) + if shuffle: + np.random.RandomState(epoch + self.seed).shuffle(batches) + + # For backward compatibility for pytorch DataLoader + if self.collate_fn is not None: + kwargs = dict(collate_fn=self.collate_fn) + else: + kwargs = {} + + return DataLoader( + dataset=self.dataset, + batch_sampler=batches, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + **kwargs, + ) diff --git a/espnet2/layers/__init__.py b/espnet2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/layers/abs_normalize.py b/espnet2/layers/abs_normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..f2be748dd7c7e30e90cd15e05fc55f62f5728df0 --- /dev/null +++ b/espnet2/layers/abs_normalize.py @@ -0,0 +1,14 @@ +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + + +class AbsNormalize(torch.nn.Module, ABC): + @abstractmethod + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # return output, output_lengths + raise NotImplementedError diff --git a/espnet2/layers/global_mvn.py b/espnet2/layers/global_mvn.py new file mode 100644 index 0000000000000000000000000000000000000000..31635cb4febac33a81b904f19f2f597d7c4de54e --- /dev/null +++ b/espnet2/layers/global_mvn.py @@ -0,0 +1,121 @@ +from pathlib import Path +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.layers.inversible_interface import InversibleInterface + + +class GlobalMVN(AbsNormalize, InversibleInterface): + """Apply global mean and variance normalization + + TODO(kamo): Make this class portable somehow + + Args: + stats_file: npy file + norm_means: Apply mean normalization + norm_vars: Apply var normalization + eps: + """ + + def __init__( + self, + stats_file: Union[Path, str], + norm_means: bool = True, + norm_vars: bool = True, + eps: float = 1.0e-20, + ): + assert check_argument_types() + super().__init__() + self.norm_means = norm_means + self.norm_vars = norm_vars + self.eps = eps + stats_file = Path(stats_file) + + self.stats_file = stats_file + stats = np.load(stats_file) + if isinstance(stats, np.ndarray): + # Kaldi like stats + count = stats[0].flatten()[-1] + mean = stats[0, :-1] / count + var = stats[1, :-1] / count - mean * mean + else: + # New style: Npz file + count = stats["count"] + sum_v = stats["sum"] + sum_square_v = stats["sum_square"] + mean = sum_v / count + var = sum_square_v / count - mean * mean + std = np.sqrt(np.maximum(var, eps)) + + self.register_buffer("mean", torch.from_numpy(mean)) + self.register_buffer("std", torch.from_numpy(std)) + + def extra_repr(self): + return ( + f"stats_file={self.stats_file}, " + f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" + ) + + def forward( + self, x: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function + + Args: + x: (B, L, ...) + ilens: (B,) + """ + if ilens is None: + ilens = x.new_full([x.size(0)], x.size(1)) + norm_means = self.norm_means + norm_vars = self.norm_vars + self.mean = self.mean.to(x.device, x.dtype) + self.std = self.std.to(x.device, x.dtype) + mask = make_pad_mask(ilens, x, 1) + + # feat: (B, T, D) + if norm_means: + if x.requires_grad: + x = x - self.mean + else: + x -= self.mean + if x.requires_grad: + x = x.masked_fill(mask, 0.0) + else: + x.masked_fill_(mask, 0.0) + + if norm_vars: + x /= self.std + + return x, ilens + + def inverse( + self, x: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + if ilens is None: + ilens = x.new_full([x.size(0)], x.size(1)) + norm_means = self.norm_means + norm_vars = self.norm_vars + self.mean = self.mean.to(x.device, x.dtype) + self.std = self.std.to(x.device, x.dtype) + mask = make_pad_mask(ilens, x, 1) + + if x.requires_grad: + x = x.masked_fill(mask, 0.0) + else: + x.masked_fill_(mask, 0.0) + + if norm_vars: + x *= self.std + + # feat: (B, T, D) + if norm_means: + x += self.mean + x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) + return x, ilens diff --git a/espnet2/layers/inversible_interface.py b/espnet2/layers/inversible_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a59399aaea3f2d9cba1cf13a63c99ee90df75f --- /dev/null +++ b/espnet2/layers/inversible_interface.py @@ -0,0 +1,14 @@ +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + + +class InversibleInterface(ABC): + @abstractmethod + def inverse( + self, input: torch.Tensor, input_lengths: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # return output, output_lengths + raise NotImplementedError diff --git a/espnet2/layers/label_aggregation.py b/espnet2/layers/label_aggregation.py new file mode 100644 index 0000000000000000000000000000000000000000..2070a888a84849ce28a877a122fb415c245c42b6 --- /dev/null +++ b/espnet2/layers/label_aggregation.py @@ -0,0 +1,78 @@ +import torch +from typeguard import check_argument_types +from typing import Optional +from typing import Tuple + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask + + +class LabelAggregate(torch.nn.Module): + def __init__( + self, + win_length: int = 512, + hop_length: int = 128, + center: bool = True, + ): + assert check_argument_types() + super().__init__() + + self.win_length = win_length + self.hop_length = hop_length + self.center = center + + def extra_repr(self): + return ( + f"win_length={self.win_length}, " + f"hop_length={self.hop_length}, " + f"center={self.center}, " + ) + + def forward( + self, input: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """LabelAggregate forward function. + + Args: + input: (Batch, Nsamples, Label_dim) + ilens: (Batch) + Returns: + output: (Batch, Frames, Label_dim) + + """ + bs = input.size(0) + max_length = input.size(1) + label_dim = input.size(2) + + # NOTE(jiatong): + # The default behaviour of label aggregation is compatible with + # torch.stft about framing and padding. + + # Step1: center padding + if self.center: + pad = self.win_length // 2 + max_length = max_length + 2 * pad + input = torch.nn.functional.pad(input, (0, 0, pad, pad), "constant", 0) + nframe = (max_length - self.win_length) // self.hop_length + 1 + + # Step2: framing + output = input.as_strided( + (bs, nframe, self.win_length, label_dim), + (max_length * label_dim, self.hop_length * label_dim, label_dim, 1), + ) + + # Step3: aggregate label + output = torch.gt(output.sum(dim=2, keepdim=False), self.win_length // 2) + output = output.float() + + # Step4: process lengths + if ilens is not None: + if self.center: + pad = self.win_length // 2 + ilens = ilens + 2 * pad + + olens = (ilens - self.win_length) // self.hop_length + 1 + output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) + else: + olens = None + + return output, olens diff --git a/espnet2/layers/log_mel.py b/espnet2/layers/log_mel.py new file mode 100644 index 0000000000000000000000000000000000000000..5caeadbe31e6acf04b7d2f55818e3ad4c663acf2 --- /dev/null +++ b/espnet2/layers/log_mel.py @@ -0,0 +1,83 @@ +import librosa +import torch +from typing import Tuple + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask + + +class LogMel(torch.nn.Module): + """Convert STFT to fbank feats + + The arguments is same as librosa.filters.mel + + Args: + fs: number > 0 [scalar] sampling rate of the incoming signal + n_fft: int > 0 [scalar] number of FFT components + n_mels: int > 0 [scalar] number of Mel bands to generate + fmin: float >= 0 [scalar] lowest frequency (in Hz) + fmax: float >= 0 [scalar] highest frequency (in Hz). + If `None`, use `fmax = fs / 2.0` + htk: use HTK formula instead of Slaney + """ + + def __init__( + self, + fs: int = 16000, + n_fft: int = 512, + n_mels: int = 80, + fmin: float = None, + fmax: float = None, + htk: bool = False, + log_base: float = None, + ): + super().__init__() + + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + _mel_options = dict( + sr=fs, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + htk=htk, + ) + self.mel_options = _mel_options + self.log_base = log_base + + # Note(kamo): The mel matrix of librosa is different from kaldi. + melmat = librosa.filters.mel(**_mel_options) + # melmat: (D2, D1) -> (D1, D2) + self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) + + def extra_repr(self): + return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) + + def forward( + self, + feat: torch.Tensor, + ilens: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) + mel_feat = torch.matmul(feat, self.melmat) + mel_feat = torch.clamp(mel_feat, min=1e-10) + + if self.log_base is None: + logmel_feat = mel_feat.log() + elif self.log_base == 2.0: + logmel_feat = mel_feat.log2() + elif self.log_base == 10.0: + logmel_feat = mel_feat.log10() + else: + logmel_feat = mel_feat.log() / torch.log(self.log_base) + + # Zero padding + if ilens is not None: + logmel_feat = logmel_feat.masked_fill( + make_pad_mask(ilens, logmel_feat, 1), 0.0 + ) + else: + ilens = feat.new_full( + [feat.size(0)], fill_value=feat.size(1), dtype=torch.long + ) + return logmel_feat, ilens diff --git a/espnet2/layers/mask_along_axis.py b/espnet2/layers/mask_along_axis.py new file mode 100644 index 0000000000000000000000000000000000000000..4c6f03aa4c09f2e1577ad6ba75b870fa12edaa1b --- /dev/null +++ b/espnet2/layers/mask_along_axis.py @@ -0,0 +1,128 @@ +import torch +from typeguard import check_argument_types +from typing import Sequence +from typing import Union + + +def mask_along_axis( + spec: torch.Tensor, + spec_lengths: torch.Tensor, + mask_width_range: Sequence[int] = (0, 30), + dim: int = 1, + num_mask: int = 2, + replace_with_zero: bool = True, +): + """Apply mask along the specified direction. + + Args: + spec: (Batch, Length, Freq) + spec_lengths: (Length): Not using lenghts in this implementation + mask_width_range: Select the width randomly between this range + """ + + org_size = spec.size() + if spec.dim() == 4: + # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq) + spec = spec.view(-1, spec.size(2), spec.size(3)) + + B = spec.shape[0] + # D = Length or Freq + D = spec.shape[dim] + # mask_length: (B, num_mask, 1) + mask_length = torch.randint( + mask_width_range[0], + mask_width_range[1], + (B, num_mask), + device=spec.device, + ).unsqueeze(2) + + # mask_pos: (B, num_mask, 1) + mask_pos = torch.randint( + 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device + ).unsqueeze(2) + + # aran: (1, 1, D) + aran = torch.arange(D, device=spec.device)[None, None, :] + # mask: (Batch, num_mask, D) + mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length)) + # Multiply masks: (Batch, num_mask, D) -> (Batch, D) + mask = mask.any(dim=1) + if dim == 1: + # mask: (Batch, Length, 1) + mask = mask.unsqueeze(2) + elif dim == 2: + # mask: (Batch, 1, Freq) + mask = mask.unsqueeze(1) + + if replace_with_zero: + value = 0.0 + else: + value = spec.mean() + + if spec.requires_grad: + spec = spec.masked_fill(mask, value) + else: + spec = spec.masked_fill_(mask, value) + spec = spec.view(*org_size) + return spec, spec_lengths + + +class MaskAlongAxis(torch.nn.Module): + def __init__( + self, + mask_width_range: Union[int, Sequence[int]] = (0, 30), + num_mask: int = 2, + dim: Union[int, str] = "time", + replace_with_zero: bool = True, + ): + assert check_argument_types() + if isinstance(mask_width_range, int): + mask_width_range = (0, mask_width_range) + if len(mask_width_range) != 2: + raise TypeError( + f"mask_width_range must be a tuple of int and int values: " + f"{mask_width_range}", + ) + + assert mask_width_range[1] > mask_width_range[0] + if isinstance(dim, str): + if dim == "time": + dim = 1 + elif dim == "freq": + dim = 2 + else: + raise ValueError("dim must be int, 'time' or 'freq'") + if dim == 1: + self.mask_axis = "time" + elif dim == 2: + self.mask_axis = "freq" + else: + self.mask_axis = "unknown" + + super().__init__() + self.mask_width_range = mask_width_range + self.num_mask = num_mask + self.dim = dim + self.replace_with_zero = replace_with_zero + + def extra_repr(self): + return ( + f"mask_width_range={self.mask_width_range}, " + f"num_mask={self.num_mask}, axis={self.mask_axis}" + ) + + def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None): + """Forward function. + + Args: + spec: (Batch, Length, Freq) + """ + + return mask_along_axis( + spec, + spec_lengths, + mask_width_range=self.mask_width_range, + dim=self.dim, + num_mask=self.num_mask, + replace_with_zero=self.replace_with_zero, + ) diff --git a/espnet2/layers/sinc_conv.py b/espnet2/layers/sinc_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..33df97fbcdf856b26c6f649fc01e52df488522b6 --- /dev/null +++ b/espnet2/layers/sinc_conv.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +# 2020, Technische Universität München; Ludwig Kürzinger +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Sinc convolutions.""" +import math +import torch +from typeguard import check_argument_types +from typing import Union + + +class LogCompression(torch.nn.Module): + """Log Compression Activation. + + Activation function `log(abs(x) + 1)`. + """ + + def __init__(self): + """Initialize.""" + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward. + + Applies the Log Compression function elementwise on tensor x. + """ + return torch.log(torch.abs(x) + 1) + + +class SincConv(torch.nn.Module): + """Sinc Convolution. + + This module performs a convolution using Sinc filters in time domain as kernel. + Sinc filters function as band passes in spectral domain. + The filtering is done as a convolution in time domain, and no transformation + to spectral domain is necessary. + + This implementation of the Sinc convolution is heavily inspired + by Ravanelli et al. https://github.com/mravanelli/SincNet, + and adapted for the ESpnet toolkit. + Combine Sinc convolutions with a log compression activation function, as in: + https://arxiv.org/abs/2010.07597 + + Notes: + Currently, the same filters are applied to all input channels. + The windowing function is applied on the kernel to obtained a smoother filter, + and not on the input values, which is different to traditional ASR. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + dilation: int = 1, + window_func: str = "hamming", + scale_type: str = "mel", + fs: Union[int, float] = 16000, + ): + """Initialize Sinc convolutions. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Sinc filter kernel size (needs to be an odd number). + stride: See torch.nn.functional.conv1d. + padding: See torch.nn.functional.conv1d. + dilation: See torch.nn.functional.conv1d. + window_func: Window function on the filter, one of ["hamming", "none"]. + fs (str, int, float): Sample rate of the input data + """ + assert check_argument_types() + super().__init__() + window_funcs = { + "none": self.none_window, + "hamming": self.hamming_window, + } + if window_func not in window_funcs: + raise NotImplementedError( + f"Window function has to be one of {list(window_funcs.keys())}", + ) + self.window_func = window_funcs[window_func] + scale_choices = { + "mel": MelScale, + "bark": BarkScale, + } + if scale_type not in scale_choices: + raise NotImplementedError( + f"Scale has to be one of {list(scale_choices.keys())}", + ) + self.scale = scale_choices[scale_type] + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + self.dilation = dilation + self.stride = stride + self.fs = float(fs) + if self.kernel_size % 2 == 0: + raise ValueError("SincConv: Kernel size must be odd.") + self.f = None + N = self.kernel_size // 2 + self._x = 2 * math.pi * torch.linspace(1, N, N) + self._window = self.window_func(torch.linspace(1, N, N)) + # init may get overwritten by E2E network, + # but is still required to calculate output dim + self.init_filters() + + @staticmethod + def sinc(x: torch.Tensor) -> torch.Tensor: + """Sinc function.""" + x2 = x + 1e-6 + return torch.sin(x2) / x2 + + @staticmethod + def none_window(x: torch.Tensor) -> torch.Tensor: + """Identity-like windowing function.""" + return torch.ones_like(x) + + @staticmethod + def hamming_window(x: torch.Tensor) -> torch.Tensor: + """Hamming Windowing function.""" + L = 2 * x.size(0) + 1 + x = x.flip(0) + return 0.54 - 0.46 * torch.cos(2.0 * math.pi * x / L) + + def init_filters(self): + """Initialize filters with filterbank values.""" + f = self.scale.bank(self.out_channels, self.fs) + f = torch.div(f, self.fs) + self.f = torch.nn.Parameter(f, requires_grad=True) + + def _create_filters(self, device: str): + """Calculate coefficients. + + This function (re-)calculates the filter convolutions coefficients. + """ + f_mins = torch.abs(self.f[:, 0]) + f_maxs = torch.abs(self.f[:, 0]) + torch.abs(self.f[:, 1] - self.f[:, 0]) + + self._x = self._x.to(device) + self._window = self._window.to(device) + + f_mins_x = torch.matmul(f_mins.view(-1, 1), self._x.view(1, -1)) + f_maxs_x = torch.matmul(f_maxs.view(-1, 1), self._x.view(1, -1)) + + kernel = (torch.sin(f_maxs_x) - torch.sin(f_mins_x)) / (0.5 * self._x) + kernel = kernel * self._window + + kernel_left = kernel.flip(1) + kernel_center = (2 * f_maxs - 2 * f_mins).unsqueeze(1) + filters = torch.cat([kernel_left, kernel_center, kernel], dim=1) + + filters = filters.view(filters.size(0), 1, filters.size(1)) + self.sinc_filters = filters + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + """Sinc convolution forward function. + + Args: + xs: Batch in form of torch.Tensor (B, C_in, D_in). + + Returns: + xs: Batch in form of torch.Tensor (B, C_out, D_out). + """ + self._create_filters(xs.device) + xs = torch.nn.functional.conv1d( + xs, + self.sinc_filters, + padding=self.padding, + stride=self.stride, + dilation=self.dilation, + groups=self.in_channels, + ) + return xs + + def get_odim(self, idim: int) -> int: + """Obtain the output dimension of the filter.""" + D_out = idim + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1 + D_out = (D_out // self.stride) + 1 + return D_out + + +class MelScale: + """Mel frequency scale.""" + + @staticmethod + def convert(f): + """Convert Hz to mel.""" + return 1125.0 * torch.log(torch.div(f, 700.0) + 1.0) + + @staticmethod + def invert(x): + """Convert mel to Hz.""" + return 700.0 * (torch.exp(torch.div(x, 1125.0)) - 1.0) + + @classmethod + def bank(cls, channels: int, fs: float) -> torch.Tensor: + """Obtain initialization values for the mel scale. + + Args: + channels: Number of channels. + fs: Sample rate. + + Returns: + torch.Tensor: Filter start frequencíes. + torch.Tensor: Filter stop frequencies. + """ + assert check_argument_types() + # min and max bandpass edge frequencies + min_frequency = torch.tensor(30.0) + max_frequency = torch.tensor(fs * 0.5) + frequencies = torch.linspace( + cls.convert(min_frequency), cls.convert(max_frequency), channels + 2 + ) + frequencies = cls.invert(frequencies) + f1, f2 = frequencies[:-2], frequencies[2:] + return torch.stack([f1, f2], dim=1) + + +class BarkScale: + """Bark frequency scale. + + Has wider bandwidths at lower frequencies, see: + Critical bandwidth: BARK + Zwicker and Terhardt, 1980 + """ + + @staticmethod + def convert(f): + """Convert Hz to Bark.""" + b = torch.div(f, 1000.0) + b = torch.pow(b, 2.0) * 1.4 + b = torch.pow(b + 1.0, 0.69) + return b * 75.0 + 25.0 + + @staticmethod + def invert(x): + """Convert Bark to Hz.""" + f = torch.div(x - 25.0, 75.0) + f = torch.pow(f, (1.0 / 0.69)) + f = torch.div(f - 1.0, 1.4) + f = torch.pow(f, 0.5) + return f * 1000.0 + + @classmethod + def bank(cls, channels: int, fs: float) -> torch.Tensor: + """Obtain initialization values for the Bark scale. + + Args: + channels: Number of channels. + fs: Sample rate. + + Returns: + torch.Tensor: Filter start frequencíes. + torch.Tensor: Filter stop frequencíes. + """ + assert check_argument_types() + # min and max BARK center frequencies by approximation + min_center_frequency = torch.tensor(70.0) + max_center_frequency = torch.tensor(fs * 0.45) + center_frequencies = torch.linspace( + cls.convert(min_center_frequency), + cls.convert(max_center_frequency), + channels, + ) + center_frequencies = cls.invert(center_frequencies) + + f1 = center_frequencies - torch.div(cls.convert(center_frequencies), 2) + f2 = center_frequencies + torch.div(cls.convert(center_frequencies), 2) + return torch.stack([f1, f2], dim=1) diff --git a/espnet2/layers/stft.py b/espnet2/layers/stft.py new file mode 100644 index 0000000000000000000000000000000000000000..f0c96aa25b5a9167f4fac75f028d4619cdef22c0 --- /dev/null +++ b/espnet2/layers/stft.py @@ -0,0 +1,169 @@ +from distutils.version import LooseVersion +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +from torch_complex.tensor import ComplexTensor +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet2.layers.inversible_interface import InversibleInterface + + +class Stft(torch.nn.Module, InversibleInterface): + def __init__( + self, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + ): + assert check_argument_types() + super().__init__() + self.n_fft = n_fft + if win_length is None: + self.win_length = n_fft + else: + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.normalized = normalized + self.onesided = onesided + if window is not None and not hasattr(torch, f"{window}_window"): + raise ValueError(f"{window} window is not implemented") + self.window = window + + def extra_repr(self): + return ( + f"n_fft={self.n_fft}, " + f"win_length={self.win_length}, " + f"hop_length={self.hop_length}, " + f"center={self.center}, " + f"normalized={self.normalized}, " + f"onesided={self.onesided}" + ) + + def forward( + self, input: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """STFT forward function. + + Args: + input: (Batch, Nsamples) or (Batch, Nsample, Channels) + ilens: (Batch) + Returns: + output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) + + """ + bs = input.size(0) + if input.dim() == 3: + multi_channel = True + # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) + input = input.transpose(1, 2).reshape(-1, input.size(1)) + else: + multi_channel = False + + # NOTE(kamo): + # The default behaviour of torch.stft is compatible with librosa.stft + # about padding and scaling. + # Note that it's different from scipy.signal.stft + + # output: (Batch, Freq, Frames, 2=real_imag) + # or (Batch, Channel, Freq, Frames, 2=real_imag) + if self.window is not None: + window_func = getattr(torch, f"{self.window}_window") + window = window_func( + self.win_length, dtype=input.dtype, device=input.device + ) + else: + window = None + output = torch.stft( + input, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + center=self.center, + window=window, + normalized=self.normalized, + onesided=self.onesided, + ) + # output: (Batch, Freq, Frames, 2=real_imag) + # -> (Batch, Frames, Freq, 2=real_imag) + output = output.transpose(1, 2) + if multi_channel: + # output: (Batch * Channel, Frames, Freq, 2=real_imag) + # -> (Batch, Frame, Channel, Freq, 2=real_imag) + output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose( + 1, 2 + ) + + if ilens is not None: + if self.center: + pad = self.win_length // 2 + ilens = ilens + 2 * pad + + olens = (ilens - self.win_length) // self.hop_length + 1 + output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) + else: + olens = None + + return output, olens + + def inverse( + self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Inverse STFT. + + Args: + input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F) + ilens: (batch,) + Returns: + wavs: (batch, samples) + ilens: (batch,) + """ + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + istft = torch.functional.istft + else: + try: + import torchaudio + except ImportError: + raise ImportError( + "Please install torchaudio>=0.3.0 or use torch>=1.6.0" + ) + + if not hasattr(torchaudio.functional, "istft"): + raise ImportError( + "Please install torchaudio>=0.3.0 or use torch>=1.6.0" + ) + istft = torchaudio.functional.istft + + if self.window is not None: + window_func = getattr(torch, f"{self.window}_window") + window = window_func( + self.win_length, dtype=input.dtype, device=input.device + ) + else: + window = None + + if isinstance(input, ComplexTensor): + input = torch.stack([input.real, input.imag], dim=-1) + assert input.shape[-1] == 2 + input = input.transpose(1, 2) + + wavs = istft( + input, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=window, + center=self.center, + normalized=self.normalized, + onesided=self.onesided, + length=ilens.max() if ilens is not None else ilens, + ) + + return wavs, ilens diff --git a/espnet2/layers/time_warp.py b/espnet2/layers/time_warp.py new file mode 100644 index 0000000000000000000000000000000000000000..52574aadbf98a7cbc9f2585a3ef7fe541b2af252 --- /dev/null +++ b/espnet2/layers/time_warp.py @@ -0,0 +1,94 @@ +from distutils.version import LooseVersion + +import torch + +from espnet.nets.pytorch_backend.nets_utils import pad_list + + +if LooseVersion(torch.__version__) >= LooseVersion("1.1"): + DEFAULT_TIME_WARP_MODE = "bicubic" +else: + # pytorch1.0 doesn't implement bicubic + DEFAULT_TIME_WARP_MODE = "bilinear" + + +def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE): + """Time warping using torch.interpolate. + + Args: + x: (Batch, Time, Freq) + window: time warp parameter + mode: Interpolate mode + """ + + # bicubic supports 4D or more dimension tensor + org_size = x.size() + if x.dim() == 3: + # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq) + x = x[:, None] + + t = x.shape[2] + if t - window <= window: + return x.view(*org_size) + + center = torch.randint(window, t - window, (1,))[0] + warped = torch.randint(center - window, center + window, (1,))[0] + 1 + + # left: (Batch, Channel, warped, Freq) + # right: (Batch, Channel, time - warped, Freq) + left = torch.nn.functional.interpolate( + x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False + ) + right = torch.nn.functional.interpolate( + x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False + ) + + if x.requires_grad: + x = torch.cat([left, right], dim=-2) + else: + x[:, :, :warped] = left + x[:, :, warped:] = right + + return x.view(*org_size) + + +class TimeWarp(torch.nn.Module): + """Time warping using torch.interpolate. + + Args: + window: time warp parameter + mode: Interpolate mode + """ + + def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE): + super().__init__() + self.window = window + self.mode = mode + + def extra_repr(self): + return f"window={self.window}, mode={self.mode}" + + def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None): + """Forward function. + + Args: + x: (Batch, Time, Freq) + x_lengths: (Batch,) + """ + + if x_lengths is None or all(le == x_lengths[0] for le in x_lengths): + # Note that applying same warping for each sample + y = time_warp(x, window=self.window, mode=self.mode) + else: + # FIXME(kamo): I have no idea to batchify Timewarp + ys = [] + for i in range(x.size(0)): + _y = time_warp( + x[i][None, : x_lengths[i]], + window=self.window, + mode=self.mode, + )[0] + ys.append(_y) + y = pad_list(ys, 0.0) + + return y, x_lengths diff --git a/espnet2/layers/utterance_mvn.py b/espnet2/layers/utterance_mvn.py new file mode 100644 index 0000000000000000000000000000000000000000..a41f869f322af8db54619731f268da5e730957c0 --- /dev/null +++ b/espnet2/layers/utterance_mvn.py @@ -0,0 +1,88 @@ +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet2.layers.abs_normalize import AbsNormalize + + +class UtteranceMVN(AbsNormalize): + def __init__( + self, + norm_means: bool = True, + norm_vars: bool = False, + eps: float = 1.0e-20, + ): + assert check_argument_types() + super().__init__() + self.norm_means = norm_means + self.norm_vars = norm_vars + self.eps = eps + + def extra_repr(self): + return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" + + def forward( + self, x: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function + + Args: + x: (B, L, ...) + ilens: (B,) + + """ + return utterance_mvn( + x, + ilens, + norm_means=self.norm_means, + norm_vars=self.norm_vars, + eps=self.eps, + ) + + +def utterance_mvn( + x: torch.Tensor, + ilens: torch.Tensor = None, + norm_means: bool = True, + norm_vars: bool = False, + eps: float = 1.0e-20, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply utterance mean and variance normalization + + Args: + x: (B, T, D), assumed zero padded + ilens: (B,) + norm_means: + norm_vars: + eps: + + """ + if ilens is None: + ilens = x.new_full([x.size(0)], x.size(1)) + ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)]) + # Zero padding + if x.requires_grad: + x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) + else: + x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) + # mean: (B, 1, D) + mean = x.sum(dim=1, keepdim=True) / ilens_ + + if norm_means: + x -= mean + + if norm_vars: + var = x.pow(2).sum(dim=1, keepdim=True) / ilens_ + std = torch.clamp(var.sqrt(), min=eps) + x = x / std.sqrt() + return x, ilens + else: + if norm_vars: + y = x - mean + y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0) + var = y.pow(2).sum(dim=1, keepdim=True) / ilens_ + std = torch.clamp(var.sqrt(), min=eps) + x /= std + return x, ilens diff --git a/espnet2/lm/__init__.py b/espnet2/lm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/lm/abs_model.py b/espnet2/lm/abs_model.py new file mode 100644 index 0000000000000000000000000000000000000000..ba5773d0126175a148be6abaf294f69008b2686c --- /dev/null +++ b/espnet2/lm/abs_model.py @@ -0,0 +1,29 @@ +from abc import ABC +from abc import abstractmethod +from typing import Tuple + +import torch + +from espnet.nets.scorer_interface import BatchScorerInterface + + +class AbsLM(torch.nn.Module, BatchScorerInterface, ABC): + """The abstract LM class + + To share the loss calculation way among different models, + We uses delegate pattern here: + The instance of this class should be passed to "LanguageModel" + + >>> from espnet2.lm.abs_model import AbsLM + >>> lm = AbsLM() + >>> model = LanguageESPnetModel(lm=lm) + + This "model" is one of mediator objects for "Task" class. + + """ + + @abstractmethod + def forward( + self, input: torch.Tensor, hidden: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/espnet2/lm/espnet_model.py b/espnet2/lm/espnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..db6b0f7d62dff606c9095a2379dc36793f48133e --- /dev/null +++ b/espnet2/lm/espnet_model.py @@ -0,0 +1,68 @@ +from typing import Dict +from typing import Tuple + +import torch +import torch.nn.functional as F +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet2.lm.abs_model import AbsLM +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.train.abs_espnet_model import AbsESPnetModel + + +class ESPnetLanguageModel(AbsESPnetModel): + def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0): + assert check_argument_types() + super().__init__() + self.lm = lm + self.sos = vocab_size - 1 + self.eos = vocab_size - 1 + + # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR. + self.ignore_id = ignore_id + + def nll( + self, text: torch.Tensor, text_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = text.size(0) + # For data parallel + text = text[:, : text_lengths.max()] + + # 1. Create a sentence pair like ' w1 w2 w3' and 'w1 w2 w3 ' + # text: (Batch, Length) -> x, y: (Batch, Length + 1) + x = F.pad(text, [1, 0], "constant", self.eos) + t = F.pad(text, [0, 1], "constant", self.ignore_id) + for i, l in enumerate(text_lengths): + t[i, l] = self.sos + x_lengths = text_lengths + 1 + + # 2. Forward Language model + # x: (Batch, Length) -> y: (Batch, Length, NVocab) + y, _ = self.lm(x, None) + + # 3. Calc negative log likelihood + # nll: (BxL,) + nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") + # nll: (BxL,) -> (BxL,) + nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0) + # nll: (BxL,) -> (B, L) + nll = nll.view(batch_size, -1) + return nll, x_lengths + + def forward( + self, text: torch.Tensor, text_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + nll, y_lengths = self.nll(text, text_lengths) + ntokens = y_lengths.sum() + loss = nll.sum() / ntokens + stats = dict(loss=loss.detach()) + + # force_gatherable: to-device and to-tensor if scalar for DataParallel + loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device) + return loss, stats, weight + + def collect_feats( + self, text: torch.Tensor, text_lengths: torch.Tensor + ) -> Dict[str, torch.Tensor]: + return {} diff --git a/espnet2/lm/seq_rnn_lm.py b/espnet2/lm/seq_rnn_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..10b18f639f6c2e69265945eac7a93c01605dff06 --- /dev/null +++ b/espnet2/lm/seq_rnn_lm.py @@ -0,0 +1,163 @@ +"""Sequential implementation of Recurrent Neural Network Language Model.""" +from typing import Tuple +from typing import Union + +import torch +import torch.nn as nn +from typeguard import check_argument_types + +from espnet2.lm.abs_model import AbsLM + + +class SequentialRNNLM(AbsLM): + """Sequential RNNLM. + + See also: + https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py + + """ + + def __init__( + self, + vocab_size: int, + unit: int = 650, + nhid: int = None, + nlayers: int = 2, + dropout_rate: float = 0.0, + tie_weights: bool = False, + rnn_type: str = "lstm", + ignore_id: int = 0, + ): + assert check_argument_types() + super().__init__() + + ninp = unit + if nhid is None: + nhid = unit + rnn_type = rnn_type.upper() + + self.drop = nn.Dropout(dropout_rate) + self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id) + if rnn_type in ["LSTM", "GRU"]: + rnn_class = getattr(nn, rnn_type) + self.rnn = rnn_class( + ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True + ) + else: + try: + nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] + except KeyError: + raise ValueError( + """An invalid option for `--model` was supplied, + options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""" + ) + self.rnn = nn.RNN( + ninp, + nhid, + nlayers, + nonlinearity=nonlinearity, + dropout=dropout_rate, + batch_first=True, + ) + self.decoder = nn.Linear(nhid, vocab_size) + + # Optionally tie weights as in: + # "Using the Output Embedding to Improve Language Models" + # (Press & Wolf 2016) https://arxiv.org/abs/1608.05859 + # and + # "Tying Word Vectors and Word Classifiers: + # A Loss Framework for Language Modeling" (Inan et al. 2016) + # https://arxiv.org/abs/1611.01462 + if tie_weights: + if nhid != ninp: + raise ValueError( + "When using the tied flag, nhid must be equal to emsize" + ) + self.decoder.weight = self.encoder.weight + + self.rnn_type = rnn_type + self.nhid = nhid + self.nlayers = nlayers + + def forward( + self, input: torch.Tensor, hidden: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + emb = self.drop(self.encoder(input)) + output, hidden = self.rnn(emb, hidden) + output = self.drop(output) + decoded = self.decoder( + output.contiguous().view(output.size(0) * output.size(1), output.size(2)) + ) + return ( + decoded.view(output.size(0), output.size(1), decoded.size(1)), + hidden, + ) + + def score( + self, + y: torch.Tensor, + state: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + x: torch.Tensor, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]]: + """Score new token. + + Args: + y: 1D torch.int64 prefix tokens. + state: Scorer state for prefix tokens + x: 2D encoder feature that generates ys. + + Returns: + Tuple of + torch.float32 scores for next token (n_vocab) + and next state for ys + + """ + y, new_state = self(y[-1].view(1, 1), state) + logp = y.log_softmax(dim=-1).view(-1) + return logp, new_state + + def batch_score( + self, ys: torch.Tensor, states: torch.Tensor, xs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Score new token batch. + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, n_vocab)` + and next state list for ys. + + """ + if states[0] is None: + states = None + elif isinstance(self.rnn, torch.nn.LSTM): + # states: Batch x 2 x (Nlayers, Dim) -> 2 x (Nlayers, Batch, Dim) + h = torch.stack([h for h, c in states], dim=1) + c = torch.stack([c for h, c in states], dim=1) + states = h, c + else: + # states: Batch x (Nlayers, Dim) -> (Nlayers, Batch, Dim) + states = torch.stack(states, dim=1) + + ys, states = self(ys[:, -1:], states) + # ys: (Batch, 1, Nvocab) -> (Batch, NVocab) + assert ys.size(1) == 1, ys.shape + ys = ys.squeeze(1) + logp = ys.log_softmax(dim=-1) + + # state: Change to batch first + if isinstance(self.rnn, torch.nn.LSTM): + # h, c: (Nlayers, Batch, Dim) + h, c = states + # states: Batch x 2 x (Nlayers, Dim) + states = [(h[:, i], c[:, i]) for i in range(h.size(1))] + else: + # states: (Nlayers, Batch, Dim) -> Batch x (Nlayers, Dim) + states = [states[:, i] for i in range(states.size(1))] + + return logp, states diff --git a/espnet2/lm/transformer_lm.py b/espnet2/lm/transformer_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..57df87bb11c49fba65b2fdf31157a20109f1ee17 --- /dev/null +++ b/espnet2/lm/transformer_lm.py @@ -0,0 +1,131 @@ +from typing import Any +from typing import List +from typing import Tuple + +import torch +import torch.nn as nn + +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet2.lm.abs_model import AbsLM + + +class TransformerLM(AbsLM): + def __init__( + self, + vocab_size: int, + pos_enc: str = None, + embed_unit: int = 128, + att_unit: int = 256, + head: int = 2, + unit: int = 1024, + layer: int = 4, + dropout_rate: float = 0.5, + ): + super().__init__() + if pos_enc == "sinusoidal": + pos_enc_class = PositionalEncoding + elif pos_enc is None: + + def pos_enc_class(*args, **kwargs): + return nn.Sequential() # indentity + + else: + raise ValueError(f"unknown pos-enc option: {pos_enc}") + + self.embed = nn.Embedding(vocab_size, embed_unit) + self.encoder = Encoder( + idim=embed_unit, + attention_dim=att_unit, + attention_heads=head, + linear_units=unit, + num_blocks=layer, + dropout_rate=dropout_rate, + input_layer="linear", + pos_enc_class=pos_enc_class, + ) + self.decoder = nn.Linear(att_unit, vocab_size) + + def _target_mask(self, ys_in_pad): + ys_mask = ys_in_pad != 0 + m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) + return ys_mask.unsqueeze(-2) & m + + def forward(self, input: torch.Tensor, hidden: None) -> Tuple[torch.Tensor, None]: + """Compute LM loss value from buffer sequences. + + Args: + input (torch.Tensor): Input ids. (batch, len) + hidden (torch.Tensor): Target ids. (batch, len) + + """ + x = self.embed(input) + mask = self._target_mask(input) + h, _ = self.encoder(x, mask) + y = self.decoder(h) + return y, None + + def score( + self, y: torch.Tensor, state: Any, x: torch.Tensor + ) -> Tuple[torch.Tensor, Any]: + """Score new token. + + Args: + y (torch.Tensor): 1D torch.int64 prefix tokens. + state: Scorer state for prefix tokens + x (torch.Tensor): encoder feature that generates ys. + + Returns: + tuple[torch.Tensor, Any]: Tuple of + torch.float32 scores for next token (vocab_size) + and next state for ys + + """ + y = y.unsqueeze(0) + h, _, cache = self.encoder.forward_one_step( + self.embed(y), self._target_mask(y), cache=state + ) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(dim=-1).squeeze(0) + return logp, cache + + def batch_score( + self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor + ) -> Tuple[torch.Tensor, List[Any]]: + """Score new token batch. + + Args: + ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). + states (List[Any]): Scorer states for prefix tokens. + xs (torch.Tensor): + The encoder feature that generates ys (n_batch, xlen, n_feat). + + Returns: + tuple[torch.Tensor, List[Any]]: Tuple of + batchfied scores for next token with shape of `(n_batch, vocab_size)` + and next state list for ys. + + """ + # merge states + n_batch = len(ys) + n_layers = len(self.encoder.encoders) + if states[0] is None: + batch_state = None + else: + # transpose state of [batch, layer] into [layer, batch] + batch_state = [ + torch.stack([states[b][i] for b in range(n_batch)]) + for i in range(n_layers) + ] + + # batch decoding + h, _, states = self.encoder.forward_one_step( + self.embed(ys), self._target_mask(ys), cache=batch_state + ) + h = self.decoder(h[:, -1]) + logp = h.log_softmax(dim=-1) + + # transpose state of [layer, batch] into [batch, layer] + state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] + return logp, state_list diff --git a/espnet2/main_funcs/__init__.py b/espnet2/main_funcs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/main_funcs/average_nbest_models.py b/espnet2/main_funcs/average_nbest_models.py new file mode 100644 index 0000000000000000000000000000000000000000..e025238e80e3d6cac85840ac09a6344710a758d2 --- /dev/null +++ b/espnet2/main_funcs/average_nbest_models.py @@ -0,0 +1,103 @@ +import logging +from pathlib import Path +from typing import Sequence +from typing import Union +import warnings + +import torch +from typeguard import check_argument_types +from typing import Collection + +from espnet2.train.reporter import Reporter + + +@torch.no_grad() +def average_nbest_models( + output_dir: Path, + reporter: Reporter, + best_model_criterion: Sequence[Sequence[str]], + nbest: Union[Collection[int], int], +) -> None: + """Generate averaged model from n-best models + + Args: + output_dir: The directory contains the model file for each epoch + reporter: Reporter instance + best_model_criterion: Give criterions to decide the best model. + e.g. [("valid", "loss", "min"), ("train", "acc", "max")] + nbest: + """ + assert check_argument_types() + if isinstance(nbest, int): + nbests = [nbest] + else: + nbests = list(nbest) + if len(nbests) == 0: + warnings.warn("At least 1 nbest values are required") + nbests = [1] + # 1. Get nbests: List[Tuple[str, str, List[Tuple[epoch, value]]]] + nbest_epochs = [ + (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) + for ph, k, m in best_model_criterion + if reporter.has(ph, k) + ] + + _loaded = {} + for ph, cr, epoch_and_values in nbest_epochs: + _nbests = [i for i in nbests if i <= len(epoch_and_values)] + if len(_nbests) == 0: + _nbests = [1] + + for n in _nbests: + if n == 0: + continue + elif n == 1: + # The averaged model is same as the best model + e, _ = epoch_and_values[0] + op = output_dir / f"{e}epoch.pth" + sym_op = output_dir / f"{ph}.{cr}.ave_1best.pth" + if sym_op.is_symlink() or sym_op.exists(): + sym_op.unlink() + sym_op.symlink_to(op.name) + else: + op = output_dir / f"{ph}.{cr}.ave_{n}best.pth" + logging.info( + f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' + ) + + avg = None + # 2.a. Averaging model + for e, _ in epoch_and_values[:n]: + if e not in _loaded: + _loaded[e] = torch.load( + output_dir / f"{e}epoch.pth", + map_location="cpu", + ) + states = _loaded[e] + + if avg is None: + avg = states + else: + # Accumulated + for k in avg: + avg[k] = avg[k] + states[k] + for k in avg: + if str(avg[k].dtype).startswith("torch.int"): + # For int type, not averaged, but only accumulated. + # e.g. BatchNorm.num_batches_tracked + # (If there are any cases that requires averaging + # or the other reducing method, e.g. max/min, for integer type, + # please report.) + pass + else: + avg[k] = avg[k] / n + + # 2.b. Save the ave model and create a symlink + torch.save(avg, op) + + # 3. *.*.ave.pth is a symlink to the max ave model + op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.pth" + sym_op = output_dir / f"{ph}.{cr}.ave.pth" + if sym_op.is_symlink() or sym_op.exists(): + sym_op.unlink() + sym_op.symlink_to(op.name) diff --git a/espnet2/main_funcs/calculate_all_attentions.py b/espnet2/main_funcs/calculate_all_attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..ed53d2b89c5e15fdd1e7bf01b704ce6819d612d7 --- /dev/null +++ b/espnet2/main_funcs/calculate_all_attentions.py @@ -0,0 +1,160 @@ +from collections import defaultdict +from typing import Dict +from typing import List + +import torch + +from espnet.nets.pytorch_backend.rnn.attentions import AttAdd +from espnet.nets.pytorch_backend.rnn.attentions import AttCov +from espnet.nets.pytorch_backend.rnn.attentions import AttCovLoc +from espnet.nets.pytorch_backend.rnn.attentions import AttDot +from espnet.nets.pytorch_backend.rnn.attentions import AttForward +from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA +from espnet.nets.pytorch_backend.rnn.attentions import AttLoc +from espnet.nets.pytorch_backend.rnn.attentions import AttLoc2D +from espnet.nets.pytorch_backend.rnn.attentions import AttLocRec +from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadAdd +from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadDot +from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadLoc +from espnet.nets.pytorch_backend.rnn.attentions import AttMultiHeadMultiResLoc +from espnet.nets.pytorch_backend.rnn.attentions import NoAtt +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention + + +from espnet2.train.abs_espnet_model import AbsESPnetModel + + +@torch.no_grad() +def calculate_all_attentions( + model: AbsESPnetModel, batch: Dict[str, torch.Tensor] +) -> Dict[str, List[torch.Tensor]]: + """Derive the outputs from the all attention layers + + Args: + model: + batch: same as forward + Returns: + return_dict: A dict of a list of tensor. + key_names x batch x (D1, D2, ...) + + """ + bs = len(next(iter(batch.values()))) + assert all(len(v) == bs for v in batch.values()), { + k: v.shape for k, v in batch.items() + } + + # 1. Register forward_hook fn to save the output from specific layers + outputs = {} + handles = {} + for name, modu in model.named_modules(): + + def hook(module, input, output, name=name): + if isinstance(module, MultiHeadedAttention): + # NOTE(kamo): MultiHeadedAttention doesn't return attention weight + # attn: (B, Head, Tout, Tin) + outputs[name] = module.attn.detach().cpu() + elif isinstance(module, AttLoc2D): + c, w = output + # w: previous concate attentions + # w: (B, nprev, Tin) + att_w = w[:, -1].detach().cpu() + outputs.setdefault(name, []).append(att_w) + elif isinstance(module, (AttCov, AttCovLoc)): + c, w = output + assert isinstance(w, list), type(w) + # w: list of previous attentions + # w: nprev x (B, Tin) + att_w = w[-1].detach().cpu() + outputs.setdefault(name, []).append(att_w) + elif isinstance(module, AttLocRec): + # w: (B, Tin) + c, (w, (att_h, att_c)) = output + att_w = w.detach().cpu() + outputs.setdefault(name, []).append(att_w) + elif isinstance( + module, + ( + AttMultiHeadDot, + AttMultiHeadAdd, + AttMultiHeadLoc, + AttMultiHeadMultiResLoc, + ), + ): + c, w = output + # w: nhead x (B, Tin) + assert isinstance(w, list), type(w) + att_w = [_w.detach().cpu() for _w in w] + outputs.setdefault(name, []).append(att_w) + elif isinstance( + module, + ( + AttAdd, + AttDot, + AttForward, + AttForwardTA, + AttLoc, + NoAtt, + ), + ): + c, w = output + att_w = w.detach().cpu() + outputs.setdefault(name, []).append(att_w) + + handle = modu.register_forward_hook(hook) + handles[name] = handle + + # 2. Just forward one by one sample. + # Batch-mode can't be used to keep requirements small for each models. + keys = [] + for k in batch: + if not k.endswith("_lengths"): + keys.append(k) + + return_dict = defaultdict(list) + for ibatch in range(bs): + # *: (B, L, ...) -> (1, L2, ...) + _sample = { + k: batch[k][ibatch, None, : batch[k + "_lengths"][ibatch]] + if k + "_lengths" in batch + else batch[k][ibatch, None] + for k in keys + } + + # *_lengths: (B,) -> (1,) + _sample.update( + { + k + "_lengths": batch[k + "_lengths"][ibatch, None] + for k in keys + if k + "_lengths" in batch + } + ) + model(**_sample) + + # Derive the attention results + for name, output in outputs.items(): + if isinstance(output, list): + if isinstance(output[0], list): + # output: nhead x (Tout, Tin) + output = torch.stack( + [ + # Tout x (1, Tin) -> (Tout, Tin) + torch.cat([o[idx] for o in output], dim=0) + for idx in range(len(output[0])) + ], + dim=0, + ) + else: + # Tout x (1, Tin) -> (Tout, Tin) + output = torch.cat(output, dim=0) + else: + # output: (1, NHead, Tout, Tin) -> (NHead, Tout, Tin) + output = output.squeeze(0) + # output: (Tout, Tin) or (NHead, Tout, Tin) + return_dict[name].append(output) + outputs.clear() + + # 3. Remove all hooks + for _, handle in handles.items(): + handle.remove() + + return dict(return_dict) diff --git a/espnet2/main_funcs/collect_stats.py b/espnet2/main_funcs/collect_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..9916ae650d75e82cde716a4db4c7cfbe6d6b6838 --- /dev/null +++ b/espnet2/main_funcs/collect_stats.py @@ -0,0 +1,126 @@ +from collections import defaultdict +import logging +from pathlib import Path +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from torch.nn.parallel import data_parallel +from torch.utils.data import DataLoader +from typeguard import check_argument_types + +from espnet2.fileio.datadir_writer import DatadirWriter +from espnet2.fileio.npy_scp import NpyScpWriter +from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.forward_adaptor import ForwardAdaptor +from espnet2.train.abs_espnet_model import AbsESPnetModel + + +@torch.no_grad() +def collect_stats( + model: AbsESPnetModel, + train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], + valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], + output_dir: Path, + ngpu: Optional[int], + log_interval: Optional[int], + write_collected_feats: bool, +) -> None: + """Perform on collect_stats mode. + + Running for deriving the shape information from data + and gathering statistics. + This method is used before executing train(). + + """ + assert check_argument_types() + + npy_scp_writers = {} + for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]): + if log_interval is None: + try: + log_interval = max(len(itr) // 20, 10) + except TypeError: + log_interval = 100 + + sum_dict = defaultdict(lambda: 0) + sq_dict = defaultdict(lambda: 0) + count_dict = defaultdict(lambda: 0) + + with DatadirWriter(output_dir / mode) as datadir_writer: + for iiter, (keys, batch) in enumerate(itr, 1): + batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") + + # 1. Write shape file + for name in batch: + if name.endswith("_lengths"): + continue + for i, (key, data) in enumerate(zip(keys, batch[name])): + if f"{name}_lengths" in batch: + lg = int(batch[f"{name}_lengths"][i]) + data = data[:lg] + datadir_writer[f"{name}_shape"][key] = ",".join( + map(str, data.shape) + ) + + # 2. Extract feats + if ngpu <= 1: + data = model.collect_feats(**batch) + else: + # Note that data_parallel can parallelize only "forward()" + data = data_parallel( + ForwardAdaptor(model, "collect_feats"), + (), + range(ngpu), + module_kwargs=batch, + ) + + # 3. Calculate sum and square sum + for key, v in data.items(): + for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())): + # Truncate zero-padding region + if f"{key}_lengths" in data: + length = data[f"{key}_lengths"][i] + # seq: (Length, Dim, ...) + seq = seq[:length] + else: + # seq: (Dim, ...) -> (1, Dim, ...) + seq = seq[None] + # Accumulate value, its square, and count + sum_dict[key] += seq.sum(0) + sq_dict[key] += (seq ** 2).sum(0) + count_dict[key] += len(seq) + + # 4. [Option] Write derived features as npy format file. + if write_collected_feats: + # Instantiate NpyScpWriter for the first iteration + if (key, mode) not in npy_scp_writers: + p = output_dir / mode / "collect_feats" + npy_scp_writers[(key, mode)] = NpyScpWriter( + p / f"data_{key}", p / f"{key}.scp" + ) + # Save array as npy file + npy_scp_writers[(key, mode)][uttid] = seq + + if iiter % log_interval == 0: + logging.info(f"Niter: {iiter}") + + for key in sum_dict: + np.savez( + output_dir / mode / f"{key}_stats.npz", + count=count_dict[key], + sum=sum_dict[key], + sum_square=sq_dict[key], + ) + + # batch_keys and stats_keys are used by aggregate_stats_dirs.py + with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f: + f.write( + "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n" + ) + with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f: + f.write("\n".join(sum_dict) + "\n") diff --git a/espnet2/main_funcs/pack_funcs.py b/espnet2/main_funcs/pack_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..dade3e06764c3ae6d5da102ef080adb06ad844f5 --- /dev/null +++ b/espnet2/main_funcs/pack_funcs.py @@ -0,0 +1,302 @@ +from datetime import datetime +from io import BytesIO +from io import TextIOWrapper +import os +from pathlib import Path +import sys +import tarfile +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Union +import zipfile + +import yaml + + +class Archiver: + def __init__(self, file, mode="r"): + if Path(file).suffix == ".tar": + self.type = "tar" + elif Path(file).suffix == ".tgz" or Path(file).suffixes == [".tar", ".gz"]: + self.type = "tar" + if mode == "w": + mode = "w:gz" + elif Path(file).suffix == ".tbz2" or Path(file).suffixes == [".tar", ".bz2"]: + self.type = "tar" + if mode == "w": + mode = "w:bz2" + elif Path(file).suffix == ".txz" or Path(file).suffixes == [".tar", ".xz"]: + self.type = "tar" + if mode == "w": + mode = "w:xz" + elif Path(file).suffix == ".zip": + self.type = "zip" + else: + raise ValueError(f"Cannot detect archive format: type={file}") + + if self.type == "tar": + self.fopen = tarfile.open(file, mode=mode) + elif self.type == "zip": + + self.fopen = zipfile.ZipFile(file, mode=mode) + else: + raise ValueError(f"Not supported: type={type}") + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.fopen.close() + + def close(self): + self.fopen.close() + + def __iter__(self): + if self.type == "tar": + return iter(self.fopen) + elif self.type == "zip": + return iter(self.fopen.infolist()) + else: + raise ValueError(f"Not supported: type={self.type}") + + def add(self, filename, arcname=None, recursive: bool = True): + if arcname is not None: + print(f"adding: {arcname}") + else: + print(f"adding: {filename}") + + if recursive and Path(filename).is_dir(): + for f in Path(filename).glob("**/*"): + if f.is_dir(): + continue + + if arcname is not None: + _arcname = Path(arcname) / f + else: + _arcname = None + + self.add(f, _arcname) + return + + if self.type == "tar": + return self.fopen.add(filename, arcname) + elif self.type == "zip": + return self.fopen.write(filename, arcname) + else: + raise ValueError(f"Not supported: type={self.type}") + + def addfile(self, info, fileobj): + print(f"adding: {self.get_name_from_info(info)}") + + if self.type == "tar": + return self.fopen.addfile(info, fileobj) + elif self.type == "zip": + return self.fopen.writestr(info, fileobj.read()) + else: + raise ValueError(f"Not supported: type={self.type}") + + def generate_info(self, name, size) -> Union[tarfile.TarInfo, zipfile.ZipInfo]: + """Generate TarInfo using system information""" + if self.type == "tar": + tarinfo = tarfile.TarInfo(str(name)) + if os.name == "posix": + tarinfo.gid = os.getgid() + tarinfo.uid = os.getuid() + tarinfo.mtime = datetime.now().timestamp() + tarinfo.size = size + # Keep mode as default + return tarinfo + elif self.type == "zip": + zipinfo = zipfile.ZipInfo(str(name), datetime.now().timetuple()[:6]) + zipinfo.file_size = size + return zipinfo + else: + raise ValueError(f"Not supported: type={self.type}") + + def get_name_from_info(self, info): + if self.type == "tar": + assert isinstance(info, tarfile.TarInfo), type(info) + return info.name + elif self.type == "zip": + assert isinstance(info, zipfile.ZipInfo), type(info) + return info.filename + else: + raise ValueError(f"Not supported: type={self.type}") + + def extract(self, info, path=None): + if self.type == "tar": + return self.fopen.extract(info, path) + elif self.type == "zip": + return self.fopen.extract(info, path) + else: + raise ValueError(f"Not supported: type={self.type}") + + def extractfile(self, info, mode="r"): + if self.type == "tar": + f = self.fopen.extractfile(info) + if mode == "r": + return TextIOWrapper(f) + else: + return f + elif self.type == "zip": + if mode == "rb": + mode = "r" + return self.fopen.open(info, mode) + else: + raise ValueError(f"Not supported: type={self.type}") + + +def find_path_and_change_it_recursive(value, src: str, tgt: str): + if isinstance(value, dict): + return { + k: find_path_and_change_it_recursive(v, src, tgt) for k, v in value.items() + } + elif isinstance(value, (list, tuple)): + return [find_path_and_change_it_recursive(v, src, tgt) for v in value] + elif isinstance(value, str) and Path(value) == Path(src): + return tgt + else: + return value + + +def get_dict_from_cache(meta: Union[Path, str]) -> Optional[Dict[str, str]]: + meta = Path(meta) + outpath = meta.parent.parent + if not meta.exists(): + return None + + with meta.open("r", encoding="utf-8") as f: + d = yaml.safe_load(f) + assert isinstance(d, dict), type(d) + yaml_files = d["yaml_files"] + files = d["files"] + assert isinstance(yaml_files, dict), type(yaml_files) + assert isinstance(files, dict), type(files) + + retval = {} + for key, value in list(yaml_files.items()) + list(files.items()): + if not (outpath / value).exists(): + return None + retval[key] = str(outpath / value) + return retval + + +def unpack( + input_archive: Union[Path, str], + outpath: Union[Path, str], + use_cache: bool = True, +) -> Dict[str, str]: + """Scan all files in the archive file and return as a dict of files. + + Examples: + tarfile: + model.pth + some1.file + some2.file + + >>> unpack("tarfile", "out") + {'asr_model_file': 'out/model.pth'} + """ + input_archive = Path(input_archive) + outpath = Path(outpath) + + with Archiver(input_archive) as archive: + for info in archive: + if Path(archive.get_name_from_info(info)).name == "meta.yaml": + if ( + use_cache + and (outpath / Path(archive.get_name_from_info(info))).exists() + ): + retval = get_dict_from_cache( + outpath / Path(archive.get_name_from_info(info)) + ) + if retval is not None: + return retval + d = yaml.safe_load(archive.extractfile(info)) + assert isinstance(d, dict), type(d) + yaml_files = d["yaml_files"] + files = d["files"] + assert isinstance(yaml_files, dict), type(yaml_files) + assert isinstance(files, dict), type(files) + break + else: + raise RuntimeError("Format error: not found meta.yaml") + + for info in archive: + fname = archive.get_name_from_info(info) + outname = outpath / fname + outname.parent.mkdir(parents=True, exist_ok=True) + if fname in set(yaml_files.values()): + d = yaml.safe_load(archive.extractfile(info)) + # Rewrite yaml + for info2 in archive: + name = archive.get_name_from_info(info2) + d = find_path_and_change_it_recursive(d, name, str(outpath / name)) + with outname.open("w", encoding="utf-8") as f: + yaml.safe_dump(d, f) + else: + archive.extract(info, path=outpath) + + retval = {} + for key, value in list(yaml_files.items()) + list(files.items()): + retval[key] = str(outpath / value) + return retval + + +def _to_relative_or_resolve(f): + # Resolve to avoid symbolic link + p = Path(f).resolve() + try: + # Change to relative if it can + p = p.relative_to(Path(".").resolve()) + except ValueError: + pass + return str(p) + + +def pack( + files: Dict[str, Union[str, Path]], + yaml_files: Dict[str, Union[str, Path]], + outpath: Union[str, Path], + option: Iterable[Union[str, Path]] = (), +): + for v in list(files.values()) + list(yaml_files.values()) + list(option): + if not Path(v).exists(): + raise FileNotFoundError(f"No such file or directory: {v}") + + files = {k: _to_relative_or_resolve(v) for k, v in files.items()} + yaml_files = {k: _to_relative_or_resolve(v) for k, v in yaml_files.items()} + option = [_to_relative_or_resolve(v) for v in option] + + meta_objs = dict( + files=files, + yaml_files=yaml_files, + timestamp=datetime.now().timestamp(), + python=sys.version, + ) + + try: + import torch + + meta_objs.update(torch=torch.__version__) + except ImportError: + pass + try: + import espnet + + meta_objs.update(espnet=espnet.__version__) + except ImportError: + pass + + Path(outpath).parent.mkdir(parents=True, exist_ok=True) + with Archiver(outpath, mode="w") as archive: + # Write packed/meta.yaml + fileobj = BytesIO(yaml.safe_dump(meta_objs).encode()) + info = archive.generate_info("meta.yaml", fileobj.getbuffer().nbytes) + archive.addfile(info, fileobj=fileobj) + + for f in list(yaml_files.values()) + list(files.values()) + list(option): + archive.add(f) + + print(f"Generate: {outpath}") diff --git a/espnet2/optimizers/__init__.py b/espnet2/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/optimizers/sgd.py b/espnet2/optimizers/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0d3d1c906acff372cab11c24b679f6e67400c6 --- /dev/null +++ b/espnet2/optimizers/sgd.py @@ -0,0 +1,32 @@ +import torch +from typeguard import check_argument_types + + +class SGD(torch.optim.SGD): + """Thin inheritance of torch.optim.SGD to bind the required arguments, 'lr' + + Note that + the arguments of the optimizer invoked by AbsTask.main() + must have default value except for 'param'. + + I can't understand why only SGD.lr doesn't have the default value. + """ + + def __init__( + self, + params, + lr: float = 0.1, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + ): + assert check_argument_types() + super().__init__( + params, + lr=lr, + momentum=momentum, + dampening=dampening, + weight_decay=weight_decay, + nesterov=nesterov, + ) diff --git a/espnet2/samplers/__init__.py b/espnet2/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/samplers/abs_sampler.py b/espnet2/samplers/abs_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..2f7aa539b8a14be204f18e26785d9c7e1d308f1f --- /dev/null +++ b/espnet2/samplers/abs_sampler.py @@ -0,0 +1,19 @@ +from abc import ABC +from abc import abstractmethod +from typing import Iterator +from typing import Tuple + +from torch.utils.data import Sampler + + +class AbsSampler(Sampler, ABC): + @abstractmethod + def __len__(self) -> int: + raise NotImplementedError + + @abstractmethod + def __iter__(self) -> Iterator[Tuple[str, ...]]: + raise NotImplementedError + + def generate(self, seed): + return list(self) diff --git a/espnet2/samplers/build_batch_sampler.py b/espnet2/samplers/build_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..1b645b371f8aa46dc3b238a414fb6b207e933c0b --- /dev/null +++ b/espnet2/samplers/build_batch_sampler.py @@ -0,0 +1,167 @@ +from typing import List +from typing import Sequence +from typing import Tuple +from typing import Union + +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet2.samplers.abs_sampler import AbsSampler +from espnet2.samplers.folded_batch_sampler import FoldedBatchSampler +from espnet2.samplers.length_batch_sampler import LengthBatchSampler +from espnet2.samplers.num_elements_batch_sampler import NumElementsBatchSampler +from espnet2.samplers.sorted_batch_sampler import SortedBatchSampler +from espnet2.samplers.unsorted_batch_sampler import UnsortedBatchSampler + + +BATCH_TYPES = dict( + unsorted="UnsortedBatchSampler has nothing in paticular feature and " + "just creates mini-batches which has constant batch_size. " + "This sampler doesn't require any length " + "information for each feature. " + "'key_file' is just a text file which describes each sample name." + "\n\n" + " utterance_id_a\n" + " utterance_id_b\n" + " utterance_id_c\n" + "\n" + "The fist column is referred, so 'shape file' can be used, too.\n\n" + " utterance_id_a 100,80\n" + " utterance_id_b 400,80\n" + " utterance_id_c 512,80\n", + sorted="SortedBatchSampler sorts samples by the length of the first input " + " in order to make each sample in a mini-batch has close length. " + "This sampler requires a text file which describes the length for each sample " + "\n\n" + " utterance_id_a 1000\n" + " utterance_id_b 1453\n" + " utterance_id_c 1241\n" + "\n" + "The first element of feature dimensions is referred, " + "so 'shape_file' can be also used.\n\n" + " utterance_id_a 1000,80\n" + " utterance_id_b 1453,80\n" + " utterance_id_c 1241,80\n", + folded="FoldedBatchSampler supports variable batch_size. " + "The batch_size is decided by\n" + " batch_size = base_batch_size // (L // fold_length)\n" + "L is referred to the largest length of samples in the mini-batch. " + "This samples requires length information as same as SortedBatchSampler\n", + length="LengthBatchSampler supports variable batch_size. " + "This sampler makes mini-batches which have same number of 'bins' as possible " + "counting by the total lengths of each feature in the mini-batch. " + "This sampler requires a text file which describes the length for each sample. " + "\n\n" + " utterance_id_a 1000\n" + " utterance_id_b 1453\n" + " utterance_id_c 1241\n" + "\n" + "The first element of feature dimensions is referred, " + "so 'shape_file' can be also used.\n\n" + " utterance_id_a 1000,80\n" + " utterance_id_b 1453,80\n" + " utterance_id_c 1241,80\n", + numel="NumElementsBatchSampler supports variable batch_size. " + "Just like LengthBatchSampler, this sampler makes mini-batches" + " which have same number of 'bins' as possible " + "counting by the total number of elements of each feature " + "instead of the length. " + "Thus this sampler requires the full information of the dimension of the features. " + "\n\n" + " utterance_id_a 1000,80\n" + " utterance_id_b 1453,80\n" + " utterance_id_c 1241,80\n", +) + + +def build_batch_sampler( + type: str, + batch_size: int, + batch_bins: int, + shape_files: Union[Tuple[str, ...], List[str]], + sort_in_batch: str = "descending", + sort_batch: str = "ascending", + drop_last: bool = False, + min_batch_size: int = 1, + fold_lengths: Sequence[int] = (), + padding: bool = True, + utt2category_file: str = None, +) -> AbsSampler: + """Helper function to instantiate BatchSampler. + + Args: + type: mini-batch type. "unsorted", "sorted", "folded", "numel", or, "length" + batch_size: The mini-batch size. Used for "unsorted", "sorted", "folded" mode + batch_bins: Used for "numel" model + shape_files: Text files describing the length and dimension + of each features. e.g. uttA 1330,80 + sort_in_batch: + sort_batch: + drop_last: + min_batch_size: Used for "numel" or "folded" mode + fold_lengths: Used for "folded" mode + padding: Whether sequences are input as a padded tensor or not. + used for "numel" mode + """ + assert check_argument_types() + if len(shape_files) == 0: + raise ValueError("No shape file are given") + + if type == "unsorted": + retval = UnsortedBatchSampler( + batch_size=batch_size, key_file=shape_files[0], drop_last=drop_last + ) + + elif type == "sorted": + retval = SortedBatchSampler( + batch_size=batch_size, + shape_file=shape_files[0], + sort_in_batch=sort_in_batch, + sort_batch=sort_batch, + drop_last=drop_last, + ) + + elif type == "folded": + if len(fold_lengths) != len(shape_files): + raise ValueError( + f"The number of fold_lengths must be equal to " + f"the number of shape_files: " + f"{len(fold_lengths)} != {len(shape_files)}" + ) + retval = FoldedBatchSampler( + batch_size=batch_size, + shape_files=shape_files, + fold_lengths=fold_lengths, + sort_in_batch=sort_in_batch, + sort_batch=sort_batch, + drop_last=drop_last, + min_batch_size=min_batch_size, + utt2category_file=utt2category_file, + ) + + elif type == "numel": + retval = NumElementsBatchSampler( + batch_bins=batch_bins, + shape_files=shape_files, + sort_in_batch=sort_in_batch, + sort_batch=sort_batch, + drop_last=drop_last, + padding=padding, + min_batch_size=min_batch_size, + ) + + elif type == "length": + retval = LengthBatchSampler( + batch_bins=batch_bins, + shape_files=shape_files, + sort_in_batch=sort_in_batch, + sort_batch=sort_batch, + drop_last=drop_last, + padding=padding, + min_batch_size=min_batch_size, + ) + + else: + raise ValueError(f"Not supported: {type}") + assert check_return_type(retval) + return retval diff --git a/espnet2/samplers/folded_batch_sampler.py b/espnet2/samplers/folded_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4d2e941e3d4cdddf67c1e182e72017a42932e173 --- /dev/null +++ b/espnet2/samplers/folded_batch_sampler.py @@ -0,0 +1,156 @@ +from typing import Iterator +from typing import List +from typing import Sequence +from typing import Tuple +from typing import Union + +from typeguard import check_argument_types + +from espnet2.fileio.read_text import load_num_sequence_text +from espnet2.fileio.read_text import read_2column_text +from espnet2.samplers.abs_sampler import AbsSampler + + +class FoldedBatchSampler(AbsSampler): + def __init__( + self, + batch_size: int, + shape_files: Union[Tuple[str, ...], List[str]], + fold_lengths: Sequence[int], + min_batch_size: int = 1, + sort_in_batch: str = "descending", + sort_batch: str = "ascending", + drop_last: bool = False, + utt2category_file: str = None, + ): + assert check_argument_types() + assert batch_size > 0 + if sort_batch != "ascending" and sort_batch != "descending": + raise ValueError( + f"sort_batch must be ascending or descending: {sort_batch}" + ) + if sort_in_batch != "descending" and sort_in_batch != "ascending": + raise ValueError( + f"sort_in_batch must be ascending or descending: {sort_in_batch}" + ) + + self.batch_size = batch_size + self.shape_files = shape_files + self.sort_in_batch = sort_in_batch + self.sort_batch = sort_batch + self.drop_last = drop_last + + # utt2shape: (Length, ...) + # uttA 100,... + # uttB 201,... + utt2shapes = [ + load_num_sequence_text(s, loader_type="csv_int") for s in shape_files + ] + + first_utt2shape = utt2shapes[0] + for s, d in zip(shape_files, utt2shapes): + if set(d) != set(first_utt2shape): + raise RuntimeError( + f"keys are mismatched between {s} != {shape_files[0]}" + ) + + # Sort samples in ascending order + # (shape order should be like (Length, Dim)) + keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0]) + if len(keys) == 0: + raise RuntimeError(f"0 lines found: {shape_files[0]}") + + category2utt = {} + if utt2category_file is not None: + utt2category = read_2column_text(utt2category_file) + if set(utt2category) != set(first_utt2shape): + raise RuntimeError( + "keys are mismatched between " + f"{utt2category_file} != {shape_files[0]}" + ) + for k in keys: + category2utt.setdefault(utt2category[k], []).append(k) + else: + category2utt["default_category"] = keys + + self.batch_list = [] + for d, v in category2utt.items(): + category_keys = v + # Decide batch-sizes + start = 0 + batch_sizes = [] + while True: + k = category_keys[start] + factor = max(int(d[k][0] / m) for d, m in zip(utt2shapes, fold_lengths)) + bs = max(min_batch_size, int(batch_size / (1 + factor))) + if self.drop_last and start + bs > len(category_keys): + # This if-block avoids 0-batches + if len(self.batch_list) > 0: + break + + bs = min(len(category_keys) - start, bs) + batch_sizes.append(bs) + start += bs + if start >= len(category_keys): + break + + if len(batch_sizes) == 0: + # Maybe we can't reach here + raise RuntimeError("0 batches") + + # If the last batch-size is smaller than minimum batch_size, + # the samples are redistributed to the other mini-batches + if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size: + for i in range(batch_sizes.pop(-1)): + batch_sizes[-(i % len(batch_sizes)) - 2] += 1 + + if not self.drop_last: + # Bug check + assert sum(batch_sizes) == len( + category_keys + ), f"{sum(batch_sizes)} != {len(category_keys)}" + + # Set mini-batch + cur_batch_list = [] + start = 0 + for bs in batch_sizes: + assert len(category_keys) >= start + bs, "Bug" + minibatch_keys = category_keys[start : start + bs] + start += bs + if sort_in_batch == "descending": + minibatch_keys.reverse() + elif sort_in_batch == "ascending": + # Key are already sorted in ascending + pass + else: + raise ValueError( + "sort_in_batch must be ascending or " + f"descending: {sort_in_batch}" + ) + cur_batch_list.append(tuple(minibatch_keys)) + + if sort_batch == "ascending": + pass + elif sort_batch == "descending": + cur_batch_list.reverse() + else: + raise ValueError( + f"sort_batch must be ascending or descending: {sort_batch}" + ) + self.batch_list.extend(cur_batch_list) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"N-batch={len(self)}, " + f"batch_size={self.batch_size}, " + f"shape_files={self.shape_files}, " + f"sort_in_batch={self.sort_in_batch}, " + f"sort_batch={self.sort_batch})" + ) + + def __len__(self): + return len(self.batch_list) + + def __iter__(self) -> Iterator[Tuple[str, ...]]: + return iter(self.batch_list) diff --git a/espnet2/samplers/length_batch_sampler.py b/espnet2/samplers/length_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..522a4b49e145c4f92cb03ca65f7d454df960033c --- /dev/null +++ b/espnet2/samplers/length_batch_sampler.py @@ -0,0 +1,143 @@ +from typing import Iterator +from typing import List +from typing import Tuple +from typing import Union + +from typeguard import check_argument_types + +from espnet2.fileio.read_text import load_num_sequence_text +from espnet2.samplers.abs_sampler import AbsSampler + + +class LengthBatchSampler(AbsSampler): + def __init__( + self, + batch_bins: int, + shape_files: Union[Tuple[str, ...], List[str]], + min_batch_size: int = 1, + sort_in_batch: str = "descending", + sort_batch: str = "ascending", + drop_last: bool = False, + padding: bool = True, + ): + assert check_argument_types() + assert batch_bins > 0 + if sort_batch != "ascending" and sort_batch != "descending": + raise ValueError( + f"sort_batch must be ascending or descending: {sort_batch}" + ) + if sort_in_batch != "descending" and sort_in_batch != "ascending": + raise ValueError( + f"sort_in_batch must be ascending or descending: {sort_in_batch}" + ) + + self.batch_bins = batch_bins + self.shape_files = shape_files + self.sort_in_batch = sort_in_batch + self.sort_batch = sort_batch + self.drop_last = drop_last + + # utt2shape: (Length, ...) + # uttA 100,... + # uttB 201,... + utt2shapes = [ + load_num_sequence_text(s, loader_type="csv_int") for s in shape_files + ] + + first_utt2shape = utt2shapes[0] + for s, d in zip(shape_files, utt2shapes): + if set(d) != set(first_utt2shape): + raise RuntimeError( + f"keys are mismatched between {s} != {shape_files[0]}" + ) + + # Sort samples in ascending order + # (shape order should be like (Length, Dim)) + keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0]) + if len(keys) == 0: + raise RuntimeError(f"0 lines found: {shape_files[0]}") + + # Decide batch-sizes + batch_sizes = [] + current_batch_keys = [] + for key in keys: + current_batch_keys.append(key) + # shape: (Length, dim1, dim2, ...) + if padding: + # bins = bs x max_length + bins = sum(len(current_batch_keys) * sh[key][0] for sh in utt2shapes) + else: + # bins = sum of lengths + bins = sum(d[k][0] for k in current_batch_keys for d in utt2shapes) + + if bins > batch_bins and len(current_batch_keys) >= min_batch_size: + batch_sizes.append(len(current_batch_keys)) + current_batch_keys = [] + else: + if len(current_batch_keys) != 0 and ( + not self.drop_last or len(batch_sizes) == 0 + ): + batch_sizes.append(len(current_batch_keys)) + + if len(batch_sizes) == 0: + # Maybe we can't reach here + raise RuntimeError("0 batches") + + # If the last batch-size is smaller than minimum batch_size, + # the samples are redistributed to the other mini-batches + if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size: + for i in range(batch_sizes.pop(-1)): + batch_sizes[-(i % len(batch_sizes)) - 1] += 1 + + if not self.drop_last: + # Bug check + assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}" + + # Set mini-batch + self.batch_list = [] + iter_bs = iter(batch_sizes) + bs = next(iter_bs) + minibatch_keys = [] + for key in keys: + minibatch_keys.append(key) + if len(minibatch_keys) == bs: + if sort_in_batch == "descending": + minibatch_keys.reverse() + elif sort_in_batch == "ascending": + # Key are already sorted in ascending + pass + else: + raise ValueError( + "sort_in_batch must be ascending" + f" or descending: {sort_in_batch}" + ) + self.batch_list.append(tuple(minibatch_keys)) + minibatch_keys = [] + try: + bs = next(iter_bs) + except StopIteration: + break + + if sort_batch == "ascending": + pass + elif sort_batch == "descending": + self.batch_list.reverse() + else: + raise ValueError( + f"sort_batch must be ascending or descending: {sort_batch}" + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"N-batch={len(self)}, " + f"batch_bins={self.batch_bins}, " + f"sort_in_batch={self.sort_in_batch}, " + f"sort_batch={self.sort_batch})" + ) + + def __len__(self): + return len(self.batch_list) + + def __iter__(self) -> Iterator[Tuple[str, ...]]: + return iter(self.batch_list) diff --git a/espnet2/samplers/num_elements_batch_sampler.py b/espnet2/samplers/num_elements_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..46ff177b8f332c2a95fc283ca80914eb56a6bfa0 --- /dev/null +++ b/espnet2/samplers/num_elements_batch_sampler.py @@ -0,0 +1,160 @@ +from typing import Iterator +from typing import List +from typing import Tuple +from typing import Union + +import numpy as np +from typeguard import check_argument_types + +from espnet2.fileio.read_text import load_num_sequence_text +from espnet2.samplers.abs_sampler import AbsSampler + + +class NumElementsBatchSampler(AbsSampler): + def __init__( + self, + batch_bins: int, + shape_files: Union[Tuple[str, ...], List[str]], + min_batch_size: int = 1, + sort_in_batch: str = "descending", + sort_batch: str = "ascending", + drop_last: bool = False, + padding: bool = True, + ): + assert check_argument_types() + assert batch_bins > 0 + if sort_batch != "ascending" and sort_batch != "descending": + raise ValueError( + f"sort_batch must be ascending or descending: {sort_batch}" + ) + if sort_in_batch != "descending" and sort_in_batch != "ascending": + raise ValueError( + f"sort_in_batch must be ascending or descending: {sort_in_batch}" + ) + + self.batch_bins = batch_bins + self.shape_files = shape_files + self.sort_in_batch = sort_in_batch + self.sort_batch = sort_batch + self.drop_last = drop_last + + # utt2shape: (Length, ...) + # uttA 100,... + # uttB 201,... + utt2shapes = [ + load_num_sequence_text(s, loader_type="csv_int") for s in shape_files + ] + + first_utt2shape = utt2shapes[0] + for s, d in zip(shape_files, utt2shapes): + if set(d) != set(first_utt2shape): + raise RuntimeError( + f"keys are mismatched between {s} != {shape_files[0]}" + ) + + # Sort samples in ascending order + # (shape order should be like (Length, Dim)) + keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0]) + if len(keys) == 0: + raise RuntimeError(f"0 lines found: {shape_files[0]}") + if padding: + # If padding case, the feat-dim must be same over whole corpus, + # therefore the first sample is referred + feat_dims = [np.prod(d[keys[0]][1:]) for d in utt2shapes] + else: + feat_dims = None + + # Decide batch-sizes + batch_sizes = [] + current_batch_keys = [] + for key in keys: + current_batch_keys.append(key) + # shape: (Length, dim1, dim2, ...) + if padding: + for d, s in zip(utt2shapes, shape_files): + if tuple(d[key][1:]) != tuple(d[keys[0]][1:]): + raise RuntimeError( + "If padding=True, the " + f"feature dimension must be unified: {s}", + ) + bins = sum( + len(current_batch_keys) * sh[key][0] * d + for sh, d in zip(utt2shapes, feat_dims) + ) + else: + bins = sum( + np.prod(d[k]) for k in current_batch_keys for d in utt2shapes + ) + + if bins > batch_bins and len(current_batch_keys) >= min_batch_size: + batch_sizes.append(len(current_batch_keys)) + current_batch_keys = [] + else: + if len(current_batch_keys) != 0 and ( + not self.drop_last or len(batch_sizes) == 0 + ): + batch_sizes.append(len(current_batch_keys)) + + if len(batch_sizes) == 0: + # Maybe we can't reach here + raise RuntimeError("0 batches") + + # If the last batch-size is smaller than minimum batch_size, + # the samples are redistributed to the other mini-batches + if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size: + for i in range(batch_sizes.pop(-1)): + batch_sizes[-(i % len(batch_sizes)) - 1] += 1 + + if not self.drop_last: + # Bug check + assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}" + + # Set mini-batch + self.batch_list = [] + iter_bs = iter(batch_sizes) + bs = next(iter_bs) + minibatch_keys = [] + for key in keys: + minibatch_keys.append(key) + if len(minibatch_keys) == bs: + if sort_in_batch == "descending": + minibatch_keys.reverse() + elif sort_in_batch == "ascending": + # Key are already sorted in ascending + pass + else: + raise ValueError( + "sort_in_batch must be ascending" + f" or descending: {sort_in_batch}" + ) + + self.batch_list.append(tuple(minibatch_keys)) + minibatch_keys = [] + try: + bs = next(iter_bs) + except StopIteration: + break + + if sort_batch == "ascending": + pass + elif sort_batch == "descending": + self.batch_list.reverse() + else: + raise ValueError( + f"sort_batch must be ascending or descending: {sort_batch}" + ) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"N-batch={len(self)}, " + f"batch_bins={self.batch_bins}, " + f"sort_in_batch={self.sort_in_batch}, " + f"sort_batch={self.sort_batch})" + ) + + def __len__(self): + return len(self.batch_list) + + def __iter__(self) -> Iterator[Tuple[str, ...]]: + return iter(self.batch_list) diff --git a/espnet2/samplers/sorted_batch_sampler.py b/espnet2/samplers/sorted_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..4649f9a4fd748f8911cdaab94b8f07326905643f --- /dev/null +++ b/espnet2/samplers/sorted_batch_sampler.py @@ -0,0 +1,95 @@ +import logging +from typing import Iterator +from typing import Tuple + +from typeguard import check_argument_types + +from espnet2.fileio.read_text import load_num_sequence_text +from espnet2.samplers.abs_sampler import AbsSampler + + +class SortedBatchSampler(AbsSampler): + """BatchSampler with sorted samples by length. + + Args: + batch_size: + shape_file: + sort_in_batch: 'descending', 'ascending' or None. + sort_batch: + """ + + def __init__( + self, + batch_size: int, + shape_file: str, + sort_in_batch: str = "descending", + sort_batch: str = "ascending", + drop_last: bool = False, + ): + assert check_argument_types() + assert batch_size > 0 + self.batch_size = batch_size + self.shape_file = shape_file + self.sort_in_batch = sort_in_batch + self.sort_batch = sort_batch + self.drop_last = drop_last + + # utt2shape: (Length, ...) + # uttA 100,... + # uttB 201,... + utt2shape = load_num_sequence_text(shape_file, loader_type="csv_int") + if sort_in_batch == "descending": + # Sort samples in descending order (required by RNN) + keys = sorted(utt2shape, key=lambda k: -utt2shape[k][0]) + elif sort_in_batch == "ascending": + # Sort samples in ascending order + keys = sorted(utt2shape, key=lambda k: utt2shape[k][0]) + else: + raise ValueError( + f"sort_in_batch must be either one of " + f"ascending, descending, or None: {sort_in_batch}" + ) + if len(keys) == 0: + raise RuntimeError(f"0 lines found: {shape_file}") + + # Apply max(, 1) to avoid 0-batches + N = max(len(keys) // batch_size, 1) + if not self.drop_last: + # Split keys evenly as possible as. Note that If N != 1, + # the these batches always have size of batch_size at minimum. + self.batch_list = [ + keys[i * len(keys) // N : (i + 1) * len(keys) // N] for i in range(N) + ] + else: + self.batch_list = [ + tuple(keys[i * batch_size : (i + 1) * batch_size]) for i in range(N) + ] + + if len(self.batch_list) == 0: + logging.warning(f"{shape_file} is empty") + + if sort_in_batch != sort_batch: + if sort_batch not in ("ascending", "descending"): + raise ValueError( + f"sort_batch must be ascending or descending: {sort_batch}" + ) + self.batch_list.reverse() + + if len(self.batch_list) == 0: + raise RuntimeError("0 batches") + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"N-batch={len(self)}, " + f"batch_size={self.batch_size}, " + f"shape_file={self.shape_file}, " + f"sort_in_batch={self.sort_in_batch}, " + f"sort_batch={self.sort_batch})" + ) + + def __len__(self): + return len(self.batch_list) + + def __iter__(self) -> Iterator[Tuple[str, ...]]: + return iter(self.batch_list) diff --git a/espnet2/samplers/unsorted_batch_sampler.py b/espnet2/samplers/unsorted_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..33a22090ac273868edd1c8eeec7a47093f0450cb --- /dev/null +++ b/espnet2/samplers/unsorted_batch_sampler.py @@ -0,0 +1,91 @@ +import logging +from typing import Iterator +from typing import Tuple + +from typeguard import check_argument_types + +from espnet2.fileio.read_text import read_2column_text +from espnet2.samplers.abs_sampler import AbsSampler + + +class UnsortedBatchSampler(AbsSampler): + """BatchSampler with constant batch-size. + + Any sorting is not done in this class, + so no length information is required, + This class is convenient for decoding mode, + or not seq2seq learning e.g. classification. + + Args: + batch_size: + key_file: + """ + + def __init__( + self, + batch_size: int, + key_file: str, + drop_last: bool = False, + utt2category_file: str = None, + ): + assert check_argument_types() + assert batch_size > 0 + self.batch_size = batch_size + self.key_file = key_file + self.drop_last = drop_last + + # utt2shape: + # uttA + # uttB + utt2any = read_2column_text(key_file) + if len(utt2any) == 0: + logging.warning(f"{key_file} is empty") + # In this case the, the first column in only used + keys = list(utt2any) + if len(keys) == 0: + raise RuntimeError(f"0 lines found: {key_file}") + + category2utt = {} + if utt2category_file is not None: + utt2category = read_2column_text(utt2category_file) + if set(utt2category) != set(keys): + raise RuntimeError( + f"keys are mismatched between {utt2category_file} != {key_file}" + ) + for k, v in utt2category.items(): + category2utt.setdefault(v, []).append(k) + else: + category2utt["default_category"] = keys + + self.batch_list = [] + for d, v in category2utt.items(): + category_keys = v + # Apply max(, 1) to avoid 0-batches + N = max(len(category_keys) // batch_size, 1) + if not self.drop_last: + # Split keys evenly as possible as. Note that If N != 1, + # the these batches always have size of batch_size at minimum. + cur_batch_list = [ + category_keys[i * len(keys) // N : (i + 1) * len(keys) // N] + for i in range(N) + ] + else: + cur_batch_list = [ + tuple(category_keys[i * batch_size : (i + 1) * batch_size]) + for i in range(N) + ] + self.batch_list.extend(cur_batch_list) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"N-batch={len(self)}, " + f"batch_size={self.batch_size}, " + f"key_file={self.key_file}, " + ) + + def __len__(self): + return len(self.batch_list) + + def __iter__(self) -> Iterator[Tuple[str, ...]]: + return iter(self.batch_list) diff --git a/espnet2/schedulers/__init__.py b/espnet2/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/schedulers/abs_scheduler.py b/espnet2/schedulers/abs_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..de8087e417e45d639295ef2a9628baafa4c6238f --- /dev/null +++ b/espnet2/schedulers/abs_scheduler.py @@ -0,0 +1,82 @@ +from abc import ABC +from abc import abstractmethod +from distutils.version import LooseVersion + +import torch +import torch.optim.lr_scheduler as L + + +class AbsScheduler(ABC): + @abstractmethod + def step(self, epoch: int = None): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def load_state_dict(self, state): + pass + + +# If you need to define custom scheduler, please inherit these classes +class AbsBatchStepScheduler(AbsScheduler): + @abstractmethod + def step(self, epoch: int = None): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def load_state_dict(self, state): + pass + + +class AbsEpochStepScheduler(AbsScheduler): + @abstractmethod + def step(self, epoch: int = None): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def load_state_dict(self, state): + pass + + +class AbsValEpochStepScheduler(AbsEpochStepScheduler): + @abstractmethod + def step(self, val, epoch: int = None): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def load_state_dict(self, state): + pass + + +# Create alias type to check the type +# Note(kamo): Currently PyTorch doesn't provide the base class +# to judge these classes. +AbsValEpochStepScheduler.register(L.ReduceLROnPlateau) +for s in [ + L.ReduceLROnPlateau, + L.LambdaLR, + L.StepLR, + L.MultiStepLR, + L.MultiStepLR, + L.ExponentialLR, + L.CosineAnnealingLR, +]: + AbsEpochStepScheduler.register(s) +if LooseVersion(torch.__version__) >= LooseVersion("1.3.0"): + for s in [L.CyclicLR, L.OneCycleLR, L.CosineAnnealingWarmRestarts]: + AbsBatchStepScheduler.register(s) diff --git a/espnet2/schedulers/noam_lr.py b/espnet2/schedulers/noam_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..e893c0481694bfc4d0844e7a214aafb4cea44449 --- /dev/null +++ b/espnet2/schedulers/noam_lr.py @@ -0,0 +1,67 @@ +from distutils.version import LooseVersion +from typing import Union +import warnings + +import torch +from torch.optim.lr_scheduler import _LRScheduler +from typeguard import check_argument_types + +from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler + + +class NoamLR(_LRScheduler, AbsBatchStepScheduler): + """The LR scheduler proposed by Noam + + Ref: + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + FIXME(kamo): PyTorch doesn't provide _LRScheduler as public class, + thus the behaviour isn't guaranteed at forward PyTorch version. + + NOTE(kamo): The "model_size" in original implementation is derived from + the model, but in this implementation, this parameter is a constant value. + You need to change it if the model is changed. + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + model_size: Union[int, float] = 320, + warmup_steps: Union[int, float] = 25000, + last_epoch: int = -1, + ): + if LooseVersion(torch.__version__) < LooseVersion("1.1.0"): + raise NotImplementedError(f"Require PyTorch>=1.1.0: {torch.__version__}") + assert check_argument_types() + self.model_size = model_size + self.warmup_steps = warmup_steps + + lr = list(optimizer.param_groups)[0]["lr"] + new_lr = self.lr_for_WarmupLR(lr) + warnings.warn( + f"NoamLR is deprecated. " + f"Use WarmupLR(warmup_steps={warmup_steps}) with Optimizer(lr={new_lr})", + ) + + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer, last_epoch) + + def lr_for_WarmupLR(self, lr: float) -> float: + return lr / self.model_size ** 0.5 / self.warmup_steps ** 0.5 + + def __repr__(self): + return ( + f"{self.__class__.__name__}(model_size={self.model_size}, " + f"warmup_steps={self.warmup_steps})" + ) + + def get_lr(self): + step_num = self.last_epoch + 1 + return [ + lr + * self.model_size ** -0.5 + * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5) + for lr in self.base_lrs + ] diff --git a/espnet2/schedulers/warmup_lr.py b/espnet2/schedulers/warmup_lr.py new file mode 100644 index 0000000000000000000000000000000000000000..2217d8b37614723e2c2bde5fa2496515ceca187b --- /dev/null +++ b/espnet2/schedulers/warmup_lr.py @@ -0,0 +1,53 @@ +from distutils.version import LooseVersion +from typing import Union + +import torch +from torch.optim.lr_scheduler import _LRScheduler +from typeguard import check_argument_types + +from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler + + +class WarmupLR(_LRScheduler, AbsBatchStepScheduler): + """The WarmupLR scheduler + + This scheduler is almost same as NoamLR Scheduler except for following difference: + + NoamLR: + lr = optimizer.lr * model_size ** -0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + WarmupLR: + lr = optimizer.lr * warmup_step ** 0.5 + * min(step ** -0.5, step * warmup_step ** -1.5) + + Note that the maximum lr equals to optimizer.lr in this scheduler. + + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_steps: Union[int, float] = 25000, + last_epoch: int = -1, + ): + if LooseVersion(torch.__version__) < LooseVersion("1.1.0"): + raise NotImplementedError(f"Require PyTorch>=1.1.0: {torch.__version__}") + + assert check_argument_types() + self.warmup_steps = warmup_steps + + # __init__() must be invoked before setting field + # because step() is also invoked in __init__() + super().__init__(optimizer, last_epoch) + + def __repr__(self): + return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})" + + def get_lr(self): + step_num = self.last_epoch + 1 + return [ + lr + * self.warmup_steps ** 0.5 + * min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5) + for lr in self.base_lrs + ] diff --git a/espnet2/tasks/__init__.py b/espnet2/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/tasks/abs_task.py b/espnet2/tasks/abs_task.py new file mode 100644 index 0000000000000000000000000000000000000000..a836633a6bbb929f31f63cbcff8c52ee0b2214f1 --- /dev/null +++ b/espnet2/tasks/abs_task.py @@ -0,0 +1,1782 @@ +from abc import ABC +from abc import abstractmethod +import argparse +from dataclasses import dataclass +from distutils.version import LooseVersion +import functools +import logging +import os +from pathlib import Path +import sys +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import humanfriendly +import numpy as np +import torch +import torch.multiprocessing +import torch.nn +import torch.optim +from torch.utils.data import DataLoader +from typeguard import check_argument_types +from typeguard import check_return_type +import wandb +import yaml + +from espnet import __version__ +from espnet.utils.cli_utils import get_commandline_args +from espnet2.iterators.abs_iter_factory import AbsIterFactory +from espnet2.iterators.chunk_iter_factory import ChunkIterFactory +from espnet2.iterators.multiple_iter_factory import MultipleIterFactory +from espnet2.iterators.sequence_iter_factory import SequenceIterFactory +from espnet2.main_funcs.collect_stats import collect_stats +from espnet2.optimizers.sgd import SGD +from espnet2.samplers.build_batch_sampler import BATCH_TYPES +from espnet2.samplers.build_batch_sampler import build_batch_sampler +from espnet2.samplers.unsorted_batch_sampler import UnsortedBatchSampler +from espnet2.schedulers.noam_lr import NoamLR +from espnet2.schedulers.warmup_lr import WarmupLR +from espnet2.torch_utils.load_pretrained_model import load_pretrained_model +from espnet2.torch_utils.model_summary import model_summary +from espnet2.torch_utils.pytorch_version import pytorch_cudnn_version +from espnet2.torch_utils.set_all_random_seed import set_all_random_seed +from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet2.train.class_choices import ClassChoices +from espnet2.train.dataset import AbsDataset +from espnet2.train.dataset import DATA_TYPES +from espnet2.train.dataset import ESPnetDataset +from espnet2.train.distributed_utils import DistributedOption +from espnet2.train.distributed_utils import free_port +from espnet2.train.distributed_utils import get_master_port +from espnet2.train.distributed_utils import get_node_rank +from espnet2.train.distributed_utils import get_num_nodes +from espnet2.train.distributed_utils import resolve_distributed_mode +from espnet2.train.iterable_dataset import IterableESPnetDataset +from espnet2.train.trainer import Trainer +from espnet2.utils.build_dataclass import build_dataclass +from espnet2.utils import config_argparse +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import humanfriendly_parse_size_or_none +from espnet2.utils.types import int_or_none +from espnet2.utils.types import str2bool +from espnet2.utils.types import str2triple_str +from espnet2.utils.types import str_or_int +from espnet2.utils.types import str_or_none +from espnet2.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump + +if LooseVersion(torch.__version__) >= LooseVersion("1.5.0"): + from torch.multiprocessing.spawn import ProcessContext +else: + from torch.multiprocessing.spawn import SpawnContext as ProcessContext + + +optim_classes = dict( + adam=torch.optim.Adam, + sgd=SGD, + adadelta=torch.optim.Adadelta, + adagrad=torch.optim.Adagrad, + adamax=torch.optim.Adamax, + asgd=torch.optim.ASGD, + lbfgs=torch.optim.LBFGS, + rmsprop=torch.optim.RMSprop, + rprop=torch.optim.Rprop, +) +if LooseVersion(torch.__version__) >= LooseVersion("1.2.0"): + optim_classes["adamw"] = torch.optim.AdamW +try: + import torch_optimizer + + optim_classes.update( + accagd=torch_optimizer.AccSGD, + adabound=torch_optimizer.AdaBound, + adamod=torch_optimizer.AdaMod, + diffgrad=torch_optimizer.DiffGrad, + lamb=torch_optimizer.Lamb, + novograd=torch_optimizer.NovoGrad, + pid=torch_optimizer.PID, + # torch_optimizer<=0.0.1a10 doesn't support + # qhadam=torch_optimizer.QHAdam, + qhm=torch_optimizer.QHM, + radam=torch_optimizer.RAdam, + sgdw=torch_optimizer.SGDW, + yogi=torch_optimizer.Yogi, + ) + del torch_optimizer +except ImportError: + pass +try: + import apex + + optim_classes.update( + fusedadam=apex.optimizers.FusedAdam, + fusedlamb=apex.optimizers.FusedLAMB, + fusednovograd=apex.optimizers.FusedNovoGrad, + fusedsgd=apex.optimizers.FusedSGD, + ) + del apex +except ImportError: + pass +try: + import fairscale +except ImportError: + fairscale = None + + +scheduler_classes = dict( + ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, + lambdalr=torch.optim.lr_scheduler.LambdaLR, + steplr=torch.optim.lr_scheduler.StepLR, + multisteplr=torch.optim.lr_scheduler.MultiStepLR, + exponentiallr=torch.optim.lr_scheduler.ExponentialLR, + CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, +) +if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): + scheduler_classes.update( + noamlr=NoamLR, + warmuplr=WarmupLR, + ) +if LooseVersion(torch.__version__) >= LooseVersion("1.3.0"): + CosineAnnealingWarmRestarts = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts + scheduler_classes.update( + cycliclr=torch.optim.lr_scheduler.CyclicLR, + onecyclelr=torch.optim.lr_scheduler.OneCycleLR, + CosineAnnealingWarmRestarts=CosineAnnealingWarmRestarts, + ) +# To lower keys +optim_classes = {k.lower(): v for k, v in optim_classes.items()} +scheduler_classes = {k.lower(): v for k, v in scheduler_classes.items()} + + +@dataclass +class IteratorOptions: + preprocess_fn: callable + collate_fn: callable + data_path_and_name_and_type: list + shape_files: list + batch_size: int + batch_bins: int + batch_type: str + max_cache_size: float + max_cache_fd: int + distributed: bool + num_batches: Optional[int] + num_iters_per_epoch: Optional[int] + train: bool + + +class AbsTask(ABC): + # Use @staticmethod, or @classmethod, + # instead of instance method to avoid God classes + + # If you need more than one optimizers, change this value in inheritance + num_optimizers: int = 1 + trainer = Trainer + class_choices_list: List[ClassChoices] = [] + + def __init__(self): + raise RuntimeError("This class can't be instantiated.") + + @classmethod + @abstractmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + pass + + @classmethod + @abstractmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[[Sequence[Dict[str, np.ndarray]]], Dict[str, torch.Tensor]]: + """Return "collate_fn", which is a callable object and given to DataLoader. + + >>> from torch.utils.data import DataLoader + >>> loader = DataLoader(collate_fn=cls.build_collate_fn(args, train=True), ...) + + In many cases, you can use our common collate_fn. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + raise NotImplementedError + + @classmethod + @abstractmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + """Define the required names by Task + + This function is used by + >>> cls.check_task_requirements() + If your model is defined as following, + + >>> from espnet2.train.abs_espnet_model import AbsESPnetModel + >>> class Model(AbsESPnetModel): + ... def forward(self, input, output, opt=None): pass + + then "required_data_names" should be as + + >>> required_data_names = ('input', 'output') + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + """Define the optional names by Task + + This function is used by + >>> cls.check_task_requirements() + If your model is defined as follows, + + >>> from espnet2.train.abs_espnet_model import AbsESPnetModel + >>> class Model(AbsESPnetModel): + ... def forward(self, input, output, opt=None): pass + + then "optional_data_names" should be as + + >>> optional_data_names = ('opt',) + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel: + raise NotImplementedError + + @classmethod + def get_parser(cls) -> config_argparse.ArgumentParser: + assert check_argument_types() + + class ArgumentDefaultsRawTextHelpFormatter( + argparse.RawTextHelpFormatter, + argparse.ArgumentDefaultsHelpFormatter, + ): + pass + + parser = config_argparse.ArgumentParser( + description="base parser", + formatter_class=ArgumentDefaultsRawTextHelpFormatter, + ) + + # NOTE(kamo): Use '_' instead of '-' to avoid confusion. + # I think '-' looks really confusing if it's written in yaml. + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + parser.set_defaults(required=["output_dir"]) + + group = parser.add_argument_group("Common configuration") + + group.add_argument( + "--print_config", + action="store_true", + help="Print the config file and exit", + ) + group.add_argument( + "--log_level", + type=lambda x: x.upper(), + default="INFO", + choices=("ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), + help="The verbose level of logging", + ) + group.add_argument( + "--dry_run", + type=str2bool, + default=False, + help="Perform process without training", + ) + group.add_argument( + "--iterator_type", + type=str, + choices=["sequence", "chunk", "task", "none"], + default="sequence", + help="Specify iterator type", + ) + + group.add_argument("--output_dir", type=str_or_none, default=None) + group.add_argument( + "--ngpu", + type=int, + default=0, + help="The number of gpus. 0 indicates CPU mode", + ) + group.add_argument("--seed", type=int, default=0, help="Random seed") + group.add_argument( + "--num_workers", + type=int, + default=1, + help="The number of workers used for DataLoader", + ) + group.add_argument( + "--num_att_plot", + type=int, + default=3, + help="The number images to plot the outputs from attention. " + "This option makes sense only when attention-based model", + ) + + group = parser.add_argument_group("distributed training related") + group.add_argument( + "--dist_backend", + default="nccl", + type=str, + help="distributed backend", + ) + group.add_argument( + "--dist_init_method", + type=str, + default="env://", + help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", ' + '"WORLD_SIZE", and "RANK" are referred.', + ) + group.add_argument( + "--dist_world_size", + default=None, + type=int_or_none, + help="number of nodes for distributed training", + ) + group.add_argument( + "--dist_rank", + type=int_or_none, + default=None, + help="node rank for distributed training", + ) + group.add_argument( + # Not starting with "dist_" for compatibility to launch.py + "--local_rank", + type=int_or_none, + default=None, + help="local rank for distributed training. This option is used if " + "--multiprocessing_distributed=false", + ) + group.add_argument( + "--dist_master_addr", + default=None, + type=str_or_none, + help="The master address for distributed training. " + "This value is used when dist_init_method == 'env://'", + ) + group.add_argument( + "--dist_master_port", + default=None, + type=int_or_none, + help="The master port for distributed training" + "This value is used when dist_init_method == 'env://'", + ) + group.add_argument( + "--dist_launcher", + default=None, + type=str_or_none, + choices=["slurm", "mpi", None], + help="The launcher type for distributed training", + ) + group.add_argument( + "--multiprocessing_distributed", + default=False, + type=str2bool, + help="Use multi-processing distributed training to launch " + "N processes per node, which has N GPUs. This is the " + "fastest way to use PyTorch for either single node or " + "multi node data parallel training", + ) + group.add_argument( + "--unused_parameters", + type=str2bool, + default=False, + help="Whether to use the find_unused_parameters in " + "torch.nn.parallel.DistributedDataParallel ", + ) + group.add_argument( + "--sharded_ddp", + default=False, + type=str2bool, + help="Enable sharded training provided by fairscale", + ) + + group = parser.add_argument_group("cudnn mode related") + group.add_argument( + "--cudnn_enabled", + type=str2bool, + default=torch.backends.cudnn.enabled, + help="Enable CUDNN", + ) + group.add_argument( + "--cudnn_benchmark", + type=str2bool, + default=torch.backends.cudnn.benchmark, + help="Enable cudnn-benchmark mode", + ) + group.add_argument( + "--cudnn_deterministic", + type=str2bool, + default=True, + help="Enable cudnn-deterministic mode", + ) + + group = parser.add_argument_group("collect stats mode related") + group.add_argument( + "--collect_stats", + type=str2bool, + default=False, + help='Perform on "collect stats" mode', + ) + group.add_argument( + "--write_collected_feats", + type=str2bool, + default=False, + help='Write the output features from the model when "collect stats" mode', + ) + + group = parser.add_argument_group("Trainer related") + group.add_argument( + "--max_epoch", + type=int, + default=40, + help="The maximum number epoch to train", + ) + group.add_argument( + "--patience", + type=int_or_none, + default=None, + help="Number of epochs to wait without improvement " + "before stopping the training", + ) + group.add_argument( + "--val_scheduler_criterion", + type=str, + nargs=2, + default=("valid", "loss"), + help="The criterion used for the value given to the lr scheduler. " + 'Give a pair referring the phase, "train" or "valid",' + 'and the criterion name. The mode specifying "min" or "max" can ' + "be changed by --scheduler_conf", + ) + group.add_argument( + "--early_stopping_criterion", + type=str, + nargs=3, + default=("valid", "loss", "min"), + help="The criterion used for judging of early stopping. " + 'Give a pair referring the phase, "train" or "valid",' + 'the criterion name and the mode, "min" or "max", e.g. "acc,max".', + ) + group.add_argument( + "--best_model_criterion", + type=str2triple_str, + nargs="+", + default=[ + ("train", "loss", "min"), + ("valid", "loss", "min"), + ("train", "acc", "max"), + ("valid", "acc", "max"), + ], + help="The criterion used for judging of the best model. " + 'Give a pair referring the phase, "train" or "valid",' + 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".', + ) + group.add_argument( + "--keep_nbest_models", + type=int, + nargs="+", + default=[10], + help="Remove previous snapshots excluding the n-best scored epochs", + ) + group.add_argument( + "--grad_clip", + type=float, + default=5.0, + help="Gradient norm threshold to clip", + ) + group.add_argument( + "--grad_clip_type", + type=float, + default=2.0, + help="The type of the used p-norm for gradient clip. Can be inf", + ) + group.add_argument( + "--grad_noise", + type=str2bool, + default=False, + help="The flag to switch to use noise injection to " + "gradients during training", + ) + group.add_argument( + "--accum_grad", + type=int, + default=1, + help="The number of gradient accumulation", + ) + group.add_argument( + "--no_forward_run", + type=str2bool, + default=False, + help="Just only iterating data loading without " + "model forwarding and training", + ) + group.add_argument( + "--resume", + type=str2bool, + default=False, + help="Enable resuming if checkpoint is existing", + ) + group.add_argument( + "--train_dtype", + default="float32", + choices=["float16", "float32", "float64"], + help="Data type for training.", + ) + group.add_argument( + "--use_amp", + type=str2bool, + default=False, + help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6", + ) + group.add_argument( + "--log_interval", + type=int_or_none, + default=None, + help="Show the logs every the number iterations in each epochs at the " + "training phase. If None is given, it is decided according the number " + "of training samples automatically .", + ) + group.add_argument( + "--use_tensorboard", + type=str2bool, + default=True, + help="Enable tensorboard logging", + ) + group.add_argument( + "--use_wandb", + type=str2bool, + default=False, + help="Enable wandb logging", + ) + group.add_argument( + "--wandb_project", + type=str, + default=None, + help="Specify wandb project", + ) + group.add_argument( + "--wandb_id", + type=str, + default=None, + help="Specify wandb id", + ) + group.add_argument( + "--detect_anomaly", + type=str2bool, + default=False, + help="Set torch.autograd.set_detect_anomaly", + ) + + group = parser.add_argument_group("Pretraining model related") + group.add_argument("--pretrain_path", help="This option is obsoleted") + group.add_argument( + "--init_param", + type=str, + default=[], + nargs="*", + help="Specify the file path used for initialization of parameters. " + "The format is ':::', " + "where file_path is the model file path, " + "src_key specifies the key of model states to be used in the model file, " + "dst_key specifies the attribute of the model to be initialized, " + "and exclude_keys excludes keys of model states for the initialization." + "e.g.\n" + " # Load all parameters" + " --init_param some/where/model.pth\n" + " # Load only decoder parameters" + " --init_param some/where/model.pth:decoder:decoder\n" + " # Load only decoder parameters excluding decoder.embed" + " --init_param some/where/model.pth:decoder:decoder:decoder.embed\n" + " --init_param some/where/model.pth:decoder:decoder:decoder.embed\n", + ) + group.add_argument( + "--freeze_param", + type=str, + default=[], + nargs="*", + help="Freeze parameters", + ) + + group = parser.add_argument_group("BatchSampler related") + group.add_argument( + "--num_iters_per_epoch", + type=int_or_none, + default=None, + help="Restrict the number of iterations for training per epoch", + ) + group.add_argument( + "--batch_size", + type=int, + default=20, + help="The mini-batch size used for training. Used if batch_type='unsorted'," + " 'sorted', or 'folded'.", + ) + group.add_argument( + "--valid_batch_size", + type=int_or_none, + default=None, + help="If not given, the value of --batch_size is used", + ) + group.add_argument( + "--batch_bins", + type=int, + default=1000000, + help="The number of batch bins. Used if batch_type='length' or 'numel'", + ) + group.add_argument( + "--valid_batch_bins", + type=int_or_none, + default=None, + help="If not given, the value of --batch_bins is used", + ) + + group.add_argument("--train_shape_file", type=str, action="append", default=[]) + group.add_argument("--valid_shape_file", type=str, action="append", default=[]) + + group = parser.add_argument_group("Sequence iterator related") + _batch_type_help = "" + for key, value in BATCH_TYPES.items(): + _batch_type_help += f'"{key}":\n{value}\n' + group.add_argument( + "--batch_type", + type=str, + default="folded", + choices=list(BATCH_TYPES), + help=_batch_type_help, + ) + group.add_argument( + "--valid_batch_type", + type=str_or_none, + default=None, + choices=list(BATCH_TYPES) + [None], + help="If not given, the value of --batch_type is used", + ) + group.add_argument("--fold_length", type=int, action="append", default=[]) + group.add_argument( + "--sort_in_batch", + type=str, + default="descending", + choices=["descending", "ascending"], + help="Sort the samples in each mini-batches by the sample " + 'lengths. To enable this, "shape_file" must have the length information.', + ) + group.add_argument( + "--sort_batch", + type=str, + default="descending", + choices=["descending", "ascending"], + help="Sort mini-batches by the sample lengths", + ) + group.add_argument( + "--multiple_iterator", + type=str2bool, + default=False, + help="Use multiple iterator mode", + ) + + group = parser.add_argument_group("Chunk iterator related") + group.add_argument( + "--chunk_length", + type=str_or_int, + default=500, + help="Specify chunk length. e.g. '300', '300,400,500', or '300-400'." + "If multiple numbers separated by command are given, " + "one of them is selected randomly for each samples. " + "If two numbers are given with '-', it indicates the range of the choices. " + "Note that if the sequence length is shorter than the all chunk_lengths, " + "the sample is discarded. ", + ) + group.add_argument( + "--chunk_shift_ratio", + type=float, + default=0.5, + help="Specify the shift width of chunks. If it's less than 1, " + "allows the overlapping and if bigger than 1, there are some gaps " + "between each chunk.", + ) + group.add_argument( + "--num_cache_chunks", + type=int, + default=1024, + help="Shuffle in the specified number of chunks and generate mini-batches " + "More larger this value, more randomness can be obtained.", + ) + + group = parser.add_argument_group("Dataset related") + _data_path_and_name_and_type_help = ( + "Give three words splitted by comma. It's used for the training data. " + "e.g. '--train_data_path_and_name_and_type some/path/a.scp,foo,sound'. " + "The first value, some/path/a.scp, indicates the file path, " + "and the second, foo, is the key name used for the mini-batch data, " + "and the last, sound, decides the file type. " + "This option is repeatable, so you can input any number of features " + "for your task. Supported file types are as follows:\n\n" + ) + for key, dic in DATA_TYPES.items(): + _data_path_and_name_and_type_help += f'"{key}":\n{dic["help"]}\n\n' + + group.add_argument( + "--train_data_path_and_name_and_type", + type=str2triple_str, + action="append", + default=[], + help=_data_path_and_name_and_type_help, + ) + group.add_argument( + "--valid_data_path_and_name_and_type", + type=str2triple_str, + action="append", + default=[], + ) + group.add_argument( + "--allow_variable_data_keys", + type=str2bool, + default=False, + help="Allow the arbitrary keys for mini-batch with ignoring " + "the task requirements", + ) + group.add_argument( + "--max_cache_size", + type=humanfriendly.parse_size, + default=0.0, + help="The maximum cache size for data loader. e.g. 10MB, 20GB.", + ) + group.add_argument( + "--max_cache_fd", + type=int, + default=32, + help="The maximum number of file descriptors to be kept " + "as opened for ark files. " + "This feature is only valid when data type is 'kaldi_ark'.", + ) + group.add_argument( + "--valid_max_cache_size", + type=humanfriendly_parse_size_or_none, + default=None, + help="The maximum cache size for validation data loader. e.g. 10MB, 20GB. " + "If None, the 5 percent size of --max_cache_size", + ) + + group = parser.add_argument_group("Optimizer related") + for i in range(1, cls.num_optimizers + 1): + suf = "" if i == 1 else str(i) + group.add_argument( + f"--optim{suf}", + type=lambda x: x.lower(), + default="adadelta", + choices=list(optim_classes), + help="The optimizer type", + ) + group.add_argument( + f"--optim{suf}_conf", + action=NestedDictAction, + default=dict(), + help="The keyword arguments for optimizer", + ) + group.add_argument( + f"--scheduler{suf}", + type=lambda x: str_or_none(x.lower()), + default=None, + choices=list(scheduler_classes) + [None], + help="The lr scheduler type", + ) + group.add_argument( + f"--scheduler{suf}_conf", + action=NestedDictAction, + default=dict(), + help="The keyword arguments for lr scheduler", + ) + + cls.trainer.add_arguments(parser) + cls.add_task_arguments(parser) + + assert check_return_type(parser) + return parser + + @classmethod + def build_optimizers( + cls, + args: argparse.Namespace, + model: torch.nn.Module, + ) -> List[torch.optim.Optimizer]: + if cls.num_optimizers != 1: + raise RuntimeError( + "build_optimizers() must be overridden if num_optimizers != 1" + ) + + optim_class = optim_classes.get(args.optim) + if optim_class is None: + raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}") + if args.sharded_ddp: + if fairscale is None: + raise RuntimeError("Requiring fairscale. Do 'pip install fairscale'") + optim = fairscale.optim.oss.OSS( + params=model.parameters(), optim=optim_class, **args.optim_conf + ) + else: + optim = optim_class(model.parameters(), **args.optim_conf) + + optimizers = [optim] + return optimizers + + @classmethod + def exclude_opts(cls) -> Tuple[str, ...]: + """The options not to be shown by --print_config""" + return "required", "print_config", "config", "ngpu" + + @classmethod + def get_default_config(cls) -> Dict[str, Any]: + """Return the configuration as dict. + + This method is used by print_config() + """ + + def get_class_type(name: str, classes: dict): + _cls = classes.get(name) + if _cls is None: + raise ValueError(f"must be one of {list(classes)}: {name}") + return _cls + + # This method is used only for --print_config + assert check_argument_types() + parser = cls.get_parser() + args, _ = parser.parse_known_args() + config = vars(args) + # Excludes the options not to be shown + for k in AbsTask.exclude_opts(): + config.pop(k) + + for i in range(1, cls.num_optimizers + 1): + suf = "" if i == 1 else str(i) + name = config[f"optim{suf}"] + optim_class = get_class_type(name, optim_classes) + conf = get_default_kwargs(optim_class) + # Overwrite the default by the arguments, + conf.update(config[f"optim{suf}_conf"]) + # and set it again + config[f"optim{suf}_conf"] = conf + + name = config[f"scheduler{suf}"] + if name is not None: + scheduler_class = get_class_type(name, scheduler_classes) + conf = get_default_kwargs(scheduler_class) + # Overwrite the default by the arguments, + conf.update(config[f"scheduler{suf}_conf"]) + # and set it again + config[f"scheduler{suf}_conf"] = conf + + for class_choices in cls.class_choices_list: + if getattr(args, class_choices.name) is not None: + class_obj = class_choices.get_class(getattr(args, class_choices.name)) + conf = get_default_kwargs(class_obj) + name = class_choices.name + # Overwrite the default by the arguments, + conf.update(config[f"{name}_conf"]) + # and set it again + config[f"{name}_conf"] = conf + return config + + @classmethod + def check_required_command_args(cls, args: argparse.Namespace): + assert check_argument_types() + for k in vars(args): + if "-" in k: + raise RuntimeError(f'Use "_" instead of "-": parser.get_parser("{k}")') + + required = ", ".join( + f"--{a}" for a in args.required if getattr(args, a) is None + ) + + if len(required) != 0: + parser = cls.get_parser() + parser.print_help(file=sys.stderr) + p = Path(sys.argv[0]).name + print(file=sys.stderr) + print( + f"{p}: error: the following arguments are required: " f"{required}", + file=sys.stderr, + ) + sys.exit(2) + + @classmethod + def check_task_requirements( + cls, + dataset: Union[AbsDataset, IterableESPnetDataset], + allow_variable_data_keys: bool, + train: bool, + inference: bool = False, + ) -> None: + """Check if the dataset satisfy the requirement of current Task""" + assert check_argument_types() + mes = ( + f"If you intend to use an additional input, modify " + f'"{cls.__name__}.required_data_names()" or ' + f'"{cls.__name__}.optional_data_names()". ' + f"Otherwise you need to set --allow_variable_data_keys true " + ) + + for k in cls.required_data_names(train, inference): + if not dataset.has_name(k): + raise RuntimeError( + f'"{cls.required_data_names(train, inference)}" are required for' + f' {cls.__name__}. but "{dataset.names()}" are input.\n{mes}' + ) + if not allow_variable_data_keys: + task_keys = cls.required_data_names( + train, inference + ) + cls.optional_data_names(train, inference) + for k in dataset.names(): + if k not in task_keys: + raise RuntimeError( + f"The data-name must be one of {task_keys} " + f'for {cls.__name__}: "{k}" is not allowed.\n{mes}' + ) + + @classmethod + def print_config(cls, file=sys.stdout) -> None: + assert check_argument_types() + # Shows the config: e.g. python train.py asr --print_config + config = cls.get_default_config() + file.write(yaml_no_alias_safe_dump(config, indent=4, sort_keys=False)) + + @classmethod + def main(cls, args: argparse.Namespace = None, cmd: Sequence[str] = None): + assert check_argument_types() + print(get_commandline_args(), file=sys.stderr) + if args is None: + parser = cls.get_parser() + args = parser.parse_args(cmd) + args.version = __version__ + if args.pretrain_path is not None: + raise RuntimeError("--pretrain_path is deprecated. Use --init_param") + if args.print_config: + cls.print_config() + sys.exit(0) + cls.check_required_command_args(args) + + # "distributed" is decided using the other command args + resolve_distributed_mode(args) + if not args.distributed or not args.multiprocessing_distributed: + cls.main_worker(args) + + else: + assert args.ngpu > 1, args.ngpu + # Multi-processing distributed mode: e.g. 2node-4process-4GPU + # | Host1 | Host2 | + # | Process1 | Process2 | <= Spawn processes + # |Child1|Child2|Child1|Child2| + # |GPU1 |GPU2 |GPU1 |GPU2 | + + # See also the following usage of --multiprocessing-distributed: + # https://github.com/pytorch/examples/blob/master/imagenet/main.py + num_nodes = get_num_nodes(args.dist_world_size, args.dist_launcher) + if num_nodes == 1: + args.dist_master_addr = "localhost" + args.dist_rank = 0 + # Single node distributed training with multi-GPUs + if ( + args.dist_init_method == "env://" + and get_master_port(args.dist_master_port) is None + ): + # Get the unused port + args.dist_master_port = free_port() + + # Assume that nodes use same number of GPUs each other + args.dist_world_size = args.ngpu * num_nodes + node_rank = get_node_rank(args.dist_rank, args.dist_launcher) + + # The following block is copied from: + # https://github.com/pytorch/pytorch/blob/master/torch/multiprocessing/spawn.py + error_queues = [] + processes = [] + mp = torch.multiprocessing.get_context("spawn") + for i in range(args.ngpu): + # Copy args + local_args = argparse.Namespace(**vars(args)) + + local_args.local_rank = i + local_args.dist_rank = args.ngpu * node_rank + i + local_args.ngpu = 1 + + process = mp.Process( + target=cls.main_worker, + args=(local_args,), + daemon=False, + ) + process.start() + processes.append(process) + error_queues.append(mp.SimpleQueue()) + # Loop on join until it returns True or raises an exception. + while not ProcessContext(processes, error_queues).join(): + pass + + @classmethod + def main_worker(cls, args: argparse.Namespace): + assert check_argument_types() + + # 0. Init distributed process + distributed_option = build_dataclass(DistributedOption, args) + # Setting distributed_option.dist_rank, etc. + distributed_option.init_options() + + # NOTE(kamo): Don't use logging before invoking logging.basicConfig() + if not distributed_option.distributed or distributed_option.dist_rank == 0: + if not distributed_option.distributed: + _rank = "" + else: + _rank = ( + f":{distributed_option.dist_rank}/" + f"{distributed_option.dist_world_size}" + ) + + # NOTE(kamo): + # logging.basicConfig() is invoked in main_worker() instead of main() + # because it can be invoked only once in a process. + # FIXME(kamo): Should we use logging.getLogger()? + logging.basicConfig( + level=args.log_level, + format=f"[{os.uname()[1].split('.')[0]}{_rank}]" + f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + else: + # Suppress logging if RANK != 0 + logging.basicConfig( + level="ERROR", + format=f"[{os.uname()[1].split('.')[0]}" + f":{distributed_option.dist_rank}/{distributed_option.dist_world_size}]" + f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", + ) + # Invoking torch.distributed.init_process_group + distributed_option.init_torch_distributed() + + # 1. Set random-seed + set_all_random_seed(args.seed) + torch.backends.cudnn.enabled = args.cudnn_enabled + torch.backends.cudnn.benchmark = args.cudnn_benchmark + torch.backends.cudnn.deterministic = args.cudnn_deterministic + if args.detect_anomaly: + logging.info("Invoking torch.autograd.set_detect_anomaly(True)") + torch.autograd.set_detect_anomaly(args.detect_anomaly) + + # 2. Build model + model = cls.build_model(args=args) + if not isinstance(model, AbsESPnetModel): + raise RuntimeError( + f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" + ) + model = model.to( + dtype=getattr(torch, args.train_dtype), + device="cuda" if args.ngpu > 0 else "cpu", + ) + for t in args.freeze_param: + for k, p in model.named_parameters(): + if k.startswith(t + ".") or k == t: + logging.info(f"Setting {k}.requires_grad = False") + p.requires_grad = False + + # 3. Build optimizer + optimizers = cls.build_optimizers(args, model=model) + + # 4. Build schedulers + schedulers = [] + for i, optim in enumerate(optimizers, 1): + suf = "" if i == 1 else str(i) + name = getattr(args, f"scheduler{suf}") + conf = getattr(args, f"scheduler{suf}_conf") + if name is not None: + cls_ = scheduler_classes.get(name) + if cls_ is None: + raise ValueError( + f"must be one of {list(scheduler_classes)}: {name}" + ) + scheduler = cls_(optim, **conf) + else: + scheduler = None + + schedulers.append(scheduler) + + logging.info(pytorch_cudnn_version()) + logging.info(model_summary(model)) + for i, (o, s) in enumerate(zip(optimizers, schedulers), 1): + suf = "" if i == 1 else str(i) + logging.info(f"Optimizer{suf}:\n{o}") + logging.info(f"Scheduler{suf}: {s}") + + # 5. Dump "args" to config.yaml + # NOTE(kamo): "args" should be saved after object-buildings are done + # because they are allowed to modify "args". + output_dir = Path(args.output_dir) + if not distributed_option.distributed or distributed_option.dist_rank == 0: + output_dir.mkdir(parents=True, exist_ok=True) + with (output_dir / "config.yaml").open("w", encoding="utf-8") as f: + logging.info( + f'Saving the configuration in {output_dir / "config.yaml"}' + ) + yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False) + + # 6. Loads pre-trained model + for p in args.init_param: + logging.info(f"Loading pretrained params from {p}") + load_pretrained_model( + model=model, + init_param=p, + # NOTE(kamo): "cuda" for torch.load always indicates cuda:0 + # in PyTorch<=1.4 + map_location=f"cuda:{torch.cuda.current_device()}" + if args.ngpu > 0 + else "cpu", + ) + + if args.dry_run: + pass + elif args.collect_stats: + # Perform on collect_stats mode. This mode has two roles + # - Derive the length and dimension of all input data + # - Accumulate feats, square values, and the length for whitening + logging.info(args) + + if args.valid_batch_size is None: + args.valid_batch_size = args.batch_size + + if len(args.train_shape_file) != 0: + train_key_file = args.train_shape_file[0] + else: + train_key_file = None + if len(args.valid_shape_file) != 0: + valid_key_file = args.valid_shape_file[0] + else: + valid_key_file = None + + collect_stats( + model=model, + train_iter=cls.build_streaming_iterator( + data_path_and_name_and_type=args.train_data_path_and_name_and_type, + key_file=train_key_file, + batch_size=args.batch_size, + dtype=args.train_dtype, + num_workers=args.num_workers, + allow_variable_data_keys=args.allow_variable_data_keys, + ngpu=args.ngpu, + preprocess_fn=cls.build_preprocess_fn(args, train=False), + collate_fn=cls.build_collate_fn(args, train=False), + ), + valid_iter=cls.build_streaming_iterator( + data_path_and_name_and_type=args.valid_data_path_and_name_and_type, + key_file=valid_key_file, + batch_size=args.valid_batch_size, + dtype=args.train_dtype, + num_workers=args.num_workers, + allow_variable_data_keys=args.allow_variable_data_keys, + ngpu=args.ngpu, + preprocess_fn=cls.build_preprocess_fn(args, train=False), + collate_fn=cls.build_collate_fn(args, train=False), + ), + output_dir=output_dir, + ngpu=args.ngpu, + log_interval=args.log_interval, + write_collected_feats=args.write_collected_feats, + ) + else: + + # 7. Build iterator factories + if args.multiple_iterator: + train_iter_factory = cls.build_multiple_iter_factory( + args=args, + distributed_option=distributed_option, + mode="train", + ) + else: + train_iter_factory = cls.build_iter_factory( + args=args, + distributed_option=distributed_option, + mode="train", + ) + valid_iter_factory = cls.build_iter_factory( + args=args, + distributed_option=distributed_option, + mode="valid", + ) + if args.num_att_plot != 0: + plot_attention_iter_factory = cls.build_iter_factory( + args=args, + distributed_option=distributed_option, + mode="plot_att", + ) + else: + plot_attention_iter_factory = None + + # 8. Start training + if args.use_wandb: + if ( + not distributed_option.distributed + or distributed_option.dist_rank == 0 + ): + if args.wandb_project is None: + project = ( + "ESPnet_" + + cls.__name__ + + str(Path(".").resolve()).replace("/", "_") + ) + else: + project = args.wandb_project + if args.wandb_id is None: + wandb_id = str(output_dir).replace("/", "_") + else: + wandb_id = args.wandb_id + + wandb.init( + project=project, + dir=output_dir, + id=wandb_id, + resume="allow", + ) + wandb.config.update(args) + else: + # wandb also supports grouping for distributed training, + # but we only logs aggregated data, + # so it's enough to perform on rank0 node. + args.use_wandb = False + + # Don't give args to trainer.run() directly!!! + # Instead of it, define "Options" object and build here. + trainer_options = cls.trainer.build_options(args) + cls.trainer.run( + model=model, + optimizers=optimizers, + schedulers=schedulers, + train_iter_factory=train_iter_factory, + valid_iter_factory=valid_iter_factory, + plot_attention_iter_factory=plot_attention_iter_factory, + trainer_options=trainer_options, + distributed_option=distributed_option, + ) + + @classmethod + def build_iter_options( + cls, + args: argparse.Namespace, + distributed_option: DistributedOption, + mode: str, + ): + if mode == "train": + preprocess_fn = cls.build_preprocess_fn(args, train=True) + collate_fn = cls.build_collate_fn(args, train=True) + data_path_and_name_and_type = args.train_data_path_and_name_and_type + shape_files = args.train_shape_file + batch_size = args.batch_size + batch_bins = args.batch_bins + batch_type = args.batch_type + max_cache_size = args.max_cache_size + max_cache_fd = args.max_cache_fd + distributed = distributed_option.distributed + num_batches = None + num_iters_per_epoch = args.num_iters_per_epoch + train = True + + elif mode == "valid": + preprocess_fn = cls.build_preprocess_fn(args, train=False) + collate_fn = cls.build_collate_fn(args, train=False) + data_path_and_name_and_type = args.valid_data_path_and_name_and_type + shape_files = args.valid_shape_file + + if args.valid_batch_type is None: + batch_type = args.batch_type + else: + batch_type = args.valid_batch_type + if args.valid_batch_size is None: + batch_size = args.batch_size + else: + batch_size = args.valid_batch_size + if args.valid_batch_bins is None: + batch_bins = args.batch_bins + else: + batch_bins = args.valid_batch_bins + if args.valid_max_cache_size is None: + # Cache 5% of maximum size for validation loader + max_cache_size = 0.05 * args.max_cache_size + else: + max_cache_size = args.valid_max_cache_size + max_cache_fd = args.max_cache_fd + distributed = distributed_option.distributed + num_batches = None + num_iters_per_epoch = None + train = False + + elif mode == "plot_att": + preprocess_fn = cls.build_preprocess_fn(args, train=False) + collate_fn = cls.build_collate_fn(args, train=False) + data_path_and_name_and_type = args.valid_data_path_and_name_and_type + shape_files = args.valid_shape_file + batch_type = "unsorted" + batch_size = 1 + batch_bins = 0 + num_batches = args.num_att_plot + max_cache_fd = args.max_cache_fd + # num_att_plot should be a few sample ~ 3, so cache all data. + max_cache_size = np.inf if args.max_cache_size != 0.0 else 0.0 + # always False because plot_attention performs on RANK0 + distributed = False + num_iters_per_epoch = None + train = False + else: + raise NotImplementedError(f"mode={mode}") + + return IteratorOptions( + preprocess_fn=preprocess_fn, + collate_fn=collate_fn, + data_path_and_name_and_type=data_path_and_name_and_type, + shape_files=shape_files, + batch_type=batch_type, + batch_size=batch_size, + batch_bins=batch_bins, + num_batches=num_batches, + max_cache_size=max_cache_size, + max_cache_fd=max_cache_fd, + distributed=distributed, + num_iters_per_epoch=num_iters_per_epoch, + train=train, + ) + + @classmethod + def build_iter_factory( + cls, + args: argparse.Namespace, + distributed_option: DistributedOption, + mode: str, + kwargs: dict = None, + ) -> AbsIterFactory: + """Build a factory object of mini-batch iterator. + + This object is invoked at every epochs to build the iterator for each epoch + as following: + + >>> iter_factory = cls.build_iter_factory(...) + >>> for epoch in range(1, max_epoch): + ... for keys, batch in iter_fatory.build_iter(epoch): + ... model(**batch) + + The mini-batches for each epochs are fully controlled by this class. + Note that the random seed used for shuffling is decided as "seed + epoch" and + the generated mini-batches can be reproduces when resuming. + + Note that the definition of "epoch" doesn't always indicate + to run out of the whole training corpus. + "--num_iters_per_epoch" option restricts the number of iterations for each epoch + and the rest of samples for the originally epoch are left for the next epoch. + e.g. If The number of mini-batches equals to 4, the following two are same: + + - 1 epoch without "--num_iters_per_epoch" + - 4 epoch with "--num_iters_per_epoch" == 4 + + """ + assert check_argument_types() + iter_options = cls.build_iter_options(args, distributed_option, mode) + + # Overwrite iter_options if any kwargs is given + if kwargs is not None: + for k, v in kwargs.items(): + setattr(iter_options, k, v) + + if args.iterator_type == "sequence": + return cls.build_sequence_iter_factory( + args=args, + iter_options=iter_options, + mode=mode, + ) + elif args.iterator_type == "chunk": + return cls.build_chunk_iter_factory( + args=args, + iter_options=iter_options, + mode=mode, + ) + elif args.iterator_type == "task": + return cls.build_task_iter_factory( + args=args, + iter_options=iter_options, + mode=mode, + ) + else: + raise RuntimeError(f"Not supported: iterator_type={args.iterator_type}") + + @classmethod + def build_sequence_iter_factory( + cls, args: argparse.Namespace, iter_options: IteratorOptions, mode: str + ) -> AbsIterFactory: + assert check_argument_types() + + dataset = ESPnetDataset( + iter_options.data_path_and_name_and_type, + float_dtype=args.train_dtype, + preprocess=iter_options.preprocess_fn, + max_cache_size=iter_options.max_cache_size, + max_cache_fd=iter_options.max_cache_fd, + ) + cls.check_task_requirements( + dataset, args.allow_variable_data_keys, train=iter_options.train + ) + + if Path( + Path(iter_options.data_path_and_name_and_type[0][0]).parent, "utt2category" + ).exists(): + utt2category_file = str( + Path( + Path(iter_options.data_path_and_name_and_type[0][0]).parent, + "utt2category", + ) + ) + else: + utt2category_file = None + batch_sampler = build_batch_sampler( + type=iter_options.batch_type, + shape_files=iter_options.shape_files, + fold_lengths=args.fold_length, + batch_size=iter_options.batch_size, + batch_bins=iter_options.batch_bins, + sort_in_batch=args.sort_in_batch, + sort_batch=args.sort_batch, + drop_last=False, + min_batch_size=torch.distributed.get_world_size() + if iter_options.distributed + else 1, + utt2category_file=utt2category_file, + ) + + batches = list(batch_sampler) + if iter_options.num_batches is not None: + batches = batches[: iter_options.num_batches] + + bs_list = [len(batch) for batch in batches] + + logging.info(f"[{mode}] dataset:\n{dataset}") + logging.info(f"[{mode}] Batch sampler: {batch_sampler}") + logging.info( + f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, " + f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}" + ) + + if iter_options.distributed: + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + for batch in batches: + if len(batch) < world_size: + raise RuntimeError( + f"The batch-size must be equal or more than world_size: " + f"{len(batch)} < {world_size}" + ) + batches = [batch[rank::world_size] for batch in batches] + + return SequenceIterFactory( + dataset=dataset, + batches=batches, + seed=args.seed, + num_iters_per_epoch=iter_options.num_iters_per_epoch, + shuffle=iter_options.train, + num_workers=args.num_workers, + collate_fn=iter_options.collate_fn, + pin_memory=args.ngpu > 0, + ) + + @classmethod + def build_chunk_iter_factory( + cls, + args: argparse.Namespace, + iter_options: IteratorOptions, + mode: str, + ) -> AbsIterFactory: + assert check_argument_types() + + dataset = ESPnetDataset( + iter_options.data_path_and_name_and_type, + float_dtype=args.train_dtype, + preprocess=iter_options.preprocess_fn, + max_cache_size=iter_options.max_cache_size, + max_cache_fd=iter_options.max_cache_fd, + ) + cls.check_task_requirements( + dataset, args.allow_variable_data_keys, train=iter_options.train + ) + + if len(iter_options.shape_files) == 0: + key_file = iter_options.data_path_and_name_and_type[0][0] + else: + key_file = iter_options.shape_files[0] + + batch_sampler = UnsortedBatchSampler(batch_size=1, key_file=key_file) + batches = list(batch_sampler) + if iter_options.num_batches is not None: + batches = batches[: iter_options.num_batches] + logging.info(f"[{mode}] dataset:\n{dataset}") + + if iter_options.distributed: + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + if len(batches) < world_size: + raise RuntimeError("Number of samples is smaller than world_size") + if iter_options.batch_size < world_size: + raise RuntimeError("batch_size must be equal or more than world_size") + + if rank < iter_options.batch_size % world_size: + batch_size = iter_options.batch_size // world_size + 1 + else: + batch_size = iter_options.batch_size // world_size + num_cache_chunks = args.num_cache_chunks // world_size + # NOTE(kamo): Split whole corpus by sample numbers without considering + # each of the lengths, therefore the number of iteration counts are not + # always equal to each other and the iterations are limitted + # by the fewest iterations. + # i.e. the samples over the counts are discarded. + batches = batches[rank::world_size] + else: + batch_size = iter_options.batch_size + num_cache_chunks = args.num_cache_chunks + + return ChunkIterFactory( + dataset=dataset, + batches=batches, + seed=args.seed, + batch_size=batch_size, + # For chunk iterator, + # --num_iters_per_epoch doesn't indicate the number of iterations, + # but indicates the number of samples. + num_samples_per_epoch=iter_options.num_iters_per_epoch, + shuffle=iter_options.train, + num_workers=args.num_workers, + collate_fn=iter_options.collate_fn, + pin_memory=args.ngpu > 0, + chunk_length=args.chunk_length, + chunk_shift_ratio=args.chunk_shift_ratio, + num_cache_chunks=num_cache_chunks, + ) + + # NOTE(kamo): Not abstract class + @classmethod + def build_task_iter_factory( + cls, + args: argparse.Namespace, + iter_options: IteratorOptions, + mode: str, + ) -> AbsIterFactory: + """Build task specific iterator factory + + Example: + + >>> class YourTask(AbsTask): + ... @classmethod + ... def add_task_arguments(cls, parser: argparse.ArgumentParser): + ... parser.set_defaults(iterator_type="task") + ... + ... @classmethod + ... def build_task_iter_factory( + ... cls, + ... args: argparse.Namespace, + ... iter_options: IteratorOptions, + ... mode: str, + ... ): + ... return FooIterFactory(...) + ... + ... @classmethod + ... def build_iter_options( + .... args: argparse.Namespace, + ... distributed_option: DistributedOption, + ... mode: str + ... ): + ... # if you need to customize options object + """ + raise NotImplementedError + + @classmethod + def build_multiple_iter_factory( + cls, args: argparse.Namespace, distributed_option: DistributedOption, mode: str + ): + assert check_argument_types() + iter_options = cls.build_iter_options(args, distributed_option, mode) + assert len(iter_options.data_path_and_name_and_type) > 0, len( + iter_options.data_path_and_name_and_type + ) + + # 1. Sanity check + num_splits = None + for path in [ + path for path, _, _ in iter_options.data_path_and_name_and_type + ] + list(iter_options.shape_files): + if not Path(path).is_dir(): + raise RuntimeError(f"{path} is not a directory") + p = Path(path) / "num_splits" + if not p.exists(): + raise FileNotFoundError(f"{p} is not found") + with p.open() as f: + _num_splits = int(f.read()) + if num_splits is not None and num_splits != _num_splits: + raise RuntimeError( + f"Number of splits are mismathed: " + f"{iter_options.data_path_and_name_and_type[0][0]} and {path}" + ) + num_splits = _num_splits + + for i in range(num_splits): + p = Path(path) / f"split.{i}" + if not p.exists(): + raise FileNotFoundError(f"{p} is not found") + + # 2. Create functions to build an iter factory for each splits + data_path_and_name_and_type_list = [ + [ + (str(Path(p) / f"split.{i}"), n, t) + for p, n, t in iter_options.data_path_and_name_and_type + ] + for i in range(num_splits) + ] + shape_files_list = [ + [str(Path(s) / f"split.{i}") for s in iter_options.shape_files] + for i in range(num_splits) + ] + num_iters_per_epoch_list = [ + (iter_options.num_iters_per_epoch + i) // num_splits + if iter_options.num_iters_per_epoch is not None + else None + for i in range(num_splits) + ] + max_cache_size = iter_options.max_cache_size / num_splits + + # Note that iter-factories are built for each epoch at runtime lazily. + build_funcs = [ + functools.partial( + cls.build_iter_factory, + args, + distributed_option, + mode, + kwargs=dict( + data_path_and_name_and_type=_data_path_and_name_and_type, + shape_files=_shape_files, + num_iters_per_epoch=_num_iters_per_epoch, + max_cache_size=max_cache_size, + ), + ) + for ( + _data_path_and_name_and_type, + _shape_files, + _num_iters_per_epoch, + ) in zip( + data_path_and_name_and_type_list, + shape_files_list, + num_iters_per_epoch_list, + ) + ] + + # 3. Build MultipleIterFactory + return MultipleIterFactory( + build_funcs=build_funcs, shuffle=iter_options.train, seed=args.seed + ) + + @classmethod + def build_streaming_iterator( + cls, + data_path_and_name_and_type, + preprocess_fn, + collate_fn, + key_file: str = None, + batch_size: int = 1, + dtype: str = np.float32, + num_workers: int = 1, + allow_variable_data_keys: bool = False, + ngpu: int = 0, + inference: bool = False, + ) -> DataLoader: + """Build DataLoader using iterable dataset""" + assert check_argument_types() + # For backward compatibility for pytorch DataLoader + if collate_fn is not None: + kwargs = dict(collate_fn=collate_fn) + else: + kwargs = {} + + # IterableDataset is supported from pytorch=1.2 + if LooseVersion(torch.__version__) >= LooseVersion("1.2"): + dataset = IterableESPnetDataset( + data_path_and_name_and_type, + float_dtype=dtype, + preprocess=preprocess_fn, + key_file=key_file, + ) + if dataset.apply_utt2category: + kwargs.update(batch_size=1) + else: + kwargs.update(batch_size=batch_size) + else: + dataset = ESPnetDataset( + data_path_and_name_and_type, + float_dtype=dtype, + preprocess=preprocess_fn, + ) + if key_file is None: + key_file = data_path_and_name_and_type[0][0] + batch_sampler = UnsortedBatchSampler( + batch_size=batch_size, + key_file=key_file, + drop_last=False, + ) + kwargs.update(batch_sampler=batch_sampler) + + cls.check_task_requirements( + dataset, allow_variable_data_keys, train=False, inference=inference + ) + + return DataLoader( + dataset=dataset, + pin_memory=ngpu > 0, + num_workers=num_workers, + **kwargs, + ) + + # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ + @classmethod + def build_model_from_file( + cls, + config_file: Union[Path, str], + model_file: Union[Path, str] = None, + device: str = "cpu", + ) -> Tuple[AbsESPnetModel, argparse.Namespace]: + """This method is used for inference or fine-tuning. + + Args: + config_file: The yaml file saved when training. + model_file: The model file saved when training. + device: + + """ + assert check_argument_types() + config_file = Path(config_file) + + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + args = argparse.Namespace(**args) + model = cls.build_model(args) + if not isinstance(model, AbsESPnetModel): + raise RuntimeError( + f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" + ) + model.to(device) + if model_file is not None: + if device == "cuda": + # NOTE(kamo): "cuda" for torch.load always indicates cuda:0 + # in PyTorch<=1.4 + device = f"cuda:{torch.cuda.current_device()}" + model.load_state_dict(torch.load(model_file, map_location=device), strict=False) # TC Marker + + return model, args diff --git a/espnet2/tasks/asr.py b/espnet2/tasks/asr.py new file mode 100644 index 0000000000000000000000000000000000000000..46c4d521d233a08226d7cdf8f706dc101b388777 --- /dev/null +++ b/espnet2/tasks/asr.py @@ -0,0 +1,435 @@ +import argparse +import logging +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet2.asr.ctc import CTC +from espnet2.asr.decoder.abs_decoder import AbsDecoder +from espnet2.asr.decoder.rnn_decoder import RNNDecoder +from espnet2.asr.decoder.transformer_decoder import ( + DynamicConvolution2DTransformerDecoder, # noqa: H301 +) +from espnet2.asr.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder +from espnet2.asr.decoder.transformer_decoder import ( + LightweightConvolution2DTransformerDecoder, # noqa: H301 +) +from espnet2.asr.decoder.transformer_decoder import ( + LightweightConvolutionTransformerDecoder, # noqa: H301 +) +from espnet2.asr.decoder.transformer_decoder import TransformerDecoder +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet2.asr.encoder.conformer_encoder import ConformerEncoder +from espnet2.asr.encoder.rnn_encoder import RNNEncoder +from espnet2.asr.encoder.transformer_encoder import TransformerEncoder +from espnet2.asr.encoder.contextual_block_transformer_encoder import ( + ContextualBlockTransformerEncoder, # noqa: H301 +) +from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder +from espnet2.asr.encoder.wav2vec2_encoder import FairSeqWav2Vec2Encoder +from espnet2.asr.espnet_model import ESPnetASRModel +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.asr.frontend.default import DefaultFrontend +from espnet2.asr.frontend.windowing import SlidingWindow +from espnet2.asr.preencoder.abs_preencoder import AbsPreEncoder +from espnet2.asr.preencoder.sinc import LightweightSincConvs +from espnet2.asr.specaug.abs_specaug import AbsSpecAug +from espnet2.asr.specaug.specaug import SpecAug +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.layers.global_mvn import GlobalMVN +from espnet2.layers.utterance_mvn import UtteranceMVN +from espnet2.tasks.abs_task import AbsTask +from espnet2.torch_utils.initialize import initialize +from espnet2.train.class_choices import ClassChoices +from espnet2.train.collate_fn import CommonCollateFn +from espnet2.train.preprocessor import CommonPreprocessor +from espnet2.train.trainer import Trainer +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import float_or_none +from espnet2.utils.types import int_or_none +from espnet2.utils.types import str2bool +from espnet2.utils.types import str_or_none + +frontend_choices = ClassChoices( + name="frontend", + classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow), + type_check=AbsFrontend, + default="default", +) +specaug_choices = ClassChoices( + name="specaug", + classes=dict(specaug=SpecAug), + type_check=AbsSpecAug, + default=None, + optional=True, +) +normalize_choices = ClassChoices( + "normalize", + classes=dict( + global_mvn=GlobalMVN, + utterance_mvn=UtteranceMVN, + ), + type_check=AbsNormalize, + default="utterance_mvn", + optional=True, +) +preencoder_choices = ClassChoices( + name="preencoder", + classes=dict( + sinc=LightweightSincConvs, + ), + type_check=AbsPreEncoder, + default=None, + optional=True, +) +encoder_choices = ClassChoices( + "encoder", + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + contextual_block_transformer=ContextualBlockTransformerEncoder, + vgg_rnn=VGGRNNEncoder, + rnn=RNNEncoder, + wav2vec2=FairSeqWav2Vec2Encoder, + ), + type_check=AbsEncoder, + default="rnn", +) +decoder_choices = ClassChoices( + "decoder", + classes=dict( + transformer=TransformerDecoder, + lightweight_conv=LightweightConvolutionTransformerDecoder, + lightweight_conv2d=LightweightConvolution2DTransformerDecoder, + dynamic_conv=DynamicConvolutionTransformerDecoder, + dynamic_conv2d=DynamicConvolution2DTransformerDecoder, + rnn=RNNDecoder, + ), + type_check=AbsDecoder, + default="rnn", +) + + +class ASRTask(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + specaug_choices, + # --normalize and --normalize_conf + normalize_choices, + # --preencoder and --preencoder_conf + preencoder_choices, + # --encoder and --encoder_conf + encoder_choices, + # --decoder and --decoder_conf + decoder_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + required = parser.get_default("required") + required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + + group.add_argument( + "--ctc_conf", + action=NestedDictAction, + default=get_default_kwargs(CTC), + help="The keyword arguments for CTC class.", + ) + group.add_argument( + "--model_conf", + action=NestedDictAction, + default=get_default_kwargs(ESPnetASRModel), + help="The keyword arguments for model class.", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default="bpe", + choices=["bpe", "char", "word", "phn"], + help="The text will be tokenized " "in the specified level token", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model file of sentencepiece", + ) + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + help="non_linguistic_symbols file path", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Apply text cleaning", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=[None, "g2p_en", "pyopenjtalk", "pyopenjtalk_kana"], + default=None, + help="Specify g2p method if --token_type=phn", + ) + parser.add_argument( + "--speech_volume_normalize", + type=float_or_none, + default=None, + help="Scale the maximum amplitude to the given value.", + ) + parser.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The file path of rir scp file.", + ) + parser.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="THe probability for applying RIR convolution.", + ) + parser.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + parser.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability applying Noise adding.", + ) + parser.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of noise decibel level.", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, + rir_apply_prob=args.rir_apply_prob + if hasattr(args, "rir_apply_prob") + else 1.0, + noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, + noise_apply_prob=args.noise_apply_prob + if hasattr(args, "noise_apply_prob") + else 1.0, + noise_db_range=args.noise_db_range + if hasattr(args, "noise_db_range") + else "13_15", + speech_volume_normalize=args.speech_volume_normalize + if hasattr(args, "rir_scp") + else None, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech", "text") + else: + # Recognition mode + retval = ("speech",) + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> ESPnetASRModel: + assert check_argument_types() + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size }") + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Pre-encoder input block + # NOTE(kan-bayashi): Use getattr to keep the compatibility + if getattr(args, "preencoder", None) is not None: + preencoder_class = preencoder_choices.get_class(args.preencoder) + preencoder = preencoder_class(**args.preencoder_conf) + input_size = preencoder.output_size() + else: + preencoder = None + + # 4. Encoder + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size=input_size, **args.encoder_conf) + + # 5. Decoder + decoder_class = decoder_choices.get_class(args.decoder) + + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder.output_size(), + **args.decoder_conf, + ) + + # 6. CTC + ctc = CTC( + odim=vocab_size, encoder_output_sizse=encoder.output_size(), **args.ctc_conf + ) + + # 7. RNN-T Decoder (Not implemented) + rnnt_decoder = None + + # 8. Build model + model = ESPnetASRModel( + vocab_size=vocab_size, + frontend=frontend, + specaug=specaug, + normalize=normalize, + preencoder=preencoder, + encoder=encoder, + decoder=decoder, + ctc=ctc, + rnnt_decoder=rnnt_decoder, + token_list=token_list, + **args.model_conf, + ) + + # FIXME(kamo): Should be done in model? + # 9. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model diff --git a/espnet2/tasks/diar.py b/espnet2/tasks/diar.py new file mode 100644 index 0000000000000000000000000000000000000000..fcea8a56ebc64599c7fcd113ff9647a1c8f9dfe0 --- /dev/null +++ b/espnet2/tasks/diar.py @@ -0,0 +1,256 @@ +import argparse +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet2.asr.encoder.conformer_encoder import ConformerEncoder +from espnet2.asr.encoder.rnn_encoder import RNNEncoder +from espnet2.asr.encoder.transformer_encoder import TransformerEncoder +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.asr.frontend.default import DefaultFrontend +from espnet2.asr.frontend.windowing import SlidingWindow +from espnet2.diar.decoder.abs_decoder import AbsDecoder +from espnet2.diar.decoder.linear_decoder import LinearDecoder +from espnet2.diar.espnet_model import ESPnetDiarizationModel +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.layers.global_mvn import GlobalMVN +from espnet2.layers.label_aggregation import LabelAggregate +from espnet2.layers.utterance_mvn import UtteranceMVN +from espnet2.tasks.abs_task import AbsTask +from espnet2.torch_utils.initialize import initialize +from espnet2.train.class_choices import ClassChoices +from espnet2.train.collate_fn import CommonCollateFn +from espnet2.train.preprocessor import CommonPreprocessor +from espnet2.train.trainer import Trainer +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import int_or_none +from espnet2.utils.types import str2bool +from espnet2.utils.types import str_or_none + +frontend_choices = ClassChoices( + name="frontend", + classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow), + type_check=AbsFrontend, + default="default", +) +normalize_choices = ClassChoices( + "normalize", + classes=dict( + global_mvn=GlobalMVN, + utterance_mvn=UtteranceMVN, + ), + type_check=AbsNormalize, + default="utterance_mvn", + optional=True, +) +label_aggregator_choices = ClassChoices( + "label_aggregator", + classes=dict(label_aggregator=LabelAggregate), + default="label_aggregator", +) +encoder_choices = ClassChoices( + "encoder", + classes=dict( + conformer=ConformerEncoder, + transformer=TransformerEncoder, + rnn=RNNEncoder, + ), + type_check=AbsEncoder, + default="rnn", +) +decoder_choices = ClassChoices( + "decoder", + classes=dict(linear=LinearDecoder), + type_check=AbsDecoder, + default="linear", +) + + +class DiarizationTask(AbsTask): + # If you need more than one optimizer, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --frontend and --frontend_conf + frontend_choices, + # --normalize and --normalize_conf + normalize_choices, + # --encoder and --encoder_conf + encoder_choices, + # --decoder and --decoder_conf + decoder_choices, + # --label_aggregator and --label_aggregator_conf + label_aggregator_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + group.add_argument( + "--num_spk", + type=int_or_none, + default=None, + help="The number fo speakers (for each recording) used in system training", + ) + + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + + group.add_argument( + "--model_conf", + action=NestedDictAction, + default=get_default_kwargs(ESPnetDiarizationModel), + help="The keyword arguments for model class.", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + # FIXME (jiatong): add more arugment here + retval = CommonPreprocessor(train=train) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech", "spk_labels") + else: + # Recognition mode + retval = ("speech",) + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + # (Note: jiatong): no optional data names for now + retval = () + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> ESPnetDiarizationModel: + assert check_argument_types() + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 3. Label Aggregator layer + label_aggregator_class = label_aggregator_choices.get_class( + args.label_aggregator + ) + label_aggregator = label_aggregator_class(**args.label_aggregator_conf) + + # 3. Encoder + encoder_class = encoder_choices.get_class(args.encoder) + # Note(jiatong): Diarization may not use subsampling when processing + encoder = encoder_class(input_size=input_size, **args.encoder_conf) + + # 4. Decoder + decoder_class = decoder_choices.get_class(args.decoder) + decoder = decoder_class( + num_spk=args.num_spk, + encoder_output_size=encoder.output_size(), + **args.decoder_conf, + ) + + # 5. Build model + model = ESPnetDiarizationModel( + frontend=frontend, + normalize=normalize, + label_aggregator=label_aggregator, + encoder=encoder, + decoder=decoder, + **args.model_conf, + ) + + # FIXME(kamo): Should be done in model? + # 6. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model diff --git a/espnet2/tasks/enh.py b/espnet2/tasks/enh.py new file mode 100644 index 0000000000000000000000000000000000000000..eba489c9c14ce9d4cd551c2c2c3db4d452259839 --- /dev/null +++ b/espnet2/tasks/enh.py @@ -0,0 +1,195 @@ +import argparse +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet2.enh.decoder.abs_decoder import AbsDecoder +from espnet2.enh.decoder.conv_decoder import ConvDecoder +from espnet2.enh.decoder.null_decoder import NullDecoder +from espnet2.enh.decoder.stft_decoder import STFTDecoder +from espnet2.enh.encoder.abs_encoder import AbsEncoder +from espnet2.enh.encoder.conv_encoder import ConvEncoder +from espnet2.enh.encoder.null_encoder import NullEncoder +from espnet2.enh.encoder.stft_encoder import STFTEncoder +from espnet2.enh.espnet_model import ESPnetEnhancementModel +from espnet2.enh.separator.abs_separator import AbsSeparator +from espnet2.enh.separator.asteroid_models import AsteroidModel_Converter +from espnet2.enh.separator.conformer_separator import ConformerSeparator +from espnet2.enh.separator.dprnn_separator import DPRNNSeparator +from espnet2.enh.separator.neural_beamformer import NeuralBeamformer +from espnet2.enh.separator.rnn_separator import RNNSeparator +from espnet2.enh.separator.tcn_separator import TCNSeparator +from espnet2.enh.separator.transformer_separator import TransformerSeparator +from espnet2.tasks.abs_task import AbsTask +from espnet2.torch_utils.initialize import initialize +from espnet2.train.class_choices import ClassChoices +from espnet2.train.collate_fn import CommonCollateFn +from espnet2.train.trainer import Trainer +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import str2bool +from espnet2.utils.types import str_or_none + +encoder_choices = ClassChoices( + name="encoder", + classes=dict(stft=STFTEncoder, conv=ConvEncoder, same=NullEncoder), + type_check=AbsEncoder, + default="stft", +) + +separator_choices = ClassChoices( + name="separator", + classes=dict( + rnn=RNNSeparator, + tcn=TCNSeparator, + dprnn=DPRNNSeparator, + transformer=TransformerSeparator, + conformer=ConformerSeparator, + wpe_beamformer=NeuralBeamformer, + asteroid=AsteroidModel_Converter, + ), + type_check=AbsSeparator, + default="rnn", +) + +decoder_choices = ClassChoices( + name="decoder", + classes=dict(stft=STFTDecoder, conv=ConvDecoder, same=NullDecoder), + type_check=AbsDecoder, + default="stft", +) + +MAX_REFERENCE_NUM = 100 + + +class EnhancementTask(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + class_choices_list = [ + # --encoder and --encoder_conf + encoder_choices, + # --separator and --separator_conf + separator_choices, + # --decoder and --decoder_conf + decoder_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + # required = parser.get_default("required") + + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--model_conf", + action=NestedDictAction, + default=get_default_kwargs(ESPnetEnhancementModel), + help="The keyword arguments for model class.", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=False, + help="Apply preprocessing to data or not", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + + return CommonCollateFn(float_pad_value=0.0, int_pad_value=0) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech_mix", "speech_ref1") + else: + # Recognition mode + retval = ("speech_mix",) + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = ["dereverb_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)] + retval += ["speech_ref{}".format(n) for n in range(2, MAX_REFERENCE_NUM + 1)] + retval += ["noise_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)] + retval = tuple(retval) + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel: + assert check_argument_types() + + encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf) + separator = separator_choices.get_class(args.separator)( + encoder.output_dim, **args.separator_conf + ) + decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf) + + # 1. Build model + model = ESPnetEnhancementModel( + encoder=encoder, separator=separator, decoder=decoder, **args.model_conf + ) + + # FIXME(kamo): Should be done in model? + # 2. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model diff --git a/espnet2/tasks/enh_asr.py b/espnet2/tasks/enh_asr.py new file mode 100644 index 0000000000000000000000000000000000000000..169b2066323a5368f57e9b15a94f8f6d0005cd4a --- /dev/null +++ b/espnet2/tasks/enh_asr.py @@ -0,0 +1,368 @@ +import argparse +import logging +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet2.asr.ctc import CTC +from espnet2.asr.decoder.abs_decoder import AbsDecoder +from espnet2.asr.decoder.rnn_decoder import RNNDecoder +from espnet2.asr.decoder.transformer_decoder import TransformerDecoder +from espnet2.asr.encoder.abs_encoder import AbsEncoder +from espnet2.asr.encoder.rnn_encoder import RNNEncoder +from espnet2.asr.encoder.transformer_encoder import TransformerEncoder +from espnet2.asr.encoder.vgg_rnn_encoder import VGGRNNEncoder +from espnet2.asr.espnet_joint_model import ESPnetEnhASRModel +from espnet2.asr.espnet_model import ESPnetASRModel +from espnet2.asr.frontend.abs_frontend import AbsFrontend +from espnet2.asr.frontend.default import DefaultFrontend +from espnet2.asr.specaug.abs_specaug import AbsSpecAug +from espnet2.asr.specaug.specaug import SpecAug +from espnet2.enh.abs_enh import AbsEnhancement +from espnet2.enh.espnet_model import ESPnetEnhancementModel +from espnet2.enh.nets.beamformer_net import BeamformerNet +from espnet2.enh.nets.tasnet import TasNet +from espnet2.enh.nets.tf_mask_net import TFMaskingNet +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.layers.global_mvn import GlobalMVN +from espnet2.layers.utterance_mvn import UtteranceMVN +from espnet2.tasks.abs_task import AbsTask +from espnet2.torch_utils.initialize import initialize +from espnet2.train.class_choices import ClassChoices +from espnet2.train.collate_fn import CommonCollateFn +from espnet2.train.preprocessor import CommonPreprocessor_multi +from espnet2.train.trainer import Trainer +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import int_or_none +from espnet2.utils.types import str2bool +from espnet2.utils.types import str_or_none + +enh_choices = ClassChoices( + name="enh", + classes=dict(tf_masking=TFMaskingNet, tasnet=TasNet, wpe_beamformer=BeamformerNet), + type_check=AbsEnhancement, + default="tf_masking", +) +frontend_choices = ClassChoices( + name="frontend", + classes=dict(default=DefaultFrontend), + type_check=AbsFrontend, + default="default", +) +specaug_choices = ClassChoices( + name="specaug", + classes=dict(specaug=SpecAug), + type_check=AbsSpecAug, + default=None, + optional=True, +) +normalize_choices = ClassChoices( + "normalize", + classes=dict( + global_mvn=GlobalMVN, + utterance_mvn=UtteranceMVN, + ), + type_check=AbsNormalize, + default="utterance_mvn", + optional=True, +) +encoder_choices = ClassChoices( + "encoder", + classes=dict( + transformer=TransformerEncoder, + vgg_rnn=VGGRNNEncoder, + rnn=RNNEncoder, + ), + type_check=AbsEncoder, + default="rnn", +) +decoder_choices = ClassChoices( + "decoder", + classes=dict(transformer=TransformerDecoder, rnn=RNNDecoder), + type_check=AbsDecoder, + default="rnn", +) + +MAX_REFERENCE_NUM = 100 + + +class ASRTask(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --enh and --enh_conf + enh_choices, + # --frontend and --frontend_conf + frontend_choices, + # --specaug and --specaug_conf + specaug_choices, + # --normalize and --normalize_conf + normalize_choices, + # --encoder and --encoder_conf + encoder_choices, + # --decoder and --decoder_conf + decoder_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + required = parser.get_default("required") + required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--input_size", + type=int_or_none, + default=None, + help="The number of input dimension of the feature", + ) + + group.add_argument( + "--ctc_conf", + action=NestedDictAction, + default=get_default_kwargs(CTC), + help="The keyword arguments for CTC class.", + ) + group.add_argument( + "--asr_model_conf", + action=NestedDictAction, + default=get_default_kwargs(ESPnetASRModel), + help="The keyword arguments for model class.", + ) + + group.add_argument( + "--enh_model_conf", + action=NestedDictAction, + default=get_default_kwargs(ESPnetEnhancementModel), + help="The keyword arguments for model class.", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=False, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default="bpe", + choices=["bpe", "char", "word", "phn"], + help="The text will be tokenized " "in the specified level token", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model file of sentencepiece", + ) + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + help="non_linguistic_symbols file path", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Apply text cleaning", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=[None, "g2p_en", "pyopenjtalk", "pyopenjtalk_kana"], + default=None, + help="Specify g2p method if --token_type=phn", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol + return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + # TODO(Jing): ask Kamo if it ok to support several args, + # like text_name = 'text_ref1' and 'text_ref2' + if args.use_preprocessor: + retval = CommonPreprocessor_multi( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_name=["text_ref1", "text_ref2"], + text_cleaner=args.cleaner, + g2p_type=args.g2p, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech_mix", "speech_ref1", "text_ref1") + else: + # Recognition mode + retval = ("speech_mix",) + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = ["dereverb_ref"] + retval += ["speech_ref{}".format(n) for n in range(2, MAX_REFERENCE_NUM + 1)] + retval += ["text_ref{}".format(n) for n in range(2, MAX_REFERENCE_NUM + 1)] + retval += ["noise_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)] + retval = tuple(retval) + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> ESPnetEnhASRModel: + assert check_argument_types() + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # Overwriting token_list to keep it as "portable". + args.token_list = list(token_list) + elif isinstance(args.token_list, (tuple, list)): + token_list = list(args.token_list) + else: + raise RuntimeError("token_list must be str or list") + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size }") + + # 0. Build pre enhancement model + enh_model = enh_choices.get_class(args.enh)(**args.enh_conf) + + # 1. frontend + if args.input_size is None: + # Extract features in the model + frontend_class = frontend_choices.get_class(args.frontend) + frontend = frontend_class(**args.frontend_conf) + input_size = frontend.output_size() + else: + # Give features from data-loader + args.frontend = None + args.frontend_conf = {} + frontend = None + input_size = args.input_size + + # 2. Data augmentation for spectrogram + if args.specaug is not None: + specaug_class = specaug_choices.get_class(args.specaug) + specaug = specaug_class(**args.specaug_conf) + else: + specaug = None + + # 3. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 4. Encoder + encoder_class = encoder_choices.get_class(args.encoder) + encoder = encoder_class(input_size=input_size, **args.encoder_conf) + + # 5. Decoder + decoder_class = decoder_choices.get_class(args.decoder) + + decoder = decoder_class( + vocab_size=vocab_size, + encoder_output_size=encoder.output_size(), + **args.decoder_conf, + ) + + # 6. CTC + ctc = CTC( + odim=vocab_size, encoder_output_sizse=encoder.output_size(), **args.ctc_conf + ) + + # 7. RNN-T Decoder (Not implemented) + rnnt_decoder = None + + # 8. Build model + model = ESPnetEnhASRModel( + vocab_size=vocab_size, + enh=enh_model, + frontend=frontend, + specaug=specaug, + normalize=normalize, + encoder=encoder, + decoder=decoder, + ctc=ctc, + rnnt_decoder=rnnt_decoder, + token_list=token_list, + **args.asr_model_conf, + ) + + # FIXME(kamo): Should be done in model? + # 9. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model diff --git a/espnet2/tasks/lm.py b/espnet2/tasks/lm.py new file mode 100644 index 0000000000000000000000000000000000000000..282778244a4c91f2d23c2d2ce9f3c3d25d6570bb --- /dev/null +++ b/espnet2/tasks/lm.py @@ -0,0 +1,214 @@ +import argparse +import logging +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet2.lm.abs_model import AbsLM +from espnet2.lm.espnet_model import ESPnetLanguageModel +from espnet2.lm.seq_rnn_lm import SequentialRNNLM +from espnet2.lm.transformer_lm import TransformerLM +from espnet2.tasks.abs_task import AbsTask +from espnet2.torch_utils.initialize import initialize +from espnet2.train.class_choices import ClassChoices +from espnet2.train.collate_fn import CommonCollateFn +from espnet2.train.preprocessor import CommonPreprocessor +from espnet2.train.trainer import Trainer +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import str2bool +from espnet2.utils.types import str_or_none + + +lm_choices = ClassChoices( + "lm", + classes=dict( + seq_rnn=SequentialRNNLM, + transformer=TransformerLM, + ), + type_check=AbsLM, + default="seq_rnn", +) + + +class LMTask(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [lm_choices] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + # NOTE(kamo): Use '_' instead of '-' to avoid confusion + assert check_argument_types() + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + required = parser.get_default("required") + required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + group.add_argument( + "--model_conf", + action=NestedDictAction, + default=get_default_kwargs(ESPnetLanguageModel), + help="The keyword arguments for model class.", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default="bpe", + choices=["bpe", "char", "word"], + help="", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model file fo sentencepiece", + ) + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + help="non_linguistic_symbols file path", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Apply text cleaning", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=[None, "g2p_en", "pyopenjtalk", "pyopenjtalk_kana"], + default=None, + help="Specify g2p method if --token_type=phn", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + assert check_return_type(parser) + return parser + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + return CommonCollateFn(int_pad_value=0) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + non_linguistic_symbols=args.non_linguistic_symbols, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = ("text",) + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = () + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel: + assert check_argument_types() + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # "args" is saved as it is in a yaml file by BaseTask.main(). + # Overwriting token_list to keep it as "portable". + args.token_list = token_list.copy() + elif isinstance(args.token_list, (tuple, list)): + token_list = args.token_list.copy() + else: + raise RuntimeError("token_list must be str or dict") + + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size }") + + # 1. Build LM model + lm_class = lm_choices.get_class(args.lm) + lm = lm_class(vocab_size=vocab_size, **args.lm_conf) + + # 2. Build ESPnetModel + # Assume the last-id is sos_and_eos + model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf) + + # FIXME(kamo): Should be done in model? + # 3. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model diff --git a/espnet2/tasks/tts.py b/espnet2/tasks/tts.py new file mode 100644 index 0000000000000000000000000000000000000000..127039dbed387f8a679ba00baf9faf67926211cf --- /dev/null +++ b/espnet2/tasks/tts.py @@ -0,0 +1,361 @@ +import argparse +import logging +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.layers.global_mvn import GlobalMVN +from espnet2.tasks.abs_task import AbsTask +from espnet2.train.class_choices import ClassChoices +from espnet2.train.collate_fn import CommonCollateFn +from espnet2.train.preprocessor import CommonPreprocessor +from espnet2.train.trainer import Trainer +from espnet2.tts.abs_tts import AbsTTS +from espnet2.tts.espnet_model import ESPnetTTSModel +from espnet2.tts.fastspeech import FastSpeech +from espnet2.tts.fastspeech2 import FastSpeech2 +from espnet2.tts.fastespeech import FastESpeech +from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract +from espnet2.tts.feats_extract.dio import Dio +from espnet2.tts.feats_extract.energy import Energy +from espnet2.tts.feats_extract.log_mel_fbank import LogMelFbank +from espnet2.tts.feats_extract.log_spectrogram import LogSpectrogram +from espnet2.tts.tacotron2 import Tacotron2 +from espnet2.tts.transformer import Transformer +from espnet2.utils.get_default_kwargs import get_default_kwargs +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import int_or_none +from espnet2.utils.types import str2bool +from espnet2.utils.types import str_or_none + +feats_extractor_choices = ClassChoices( + "feats_extract", + classes=dict(fbank=LogMelFbank, spectrogram=LogSpectrogram), + type_check=AbsFeatsExtract, + default="fbank", +) +pitch_extractor_choices = ClassChoices( + "pitch_extract", + classes=dict(dio=Dio), + type_check=AbsFeatsExtract, + default=None, + optional=True, +) +energy_extractor_choices = ClassChoices( + "energy_extract", + classes=dict(energy=Energy), + type_check=AbsFeatsExtract, + default=None, + optional=True, +) +normalize_choices = ClassChoices( + "normalize", + classes=dict(global_mvn=GlobalMVN), + type_check=AbsNormalize, + default="global_mvn", + optional=True, +) +pitch_normalize_choices = ClassChoices( + "pitch_normalize", + classes=dict(global_mvn=GlobalMVN), + type_check=AbsNormalize, + default=None, + optional=True, +) +energy_normalize_choices = ClassChoices( + "energy_normalize", + classes=dict(global_mvn=GlobalMVN), + type_check=AbsNormalize, + default=None, + optional=True, +) +tts_choices = ClassChoices( + "tts", + classes=dict( + tacotron2=Tacotron2, + transformer=Transformer, + fastspeech=FastSpeech, + fastspeech2=FastSpeech2, + fastespeech=FastESpeech, + ), + type_check=AbsTTS, + default="tacotron2", +) + + +class TTSTask(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + # Add variable objects configurations + class_choices_list = [ + # --feats_extractor and --feats_extractor_conf + feats_extractor_choices, + # --normalize and --normalize_conf + normalize_choices, + # --tts and --tts_conf + tts_choices, + # --pitch_extract and --pitch_extract_conf + pitch_extractor_choices, + # --pitch_normalize and --pitch_normalize_conf + pitch_normalize_choices, + # --energy_extract and --energy_extract_conf + energy_extractor_choices, + # --energy_normalize and --energy_normalize_conf + energy_normalize_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + # NOTE(kamo): Use '_' instead of '-' to avoid confusion + assert check_argument_types() + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + required = parser.get_default("required") + required += ["token_list"] + + group.add_argument( + "--token_list", + type=str_or_none, + default=None, + help="A text mapping int-id to token", + ) + group.add_argument( + "--odim", + type=int_or_none, + default=None, + help="The number of dimension of output feature", + ) + group.add_argument( + "--model_conf", + action=NestedDictAction, + default=get_default_kwargs(ESPnetTTSModel), + help="The keyword arguments for model class.", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--use_preprocessor", + type=str2bool, + default=True, + help="Apply preprocessing to data or not", + ) + group.add_argument( + "--token_type", + type=str, + default="phn", + choices=["bpe", "char", "word", "phn"], + help="The text will be tokenized in the specified level token", + ) + group.add_argument( + "--bpemodel", + type=str_or_none, + default=None, + help="The model file of sentencepiece", + ) + parser.add_argument( + "--non_linguistic_symbols", + type=str_or_none, + help="non_linguistic_symbols file path", + ) + parser.add_argument( + "--cleaner", + type=str_or_none, + choices=[None, "tacotron", "jaconv", "vietnamese"], + default=None, + help="Apply text cleaning", + ) + parser.add_argument( + "--g2p", + type=str_or_none, + choices=[ + None, + "g2p_en", + "g2p_en_no_space", + "pyopenjtalk", + "pyopenjtalk_kana", + "pyopenjtalk_accent", + "pyopenjtalk_accent_with_pause", + "pypinyin_g2p", + "pypinyin_g2p_phone", + "espeak_ng_arabic", + ], + default=None, + help="Specify g2p method if --token_type=phn", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + return CommonCollateFn( + float_pad_value=0.0, int_pad_value=0, not_sequence=["spembs"] + ) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + if args.use_preprocessor: + retval = CommonPreprocessor( + train=train, + token_type=args.token_type, + token_list=args.token_list, + bpemodel=args.bpemodel, + non_linguistic_symbols=args.non_linguistic_symbols, + text_cleaner=args.cleaner, + g2p_type=args.g2p, + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("text", "speech") + else: + # Inference mode + retval = ("text",) + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("spembs", "durations", "pitch", "energy") + else: + # Inference mode + retval = ("spembs", "speech", "durations") + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> ESPnetTTSModel: + assert check_argument_types() + if isinstance(args.token_list, str): + with open(args.token_list, encoding="utf-8") as f: + token_list = [line.rstrip() for line in f] + + # "args" is saved as it is in a yaml file by BaseTask.main(). + # Overwriting token_list to keep it as "portable". + args.token_list = token_list.copy() + elif isinstance(args.token_list, (tuple, list)): + token_list = args.token_list.copy() + else: + raise RuntimeError("token_list must be str or dict") + + vocab_size = len(token_list) + logging.info(f"Vocabulary size: {vocab_size }") + + # 1. feats_extract + if args.odim is None: + # Extract features in the model + feats_extract_class = feats_extractor_choices.get_class(args.feats_extract) + feats_extract = feats_extract_class(**args.feats_extract_conf) + odim = feats_extract.output_size() + else: + # Give features from data-loader + args.feats_extract = None + args.feats_extract_conf = None + feats_extract = None + odim = args.odim + + # 2. Normalization layer + if args.normalize is not None: + normalize_class = normalize_choices.get_class(args.normalize) + normalize = normalize_class(**args.normalize_conf) + else: + normalize = None + + # 3. TTS + tts_class = tts_choices.get_class(args.tts) + tts = tts_class(idim=vocab_size, odim=odim, **args.tts_conf) + + # 4. Extra components + pitch_extract = None + energy_extract = None + pitch_normalize = None + energy_normalize = None + if getattr(args, "pitch_extract", None) is not None: + pitch_extract_class = pitch_extractor_choices.get_class(args.pitch_extract) + if args.pitch_extract_conf.get("reduction_factor", None) is not None: + assert args.pitch_extract_conf.get( + "reduction_factor", None + ) == args.tts_conf.get("reduction_factor", 1) + else: + args.pitch_extract_conf["reduction_factor"] = args.tts_conf.get( + "reduction_factor", 1 + ) + pitch_extract = pitch_extract_class(**args.pitch_extract_conf) + if getattr(args, "energy_extract", None) is not None: + if args.energy_extract_conf.get("reduction_factor", None) is not None: + assert args.energy_extract_conf.get( + "reduction_factor", None + ) == args.tts_conf.get("reduction_factor", 1) + else: + args.energy_extract_conf["reduction_factor"] = args.tts_conf.get( + "reduction_factor", 1 + ) + energy_extract_class = energy_extractor_choices.get_class( + args.energy_extract + ) + energy_extract = energy_extract_class(**args.energy_extract_conf) + if getattr(args, "pitch_normalize", None) is not None: + pitch_normalize_class = pitch_normalize_choices.get_class( + args.pitch_normalize + ) + pitch_normalize = pitch_normalize_class(**args.pitch_normalize_conf) + if getattr(args, "energy_normalize", None) is not None: + energy_normalize_class = energy_normalize_choices.get_class( + args.energy_normalize + ) + energy_normalize = energy_normalize_class(**args.energy_normalize_conf) + + # 5. Build model + model = ESPnetTTSModel( + feats_extract=feats_extract, + pitch_extract=pitch_extract, + energy_extract=energy_extract, + normalize=normalize, + pitch_normalize=pitch_normalize, + energy_normalize=energy_normalize, + tts=tts, + **args.model_conf, + ) + + # AR prior training + # for mod, param in model.named_parameters(): + # if not mod.startswith("tts.prosody_encoder.ar_prior"): + # print(f"Setting {mod}.requires_grad = False") + # param.requires_grad = False + + assert check_return_type(model) + return model diff --git a/espnet2/text/__init__.py b/espnet2/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/text/abs_tokenizer.py b/espnet2/text/abs_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2ccb3c3694fef0fc4d4bc7576c355c7712fee4 --- /dev/null +++ b/espnet2/text/abs_tokenizer.py @@ -0,0 +1,14 @@ +from abc import ABC +from abc import abstractmethod +from typing import Iterable +from typing import List + + +class AbsTokenizer(ABC): + @abstractmethod + def text2tokens(self, line: str) -> List[str]: + raise NotImplementedError + + @abstractmethod + def tokens2text(self, tokens: Iterable[str]) -> str: + raise NotImplementedError diff --git a/espnet2/text/build_tokenizer.py b/espnet2/text/build_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..66ca54d455cbf5d41e8e3d4b4eddebee8263c958 --- /dev/null +++ b/espnet2/text/build_tokenizer.py @@ -0,0 +1,62 @@ +from pathlib import Path +from typing import Iterable +from typing import Union + +from typeguard import check_argument_types + +from espnet2.text.abs_tokenizer import AbsTokenizer +from espnet2.text.char_tokenizer import CharTokenizer +from espnet2.text.phoneme_tokenizer import PhonemeTokenizer +from espnet2.text.sentencepiece_tokenizer import SentencepiecesTokenizer +from espnet2.text.word_tokenizer import WordTokenizer + + +def build_tokenizer( + token_type: str, + bpemodel: Union[Path, str, Iterable[str]] = None, + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + remove_non_linguistic_symbols: bool = False, + space_symbol: str = "", + delimiter: str = None, + g2p_type: str = None, +) -> AbsTokenizer: + """A helper function to instantiate Tokenizer""" + assert check_argument_types() + if token_type == "bpe": + if bpemodel is None: + raise ValueError('bpemodel is required if token_type = "bpe"') + + if remove_non_linguistic_symbols: + raise RuntimeError( + "remove_non_linguistic_symbols is not implemented for token_type=bpe" + ) + return SentencepiecesTokenizer(bpemodel) + + elif token_type == "word": + if remove_non_linguistic_symbols and non_linguistic_symbols is not None: + return WordTokenizer( + delimiter=delimiter, + non_linguistic_symbols=non_linguistic_symbols, + remove_non_linguistic_symbols=True, + ) + else: + return WordTokenizer(delimiter=delimiter) + + elif token_type == "char": + return CharTokenizer( + non_linguistic_symbols=non_linguistic_symbols, + space_symbol=space_symbol, + remove_non_linguistic_symbols=remove_non_linguistic_symbols, + ) + + elif token_type == "phn": + return PhonemeTokenizer( + g2p_type=g2p_type, + non_linguistic_symbols=non_linguistic_symbols, + space_symbol=space_symbol, + remove_non_linguistic_symbols=remove_non_linguistic_symbols, + ) + else: + raise ValueError( + f"token_mode must be one of bpe, word, char or phn: " f"{token_type}" + ) diff --git a/espnet2/text/char_tokenizer.py b/espnet2/text/char_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5900352862453db4f2c07ca4b850635cb1bb91f8 --- /dev/null +++ b/espnet2/text/char_tokenizer.py @@ -0,0 +1,57 @@ +from pathlib import Path +from typing import Iterable +from typing import List +from typing import Union + +from typeguard import check_argument_types + +from espnet2.text.abs_tokenizer import AbsTokenizer + + +class CharTokenizer(AbsTokenizer): + def __init__( + self, + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + space_symbol: str = "", + remove_non_linguistic_symbols: bool = False, + ): + assert check_argument_types() + self.space_symbol = space_symbol + if non_linguistic_symbols is None: + self.non_linguistic_symbols = set() + elif isinstance(non_linguistic_symbols, (Path, str)): + non_linguistic_symbols = Path(non_linguistic_symbols) + with non_linguistic_symbols.open("r", encoding="utf-8") as f: + self.non_linguistic_symbols = set(line.rstrip() for line in f) + else: + self.non_linguistic_symbols = set(non_linguistic_symbols) + self.remove_non_linguistic_symbols = remove_non_linguistic_symbols + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f'space_symbol="{self.space_symbol}"' + f'non_linguistic_symbols="{self.non_linguistic_symbols}"' + f")" + ) + + def text2tokens(self, line: str) -> List[str]: + tokens = [] + while len(line) != 0: + for w in self.non_linguistic_symbols: + if line.startswith(w): + if not self.remove_non_linguistic_symbols: + tokens.append(line[: len(w)]) + line = line[len(w) :] + break + else: + t = line[0] + if t == " ": + t = "" + tokens.append(t) + line = line[1:] + return tokens + + def tokens2text(self, tokens: Iterable[str]) -> str: + tokens = [t if t != self.space_symbol else " " for t in tokens] + return "".join(tokens) diff --git a/espnet2/text/cleaner.py b/espnet2/text/cleaner.py new file mode 100644 index 0000000000000000000000000000000000000000..714eb7cbdb4b430af3c31390fefd9e904ad41ad9 --- /dev/null +++ b/espnet2/text/cleaner.py @@ -0,0 +1,46 @@ +from typing import Collection + +from jaconv import jaconv +import tacotron_cleaner.cleaners +from typeguard import check_argument_types + +try: + from vietnamese_cleaner import vietnamese_cleaners +except ImportError: + vietnamese_cleaners = None + + +class TextCleaner: + """Text cleaner. + + Examples: + >>> cleaner = TextCleaner("tacotron") + >>> cleaner("(Hello-World); & jr. & dr.") + 'HELLO WORLD, AND JUNIOR AND DOCTOR' + + """ + + def __init__(self, cleaner_types: Collection[str] = None): + assert check_argument_types() + + if cleaner_types is None: + self.cleaner_types = [] + elif isinstance(cleaner_types, str): + self.cleaner_types = [cleaner_types] + else: + self.cleaner_types = list(cleaner_types) + + def __call__(self, text: str) -> str: + for t in self.cleaner_types: + if t == "tacotron": + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + elif t == "jaconv": + text = jaconv.normalize(text) + elif t == "vietnamese": + if vietnamese_cleaners is None: + raise RuntimeError("Please install underthesea") + text = vietnamese_cleaners.vietnamese_cleaner(text) + else: + raise RuntimeError(f"Not supported: type={t}") + + return text diff --git a/espnet2/text/phoneme_tokenizer.py b/espnet2/text/phoneme_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a1298a18cc65587220397362df679fed68272194 --- /dev/null +++ b/espnet2/text/phoneme_tokenizer.py @@ -0,0 +1,218 @@ +from pathlib import Path +from typing import Iterable +from typing import List +from typing import Optional +from typing import Union + +import g2p_en +from typeguard import check_argument_types + +from espnet2.text.abs_tokenizer import AbsTokenizer + + +def split_by_space(text) -> List[str]: + return text.split(" ") + + +def pyopenjtalk_g2p(text) -> List[str]: + import pyopenjtalk + + # phones is a str object separated by space + phones = pyopenjtalk.g2p(text, kana=False) + phones = phones.split(" ") + return phones + + +def pyopenjtalk_g2p_accent(text) -> List[str]: + import pyopenjtalk + import re + + phones = [] + for labels in pyopenjtalk.run_frontend(text)[1]: + p = re.findall(r"\-(.*?)\+.*?\/A:([0-9\-]+).*?\/F:.*?_([0-9])", labels) + if len(p) == 1: + phones += [p[0][0], p[0][2], p[0][1]] + return phones + + +def pyopenjtalk_g2p_accent_with_pause(text) -> List[str]: + import pyopenjtalk + import re + + phones = [] + for labels in pyopenjtalk.run_frontend(text)[1]: + if labels.split("-")[1].split("+")[0] == "pau": + phones += ["pau"] + continue + p = re.findall(r"\-(.*?)\+.*?\/A:([0-9\-]+).*?\/F:.*?_([0-9])", labels) + if len(p) == 1: + phones += [p[0][0], p[0][2], p[0][1]] + return phones + + +def pyopenjtalk_g2p_kana(text) -> List[str]: + import pyopenjtalk + + kanas = pyopenjtalk.g2p(text, kana=True) + return list(kanas) + + +def pypinyin_g2p(text) -> List[str]: + from pypinyin import pinyin + from pypinyin import Style + + phones = [phone[0] for phone in pinyin(text, style=Style.TONE3)] + return phones + + +def pypinyin_g2p_phone(text) -> List[str]: + from pypinyin import pinyin + from pypinyin import Style + from pypinyin.style._utils import get_finals + from pypinyin.style._utils import get_initials + + phones = [ + p + for phone in pinyin(text, style=Style.TONE3) + for p in [ + get_initials(phone[0], strict=True), + get_finals(phone[0], strict=True), + ] + if len(p) != 0 + ] + return phones + + +class G2p_en: + """On behalf of g2p_en.G2p. + + g2p_en.G2p isn't pickalable and it can't be copied to the other processes + via multiprocessing module. + As a workaround, g2p_en.G2p is instantiated upon calling this class. + + """ + + def __init__(self, no_space: bool = False): + self.no_space = no_space + self.g2p = None + + def __call__(self, text) -> List[str]: + if self.g2p is None: + self.g2p = g2p_en.G2p() + + phones = self.g2p(text) + if self.no_space: + # remove space which represents word serapater + phones = list(filter(lambda s: s != " ", phones)) + return phones + + +class Phonemizer: + """Phonemizer module for various languages. + + This is wrapper module of https://github.com/bootphon/phonemizer. + You can define various g2p modules by specifying options for phonemizer. + + See available options: + https://github.com/bootphon/phonemizer/blob/master/phonemizer/phonemize.py#L32 + + """ + + def __init__( + self, + word_separator: Optional[str] = None, + syllable_separator: Optional[str] = None, + **phonemize_kwargs, + ): + # delayed import + from phonemizer import phonemize + from phonemizer.separator import Separator + + self.phonemize = phonemize + self.separator = Separator( + word=word_separator, syllable=syllable_separator, phone=" " + ) + self.phonemize_kwargs = phonemize_kwargs + + def __call__(self, text) -> List[str]: + return self.phonemize( + text, + separator=self.separator, + **self.phonemize_kwargs, + ).split() + + +class PhonemeTokenizer(AbsTokenizer): + def __init__( + self, + g2p_type: Union[None, str], + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + space_symbol: str = "", + remove_non_linguistic_symbols: bool = False, + ): + assert check_argument_types() + if g2p_type is None: + self.g2p = split_by_space + elif g2p_type == "g2p_en": + self.g2p = G2p_en(no_space=False) + elif g2p_type == "g2p_en_no_space": + self.g2p = G2p_en(no_space=True) + elif g2p_type == "pyopenjtalk": + self.g2p = pyopenjtalk_g2p + elif g2p_type == "pyopenjtalk_kana": + self.g2p = pyopenjtalk_g2p_kana + elif g2p_type == "pyopenjtalk_accent": + self.g2p = pyopenjtalk_g2p_accent + elif g2p_type == "pyopenjtalk_accent_with_pause": + self.g2p = pyopenjtalk_g2p_accent_with_pause + elif g2p_type == "pypinyin_g2p": + self.g2p = pypinyin_g2p + elif g2p_type == "pypinyin_g2p_phone": + self.g2p = pypinyin_g2p_phone + elif g2p_type == "espeak_ng_arabic": + self.g2p = Phonemizer(language="ar", backend="espeak", with_stress=True) + else: + raise NotImplementedError(f"Not supported: g2p_type={g2p_type}") + + self.g2p_type = g2p_type + self.space_symbol = space_symbol + if non_linguistic_symbols is None: + self.non_linguistic_symbols = set() + elif isinstance(non_linguistic_symbols, (Path, str)): + non_linguistic_symbols = Path(non_linguistic_symbols) + with non_linguistic_symbols.open("r", encoding="utf-8") as f: + self.non_linguistic_symbols = set(line.rstrip() for line in f) + else: + self.non_linguistic_symbols = set(non_linguistic_symbols) + self.remove_non_linguistic_symbols = remove_non_linguistic_symbols + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f'g2p_type="{self.g2p_type}", ' + f'space_symbol="{self.space_symbol}", ' + f'non_linguistic_symbols="{self.non_linguistic_symbols}"' + f")" + ) + + def text2tokens(self, line: str) -> List[str]: + tokens = [] + while len(line) != 0: + for w in self.non_linguistic_symbols: + if line.startswith(w): + if not self.remove_non_linguistic_symbols: + tokens.append(line[: len(w)]) + line = line[len(w) :] + break + else: + t = line[0] + tokens.append(t) + line = line[1:] + + line = "".join(tokens) + tokens = self.g2p(line) + return tokens + + def tokens2text(self, tokens: Iterable[str]) -> str: + # phoneme type is not invertible + return "".join(tokens) diff --git a/espnet2/text/sentencepiece_tokenizer.py b/espnet2/text/sentencepiece_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0db7110760c34ccc1f4ad57e3272835712b89111 --- /dev/null +++ b/espnet2/text/sentencepiece_tokenizer.py @@ -0,0 +1,38 @@ +from pathlib import Path +from typing import Iterable +from typing import List +from typing import Union + +import sentencepiece as spm +from typeguard import check_argument_types + +from espnet2.text.abs_tokenizer import AbsTokenizer + + +class SentencepiecesTokenizer(AbsTokenizer): + def __init__(self, model: Union[Path, str]): + assert check_argument_types() + self.model = str(model) + # NOTE(kamo): + # Don't build SentencePieceProcessor in __init__() + # because it's not picklable and it may cause following error, + # "TypeError: can't pickle SwigPyObject objects", + # when giving it as argument of "multiprocessing.Process()". + self.sp = None + + def __repr__(self): + return f'{self.__class__.__name__}(model="{self.model}")' + + def _build_sentence_piece_processor(self): + # Build SentencePieceProcessor lazily. + if self.sp is None: + self.sp = spm.SentencePieceProcessor() + self.sp.load(self.model) + + def text2tokens(self, line: str) -> List[str]: + self._build_sentence_piece_processor() + return self.sp.EncodeAsPieces(line) + + def tokens2text(self, tokens: Iterable[str]) -> str: + self._build_sentence_piece_processor() + return self.sp.DecodePieces(list(tokens)) diff --git a/espnet2/text/token_id_converter.py b/espnet2/text/token_id_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..c9a6b28638b75a547708a368ad10389ad0198f38 --- /dev/null +++ b/espnet2/text/token_id_converter.py @@ -0,0 +1,60 @@ +from pathlib import Path +from typing import Dict +from typing import Iterable +from typing import List +from typing import Union + +import numpy as np +from typeguard import check_argument_types + + +class TokenIDConverter: + def __init__( + self, + token_list: Union[Path, str, Iterable[str]], + unk_symbol: str = "", + ): + assert check_argument_types() + + if isinstance(token_list, (Path, str)): + token_list = Path(token_list) + self.token_list_repr = str(token_list) + self.token_list: List[str] = [] + + with token_list.open("r", encoding="utf-8") as f: + for idx, line in enumerate(f): + line = line.rstrip() + self.token_list.append(line) + + else: + self.token_list: List[str] = list(token_list) + self.token_list_repr = "" + for i, t in enumerate(self.token_list): + if i == 3: + break + self.token_list_repr += f"{t}, " + self.token_list_repr += f"... (NVocab={(len(self.token_list))})" + + self.token2id: Dict[str, int] = {} + for i, t in enumerate(self.token_list): + if t in self.token2id: + raise RuntimeError(f'Symbol "{t}" is duplicated') + self.token2id[t] = i + + self.unk_symbol = unk_symbol + if self.unk_symbol not in self.token2id: + raise RuntimeError( + f"Unknown symbol '{unk_symbol}' doesn't exist in the token_list" + ) + self.unk_id = self.token2id[self.unk_symbol] + + def get_num_vocabulary_size(self) -> int: + return len(self.token_list) + + def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: + if isinstance(integers, np.ndarray) and integers.ndim != 1: + raise ValueError(f"Must be 1 dim ndarray, but got {integers.ndim}") + return [self.token_list[i] for i in integers] + + def tokens2ids(self, tokens: Iterable[str]) -> List[int]: + return [self.token2id.get(i, self.unk_id) for i in tokens] diff --git a/espnet2/text/word_tokenizer.py b/espnet2/text/word_tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b729ecf91de21cb0552b3a6a1096ab9a20522dc --- /dev/null +++ b/espnet2/text/word_tokenizer.py @@ -0,0 +1,54 @@ +from pathlib import Path +from typing import Iterable +from typing import List +from typing import Union +import warnings + +from typeguard import check_argument_types + +from espnet2.text.abs_tokenizer import AbsTokenizer + + +class WordTokenizer(AbsTokenizer): + def __init__( + self, + delimiter: str = None, + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + remove_non_linguistic_symbols: bool = False, + ): + assert check_argument_types() + self.delimiter = delimiter + + if not remove_non_linguistic_symbols and non_linguistic_symbols is not None: + warnings.warn( + "non_linguistic_symbols is only used " + "when remove_non_linguistic_symbols = True" + ) + + if non_linguistic_symbols is None: + self.non_linguistic_symbols = set() + elif isinstance(non_linguistic_symbols, (Path, str)): + non_linguistic_symbols = Path(non_linguistic_symbols) + with non_linguistic_symbols.open("r", encoding="utf-8") as f: + self.non_linguistic_symbols = set(line.rstrip() for line in f) + else: + self.non_linguistic_symbols = set(non_linguistic_symbols) + self.remove_non_linguistic_symbols = remove_non_linguistic_symbols + + def __repr__(self): + return f'{self.__class__.__name__}(delimiter="{self.delimiter}")' + + def text2tokens(self, line: str) -> List[str]: + tokens = [] + for t in line.split(self.delimiter): + if self.remove_non_linguistic_symbols and t in self.non_linguistic_symbols: + continue + tokens.append(t) + return tokens + + def tokens2text(self, tokens: Iterable[str]) -> str: + if self.delimiter is None: + delimiter = " " + else: + delimiter = self.delimiter + return delimiter.join(tokens) diff --git a/espnet2/torch_utils/__init__.py b/espnet2/torch_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/torch_utils/add_gradient_noise.py b/espnet2/torch_utils/add_gradient_noise.py new file mode 100644 index 0000000000000000000000000000000000000000..dd488dd75b9d58d1cc7c944670e1eb66af4269fa --- /dev/null +++ b/espnet2/torch_utils/add_gradient_noise.py @@ -0,0 +1,31 @@ +import torch + + +def add_gradient_noise( + model: torch.nn.Module, + iteration: int, + duration: float = 100, + eta: float = 1.0, + scale_factor: float = 0.55, +): + """Adds noise from a standard normal distribution to the gradients. + + The standard deviation (`sigma`) is controlled + by the three hyper-parameters below. + `sigma` goes to zero (no noise) with more iterations. + + Args: + model: Model. + iteration: Number of iterations. + duration: {100, 1000}: Number of durations to control + the interval of the `sigma` change. + eta: {0.01, 0.3, 1.0}: The magnitude of `sigma`. + scale_factor: {0.55}: The scale of `sigma`. + """ + interval = (iteration // duration) + 1 + sigma = eta / interval ** scale_factor + for param in model.parameters(): + if param.grad is not None: + _shape = param.grad.size() + noise = sigma * torch.randn(_shape).to(param.device) + param.grad += noise diff --git a/espnet2/torch_utils/device_funcs.py b/espnet2/torch_utils/device_funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..7919e7d923200e0deb67bc1a8f83ad56c599d274 --- /dev/null +++ b/espnet2/torch_utils/device_funcs.py @@ -0,0 +1,71 @@ +import dataclasses +import warnings + +import numpy as np +import torch + + +def to_device(data, device=None, dtype=None, non_blocking=False, copy=False): + """Change the device of object recursively""" + if isinstance(data, dict): + return { + k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items() + } + elif dataclasses.is_dataclass(data) and not isinstance(data, type): + return type(data)( + *[ + to_device(v, device, dtype, non_blocking, copy) + for v in dataclasses.astuple(data) + ] + ) + # maybe namedtuple. I don't know the correct way to judge namedtuple. + elif isinstance(data, tuple) and type(data) is not tuple: + return type(data)( + *[to_device(o, device, dtype, non_blocking, copy) for o in data] + ) + elif isinstance(data, (list, tuple)): + return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data) + elif isinstance(data, np.ndarray): + return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy) + elif isinstance(data, torch.Tensor): + return data.to(device, dtype, non_blocking, copy) + else: + return data + + +def force_gatherable(data, device): + """Change object to gatherable in torch.nn.DataParallel recursively + + The difference from to_device() is changing to torch.Tensor if float or int + value is found. + + The restriction to the returned value in DataParallel: + The object must be + - torch.cuda.Tensor + - 1 or more dimension. 0-dimension-tensor sends warning. + or a list, tuple, dict. + + """ + if isinstance(data, dict): + return {k: force_gatherable(v, device) for k, v in data.items()} + # DataParallel can't handle NamedTuple well + elif isinstance(data, tuple) and type(data) is not tuple: + return type(data)(*[force_gatherable(o, device) for o in data]) + elif isinstance(data, (list, tuple, set)): + return type(data)(force_gatherable(v, device) for v in data) + elif isinstance(data, np.ndarray): + return force_gatherable(torch.from_numpy(data), device) + elif isinstance(data, torch.Tensor): + if data.dim() == 0: + # To 1-dim array + data = data[None] + return data.to(device) + elif isinstance(data, float): + return torch.tensor([data], dtype=torch.float, device=device) + elif isinstance(data, int): + return torch.tensor([data], dtype=torch.long, device=device) + elif data is None: + return None + else: + warnings.warn(f"{type(data)} may not be gatherable by DataParallel") + return data diff --git a/espnet2/torch_utils/forward_adaptor.py b/espnet2/torch_utils/forward_adaptor.py new file mode 100644 index 0000000000000000000000000000000000000000..114af785113b6b081b8874db99e41fd9a7221ee3 --- /dev/null +++ b/espnet2/torch_utils/forward_adaptor.py @@ -0,0 +1,33 @@ +import torch +from typeguard import check_argument_types + + +class ForwardAdaptor(torch.nn.Module): + """Wrapped module to parallelize specified method + + torch.nn.DataParallel parallelizes only "forward()" + and, maybe, the method having the other name can't be applied + except for wrapping the module just like this class. + + Examples: + >>> class A(torch.nn.Module): + ... def foo(self, x): + ... ... + >>> model = A() + >>> model = ForwardAdaptor(model, "foo") + >>> model = torch.nn.DataParallel(model, device_ids=[0, 1]) + >>> x = torch.randn(2, 10) + >>> model(x) + """ + + def __init__(self, module: torch.nn.Module, name: str): + assert check_argument_types() + super().__init__() + self.module = module + self.name = name + if not hasattr(module, name): + raise ValueError(f"{module} doesn't have {name}") + + def forward(self, *args, **kwargs): + func = getattr(self.module, self.name) + return func(*args, **kwargs) diff --git a/espnet2/torch_utils/initialize.py b/espnet2/torch_utils/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..b8d4a0711600d97696d7c29af55e67a5afdf9de0 --- /dev/null +++ b/espnet2/torch_utils/initialize.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +"""Initialize modules for espnet2 neural networks.""" + +import math +import torch +from typeguard import check_argument_types + + +def initialize(model: torch.nn.Module, init: str): + """Initialize weights of a neural network module. + + Parameters are initialized using the given method or distribution. + + Custom initialization routines can be implemented into submodules + as function `espnet_initialization_fn` within the custom module. + + Args: + model: Target. + init: Method of initialization. + """ + assert check_argument_types() + + if init == "chainer": + # 1. lecun_normal_init_parameters + for p in model.parameters(): + data = p.data + if data.dim() == 1: + # bias + data.zero_() + elif data.dim() == 2: + # linear weight + n = data.size(1) + stdv = 1.0 / math.sqrt(n) + data.normal_(0, stdv) + elif data.dim() in (3, 4): + # conv weight + n = data.size(1) + for k in data.size()[2:]: + n *= k + stdv = 1.0 / math.sqrt(n) + data.normal_(0, stdv) + else: + raise NotImplementedError + + for mod in model.modules(): + # 2. embed weight ~ Normal(0, 1) + if isinstance(mod, torch.nn.Embedding): + mod.weight.data.normal_(0, 1) + # 3. forget-bias = 1.0 + elif isinstance(mod, torch.nn.RNNCellBase): + n = mod.bias_ih.size(0) + mod.bias_ih.data[n // 4 : n // 2].fill_(1.0) + elif isinstance(mod, torch.nn.RNNBase): + for name, param in mod.named_parameters(): + if "bias" in name: + n = param.size(0) + param.data[n // 4 : n // 2].fill_(1.0) + if hasattr(mod, "espnet_initialization_fn"): + mod.espnet_initialization_fn() + + else: + # weight init + for p in model.parameters(): + if p.dim() > 1: + if init == "xavier_uniform": + torch.nn.init.xavier_uniform_(p.data) + elif init == "xavier_normal": + torch.nn.init.xavier_normal_(p.data) + elif init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") + elif init == "kaiming_normal": + torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") + else: + raise ValueError("Unknown initialization: " + init) + # bias init + for p in model.parameters(): + if p.dim() == 1: + p.data.zero_() + + # reset some modules with default init + for m in model.modules(): + if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)): + m.reset_parameters() + if hasattr(m, "espnet_initialization_fn"): + m.espnet_initialization_fn() + + # TODO(xkc): Hacking wav2vec2 initialization + if getattr(model, "encoder", None) and getattr( + model.encoder, "reload_pretrained_parameters", None + ): + model.encoder.reload_pretrained_parameters() diff --git a/espnet2/torch_utils/load_pretrained_model.py b/espnet2/torch_utils/load_pretrained_model.py new file mode 100644 index 0000000000000000000000000000000000000000..dbce4d95c5128abd5176000d02dd986bc9a6916b --- /dev/null +++ b/espnet2/torch_utils/load_pretrained_model.py @@ -0,0 +1,81 @@ +from typing import Any + +import torch +import torch.nn +import torch.optim + + +def load_pretrained_model( + init_param: str, + model: torch.nn.Module, + map_location: str = "cpu", +): + """Load a model state and set it to the model. + + Args: + init_param: ::: + + Examples: + >>> load_pretrained_model("somewhere/model.pth", model) + >>> load_pretrained_model("somewhere/model.pth:decoder:decoder", model) + >>> load_pretrained_model("somewhere/model.pth:decoder:decoder:", model) + >>> load_pretrained_model( + ... "somewhere/model.pth:decoder:decoder:decoder.embed", model + ... ) + >>> load_pretrained_model("somewhere/decoder.pth::decoder", model) + """ + sps = init_param.split(":", 4) + if len(sps) == 4: + path, src_key, dst_key, excludes = sps + elif len(sps) == 3: + path, src_key, dst_key = sps + excludes = None + elif len(sps) == 2: + path, src_key = sps + dst_key, excludes = None, None + else: + (path,) = sps + src_key, dst_key, excludes = None, None, None + if src_key == "": + src_key = None + if dst_key == "": + dst_key = None + + if dst_key is None: + obj = model + else: + + def get_attr(obj: Any, key: str): + """Get an nested attribute. + + >>> class A(torch.nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = torch.nn.Linear(10, 10) + >>> a = A() + >>> assert A.linear.weight is get_attr(A, 'linear.weight') + + """ + if key.strip() == "": + return obj + for k in key.split("."): + obj = getattr(obj, k) + return obj + + obj = get_attr(model, dst_key) + + src_state = torch.load(path, map_location=map_location) + if excludes is not None: + for e in excludes.split(","): + src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} + + if src_key is not None: + src_state = { + k[len(src_key) + 1 :]: v + for k, v in src_state.items() + if k.startswith(src_key) + } + + dst_state = obj.state_dict() + dst_state.update(src_state) + obj.load_state_dict(dst_state) diff --git a/espnet2/torch_utils/model_summary.py b/espnet2/torch_utils/model_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..df34b0b9a7f6f34fc0be124e744f044a68aee5d5 --- /dev/null +++ b/espnet2/torch_utils/model_summary.py @@ -0,0 +1,70 @@ +import humanfriendly +import numpy as np +import torch + + +def get_human_readable_count(number: int) -> str: + """Return human_readable_count + + Originated from: + https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py + + Abbreviates an integer number with K, M, B, T for thousands, millions, + billions and trillions, respectively. + Examples: + >>> get_human_readable_count(123) + '123 ' + >>> get_human_readable_count(1234) # (one thousand) + '1 K' + >>> get_human_readable_count(2e6) # (two million) + '2 M' + >>> get_human_readable_count(3e9) # (three billion) + '3 B' + >>> get_human_readable_count(4e12) # (four trillion) + '4 T' + >>> get_human_readable_count(5e15) # (more than trillion) + '5,000 T' + Args: + number: a positive integer number + Return: + A string formatted according to the pattern described above. + """ + assert number >= 0 + labels = [" ", "K", "M", "B", "T"] + num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) + num_groups = int(np.ceil(num_digits / 3)) + num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions + shift = -3 * (num_groups - 1) + number = number * (10 ** shift) + index = num_groups - 1 + return f"{number:.2f} {labels[index]}" + + +def to_bytes(dtype) -> int: + # torch.float16 -> 16 + return int(str(dtype)[-2:]) // 8 + + +def model_summary(model: torch.nn.Module) -> str: + message = "Model structure:\n" + message += str(model) + tot_params = sum(p.numel() for p in model.parameters()) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params) + tot_params = get_human_readable_count(tot_params) + num_params = get_human_readable_count(num_params) + message += "\n\nModel summary:\n" + message += f" Class Name: {model.__class__.__name__}\n" + message += f" Total Number of model parameters: {tot_params}\n" + message += ( + f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n" + ) + num_bytes = humanfriendly.format_size( + sum( + p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad + ) + ) + message += f" Size: {num_bytes}\n" + dtype = next(iter(model.parameters())).dtype + message += f" Type: {dtype}" + return message diff --git a/espnet2/torch_utils/pytorch_version.py b/espnet2/torch_utils/pytorch_version.py new file mode 100644 index 0000000000000000000000000000000000000000..01f17cc748e3af444d551a14600728205cd6f61d --- /dev/null +++ b/espnet2/torch_utils/pytorch_version.py @@ -0,0 +1,16 @@ +import torch + + +def pytorch_cudnn_version() -> str: + message = ( + f"pytorch.version={torch.__version__}, " + f"cuda.available={torch.cuda.is_available()}, " + ) + + if torch.backends.cudnn.enabled: + message += ( + f"cudnn.version={torch.backends.cudnn.version()}, " + f"cudnn.benchmark={torch.backends.cudnn.benchmark}, " + f"cudnn.deterministic={torch.backends.cudnn.deterministic}" + ) + return message diff --git a/espnet2/torch_utils/recursive_op.py b/espnet2/torch_utils/recursive_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c3b70fb3fa5b29530629769d59946bd72323e56f --- /dev/null +++ b/espnet2/torch_utils/recursive_op.py @@ -0,0 +1,53 @@ +from distutils.version import LooseVersion + +import torch + +if torch.distributed.is_available(): + if LooseVersion(torch.__version__) > LooseVersion("1.0.1"): + from torch.distributed import ReduceOp + else: + from torch.distributed import reduce_op as ReduceOp +else: + ReduceOp = None + + +def recursive_sum(obj, weight: torch.Tensor, distributed: bool = False): + assert weight.dim() == 1, weight.size() + if isinstance(obj, (tuple, list)): + return type(obj)(recursive_sum(v, weight, distributed) for v in obj) + elif isinstance(obj, dict): + return {k: recursive_sum(v, weight, distributed) for k, v in obj.items()} + elif isinstance(obj, torch.Tensor): + assert obj.size() == weight.size(), (obj.size(), weight.size()) + obj = (obj * weight.type(obj.dtype)).sum() + if distributed: + torch.distributed.all_reduce(obj, op=ReduceOp.SUM) + return obj + elif obj is None: + return None + else: + raise ValueError(type(obj)) + + +def recursive_divide(a, b: torch.Tensor): + if isinstance(a, (tuple, list)): + return type(a)(recursive_divide(v, b) for v in a) + elif isinstance(a, dict): + return {k: recursive_divide(v, b) for k, v in a.items()} + elif isinstance(a, torch.Tensor): + assert a.size() == b.size(), (a.size(), b.size()) + return a / b.type(a.dtype) + elif a is None: + return None + else: + raise ValueError(type(a)) + + +def recursive_average(obj, weight: torch.Tensor, distributed: bool = False): + obj = recursive_sum(obj, weight, distributed) + weight = weight.sum() + if distributed: + torch.distributed.all_reduce(weight, op=ReduceOp.SUM) + # Normalize weight to be sum-to-1 + obj = recursive_divide(obj, weight) + return obj, weight diff --git a/espnet2/torch_utils/set_all_random_seed.py b/espnet2/torch_utils/set_all_random_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..ebdca3f537aac53bdc6e6cea168c49805bdf2d2f --- /dev/null +++ b/espnet2/torch_utils/set_all_random_seed.py @@ -0,0 +1,10 @@ +import random + +import numpy as np +import torch + + +def set_all_random_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) diff --git a/espnet2/train/__init__.py b/espnet2/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/train/abs_espnet_model.py b/espnet2/train/abs_espnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..6fd50603680e8370a8fc71de5c4ff6cb3c7e5c02 --- /dev/null +++ b/espnet2/train/abs_espnet_model.py @@ -0,0 +1,42 @@ +from abc import ABC +from abc import abstractmethod +from typing import Dict +from typing import Tuple + +import torch + + +class AbsESPnetModel(torch.nn.Module, ABC): + """The common abstract class among each tasks + + "ESPnetModel" is referred to a class which inherits torch.nn.Module, + and makes the dnn-models forward as its member field, + a.k.a delegate pattern, + and defines "loss", "stats", and "weight" for the task. + + If you intend to implement new task in ESPNet, + the model must inherit this class. + In other words, the "mediator" objects between + our training system and the your task class are + just only these three values, loss, stats, and weight. + + Example: + >>> from espnet2.tasks.abs_task import AbsTask + >>> class YourESPnetModel(AbsESPnetModel): + ... def forward(self, input, input_lengths): + ... ... + ... return loss, stats, weight + >>> class YourTask(AbsTask): + ... @classmethod + ... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel: + """ + + @abstractmethod + def forward( + self, **batch: torch.Tensor + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]: + raise NotImplementedError diff --git a/espnet2/train/class_choices.py b/espnet2/train/class_choices.py new file mode 100644 index 0000000000000000000000000000000000000000..821bab8121b5105625616eb6cb4962576accc68f --- /dev/null +++ b/espnet2/train/class_choices.py @@ -0,0 +1,95 @@ +from typing import Mapping +from typing import Optional +from typing import Tuple + +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet2.utils.nested_dict_action import NestedDictAction +from espnet2.utils.types import str_or_none + + +class ClassChoices: + """Helper class to manage the options for variable objects and its configuration. + + Example: + + >>> class A: + ... def __init__(self, foo=3): pass + >>> class B: + ... def __init__(self, bar="aaaa"): pass + >>> choices = ClassChoices("var", dict(a=A, b=B), default="a") + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> choices.add_arguments(parser) + >>> args = parser.parse_args(["--var", "a", "--var_conf", "foo=4") + >>> args.var + a + >>> args.var_conf + {"foo": 4} + >>> class_obj = choices.get_class(args.var) + >>> a_object = class_obj(**args.var_conf) + + """ + + def __init__( + self, + name: str, + classes: Mapping[str, type], + type_check: type = None, + default: str = None, + optional: bool = False, + ): + assert check_argument_types() + self.name = name + self.base_type = type_check + self.classes = {k.lower(): v for k, v in classes.items()} + if "none" in self.classes or "nil" in self.classes or "null" in self.classes: + raise ValueError('"none", "nil", and "null" are reserved.') + if type_check is not None: + for v in self.classes.values(): + if not issubclass(v, type_check): + raise ValueError(f"must be {type_check.__name__}, but got {v}") + + self.optional = optional + self.default = default + if default is None: + self.optional = True + + def choices(self) -> Tuple[Optional[str], ...]: + retval = tuple(self.classes) + if self.optional: + return retval + (None,) + else: + return retval + + def get_class(self, name: Optional[str]) -> Optional[type]: + assert check_argument_types() + if name is None or (self.optional and name.lower() == ("none", "null", "nil")): + retval = None + elif name.lower() in self.classes: + class_obj = self.classes[name] + assert check_return_type(class_obj) + retval = class_obj + else: + raise ValueError( + f"--{self.name} must be one of {self.choices()}: " + f"--{self.name} {name.lower()}" + ) + + return retval + + def add_arguments(self, parser): + parser.add_argument( + f"--{self.name}", + type=lambda x: str_or_none(x.lower()), + default=self.default, + choices=self.choices(), + help=f"The {self.name} type", + ) + parser.add_argument( + f"--{self.name}_conf", + action=NestedDictAction, + default=dict(), + help=f"The keyword arguments for {self.name}", + ) diff --git a/espnet2/train/collate_fn.py b/espnet2/train/collate_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a5bbb7792d28bf0fe1f08622d6961e05e65e16 --- /dev/null +++ b/espnet2/train/collate_fn.py @@ -0,0 +1,104 @@ +from typing import Collection +from typing import Dict +from typing import List +from typing import Tuple +from typing import Union + +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet.nets.pytorch_backend.nets_utils import pad_list + + +class CommonCollateFn: + """Functor class of common_collate_fn()""" + + def __init__( + self, + float_pad_value: Union[float, int] = 0.0, + int_pad_value: int = -32768, + not_sequence: Collection[str] = (), + ): + assert check_argument_types() + self.float_pad_value = float_pad_value + self.int_pad_value = int_pad_value + self.not_sequence = set(not_sequence) + + def __repr__(self): + return ( + f"{self.__class__}(float_pad_value={self.float_pad_value}, " + f"int_pad_value={self.float_pad_value})" + ) + + def __call__( + self, data: Collection[Tuple[str, Dict[str, np.ndarray]]] + ) -> Tuple[List[str], Dict[str, torch.Tensor]]: + return common_collate_fn( + data, + float_pad_value=self.float_pad_value, + int_pad_value=self.int_pad_value, + not_sequence=self.not_sequence, + ) + + +def common_collate_fn( + data: Collection[Tuple[str, Dict[str, np.ndarray]]], + float_pad_value: Union[float, int] = 0.0, + int_pad_value: int = -32768, + not_sequence: Collection[str] = (), +) -> Tuple[List[str], Dict[str, torch.Tensor]]: + """Concatenate ndarray-list to an array and convert to torch.Tensor. + + Examples: + >>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler, + >>> import espnet2.tasks.abs_task + >>> from espnet2.train.dataset import ESPnetDataset + >>> sampler = ConstantBatchSampler(...) + >>> dataset = ESPnetDataset(...) + >>> keys = next(iter(sampler) + >>> batch = [dataset[key] for key in keys] + >>> batch = common_collate_fn(batch) + >>> model(**batch) + + Note that the dict-keys of batch are propagated from + that of the dataset as they are. + + """ + assert check_argument_types() + uttids = [u for u, _ in data] + data = [d for _, d in data] + + assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching" + assert all( + not k.endswith("_lengths") for k in data[0] + ), f"*_lengths is reserved: {list(data[0])}" + + output = {} + for key in data[0]: + # NOTE(kamo): + # Each models, which accepts these values finally, are responsible + # to repaint the pad_value to the desired value for each tasks. + if data[0][key].dtype.kind == "i": + pad_value = int_pad_value + else: + pad_value = float_pad_value + + array_list = [d[key] for d in data] + + # Assume the first axis is length: + # tensor_list: Batch x (Length, ...) + tensor_list = [torch.from_numpy(a) for a in array_list] + # tensor: (Batch, Length, ...) + tensor = pad_list(tensor_list, pad_value) + output[key] = tensor + + # lens: (Batch,) + if key not in not_sequence: + lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long) + output[key + "_lengths"] = lens + + output = (uttids, output) + assert check_return_type(output) + return output diff --git a/espnet2/train/dataset.py b/espnet2/train/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5601d399f8b9e9b0bdcb7d9f0502d1c31f1e08ed --- /dev/null +++ b/espnet2/train/dataset.py @@ -0,0 +1,453 @@ +from abc import ABC +from abc import abstractmethod +import collections +import copy +import functools +import logging +import numbers +import re +from typing import Any +from typing import Callable +from typing import Collection +from typing import Dict +from typing import Mapping +from typing import Tuple +from typing import Union + +import h5py +import humanfriendly +import kaldiio +import numpy as np +import torch +from torch.utils.data.dataset import Dataset +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet2.fileio.npy_scp import NpyScpReader +from espnet2.fileio.rand_gen_dataset import FloatRandomGenerateDataset +from espnet2.fileio.rand_gen_dataset import IntRandomGenerateDataset +from espnet2.fileio.read_text import load_num_sequence_text +from espnet2.fileio.read_text import read_2column_text +from espnet2.fileio.rttm import RttmReader +from espnet2.fileio.sound_scp import SoundScpReader +from espnet2.utils.sized_dict import SizedDict + + +class AdapterForSoundScpReader(collections.abc.Mapping): + def __init__(self, loader, dtype=None): + assert check_argument_types() + self.loader = loader + self.dtype = dtype + self.rate = None + + def keys(self): + return self.loader.keys() + + def __len__(self): + return len(self.loader) + + def __iter__(self): + return iter(self.loader) + + def __getitem__(self, key: str) -> np.ndarray: + retval = self.loader[key] + + if isinstance(retval, tuple): + assert len(retval) == 2, len(retval) + if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray): + # sound scp case + rate, array = retval + elif isinstance(retval[0], int) and isinstance(retval[1], np.ndarray): + # Extended ark format case + array, rate = retval + else: + raise RuntimeError( + f"Unexpected type: {type(retval[0])}, {type(retval[1])}" + ) + + if self.rate is not None and self.rate != rate: + raise RuntimeError( + f"Sampling rates are mismatched: {self.rate} != {rate}" + ) + self.rate = rate + # Multichannel wave fie + # array: (NSample, Channel) or (Nsample) + if self.dtype is not None: + array = array.astype(self.dtype) + + else: + # Normal ark case + assert isinstance(retval, np.ndarray), type(retval) + array = retval + if self.dtype is not None: + array = array.astype(self.dtype) + + assert isinstance(array, np.ndarray), type(array) + return array + + +class H5FileWrapper: + def __init__(self, path: str): + self.path = path + self.h5_file = h5py.File(path, "r") + + def __repr__(self) -> str: + return str(self.h5_file) + + def __len__(self) -> int: + return len(self.h5_file) + + def __iter__(self): + return iter(self.h5_file) + + def __getitem__(self, key) -> np.ndarray: + value = self.h5_file[key] + return value[()] + + +def sound_loader(path, float_dtype=None): + # The file is as follows: + # utterance_id_A /some/where/a.wav + # utterance_id_B /some/where/a.flac + + # NOTE(kamo): SoundScpReader doesn't support pipe-fashion + # like Kaldi e.g. "cat a.wav |". + # NOTE(kamo): The audio signal is normalized to [-1,1] range. + loader = SoundScpReader(path, normalize=True, always_2d=False) + + # SoundScpReader.__getitem__() returns Tuple[int, ndarray], + # but ndarray is desired, so Adapter class is inserted here + return AdapterForSoundScpReader(loader, float_dtype) + + +def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0): + loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd) + return AdapterForSoundScpReader(loader, float_dtype) + + +def rand_int_loader(filepath, loader_type): + # e.g. rand_int_3_10 + try: + low, high = map(int, loader_type[len("rand_int_") :].split("_")) + except ValueError: + raise RuntimeError(f"e.g rand_int_3_10: but got {loader_type}") + return IntRandomGenerateDataset(filepath, low, high) + + +DATA_TYPES = { + "sound": dict( + func=sound_loader, + kwargs=["float_dtype"], + help="Audio format types which supported by sndfile wav, flac, etc." + "\n\n" + " utterance_id_a a.wav\n" + " utterance_id_b b.wav\n" + " ...", + ), + "kaldi_ark": dict( + func=kaldi_loader, + kwargs=["max_cache_fd"], + help="Kaldi-ark file type." + "\n\n" + " utterance_id_A /some/where/a.ark:123\n" + " utterance_id_B /some/where/a.ark:456\n" + " ...", + ), + "npy": dict( + func=NpyScpReader, + kwargs=[], + help="Npy file format." + "\n\n" + " utterance_id_A /some/where/a.npy\n" + " utterance_id_B /some/where/b.npy\n" + " ...", + ), + "text_int": dict( + func=functools.partial(load_num_sequence_text, loader_type="text_int"), + kwargs=[], + help="A text file in which is written a sequence of interger numbers " + "separated by space." + "\n\n" + " utterance_id_A 12 0 1 3\n" + " utterance_id_B 3 3 1\n" + " ...", + ), + "csv_int": dict( + func=functools.partial(load_num_sequence_text, loader_type="csv_int"), + kwargs=[], + help="A text file in which is written a sequence of interger numbers " + "separated by comma." + "\n\n" + " utterance_id_A 100,80\n" + " utterance_id_B 143,80\n" + " ...", + ), + "text_float": dict( + func=functools.partial(load_num_sequence_text, loader_type="text_float"), + kwargs=[], + help="A text file in which is written a sequence of float numbers " + "separated by space." + "\n\n" + " utterance_id_A 12. 3.1 3.4 4.4\n" + " utterance_id_B 3. 3.12 1.1\n" + " ...", + ), + "csv_float": dict( + func=functools.partial(load_num_sequence_text, loader_type="csv_float"), + kwargs=[], + help="A text file in which is written a sequence of float numbers " + "separated by comma." + "\n\n" + " utterance_id_A 12.,3.1,3.4,4.4\n" + " utterance_id_B 3.,3.12,1.1\n" + " ...", + ), + "text": dict( + func=read_2column_text, + kwargs=[], + help="Return text as is. The text must be converted to ndarray " + "by 'preprocess'." + "\n\n" + " utterance_id_A hello world\n" + " utterance_id_B foo bar\n" + " ...", + ), + "hdf5": dict( + func=H5FileWrapper, + kwargs=[], + help="A HDF5 file which contains arrays at the first level or the second level." + " >>> f = h5py.File('file.h5')\n" + " >>> array1 = f['utterance_id_A']\n" + " >>> array2 = f['utterance_id_B']\n", + ), + "rand_float": dict( + func=FloatRandomGenerateDataset, + kwargs=[], + help="Generate random float-ndarray which has the given shapes " + "in the file." + "\n\n" + " utterance_id_A 3,4\n" + " utterance_id_B 10,4\n" + " ...", + ), + "rand_int_\\d+_\\d+": dict( + func=rand_int_loader, + kwargs=["loader_type"], + help="e.g. 'rand_int_0_10'. Generate random int-ndarray which has the given " + "shapes in the path. " + "Give the lower and upper value by the file type. e.g. " + "rand_int_0_10 -> Generate integers from 0 to 10." + "\n\n" + " utterance_id_A 3,4\n" + " utterance_id_B 10,4\n" + " ...", + ), + "rttm": dict( + func=RttmReader, + kwargs=[], + help="rttm file loader, currently support for speaker diarization" + "\n\n" + " SPEAKER file1 1 0 1023 spk1 " + " SPEAKER file1 2 4000 3023 spk2 " + " SPEAKER file1 3 500 4023 spk1 " + " END file1 4023 " + " ...", + ), +} + + +class AbsDataset(Dataset, ABC): + @abstractmethod + def has_name(self, name) -> bool: + raise NotImplementedError + + @abstractmethod + def names(self) -> Tuple[str, ...]: + raise NotImplementedError + + @abstractmethod + def __getitem__(self, uid) -> Tuple[Any, Dict[str, np.ndarray]]: + raise NotImplementedError + + +class ESPnetDataset(AbsDataset): + """Pytorch Dataset class for ESPNet. + + Examples: + >>> dataset = ESPnetDataset([('wav.scp', 'input', 'sound'), + ... ('token_int', 'output', 'text_int')], + ... ) + ... uttid, data = dataset['uttid'] + {'input': per_utt_array, 'output': per_utt_array} + """ + + def __init__( + self, + path_name_type_list: Collection[Tuple[str, str, str]], + preprocess: Callable[ + [str, Dict[str, np.ndarray]], Dict[str, np.ndarray] + ] = None, + float_dtype: str = "float32", + int_dtype: str = "long", + max_cache_size: Union[float, int, str] = 0.0, + max_cache_fd: int = 0, + ): + assert check_argument_types() + if len(path_name_type_list) == 0: + raise ValueError( + '1 or more elements are required for "path_name_type_list"' + ) + + path_name_type_list = copy.deepcopy(path_name_type_list) + self.preprocess = preprocess + + self.float_dtype = float_dtype + self.int_dtype = int_dtype + self.max_cache_fd = max_cache_fd + + self.loader_dict = {} + self.debug_info = {} + for path, name, _type in path_name_type_list: + if name in self.loader_dict: + raise RuntimeError(f'"{name}" is duplicated for data-key') + + loader = self._build_loader(path, _type) + self.loader_dict[name] = loader + self.debug_info[name] = path, _type + if len(self.loader_dict[name]) == 0: + raise RuntimeError(f"{path} has no samples") + + # TODO(kamo): Should check consistency of each utt-keys? + + if isinstance(max_cache_size, str): + max_cache_size = humanfriendly.parse_size(max_cache_size) + self.max_cache_size = max_cache_size + if max_cache_size > 0: + self.cache = SizedDict(shared=True) + else: + self.cache = None + + def _build_loader( + self, path: str, loader_type: str + ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]: + """Helper function to instantiate Loader. + + Args: + path: The file path + loader_type: loader_type. sound, npy, text_int, text_float, etc + """ + for key, dic in DATA_TYPES.items(): + # e.g. loader_type="sound" + # -> return DATA_TYPES["sound"]["func"](path) + if re.match(key, loader_type): + kwargs = {} + for key2 in dic["kwargs"]: + if key2 == "loader_type": + kwargs["loader_type"] = loader_type + elif key2 == "float_dtype": + kwargs["float_dtype"] = self.float_dtype + elif key2 == "int_dtype": + kwargs["int_dtype"] = self.int_dtype + elif key2 == "max_cache_fd": + kwargs["max_cache_fd"] = self.max_cache_fd + else: + raise RuntimeError(f"Not implemented keyword argument: {key2}") + + func = dic["func"] + try: + return func(path, **kwargs) + except Exception: + if hasattr(func, "__name__"): + name = func.__name__ + else: + name = str(func) + logging.error(f"An error happend with {name}({path})") + raise + else: + raise RuntimeError(f"Not supported: loader_type={loader_type}") + + def has_name(self, name) -> bool: + return name in self.loader_dict + + def names(self) -> Tuple[str, ...]: + return tuple(self.loader_dict) + + def __iter__(self): + return iter(next(iter(self.loader_dict.values()))) + + def __repr__(self): + _mes = self.__class__.__name__ + _mes += "(" + for name, (path, _type) in self.debug_info.items(): + _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}' + _mes += f"\n preprocess: {self.preprocess})" + return _mes + + def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]: + assert check_argument_types() + + # Change integer-id to string-id + if isinstance(uid, int): + d = next(iter(self.loader_dict.values())) + uid = list(d)[uid] + + if self.cache is not None and uid in self.cache: + data = self.cache[uid] + return uid, data + + data = {} + # 1. Load data from each loaders + for name, loader in self.loader_dict.items(): + try: + value = loader[uid] + if isinstance(value, (list, tuple)): + value = np.array(value) + if not isinstance( + value, (np.ndarray, torch.Tensor, str, numbers.Number) + ): + raise TypeError( + f"Must be ndarray, torch.Tensor, str or Number: {type(value)}" + ) + except Exception: + path, _type = self.debug_info[name] + logging.error( + f"Error happened with path={path}, type={_type}, id={uid}" + ) + raise + + # torch.Tensor is converted to ndarray + if isinstance(value, torch.Tensor): + value = value.numpy() + elif isinstance(value, numbers.Number): + value = np.array([value]) + data[name] = value + + # 2. [Option] Apply preprocessing + # e.g. espnet2.train.preprocessor:CommonPreprocessor + if self.preprocess is not None: + data = self.preprocess(uid, data) + + # 3. Force data-precision + for name in data: + value = data[name] + if not isinstance(value, np.ndarray): + raise RuntimeError( + f"All values must be converted to np.ndarray object " + f'by preprocessing, but "{name}" is still {type(value)}.' + ) + + # Cast to desired type + if value.dtype.kind == "f": + value = value.astype(self.float_dtype) + elif value.dtype.kind == "i": + value = value.astype(self.int_dtype) + else: + raise NotImplementedError(f"Not supported dtype: {value.dtype}") + data[name] = value + + if self.cache is not None and self.cache.size < self.max_cache_size: + self.cache[uid] = data + + retval = uid, data + assert check_return_type(retval) + return retval diff --git a/espnet2/train/distributed_utils.py b/espnet2/train/distributed_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f2c56c69f281951f211903acd67481f42a19ba9 --- /dev/null +++ b/espnet2/train/distributed_utils.py @@ -0,0 +1,370 @@ +import dataclasses +import os +import socket +from typing import Optional + +import torch +import torch.distributed + + +@dataclasses.dataclass +class DistributedOption: + # Enable distributed Training + distributed: bool = False + # torch.distributed.Backend: "nccl", "mpi", "gloo", or "tcp" + dist_backend: str = "nccl" + # if init_method="env://", + # env values of "MASTER_PORT", "MASTER_ADDR", "WORLD_SIZE", and "RANK" are referred. + dist_init_method: str = "env://" + dist_world_size: Optional[int] = None + dist_rank: Optional[int] = None + local_rank: Optional[int] = None + ngpu: int = 0 + dist_master_addr: Optional[str] = None + dist_master_port: Optional[int] = None + dist_launcher: Optional[str] = None + multiprocessing_distributed: bool = True + + def init_options(self): + if self.distributed: + if self.dist_init_method == "env://": + if get_master_addr(self.dist_master_addr, self.dist_launcher) is None: + raise RuntimeError( + "--dist_master_addr or MASTER_ADDR must be set " + "if --dist_init_method == 'env://'" + ) + if get_master_port(self.dist_master_port) is None: + raise RuntimeError( + "--dist_master_port or MASTER_PORT must be set " + "if --dist_init_port == 'env://'" + ) + + # About priority order: + # If --dist_* is specified: + # Use the value of --dist_rank and overwrite it environ just in case. + # elif environ is set: + # Use the value of environ and set it to self + self.dist_rank = get_rank(self.dist_rank, self.dist_launcher) + self.dist_world_size = get_world_size( + self.dist_world_size, self.dist_launcher + ) + self.local_rank = get_local_rank(self.local_rank, self.dist_launcher) + + if self.local_rank is not None: + if self.ngpu > 1: + raise RuntimeError(f"Assuming 1GPU in this case: ngpu={self.ngpu}") + if "CUDA_VISIBLE_DEVICES" in os.environ: + cvd = os.environ["CUDA_VISIBLE_DEVICES"] + if self.local_rank >= len(cvd.split(",")): + raise RuntimeError( + f"LOCAL_RANK={self.local_rank} is bigger " + f"than the number of visible devices: {cvd}" + ) + + if ( + self.dist_rank is not None + and self.dist_world_size is not None + and self.dist_rank >= self.dist_world_size + ): + raise RuntimeError( + f"RANK >= WORLD_SIZE: {self.dist_rank} >= {self.dist_world_size}" + ) + + if self.dist_init_method == "env://": + self.dist_master_addr = get_master_addr( + self.dist_master_addr, self.dist_launcher + ) + self.dist_master_port = get_master_port(self.dist_master_port) + if ( + self.dist_master_addr is not None + and self.dist_master_port is not None + ): + self.dist_init_method = ( + f"tcp://{self.dist_master_addr}:{self.dist_master_port}" + ) + + def init_torch_distributed(self): + if self.distributed: + # See: + # https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/env.html + os.environ.setdefault("NCCL_DEBUG", "INFO") + + # See: + # https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group + os.environ.setdefault("NCCL_BLOCKING_WAIT", "1") + + torch.distributed.init_process_group( + backend=self.dist_backend, + init_method=self.dist_init_method, + world_size=self.dist_world_size, + rank=self.dist_rank, + ) + + # About distributed model: + # if self.local_rank is not None and ngpu == 1 + # => Distributed with n-Process and n-GPU + # if self.local_rank is None and ngpu >= 1 + # => Distributed with 1-Process and n-GPU + if self.local_rank is not None and self.ngpu > 0: + torch.cuda.set_device(self.local_rank) + + +def resolve_distributed_mode(args): + # Note that args.distributed is set by only this function. + # and ArgumentParser doesn't have such option + + if args.multiprocessing_distributed: + num_nodes = get_num_nodes(args.dist_world_size, args.dist_launcher) + # a. multi-node + if num_nodes > 1: + args.distributed = True + # b. single-node and multi-gpu with multiprocessing_distributed mode + elif args.ngpu > 1: + args.distributed = True + # c. single-node and single-gpu + else: + args.distributed = False + + if args.ngpu <= 1: + # Disable multiprocessing_distributed mode if 1process per node or cpu mode + args.multiprocessing_distributed = False + if args.ngpu == 1: + # If the number of GPUs equals to 1 with multiprocessing_distributed mode, + # LOCAL_RANK is always 0 + args.local_rank = 0 + + if num_nodes > 1 and get_node_rank(args.dist_rank, args.dist_launcher) is None: + raise RuntimeError( + "--dist_rank or RANK must be set " + "if --multiprocessing_distributed == true" + ) + + # Note that RANK, LOCAL_RANK, and WORLD_SIZE is automatically set, + # so we don't need to check here + else: + # d. multiprocess and multi-gpu with external launcher + # e.g. torch.distributed.launch + if get_world_size(args.dist_world_size, args.dist_launcher) > 1: + args.distributed = True + # e. single-process + else: + args.distributed = False + + if args.distributed and args.ngpu > 0: + if get_local_rank(args.local_rank, args.dist_launcher) is None: + raise RuntimeError( + "--local_rank or LOCAL_RANK must be set " + "if --multiprocessing_distributed == false" + ) + if args.distributed: + if get_node_rank(args.dist_rank, args.dist_launcher) is None: + raise RuntimeError( + "--dist_rank or RANK must be set " + "if --multiprocessing_distributed == false" + ) + if args.distributed and args.dist_launcher == "slurm" and not is_in_slurm_step(): + raise RuntimeError("Launch by 'srun' command if --dist_launcher='slurm'") + + +def is_in_slurm_job() -> bool: + return "SLURM_PROCID" in os.environ and "SLURM_NTASKS" in os.environ + + +def is_in_slurm_step() -> bool: + return ( + is_in_slurm_job() + and "SLURM_STEP_NUM_NODES" in os.environ + and "SLURM_STEP_NODELIST" in os.environ + ) + + +def _int_or_none(x: Optional[str]) -> Optional[int]: + if x is None: + return x + return int(x) + + +def free_port(): + """Find free port using bind(). + + There are some interval between finding this port and using it + and the other process might catch the port by that time. + Thus it is not guaranteed that the port is really empty. + + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + +def get_rank(prior=None, launcher: str = None) -> Optional[int]: + if prior is None: + if launcher == "slurm": + if not is_in_slurm_step(): + raise RuntimeError("This process seems not to be launched by 'srun'") + prior = os.environ["SLURM_PROCID"] + elif launcher == "mpi": + raise RuntimeError( + "launcher=mpi is used for 'multiprocessing-distributed' mode" + ) + elif launcher is not None: + raise RuntimeError(f"launcher='{launcher}' is not supported") + + if prior is not None: + return int(prior) + else: + # prior is None and RANK is None -> RANK = None + return _int_or_none(os.environ.get("RANK")) + + +def get_world_size(prior=None, launcher: str = None) -> int: + if prior is None: + if launcher == "slurm": + if not is_in_slurm_step(): + raise RuntimeError("This process seems not to be launched by 'srun'") + prior = int(os.environ["SLURM_NTASKS"]) + elif launcher == "mpi": + raise RuntimeError( + "launcher=mpi is used for 'multiprocessing-distributed' mode" + ) + elif launcher is not None: + raise RuntimeError(f"launcher='{launcher}' is not supported") + + if prior is not None: + return int(prior) + else: + # prior is None and WORLD_SIZE is None -> WORLD_SIZE = 1 + return int(os.environ.get("WORLD_SIZE", "1")) + + +def get_local_rank(prior=None, launcher: str = None) -> Optional[int]: + # LOCAL_RANK is same as GPU device id + + if prior is None: + if launcher == "slurm": + if not is_in_slurm_step(): + raise RuntimeError("This process seems not to be launched by 'srun'") + + prior = int(os.environ["SLURM_LOCALID"]) + elif launcher == "mpi": + raise RuntimeError( + "launcher=mpi is used for 'multiprocessing-distributed' mode" + ) + elif launcher is not None: + raise RuntimeError(f"launcher='{launcher}' is not supported") + + if prior is not None: + return int(prior) + + elif "LOCAL_RANK" in os.environ: + return int(os.environ["LOCAL_RANK"]) + + elif "CUDA_VISIBLE_DEVICES" in os.environ: + # There are two possibility: + # - "CUDA_VISIBLE_DEVICES" is set to multiple GPU ids. e.g. "0.1,2" + # => This intends to specify multiple devices to to be used exactly + # and local_rank information is possibly insufficient. + # - "CUDA_VISIBLE_DEVICES" is set to an id. e.g. "1" + # => This could be used for LOCAL_RANK + cvd = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + if len(cvd) == 1 and "LOCAL_RANK" not in os.environ: + # If CUDA_VISIBLE_DEVICES is set and LOCAL_RANK is not set, + # then use it as LOCAL_RANK. + + # Unset CUDA_VISIBLE_DEVICES + # because the other device must be visible to communicate + return int(os.environ.pop("CUDA_VISIBLE_DEVICES")) + else: + return None + else: + return None + + +def get_master_addr(prior=None, launcher: str = None) -> Optional[str]: + if prior is None: + if launcher == "slurm": + if not is_in_slurm_step(): + raise RuntimeError("This process seems not to be launched by 'srun'") + + # e.g nodelist = foo[1-10],bar[3-8] or foo4,bar[2-10] + nodelist = os.environ["SLURM_STEP_NODELIST"] + prior = nodelist.split(",")[0].split("-")[0].replace("[", "") + + if prior is not None: + return str(prior) + else: + return os.environ.get("MASTER_ADDR") + + +def get_master_port(prior=None) -> Optional[int]: + if prior is not None: + return prior + else: + return _int_or_none(os.environ.get("MASTER_PORT")) + + +def get_node_rank(prior=None, launcher: str = None) -> Optional[int]: + """Get Node Rank. + + Use for "multiprocessing distributed" mode. + The initial RANK equals to the Node id in this case and + the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed. + + """ + if prior is not None: + return prior + elif launcher == "slurm": + if not is_in_slurm_step(): + raise RuntimeError("This process seems not to be launched by 'srun'") + + # Assume ntasks_per_node == 1 + if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]: + raise RuntimeError( + "Run with --ntasks_per_node=1 if mutliprocessing_distributed=true" + ) + return int(os.environ["SLURM_NODEID"]) + elif launcher == "mpi": + # Use mpi4py only for initialization and not using for communication + from mpi4py import MPI + + comm = MPI.COMM_WORLD + # Assume ntasks_per_node == 1 (We can't check whether it is or not) + return comm.Get_rank() + elif launcher is not None: + raise RuntimeError(f"launcher='{launcher}' is not supported") + else: + return _int_or_none(os.environ.get("RANK")) + + +def get_num_nodes(prior=None, launcher: str = None) -> Optional[int]: + """Get the number of nodes. + + Use for "multiprocessing distributed" mode. + RANK equals to the Node id in this case and + the real Rank is set as (nGPU * NodeID) + LOCAL_RANK in torch.distributed. + + """ + if prior is not None: + return prior + elif launcher == "slurm": + if not is_in_slurm_step(): + raise RuntimeError("This process seems not to be launched by 'srun'") + + # Assume ntasks_per_node == 1 + if os.environ["SLURM_STEP_NUM_NODES"] != os.environ["SLURM_NTASKS"]: + raise RuntimeError( + "Run with --ntasks_per_node=1 if mutliprocessing_distributed=true" + ) + return int(os.environ["SLURM_STEP_NUM_NODES"]) + elif launcher == "mpi": + # Use mpi4py only for initialization and not using for communication + from mpi4py import MPI + + comm = MPI.COMM_WORLD + # Assume ntasks_per_node == 1 (We can't check whether it is or not) + return comm.Get_size() + elif launcher is not None: + raise RuntimeError(f"launcher='{launcher}' is not supported") + else: + # prior is None -> NUM_NODES = 1 + return int(os.environ.get("WORLD_SIZE", 1)) diff --git a/espnet2/train/iterable_dataset.py b/espnet2/train/iterable_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cf1ccd33cc5d6574f47658eaa7024e548da1e114 --- /dev/null +++ b/espnet2/train/iterable_dataset.py @@ -0,0 +1,241 @@ +import copy +from distutils.version import LooseVersion +from io import StringIO +from pathlib import Path +from typing import Callable +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import Tuple +from typing import Union + +import kaldiio +import numpy as np +import soundfile +import torch +from typeguard import check_argument_types + +from espnet2.train.dataset import ESPnetDataset + +if LooseVersion(torch.__version__) >= LooseVersion("1.2"): + from torch.utils.data.dataset import IterableDataset +else: + from torch.utils.data.dataset import Dataset as IterableDataset + + +def load_kaldi(input): + retval = kaldiio.load_mat(input) + if isinstance(retval, tuple): + assert len(retval) == 2, len(retval) + if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray): + # sound scp case + rate, array = retval + elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray): + # Extended ark format case + array, rate = retval + else: + raise RuntimeError(f"Unexpected type: {type(retval[0])}, {type(retval[1])}") + + # Multichannel wave fie + # array: (NSample, Channel) or (Nsample) + + else: + # Normal ark case + assert isinstance(retval, np.ndarray), type(retval) + array = retval + return array + + +DATA_TYPES = { + "sound": lambda x: soundfile.read(x)[0], + "kaldi_ark": load_kaldi, + "npy": np.load, + "text_int": lambda x: np.loadtxt( + StringIO(x), ndmin=1, dtype=np.long, delimiter=" " + ), + "csv_int": lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=","), + "text_float": lambda x: np.loadtxt( + StringIO(x), ndmin=1, dtype=np.float32, delimiter=" " + ), + "csv_float": lambda x: np.loadtxt( + StringIO(x), ndmin=1, dtype=np.float32, delimiter="," + ), + "text": lambda x: x, +} + + +class IterableESPnetDataset(IterableDataset): + """Pytorch Dataset class for ESPNet. + + Examples: + >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'), + ... ('token_int', 'output', 'text_int')], + ... ) + >>> for uid, data in dataset: + ... data + {'input': per_utt_array, 'output': per_utt_array} + """ + + def __init__( + self, + path_name_type_list: Collection[Tuple[str, str, str]], + preprocess: Callable[ + [str, Dict[str, np.ndarray]], Dict[str, np.ndarray] + ] = None, + float_dtype: str = "float32", + int_dtype: str = "long", + key_file: str = None, + ): + assert check_argument_types() + if len(path_name_type_list) == 0: + raise ValueError( + '1 or more elements are required for "path_name_type_list"' + ) + + path_name_type_list = copy.deepcopy(path_name_type_list) + self.preprocess = preprocess + + self.float_dtype = float_dtype + self.int_dtype = int_dtype + self.key_file = key_file + + self.debug_info = {} + non_iterable_list = [] + self.path_name_type_list = [] + + for path, name, _type in path_name_type_list: + if name in self.debug_info: + raise RuntimeError(f'"{name}" is duplicated for data-key') + self.debug_info[name] = path, _type + if _type not in DATA_TYPES: + non_iterable_list.append((path, name, _type)) + else: + self.path_name_type_list.append((path, name, _type)) + + if len(non_iterable_list) != 0: + # Some types doesn't support iterable mode + self.non_iterable_dataset = ESPnetDataset( + path_name_type_list=non_iterable_list, + preprocess=preprocess, + float_dtype=float_dtype, + int_dtype=int_dtype, + ) + else: + self.non_iterable_dataset = None + + if Path(Path(path_name_type_list[0][0]).parent, "utt2category").exists(): + self.apply_utt2category = True + else: + self.apply_utt2category = False + + def has_name(self, name) -> bool: + return name in self.debug_info + + def names(self) -> Tuple[str, ...]: + return tuple(self.debug_info) + + def __repr__(self): + _mes = self.__class__.__name__ + _mes += "(" + for name, (path, _type) in self.debug_info.items(): + _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}' + _mes += f"\n preprocess: {self.preprocess})" + return _mes + + def __iter__(self) -> Iterable[Tuple[Union[str, int], Dict[str, np.ndarray]]]: + if self.key_file is not None: + uid_iter = ( + line.rstrip().split(maxsplit=1)[0] + for line in open(self.key_file, encoding="utf-8") + ) + elif len(self.path_name_type_list) != 0: + uid_iter = ( + line.rstrip().split(maxsplit=1)[0] + for line in open(self.path_name_type_list[0][0], encoding="utf-8") + ) + else: + uid_iter = iter(self.non_iterable_dataset) + + files = [open(lis[0], encoding="utf-8") for lis in self.path_name_type_list] + + worker_info = torch.utils.data.get_worker_info() + + linenum = 0 + count = 0 + for count, uid in enumerate(uid_iter, 1): + # If num_workers>=1, split keys + if worker_info is not None: + if (count - 1) % worker_info.num_workers != worker_info.id: + continue + + # 1. Read a line from each file + while True: + keys = [] + values = [] + for f in files: + linenum += 1 + try: + line = next(f) + except StopIteration: + raise RuntimeError(f"{uid} is not found in the files") + sps = line.rstrip().split(maxsplit=1) + if len(sps) != 2: + raise RuntimeError( + f"This line doesn't include a space:" + f" {f}:L{linenum}: {line})" + ) + key, value = sps + keys.append(key) + values.append(value) + + for k_idx, k in enumerate(keys): + if k != keys[0]: + raise RuntimeError( + f"Keys are mismatched. Text files (idx={k_idx}) is " + f"not sorted or not having same keys at L{linenum}" + ) + + # If the key is matched, break the loop + if len(keys) == 0 or keys[0] == uid: + break + + # 2. Load the entry from each line and create a dict + data = {} + # 2.a. Load data streamingly + for value, (path, name, _type) in zip(values, self.path_name_type_list): + func = DATA_TYPES[_type] + # Load entry + array = func(value) + data[name] = array + if self.non_iterable_dataset is not None: + # 2.b. Load data from non-iterable dataset + _, from_non_iterable = self.non_iterable_dataset[uid] + data.update(from_non_iterable) + + # 3. [Option] Apply preprocessing + # e.g. espnet2.train.preprocessor:CommonPreprocessor + if self.preprocess is not None: + data = self.preprocess(uid, data) + + # 4. Force data-precision + for name in data: + value = data[name] + if not isinstance(value, np.ndarray): + raise RuntimeError( + f"All values must be converted to np.ndarray object " + f'by preprocessing, but "{name}" is still {type(value)}.' + ) + + # Cast to desired type + if value.dtype.kind == "f": + value = value.astype(self.float_dtype) + elif value.dtype.kind == "i": + value = value.astype(self.int_dtype) + else: + raise NotImplementedError(f"Not supported dtype: {value.dtype}") + data[name] = value + + yield uid, data + + if count == 0: + raise RuntimeError("No iteration") diff --git a/espnet2/train/preprocessor.py b/espnet2/train/preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..0197990e64d222fec6996e48372078503dc8384c --- /dev/null +++ b/espnet2/train/preprocessor.py @@ -0,0 +1,377 @@ +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from typing import Collection +from typing import Dict +from typing import Iterable +from typing import Union + +import numpy as np +import scipy.signal +import soundfile +from typeguard import check_argument_types +from typeguard import check_return_type + +from espnet2.text.build_tokenizer import build_tokenizer +from espnet2.text.cleaner import TextCleaner +from espnet2.text.token_id_converter import TokenIDConverter + + +class AbsPreprocessor(ABC): + def __init__(self, train: bool): + self.train = train + + @abstractmethod + def __call__( + self, uid: str, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + raise NotImplementedError + + +def framing( + x, + frame_length: int = 512, + frame_shift: int = 256, + centered: bool = True, + padded: bool = True, +): + if x.size == 0: + raise ValueError("Input array size is zero") + if frame_length < 1: + raise ValueError("frame_length must be a positive integer") + if frame_length > x.shape[-1]: + raise ValueError("frame_length is greater than input length") + if 0 >= frame_shift: + raise ValueError("frame_shift must be greater than 0") + + if centered: + pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [ + (frame_length // 2, frame_length // 2) + ] + x = np.pad(x, pad_shape, mode="constant", constant_values=0) + + if padded: + # Pad to integer number of windowed segments + # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep, + # with integer nseg + nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length + pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)] + x = np.pad(x, pad_shape, mode="constant", constant_values=0) + + # Created strided array of data segments + if frame_length == 1 and frame_length == frame_shift: + result = x[..., None] + else: + shape = x.shape[:-1] + ( + (x.shape[-1] - frame_length) // frame_shift + 1, + frame_length, + ) + strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1]) + result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + return result + + +def detect_non_silence( + x: np.ndarray, + threshold: float = 0.01, + frame_length: int = 1024, + frame_shift: int = 512, + window: str = "boxcar", +) -> np.ndarray: + """Power based voice activity detection. + + Args: + x: (Channel, Time) + >>> x = np.random.randn(1000) + >>> detect = detect_non_silence(x) + >>> assert x.shape == detect.shape + >>> assert detect.dtype == np.bool + """ + if x.shape[-1] < frame_length: + return np.full(x.shape, fill_value=True, dtype=np.bool) + + if x.dtype.kind == "i": + x = x.astype(np.float64) + # framed_w: (C, T, F) + framed_w = framing( + x, + frame_length=frame_length, + frame_shift=frame_shift, + centered=False, + padded=True, + ) + framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype) + # power: (C, T) + power = (framed_w ** 2).mean(axis=-1) + # mean_power: (C,) + mean_power = power.mean(axis=-1) + if np.all(mean_power == 0): + return np.full(x.shape, fill_value=True, dtype=np.bool) + # detect_frames: (C, T) + detect_frames = power / mean_power > threshold + # detects: (C, T, F) + detects = np.broadcast_to( + detect_frames[..., None], detect_frames.shape + (frame_shift,) + ) + # detects: (C, TF) + detects = detects.reshape(*detect_frames.shape[:-1], -1) + # detects: (C, TF) + return np.pad( + detects, + [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])], + mode="edge", + ) + + +class CommonPreprocessor(AbsPreprocessor): + def __init__( + self, + train: bool, + token_type: str = None, + token_list: Union[Path, str, Iterable[str]] = None, + bpemodel: Union[Path, str, Iterable[str]] = None, + text_cleaner: Collection[str] = None, + g2p_type: str = None, + unk_symbol: str = "", + space_symbol: str = "", + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + delimiter: str = None, + rir_scp: str = None, + rir_apply_prob: float = 1.0, + noise_scp: str = None, + noise_apply_prob: float = 1.0, + noise_db_range: str = "3_10", + speech_volume_normalize: float = None, + speech_name: str = "speech", + text_name: str = "text", + ): + super().__init__(train) + self.train = train + self.speech_name = speech_name + self.text_name = text_name + self.speech_volume_normalize = speech_volume_normalize + self.rir_apply_prob = rir_apply_prob + self.noise_apply_prob = noise_apply_prob + + if token_type is not None: + if token_list is None: + raise ValueError("token_list is required if token_type is not None") + self.text_cleaner = TextCleaner(text_cleaner) + + self.tokenizer = build_tokenizer( + token_type=token_type, + bpemodel=bpemodel, + delimiter=delimiter, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + g2p_type=g2p_type, + ) + self.token_id_converter = TokenIDConverter( + token_list=token_list, + unk_symbol=unk_symbol, + ) + else: + self.text_cleaner = None + self.tokenizer = None + self.token_id_converter = None + + if train and rir_scp is not None: + self.rirs = [] + with open(rir_scp, "r", encoding="utf-8") as f: + for line in f: + sps = line.strip().split(None, 1) + if len(sps) == 1: + self.rirs.append(sps[0]) + else: + self.rirs.append(sps[1]) + else: + self.rirs = None + + if train and noise_scp is not None: + self.noises = [] + with open(noise_scp, "r", encoding="utf-8") as f: + for line in f: + sps = line.strip().split(None, 1) + if len(sps) == 1: + self.noises.append(sps[0]) + else: + self.noises.append(sps[1]) + sps = noise_db_range.split("_") + if len(sps) == 1: + self.noise_db_low, self.noise_db_high = float(sps[0]) + elif len(sps) == 2: + self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1]) + else: + raise ValueError( + "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]" + ) + else: + self.noises = None + + def __call__( + self, uid: str, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + assert check_argument_types() + + if self.speech_name in data: + if self.train and self.rirs is not None and self.noises is not None: + speech = data[self.speech_name] + nsamples = len(speech) + + # speech: (Nmic, Time) + if speech.ndim == 1: + speech = speech[None, :] + else: + speech = speech.T + # Calc power on non shlence region + power = (speech[detect_non_silence(speech)] ** 2).mean() + + # 1. Convolve RIR + if self.rirs is not None and self.rir_apply_prob >= np.random.random(): + rir_path = np.random.choice(self.rirs) + if rir_path is not None: + rir, _ = soundfile.read( + rir_path, dtype=np.float64, always_2d=True + ) + + # rir: (Nmic, Time) + rir = rir.T + + # speech: (Nmic, Time) + # Note that this operation doesn't change the signal length + speech = scipy.signal.convolve(speech, rir, mode="full")[ + :, : speech.shape[1] + ] + # Reverse mean power to the original power + power2 = (speech[detect_non_silence(speech)] ** 2).mean() + speech = np.sqrt(power / max(power2, 1e-10)) * speech + + # 2. Add Noise + if ( + self.noises is not None + and self.rir_apply_prob >= np.random.random() + ): + noise_path = np.random.choice(self.noises) + if noise_path is not None: + noise_db = np.random.uniform( + self.noise_db_low, self.noise_db_high + ) + with soundfile.SoundFile(noise_path) as f: + if f.frames == nsamples: + noise = f.read(dtype=np.float64, always_2d=True) + elif f.frames < nsamples: + offset = np.random.randint(0, nsamples - f.frames) + # noise: (Time, Nmic) + noise = f.read(dtype=np.float64, always_2d=True) + # Repeat noise + noise = np.pad( + noise, + [(offset, nsamples - f.frames - offset), (0, 0)], + mode="wrap", + ) + else: + offset = np.random.randint(0, f.frames - nsamples) + f.seek(offset) + # noise: (Time, Nmic) + noise = f.read( + nsamples, dtype=np.float64, always_2d=True + ) + if len(noise) != nsamples: + raise RuntimeError(f"Something wrong: {noise_path}") + # noise: (Nmic, Time) + noise = noise.T + + noise_power = (noise ** 2).mean() + scale = ( + 10 ** (-noise_db / 20) + * np.sqrt(power) + / np.sqrt(max(noise_power, 1e-10)) + ) + speech = speech + scale * noise + + speech = speech.T + ma = np.max(np.abs(speech)) + if ma > 1.0: + speech /= ma + data[self.speech_name] = speech + + if self.speech_volume_normalize is not None: + speech = data[self.speech_name] + ma = np.max(np.abs(speech)) + data[self.speech_name] = speech * self.speech_volume_normalize / ma + + if self.text_name in data and self.tokenizer is not None: + text = data[self.text_name] + text = self.text_cleaner(text) + tokens = self.tokenizer.text2tokens(text) + text_ints = self.token_id_converter.tokens2ids(tokens) + data[self.text_name] = np.array(text_ints, dtype=np.int64) + assert check_return_type(data) + return data + + +class CommonPreprocessor_multi(AbsPreprocessor): + def __init__( + self, + train: bool, + token_type: str = None, + token_list: Union[Path, str, Iterable[str]] = None, + bpemodel: Union[Path, str, Iterable[str]] = None, + text_cleaner: Collection[str] = None, + g2p_type: str = None, + unk_symbol: str = "", + space_symbol: str = "", + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + delimiter: str = None, + speech_name: str = "speech", + text_name: list = ["text"], + ): + super().__init__(train) + self.train = train + self.speech_name = speech_name + self.text_name = text_name + + if token_type is not None: + if token_list is None: + raise ValueError("token_list is required if token_type is not None") + self.text_cleaner = TextCleaner(text_cleaner) + + self.tokenizer = build_tokenizer( + token_type=token_type, + bpemodel=bpemodel, + delimiter=delimiter, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + g2p_type=g2p_type, + ) + self.token_id_converter = TokenIDConverter( + token_list=token_list, + unk_symbol=unk_symbol, + ) + else: + self.text_cleaner = None + self.tokenizer = None + self.token_id_converter = None + + def __call__( + self, uid: str, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + assert check_argument_types() + + if self.speech_name in data: + # Nothing now: candidates: + # - STFT + # - Fbank + # - CMVN + # - Data augmentation + pass + + for text_n in self.text_name: + if text_n in data and self.tokenizer is not None: + text = data[text_n] + text = self.text_cleaner(text) + tokens = self.tokenizer.text2tokens(text) + text_ints = self.token_id_converter.tokens2ids(tokens) + data[text_n] = np.array(text_ints, dtype=np.int64) + assert check_return_type(data) + return data diff --git a/espnet2/train/reporter.py b/espnet2/train/reporter.py new file mode 100644 index 0000000000000000000000000000000000000000..61218ec5871bf97e3cdf614df8becf7f829eb220 --- /dev/null +++ b/espnet2/train/reporter.py @@ -0,0 +1,571 @@ +from collections import defaultdict +from contextlib import contextmanager +import dataclasses +import datetime +from distutils.version import LooseVersion +import logging +from pathlib import Path +import time +from typing import ContextManager +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +import warnings + +import humanfriendly +import numpy as np +import torch +from typeguard import check_argument_types +from typeguard import check_return_type +import wandb + +if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): + from torch.utils.tensorboard import SummaryWriter +else: + from tensorboardX import SummaryWriter + +Num = Union[float, int, complex, torch.Tensor, np.ndarray] + + +_reserved = {"time", "total_count"} + + +def to_reported_value(v: Num, weight: Num = None) -> "ReportedValue": + assert check_argument_types() + if isinstance(v, (torch.Tensor, np.ndarray)): + if np.prod(v.shape) != 1: + raise ValueError(f"v must be 0 or 1 dimension: {len(v.shape)}") + v = v.item() + + if isinstance(weight, (torch.Tensor, np.ndarray)): + if np.prod(weight.shape) != 1: + raise ValueError(f"weight must be 0 or 1 dimension: {len(weight.shape)}") + weight = weight.item() + + if weight is not None: + retval = WeightedAverage(v, weight) + else: + retval = Average(v) + assert check_return_type(retval) + return retval + + +def aggregate(values: Sequence["ReportedValue"]) -> Num: + assert check_argument_types() + + for v in values: + if not isinstance(v, type(values[0])): + raise ValueError( + f"Can't use different Reported type together: " + f"{type(v)} != {type(values[0])}" + ) + + if len(values) == 0: + warnings.warn("No stats found") + retval = np.nan + + elif isinstance(values[0], Average): + retval = np.nanmean([v.value for v in values]) + + elif isinstance(values[0], WeightedAverage): + # Excludes non finite values + invalid_indices = set() + for i, v in enumerate(values): + if not np.isfinite(v.value) or not np.isfinite(v.weight): + invalid_indices.add(i) + values = [v for i, v in enumerate(values) if i not in invalid_indices] + + if len(values) != 0: + # Calc weighed average. Weights are changed to sum-to-1. + sum_weights = sum(v.weight for i, v in enumerate(values)) + sum_value = sum(v.value * v.weight for i, v in enumerate(values)) + if sum_weights == 0: + warnings.warn("weight is zero") + retval = np.nan + else: + retval = sum_value / sum_weights + else: + warnings.warn("No valid stats found") + retval = np.nan + + else: + raise NotImplementedError(f"type={type(values[0])}") + assert check_return_type(retval) + return retval + + +class ReportedValue: + pass + + +@dataclasses.dataclass(frozen=True) +class Average(ReportedValue): + value: Num + + +@dataclasses.dataclass(frozen=True) +class WeightedAverage(ReportedValue): + value: Tuple[Num, Num] + weight: Num + + +class SubReporter: + """This class is used in Reporter. + + See the docstring of Reporter for the usage. + """ + + def __init__(self, key: str, epoch: int, total_count: int): + assert check_argument_types() + self.key = key + self.epoch = epoch + self.start_time = time.perf_counter() + self.stats = defaultdict(list) + self._finished = False + self.total_count = total_count + self.count = 0 + self._seen_keys_in_the_step = set() + + def get_total_count(self) -> int: + """Returns the number of iterations over all epochs.""" + return self.total_count + + def get_epoch(self) -> int: + return self.epoch + + def next(self): + """Close up this step and reset state for the next step""" + for key, stats_list in self.stats.items(): + if key not in self._seen_keys_in_the_step: + # Fill nan value if the key is not registered in this step + if isinstance(stats_list[0], WeightedAverage): + stats_list.append(to_reported_value(np.nan, 0)) + elif isinstance(stats_list[0], Average): + stats_list.append(to_reported_value(np.nan)) + else: + raise NotImplementedError(f"type={type(stats_list[0])}") + + assert len(stats_list) == self.count, (len(stats_list), self.count) + + self._seen_keys_in_the_step = set() + + def register( + self, + stats: Dict[str, Optional[Union[Num, Dict[str, Num]]]], + weight: Num = None, + ) -> None: + assert check_argument_types() + if self._finished: + raise RuntimeError("Already finished") + if len(self._seen_keys_in_the_step) == 0: + # Increment count as the first register in this step + self.total_count += 1 + self.count += 1 + + for key2, v in stats.items(): + if key2 in _reserved: + raise RuntimeError(f"{key2} is reserved.") + if key2 in self._seen_keys_in_the_step: + raise RuntimeError(f"{key2} is registered twice.") + if v is None: + v = np.nan + r = to_reported_value(v, weight) + + if key2 not in self.stats: + # If it's the first time to register the key, + # append nan values in front of the the value + # to make it same length to the other stats + # e.g. + # stat A: [0.4, 0.3, 0.5] + # stat B: [nan, nan, 0.2] + nan = to_reported_value(np.nan, None if weight is None else 0) + self.stats[key2].extend( + r if i == self.count - 1 else nan for i in range(self.count) + ) + else: + self.stats[key2].append(r) + self._seen_keys_in_the_step.add(key2) + + def log_message(self, start: int = None, end: int = None) -> str: + if self._finished: + raise RuntimeError("Already finished") + if start is None: + start = 0 + if start < 0: + start = self.count + start + if end is None: + end = self.count + + if self.count == 0 or start == end: + return "" + + message = f"{self.epoch}epoch:{self.key}:" f"{start + 1}-{end}batch: " + + for idx, (key2, stats_list) in enumerate(self.stats.items()): + assert len(stats_list) == self.count, (len(stats_list), self.count) + # values: List[ReportValue] + values = stats_list[start:end] + if idx != 0 and idx != len(stats_list): + message += ", " + + v = aggregate(values) + if abs(v) > 1.0e3: + message += f"{key2}={v:.3e}" + elif abs(v) > 1.0e-3: + message += f"{key2}={v:.3f}" + else: + message += f"{key2}={v:.3e}" + return message + + def tensorboard_add_scalar(self, summary_writer: SummaryWriter, start: int = None): + if start is None: + start = 0 + if start < 0: + start = self.count + start + + for key2, stats_list in self.stats.items(): + assert len(stats_list) == self.count, (len(stats_list), self.count) + # values: List[ReportValue] + values = stats_list[start:] + v = aggregate(values) + summary_writer.add_scalar(key2, v, self.total_count) + + def wandb_log(self, start: int = None, commit: bool = True): + if start is None: + start = 0 + if start < 0: + start = self.count + start + + d = {} + for key2, stats_list in self.stats.items(): + assert len(stats_list) == self.count, (len(stats_list), self.count) + # values: List[ReportValue] + values = stats_list[start:] + v = aggregate(values) + d[key2] = v + d["iteration"] = self.total_count + wandb.log(d, commit=commit) + + def finished(self) -> None: + self._finished = True + + @contextmanager + def measure_time(self, name: str): + start = time.perf_counter() + yield start + t = time.perf_counter() - start + self.register({name: t}) + + def measure_iter_time(self, iterable, name: str): + iterator = iter(iterable) + while True: + try: + start = time.perf_counter() + retval = next(iterator) + t = time.perf_counter() - start + self.register({name: t}) + yield retval + except StopIteration: + break + + +class Reporter: + """Reporter class. + + Examples: + + >>> reporter = Reporter() + >>> with reporter.observe('train') as sub_reporter: + ... for batch in iterator: + ... stats = dict(loss=0.2) + ... sub_reporter.register(stats) + + """ + + def __init__(self, epoch: int = 0): + assert check_argument_types() + if epoch < 0: + raise ValueError(f"epoch must be 0 or more: {epoch}") + self.epoch = epoch + # stats: Dict[int, Dict[str, Dict[str, float]]] + # e.g. self.stats[epoch]['train']['loss'] + self.stats = {} + + def get_epoch(self) -> int: + return self.epoch + + def set_epoch(self, epoch: int) -> None: + if epoch < 0: + raise ValueError(f"epoch must be 0 or more: {epoch}") + self.epoch = epoch + + @contextmanager + def observe(self, key: str, epoch: int = None) -> ContextManager[SubReporter]: + sub_reporter = self.start_epoch(key, epoch) + yield sub_reporter + # Receive the stats from sub_reporter + self.finish_epoch(sub_reporter) + + def start_epoch(self, key: str, epoch: int = None) -> SubReporter: + if epoch is not None: + if epoch < 0: + raise ValueError(f"epoch must be 0 or more: {epoch}") + self.epoch = epoch + + if self.epoch - 1 not in self.stats or key not in self.stats[self.epoch - 1]: + # If the previous epoch doesn't exist for some reason, + # maybe due to bug, this case also indicates 0-count. + if self.epoch - 1 != 0: + warnings.warn( + f"The stats of the previous epoch={self.epoch - 1}" + f"doesn't exist." + ) + total_count = 0 + else: + total_count = self.stats[self.epoch - 1][key]["total_count"] + + sub_reporter = SubReporter(key, self.epoch, total_count) + # Clear the stats for the next epoch if it exists + self.stats.pop(epoch, None) + return sub_reporter + + def finish_epoch(self, sub_reporter: SubReporter) -> None: + if self.epoch != sub_reporter.epoch: + raise RuntimeError( + f"Don't change epoch during observation: " + f"{self.epoch} != {sub_reporter.epoch}" + ) + + # Calc mean of current stats and set it as previous epochs stats + stats = {} + for key2, values in sub_reporter.stats.items(): + v = aggregate(values) + stats[key2] = v + + stats["time"] = datetime.timedelta( + seconds=time.perf_counter() - sub_reporter.start_time + ) + stats["total_count"] = sub_reporter.total_count + if LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): + if torch.cuda.is_initialized(): + stats["gpu_max_cached_mem_GB"] = ( + torch.cuda.max_memory_reserved() / 2 ** 30 + ) + else: + if torch.cuda.is_available() and torch.cuda.max_memory_cached() > 0: + stats["gpu_cached_mem_GB"] = torch.cuda.max_memory_cached() / 2 ** 30 + + self.stats.setdefault(self.epoch, {})[sub_reporter.key] = stats + sub_reporter.finished() + + def sort_epochs_and_values( + self, key: str, key2: str, mode: str + ) -> List[Tuple[int, float]]: + """Return the epoch which resulted the best value. + + Example: + >>> val = reporter.sort_epochs_and_values('eval', 'loss', 'min') + >>> e_1best, v_1best = val[0] + >>> e_2best, v_2best = val[1] + """ + if mode not in ("min", "max"): + raise ValueError(f"mode must min or max: {mode}") + if not self.has(key, key2): + raise KeyError(f"{key}.{key2} is not found: {self.get_all_keys()}") + + # iterate from the last epoch + values = [(e, self.stats[e][key][key2]) for e in self.stats] + + if mode == "min": + values = sorted(values, key=lambda x: x[1]) + else: + values = sorted(values, key=lambda x: -x[1]) + return values + + def sort_epochs(self, key: str, key2: str, mode: str) -> List[int]: + return [e for e, v in self.sort_epochs_and_values(key, key2, mode)] + + def sort_values(self, key: str, key2: str, mode: str) -> List[float]: + return [v for e, v in self.sort_epochs_and_values(key, key2, mode)] + + def get_best_epoch(self, key: str, key2: str, mode: str, nbest: int = 0) -> int: + return self.sort_epochs(key, key2, mode)[nbest] + + def check_early_stopping( + self, + patience: int, + key1: str, + key2: str, + mode: str, + epoch: int = None, + logger=None, + ) -> bool: + if logger is None: + logger = logging + if epoch is None: + epoch = self.get_epoch() + + best_epoch = self.get_best_epoch(key1, key2, mode) + if epoch - best_epoch > patience: + logger.info( + f"[Early stopping] {key1}.{key2} has not been " + f"improved {epoch - best_epoch} epochs continuously. " + f"The training was stopped at {epoch}epoch" + ) + return True + else: + return False + + def has(self, key: str, key2: str, epoch: int = None) -> bool: + if epoch is None: + epoch = self.get_epoch() + return ( + epoch in self.stats + and key in self.stats[epoch] + and key2 in self.stats[epoch][key] + ) + + def log_message(self, epoch: int = None) -> str: + if epoch is None: + epoch = self.get_epoch() + + message = "" + for key, d in self.stats[epoch].items(): + _message = "" + for key2, v in d.items(): + if v is not None: + if len(_message) != 0: + _message += ", " + if isinstance(v, float): + if abs(v) > 1.0e3: + _message += f"{key2}={v:.3e}" + elif abs(v) > 1.0e-3: + _message += f"{key2}={v:.3f}" + else: + _message += f"{key2}={v:.3e}" + elif isinstance(v, datetime.timedelta): + _v = humanfriendly.format_timespan(v) + _message += f"{key2}={_v}" + else: + _message += f"{key2}={v}" + if len(_message) != 0: + if len(message) == 0: + message += f"{epoch}epoch results: " + else: + message += ", " + message += f"[{key}] {_message}" + return message + + def get_value(self, key: str, key2: str, epoch: int = None): + if not self.has(key, key2): + raise KeyError(f"{key}.{key2} is not found in stats: {self.get_all_keys()}") + if epoch is None: + epoch = self.get_epoch() + return self.stats[epoch][key][key2] + + def get_keys(self, epoch: int = None) -> Tuple[str, ...]: + """Returns keys1 e.g. train,eval.""" + if epoch is None: + epoch = self.get_epoch() + return tuple(self.stats[epoch]) + + def get_keys2(self, key: str, epoch: int = None) -> Tuple[str, ...]: + """Returns keys2 e.g. loss,acc.""" + if epoch is None: + epoch = self.get_epoch() + d = self.stats[epoch][key] + keys2 = tuple(k for k in d if k not in ("time", "total_count")) + return keys2 + + def get_all_keys(self, epoch: int = None) -> Tuple[Tuple[str, str], ...]: + if epoch is None: + epoch = self.get_epoch() + all_keys = [] + for key in self.stats[epoch]: + for key2 in self.stats[epoch][key]: + all_keys.append((key, key2)) + return tuple(all_keys) + + def matplotlib_plot(self, output_dir: Union[str, Path]): + """Plot stats using Matplotlib and save images.""" + keys2 = set.union(*[set(self.get_keys2(k)) for k in self.get_keys()]) + for key2 in keys2: + keys = [k for k in self.get_keys() if key2 in self.get_keys2(k)] + plt = self._plot_stats(keys, key2) + p = output_dir / f"{key2}.png" + p.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(p) + + def _plot_stats(self, keys: Sequence[str], key2: str): + assert check_argument_types() + # str is also Sequence[str] + if isinstance(keys, str): + raise TypeError(f"Input as [{keys}]") + + import matplotlib + + matplotlib.use("agg") + import matplotlib.pyplot as plt + import matplotlib.ticker as ticker + + plt.clf() + + epochs = np.arange(1, self.get_epoch() + 1) + for key in keys: + y = [ + self.stats[e][key][key2] + if e in self.stats + and key in self.stats[e] + and key2 in self.stats[e][key] + else np.nan + for e in epochs + ] + assert len(epochs) == len(y), "Bug?" + + plt.plot(epochs, y, label=key, marker="x") + plt.legend() + plt.title(f"epoch vs {key2}") + # Force integer tick for x-axis + plt.gca().get_xaxis().set_major_locator(ticker.MaxNLocator(integer=True)) + plt.xlabel("epoch") + plt.ylabel(key2) + plt.grid() + + return plt + + def tensorboard_add_scalar(self, summary_writer: SummaryWriter, epoch: int = None): + if epoch is None: + epoch = self.get_epoch() + + for key1 in self.get_keys(epoch): + for key2 in self.stats[epoch][key1]: + if key2 in ("time", "total_count"): + continue + summary_writer.add_scalar( + f"{key1}_{key2}_epoch", + self.stats[epoch][key1][key2], + epoch, + ) + + def wandb_log(self, epoch: int = None, commit: bool = True): + if epoch is None: + epoch = self.get_epoch() + + d = {} + for key1 in self.get_keys(epoch): + for key2 in self.stats[epoch][key1]: + if key2 in ("time", "total_count"): + continue + d[f"{key1}_{key2}_epoch"] = self.stats[epoch][key1][key2] + d["epoch"] = epoch + wandb.log(d, commit=commit) + + def state_dict(self): + return {"stats": self.stats, "epoch": self.epoch} + + def load_state_dict(self, state_dict: dict): + self.epoch = state_dict["epoch"] + self.stats = state_dict["stats"] diff --git a/espnet2/train/trainer.py b/espnet2/train/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d86bbb7d45b07807b8d2a8bfb782d7715928191 --- /dev/null +++ b/espnet2/train/trainer.py @@ -0,0 +1,771 @@ +import argparse +from contextlib import contextmanager +import dataclasses +from dataclasses import is_dataclass +from distutils.version import LooseVersion +import logging +from pathlib import Path +import time +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import humanfriendly +import numpy as np +import torch +import torch.nn +import torch.optim +from typeguard import check_argument_types + +from espnet2.iterators.abs_iter_factory import AbsIterFactory +from espnet2.main_funcs.average_nbest_models import average_nbest_models +from espnet2.main_funcs.calculate_all_attentions import calculate_all_attentions +from espnet2.schedulers.abs_scheduler import AbsBatchStepScheduler +from espnet2.schedulers.abs_scheduler import AbsEpochStepScheduler +from espnet2.schedulers.abs_scheduler import AbsScheduler +from espnet2.schedulers.abs_scheduler import AbsValEpochStepScheduler +from espnet2.torch_utils.add_gradient_noise import add_gradient_noise +from espnet2.torch_utils.device_funcs import to_device +from espnet2.torch_utils.recursive_op import recursive_average +from espnet2.torch_utils.set_all_random_seed import set_all_random_seed +from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet2.train.distributed_utils import DistributedOption +from espnet2.train.reporter import Reporter +from espnet2.train.reporter import SubReporter +from espnet2.utils.build_dataclass import build_dataclass + +if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): + from torch.utils.tensorboard import SummaryWriter +else: + from tensorboardX import SummaryWriter +if torch.distributed.is_available(): + if LooseVersion(torch.__version__) > LooseVersion("1.0.1"): + from torch.distributed import ReduceOp + else: + from torch.distributed import reduce_op as ReduceOp +else: + ReduceOp = None + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast + from torch.cuda.amp import GradScaler +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + GradScaler = None + +try: + import fairscale +except ImportError: + fairscale = None + + +@dataclasses.dataclass +class TrainerOptions: + ngpu: int + resume: bool + use_amp: bool + train_dtype: str + grad_noise: bool + accum_grad: int + grad_clip: float + grad_clip_type: float + log_interval: Optional[int] + no_forward_run: bool + use_tensorboard: bool + use_wandb: bool + output_dir: Union[Path, str] + max_epoch: int + seed: int + sharded_ddp: bool + patience: Optional[int] + keep_nbest_models: Union[int, List[int]] + early_stopping_criterion: Sequence[str] + best_model_criterion: Sequence[Sequence[str]] + val_scheduler_criterion: Sequence[str] + unused_parameters: bool + + +class Trainer: + """Trainer having a optimizer. + + If you'd like to use multiple optimizers, then inherit this class + and override the methods if necessary - at least "train_one_epoch()" + + >>> class TwoOptimizerTrainer(Trainer): + ... @classmethod + ... def add_arguments(cls, parser): + ... ... + ... + ... @classmethod + ... def train_one_epoch(cls, model, optimizers, ...): + ... loss1 = model.model1(...) + ... loss1.backward() + ... optimizers[0].step() + ... + ... loss2 = model.model2(...) + ... loss2.backward() + ... optimizers[1].step() + + """ + + def __init__(self): + raise RuntimeError("This class can't be instantiated.") + + @classmethod + def build_options(cls, args: argparse.Namespace) -> TrainerOptions: + """Build options consumed by train(), eval(), and plot_attention()""" + assert check_argument_types() + return build_dataclass(TrainerOptions, args) + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + """Reserved for future development of another Trainer""" + pass + + @staticmethod + def resume( + checkpoint: Union[str, Path], + model: torch.nn.Module, + reporter: Reporter, + optimizers: Sequence[torch.optim.Optimizer], + schedulers: Sequence[Optional[AbsScheduler]], + scaler: Optional[GradScaler], + ngpu: int = 0, + ): + states = torch.load( + checkpoint, + map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", + ) + model.load_state_dict(states["model"]) + reporter.load_state_dict(states["reporter"]) + for optimizer, state in zip(optimizers, states["optimizers"]): + optimizer.load_state_dict(state) + for scheduler, state in zip(schedulers, states["schedulers"]): + if scheduler is not None: + scheduler.load_state_dict(state) + if scaler is not None: + if states["scaler"] is None: + logging.warning("scaler state is not found") + else: + scaler.load_state_dict(states["scaler"]) + + logging.info(f"The training was resumed using {checkpoint}") + + @classmethod + def run( + cls, + model: AbsESPnetModel, + optimizers: Sequence[torch.optim.Optimizer], + schedulers: Sequence[Optional[AbsScheduler]], + train_iter_factory: AbsIterFactory, + valid_iter_factory: AbsIterFactory, + plot_attention_iter_factory: Optional[AbsIterFactory], + trainer_options, + distributed_option: DistributedOption, + ) -> None: + """Perform training. This method performs the main process of training.""" + assert check_argument_types() + # NOTE(kamo): Don't check the type more strictly as far trainer_options + assert is_dataclass(trainer_options), type(trainer_options) + assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers)) + + if isinstance(trainer_options.keep_nbest_models, int): + keep_nbest_models = trainer_options.keep_nbest_models + else: + if len(trainer_options.keep_nbest_models) == 0: + logging.warning("No keep_nbest_models is given. Change to [1]") + trainer_options.keep_nbest_models = [1] + keep_nbest_models = max(trainer_options.keep_nbest_models) + + output_dir = Path(trainer_options.output_dir) + reporter = Reporter() + if trainer_options.use_amp: + if LooseVersion(torch.__version__) < LooseVersion("1.6.0"): + raise RuntimeError( + "Require torch>=1.6.0 for Automatic Mixed Precision" + ) + if trainer_options.sharded_ddp: + if fairscale is None: + raise RuntimeError( + "Requiring fairscale. Do 'pip install fairscale'" + ) + scaler = fairscale.optim.grad_scaler.ShardedGradScaler() + else: + scaler = GradScaler() + else: + scaler = None + + if trainer_options.resume and (output_dir / "checkpoint.pth").exists(): + cls.resume( + checkpoint=output_dir / "checkpoint.pth", + model=model, + optimizers=optimizers, + schedulers=schedulers, + reporter=reporter, + scaler=scaler, + ngpu=trainer_options.ngpu, + ) + + start_epoch = reporter.get_epoch() + 1 + if start_epoch == trainer_options.max_epoch + 1: + logging.warning( + f"The training has already reached at max_epoch: {start_epoch}" + ) + + if distributed_option.distributed: + if trainer_options.sharded_ddp: + dp_model = fairscale.nn.data_parallel.ShardedDataParallel( + module=model, + sharded_optimizer=optimizers, + ) + else: + dp_model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=( + # Perform multi-Process with multi-GPUs + [torch.cuda.current_device()] + if distributed_option.ngpu == 1 + # Perform single-Process with multi-GPUs + else None + ), + output_device=( + torch.cuda.current_device() + if distributed_option.ngpu == 1 + else None + ), + find_unused_parameters=trainer_options.unused_parameters, + ) + elif distributed_option.ngpu > 1: + dp_model = torch.nn.parallel.DataParallel( + model, + device_ids=list(range(distributed_option.ngpu)), + ) + else: + # NOTE(kamo): DataParallel also should work with ngpu=1, + # but for debuggability it's better to keep this block. + dp_model = model + + if trainer_options.use_tensorboard and ( + not distributed_option.distributed or distributed_option.dist_rank == 0 + ): + summary_writer = SummaryWriter(str(output_dir / "tensorboard")) + else: + summary_writer = None + + start_time = time.perf_counter() + for iepoch in range(start_epoch, trainer_options.max_epoch + 1): + if iepoch != start_epoch: + logging.info( + "{}/{}epoch started. Estimated time to finish: {}".format( + iepoch, + trainer_options.max_epoch, + humanfriendly.format_timespan( + (time.perf_counter() - start_time) + / (iepoch - start_epoch) + * (trainer_options.max_epoch - iepoch + 1) + ), + ) + ) + else: + logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started") + set_all_random_seed(trainer_options.seed + iepoch) + + reporter.set_epoch(iepoch) + # 1. Train and validation for one-epoch + with reporter.observe("train") as sub_reporter: + all_steps_are_invalid = cls.train_one_epoch( + model=dp_model, + optimizers=optimizers, + schedulers=schedulers, + iterator=train_iter_factory.build_iter(iepoch), + reporter=sub_reporter, + scaler=scaler, + summary_writer=summary_writer, + options=trainer_options, + distributed_option=distributed_option, + ) + + with reporter.observe("valid") as sub_reporter: + cls.validate_one_epoch( + model=dp_model, + iterator=valid_iter_factory.build_iter(iepoch), + reporter=sub_reporter, + options=trainer_options, + distributed_option=distributed_option, + ) + + if not distributed_option.distributed or distributed_option.dist_rank == 0: + # att_plot doesn't support distributed + if plot_attention_iter_factory is not None: + with reporter.observe("att_plot") as sub_reporter: + cls.plot_attention( + model=model, + output_dir=output_dir / "att_ws", + summary_writer=summary_writer, + iterator=plot_attention_iter_factory.build_iter(iepoch), + reporter=sub_reporter, + options=trainer_options, + ) + + # 2. LR Scheduler step + for scheduler in schedulers: + if isinstance(scheduler, AbsValEpochStepScheduler): + scheduler.step( + reporter.get_value(*trainer_options.val_scheduler_criterion) + ) + elif isinstance(scheduler, AbsEpochStepScheduler): + scheduler.step() + if trainer_options.sharded_ddp: + for optimizer in optimizers: + if isinstance(optimizer, fairscale.optim.oss.OSS): + optimizer.consolidate_state_dict() + + if not distributed_option.distributed or distributed_option.dist_rank == 0: + # 3. Report the results + logging.info(reporter.log_message()) + reporter.matplotlib_plot(output_dir / "images") + if summary_writer is not None: + reporter.tensorboard_add_scalar(summary_writer) + if trainer_options.use_wandb: + reporter.wandb_log() + + # 4. Save/Update the checkpoint + torch.save( + { + "model": model.state_dict(), + "reporter": reporter.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "schedulers": [ + s.state_dict() if s is not None else None + for s in schedulers + ], + "scaler": scaler.state_dict() if scaler is not None else None, + }, + output_dir / "checkpoint.pth", + ) + + # 5. Save the model and update the link to the best model + torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pth") + + # Creates a sym link latest.pth -> {iepoch}epoch.pth + p = output_dir / "latest.pth" + if p.is_symlink() or p.exists(): + p.unlink() + p.symlink_to(f"{iepoch}epoch.pth") + + _improved = [] + for _phase, k, _mode in trainer_options.best_model_criterion: + # e.g. _phase, k, _mode = "train", "loss", "min" + if reporter.has(_phase, k): + best_epoch = reporter.get_best_epoch(_phase, k, _mode) + # Creates sym links if it's the best result + if best_epoch == iepoch: + p = output_dir / f"{_phase}.{k}.best.pth" + if p.is_symlink() or p.exists(): + p.unlink() + p.symlink_to(f"{iepoch}epoch.pth") + _improved.append(f"{_phase}.{k}") + if len(_improved) == 0: + logging.info("There are no improvements in this epoch") + else: + logging.info( + "The best model has been updated: " + ", ".join(_improved) + ) + + # 6. Remove the model files excluding n-best epoch and latest epoch + _removed = [] + # Get the union set of the n-best among multiple criterion + nbests = set().union( + *[ + set(reporter.sort_epochs(ph, k, m)[:keep_nbest_models]) + for ph, k, m in trainer_options.best_model_criterion + if reporter.has(ph, k) + ] + ) + for e in range(1, iepoch): + p = output_dir / f"{e}epoch.pth" + if p.exists() and e not in nbests: + p.unlink() + _removed.append(str(p)) + if len(_removed) != 0: + logging.info("The model files were removed: " + ", ".join(_removed)) + + # 7. If any updating haven't happened, stops the training + if all_steps_are_invalid: + logging.warning( + f"The gradients at all steps are invalid in this epoch. " + f"Something seems wrong. This training was stopped at {iepoch}epoch" + ) + break + + # 8. Check early stopping + if trainer_options.patience is not None: + if reporter.check_early_stopping( + trainer_options.patience, *trainer_options.early_stopping_criterion + ): + break + + else: + logging.info( + f"The training was finished at {trainer_options.max_epoch} epochs " + ) + + if not distributed_option.distributed or distributed_option.dist_rank == 0: + # Generated n-best averaged model + average_nbest_models( + reporter=reporter, + output_dir=output_dir, + best_model_criterion=trainer_options.best_model_criterion, + nbest=keep_nbest_models, + ) + + @classmethod + def train_one_epoch( + cls, + model: torch.nn.Module, + iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], + optimizers: Sequence[torch.optim.Optimizer], + schedulers: Sequence[Optional[AbsScheduler]], + scaler: Optional[GradScaler], + reporter: SubReporter, + summary_writer: Optional[SummaryWriter], + options: TrainerOptions, + distributed_option: DistributedOption, + ) -> bool: + assert check_argument_types() + + grad_noise = options.grad_noise + accum_grad = options.accum_grad + grad_clip = options.grad_clip + grad_clip_type = options.grad_clip_type + log_interval = options.log_interval + no_forward_run = options.no_forward_run + ngpu = options.ngpu + use_wandb = options.use_wandb + distributed = distributed_option.distributed + + if log_interval is None: + try: + log_interval = max(len(iterator) // 20, 10) + except TypeError: + log_interval = 100 + + model.train() + all_steps_are_invalid = True + # [For distributed] Because iteration counts are not always equals between + # processes, send stop-flag to the other processes if iterator is finished + iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") + + start_time = time.perf_counter() + for iiter, (_, batch) in enumerate( + reporter.measure_iter_time(iterator, "iter_time"), 1 + ): + assert isinstance(batch, dict), type(batch) + + if distributed: + torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) + if iterator_stop > 0: + break + + batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") + if no_forward_run: + all_steps_are_invalid = False + continue + + with autocast(scaler is not None): + with reporter.measure_time("forward_time"): + retval = model(**batch) + + # Note(kamo): + # Supporting two patterns for the returned value from the model + # a. dict type + if isinstance(retval, dict): + loss = retval["loss"] + stats = retval["stats"] + weight = retval["weight"] + optim_idx = retval.get("optim_idx") + if optim_idx is not None and not isinstance(optim_idx, int): + if not isinstance(optim_idx, torch.Tensor): + raise RuntimeError( + "optim_idx must be int or 1dim torch.Tensor, " + f"but got {type(optim_idx)}" + ) + if optim_idx.dim() >= 2: + raise RuntimeError( + "optim_idx must be int or 1dim torch.Tensor, " + f"but got {optim_idx.dim()}dim tensor" + ) + if optim_idx.dim() == 1: + for v in optim_idx: + if v != optim_idx[0]: + raise RuntimeError( + "optim_idx must be 1dim tensor " + "having same values for all entries" + ) + optim_idx = optim_idx[0].item() + else: + optim_idx = optim_idx.item() + + # b. tuple or list type + else: + loss, stats, weight = retval + optim_idx = None + + stats = {k: v for k, v in stats.items() if v is not None} + if ngpu > 1 or distributed: + # Apply weighted averaging for loss and stats + loss = (loss * weight.type(loss.dtype)).sum() + + # if distributed, this method can also apply all_reduce() + stats, weight = recursive_average(stats, weight, distributed) + + # Now weight is summation over all workers + loss /= weight + if distributed: + # NOTE(kamo): Multiply world_size because DistributedDataParallel + # automatically normalizes the gradient by world_size. + loss *= torch.distributed.get_world_size() + + loss /= accum_grad + + reporter.register(stats, weight) + + with reporter.measure_time("backward_time"): + if scaler is not None: + # Scales loss. Calls backward() on scaled loss + # to create scaled gradients. + # Backward passes under autocast are not recommended. + # Backward ops run in the same dtype autocast chose + # for corresponding forward ops. + scaler.scale(loss).backward() + else: + loss.backward() + + if iiter % accum_grad == 0: + if scaler is not None: + # Unscales the gradients of optimizer's assigned params in-place + for iopt, optimizer in enumerate(optimizers): + if optim_idx is not None and iopt != optim_idx: + continue + scaler.unscale_(optimizer) + + # gradient noise injection + if grad_noise: + add_gradient_noise( + model, + reporter.get_total_count(), + duration=100, + eta=1.0, + scale_factor=0.55, + ) + + # compute the gradient norm to check if it is normal or not + grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), + max_norm=grad_clip, + norm_type=grad_clip_type, + ) + # PyTorch<=1.4, clip_grad_norm_ returns float value + if not isinstance(grad_norm, torch.Tensor): + grad_norm = torch.tensor(grad_norm) + + if not torch.isfinite(grad_norm): + logging.warning( + f"The grad norm is {grad_norm}. Skipping updating the model." + ) + + # Must invoke scaler.update() if unscale_() is used in the iteration + # to avoid the following error: + # RuntimeError: unscale_() has already been called + # on this optimizer since the last update(). + # Note that if the gradient has inf/nan values, + # scaler.step skips optimizer.step(). + if scaler is not None: + for iopt, optimizer in enumerate(optimizers): + if optim_idx is not None and iopt != optim_idx: + continue + scaler.step(optimizer) + scaler.update() + + else: + all_steps_are_invalid = False + with reporter.measure_time("optim_step_time"): + for iopt, (optimizer, scheduler) in enumerate( + zip(optimizers, schedulers) + ): + if optim_idx is not None and iopt != optim_idx: + continue + if scaler is not None: + # scaler.step() first unscales the gradients of + # the optimizer's assigned params. + scaler.step(optimizer) + # Updates the scale for next iteration. + scaler.update() + else: + optimizer.step() + if isinstance(scheduler, AbsBatchStepScheduler): + scheduler.step() + optimizer.zero_grad() + + # Register lr and train/load time[sec/step], + # where step refers to accum_grad * mini-batch + reporter.register( + dict( + { + f"optim{i}_lr{j}": pg["lr"] + for i, optimizer in enumerate(optimizers) + for j, pg in enumerate(optimizer.param_groups) + if "lr" in pg + }, + train_time=time.perf_counter() - start_time, + ), + ) + start_time = time.perf_counter() + + # NOTE(kamo): Call log_message() after next() + reporter.next() + if iiter % log_interval == 0: + logging.info(reporter.log_message(-log_interval)) + if summary_writer is not None: + reporter.tensorboard_add_scalar(summary_writer, -log_interval) + if use_wandb: + reporter.wandb_log() + + else: + if distributed: + iterator_stop.fill_(1) + torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) + + return all_steps_are_invalid + + @classmethod + @torch.no_grad() + def validate_one_epoch( + cls, + model: torch.nn.Module, + iterator: Iterable[Dict[str, torch.Tensor]], + reporter: SubReporter, + options: TrainerOptions, + distributed_option: DistributedOption, + ) -> None: + assert check_argument_types() + ngpu = options.ngpu + no_forward_run = options.no_forward_run + distributed = distributed_option.distributed + + model.eval() + + # [For distributed] Because iteration counts are not always equals between + # processes, send stop-flag to the other processes if iterator is finished + iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu") + for (_, batch) in iterator: + assert isinstance(batch, dict), type(batch) + if distributed: + torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) + if iterator_stop > 0: + break + + batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") + if no_forward_run: + continue + + retval = model(**batch) + if isinstance(retval, dict): + stats = retval["stats"] + weight = retval["weight"] + else: + _, stats, weight = retval + if ngpu > 1 or distributed: + # Apply weighted averaging for stats. + # if distributed, this method can also apply all_reduce() + stats, weight = recursive_average(stats, weight, distributed) + + reporter.register(stats, weight) + reporter.next() + + else: + if distributed: + iterator_stop.fill_(1) + torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM) + + @classmethod + @torch.no_grad() + def plot_attention( + cls, + model: torch.nn.Module, + output_dir: Optional[Path], + summary_writer: Optional[SummaryWriter], + iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], + reporter: SubReporter, + options: TrainerOptions, + ) -> None: + assert check_argument_types() + import matplotlib + + ngpu = options.ngpu + no_forward_run = options.no_forward_run + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + from matplotlib.ticker import MaxNLocator + + model.eval() + for ids, batch in iterator: + assert isinstance(batch, dict), type(batch) + assert len(next(iter(batch.values()))) == len(ids), ( + len(next(iter(batch.values()))), + len(ids), + ) + batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") + if no_forward_run: + continue + + # 1. Forwarding model and gathering all attentions + # calculate_all_attentions() uses single gpu only. + att_dict = calculate_all_attentions(model, batch) + + # 2. Plot attentions: This part is slow due to matplotlib + for k, att_list in att_dict.items(): + assert len(att_list) == len(ids), (len(att_list), len(ids)) + for id_, att_w in zip(ids, att_list): + + if isinstance(att_w, torch.Tensor): + att_w = att_w.detach().cpu().numpy() + + if att_w.ndim == 2: + att_w = att_w[None] + elif att_w.ndim > 3 or att_w.ndim == 1: + raise RuntimeError(f"Must be 2 or 3 dimension: {att_w.ndim}") + + w, h = plt.figaspect(1.0 / len(att_w)) + fig = plt.Figure(figsize=(w * 1.3, h * 1.3)) + axes = fig.subplots(1, len(att_w)) + if len(att_w) == 1: + axes = [axes] + + for ax, aw in zip(axes, att_w): + ax.imshow(aw.astype(np.float32), aspect="auto") + ax.set_title(f"{k}_{id_}") + ax.set_xlabel("Input") + ax.set_ylabel("Output") + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.yaxis.set_major_locator(MaxNLocator(integer=True)) + + if output_dir is not None: + p = output_dir / id_ / f"{k}.{reporter.get_epoch()}ep.png" + p.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(p) + + if summary_writer is not None: + summary_writer.add_figure( + f"{k}_{id_}", fig, reporter.get_epoch() + ) + reporter.next() diff --git a/espnet2/tts/__init__.py b/espnet2/tts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/tts/abs_tts.py b/espnet2/tts/abs_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..d226b678069327fc5b1581e073140a08ad5e1c04 --- /dev/null +++ b/espnet2/tts/abs_tts.py @@ -0,0 +1,30 @@ +from abc import ABC +from abc import abstractmethod +from typing import Dict +from typing import Tuple + +import torch + + +class AbsTTS(torch.nn.Module, ABC): + @abstractmethod + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + spembs: torch.Tensor = None, + spcs: torch.Tensor = None, + spcs_lengths: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def inference( + self, + text: torch.Tensor, + spembs: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/espnet2/tts/duration_calculator.py b/espnet2/tts/duration_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..82a31498e5f9051c84f07cf84f6fe0120f0eca83 --- /dev/null +++ b/espnet2/tts/duration_calculator.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Duration calculator for ESPnet2.""" + +from typing import Tuple + +import torch + + +class DurationCalculator(torch.nn.Module): + """Duration calculator module.""" + + def __init__(self): + """Initilize duration calculator.""" + super().__init__() + + @torch.no_grad() + def forward(self, att_ws: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert attention weight to durations. + + Args: + att_ws (Tesnor): Attention weight tensor (L, T) or (#layers, #heads, L, T). + + Returns: + LongTensor: Duration of each input (T,). + Tensor: Focus rate value. + + """ + duration = self._calculate_duration(att_ws) + focus_rate = self._calculate_focus_rete(att_ws) + + return duration, focus_rate + + @staticmethod + def _calculate_focus_rete(att_ws): + if len(att_ws.shape) == 2: + # tacotron 2 case -> (L, T) + return att_ws.max(dim=-1)[0].mean() + elif len(att_ws.shape) == 4: + # transformer case -> (#layers, #heads, L, T) + return att_ws.max(dim=-1)[0].mean(dim=-1).max() + else: + raise ValueError("att_ws should be 2 or 4 dimensional tensor.") + + @staticmethod + def _calculate_duration(att_ws): + if len(att_ws.shape) == 2: + # tacotron 2 case -> (L, T) + pass + elif len(att_ws.shape) == 4: + # transformer case -> (#layers, #heads, L, T) + # get the most diagonal head according to focus rate + att_ws = torch.cat( + [att_w for att_w in att_ws], dim=0 + ) # (#heads * #layers, L, T) + diagonal_scores = att_ws.max(dim=-1)[0].mean(dim=-1) # (#heads * #layers,) + diagonal_head_idx = diagonal_scores.argmax() + att_ws = att_ws[diagonal_head_idx] # (L, T) + else: + raise ValueError("att_ws should be 2 or 4 dimensional tensor.") + # calculate duration from 2d attention weight + durations = torch.stack( + [att_ws.argmax(-1).eq(i).sum() for i in range(att_ws.shape[1])] + ) + return durations.view(-1) diff --git a/espnet2/tts/espnet_model.py b/espnet2/tts/espnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..217d76e011509649c111778ac7f439cf1e8c8e98 --- /dev/null +++ b/espnet2/tts/espnet_model.py @@ -0,0 +1,218 @@ +from contextlib import contextmanager +from distutils.version import LooseVersion +from typing import Dict +from typing import Optional +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from espnet2.layers.abs_normalize import AbsNormalize +from espnet2.layers.inversible_interface import InversibleInterface +from espnet2.train.abs_espnet_model import AbsESPnetModel +from espnet2.tts.abs_tts import AbsTTS +from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract + +if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + from torch.cuda.amp import autocast +else: + # Nothing to do if torch<1.6.0 + @contextmanager + def autocast(enabled=True): + yield + + +class ESPnetTTSModel(AbsESPnetModel): + def __init__( + self, + feats_extract: Optional[AbsFeatsExtract], + pitch_extract: Optional[AbsFeatsExtract], + energy_extract: Optional[AbsFeatsExtract], + normalize: Optional[AbsNormalize and InversibleInterface], + pitch_normalize: Optional[AbsNormalize and InversibleInterface], + energy_normalize: Optional[AbsNormalize and InversibleInterface], + tts: AbsTTS, + ): + assert check_argument_types() + super().__init__() + self.feats_extract = feats_extract + self.pitch_extract = pitch_extract + self.energy_extract = energy_extract + self.normalize = normalize + self.pitch_normalize = pitch_normalize + self.energy_normalize = energy_normalize + self.tts = tts + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + durations: torch.Tensor = None, + durations_lengths: torch.Tensor = None, + pitch: torch.Tensor = None, + pitch_lengths: torch.Tensor = None, + energy: torch.Tensor = None, + energy_lengths: torch.Tensor = None, + spembs: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + with autocast(False): + # Extract features + if self.feats_extract is not None: + feats, feats_lengths = self.feats_extract(speech, speech_lengths) + else: + feats, feats_lengths = speech, speech_lengths + + # Extract auxiliary features + if self.pitch_extract is not None and pitch is None: + pitch, pitch_lengths = self.pitch_extract( + speech, + speech_lengths, + feats_lengths=feats_lengths, + durations=durations, + durations_lengths=durations_lengths, + ) + if self.energy_extract is not None and energy is None: + energy, energy_lengths = self.energy_extract( + speech, + speech_lengths, + feats_lengths=feats_lengths, + durations=durations, + durations_lengths=durations_lengths, + ) + + # Normalize + if self.normalize is not None: + feats, feats_lengths = self.normalize(feats, feats_lengths) + if self.pitch_normalize is not None: + pitch, pitch_lengths = self.pitch_normalize(pitch, pitch_lengths) + if self.energy_normalize is not None: + energy, energy_lengths = self.energy_normalize(energy, energy_lengths) + + # Update kwargs for additional auxiliary inputs + if spembs is not None: + kwargs.update(spembs=spembs) + if durations is not None: + kwargs.update(durations=durations, durations_lengths=durations_lengths) + if self.pitch_extract is not None and pitch is not None: + kwargs.update(pitch=pitch, pitch_lengths=pitch_lengths) + if self.energy_extract is not None and energy is not None: + kwargs.update(energy=energy, energy_lengths=energy_lengths) + + return self.tts( + text=text, + text_lengths=text_lengths, + speech=feats, + speech_lengths=feats_lengths, + **kwargs, + ) + + def collect_feats( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + durations: torch.Tensor = None, + durations_lengths: torch.Tensor = None, + pitch: torch.Tensor = None, + pitch_lengths: torch.Tensor = None, + energy: torch.Tensor = None, + energy_lengths: torch.Tensor = None, + spembs: torch.Tensor = None, + ) -> Dict[str, torch.Tensor]: + if self.feats_extract is not None: + feats, feats_lengths = self.feats_extract(speech, speech_lengths) + else: + feats, feats_lengths = speech, speech_lengths + feats_dict = {"feats": feats, "feats_lengths": feats_lengths} + + if self.pitch_extract is not None: + pitch, pitch_lengths = self.pitch_extract( + speech, + speech_lengths, + feats_lengths=feats_lengths, + durations=durations, + durations_lengths=durations_lengths, + ) + if self.energy_extract is not None: + energy, energy_lengths = self.energy_extract( + speech, + speech_lengths, + feats_lengths=feats_lengths, + durations=durations, + durations_lengths=durations_lengths, + ) + if pitch is not None: + feats_dict.update(pitch=pitch, pitch_lengths=pitch_lengths) + if energy is not None: + feats_dict.update(energy=energy, energy_lengths=energy_lengths) + + return feats_dict + + def inference( + self, + text: torch.Tensor, + speech: torch.Tensor = None, + spembs: torch.Tensor = None, + durations: torch.Tensor = None, + pitch: torch.Tensor = None, + energy: torch.Tensor = None, + **decode_config, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + kwargs = {} + # TC marker, oorspr false + if decode_config["use_teacher_forcing"] or getattr(self.tts, "use_gst", False): + if speech is None: + raise RuntimeError("missing required argument: 'speech'") + if self.feats_extract is not None: + feats = self.feats_extract(speech[None])[0][0] + else: + feats = speech + if self.normalize is not None: + feats = self.normalize(feats[None])[0][0] + kwargs["speech"] = feats + + if decode_config["use_teacher_forcing"]: + if durations is not None: + kwargs["durations"] = durations + + if self.pitch_extract is not None: + pitch = self.pitch_extract( + speech[None], + feats_lengths=torch.LongTensor([len(feats)]), + durations=durations[None], + )[0][0] + if self.pitch_normalize is not None: + pitch = self.pitch_normalize(pitch[None])[0][0] + if pitch is not None: + kwargs["pitch"] = pitch + + if self.energy_extract is not None: + energy = self.energy_extract( + speech[None], + feats_lengths=torch.LongTensor([len(feats)]), + durations=durations[None], + )[0][0] + if self.energy_normalize is not None: + energy = self.energy_normalize(energy[None])[0][0] + if energy is not None: + kwargs["energy"] = energy + + if spembs is not None: + kwargs["spembs"] = spembs + + outs, probs, att_ws, ref_embs, ar_prior_loss = self.tts.inference( + text=text, + **kwargs, + **decode_config + ) + + if self.normalize is not None: + # NOTE: normalize.inverse is in-place operation + outs_denorm = self.normalize.inverse(outs.clone()[None])[0][0] + else: + outs_denorm = outs + return outs, outs_denorm, probs, att_ws, ref_embs, ar_prior_loss diff --git a/espnet2/tts/fastespeech.py b/espnet2/tts/fastespeech.py new file mode 100644 index 0000000000000000000000000000000000000000..cdbbc35668f42a13a0e45b907ea18146302d6ac5 --- /dev/null +++ b/espnet2/tts/fastespeech.py @@ -0,0 +1,710 @@ +""" FastESpeech """ + +from typing import Dict +from typing import Sequence +from typing import Tuple + +import torch +import torch.nn.functional as F + +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.e2e_tts_fastspeech import ( + FeedForwardTransformerLoss as FastSpeechLoss, # NOQA +) +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor +from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import ( + Encoder as TransformerEncoder, # noqa: H301 +) + +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.torch_utils.initialize import initialize +from espnet2.tts.abs_tts import AbsTTS +from espnet2.tts.prosody_encoder import ProsodyEncoder + + +class FastESpeech(AbsTTS): + """FastESpeech module. + + This module adds a VQ-VAE prosody encoder to the FastSpeech model, and + takes cues from FastSpeech 2 for training. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/abs/1905.09263 + .. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`: + https://arxiv.org/abs/2006.04558 + + Args: + idim (int): Dimension of the input -> size of the phoneme vocabulary. + odim (int): Dimension of the output -> dimension of the mel-spectrograms. + adim (int, optional): Dimension of the phoneme embeddings, dimension of the + prosody embedding, the hidden size of the self-attention, 1D convolution + in the FFT block. + aheads (int, optional): Number of attention heads. + elayers (int, optional): Number of encoder layers/blocks. + eunits (int, optional): Number of encoder hidden units + -> The number of units of position-wise feed forward layer. + dlayers (int, optional): Number of decoder layers/blocks. + dunits (int, optional): Number of decoder hidden units + -> The number of units of position-wise feed forward layer. + positionwise_layer_type (str, optional): Type of position-wise feed forward + layer - linear or conv1d. + positionwise_conv_kernel_size (int, optional): kernel size of positionwise + conv1d layer. + use_scaled_pos_enc (bool, optional): + Whether to use trainable scaled positional encoding. + encoder_normalize_before (bool, optional): + Whether to perform layer normalization before encoder block. + decoder_normalize_before (bool, optional): + Whether to perform layer normalization before decoder block. + encoder_concat_after (bool, optional): Whether to concatenate attention + layer's input and output in encoder. + decoder_concat_after (bool, optional): Whether to concatenate attention + layer's input and output in decoder. + duration_predictor_layers (int, optional): Number of duration predictor layers. + duration_predictor_chans (int, optional): Number of duration predictor channels. + duration_predictor_kernel_size (int, optional): + Kernel size of duration predictor. + reduction_factor (int, optional): Factor to multiply with output dimension. + encoder_type (str, optional): Encoder architecture type. + decoder_type (str, optional): Decoder architecture type. + # spk_embed_dim (int, optional): Number of speaker embedding dimensions. + # spk_embed_integration_type: How to integrate speaker embedding. + ref_enc_conv_layers (int, optional): + The number of conv layers in the reference encoder. + ref_enc_conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in the referece encoder. + ref_enc_conv_kernel_size (int, optional): + Kernal size of conv layers in the reference encoder. + ref_enc_conv_stride (int, optional): + Stride size of conv layers in the reference encoder. + ref_enc_gru_layers (int, optional): + The number of GRU layers in the reference encoder. + ref_enc_gru_units (int, optional): + The number of GRU units in the reference encoder. + ref_emb_integration_type: How to integrate reference embedding. + # reduction_factor (int, optional): Reduction factor. + prosody_num_embs (int, optional): The higher this value, the higher the + capacity in the information bottleneck. + prosody_hidden_dim (int, optional): Number of hidden channels. + prosody_emb_integration_type: How to integrate prosody embedding. + transformer_enc_dropout_rate (float, optional): + Dropout rate in encoder except attention & positional encoding. + transformer_enc_positional_dropout_rate (float, optional): + Dropout rate after encoder positional encoding. + transformer_enc_attn_dropout_rate (float, optional): + Dropout rate in encoder self-attention module. + transformer_dec_dropout_rate (float, optional): + Dropout rate in decoder except attention & positional encoding. + transformer_dec_positional_dropout_rate (float, optional): + Dropout rate after decoder positional encoding. + transformer_dec_attn_dropout_rate (float, optional): + Dropout rate in decoder self-attention module. + duration_predictor_dropout_rate (float, optional): + Dropout rate in duration predictor. + init_type (str, optional): + How to initialize transformer parameters. + init_enc_alpha (float, optional): + Initial value of alpha in scaled pos encoding of the encoder. + init_dec_alpha (float, optional): + Initial value of alpha in scaled pos encoding of the decoder. + use_masking (bool, optional): + Whether to apply masking for padded part in loss calculation. + use_weighted_masking (bool, optional): + Whether to apply weighted masking in loss calculation. + """ + + def __init__( + self, + # network structure related + idim: int, + odim: int, + adim: int = 384, + aheads: int = 4, + elayers: int = 6, + eunits: int = 1536, + dlayers: int = 6, + dunits: int = 1536, + postnet_layers: int = 0, # 5 + postnet_chans: int = 512, + postnet_filts: int = 5, + positionwise_layer_type: str = "conv1d", + positionwise_conv_kernel_size: int = 1, + use_scaled_pos_enc: bool = True, + use_batch_norm: bool = True, + encoder_normalize_before: bool = True, + decoder_normalize_before: bool = True, + encoder_concat_after: bool = False, + decoder_concat_after: bool = False, + duration_predictor_layers: int = 2, + duration_predictor_chans: int = 384, + duration_predictor_kernel_size: int = 3, + reduction_factor: int = 1, + encoder_type: str = "transformer", + decoder_type: str = "transformer", + # # only for conformer + # conformer_pos_enc_layer_type: str = "rel_pos", + # conformer_self_attn_layer_type: str = "rel_selfattn", + # conformer_activation_type: str = "swish", + # use_macaron_style_in_conformer: bool = True, + # use_cnn_in_conformer: bool = True, + # conformer_enc_kernel_size: int = 7, + # conformer_dec_kernel_size: int = 31, + # # pretrained spk emb + # spk_embed_dim: int = None, + # spk_embed_integration_type: str = "add", + # reference encoder + ref_enc_conv_layers: int = 2, + ref_enc_conv_chans_list: Sequence[int] = (32, 32), + ref_enc_conv_kernel_size: int = 3, + ref_enc_conv_stride: int = 1, + ref_enc_gru_layers: int = 1, + ref_enc_gru_units: int = 32, + ref_emb_integration_type: str = "add", + # prosody encoder + prosody_num_embs: int = 256, + prosody_hidden_dim: int = 128, + prosody_emb_integration_type: str = "add", + # training related + transformer_enc_dropout_rate: float = 0.1, + transformer_enc_positional_dropout_rate: float = 0.1, + transformer_enc_attn_dropout_rate: float = 0.1, + transformer_dec_dropout_rate: float = 0.1, + transformer_dec_positional_dropout_rate: float = 0.1, + transformer_dec_attn_dropout_rate: float = 0.1, + duration_predictor_dropout_rate: float = 0.1, + postnet_dropout_rate: float = 0.5, + init_type: str = "xavier_uniform", + init_enc_alpha: float = 1.0, + init_dec_alpha: float = 1.0, + use_masking: bool = False, + use_weighted_masking: bool = False, + ): + """Initialize FastESpeech module.""" + assert check_argument_types() + super().__init__() + + # store hyperparameters + self.idim = idim + self.odim = odim + self.eos = idim - 1 + self.reduction_factor = reduction_factor + self.encoder_type = encoder_type + self.decoder_type = decoder_type + self.use_scaled_pos_enc = use_scaled_pos_enc + self.prosody_emb_integration_type = prosody_emb_integration_type + # self.spk_embed_dim = spk_embed_dim + # if self.spk_embed_dim is not None: + # self.spk_embed_integration_type = spk_embed_integration_type + + # use idx 0 as padding idx, see: + # https://stackoverflow.com/questions/61172400/what-does-padding-idx-do-in-nn-embeddings + self.padding_idx = 0 + + # get positional encoding class + pos_enc_class = ( + ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding + ) + + # define encoder + encoder_input_layer = torch.nn.Embedding( + num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx + ) + if encoder_type == "transformer": + self.encoder = TransformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + ) + # elif encoder_type == "conformer": + # self.encoder = ConformerEncoder( + # idim=idim, + # attention_dim=adim, + # attention_heads=aheads, + # linear_units=eunits, + # num_blocks=elayers, + # input_layer=encoder_input_layer, + # dropout_rate=transformer_enc_dropout_rate, + # positional_dropout_rate=transformer_enc_positional_dropout_rate, + # attention_dropout_rate=transformer_enc_attn_dropout_rate, + # normalize_before=encoder_normalize_before, + # concat_after=encoder_concat_after, + # positionwise_layer_type=positionwise_layer_type, + # positionwise_conv_kernel_size=positionwise_conv_kernel_size, + # macaron_style=use_macaron_style_in_conformer, + # pos_enc_layer_type=conformer_pos_enc_layer_type, + # selfattention_layer_type=conformer_self_attn_layer_type, + # activation_type=conformer_activation_type, + # use_cnn_module=use_cnn_in_conformer, + # cnn_module_kernel=conformer_enc_kernel_size, + # ) + else: + raise ValueError(f"{encoder_type} is not supported.") + + # define additional projection for prosody embedding + if self.prosody_emb_integration_type == "concat": + self.prosody_projection = torch.nn.Linear( + adim * 2, adim + ) + + # define prosody encoder + self.prosody_encoder = ProsodyEncoder( + odim, + adim=adim, + num_embeddings=prosody_num_embs, + hidden_dim=prosody_hidden_dim, + ref_enc_conv_layers=ref_enc_conv_layers, + ref_enc_conv_chans_list=ref_enc_conv_chans_list, + ref_enc_conv_kernel_size=ref_enc_conv_kernel_size, + ref_enc_conv_stride=ref_enc_conv_stride, + global_enc_gru_layers=ref_enc_gru_layers, + global_enc_gru_units=ref_enc_gru_units, + global_emb_integration_type=ref_emb_integration_type, + ) + + # # define additional projection for speaker embedding + # if self.spk_embed_dim is not None: + # if self.spk_embed_integration_type == "add": + # self.projection = torch.nn.Linear(self.spk_embed_dim, adim) + # else: + # self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) + + # define duration predictor + self.duration_predictor = DurationPredictor( + idim=adim, + n_layers=duration_predictor_layers, + n_chans=duration_predictor_chans, + kernel_size=duration_predictor_kernel_size, + dropout_rate=duration_predictor_dropout_rate, + ) + + # define length regulator + self.length_regulator = LengthRegulator() + + # define decoder + # NOTE: we use encoder as decoder + # because fastspeech's decoder is the same as encoder + if decoder_type == "transformer": + self.decoder = TransformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + ) + # elif decoder_type == "conformer": + # self.decoder = ConformerEncoder( + # idim=0, + # attention_dim=adim, + # attention_heads=aheads, + # linear_units=dunits, + # num_blocks=dlayers, + # input_layer=None, + # dropout_rate=transformer_dec_dropout_rate, + # positional_dropout_rate=transformer_dec_positional_dropout_rate, + # attention_dropout_rate=transformer_dec_attn_dropout_rate, + # normalize_before=decoder_normalize_before, + # concat_after=decoder_concat_after, + # positionwise_layer_type=positionwise_layer_type, + # positionwise_conv_kernel_size=positionwise_conv_kernel_size, + # macaron_style=use_macaron_style_in_conformer, + # pos_enc_layer_type=conformer_pos_enc_layer_type, + # selfattention_layer_type=conformer_self_attn_layer_type, + # activation_type=conformer_activation_type, + # use_cnn_module=use_cnn_in_conformer, + # cnn_module_kernel=conformer_dec_kernel_size, + # ) + else: + raise ValueError(f"{decoder_type} is not supported.") + + # define final projection + self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) + + # define postnet + self.postnet = ( + None + if postnet_layers == 0 + else Postnet( + idim=idim, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=use_batch_norm, + dropout_rate=postnet_dropout_rate, + ) + ) + + # initialize parameters + self._reset_parameters( + init_type=init_type, + init_enc_alpha=init_enc_alpha, + init_dec_alpha=init_dec_alpha, + ) + + # define criterions + self.criterion = FastSpeechLoss( + use_masking=use_masking, use_weighted_masking=use_weighted_masking + ) + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + durations: torch.Tensor, + durations_lengths: torch.Tensor, + spembs: torch.Tensor = None, + train_ar_prior: bool = False, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Calculate forward propagation. + + Args: + text (LongTensor): Batch of padded token ids (B, Tmax). + text_lengths (LongTensor): Batch of lengths of each input (B,). + speech (Tensor): Batch of padded target features (B, Lmax, odim). + speech_lengths (LongTensor): Batch of the lengths of each target (B,). + durations (LongTensor): Batch of padded durations (B, Tmax + 1). + durations_lengths (LongTensor): Batch of duration lengths (B, Tmax + 1). + spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Loss scalar value. + Dict: Statistics to be monitored. + Tensor: Weight value. + + """ + # train_ar_prior = True # TC marker + text = text[:, : text_lengths.max()] # for data-parallel + speech = speech[:, : speech_lengths.max()] # for data-parallel + durations = durations[:, : durations_lengths.max()] # for data-parallel + + batch_size = text.size(0) + + # Add eos at the last of sequence + xs = F.pad(text, [0, 1], "constant", self.padding_idx) + for i, l in enumerate(text_lengths): + xs[i, l] = self.eos + ilens = text_lengths + 1 + + ys, ds = speech, durations + olens = speech_lengths + + # forward propagation + before_outs, after_outs, d_outs, ref_embs, \ + vq_loss, ar_prior_loss, perplexity = self._forward( + xs, + ilens, + ys, + olens, + ds, + spembs=spembs, + is_inference=False, + train_ar_prior=train_ar_prior + ) + + # modify mod part of groundtruth + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + max_olen = max(olens) + ys = ys[:, :max_olen] + + if self.postnet is None: + after_outs = None + + # calculate loss TODO: refactor if freezing works + l1_loss, duration_loss = self.criterion( + after_outs, before_outs, d_outs, ys, ds, ilens, olens + ) + if train_ar_prior: + loss = ar_prior_loss + stats = dict( + l1_loss=l1_loss.item(), + duration_loss=duration_loss.item(), + vq_loss=vq_loss.item(), + ar_prior_loss=ar_prior_loss.item(), + loss=loss.item(), + perplexity=perplexity.item(), + ) + else : + loss = l1_loss + duration_loss + vq_loss + stats = dict( + l1_loss=l1_loss.item(), + duration_loss=duration_loss.item(), + vq_loss=vq_loss.item(), + loss=loss.item(), + perplexity=perplexity.item() + ) + + # report extra information + if self.encoder_type == "transformer" and self.use_scaled_pos_enc: + stats.update( + encoder_alpha=self.encoder.embed[-1].alpha.data.item(), + ) + if self.decoder_type == "transformer" and self.use_scaled_pos_enc: + stats.update( + decoder_alpha=self.decoder.embed[-1].alpha.data.item(), + ) + + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def _forward( + self, + xs: torch.Tensor, + ilens: torch.Tensor, + ys: torch.Tensor = None, + olens: torch.Tensor = None, + ds: torch.Tensor = None, + spembs: torch.Tensor = None, + ref_embs: torch.Tensor = None, + is_inference: bool = False, + train_ar_prior: bool = False, + ar_prior_inference: bool = False, + alpha: float = 1.0, + fg_inds: torch.Tensor = None, + ) -> Sequence[torch.Tensor]: + # forward encoder + x_masks = self._source_mask(ilens) + hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) + + # # integrate speaker embedding + # if self.spk_embed_dim is not None: + # hs = self._integrate_with_spk_embed(hs, spembs) + + # integrate with prosody encoder + # (B, Tmax, adim) + p_embs, vq_loss, ar_prior_loss, perplexity, ref_embs = self.prosody_encoder( + ys, + ds, + hs, + global_embs=ref_embs, + train_ar_prior=train_ar_prior, + ar_prior_inference=ar_prior_inference, + fg_inds=fg_inds, + ) + + hs = self._integrate_with_prosody_embs(hs, p_embs) + + # forward duration predictor + d_masks = make_pad_mask(ilens).to(xs.device) + + if is_inference: + print('predicted durations') + d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) + hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim) + else: + d_outs = self.duration_predictor(hs, d_masks) + # use groundtruth in training + hs = self.length_regulator(hs, ds) # (B, Lmax, adim) + + # forward decoder + if olens is not None and not is_inference: + if self.reduction_factor > 1: + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + olens_in = olens + h_masks = self._source_mask(olens_in) + else: + h_masks = None + zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) + before_outs = self.feat_out(zs).view( + zs.size(0), -1, self.odim + ) # (B, Lmax, odim) + + # postnet -> (B, Lmax//r * r, odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose(1, 2) + ).transpose(1, 2) + + return before_outs, after_outs, d_outs, ref_embs, vq_loss, ar_prior_loss, \ + perplexity + + def inference( + self, + text: torch.Tensor, + speech: torch.Tensor = None, + spembs: torch.Tensor = None, + durations: torch.Tensor = None, + ref_embs: torch.Tensor = None, + alpha: float = 1.0, + use_teacher_forcing: bool = False, + ar_prior_inference: bool = False, + fg_inds: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the sequence of features given the sequences of characters. + + Args: + text (LongTensor): Input sequence of characters (T,). + speech (Tensor, optional): Feature sequence to extract style (B, idim). + spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). + durations (LongTensor, optional): Groundtruth of duration (T + 1,). + ref_embs (Tensor, optional): Reference embedding vector (B, gru_units). + alpha (float, optional): Alpha to control the speed. + use_teacher_forcing (bool, optional): Whether to use teacher forcing. + If true, groundtruth of duration will be used. + + Returns: + Tensor: Output sequence of features (L, odim). + None: Dummy for compatibility. + None: Dummy for compatibility. + + """ + x, y = text, speech + spemb, d = spembs, durations + + # add eos at the last of sequence + x = F.pad(x, [0, 1], "constant", self.eos) + + # setup batch axis + ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) + xs, ys = x.unsqueeze(0), None + if y is not None: + ys = y.unsqueeze(0) + if spemb is not None: + spembs = spemb.unsqueeze(0) + if ref_embs is not None: + ref_embs = ref_embs.unsqueeze(0) + + if use_teacher_forcing: + # use groundtruth of duration + ds = d.unsqueeze(0) + _, after_outs, _, ref_embs, _, ar_prior_loss, _ = self._forward( + xs, + ilens, + ys, + ds=ds, + spembs=spembs, + ref_embs=ref_embs, + ar_prior_inference=ar_prior_inference, + ) # (1, L, odim) + else: + _, after_outs, _, ref_embs, _, ar_prior_loss, _ = self._forward( + xs, + ilens, + ys, + spembs=spembs, + ref_embs=ref_embs, + is_inference=True, + alpha=alpha, + ar_prior_inference=ar_prior_inference, + fg_inds=fg_inds, + ) # (1, L, odim) + + return after_outs[0], None, None, ref_embs, ar_prior_loss + + # def _integrate_with_spk_embed( + # self, hs: torch.Tensor, spembs: torch.Tensor + # ) -> torch.Tensor: + # """Integrate speaker embedding with hidden states. + + # Args: + # hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). + # spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + + # Returns: + # Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). + + # """ + # if self.spk_embed_integration_type == "add": + # # apply projection and then add to hidden states + # spembs = self.projection(F.normalize(spembs)) + # hs = hs + spembs.unsqueeze(1) + # elif self.spk_embed_integration_type == "concat": + # # concat hidden states with spk embeds and then apply projection + # spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + # hs = self.projection(torch.cat([hs, spembs], dim=-1)) + # else: + # raise NotImplementedError("support only add or concat.") + + # return hs + + def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: + """Make masks for self-attention. + + Args: + ilens (LongTensor): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + return x_masks.unsqueeze(-2) + + def _integrate_with_prosody_embs( + self, hs: torch.Tensor, p_embs: torch.Tensor + ) -> torch.Tensor: + """Integrate prosody embeddings with hidden states. + + Args: + hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). + p_embs (Tensor): Batch of prosody embeddings (B, Tmax, adim). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). + + """ + if self.prosody_emb_integration_type == "add": + # apply projection and then add to hidden states + # (B, Tmax, adim) + hs = hs + p_embs + elif self.prosody_emb_integration_type == "concat": + # concat hidden states with prosody embeds and then apply projection + # (B, Tmax, adim) + hs = self.prosody_projection(torch.cat([hs, p_embs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs + + def _reset_parameters( + self, init_type: str, init_enc_alpha: float, init_dec_alpha: float + ): + # initialize parameters + if init_type != "pytorch": + initialize(self, init_type) + + # initialize alpha in scaled positional encoding + if self.encoder_type == "transformer" and self.use_scaled_pos_enc: + self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) + if self.decoder_type == "transformer" and self.use_scaled_pos_enc: + self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) diff --git a/espnet2/tts/fastspeech.py b/espnet2/tts/fastspeech.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2bd9c005f3933a4804c1f8ca570f20bce3bdcc --- /dev/null +++ b/espnet2/tts/fastspeech.py @@ -0,0 +1,640 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Fastspeech related modules for ESPnet2.""" + +import logging + +from typing import Dict +from typing import Sequence +from typing import Tuple + +import torch +import torch.nn.functional as F + +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.conformer.encoder import ( + Encoder as ConformerEncoder, # noqa: H301 +) +from espnet.nets.pytorch_backend.e2e_tts_fastspeech import ( + FeedForwardTransformerLoss as FastSpeechLoss, # NOQA +) +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor +from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import ( + Encoder as TransformerEncoder, # noqa: H301 +) + +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.torch_utils.initialize import initialize +from espnet2.tts.abs_tts import AbsTTS +from espnet2.tts.gst.style_encoder import StyleEncoder + + +class FastSpeech(AbsTTS): + """FastSpeech module for end-to-end text-to-speech. + + This is a module of FastSpeech, feed-forward Transformer with duration predictor + described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which + does not require any auto-regressive processing during inference, resulting in + fast decoding compared with auto-regressive Transformer. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + elayers (int, optional): Number of encoder layers. + eunits (int, optional): Number of encoder hidden units. + dlayers (int, optional): Number of decoder layers. + dunits (int, optional): Number of decoder hidden units. + use_scaled_pos_enc (bool, optional): + Whether to use trainable scaled positional encoding. + encoder_normalize_before (bool, optional): + Whether to perform layer normalization before encoder block. + decoder_normalize_before (bool, optional): + Whether to perform layer normalization before decoder block. + encoder_concat_after (bool, optional): Whether to concatenate attention + layer's input and output in encoder. + decoder_concat_after (bool, optional): Whether to concatenate attention + layer's input and output in decoder. + duration_predictor_layers (int, optional): Number of duration predictor layers. + duration_predictor_chans (int, optional): Number of duration predictor channels. + duration_predictor_kernel_size (int, optional): + Kernel size of duration predictor. + spk_embed_dim (int, optional): Number of speaker embedding dimensions. + spk_embed_integration_type: How to integrate speaker embedding. + use_gst (str, optional): Whether to use global style token. + gst_tokens (int, optional): The number of GST embeddings. + gst_heads (int, optional): The number of heads in GST multihead attention. + gst_conv_layers (int, optional): The number of conv layers in GST. + gst_conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in GST. + gst_conv_kernel_size (int, optional): Kernal size of conv layers in GST. + gst_conv_stride (int, optional): Stride size of conv layers in GST. + gst_gru_layers (int, optional): The number of GRU layers in GST. + gst_gru_units (int, optional): The number of GRU units in GST. + reduction_factor (int, optional): Reduction factor. + transformer_enc_dropout_rate (float, optional): + Dropout rate in encoder except attention & positional encoding. + transformer_enc_positional_dropout_rate (float, optional): + Dropout rate after encoder positional encoding. + transformer_enc_attn_dropout_rate (float, optional): + Dropout rate in encoder self-attention module. + transformer_dec_dropout_rate (float, optional): + Dropout rate in decoder except attention & positional encoding. + transformer_dec_positional_dropout_rate (float, optional): + Dropout rate after decoder positional encoding. + transformer_dec_attn_dropout_rate (float, optional): + Dropout rate in deocoder self-attention module. + init_type (str, optional): + How to initialize transformer parameters. + init_enc_alpha (float, optional): + Initial value of alpha in scaled pos encoding of the encoder. + init_dec_alpha (float, optional): + Initial value of alpha in scaled pos encoding of the decoder. + use_masking (bool, optional): + Whether to apply masking for padded part in loss calculation. + use_weighted_masking (bool, optional): + Whether to apply weighted masking in loss calculation. + + """ + + def __init__( + self, + # network structure related + idim: int, + odim: int, + adim: int = 384, + aheads: int = 4, + elayers: int = 6, + eunits: int = 1536, + dlayers: int = 6, + dunits: int = 1536, + postnet_layers: int = 5, + postnet_chans: int = 512, + postnet_filts: int = 5, + positionwise_layer_type: str = "conv1d", + positionwise_conv_kernel_size: int = 1, + use_scaled_pos_enc: bool = True, + use_batch_norm: bool = True, + encoder_normalize_before: bool = True, + decoder_normalize_before: bool = True, + encoder_concat_after: bool = False, + decoder_concat_after: bool = False, + duration_predictor_layers: int = 2, + duration_predictor_chans: int = 384, + duration_predictor_kernel_size: int = 3, + reduction_factor: int = 1, + encoder_type: str = "transformer", + decoder_type: str = "transformer", + # only for conformer + conformer_rel_pos_type: str = "legacy", + conformer_pos_enc_layer_type: str = "rel_pos", + conformer_self_attn_layer_type: str = "rel_selfattn", + conformer_activation_type: str = "swish", + use_macaron_style_in_conformer: bool = True, + use_cnn_in_conformer: bool = True, + conformer_enc_kernel_size: int = 7, + conformer_dec_kernel_size: int = 31, + zero_triu: bool = False, + # pretrained spk emb + spk_embed_dim: int = None, + spk_embed_integration_type: str = "add", + # GST + use_gst: bool = False, + gst_tokens: int = 10, + gst_heads: int = 4, + gst_conv_layers: int = 6, + gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), + gst_conv_kernel_size: int = 3, + gst_conv_stride: int = 2, + gst_gru_layers: int = 1, + gst_gru_units: int = 128, + # training related + transformer_enc_dropout_rate: float = 0.1, + transformer_enc_positional_dropout_rate: float = 0.1, + transformer_enc_attn_dropout_rate: float = 0.1, + transformer_dec_dropout_rate: float = 0.1, + transformer_dec_positional_dropout_rate: float = 0.1, + transformer_dec_attn_dropout_rate: float = 0.1, + duration_predictor_dropout_rate: float = 0.1, + postnet_dropout_rate: float = 0.5, + init_type: str = "xavier_uniform", + init_enc_alpha: float = 1.0, + init_dec_alpha: float = 1.0, + use_masking: bool = False, + use_weighted_masking: bool = False, + ): + """Initialize FastSpeech module.""" + assert check_argument_types() + super().__init__() + + # store hyperparameters + self.idim = idim + self.odim = odim + self.eos = idim - 1 + self.reduction_factor = reduction_factor + self.encoder_type = encoder_type + self.decoder_type = decoder_type + self.use_scaled_pos_enc = use_scaled_pos_enc + self.use_gst = use_gst + self.spk_embed_dim = spk_embed_dim + if self.spk_embed_dim is not None: + self.spk_embed_integration_type = spk_embed_integration_type + + # use idx 0 as padding idx + self.padding_idx = 0 + + # get positional encoding class + pos_enc_class = ( + ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding + ) + + # check relative positional encoding compatibility + if "conformer" in [encoder_type, decoder_type]: + if conformer_rel_pos_type == "legacy": + if conformer_pos_enc_layer_type == "rel_pos": + conformer_pos_enc_layer_type = "legacy_rel_pos" + logging.warning( + "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " + "due to the compatibility. If you want to use the new one, " + "please use conformer_pos_enc_layer_type = 'latest'." + ) + if conformer_self_attn_layer_type == "rel_selfattn": + conformer_self_attn_layer_type = "legacy_rel_selfattn" + logging.warning( + "Fallback to " + "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " + "due to the compatibility. If you want to use the new one, " + "please use conformer_pos_enc_layer_type = 'latest'." + ) + elif conformer_rel_pos_type == "latest": + assert conformer_pos_enc_layer_type != "legacy_rel_pos" + assert conformer_self_attn_layer_type != "legacy_rel_selfattn" + else: + raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}") + + # define encoder + encoder_input_layer = torch.nn.Embedding( + num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx + ) + if encoder_type == "transformer": + self.encoder = TransformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + ) + elif encoder_type == "conformer": + self.encoder = ConformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_enc_kernel_size, + ) + else: + raise ValueError(f"{encoder_type} is not supported.") + + # define GST + if self.use_gst: + self.gst = StyleEncoder( + idim=odim, # the input is mel-spectrogram + gst_tokens=gst_tokens, + gst_token_dim=adim, + gst_heads=gst_heads, + conv_layers=gst_conv_layers, + conv_chans_list=gst_conv_chans_list, + conv_kernel_size=gst_conv_kernel_size, + conv_stride=gst_conv_stride, + gru_layers=gst_gru_layers, + gru_units=gst_gru_units, + ) + + # define additional projection for speaker embedding + if self.spk_embed_dim is not None: + if self.spk_embed_integration_type == "add": + self.projection = torch.nn.Linear(self.spk_embed_dim, adim) + else: + self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) + + # define duration predictor + self.duration_predictor = DurationPredictor( + idim=adim, + n_layers=duration_predictor_layers, + n_chans=duration_predictor_chans, + kernel_size=duration_predictor_kernel_size, + dropout_rate=duration_predictor_dropout_rate, + ) + + # define length regulator + self.length_regulator = LengthRegulator() + + # define decoder + # NOTE: we use encoder as decoder + # because fastspeech's decoder is the same as encoder + if decoder_type == "transformer": + self.decoder = TransformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + ) + elif decoder_type == "conformer": + self.decoder = ConformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_dec_kernel_size, + ) + else: + raise ValueError(f"{decoder_type} is not supported.") + + # define final projection + self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) + + # define postnet + self.postnet = ( + None + if postnet_layers == 0 + else Postnet( + idim=idim, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=use_batch_norm, + dropout_rate=postnet_dropout_rate, + ) + ) + + # initialize parameters + self._reset_parameters( + init_type=init_type, + init_enc_alpha=init_enc_alpha, + init_dec_alpha=init_dec_alpha, + ) + + # define criterions + self.criterion = FastSpeechLoss( + use_masking=use_masking, use_weighted_masking=use_weighted_masking + ) + + def _forward( + self, + xs: torch.Tensor, + ilens: torch.Tensor, + ys: torch.Tensor = None, + olens: torch.Tensor = None, + ds: torch.Tensor = None, + spembs: torch.Tensor = None, + is_inference: bool = False, + alpha: float = 1.0, + ) -> Sequence[torch.Tensor]: + # forward encoder + x_masks = self._source_mask(ilens) + hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) + + # integrate with GST + if self.use_gst: + style_embs = self.gst(ys) + hs = hs + style_embs.unsqueeze(1) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + # forward duration predictor and length regulator + d_masks = make_pad_mask(ilens).to(xs.device) + if is_inference: + d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) + hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim) + else: + d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) + hs = self.length_regulator(hs, ds) # (B, Lmax, adim) + + # forward decoder + if olens is not None and not is_inference: + if self.reduction_factor > 1: + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + olens_in = olens + h_masks = self._source_mask(olens_in) + else: + h_masks = None + zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) + before_outs = self.feat_out(zs).view( + zs.size(0), -1, self.odim + ) # (B, Lmax, odim) + + # postnet -> (B, Lmax//r * r, odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose(1, 2) + ).transpose(1, 2) + + return before_outs, after_outs, d_outs + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + durations: torch.Tensor, + durations_lengths: torch.Tensor, + spembs: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Calculate forward propagation. + + Args: + text (LongTensor): Batch of padded character ids (B, Tmax). + text_lengths (LongTensor): Batch of lengths of each input (B,). + speech (Tensor): Batch of padded target features (B, Lmax, odim). + speech_lengths (LongTensor): Batch of the lengths of each target (B,). + durations (LongTensor): Batch of padded durations (B, Tmax + 1). + durations_lengths (LongTensor): Batch of duration lengths (B, Tmax + 1). + spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Loss scalar value. + Dict: Statistics to be monitored. + Tensor: Weight value. + + """ + text = text[:, : text_lengths.max()] # for data-parallel + speech = speech[:, : speech_lengths.max()] # for data-parallel + durations = durations[:, : durations_lengths.max()] # for data-parallel + + batch_size = text.size(0) + + # Add eos at the last of sequence + xs = F.pad(text, [0, 1], "constant", self.padding_idx) + for i, l in enumerate(text_lengths): + xs[i, l] = self.eos + ilens = text_lengths + 1 + + ys, ds = speech, durations + olens = speech_lengths + + # forward propagation + before_outs, after_outs, d_outs = self._forward( + xs, ilens, ys, olens, ds, spembs=spembs, is_inference=False + ) + + # modifiy mod part of groundtruth + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + max_olen = max(olens) + ys = ys[:, :max_olen] + + # calculate loss + if self.postnet is None: + after_outs = None + l1_loss, duration_loss = self.criterion( + after_outs, before_outs, d_outs, ys, ds, ilens, olens + ) + loss = l1_loss + duration_loss + + stats = dict( + l1_loss=l1_loss.item(), + duration_loss=duration_loss.item(), + loss=loss.item(), + ) + + # report extra information + if self.encoder_type == "transformer" and self.use_scaled_pos_enc: + stats.update( + encoder_alpha=self.encoder.embed[-1].alpha.data.item(), + ) + if self.decoder_type == "transformer" and self.use_scaled_pos_enc: + stats.update( + decoder_alpha=self.decoder.embed[-1].alpha.data.item(), + ) + + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def inference( + self, + text: torch.Tensor, + speech: torch.Tensor = None, + spembs: torch.Tensor = None, + durations: torch.Tensor = None, + alpha: float = 1.0, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the sequence of features given the sequences of characters. + + Args: + text (LongTensor): Input sequence of characters (T,). + speech (Tensor, optional): Feature sequence to extract style (N, idim). + spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). + durations (LongTensor, optional): Groundtruth of duration (T + 1,). + alpha (float, optional): Alpha to control the speed. + use_teacher_forcing (bool, optional): Whether to use teacher forcing. + If true, groundtruth of duration, pitch and energy will be used. + + Returns: + Tensor: Output sequence of features (L, odim). + None: Dummy for compatibility. + None: Dummy for compatibility. + + """ + x, y = text, speech + spemb, d = spembs, durations + + # add eos at the last of sequence + x = F.pad(x, [0, 1], "constant", self.eos) + + # setup batch axis + ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) + xs, ys = x.unsqueeze(0), None + if y is not None: + ys = y.unsqueeze(0) + if spemb is not None: + spembs = spemb.unsqueeze(0) + + if use_teacher_forcing: + # use groundtruth of duration, pitch, and energy + ds = d.unsqueeze(0) + _, outs, *_ = self._forward( + xs, + ilens, + ys, + ds=ds, + spembs=spembs, + ) # (1, L, odim) + else: + # inference + _, outs, _ = self._forward( + xs, + ilens, + ys, + spembs=spembs, + is_inference=True, + alpha=alpha, + ) # (1, L, odim) + + return outs[0], None, None + + def _integrate_with_spk_embed( + self, hs: torch.Tensor, spembs: torch.Tensor + ) -> torch.Tensor: + """Integrate speaker embedding with hidden states. + + Args: + hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). + spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). + + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spembs = self.projection(F.normalize(spembs)) + hs = hs + spembs.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds and then apply projection + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = self.projection(torch.cat([hs, spembs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs + + def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: + """Make masks for self-attention. + + Args: + ilens (LongTensor): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + return x_masks.unsqueeze(-2) + + def _reset_parameters( + self, init_type: str, init_enc_alpha: float, init_dec_alpha: float + ): + # initialize parameters + if init_type != "pytorch": + initialize(self, init_type) + + # initialize alpha in scaled positional encoding + if self.encoder_type == "transformer" and self.use_scaled_pos_enc: + self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) + if self.decoder_type == "transformer" and self.use_scaled_pos_enc: + self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) diff --git a/espnet2/tts/fastspeech2.py b/espnet2/tts/fastspeech2.py new file mode 100644 index 0000000000000000000000000000000000000000..de6c5657dea1c31386fef2a7ddbebc4cf6767a76 --- /dev/null +++ b/espnet2/tts/fastspeech2.py @@ -0,0 +1,803 @@ +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Fastspeech2 related modules for ESPnet2.""" + +import logging + +from typing import Dict +from typing import Sequence +from typing import Tuple + +import torch +import torch.nn.functional as F + +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.conformer.encoder import ( + Encoder as ConformerEncoder, # noqa: H301 +) +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import DurationPredictor +from espnet.nets.pytorch_backend.fastspeech.duration_predictor import ( + DurationPredictorLoss, # noqa: H301 +) +from espnet.nets.pytorch_backend.fastspeech.length_regulator import LengthRegulator +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import ( + Encoder as TransformerEncoder, # noqa: H301 +) + +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.torch_utils.initialize import initialize +from espnet2.tts.abs_tts import AbsTTS +from espnet2.tts.gst.style_encoder import StyleEncoder +from espnet2.tts.variance_predictor import VariancePredictor + + +class FastSpeech2(AbsTTS): + """FastSpeech2 module. + + This is a module of FastSpeech2 described in `FastSpeech 2: Fast and + High-Quality End-to-End Text to Speech`_. Instead of quantized pitch and + energy, we use token-averaged value introduced in `FastPitch: Parallel + Text-to-speech with Pitch Prediction`_. + + .. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`: + https://arxiv.org/abs/2006.04558 + .. _`FastPitch: Parallel Text-to-speech with Pitch Prediction`: + https://arxiv.org/abs/2006.06873 + + """ + + def __init__( + self, + # network structure related + idim: int, + odim: int, + adim: int = 384, + aheads: int = 4, + elayers: int = 6, + eunits: int = 1536, + dlayers: int = 6, + dunits: int = 1536, + postnet_layers: int = 5, + postnet_chans: int = 512, + postnet_filts: int = 5, + positionwise_layer_type: str = "conv1d", + positionwise_conv_kernel_size: int = 1, + use_scaled_pos_enc: bool = True, + use_batch_norm: bool = True, + encoder_normalize_before: bool = True, + decoder_normalize_before: bool = True, + encoder_concat_after: bool = False, + decoder_concat_after: bool = False, + reduction_factor: int = 1, + encoder_type: str = "transformer", + decoder_type: str = "transformer", + # only for conformer + conformer_rel_pos_type: str = "legacy", + conformer_pos_enc_layer_type: str = "rel_pos", + conformer_self_attn_layer_type: str = "rel_selfattn", + conformer_activation_type: str = "swish", + use_macaron_style_in_conformer: bool = True, + use_cnn_in_conformer: bool = True, + zero_triu: bool = False, + conformer_enc_kernel_size: int = 7, + conformer_dec_kernel_size: int = 31, + # duration predictor + duration_predictor_layers: int = 2, + duration_predictor_chans: int = 384, + duration_predictor_kernel_size: int = 3, + # energy predictor + energy_predictor_layers: int = 2, + energy_predictor_chans: int = 384, + energy_predictor_kernel_size: int = 3, + energy_predictor_dropout: float = 0.5, + energy_embed_kernel_size: int = 9, + energy_embed_dropout: float = 0.5, + stop_gradient_from_energy_predictor: bool = False, + # pitch predictor + pitch_predictor_layers: int = 2, + pitch_predictor_chans: int = 384, + pitch_predictor_kernel_size: int = 3, + pitch_predictor_dropout: float = 0.5, + pitch_embed_kernel_size: int = 9, + pitch_embed_dropout: float = 0.5, + stop_gradient_from_pitch_predictor: bool = False, + # pretrained spk emb + spk_embed_dim: int = None, + spk_embed_integration_type: str = "add", + # GST + use_gst: bool = False, + gst_tokens: int = 10, + gst_heads: int = 4, + gst_conv_layers: int = 6, + gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), + gst_conv_kernel_size: int = 3, + gst_conv_stride: int = 2, + gst_gru_layers: int = 1, + gst_gru_units: int = 128, + # training related + transformer_enc_dropout_rate: float = 0.1, + transformer_enc_positional_dropout_rate: float = 0.1, + transformer_enc_attn_dropout_rate: float = 0.1, + transformer_dec_dropout_rate: float = 0.1, + transformer_dec_positional_dropout_rate: float = 0.1, + transformer_dec_attn_dropout_rate: float = 0.1, + duration_predictor_dropout_rate: float = 0.1, + postnet_dropout_rate: float = 0.5, + init_type: str = "xavier_uniform", + init_enc_alpha: float = 1.0, + init_dec_alpha: float = 1.0, + use_masking: bool = False, + use_weighted_masking: bool = False, + ): + """Initialize FastSpeech2 module.""" + assert check_argument_types() + super().__init__() + + # store hyperparameters + self.idim = idim + self.odim = odim + self.eos = idim - 1 + self.reduction_factor = reduction_factor + self.encoder_type = encoder_type + self.decoder_type = decoder_type + self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor + self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor + self.use_scaled_pos_enc = use_scaled_pos_enc + self.use_gst = use_gst + self.spk_embed_dim = spk_embed_dim + if self.spk_embed_dim is not None: + self.spk_embed_integration_type = spk_embed_integration_type + + # use idx 0 as padding idx + self.padding_idx = 0 + + # get positional encoding class + pos_enc_class = ( + ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding + ) + + # check relative positional encoding compatibility + if "conformer" in [encoder_type, decoder_type]: + if conformer_rel_pos_type == "legacy": + if conformer_pos_enc_layer_type == "rel_pos": + conformer_pos_enc_layer_type = "legacy_rel_pos" + logging.warning( + "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " + "due to the compatibility. If you want to use the new one, " + "please use conformer_pos_enc_layer_type = 'latest'." + ) + if conformer_self_attn_layer_type == "rel_selfattn": + conformer_self_attn_layer_type = "legacy_rel_selfattn" + logging.warning( + "Fallback to " + "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " + "due to the compatibility. If you want to use the new one, " + "please use conformer_pos_enc_layer_type = 'latest'." + ) + elif conformer_rel_pos_type == "latest": + assert conformer_pos_enc_layer_type != "legacy_rel_pos" + assert conformer_self_attn_layer_type != "legacy_rel_selfattn" + else: + raise ValueError(f"Unknown rel_pos_type: {conformer_rel_pos_type}") + + # define encoder + encoder_input_layer = torch.nn.Embedding( + num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx + ) + if encoder_type == "transformer": + self.encoder = TransformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + ) + elif encoder_type == "conformer": + self.encoder = ConformerEncoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_enc_kernel_size, + zero_triu=zero_triu, + ) + else: + raise ValueError(f"{encoder_type} is not supported.") + + # define GST + if self.use_gst: + self.gst = StyleEncoder( + idim=odim, # the input is mel-spectrogram + gst_tokens=gst_tokens, + gst_token_dim=adim, + gst_heads=gst_heads, + conv_layers=gst_conv_layers, + conv_chans_list=gst_conv_chans_list, + conv_kernel_size=gst_conv_kernel_size, + conv_stride=gst_conv_stride, + gru_layers=gst_gru_layers, + gru_units=gst_gru_units, + ) + + # define additional projection for speaker embedding + if self.spk_embed_dim is not None: + if self.spk_embed_integration_type == "add": + self.projection = torch.nn.Linear(self.spk_embed_dim, adim) + else: + self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) + + # define duration predictor + self.duration_predictor = DurationPredictor( + idim=adim, + n_layers=duration_predictor_layers, + n_chans=duration_predictor_chans, + kernel_size=duration_predictor_kernel_size, + dropout_rate=duration_predictor_dropout_rate, + ) + + # define pitch predictor + self.pitch_predictor = VariancePredictor( + idim=adim, + n_layers=pitch_predictor_layers, + n_chans=pitch_predictor_chans, + kernel_size=pitch_predictor_kernel_size, + dropout_rate=pitch_predictor_dropout, + ) + # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg + self.pitch_embed = torch.nn.Sequential( + torch.nn.Conv1d( + in_channels=1, + out_channels=adim, + kernel_size=pitch_embed_kernel_size, + padding=(pitch_embed_kernel_size - 1) // 2, + ), + torch.nn.Dropout(pitch_embed_dropout), + ) + + # define energy predictor + self.energy_predictor = VariancePredictor( + idim=adim, + n_layers=energy_predictor_layers, + n_chans=energy_predictor_chans, + kernel_size=energy_predictor_kernel_size, + dropout_rate=energy_predictor_dropout, + ) + # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg + self.energy_embed = torch.nn.Sequential( + torch.nn.Conv1d( + in_channels=1, + out_channels=adim, + kernel_size=energy_embed_kernel_size, + padding=(energy_embed_kernel_size - 1) // 2, + ), + torch.nn.Dropout(energy_embed_dropout), + ) + + # define length regulator + self.length_regulator = LengthRegulator() + + # define decoder + # NOTE: we use encoder as decoder + # because fastspeech's decoder is the same as encoder + if decoder_type == "transformer": + self.decoder = TransformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + ) + elif decoder_type == "conformer": + self.decoder = ConformerEncoder( + idim=0, + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + input_layer=None, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + attention_dropout_rate=transformer_dec_attn_dropout_rate, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + macaron_style=use_macaron_style_in_conformer, + pos_enc_layer_type=conformer_pos_enc_layer_type, + selfattention_layer_type=conformer_self_attn_layer_type, + activation_type=conformer_activation_type, + use_cnn_module=use_cnn_in_conformer, + cnn_module_kernel=conformer_dec_kernel_size, + ) + else: + raise ValueError(f"{decoder_type} is not supported.") + + # define final projection + self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) + + # define postnet + self.postnet = ( + None + if postnet_layers == 0 + else Postnet( + idim=idim, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=use_batch_norm, + dropout_rate=postnet_dropout_rate, + ) + ) + + # initialize parameters + self._reset_parameters( + init_type=init_type, + init_enc_alpha=init_enc_alpha, + init_dec_alpha=init_dec_alpha, + ) + + # define criterions + self.criterion = FastSpeech2Loss( + use_masking=use_masking, use_weighted_masking=use_weighted_masking + ) + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + durations: torch.Tensor, + durations_lengths: torch.Tensor, + pitch: torch.Tensor, + pitch_lengths: torch.Tensor, + energy: torch.Tensor, + energy_lengths: torch.Tensor, + spembs: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Calculate forward propagation. + + Args: + text (LongTensor): Batch of padded token ids (B, Tmax). + text_lengths (LongTensor): Batch of lengths of each input (B,). + speech (Tensor): Batch of padded target features (B, Lmax, odim). + speech_lengths (LongTensor): Batch of the lengths of each target (B,). + durations (LongTensor): Batch of padded durations (B, Tmax + 1). + durations_lengths (LongTensor): Batch of duration lengths (B, Tmax + 1). + pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1). + pitch_lengths (LongTensor): Batch of pitch lengths (B, Tmax + 1). + energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1). + energy_lengths (LongTensor): Batch of energy lengths (B, Tmax + 1). + spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Loss scalar value. + Dict: Statistics to be monitored. + Tensor: Weight value. + + """ + text = text[:, : text_lengths.max()] # for data-parallel + speech = speech[:, : speech_lengths.max()] # for data-parallel + durations = durations[:, : durations_lengths.max()] # for data-parallel + pitch = pitch[:, : pitch_lengths.max()] # for data-parallel + energy = energy[:, : energy_lengths.max()] # for data-parallel + + batch_size = text.size(0) + + # Add eos at the last of sequence + xs = F.pad(text, [0, 1], "constant", self.padding_idx) + for i, l in enumerate(text_lengths): + xs[i, l] = self.eos + ilens = text_lengths + 1 + + ys, ds, ps, es = speech, durations, pitch, energy + olens = speech_lengths + + # forward propagation + before_outs, after_outs, d_outs, p_outs, e_outs = self._forward( + xs, ilens, ys, olens, ds, ps, es, spembs=spembs, is_inference=False + ) + + # modify mod part of groundtruth + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + max_olen = max(olens) + ys = ys[:, :max_olen] + + # calculate loss + if self.postnet is None: + after_outs = None + + # calculate loss + l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( + after_outs=after_outs, + before_outs=before_outs, + d_outs=d_outs, + p_outs=p_outs, + e_outs=e_outs, + ys=ys, + ds=ds, + ps=ps, + es=es, + ilens=ilens, + olens=olens, + ) + loss = l1_loss + duration_loss + pitch_loss + energy_loss + + stats = dict( + l1_loss=l1_loss.item(), + duration_loss=duration_loss.item(), + pitch_loss=pitch_loss.item(), + energy_loss=energy_loss.item(), + loss=loss.item(), + ) + + # report extra information + if self.encoder_type == "transformer" and self.use_scaled_pos_enc: + stats.update( + encoder_alpha=self.encoder.embed[-1].alpha.data.item(), + ) + if self.decoder_type == "transformer" and self.use_scaled_pos_enc: + stats.update( + decoder_alpha=self.decoder.embed[-1].alpha.data.item(), + ) + + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def _forward( + self, + xs: torch.Tensor, + ilens: torch.Tensor, + ys: torch.Tensor = None, + olens: torch.Tensor = None, + ds: torch.Tensor = None, + ps: torch.Tensor = None, + es: torch.Tensor = None, + spembs: torch.Tensor = None, + is_inference: bool = False, + alpha: float = 1.0, + ) -> Sequence[torch.Tensor]: + # forward encoder + x_masks = self._source_mask(ilens) + hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) + + # integrate with GST + if self.use_gst: + style_embs = self.gst(ys) + hs = hs + style_embs.unsqueeze(1) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + # forward duration predictor and variance predictors + d_masks = make_pad_mask(ilens).to(xs.device) + + if self.stop_gradient_from_pitch_predictor: + p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) + else: + p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) + if self.stop_gradient_from_energy_predictor: + e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) + else: + e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) + + if is_inference: + d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) + # use prediction in inference + p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) + e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) + hs = hs + e_embs + p_embs + hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim) + else: + d_outs = self.duration_predictor(hs, d_masks) + # use groundtruth in training + p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) + e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) + hs = hs + e_embs + p_embs + hs = self.length_regulator(hs, ds) # (B, Lmax, adim) + + # forward decoder + if olens is not None and not is_inference: + if self.reduction_factor > 1: + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + olens_in = olens + h_masks = self._source_mask(olens_in) + else: + h_masks = None + zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) + before_outs = self.feat_out(zs).view( + zs.size(0), -1, self.odim + ) # (B, Lmax, odim) + + # postnet -> (B, Lmax//r * r, odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose(1, 2) + ).transpose(1, 2) + + return before_outs, after_outs, d_outs, p_outs, e_outs + + def inference( + self, + text: torch.Tensor, + speech: torch.Tensor = None, + spembs: torch.Tensor = None, + durations: torch.Tensor = None, + pitch: torch.Tensor = None, + energy: torch.Tensor = None, + alpha: float = 1.0, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the sequence of features given the sequences of characters. + + Args: + text (LongTensor): Input sequence of characters (T,). + speech (Tensor, optional): Feature sequence to extract style (N, idim). + spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). + durations (LongTensor, optional): Groundtruth of duration (T + 1,). + pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1). + energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1). + alpha (float, optional): Alpha to control the speed. + use_teacher_forcing (bool, optional): Whether to use teacher forcing. + If true, groundtruth of duration, pitch and energy will be used. + + Returns: + Tensor: Output sequence of features (L, odim). + None: Dummy for compatibility. + None: Dummy for compatibility. + + """ + x, y = text, speech + spemb, d, p, e = spembs, durations, pitch, energy + + # add eos at the last of sequence + x = F.pad(x, [0, 1], "constant", self.eos) + + # setup batch axis + ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) + xs, ys = x.unsqueeze(0), None + if y is not None: + ys = y.unsqueeze(0) + if spemb is not None: + spembs = spemb.unsqueeze(0) + + if use_teacher_forcing: + # use groundtruth of duration, pitch, and energy + ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0) + _, outs, *_ = self._forward( + xs, + ilens, + ys, + ds=ds, + ps=ps, + es=es, + spembs=spembs, + ) # (1, L, odim) + else: + _, outs, *_ = self._forward( + xs, + ilens, + ys, + spembs=spembs, + is_inference=True, + alpha=alpha, + ) # (1, L, odim) + + return outs[0], None, None + + def _integrate_with_spk_embed( + self, hs: torch.Tensor, spembs: torch.Tensor + ) -> torch.Tensor: + """Integrate speaker embedding with hidden states. + + Args: + hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). + spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). + + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spembs = self.projection(F.normalize(spembs)) + hs = hs + spembs.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds and then apply projection + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = self.projection(torch.cat([hs, spembs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs + + def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: + """Make masks for self-attention. + + Args: + ilens (LongTensor): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + return x_masks.unsqueeze(-2) + + def _reset_parameters( + self, init_type: str, init_enc_alpha: float, init_dec_alpha: float + ): + # initialize parameters + if init_type != "pytorch": + initialize(self, init_type) + + # initialize alpha in scaled positional encoding + if self.encoder_type == "transformer" and self.use_scaled_pos_enc: + self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) + if self.decoder_type == "transformer" and self.use_scaled_pos_enc: + self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) + + +class FastSpeech2Loss(torch.nn.Module): + """Loss function module for FastSpeech2.""" + + def __init__(self, use_masking: bool = True, use_weighted_masking: bool = False): + """Initialize feed-forward Transformer loss module. + + Args: + use_masking (bool): + Whether to apply masking for padded part in loss calculation. + use_weighted_masking (bool): + Whether to weighted masking in loss calculation. + + """ + assert check_argument_types() + super().__init__() + + assert (use_masking != use_weighted_masking) or not use_masking + self.use_masking = use_masking + self.use_weighted_masking = use_weighted_masking + + # define criterions + reduction = "none" if self.use_weighted_masking else "mean" + self.l1_criterion = torch.nn.L1Loss(reduction=reduction) + self.mse_criterion = torch.nn.MSELoss(reduction=reduction) + self.duration_criterion = DurationPredictorLoss(reduction=reduction) + + def forward( + self, + after_outs: torch.Tensor, + before_outs: torch.Tensor, + d_outs: torch.Tensor, + p_outs: torch.Tensor, + e_outs: torch.Tensor, + ys: torch.Tensor, + ds: torch.Tensor, + ps: torch.Tensor, + es: torch.Tensor, + ilens: torch.Tensor, + olens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Calculate forward propagation. + + Args: + after_outs (Tensor): Batch of outputs after postnets (B, Lmax, odim). + before_outs (Tensor): Batch of outputs before postnets (B, Lmax, odim). + d_outs (LongTensor): Batch of outputs of duration predictor (B, Tmax). + p_outs (Tensor): Batch of outputs of pitch predictor (B, Tmax, 1). + e_outs (Tensor): Batch of outputs of energy predictor (B, Tmax, 1). + ys (Tensor): Batch of target features (B, Lmax, odim). + ds (LongTensor): Batch of durations (B, Tmax). + ps (Tensor): Batch of target token-averaged pitch (B, Tmax, 1). + es (Tensor): Batch of target token-averaged energy (B, Tmax, 1). + ilens (LongTensor): Batch of the lengths of each input (B,). + olens (LongTensor): Batch of the lengths of each target (B,). + + Returns: + Tensor: L1 loss value. + Tensor: Duration predictor loss value. + Tensor: Pitch predictor loss value. + Tensor: Energy predictor loss value. + + """ + # apply mask to remove padded part + if self.use_masking: + out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) + before_outs = before_outs.masked_select(out_masks) + if after_outs is not None: + after_outs = after_outs.masked_select(out_masks) + ys = ys.masked_select(out_masks) + duration_masks = make_non_pad_mask(ilens).to(ys.device) + d_outs = d_outs.masked_select(duration_masks) + ds = ds.masked_select(duration_masks) + pitch_masks = make_non_pad_mask(ilens).unsqueeze(-1).to(ys.device) + p_outs = p_outs.masked_select(pitch_masks) + e_outs = e_outs.masked_select(pitch_masks) + ps = ps.masked_select(pitch_masks) + es = es.masked_select(pitch_masks) + + # calculate loss + l1_loss = self.l1_criterion(before_outs, ys) + if after_outs is not None: + l1_loss += self.l1_criterion(after_outs, ys) + duration_loss = self.duration_criterion(d_outs, ds) + pitch_loss = self.mse_criterion(p_outs, ps) + energy_loss = self.mse_criterion(e_outs, es) + + # make weighted mask and apply it + if self.use_weighted_masking: + out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device) + out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float() + out_weights /= ys.size(0) * ys.size(2) + duration_masks = make_non_pad_mask(ilens).to(ys.device) + duration_weights = ( + duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float() + ) + duration_weights /= ds.size(0) + + # apply weight + l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum() + duration_loss = ( + duration_loss.mul(duration_weights).masked_select(duration_masks).sum() + ) + pitch_masks = duration_masks.unsqueeze(-1) + pitch_weights = duration_weights.unsqueeze(-1) + pitch_loss = pitch_loss.mul(pitch_weights).masked_select(pitch_masks).sum() + energy_loss = ( + energy_loss.mul(pitch_weights).masked_select(pitch_masks).sum() + ) + + return l1_loss, duration_loss, pitch_loss, energy_loss diff --git a/espnet2/tts/feats_extract/__init__.py b/espnet2/tts/feats_extract/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/tts/feats_extract/abs_feats_extract.py b/espnet2/tts/feats_extract/abs_feats_extract.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a459e5be7235026a880f1b776efdcd5ed8825d --- /dev/null +++ b/espnet2/tts/feats_extract/abs_feats_extract.py @@ -0,0 +1,23 @@ +from abc import ABC +from abc import abstractmethod +from typing import Any +from typing import Dict + +import torch +from typing import Tuple + + +class AbsFeatsExtract(torch.nn.Module, ABC): + @abstractmethod + def output_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_parameters(self) -> Dict[str, Any]: + raise NotImplementedError + + @abstractmethod + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError diff --git a/espnet2/tts/feats_extract/dio.py b/espnet2/tts/feats_extract/dio.py new file mode 100644 index 0000000000000000000000000000000000000000..48f7249aa1f9c8253ccdf48ac007375aca4c1829 --- /dev/null +++ b/espnet2/tts/feats_extract/dio.py @@ -0,0 +1,187 @@ +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""F0 extractor using DIO + Stonemask algorithm.""" + +import logging + +from typing import Any +from typing import Dict +from typing import Tuple +from typing import Union + +import humanfriendly +import numpy as np +import pyworld +import torch +import torch.nn.functional as F + +from scipy.interpolate import interp1d +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract + + +class Dio(AbsFeatsExtract): + """F0 estimation with dio + stonemask algortihm. + + This is f0 extractor based on dio + stonmask algorithm introduced in `WORLD: + a vocoder-based high-quality speech synthesis system for real-time applications`_. + + .. _`WORLD: a vocoder-based high-quality speech synthesis system for real-time + applications`: https://doi.org/10.1587/transinf.2015EDP7457 + + Note: + This module is based on NumPy implementation. Therefore, the computational graph + is not connected. + + Todo: + Replace this module with PyTorch-based implementation. + + """ + + def __init__( + self, + fs: Union[int, str] = 22050, + n_fft: int = 1024, + hop_length: int = 256, + f0min: int = 80, + f0max: int = 400, + use_token_averaged_f0: bool = True, + use_continuous_f0: bool = True, + use_log_f0: bool = True, + reduction_factor: int = None, + ): + assert check_argument_types() + super().__init__() + if isinstance(fs, str): + fs = humanfriendly.parse_size(fs) + self.fs = fs + self.n_fft = n_fft + self.hop_length = hop_length + self.frame_period = 1000 * hop_length / fs + self.f0min = f0min + self.f0max = f0max + self.use_token_averaged_f0 = use_token_averaged_f0 + self.use_continuous_f0 = use_continuous_f0 + self.use_log_f0 = use_log_f0 + if use_token_averaged_f0: + assert reduction_factor >= 1 + self.reduction_factor = reduction_factor + + def output_size(self) -> int: + return 1 + + def get_parameters(self) -> Dict[str, Any]: + return dict( + fs=self.fs, + n_fft=self.n_fft, + hop_length=self.hop_length, + f0min=self.f0min, + f0max=self.f0max, + use_token_averaged_f0=self.use_token_averaged_f0, + use_continuous_f0=self.use_continuous_f0, + use_log_f0=self.use_log_f0, + reduction_factor=self.reduction_factor, + ) + + def forward( + self, + input: torch.Tensor, + input_lengths: torch.Tensor = None, + feats_lengths: torch.Tensor = None, + durations: torch.Tensor = None, + durations_lengths: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If not provide, we assume that the inputs have the same length + if input_lengths is None: + input_lengths = ( + input.new_ones(input.shape[0], dtype=torch.long) * input.shape[1] + ) + + # F0 extraction + pitch = [self._calculate_f0(x[:xl]) for x, xl in zip(input, input_lengths)] + + # (Optional): Adjust length to match with the mel-spectrogram + if feats_lengths is not None: + pitch = [ + self._adjust_num_frames(p, fl).view(-1) + for p, fl in zip(pitch, feats_lengths) + ] + + # (Optional): Average by duration to calculate token-wise f0 + if self.use_token_averaged_f0: + durations = durations * self.reduction_factor + pitch = [ + self._average_by_duration(p, d).view(-1) + for p, d in zip(pitch, durations) + ] + pitch_lengths = durations_lengths + else: + pitch_lengths = input.new_tensor([len(p) for p in pitch], dtype=torch.long) + + # Padding + pitch = pad_list(pitch, 0.0) + + # Return with the shape (B, T, 1) + return pitch.unsqueeze(-1), pitch_lengths + + def _calculate_f0(self, input: torch.Tensor) -> torch.Tensor: + x = input.cpu().numpy().astype(np.double) + f0, timeaxis = pyworld.dio( + x, + self.fs, + f0_floor=self.f0min, + f0_ceil=self.f0max, + frame_period=self.frame_period, + ) + f0 = pyworld.stonemask(x, f0, timeaxis, self.fs) + if self.use_continuous_f0: + f0 = self._convert_to_continuous_f0(f0) + if self.use_log_f0: + nonzero_idxs = np.where(f0 != 0)[0] + f0[nonzero_idxs] = np.log(f0[nonzero_idxs]) + return input.new_tensor(f0.reshape(-1), dtype=torch.float) + + @staticmethod + def _adjust_num_frames(x: torch.Tensor, num_frames: torch.Tensor) -> torch.Tensor: + if num_frames > len(x): + x = F.pad(x, (0, num_frames - len(x))) + elif num_frames < len(x): + x = x[:num_frames] + return x + + @staticmethod + def _convert_to_continuous_f0(f0: np.array) -> np.array: + if (f0 == 0).all(): + logging.warn("All frames seems to be unvoiced.") + return f0 + + # padding start and end of f0 sequence + start_f0 = f0[f0 != 0][0] + end_f0 = f0[f0 != 0][-1] + start_idx = np.where(f0 == start_f0)[0][0] + end_idx = np.where(f0 == end_f0)[0][-1] + f0[:start_idx] = start_f0 + f0[end_idx:] = end_f0 + + # get non-zero frame index + nonzero_idxs = np.where(f0 != 0)[0] + + # perform linear interpolation + interp_fn = interp1d(nonzero_idxs, f0[nonzero_idxs]) + f0 = interp_fn(np.arange(0, f0.shape[0])) + + return f0 + + def _average_by_duration(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor: + assert 0 <= len(x) - d.sum() < self.reduction_factor + d_cumsum = F.pad(d.cumsum(dim=0), (1, 0)) + x_avg = [ + x[start:end].masked_select(x[start:end].gt(0.0)).mean(dim=0) + if len(x[start:end].masked_select(x[start:end].gt(0.0))) != 0 + else x.new_tensor(0.0) + for start, end in zip(d_cumsum[:-1], d_cumsum[1:]) + ] + return torch.stack(x_avg) diff --git a/espnet2/tts/feats_extract/energy.py b/espnet2/tts/feats_extract/energy.py new file mode 100644 index 0000000000000000000000000000000000000000..d80f3af53b5e59f44b427b3b0f9189ee20861baf --- /dev/null +++ b/espnet2/tts/feats_extract/energy.py @@ -0,0 +1,143 @@ +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Energy extractor.""" + +from typing import Any +from typing import Dict +from typing import Tuple +from typing import Union + +import humanfriendly +import torch +import torch.nn.functional as F + +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.nets_utils import pad_list +from espnet2.layers.stft import Stft +from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract + + +class Energy(AbsFeatsExtract): + """Energy extractor.""" + + def __init__( + self, + fs: Union[int, str] = 22050, + n_fft: int = 1024, + win_length: int = None, + hop_length: int = 256, + window: str = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + use_token_averaged_energy: bool = True, + reduction_factor: int = None, + ): + assert check_argument_types() + super().__init__() + if isinstance(fs, str): + fs = humanfriendly.parse_size(fs) + + self.fs = fs + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.use_token_averaged_energy = use_token_averaged_energy + if use_token_averaged_energy: + assert reduction_factor >= 1 + self.reduction_factor = reduction_factor + + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + + def output_size(self) -> int: + return 1 + + def get_parameters(self) -> Dict[str, Any]: + return dict( + fs=self.fs, + n_fft=self.n_fft, + hop_length=self.hop_length, + window=self.window, + win_length=self.win_length, + center=self.stft.center, + normalized=self.stft.normalized, + use_token_averaged_energy=self.use_token_averaged_energy, + reduction_factor=self.reduction_factor, + ) + + def forward( + self, + input: torch.Tensor, + input_lengths: torch.Tensor = None, + feats_lengths: torch.Tensor = None, + durations: torch.Tensor = None, + durations_lengths: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # If not provide, we assume that the inputs have the same length + if input_lengths is None: + input_lengths = ( + input.new_ones(input.shape[0], dtype=torch.long) * input.shape[1] + ) + + # Domain-conversion: e.g. Stft: time -> time-freq + input_stft, energy_lengths = self.stft(input, input_lengths) + + assert input_stft.dim() >= 4, input_stft.shape + assert input_stft.shape[-1] == 2, input_stft.shape + + # input_stft: (..., F, 2) -> (..., F) + input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 + # sum over frequency (B, N, F) -> (B, N) + energy = torch.sqrt(torch.clamp(input_power.sum(dim=2), min=1.0e-10)) + + # (Optional): Adjust length to match with the mel-spectrogram + if feats_lengths is not None: + energy = [ + self._adjust_num_frames(e[:el].view(-1), fl) + for e, el, fl in zip(energy, energy_lengths, feats_lengths) + ] + energy_lengths = feats_lengths + + # (Optional): Average by duration to calculate token-wise energy + if self.use_token_averaged_energy: + durations = durations * self.reduction_factor + energy = [ + self._average_by_duration(e[:el].view(-1), d) + for e, el, d in zip(energy, energy_lengths, durations) + ] + energy_lengths = durations_lengths + + # Padding + if isinstance(energy, list): + energy = pad_list(energy, 0.0) + + # Return with the shape (B, T, 1) + return energy.unsqueeze(-1), energy_lengths + + def _average_by_duration(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor: + assert 0 <= len(x) - d.sum() < self.reduction_factor + d_cumsum = F.pad(d.cumsum(dim=0), (1, 0)) + x_avg = [ + x[start:end].mean() if len(x[start:end]) != 0 else x.new_tensor(0.0) + for start, end in zip(d_cumsum[:-1], d_cumsum[1:]) + ] + return torch.stack(x_avg) + + @staticmethod + def _adjust_num_frames(x: torch.Tensor, num_frames: torch.Tensor) -> torch.Tensor: + if num_frames > len(x): + x = F.pad(x, (0, num_frames - len(x))) + elif num_frames < len(x): + x = x[:num_frames] + return x diff --git a/espnet2/tts/feats_extract/log_mel_fbank.py b/espnet2/tts/feats_extract/log_mel_fbank.py new file mode 100644 index 0000000000000000000000000000000000000000..e760ceab61fce646a7e8c5a9382e98b6d81fb685 --- /dev/null +++ b/espnet2/tts/feats_extract/log_mel_fbank.py @@ -0,0 +1,105 @@ +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Union + +import humanfriendly +import torch +from typeguard import check_argument_types + +from espnet2.layers.log_mel import LogMel +from espnet2.layers.stft import Stft +from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract + + +class LogMelFbank(AbsFeatsExtract): + """Conventional frontend structure for ASR + + Stft -> amplitude-spec -> Log-Mel-Fbank + """ + + def __init__( + self, + fs: Union[int, str] = 16000, + n_fft: int = 1024, + win_length: int = None, + hop_length: int = 256, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + n_mels: int = 80, + fmin: Optional[int] = 80, + fmax: Optional[int] = 7600, + htk: bool = False, + ): + assert check_argument_types() + super().__init__() + if isinstance(fs, str): + fs = humanfriendly.parse_size(fs) + + self.fs = fs + self.n_mels = n_mels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.fmin = fmin + self.fmax = fmax + + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + + self.logmel = LogMel( + fs=fs, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + htk=htk, + log_base=10.0, + ) + + def output_size(self) -> int: + return self.n_mels + + def get_parameters(self) -> Dict[str, Any]: + """Return the parameters required by Vocoder""" + return dict( + fs=self.fs, + n_fft=self.n_fft, + n_shift=self.hop_length, + window=self.window, + n_mels=self.n_mels, + win_length=self.win_length, + fmin=self.fmin, + fmax=self.fmax, + ) + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Domain-conversion: e.g. Stft: time -> time-freq + input_stft, feats_lens = self.stft(input, input_lengths) + + assert input_stft.dim() >= 4, input_stft.shape + # "2" refers to the real/imag parts of Complex + assert input_stft.shape[-1] == 2, input_stft.shape + + # NOTE(kamo): We use different definition for log-spec between TTS and ASR + # TTS: log_10(abs(stft)) + # ASR: log_e(power(stft)) + + # input_stft: (..., F, 2) -> (..., F) + input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 + input_amp = torch.sqrt(torch.clamp(input_power, min=1.0e-10)) + input_feats, _ = self.logmel(input_amp, feats_lens) + return input_feats, feats_lens diff --git a/espnet2/tts/feats_extract/log_spectrogram.py b/espnet2/tts/feats_extract/log_spectrogram.py new file mode 100644 index 0000000000000000000000000000000000000000..fa00ea435f179e6621b67d3ff28579bbf94a39b1 --- /dev/null +++ b/espnet2/tts/feats_extract/log_spectrogram.py @@ -0,0 +1,76 @@ +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple + +import torch +from typeguard import check_argument_types + +from espnet2.layers.stft import Stft +from espnet2.tts.feats_extract.abs_feats_extract import AbsFeatsExtract + + +class LogSpectrogram(AbsFeatsExtract): + """Conventional frontend structure for ASR + + Stft -> log-amplitude-spec + """ + + def __init__( + self, + n_fft: int = 1024, + win_length: int = None, + hop_length: int = 256, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + ): + assert check_argument_types() + super().__init__() + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + self.n_fft = n_fft + + def output_size(self) -> int: + return self.n_fft // 2 + 1 + + def get_parameters(self) -> Dict[str, Any]: + """Return the parameters required by Vocoder""" + return dict( + n_fft=self.n_fft, + n_shift=self.hop_length, + win_length=self.win_length, + window=self.window, + ) + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Stft: time -> time-freq + input_stft, feats_lens = self.stft(input, input_lengths) + + assert input_stft.dim() >= 4, input_stft.shape + # "2" refers to the real/imag parts of Complex + assert input_stft.shape[-1] == 2, input_stft.shape + + # NOTE(kamo): We use different definition for log-spec between TTS and ASR + # TTS: log_10(abs(stft)) + # ASR: log_e(power(stft)) + + # STFT -> Power spectrum + # input_stft: (..., F, 2) -> (..., F) + input_power = input_stft[..., 0] ** 2 + input_stft[..., 1] ** 2 + log_amp = 0.5 * torch.log10(torch.clamp(input_power, min=1.0e-10)) + return log_amp, feats_lens diff --git a/espnet2/tts/gst/__init__.py b/espnet2/tts/gst/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/tts/gst/style_encoder.py b/espnet2/tts/gst/style_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..33ac0210b292584dcfdbc57e613b6a1a6ad30cf6 --- /dev/null +++ b/espnet2/tts/gst/style_encoder.py @@ -0,0 +1,272 @@ +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Style encoder of GST-Tacotron.""" + +from typeguard import check_argument_types +from typing import Sequence + +import torch + +from espnet.nets.pytorch_backend.transformer.attention import ( + MultiHeadedAttention as BaseMultiHeadedAttention, # NOQA +) + + +class StyleEncoder(torch.nn.Module): + """Style encoder. + + This module is style encoder introduced in `Style Tokens: Unsupervised Style + Modeling, Control and Transfer in End-to-End Speech Synthesis`. + + .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End + Speech Synthesis`: https://arxiv.org/abs/1803.09017 + + Args: + idim (int, optional): Dimension of the input mel-spectrogram. + gst_tokens (int, optional): The number of GST embeddings. + gst_token_dim (int, optional): Dimension of each GST embedding. + gst_heads (int, optional): The number of heads in GST multihead attention. + conv_layers (int, optional): The number of conv layers in the reference encoder. + conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in the referece encoder. + conv_kernel_size (int, optional): + Kernal size of conv layers in the reference encoder. + conv_stride (int, optional): + Stride size of conv layers in the reference encoder. + gru_layers (int, optional): The number of GRU layers in the reference encoder. + gru_units (int, optional): The number of GRU units in the reference encoder. + + Todo: + * Support manual weight specification in inference. + + """ + + def __init__( + self, + idim: int = 80, + gst_tokens: int = 10, + gst_token_dim: int = 256, + gst_heads: int = 4, + conv_layers: int = 6, + conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), + conv_kernel_size: int = 3, + conv_stride: int = 2, + gru_layers: int = 1, + gru_units: int = 128, + ): + """Initilize global style encoder module.""" + assert check_argument_types() + super(StyleEncoder, self).__init__() + + self.ref_enc = ReferenceEncoder( + idim=idim, + conv_layers=conv_layers, + conv_chans_list=conv_chans_list, + conv_kernel_size=conv_kernel_size, + conv_stride=conv_stride, + gru_layers=gru_layers, + gru_units=gru_units, + ) + self.stl = StyleTokenLayer( + ref_embed_dim=gru_units, + gst_tokens=gst_tokens, + gst_token_dim=gst_token_dim, + gst_heads=gst_heads, + ) + + def forward(self, speech: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + speech (Tensor): Batch of padded target features (B, Lmax, odim). + + Returns: + Tensor: Style token embeddings (B, token_dim). + + """ + ref_embs = self.ref_enc(speech) + style_embs = self.stl(ref_embs) + + return style_embs + + +class ReferenceEncoder(torch.nn.Module): + """Reference encoder module. + + This module is refernece encoder introduced in `Style Tokens: Unsupervised Style + Modeling, Control and Transfer in End-to-End Speech Synthesis`. + + .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End + Speech Synthesis`: https://arxiv.org/abs/1803.09017 + + Args: + idim (int, optional): Dimension of the input mel-spectrogram. + conv_layers (int, optional): The number of conv layers in the reference encoder. + conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in the referece encoder. + conv_kernel_size (int, optional): + Kernal size of conv layers in the reference encoder. + conv_stride (int, optional): + Stride size of conv layers in the reference encoder. + gru_layers (int, optional): The number of GRU layers in the reference encoder. + gru_units (int, optional): The number of GRU units in the reference encoder. + + """ + + def __init__( + self, + idim=80, + conv_layers: int = 6, + conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), + conv_kernel_size: int = 3, + conv_stride: int = 2, + gru_layers: int = 1, + gru_units: int = 128, + ): + """Initilize reference encoder module.""" + assert check_argument_types() + super(ReferenceEncoder, self).__init__() + + # check hyperparameters are valid + assert conv_kernel_size % 2 == 1, "kernel size must be odd." + assert ( + len(conv_chans_list) == conv_layers + ), "the number of conv layers and length of channels list must be the same." + + convs = [] + padding = (conv_kernel_size - 1) // 2 + for i in range(conv_layers): + conv_in_chans = 1 if i == 0 else conv_chans_list[i - 1] + conv_out_chans = conv_chans_list[i] + convs += [ + torch.nn.Conv2d( + conv_in_chans, + conv_out_chans, + kernel_size=conv_kernel_size, + stride=conv_stride, + padding=padding, + # Do not use bias due to the following batch norm + bias=False, + ), + torch.nn.BatchNorm2d(conv_out_chans), + torch.nn.ReLU(inplace=True), + ] + self.convs = torch.nn.Sequential(*convs) + + self.conv_layers = conv_layers + self.kernel_size = conv_kernel_size + self.stride = conv_stride + self.padding = padding + + # get the number of GRU input units + gru_in_units = idim + for i in range(conv_layers): + gru_in_units = ( + gru_in_units - conv_kernel_size + 2 * padding + ) // conv_stride + 1 + gru_in_units *= conv_out_chans + self.gru = torch.nn.GRU(gru_in_units, gru_units, gru_layers, batch_first=True) + + def forward(self, speech: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + speech (Tensor): Batch of padded target features (B, Lmax, idim). + + Returns: + Tensor: Reference embedding (B, gru_units) + + """ + batch_size = speech.size(0) + xs = speech.unsqueeze(1) # (B, 1, Lmax, idim) + hs = self.convs(xs).transpose(1, 2) # (B, Lmax', conv_out_chans, idim') + # NOTE(kan-bayashi): We need to care the length? + time_length = hs.size(1) + hs = hs.contiguous().view(batch_size, time_length, -1) # (B, Lmax', gru_units) + self.gru.flatten_parameters() + _, ref_embs = self.gru(hs) # (gru_layers, batch_size, gru_units) + ref_embs = ref_embs[-1] # (batch_size, gru_units) + + return ref_embs + + +class StyleTokenLayer(torch.nn.Module): + """Style token layer module. + + This module is style token layer introduced in `Style Tokens: Unsupervised Style + Modeling, Control and Transfer in End-to-End Speech Synthesis`. + + .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End + Speech Synthesis`: https://arxiv.org/abs/1803.09017 + + Args: + ref_embed_dim (int, optional): Dimension of the input reference embedding. + gst_tokens (int, optional): The number of GST embeddings. + gst_token_dim (int, optional): Dimension of each GST embedding. + gst_heads (int, optional): The number of heads in GST multihead attention. + dropout_rate (float, optional): Dropout rate in multi-head attention. + + """ + + def __init__( + self, + ref_embed_dim: int = 128, + gst_tokens: int = 10, + gst_token_dim: int = 256, + gst_heads: int = 4, + dropout_rate: float = 0.0, + ): + """Initilize style token layer module.""" + assert check_argument_types() + super(StyleTokenLayer, self).__init__() + + gst_embs = torch.randn(gst_tokens, gst_token_dim // gst_heads) + self.register_parameter("gst_embs", torch.nn.Parameter(gst_embs)) + self.mha = MultiHeadedAttention( + q_dim=ref_embed_dim, + k_dim=gst_token_dim // gst_heads, + v_dim=gst_token_dim // gst_heads, + n_head=gst_heads, + n_feat=gst_token_dim, + dropout_rate=dropout_rate, + ) + + def forward(self, ref_embs: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + ref_embs (Tensor): Reference embeddings (B, ref_embed_dim). + + Returns: + Tensor: Style token embeddings (B, gst_token_dim). + + """ + batch_size = ref_embs.size(0) + # (num_tokens, token_dim) -> (batch_size, num_tokens, token_dim) + gst_embs = torch.tanh(self.gst_embs).unsqueeze(0).expand(batch_size, -1, -1) + # NOTE(kan-bayashi): Shoule we apply Tanh? + ref_embs = ref_embs.unsqueeze(1) # (batch_size, 1 ,ref_embed_dim) + style_embs = self.mha(ref_embs, gst_embs, gst_embs, None) + + return style_embs.squeeze(1) + + +class MultiHeadedAttention(BaseMultiHeadedAttention): + """Multi head attention module with different input dimension.""" + + def __init__(self, q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0): + """Initialize multi head attention module.""" + # NOTE(kan-bayashi): Do not use super().__init__() here since we want to + # overwrite BaseMultiHeadedAttention.__init__() method. + torch.nn.Module.__init__(self) + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = torch.nn.Linear(q_dim, n_feat) + self.linear_k = torch.nn.Linear(k_dim, n_feat) + self.linear_v = torch.nn.Linear(v_dim, n_feat) + self.linear_out = torch.nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = torch.nn.Dropout(p=dropout_rate) diff --git a/espnet2/tts/prosody_encoder.py b/espnet2/tts/prosody_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..60714585165fa8f86bcc077449b49978dacb0ba2 --- /dev/null +++ b/espnet2/tts/prosody_encoder.py @@ -0,0 +1,611 @@ +from typing import Sequence + +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from typeguard import check_argument_types + + +class VectorQuantizer(nn.Module): + """ + Reference: + [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py + """ + def __init__(self, + num_embeddings: int, + hidden_dim: int, + beta: float = 0.25): + super().__init__() + self.K = num_embeddings + self.D = hidden_dim + self.beta = 0.05 # beta override + + self.embedding = nn.Embedding(self.K, self.D) + self.embedding.weight.data.normal_(0.8, 0.1) # override + + def forward(self, latents: torch.Tensor) -> torch.Tensor: + # latents = latents.permute(0, 2, 1).contiguous() # (B, D, L) -> (B, L, D) + latents_shape = latents.shape + flat_latents = latents.view(-1, self.D) # (BL, D) + + # Compute L2 distance between latents and embedding weights + dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - \ + 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # (BL, K) + + # Get the encoding that has the min distance + encoding_inds = torch.argmin(dist, dim=1) # (BL) + output_inds = encoding_inds.view(latents_shape[0], latents_shape[1]) # (B, L) + encoding_inds = encoding_inds.unsqueeze(1) # (BL, 1) + + # Convert to one-hot encodings + device = latents.device + encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) + encoding_one_hot.scatter_(1, encoding_inds, 1) # (BL, K) + + # Quantize the latents + # (BL, D) + quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) + quantized_latents = quantized_latents.view(latents_shape) # (B, L, D) + + # Compute the VQ Losses + commitment_loss = F.mse_loss(quantized_latents.detach(), latents) + embedding_loss = F.mse_loss(quantized_latents, latents.detach()) + + vq_loss = commitment_loss * self.beta + embedding_loss + + # Add the residue back to the latents + quantized_latents = latents + (quantized_latents - latents).detach() + + # print(output_inds) + # print(quantized_latents) + + # The perplexity a useful value to track during training. + # It indicates how many codes are 'active' on average. + avg_probs = torch.mean(encoding_one_hot, dim=0) + # Exponential entropy + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return quantized_latents, vq_loss, output_inds, self.embedding, perplexity + + +class ProsodyEncoder(nn.Module): + """VQ-VAE prosody encoder module. + + Args: + odim (int): Number of input channels (mel spectrogram channels). + ref_enc_conv_layers (int, optional): + The number of conv layers in the reference encoder. + ref_enc_conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in the referece encoder. + ref_enc_conv_kernel_size (int, optional): + Kernal size of conv layers in the reference encoder. + ref_enc_conv_stride (int, optional): + Stride size of conv layers in the reference encoder. + ref_enc_gru_layers (int, optional): + The number of GRU layers in the reference encoder. + ref_enc_gru_units (int, optional): + The number of GRU units in the reference encoder. + ref_emb_integration_type: How to integrate reference embedding. + adim (int, optional): This value is not that important. + This will not change the capacity in the information-bottleneck. + num_embeddings (int, optional): The higher this value, the higher the + capacity in the information bottleneck. + FG (int, optional): Number of hidden channels. + """ + def __init__( + self, + odim: int, + adim: int = 64, + num_embeddings: int = 10, + hidden_dim: int = 3, + beta: float = 0.25, + ref_enc_conv_layers: int = 2, + ref_enc_conv_chans_list: Sequence[int] = (32, 32), + ref_enc_conv_kernel_size: int = 3, + ref_enc_conv_stride: int = 1, + global_enc_gru_layers: int = 1, + global_enc_gru_units: int = 32, + global_emb_integration_type: str = "add", + ) -> None: + assert check_argument_types() + super().__init__() + + # store hyperparameters + self.global_emb_integration_type = global_emb_integration_type + + padding = (ref_enc_conv_kernel_size - 1) // 2 + + self.ref_encoder = RefEncoder( + ref_enc_conv_layers=ref_enc_conv_layers, + ref_enc_conv_chans_list=ref_enc_conv_chans_list, + ref_enc_conv_kernel_size=ref_enc_conv_kernel_size, + ref_enc_conv_stride=ref_enc_conv_stride, + ref_enc_conv_padding=padding, + ) + + # get the number of ref enc output units + ref_enc_output_units = odim + for i in range(ref_enc_conv_layers): + ref_enc_output_units = ( + ref_enc_output_units - ref_enc_conv_kernel_size + 2 * padding + ) // ref_enc_conv_stride + 1 + ref_enc_output_units *= ref_enc_conv_chans_list[-1] + + self.fg_encoder = FGEncoder( + ref_enc_output_units + global_enc_gru_units, + hidden_dim=hidden_dim, + ) + + self.global_encoder = GlobalEncoder( + ref_enc_output_units, + global_enc_gru_layers=global_enc_gru_layers, + global_enc_gru_units=global_enc_gru_units, + ) + + # define a projection for the global embeddings + if self.global_emb_integration_type == "add": + self.global_projection = nn.Linear(global_enc_gru_units, adim) + else: + self.global_projection = nn.Linear( + adim + global_enc_gru_units, adim + ) + + self.ar_prior = ARPrior( + adim, + num_embeddings=num_embeddings, + hidden_dim=hidden_dim, + ) + + self.vq_layer = VectorQuantizer(num_embeddings, hidden_dim, beta) + + # define a projection for the quantized fine-grained embeddings + self.qfg_projection = nn.Linear(hidden_dim, adim) + + def forward( + self, + ys: torch.Tensor, + ds: torch.Tensor, + hs: torch.Tensor, + global_embs: torch.Tensor = None, + train_ar_prior: bool = False, + ar_prior_inference: bool = False, + fg_inds: torch.Tensor = None, + ) -> Sequence[torch.Tensor]: + """Calculate forward propagation. + + Args: + ys (Tensor): Batch of padded target features (B, Lmax, odim). + ds (LongTensor): Batch of padded durations (B, Tmax). + hs (Tensor): Batch of phoneme embeddings (B, Tmax, D). + global_embs (Tensor, optional): Global embeddings (B, D) + + Returns: + Tensor: Fine-grained quantized prosody embeddings (B, Tmax, adim). + Tensor: VQ loss. + Tensor: Global prosody embeddings (B, ref_enc_gru_units) + """ + if ys is not None: + print('generating global_embs') + ref_embs = self.ref_encoder(ys) # (B, L', ref_enc_output_units) + global_embs = self.global_encoder(ref_embs) # (B, ref_enc_gru_units) + + if ar_prior_inference: + print('Using ar prior') + hs_integrated = self._integrate_with_global_embs(hs, global_embs) + qs, top_inds = self.ar_prior.inference( + hs_integrated, fg_inds, self.vq_layer.embedding + ) + + qs = self.qfg_projection(qs) # (B, Tmax, adim) + assert hs.size(2) == qs.size(2) + + p_embs = self._integrate_with_global_embs(qs, global_embs) + assert hs.shape == p_embs.shape + + return p_embs, 0, 0, 0, top_inds # (B, Tmax, adim) + + # concat global embs to ref embs + global_embs_expanded = global_embs.unsqueeze(1).expand(-1, ref_embs.size(1), -1) + # (B, Tmax, D) + ref_embs_integrated = torch.cat([ref_embs, global_embs_expanded], dim=-1) + + # (B, Tmax, hidden_dim) + fg_embs = self.fg_encoder(ref_embs_integrated, ds, ys.size(1)) + + # (B, Tmax, hidden_dim) + qs, vq_loss, inds, codebook, perplexity = self.vq_layer(fg_embs) + # Vector quantization should maintain length + assert hs.size(1) == qs.size(1) + + qs = self.qfg_projection(qs) # (B, Tmax, adim) + assert hs.size(2) == qs.size(2) + + p_embs = self._integrate_with_global_embs(qs, global_embs) + assert hs.shape == p_embs.shape + + ar_prior_loss = 0 + if train_ar_prior: + # (B, Tmax, adim) + hs_integrated = self._integrate_with_global_embs(hs, global_embs) + qs, ar_prior_loss = self.ar_prior(hs_integrated, inds, codebook) + qs = self.qfg_projection(qs) # (B, Tmax, adim) + assert hs.size(2) == qs.size(2) + + p_embs = self._integrate_with_global_embs(qs, global_embs) + assert hs.shape == p_embs.shape + + return p_embs, vq_loss, ar_prior_loss, perplexity, global_embs + + def _integrate_with_global_embs( + self, + qs: torch.Tensor, + global_embs: torch.Tensor + ) -> torch.Tensor: + """Integrate ref embedding with spectrogram hidden states. + + Args: + qs (Tensor): Batch of quantized FG embeddings (B, Tmax, adim). + global_embs (Tensor): Batch of global embeddings (B, global_enc_gru_units). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). + """ + if self.global_emb_integration_type == "add": + # apply projection to hidden states + global_embs = self.global_projection(global_embs) + res = qs + global_embs.unsqueeze(1) + elif self.global_emb_integration_type == "concat": + # concat hidden states with prosody embeds and then apply projection + # (B, Tmax, ref_emb_dim) + global_embs = global_embs.unsqueeze(1).expand(-1, qs.size(1), -1) + # (B, Tmax, D) + res = self.prosody_projection(torch.cat([qs, global_embs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return res + + +class RefEncoder(nn.Module): + def __init__( + self, + ref_enc_conv_layers: int = 2, + ref_enc_conv_chans_list: Sequence[int] = (32, 32), + ref_enc_conv_kernel_size: int = 3, + ref_enc_conv_stride: int = 1, + ref_enc_conv_padding: int = 1, + ): + """Initilize reference encoder module.""" + assert check_argument_types() + super().__init__() + + # check hyperparameters are valid + assert ref_enc_conv_kernel_size % 2 == 1, "kernel size must be odd." + assert ( + len(ref_enc_conv_chans_list) == ref_enc_conv_layers + ), "the number of conv layers and length of channels list must be the same." + + convs = [] + for i in range(ref_enc_conv_layers): + conv_in_chans = 1 if i == 0 else ref_enc_conv_chans_list[i - 1] + conv_out_chans = ref_enc_conv_chans_list[i] + convs += [ + nn.Conv2d( + conv_in_chans, + conv_out_chans, + kernel_size=ref_enc_conv_kernel_size, + stride=ref_enc_conv_stride, + padding=ref_enc_conv_padding, + ), + nn.ReLU(inplace=True), + + ] + self.convs = nn.Sequential(*convs) + + def forward(self, ys: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + ys (Tensor): Batch of padded target features (B, Lmax, odim). + + Returns: + Tensor: Batch of spectrogram hiddens (B, L', ref_enc_output_units) + + """ + B = ys.size(0) + ys = ys.unsqueeze(1) # (B, 1, Lmax, odim) + hs = self.convs(ys) # (B, conv_out_chans, L', odim') + hs = hs.transpose(1, 2) # (B, L', conv_out_chans, odim') + L = hs.size(1) + # (B, L', ref_enc_output_units) -> "flatten" + hs = hs.contiguous().view(B, L, -1) + + return hs + + +class GlobalEncoder(nn.Module): + """Module that creates a global embedding from a hidden spectrogram sequence. + + Args: + """ + def __init__( + self, + ref_enc_output_units: int, + global_enc_gru_layers: int = 1, + global_enc_gru_units: int = 32, + ): + super().__init__() + self.gru = torch.nn.GRU(ref_enc_output_units, global_enc_gru_units, + global_enc_gru_layers, batch_first=True) + + def forward( + self, + hs: torch.Tensor, + ): + """Calculate forward propagation. + + Args: + hs (Tensor): Batch of spectrogram hiddens (B, L', ref_enc_output_units). + + Returns: + Tensor: Reference embedding (B, ref_enc_gru_units). + """ + self.gru.flatten_parameters() + _, global_embs = self.gru(hs) # (gru_layers, B, ref_enc_gru_units) + global_embs = global_embs[-1] # (B, ref_enc_gru_units) + + return global_embs + + +class FGEncoder(nn.Module): + """Spectrogram to phoneme alignment module. + + Args: + """ + def __init__( + self, + input_units: int, + hidden_dim: int = 3, + ): + assert check_argument_types() + super().__init__() + + self.projection = nn.Sequential( + nn.Sequential( + nn.Linear(input_units, input_units // 2), + nn.ReLU(), + nn.Dropout(p=0.2), + ), + nn.Sequential( + nn.Linear(input_units // 2, hidden_dim), + nn.ReLU(), + nn.Dropout(p=0.2), + ) + ) + + def forward( + self, + hs: torch.Tensor, + ds: torch.Tensor, + Lmax: int + ): + """Calculate forward propagation. + + Args: + hs (Tensor): Batch of spectrogram hiddens + (B, L', ref_enc_output_units + global_enc_gru_units). + ds (LongTensor): Batch of padded durations (B, Tmax). + + Returns: + Tensor: aligned spectrogram hiddens (B, Tmax, hidden_dim). + """ + # (B, Tmax, ref_enc_output_units + global_enc_gru_units) + hs = self._align_durations(hs, ds, Lmax) + hs = self.projection(hs) # (B, Tmax, hidden_dim) + + return hs + + def _align_durations(self, hs, ds, Lmax): + """Transform the spectrogram hiddens according to the ground-truth durations + so that there's only one hidden per phoneme hidden. + + Args: + # (B, L', ref_enc_output_units + global_enc_gru_units) + hs (Tensor): Batch of spectrogram hidden state sequences . + ds (LongTensor): Batch of padded durations (B, Tmax) + + Returns: + # (B, Tmax, ref_enc_output_units + global_enc_gru_units) + Tensor: Batch of averaged spectrogram hidden state sequences. + """ + B = hs.size(0) + L = hs.size(1) + D = hs.size(2) + + Tmax = ds.size(1) # -1 if Tmax + 1 + + device = hs.device + hs_res = torch.zeros( + [B, Tmax, D], + device=device + ) # (B, Tmax, D) + + with torch.no_grad(): + for b_i in range(B): + durations = ds[b_i] + multiplier = L / Lmax + i = 0 + for d_i in range(Tmax): + # take into account downsampling because of conv layers + d = max(math.floor(durations[d_i].item() * multiplier), 1) + if durations[d_i].item() > 0: + hs_slice = hs[b_i, i:i + d, :] # (d, D) + hs_res[b_i, d_i, :] = torch.mean(hs_slice, 0) + i += d + hs_res.requires_grad_(hs.requires_grad) + return hs_res + + +class ARPrior(nn.Module): + # torch.topk(decoder_output, beam_width) + """Autoregressive prior. + + This module is inspired by the AR prior described in `Generating diverse and + natural text-to-speech samples using a quantized fine-grained VAE and + auto-regressive prosody prior`. This prior is fit in the continuous latent space. + """ + def __init__( + self, + adim: int, + num_embeddings: int = 10, + hidden_dim: int = 3, + ): + assert check_argument_types() + super().__init__() + + # store hyperparameters + self.adim = adim + self.hidden_dim = hidden_dim + self.num_embeddings = num_embeddings + + self.qs_projection = nn.Linear(hidden_dim, adim) + + self.lstm = nn.LSTMCell( + self.adim, + self.num_embeddings, + ) + + self.criterion = nn.NLLLoss() + + def inds_to_embs(self, inds, codebook, device): + """Returns the quantized embeddings from the codebook, + corresponding to the indices. + + Args: + inds (Tensor): Batch of indices (B, Tmax, 1). + codebook (Embedding): (num_embeddings, D). + + Returns: + Tensor: Quantized embeddings (B, Tmax, D). + """ + flat_inds = torch.flatten(inds).unsqueeze(1) # (BL, 1) + + # Convert to one-hot encodings + encoding_one_hot = torch.zeros( + flat_inds.size(0), + self.num_embeddings, + device=device + ) + encoding_one_hot.scatter_(1, flat_inds, 1) # (BL, K) + + # Quantize the latents + # (BL, D) + quantized_embs = torch.matmul(encoding_one_hot, codebook.weight) + # (B, L, D) + quantized_embs = quantized_embs.view( + inds.size(0), inds.size(1), self.hidden_dim + ) + + return quantized_embs + + def top_embeddings(self, emb_scores: torch.Tensor, codebook): + """Returns the top quantized embeddings from the codebook using the scores. + + Args: + emb_scores (Tensor): Batch of embedding scores (B, Tmax, num_embeddings). + codebook (Embedding): (num_embeddings, D). + + Returns: + Tensor: Top quantized embeddings (B, Tmax, D). + Tensor: Top 3 inds (B, Tmax, 3). + """ + _, top_inds = emb_scores.topk(1, dim=-1) # (B, L, 1) + quantized_embs = self.inds_to_embs( + top_inds, + codebook, + emb_scores.device, + ) + _, top3_inds = emb_scores.topk(3, dim=-1) # (B, L, 1) + return quantized_embs, top3_inds + + def _forward(self, hs_ref_embs, codebook, fg_inds=None): + inds = [] + scores = [] + embs = [] + + if fg_inds is not None: + init_embs = self.inds_to_embs(fg_inds, codebook, hs_ref_embs.device) + embs = [init_emb.unsqueeze(1) for init_emb in init_embs.transpose(1, 0)] + + start = fg_inds.size(1) if fg_inds is not None else 0 + hidden = hs_ref_embs.new_zeros(hs_ref_embs.size(0), self.lstm.hidden_size) + cell = hs_ref_embs.new_zeros(hs_ref_embs.size(0), self.lstm.hidden_size) + + for i in range(start, hs_ref_embs.size(1)): + # (B, adim) + input = hs_ref_embs[:, i] + if i != 0: + # (B, 1, adim) + qs = self.qs_projection(embs[-1]) + # (B, adim) + input = hs_ref_embs[:, i] + qs.squeeze() + hidden, cell = self.lstm(input, (hidden, cell)) # (B, K) + out = hidden.unsqueeze(1) # (B, 1, K) + # (B, 1, K) + emb_scores = F.log_softmax(out, dim=2) + quantized_embs, top_inds = self.top_embeddings(emb_scores, codebook) + # (B, 1, hidden_dim) + embs.append(quantized_embs) + scores.append(emb_scores) + inds.append(top_inds) + + out_embs = torch.cat(embs, dim=1) # (B, L, hidden_dim) + assert(out_embs.size(0) == hs_ref_embs.size(0)) + assert(out_embs.size(1) == hs_ref_embs.size(1)) + out_emb_scores = torch.cat(scores, dim=1) if start < hs_ref_embs.size(1) else scores + out_inds = torch.cat(inds, dim=1) if start < hs_ref_embs.size(1) else fg_inds + + return out_embs, out_emb_scores, out_inds + + def forward(self, hs_ref_embs, inds, codebook): + """Calculate forward propagation. + + Args: + hs_p_embs (Tensor): Batch of phoneme embeddings + with integrated global prosody embeddings (B, Tmax, D). + inds (Tensor): Batch of ground-truth codebook indices + (B, Tmax). + + Returns: + Tensor: Batch of predicted quantized latents (B, Tmax, D). + Tensor: Cross entropy loss value. + + """ + quantized_embs, emb_scores, _ = self._forward(hs_ref_embs, codebook) + emb_scores = emb_scores.permute(0, 2, 1).contiguous() # (B, num_embeddings, L) + loss = self.criterion(emb_scores, inds) + return quantized_embs, loss + + def inference(self, hs_ref_embs, fg_inds, codebook): + """Inference duration. + + Args: + hs_p_embs (Tensor): Batch of phoneme embeddings + with integrated global prosody embeddings (B, Tmax, D). + + Returns: + Tensor: Batch of predicted quantized latents (B, Tmax, D). + + """ + # Random sampling + # fg_inds = torch.rand(hs_ref_embs.size(0), hs_ref_embs.size(1)) + # fg_inds *= codebook.weight.size(0) - 1 + # fg_inds = torch.round(fg_inds) + # fg_inds = fg_inds.long() + + quantized_embs, _, top_inds = self._forward(hs_ref_embs, codebook, fg_inds) + return quantized_embs, top_inds diff --git a/espnet2/tts/tacotron2.py b/espnet2/tts/tacotron2.py new file mode 100644 index 0000000000000000000000000000000000000000..d5c8b3cc71482dbd7a4357a120a1b43c21115d69 --- /dev/null +++ b/espnet2/tts/tacotron2.py @@ -0,0 +1,463 @@ +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Tacotron 2 related modules for ESPnet2.""" + +import logging +from typing import Dict +from typing import Sequence +from typing import Tuple + +import torch +import torch.nn.functional as F +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import GuidedAttentionLoss +from espnet.nets.pytorch_backend.e2e_tts_tacotron2 import Tacotron2Loss +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.rnn.attentions import AttForward +from espnet.nets.pytorch_backend.rnn.attentions import AttForwardTA +from espnet.nets.pytorch_backend.rnn.attentions import AttLoc +from espnet.nets.pytorch_backend.tacotron2.decoder import Decoder +from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.tts.abs_tts import AbsTTS +from espnet2.tts.gst.style_encoder import StyleEncoder + + +class Tacotron2(AbsTTS): + """Tacotron2 module for end-to-end text-to-speech. + + This is a module of Spectrogram prediction network in Tacotron2 described + in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_, + which converts the sequence of characters into the sequence of Mel-filterbanks. + + .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: + https://arxiv.org/abs/1712.05884 + + Args: + idim (int): Dimension of the inputs. + odim: (int) Dimension of the outputs. + spk_embed_dim (int, optional): Dimension of the speaker embedding. + embed_dim (int, optional): Dimension of character embedding. + elayers (int, optional): The number of encoder blstm layers. + eunits (int, optional): The number of encoder blstm units. + econv_layers (int, optional): The number of encoder conv layers. + econv_filts (int, optional): The number of encoder conv filter size. + econv_chans (int, optional): The number of encoder conv filter channels. + dlayers (int, optional): The number of decoder lstm layers. + dunits (int, optional): The number of decoder lstm units. + prenet_layers (int, optional): The number of prenet layers. + prenet_units (int, optional): The number of prenet units. + postnet_layers (int, optional): The number of postnet layers. + postnet_filts (int, optional): The number of postnet filter size. + postnet_chans (int, optional): The number of postnet filter channels. + output_activation (str, optional): The name of activation function for outputs. + adim (int, optional): The number of dimension of mlp in attention. + aconv_chans (int, optional): The number of attention conv filter channels. + aconv_filts (int, optional): The number of attention conv filter size. + cumulate_att_w (bool, optional): Whether to cumulate previous attention weight. + use_batch_norm (bool, optional): Whether to use batch normalization. + use_concate (bool, optional): Whether to concatenate encoder embedding with + decoder lstm outputs. + reduction_factor (int, optional): Reduction factor. + spk_embed_dim (int, optional): Number of speaker embedding dimenstions. + spk_embed_integration_type (str, optional): How to integrate speaker embedding. + use_gst (str, optional): Whether to use global style token. + gst_tokens (int, optional): The number of GST embeddings. + gst_heads (int, optional): The number of heads in GST multihead attention. + gst_conv_layers (int, optional): The number of conv layers in GST. + gst_conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in GST. + gst_conv_kernel_size (int, optional): Kernal size of conv layers in GST. + gst_conv_stride (int, optional): Stride size of conv layers in GST. + gst_gru_layers (int, optional): The number of GRU layers in GST. + gst_gru_units (int, optional): The number of GRU units in GST. + dropout_rate (float, optional): Dropout rate. + zoneout_rate (float, optional): Zoneout rate. + use_masking (bool, optional): Whether to mask padded part in loss calculation. + use_weighted_masking (bool, optional): Whether to apply weighted masking in + loss calculation. + bce_pos_weight (float, optional): Weight of positive sample of stop token + (only for use_masking=True). + loss_type (str, optional): How to calculate loss. + use_guided_attn_loss (bool, optional): Whether to use guided attention loss. + guided_attn_loss_sigma (float, optional): Sigma in guided attention loss. + guided_attn_loss_lamdba (float, optional): Lambda in guided attention loss. + + """ + + def __init__( + self, + # network structure related + idim: int, + odim: int, + embed_dim: int = 512, + elayers: int = 1, + eunits: int = 512, + econv_layers: int = 3, + econv_chans: int = 512, + econv_filts: int = 5, + atype: str = "location", + adim: int = 512, + aconv_chans: int = 32, + aconv_filts: int = 15, + cumulate_att_w: bool = True, + dlayers: int = 2, + dunits: int = 1024, + prenet_layers: int = 2, + prenet_units: int = 256, + postnet_layers: int = 5, + postnet_chans: int = 512, + postnet_filts: int = 5, + output_activation: str = None, + use_batch_norm: bool = True, + use_concate: bool = True, + use_residual: bool = False, + reduction_factor: int = 1, + spk_embed_dim: int = None, + spk_embed_integration_type: str = "concat", + use_gst: bool = False, + gst_tokens: int = 10, + gst_heads: int = 4, + gst_conv_layers: int = 6, + gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), + gst_conv_kernel_size: int = 3, + gst_conv_stride: int = 2, + gst_gru_layers: int = 1, + gst_gru_units: int = 128, + # training related + dropout_rate: float = 0.5, + zoneout_rate: float = 0.1, + use_masking: bool = True, + use_weighted_masking: bool = False, + bce_pos_weight: float = 5.0, + loss_type: str = "L1+L2", + use_guided_attn_loss: bool = True, + guided_attn_loss_sigma: float = 0.4, + guided_attn_loss_lambda: float = 1.0, + ): + """Initialize Tacotron2 module.""" + assert check_argument_types() + super().__init__() + + # store hyperparameters + self.idim = idim + self.odim = odim + self.eos = idim - 1 + self.spk_embed_dim = spk_embed_dim + self.cumulate_att_w = cumulate_att_w + self.reduction_factor = reduction_factor + self.use_gst = use_gst + self.use_guided_attn_loss = use_guided_attn_loss + self.loss_type = loss_type + if self.spk_embed_dim is not None: + self.spk_embed_integration_type = spk_embed_integration_type + + # define activation function for the final output + if output_activation is None: + self.output_activation_fn = None + elif hasattr(F, output_activation): + self.output_activation_fn = getattr(F, output_activation) + else: + raise ValueError( + f"there is no such an activation function. " f"({output_activation})" + ) + + # set padding idx + padding_idx = 0 + self.padding_idx = padding_idx + + # define network modules + self.enc = Encoder( + idim=idim, + embed_dim=embed_dim, + elayers=elayers, + eunits=eunits, + econv_layers=econv_layers, + econv_chans=econv_chans, + econv_filts=econv_filts, + use_batch_norm=use_batch_norm, + use_residual=use_residual, + dropout_rate=dropout_rate, + padding_idx=padding_idx, + ) + + if self.use_gst: + self.gst = StyleEncoder( + idim=odim, # the input is mel-spectrogram + gst_tokens=gst_tokens, + gst_token_dim=eunits, + gst_heads=gst_heads, + conv_layers=gst_conv_layers, + conv_chans_list=gst_conv_chans_list, + conv_kernel_size=gst_conv_kernel_size, + conv_stride=gst_conv_stride, + gru_layers=gst_gru_layers, + gru_units=gst_gru_units, + ) + + if spk_embed_dim is None: + dec_idim = eunits + elif spk_embed_integration_type == "concat": + dec_idim = eunits + spk_embed_dim + elif spk_embed_integration_type == "add": + dec_idim = eunits + self.projection = torch.nn.Linear(self.spk_embed_dim, eunits) + else: + raise ValueError(f"{spk_embed_integration_type} is not supported.") + + if atype == "location": + att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts) + elif atype == "forward": + att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts) + if self.cumulate_att_w: + logging.warning( + "cumulation of attention weights is disabled " + "in forward attention." + ) + self.cumulate_att_w = False + elif atype == "forward_ta": + att = AttForwardTA(dec_idim, dunits, adim, aconv_chans, aconv_filts, odim) + if self.cumulate_att_w: + logging.warning( + "cumulation of attention weights is disabled " + "in forward attention." + ) + self.cumulate_att_w = False + else: + raise NotImplementedError("Support only location or forward") + self.dec = Decoder( + idim=dec_idim, + odim=odim, + att=att, + dlayers=dlayers, + dunits=dunits, + prenet_layers=prenet_layers, + prenet_units=prenet_units, + postnet_layers=postnet_layers, + postnet_chans=postnet_chans, + postnet_filts=postnet_filts, + output_activation_fn=self.output_activation_fn, + cumulate_att_w=self.cumulate_att_w, + use_batch_norm=use_batch_norm, + use_concate=use_concate, + dropout_rate=dropout_rate, + zoneout_rate=zoneout_rate, + reduction_factor=reduction_factor, + ) + self.taco2_loss = Tacotron2Loss( + use_masking=use_masking, + use_weighted_masking=use_weighted_masking, + bce_pos_weight=bce_pos_weight, + ) + if self.use_guided_attn_loss: + self.attn_loss = GuidedAttentionLoss( + sigma=guided_attn_loss_sigma, + alpha=guided_attn_loss_lambda, + ) + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + spembs: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Calculate forward propagation. + + Args: + text (LongTensor): Batch of padded character ids (B, Tmax). + text_lengths (LongTensor): Batch of lengths of each input batch (B,). + speech (Tensor): Batch of padded target features (B, Lmax, odim). + speech_lengths (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Loss scalar value. + Dict: Statistics to be monitored. + Tensor: Weight value. + + """ + text = text[:, : text_lengths.max()] # for data-parallel + speech = speech[:, : speech_lengths.max()] # for data-parallel + + batch_size = text.size(0) + + # Add eos at the last of sequence + xs = F.pad(text, [0, 1], "constant", self.padding_idx) + for i, l in enumerate(text_lengths): + xs[i, l] = self.eos + ilens = text_lengths + 1 + + ys = speech + olens = speech_lengths + + # make labels for stop prediction + labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) + labels = F.pad(labels, [0, 1], "constant", 1.0) + + # calculate tacotron2 outputs + after_outs, before_outs, logits, att_ws = self._forward( + xs, ilens, ys, olens, spembs + ) + + # modify mod part of groundtruth + if self.reduction_factor > 1: + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + max_out = max(olens) + ys = ys[:, :max_out] + labels = labels[:, :max_out] + labels[:, -1] = 1.0 # make sure at least one frame has 1 + + # calculate taco2 loss + l1_loss, mse_loss, bce_loss = self.taco2_loss( + after_outs, before_outs, logits, ys, labels, olens + ) + if self.loss_type == "L1+L2": + loss = l1_loss + mse_loss + bce_loss + elif self.loss_type == "L1": + loss = l1_loss + bce_loss + elif self.loss_type == "L2": + loss = mse_loss + bce_loss + else: + raise ValueError(f"unknown --loss-type {self.loss_type}") + + stats = dict( + l1_loss=l1_loss.item(), + mse_loss=mse_loss.item(), + bce_loss=bce_loss.item(), + ) + + # calculate attention loss + if self.use_guided_attn_loss: + # NOTE(kan-bayashi): length of output for auto-regressive + # input will be changed when r > 1 + if self.reduction_factor > 1: + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + olens_in = olens + attn_loss = self.attn_loss(att_ws, ilens, olens_in) + loss = loss + attn_loss + stats.update(attn_loss=attn_loss.item()) + + stats.update(loss=loss.item()) + + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def _forward( + self, + xs: torch.Tensor, + ilens: torch.Tensor, + ys: torch.Tensor, + olens: torch.Tensor, + spembs: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hs, hlens = self.enc(xs, ilens) + if self.use_gst: + style_embs = self.gst(ys) + hs = hs + style_embs.unsqueeze(1) + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + return self.dec(hs, hlens, ys) + + def inference( + self, + text: torch.Tensor, + speech: torch.Tensor = None, + spembs: torch.Tensor = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 10.0, + use_att_constraint: bool = False, + backward_window: int = 1, + forward_window: int = 3, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the sequence of features given the sequences of characters. + + Args: + text (LongTensor): Input sequence of characters (T,). + speech (Tensor, optional): Feature sequence to extract style (N, idim). + spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). + threshold (float, optional): Threshold in inference. + minlenratio (float, optional): Minimum length ratio in inference. + maxlenratio (float, optional): Maximum length ratio in inference. + use_att_constraint (bool, optional): Whether to apply attention constraint. + backward_window (int, optional): Backward window in attention constraint. + forward_window (int, optional): Forward window in attention constraint. + use_teacher_forcing (bool, optional): Whether to use teacher forcing. + + Returns: + Tensor: Output sequence of features (L, odim). + Tensor: Output sequence of stop probabilities (L,). + Tensor: Attention weights (L, T). + + """ + x = text + y = speech + spemb = spembs + + # add eos at the last of sequence + x = F.pad(x, [0, 1], "constant", self.eos) + + # inference with teacher forcing + if use_teacher_forcing: + assert speech is not None, "speech must be provided with teacher forcing." + + xs, ys = x.unsqueeze(0), y.unsqueeze(0) + spembs = None if spemb is None else spemb.unsqueeze(0) + ilens = x.new_tensor([xs.size(1)]).long() + olens = y.new_tensor([ys.size(1)]).long() + outs, _, _, att_ws = self._forward(xs, ilens, ys, olens, spembs) + + return outs[0], None, att_ws[0] + + # inference + h = self.enc.inference(x) + if self.use_gst: + style_emb = self.gst(y.unsqueeze(0)) + h = h + style_emb + if self.spk_embed_dim is not None: + hs, spembs = h.unsqueeze(0), spemb.unsqueeze(0) + h = self._integrate_with_spk_embed(hs, spembs)[0] + outs, probs, att_ws = self.dec.inference( + h, + threshold=threshold, + minlenratio=minlenratio, + maxlenratio=maxlenratio, + use_att_constraint=use_att_constraint, + backward_window=backward_window, + forward_window=forward_window, + ) + + return outs, probs, att_ws + + def _integrate_with_spk_embed( + self, hs: torch.Tensor, spembs: torch.Tensor + ) -> torch.Tensor: + """Integrate speaker embedding with hidden states. + + Args: + hs (Tensor): Batch of hidden state sequences (B, Tmax, eunits). + spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, eunits) if + integration_type is "add" else (B, Tmax, eunits + spk_embed_dim). + + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spembs = self.projection(F.normalize(spembs)) + hs = hs + spembs.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = torch.cat([hs, spembs], dim=-1) + else: + raise NotImplementedError("support only add or concat.") + + return hs diff --git a/espnet2/tts/thesis_text.py b/espnet2/tts/thesis_text.py new file mode 100644 index 0000000000000000000000000000000000000000..3d2a1d02b2484457b4c6ed0dbd6953a50c20c679 --- /dev/null +++ b/espnet2/tts/thesis_text.py @@ -0,0 +1,611 @@ +from typing import Sequence + +import math + +import torch +from torch import nn +from torch.nn import functional as F + +from typeguard import check_argument_types + + +class VectorQuantizer(nn.Module): + """ + Reference: + [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py + """ + def __init__(self, + num_embeddings: int, + hidden_dim: int, + beta: float = 0.25): + super().__init__() + self.K = num_embeddings + self.D = hidden_dim + self.beta = 0.05 # beta override + + self.embedding = nn.Embedding(self.K, self.D) + self.embedding.weight.data.normal_(0.8, 0.1) # override + + def forward(self, latents: torch.Tensor) -> torch.Tensor: + # latents = latents.permute(0, 2, 1).contiguous() # (B, D, L) -> (B, L, D) + latents_shape = latents.shape + flat_latents = latents.view(-1, self.D) # (BL, D) + + # Compute L2 distance between latents and embedding weights + dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - \ + 2 * torch.matmul(flat_latents, self.embedding.weight.t()) # (BL, K) + + # Get the encoding that has the min distance + encoding_inds = torch.argmin(dist, dim=1) # (BL) + output_inds = encoding_inds.view(latents_shape[0], latents_shape[1]) # (B, L) + encoding_inds = encoding_inds.unsqueeze(1) # (BL, 1) + + # Convert to one-hot encodings + device = latents.device + encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=device) + encoding_one_hot.scatter_(1, encoding_inds, 1) # (BL, K) + + # Quantize the latents + # (BL, D) + quantized_latents = torch.matmul(encoding_one_hot, self.embedding.weight) + quantized_latents = quantized_latents.view(latents_shape) # (B, L, D) + + # Compute the VQ Losses + commitment_loss = F.mse_loss(quantized_latents.detach(), latents) + embedding_loss = F.mse_loss(quantized_latents, latents.detach()) + + vq_loss = commitment_loss * self.beta + embedding_loss + + # Add the residue back to the latents + quantized_latents = latents + (quantized_latents - latents).detach() + + # print(output_inds) + # print(quantized_latents) + + # The perplexity a useful value to track during training. + # It indicates how many codes are 'active' on average. + avg_probs = torch.mean(encoding_one_hot, dim=0) + # Exponential entropy + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + return quantized_latents, vq_loss, output_inds, self.embedding, perplexity + + +class ProsodyEncoder(nn.Module): + """VQ-VAE prosody encoder module. + + Args: + odim (int): Number of input channels (mel spectrogram channels). + ref_enc_conv_layers (int, optional): + The number of conv layers in the reference encoder. + ref_enc_conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in the referece encoder. + ref_enc_conv_kernel_size (int, optional): + Kernal size of conv layers in the reference encoder. + ref_enc_conv_stride (int, optional): + Stride size of conv layers in the reference encoder. + ref_enc_gru_layers (int, optional): + The number of GRU layers in the reference encoder. + ref_enc_gru_units (int, optional): + The number of GRU units in the reference encoder. + ref_emb_integration_type: How to integrate reference embedding. + adim (int, optional): This value is not that important. + This will not change the capacity in the information-bottleneck. + num_embeddings (int, optional): The higher this value, the higher the + capacity in the information bottleneck. + hidden_dim (int, optional): Number of hidden channels. + """ + def __init__( + self, + odim: int, + adim: int = 64, + num_embeddings: int = 10, + hidden_dim: int = 3, + beta: float = 0.25, + ref_enc_conv_layers: int = 2, + ref_enc_conv_chans_list: Sequence[int] = (32, 32), + ref_enc_conv_kernel_size: int = 3, + ref_enc_conv_stride: int = 1, + global_enc_gru_layers: int = 1, + global_enc_gru_units: int = 32, + global_emb_integration_type: str = "add", + ) -> None: + assert check_argument_types() + super().__init__() + + # store hyperparameters + self.global_emb_integration_type = global_emb_integration_type + + padding = (ref_enc_conv_kernel_size - 1) // 2 + + self.ref_encoder = RefEncoder( + ref_enc_conv_layers=ref_enc_conv_layers, + ref_enc_conv_chans_list=ref_enc_conv_chans_list, + ref_enc_conv_kernel_size=ref_enc_conv_kernel_size, + ref_enc_conv_stride=ref_enc_conv_stride, + ref_enc_conv_padding=padding, + ) + + # get the number of ref enc output units + ref_enc_output_units = odim + for i in range(ref_enc_conv_layers): + ref_enc_output_units = ( + ref_enc_output_units - ref_enc_conv_kernel_size + 2 * padding + ) // ref_enc_conv_stride + 1 + ref_enc_output_units *= ref_enc_conv_chans_list[-1] + + self.fg_encoder = FGEncoder( + ref_enc_output_units + global_enc_gru_units, + hidden_dim=hidden_dim, + ) + + self.global_encoder = GlobalEncoder( + ref_enc_output_units, + global_enc_gru_layers=global_enc_gru_layers, + global_enc_gru_units=global_enc_gru_units, + ) + + # define a projection for the global embeddings + if self.global_emb_integration_type == "add": + self.global_projection = nn.Linear(global_enc_gru_units, adim) + else: + self.global_projection = nn.Linear( + adim + global_enc_gru_units, adim + ) + + self.ar_prior = ARPrior( + adim, + num_embeddings=num_embeddings, + hidden_dim=hidden_dim, + ) + + self.vq_layer = VectorQuantizer(num_embeddings, hidden_dim, beta) + + # define a projection for the quantized fine-grained embeddings + self.qfg_projection = nn.Linear(hidden_dim, adim) + + def forward( + self, + ys: torch.Tensor, + ds: torch.Tensor, + hs: torch.Tensor, + global_embs: torch.Tensor = None, + train_ar_prior: bool = False, + ar_prior_inference: bool = False, + fg_inds: torch.Tensor = None, + ) -> Sequence[torch.Tensor]: + """Calculate forward propagation. + + Args: + ys (Tensor): Batch of padded target features (B, Lmax, odim). + ds (LongTensor): Batch of padded durations (B, Tmax). + hs (Tensor): Batch of phoneme embeddings (B, Tmax, D). + global_embs (Tensor, optional): Global embeddings (B, D) + + Returns: + Tensor: Fine-grained quantized prosody embeddings (B, Tmax, adim). + Tensor: VQ loss. + Tensor: Global prosody embeddings (B, ref_enc_gru_units) + """ + if ys is not None: + print('generating global_embs') + ref_embs = self.ref_encoder(ys) # (B, L', ref_enc_output_units) + global_embs = self.global_encoder(ref_embs) # (B, ref_enc_gru_units) + + if ar_prior_inference: + print('Using ar prior') + hs_integrated = self._integrate_with_global_embs(hs, global_embs) + qs, top_inds = self.ar_prior.inference( + hs_integrated, fg_inds, self.vq_layer.embedding + ) + + qs = self.qfg_projection(qs) # (B, Tmax, adim) + assert hs.size(2) == qs.size(2) + + p_embs = self._integrate_with_global_embs(qs, global_embs) + assert hs.shape == p_embs.shape + + return p_embs, 0, 0, 0, top_inds # (B, Tmax, adim) + + # concat global embs to ref embs + global_embs_expanded = global_embs.unsqueeze(1).expand(-1, ref_embs.size(1), -1) + # (B, Tmax, D) + ref_embs_integrated = torch.cat([ref_embs, global_embs_expanded], dim=-1) + + # (B, Tmax, hidden_dim) + fg_embs = self.fg_encoder(ref_embs_integrated, ds, ys.size(1)) + + # (B, Tmax, hidden_dim) + qs, vq_loss, inds, codebook, perplexity = self.vq_layer(fg_embs) + # Vector quantization should maintain length + assert hs.size(1) == qs.size(1) + + qs = self.qfg_projection(qs) # (B, Tmax, adim) + assert hs.size(2) == qs.size(2) + + p_embs = self._integrate_with_global_embs(qs, global_embs) + assert hs.shape == p_embs.shape + + ar_prior_loss = 0 + if train_ar_prior: + # (B, Tmax, adim) + hs_integrated = self._integrate_with_global_embs(hs, global_embs) + qs, ar_prior_loss = self.ar_prior(hs_integrated, inds, codebook) + qs = self.qfg_projection(qs) # (B, Tmax, adim) + assert hs.size(2) == qs.size(2) + + p_embs = self._integrate_with_global_embs(qs, global_embs) + assert hs.shape == p_embs.shape + + return p_embs, vq_loss, ar_prior_loss, perplexity, global_embs + + def _integrate_with_global_embs( + self, + qs: torch.Tensor, + global_embs: torch.Tensor + ) -> torch.Tensor: + """Integrate ref embedding with spectrogram hidden states. + + Args: + qs (Tensor): Batch of quantized FG embeddings (B, Tmax, adim). + global_embs (Tensor): Batch of global embeddings (B, global_enc_gru_units). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). + """ + if self.global_emb_integration_type == "add": + # apply projection to hidden states + global_embs = self.global_projection(global_embs) + res = qs + global_embs.unsqueeze(1) + elif self.global_emb_integration_type == "concat": + # concat hidden states with prosody embeds and then apply projection + # (B, Tmax, ref_emb_dim) + global_embs = global_embs.unsqueeze(1).expand(-1, qs.size(1), -1) + # (B, Tmax, D) + res = self.prosody_projection(torch.cat([qs, global_embs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return res + + +class RefEncoder(nn.Module): + def __init__( + self, + ref_enc_conv_layers: int = 2, + ref_enc_conv_chans_list: Sequence[int] = (32, 32), + ref_enc_conv_kernel_size: int = 3, + ref_enc_conv_stride: int = 1, + ref_enc_conv_padding: int = 1, + ): + """Initilize reference encoder module.""" + assert check_argument_types() + super().__init__() + + # check hyperparameters are valid + assert ref_enc_conv_kernel_size % 2 == 1, "kernel size must be odd." + assert ( + len(ref_enc_conv_chans_list) == ref_enc_conv_layers + ), "the number of conv layers and length of channels list must be the same." + + convs = [] + for i in range(ref_enc_conv_layers): + conv_in_chans = 1 if i == 0 else ref_enc_conv_chans_list[i - 1] + conv_out_chans = ref_enc_conv_chans_list[i] + convs += [ + nn.Conv2d( + conv_in_chans, + conv_out_chans, + kernel_size=ref_enc_conv_kernel_size, + stride=ref_enc_conv_stride, + padding=ref_enc_conv_padding, + ), + nn.ReLU(inplace=True), + + ] + self.convs = nn.Sequential(*convs) + + def forward(self, ys: torch.Tensor) -> torch.Tensor: + """Calculate forward propagation. + + Args: + ys (Tensor): Batch of padded target features (B, Lmax, odim). + + Returns: + Tensor: Batch of spectrogram hiddens (B, L', ref_enc_output_units) + + """ + B = ys.size(0) + ys = ys.unsqueeze(1) # (B, 1, Lmax, odim) + hs = self.convs(ys) # (B, conv_out_chans, L', odim') + hs = hs.transpose(1, 2) # (B, L', conv_out_chans, odim') + L = hs.size(1) + # (B, L', ref_enc_output_units) -> "flatten" + hs = hs.contiguous().view(B, L, -1) + + return hs + + +class GlobalEncoder(nn.Module): + """Module that creates a global embedding from a hidden spectrogram sequence. + + Args: + """ + def __init__( + self, + ref_enc_output_units: int, + global_enc_gru_layers: int = 1, + global_enc_gru_units: int = 32, + ): + super().__init__() + self.gru = torch.nn.GRU(ref_enc_output_units, global_enc_gru_units, + global_enc_gru_layers, batch_first=True) + + def forward( + self, + hs: torch.Tensor, + ): + """Calculate forward propagation. + + Args: + hs (Tensor): Batch of spectrogram hiddens (B, L', ref_enc_output_units). + + Returns: + Tensor: Reference embedding (B, ref_enc_gru_units). + """ + self.gru.flatten_parameters() + _, global_embs = self.gru(hs) # (gru_layers, B, ref_enc_gru_units) + global_embs = global_embs[-1] # (B, ref_enc_gru_units) + + return global_embs + + +class FGEncoder(nn.Module): + """Spectrogram to phoneme alignment module. + + Args: + """ + def __init__( + self, + input_units: int, + hidden_dim: int = 3, + ): + assert check_argument_types() + super().__init__() + + self.projection = nn.Sequential( + nn.Sequential( + nn.Linear(input_units, input_units // 2), + nn.ReLU(), + nn.Dropout(p=0.2), + ), + nn.Sequential( + nn.Linear(input_units // 2, hidden_dim), + nn.ReLU(), + nn.Dropout(p=0.2), + ) + ) + + def forward( + self, + hs: torch.Tensor, + ds: torch.Tensor, + Lmax: int + ): + """Calculate forward propagation. + + Args: + hs (Tensor): Batch of spectrogram hiddens + (B, L', ref_enc_output_units + global_enc_gru_units). + ds (LongTensor): Batch of padded durations (B, Tmax). + + Returns: + Tensor: aligned spectrogram hiddens (B, Tmax, hidden_dim). + """ + # (B, Tmax, ref_enc_output_units + global_enc_gru_units) + hs = self._align_durations(hs, ds, Lmax) + hs = self.projection(hs) # (B, Tmax, hidden_dim) + + return hs + + def _align_durations(self, hs, ds, Lmax): + """Transform the spectrogram hiddens according to the ground-truth durations + so that there's only one hidden per phoneme hidden. + + Args: + # (B, L', ref_enc_output_units + global_enc_gru_units) + hs (Tensor): Batch of spectrogram hidden state sequences . + ds (LongTensor): Batch of padded durations (B, Tmax) + + Returns: + # (B, Tmax, ref_enc_output_units + global_enc_gru_units) + Tensor: Batch of averaged spectrogram hidden state sequences. + """ + B = hs.size(0) + L = hs.size(1) + D = hs.size(2) + + Tmax = ds.size(1) # -1 if Tmax + 1 + + device = hs.device + hs_res = torch.zeros( + [B, Tmax, D], + device=device + ) # (B, Tmax, D) + + with torch.no_grad(): + for b_i in range(B): + durations = ds[b_i] + multiplier = L / Lmax + i = 0 + for d_i in range(Tmax): + # take into account downsampling because of conv layers + d = max(math.floor(durations[d_i].item() * multiplier), 1) + if durations[d_i].item() > 0: + hs_slice = hs[b_i, i:i + d, :] # (d, D) + hs_res[b_i, d_i, :] = torch.mean(hs_slice, 0) + i += d + hs_res.requires_grad_(hs.requires_grad) + return hs_res + + +class ARPrior(nn.Module): + # torch.topk(decoder_output, beam_width) + """Autoregressive prior. + + This module is inspired by the AR prior described in `Generating diverse and + natural text-to-speech samples using a quantized fine-grained VAE and + auto-regressive prosody prior`. This prior is fit in the continuous latent space. + """ + def __init__( + self, + adim: int, + num_embeddings: int = 10, + hidden_dim: int = 3, + ): + assert check_argument_types() + super().__init__() + + # store hyperparameters + self.adim = adim + self.hidden_dim = hidden_dim + self.num_embeddings = num_embeddings + + self.qs_projection = nn.Linear(hidden_dim, adim) + + self.lstm = nn.LSTMCell( + self.adim, + self.num_embeddings, + ) + + self.criterion = nn.NLLLoss() + + def inds_to_embs(self, inds, codebook, device): + """Returns the quantized embeddings from the codebook, + corresponding to the indices. + + Args: + inds (Tensor): Batch of indices (B, Tmax, 1). + codebook (Embedding): (num_embeddings, D). + + Returns: + Tensor: Quantized embeddings (B, Tmax, D). + """ + flat_inds = torch.flatten(inds).unsqueeze(1) # (BL, 1) + + # Convert to one-hot encodings + encoding_one_hot = torch.zeros( + flat_inds.size(0), + self.num_embeddings, + device=device + ) + encoding_one_hot.scatter_(1, flat_inds, 1) # (BL, K) + + # Quantize the latents + # (BL, D) + quantized_embs = torch.matmul(encoding_one_hot, codebook.weight) + # (B, L, D) + quantized_embs = quantized_embs.view( + inds.size(0), inds.size(1), self.hidden_dim + ) + + return quantized_embs + + def top_embeddings(self, emb_scores: torch.Tensor, codebook): + """Returns the top quantized embeddings from the codebook using the scores. + + Args: + emb_scores (Tensor): Batch of embedding scores (B, Tmax, num_embeddings). + codebook (Embedding): (num_embeddings, D). + + Returns: + Tensor: Top quantized embeddings (B, Tmax, D). + Tensor: Top 3 inds (B, Tmax, 3). + """ + _, top_inds = emb_scores.topk(1, dim=-1) # (B, L, 1) + quantized_embs = self.inds_to_embs( + top_inds, + codebook, + emb_scores.device, + ) + _, top3_inds = emb_scores.topk(3, dim=-1) # (B, L, 1) + return quantized_embs, top3_inds + + def _forward(self, hs_ref_embs, codebook, fg_inds=None): + inds = [] + scores = [] + embs = [] + + if fg_inds is not None: + init_embs = self.inds_to_embs(fg_inds, codebook, hs_ref_embs.device) + embs = [init_emb.unsqueeze(1) for init_emb in init_embs.transpose(1, 0)] + + start = fg_inds.size(1) if fg_inds is not None else 0 + hidden = hs_ref_embs.new_zeros(hs_ref_embs.size(0), self.lstm.hidden_size) + cell = hs_ref_embs.new_zeros(hs_ref_embs.size(0), self.lstm.hidden_size) + + for i in range(start, hs_ref_embs.size(1)): + # (B, adim) + input = hs_ref_embs[:, i] + if i != 0: + # (B, 1, adim) + qs = self.qs_projection(embs[-1]) + # (B, adim) + input = hs_ref_embs[:, i] + qs.squeeze() + hidden, cell = self.lstm(input, (hidden, cell)) # (B, K) + out = hidden.unsqueeze(1) # (B, 1, K) + # (B, 1, K) + emb_scores = F.log_softmax(out, dim=2) + quantized_embs, top_inds = self.top_embeddings(emb_scores, codebook) + # (B, 1, hidden_dim) + embs.append(quantized_embs) + scores.append(emb_scores) + inds.append(top_inds) + + out_embs = torch.cat(embs, dim=1) # (B, L, hidden_dim) + assert(out_embs.size(0) == hs_ref_embs.size(0)) + assert(out_embs.size(1) == hs_ref_embs.size(1)) + out_emb_scores = torch.cat(scores, dim=1) if start < hs_ref_embs.size(1) else scores + out_inds = torch.cat(inds, dim=1) if start < hs_ref_embs.size(1) else fg_inds + + return out_embs, out_emb_scores, out_inds + + def forward(self, hs_ref_embs, inds, codebook): + """Calculate forward propagation. + + Args: + hs_p_embs (Tensor): Batch of phoneme embeddings + with integrated global prosody embeddings (B, Tmax, D). + inds (Tensor): Batch of ground-truth codebook indices + (B, Tmax). + + Returns: + Tensor: Batch of predicted quantized latents (B, Tmax, D). + Tensor: Cross entropy loss value. + + """ + quantized_embs, emb_scores, _ = self._forward(hs_ref_embs, codebook) + emb_scores = emb_scores.permute(0, 2, 1).contiguous() # (B, num_embeddings, L) + loss = self.criterion(emb_scores, inds) + return quantized_embs, loss + + def inference(self, hs_ref_embs, fg_inds, codebook): + """Inference duration. + + Args: + hs_p_embs (Tensor): Batch of phoneme embeddings + with integrated global prosody embeddings (B, Tmax, D). + + Returns: + Tensor: Batch of predicted quantized latents (B, Tmax, D). + + """ + # Random sampling + # fg_inds = torch.rand(hs_ref_embs.size(0), hs_ref_embs.size(1)) + # fg_inds *= codebook.weight.size(0) - 1 + # fg_inds = torch.round(fg_inds) + # fg_inds = fg_inds.long() + + quantized_embs, _, top_inds = self._forward(hs_ref_embs, codebook, fg_inds) + return quantized_embs, top_inds diff --git a/espnet2/tts/transformer.py b/espnet2/tts/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..15e9085352f1ccdab2fde88c6bdd3d68df4e6b0a --- /dev/null +++ b/espnet2/tts/transformer.py @@ -0,0 +1,775 @@ +# Copyright 2020 Nagoya University (Tomoki Hayashi) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""TTS-Transformer related modules.""" + +from typing import Dict +from typing import Sequence +from typing import Tuple + +import torch +import torch.nn.functional as F +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.e2e_tts_transformer import GuidedMultiHeadAttentionLoss +from espnet.nets.pytorch_backend.e2e_tts_transformer import TransformerLoss +from espnet.nets.pytorch_backend.nets_utils import make_non_pad_mask +from espnet.nets.pytorch_backend.nets_utils import make_pad_mask +from espnet.nets.pytorch_backend.tacotron2.decoder import Postnet +from espnet.nets.pytorch_backend.tacotron2.decoder import Prenet as DecoderPrenet +from espnet.nets.pytorch_backend.tacotron2.encoder import Encoder as EncoderPrenet +from espnet.nets.pytorch_backend.transformer.attention import MultiHeadedAttention +from espnet.nets.pytorch_backend.transformer.decoder import Decoder +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding +from espnet.nets.pytorch_backend.transformer.embedding import ScaledPositionalEncoding +from espnet.nets.pytorch_backend.transformer.encoder import Encoder +from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask +from espnet2.torch_utils.device_funcs import force_gatherable +from espnet2.torch_utils.initialize import initialize +from espnet2.tts.abs_tts import AbsTTS +from espnet2.tts.gst.style_encoder import StyleEncoder + + +class Transformer(AbsTTS): + """TTS-Transformer module. + + This is a module of text-to-speech Transformer described in `Neural Speech Synthesis + with Transformer Network`_, which convert the sequence of tokens into the sequence + of Mel-filterbanks. + + .. _`Neural Speech Synthesis with Transformer Network`: + https://arxiv.org/pdf/1809.08895.pdf + + Args: + idim (int): Dimension of the inputs. + odim (int): Dimension of the outputs. + embed_dim (int, optional): Dimension of character embedding. + eprenet_conv_layers (int, optional): + Number of encoder prenet convolution layers. + eprenet_conv_chans (int, optional): + Number of encoder prenet convolution channels. + eprenet_conv_filts (int, optional): + Filter size of encoder prenet convolution. + dprenet_layers (int, optional): Number of decoder prenet layers. + dprenet_units (int, optional): Number of decoder prenet hidden units. + elayers (int, optional): Number of encoder layers. + eunits (int, optional): Number of encoder hidden units. + adim (int, optional): Number of attention transformation dimensions. + aheads (int, optional): Number of heads for multi head attention. + dlayers (int, optional): Number of decoder layers. + dunits (int, optional): Number of decoder hidden units. + postnet_layers (int, optional): Number of postnet layers. + postnet_chans (int, optional): Number of postnet channels. + postnet_filts (int, optional): Filter size of postnet. + use_scaled_pos_enc (bool, optional): + Whether to use trainable scaled positional encoding. + use_batch_norm (bool, optional): + Whether to use batch normalization in encoder prenet. + encoder_normalize_before (bool, optional): + Whether to perform layer normalization before encoder block. + decoder_normalize_before (bool, optional): + Whether to perform layer normalization before decoder block. + encoder_concat_after (bool, optional): Whether to concatenate attention + layer's input and output in encoder. + decoder_concat_after (bool, optional): Whether to concatenate attention + layer's input and output in decoder. + positionwise_layer_type (str, optional): + Position-wise operation type. + positionwise_conv_kernel_size (int, optional): + Kernel size in position wise conv 1d. + reduction_factor (int, optional): Reduction factor. + spk_embed_dim (int, optional): Number of speaker embedding dimenstions. + spk_embed_integration_type (str, optional): How to integrate speaker embedding. + use_gst (str, optional): Whether to use global style token. + gst_tokens (int, optional): The number of GST embeddings. + gst_heads (int, optional): The number of heads in GST multihead attention. + gst_conv_layers (int, optional): The number of conv layers in GST. + gst_conv_chans_list: (Sequence[int], optional): + List of the number of channels of conv layers in GST. + gst_conv_kernel_size (int, optional): Kernal size of conv layers in GST. + gst_conv_stride (int, optional): Stride size of conv layers in GST. + gst_gru_layers (int, optional): The number of GRU layers in GST. + gst_gru_units (int, optional): The number of GRU units in GST. + transformer_lr (float, optional): Initial value of learning rate. + transformer_warmup_steps (int, optional): Optimizer warmup steps. + transformer_enc_dropout_rate (float, optional): + Dropout rate in encoder except attention and positional encoding. + transformer_enc_positional_dropout_rate (float, optional): + Dropout rate after encoder positional encoding. + transformer_enc_attn_dropout_rate (float, optional): + Dropout rate in encoder self-attention module. + transformer_dec_dropout_rate (float, optional): + Dropout rate in decoder except attention & positional encoding. + transformer_dec_positional_dropout_rate (float, optional): + Dropout rate after decoder positional encoding. + transformer_dec_attn_dropout_rate (float, optional): + Dropout rate in deocoder self-attention module. + transformer_enc_dec_attn_dropout_rate (float, optional): + Dropout rate in encoder-deocoder attention module. + init_type (str, optional): + How to initialize transformer parameters. + init_enc_alpha (float, optional): + Initial value of alpha in scaled pos encoding of the encoder. + init_dec_alpha (float, optional): + Initial value of alpha in scaled pos encoding of the decoder. + eprenet_dropout_rate (float, optional): Dropout rate in encoder prenet. + dprenet_dropout_rate (float, optional): Dropout rate in decoder prenet. + postnet_dropout_rate (float, optional): Dropout rate in postnet. + use_masking (bool, optional): + Whether to apply masking for padded part in loss calculation. + use_weighted_masking (bool, optional): + Whether to apply weighted masking in loss calculation. + bce_pos_weight (float, optional): Positive sample weight in bce calculation + (only for use_masking=true). + loss_type (str, optional): How to calculate loss. + use_guided_attn_loss (bool, optional): Whether to use guided attention loss. + num_heads_applied_guided_attn (int, optional): + Number of heads in each layer to apply guided attention loss. + num_layers_applied_guided_attn (int, optional): + Number of layers to apply guided attention loss. + modules_applied_guided_attn (Sequence[str], optional): + List of module names to apply guided attention loss. + guided_attn_loss_sigma (float, optional) Sigma in guided attention loss. + guided_attn_loss_lambda (float, optional): Lambda in guided attention loss. + + """ + + def __init__( + self, + # network structure related + idim: int, + odim: int, + embed_dim: int = 512, + eprenet_conv_layers: int = 3, + eprenet_conv_chans: int = 256, + eprenet_conv_filts: int = 5, + dprenet_layers: int = 2, + dprenet_units: int = 256, + elayers: int = 6, + eunits: int = 1024, + adim: int = 512, + aheads: int = 4, + dlayers: int = 6, + dunits: int = 1024, + postnet_layers: int = 5, + postnet_chans: int = 256, + postnet_filts: int = 5, + positionwise_layer_type: str = "conv1d", + positionwise_conv_kernel_size: int = 1, + use_scaled_pos_enc: bool = True, + use_batch_norm: bool = True, + encoder_normalize_before: bool = True, + decoder_normalize_before: bool = True, + encoder_concat_after: bool = False, + decoder_concat_after: bool = False, + reduction_factor: int = 1, + spk_embed_dim: int = None, + spk_embed_integration_type: str = "add", + use_gst: bool = False, + gst_tokens: int = 10, + gst_heads: int = 4, + gst_conv_layers: int = 6, + gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), + gst_conv_kernel_size: int = 3, + gst_conv_stride: int = 2, + gst_gru_layers: int = 1, + gst_gru_units: int = 128, + # training related + transformer_enc_dropout_rate: float = 0.1, + transformer_enc_positional_dropout_rate: float = 0.1, + transformer_enc_attn_dropout_rate: float = 0.1, + transformer_dec_dropout_rate: float = 0.1, + transformer_dec_positional_dropout_rate: float = 0.1, + transformer_dec_attn_dropout_rate: float = 0.1, + transformer_enc_dec_attn_dropout_rate: float = 0.1, + eprenet_dropout_rate: float = 0.5, + dprenet_dropout_rate: float = 0.5, + postnet_dropout_rate: float = 0.5, + init_type: str = "xavier_uniform", + init_enc_alpha: float = 1.0, + init_dec_alpha: float = 1.0, + use_masking: bool = False, + use_weighted_masking: bool = False, + bce_pos_weight: float = 5.0, + loss_type: str = "L1", + use_guided_attn_loss: bool = True, + num_heads_applied_guided_attn: int = 2, + num_layers_applied_guided_attn: int = 2, + modules_applied_guided_attn: Sequence[str] = ("encoder-decoder"), + guided_attn_loss_sigma: float = 0.4, + guided_attn_loss_lambda: float = 1.0, + ): + """Initialize Transformer module.""" + assert check_argument_types() + super().__init__() + + # store hyperparameters + self.idim = idim + self.odim = odim + self.eos = idim - 1 + self.spk_embed_dim = spk_embed_dim + self.reduction_factor = reduction_factor + self.use_gst = use_gst + self.use_guided_attn_loss = use_guided_attn_loss + self.use_scaled_pos_enc = use_scaled_pos_enc + self.loss_type = loss_type + self.use_guided_attn_loss = use_guided_attn_loss + if self.use_guided_attn_loss: + if num_layers_applied_guided_attn == -1: + self.num_layers_applied_guided_attn = elayers + else: + self.num_layers_applied_guided_attn = num_layers_applied_guided_attn + if num_heads_applied_guided_attn == -1: + self.num_heads_applied_guided_attn = aheads + else: + self.num_heads_applied_guided_attn = num_heads_applied_guided_attn + self.modules_applied_guided_attn = modules_applied_guided_attn + if self.spk_embed_dim is not None: + self.spk_embed_integration_type = spk_embed_integration_type + + # use idx 0 as padding idx + self.padding_idx = 0 + + # get positional encoding class + pos_enc_class = ( + ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding + ) + + # define transformer encoder + if eprenet_conv_layers != 0: + # encoder prenet + encoder_input_layer = torch.nn.Sequential( + EncoderPrenet( + idim=idim, + embed_dim=embed_dim, + elayers=0, + econv_layers=eprenet_conv_layers, + econv_chans=eprenet_conv_chans, + econv_filts=eprenet_conv_filts, + use_batch_norm=use_batch_norm, + dropout_rate=eprenet_dropout_rate, + padding_idx=self.padding_idx, + ), + torch.nn.Linear(eprenet_conv_chans, adim), + ) + else: + encoder_input_layer = torch.nn.Embedding( + num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx + ) + self.encoder = Encoder( + idim=idim, + attention_dim=adim, + attention_heads=aheads, + linear_units=eunits, + num_blocks=elayers, + input_layer=encoder_input_layer, + dropout_rate=transformer_enc_dropout_rate, + positional_dropout_rate=transformer_enc_positional_dropout_rate, + attention_dropout_rate=transformer_enc_attn_dropout_rate, + pos_enc_class=pos_enc_class, + normalize_before=encoder_normalize_before, + concat_after=encoder_concat_after, + positionwise_layer_type=positionwise_layer_type, + positionwise_conv_kernel_size=positionwise_conv_kernel_size, + ) + + # define GST + if self.use_gst: + self.gst = StyleEncoder( + idim=odim, # the input is mel-spectrogram + gst_tokens=gst_tokens, + gst_token_dim=adim, + gst_heads=gst_heads, + conv_layers=gst_conv_layers, + conv_chans_list=gst_conv_chans_list, + conv_kernel_size=gst_conv_kernel_size, + conv_stride=gst_conv_stride, + gru_layers=gst_gru_layers, + gru_units=gst_gru_units, + ) + + # define projection layer + if self.spk_embed_dim is not None: + if self.spk_embed_integration_type == "add": + self.projection = torch.nn.Linear(self.spk_embed_dim, adim) + else: + self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) + + # define transformer decoder + if dprenet_layers != 0: + # decoder prenet + decoder_input_layer = torch.nn.Sequential( + DecoderPrenet( + idim=odim, + n_layers=dprenet_layers, + n_units=dprenet_units, + dropout_rate=dprenet_dropout_rate, + ), + torch.nn.Linear(dprenet_units, adim), + ) + else: + decoder_input_layer = "linear" + self.decoder = Decoder( + odim=odim, # odim is needed when no prenet is used + attention_dim=adim, + attention_heads=aheads, + linear_units=dunits, + num_blocks=dlayers, + dropout_rate=transformer_dec_dropout_rate, + positional_dropout_rate=transformer_dec_positional_dropout_rate, + self_attention_dropout_rate=transformer_dec_attn_dropout_rate, + src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate, + input_layer=decoder_input_layer, + use_output_layer=False, + pos_enc_class=pos_enc_class, + normalize_before=decoder_normalize_before, + concat_after=decoder_concat_after, + ) + + # define final projection + self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) + self.prob_out = torch.nn.Linear(adim, reduction_factor) + + # define postnet + self.postnet = ( + None + if postnet_layers == 0 + else Postnet( + idim=idim, + odim=odim, + n_layers=postnet_layers, + n_chans=postnet_chans, + n_filts=postnet_filts, + use_batch_norm=use_batch_norm, + dropout_rate=postnet_dropout_rate, + ) + ) + + # define loss function + self.criterion = TransformerLoss( + use_masking=use_masking, + use_weighted_masking=use_weighted_masking, + bce_pos_weight=bce_pos_weight, + ) + if self.use_guided_attn_loss: + self.attn_criterion = GuidedMultiHeadAttentionLoss( + sigma=guided_attn_loss_sigma, + alpha=guided_attn_loss_lambda, + ) + + # initialize parameters + self._reset_parameters( + init_type=init_type, + init_enc_alpha=init_enc_alpha, + init_dec_alpha=init_enc_alpha, + ) + + def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): + # initialize parameters + if init_type != "pytorch": + initialize(self, init_type) + + # initialize alpha in scaled positional encoding + if self.use_scaled_pos_enc: + self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) + self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) + + def forward( + self, + text: torch.Tensor, + text_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + spembs: torch.Tensor = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Calculate forward propagation. + + Args: + text (LongTensor): Batch of padded character ids (B, Tmax). + text_lengths (LongTensor): Batch of lengths of each input batch (B,). + speech (Tensor): Batch of padded target features (B, Lmax, odim). + speech_lengths (LongTensor): Batch of the lengths of each target (B,). + spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Loss scalar value. + Dict: Statistics to be monitored. + Tensor: Weight value. + + """ + text = text[:, : text_lengths.max()] # for data-parallel + speech = speech[:, : speech_lengths.max()] # for data-parallel + batch_size = text.size(0) + + # Add eos at the last of sequence + xs = F.pad(text, [0, 1], "constant", self.padding_idx) + for i, l in enumerate(text_lengths): + xs[i, l] = self.eos + ilens = text_lengths + 1 + + ys = speech + olens = speech_lengths + + # make labels for stop prediction + labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) + labels = F.pad(labels, [0, 1], "constant", 1.0) + + # calculate transformer outputs + after_outs, before_outs, logits = self._forward(xs, ilens, ys, olens, spembs) + + # modifiy mod part of groundtruth + olens_in = olens + if self.reduction_factor > 1: + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) + max_olen = max(olens) + ys = ys[:, :max_olen] + labels = labels[:, :max_olen] + labels[:, -1] = 1.0 # make sure at least one frame has 1 + + # caluculate loss values + l1_loss, l2_loss, bce_loss = self.criterion( + after_outs, before_outs, logits, ys, labels, olens + ) + if self.loss_type == "L1": + loss = l1_loss + bce_loss + elif self.loss_type == "L2": + loss = l2_loss + bce_loss + elif self.loss_type == "L1+L2": + loss = l1_loss + l2_loss + bce_loss + else: + raise ValueError("unknown --loss-type " + self.loss_type) + + stats = dict( + l1_loss=l1_loss.item(), + l2_loss=l2_loss.item(), + bce_loss=bce_loss.item(), + ) + + # calculate guided attention loss + if self.use_guided_attn_loss: + # calculate for encoder + if "encoder" in self.modules_applied_guided_attn: + att_ws = [] + for idx, layer_idx in enumerate( + reversed(range(len(self.encoder.encoders))) + ): + att_ws += [ + self.encoder.encoders[layer_idx].self_attn.attn[ + :, : self.num_heads_applied_guided_attn + ] + ] + if idx + 1 == self.num_layers_applied_guided_attn: + break + att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_in, T_in) + enc_attn_loss = self.attn_criterion(att_ws, ilens, ilens) + loss = loss + enc_attn_loss + stats.update(enc_attn_loss=enc_attn_loss.item()) + # calculate for decoder + if "decoder" in self.modules_applied_guided_attn: + att_ws = [] + for idx, layer_idx in enumerate( + reversed(range(len(self.decoder.decoders))) + ): + att_ws += [ + self.decoder.decoders[layer_idx].self_attn.attn[ + :, : self.num_heads_applied_guided_attn + ] + ] + if idx + 1 == self.num_layers_applied_guided_attn: + break + att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_out) + dec_attn_loss = self.attn_criterion(att_ws, olens_in, olens_in) + loss = loss + dec_attn_loss + stats.update(dec_attn_loss=dec_attn_loss.item()) + # calculate for encoder-decoder + if "encoder-decoder" in self.modules_applied_guided_attn: + att_ws = [] + for idx, layer_idx in enumerate( + reversed(range(len(self.decoder.decoders))) + ): + att_ws += [ + self.decoder.decoders[layer_idx].src_attn.attn[ + :, : self.num_heads_applied_guided_attn + ] + ] + if idx + 1 == self.num_layers_applied_guided_attn: + break + att_ws = torch.cat(att_ws, dim=1) # (B, H*L, T_out, T_in) + enc_dec_attn_loss = self.attn_criterion(att_ws, ilens, olens_in) + loss = loss + enc_dec_attn_loss + stats.update(enc_dec_attn_loss=enc_dec_attn_loss.item()) + + stats.update(loss=loss.item()) + + # report extra information + if self.use_scaled_pos_enc: + stats.update( + encoder_alpha=self.encoder.embed[-1].alpha.data.item(), + decoder_alpha=self.decoder.embed[-1].alpha.data.item(), + ) + + loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) + return loss, stats, weight + + def _forward( + self, + xs: torch.Tensor, + ilens: torch.Tensor, + ys: torch.Tensor, + olens: torch.Tensor, + spembs: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # forward encoder + x_masks = self._source_mask(ilens) + hs, h_masks = self.encoder(xs, x_masks) + + # integrate with GST + if self.use_gst: + style_embs = self.gst(ys) + hs = hs + style_embs.unsqueeze(1) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + hs = self._integrate_with_spk_embed(hs, spembs) + + # thin out frames for reduction factor (B, Lmax, odim) -> (B, Lmax//r, odim) + if self.reduction_factor > 1: + ys_in = ys[:, self.reduction_factor - 1 :: self.reduction_factor] + olens_in = olens.new([olen // self.reduction_factor for olen in olens]) + else: + ys_in, olens_in = ys, olens + + # add first zero frame and remove last frame for auto-regressive + ys_in = self._add_first_frame_and_remove_last_frame(ys_in) + + # forward decoder + y_masks = self._target_mask(olens_in) + zs, _ = self.decoder(ys_in, y_masks, hs, h_masks) + # (B, Lmax//r, odim * r) -> (B, Lmax//r * r, odim) + before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) + # (B, Lmax//r, r) -> (B, Lmax//r * r) + logits = self.prob_out(zs).view(zs.size(0), -1) + + # postnet -> (B, Lmax//r * r, odim) + if self.postnet is None: + after_outs = before_outs + else: + after_outs = before_outs + self.postnet( + before_outs.transpose(1, 2) + ).transpose(1, 2) + + return after_outs, before_outs, logits + + def inference( + self, + text: torch.Tensor, + speech: torch.Tensor = None, + spembs: torch.Tensor = None, + threshold: float = 0.5, + minlenratio: float = 0.0, + maxlenratio: float = 10.0, + use_teacher_forcing: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate the sequence of features given the sequences of characters. + + Args: + text (LongTensor): Input sequence of characters (T,). + speech (Tensor, optional): Feature sequence to extract style (N, idim). + spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). + threshold (float, optional): Threshold in inference. + minlenratio (float, optional): Minimum length ratio in inference. + maxlenratio (float, optional): Maximum length ratio in inference. + use_teacher_forcing (bool, optional): Whether to use teacher forcing. + + Returns: + Tensor: Output sequence of features (L, odim). + Tensor: Output sequence of stop probabilities (L,). + Tensor: Encoder-decoder (source) attention weights (#layers, #heads, L, T). + + """ + x = text + y = speech + spemb = spembs + + # add eos at the last of sequence + x = F.pad(x, [0, 1], "constant", self.eos) + + # inference with teacher forcing + if use_teacher_forcing: + assert speech is not None, "speech must be provided with teacher forcing." + + # get teacher forcing outputs + xs, ys = x.unsqueeze(0), y.unsqueeze(0) + spembs = None if spemb is None else spemb.unsqueeze(0) + ilens = x.new_tensor([xs.size(1)]).long() + olens = y.new_tensor([ys.size(1)]).long() + outs, *_ = self._forward(xs, ilens, ys, olens, spembs) + + # get attention weights + att_ws = [] + for i in range(len(self.decoder.decoders)): + att_ws += [self.decoder.decoders[i].src_attn.attn] + att_ws = torch.stack(att_ws, dim=1) # (B, L, H, T_out, T_in) + + return outs[0], None, att_ws[0] + + # forward encoder + xs = x.unsqueeze(0) + hs, _ = self.encoder(xs, None) + + # integrate GST + if self.use_gst: + style_embs = self.gst(y.unsqueeze(0)) + hs = hs + style_embs.unsqueeze(1) + + # integrate speaker embedding + if self.spk_embed_dim is not None: + spembs = spemb.unsqueeze(0) + hs = self._integrate_with_spk_embed(hs, spembs) + + # set limits of length + maxlen = int(hs.size(1) * maxlenratio / self.reduction_factor) + minlen = int(hs.size(1) * minlenratio / self.reduction_factor) + + # initialize + idx = 0 + ys = hs.new_zeros(1, 1, self.odim) + outs, probs = [], [] + + # forward decoder step-by-step + z_cache = self.decoder.init_state(x) + while True: + # update index + idx += 1 + + # calculate output and stop prob at idx-th step + y_masks = subsequent_mask(idx).unsqueeze(0).to(x.device) + z, z_cache = self.decoder.forward_one_step( + ys, y_masks, hs, cache=z_cache + ) # (B, adim) + outs += [ + self.feat_out(z).view(self.reduction_factor, self.odim) + ] # [(r, odim), ...] + probs += [torch.sigmoid(self.prob_out(z))[0]] # [(r), ...] + + # update next inputs + ys = torch.cat( + (ys, outs[-1][-1].view(1, 1, self.odim)), dim=1 + ) # (1, idx + 1, odim) + + # get attention weights + att_ws_ = [] + for name, m in self.named_modules(): + if isinstance(m, MultiHeadedAttention) and "src" in name: + att_ws_ += [m.attn[0, :, -1].unsqueeze(1)] # [(#heads, 1, T),...] + if idx == 1: + att_ws = att_ws_ + else: + # [(#heads, l, T), ...] + att_ws = [ + torch.cat([att_w, att_w_], dim=1) + for att_w, att_w_ in zip(att_ws, att_ws_) + ] + + # check whether to finish generation + if int(sum(probs[-1] >= threshold)) > 0 or idx >= maxlen: + # check mininum length + if idx < minlen: + continue + outs = ( + torch.cat(outs, dim=0).unsqueeze(0).transpose(1, 2) + ) # (L, odim) -> (1, L, odim) -> (1, odim, L) + if self.postnet is not None: + outs = outs + self.postnet(outs) # (1, odim, L) + outs = outs.transpose(2, 1).squeeze(0) # (L, odim) + probs = torch.cat(probs, dim=0) + break + + # concatenate attention weights -> (#layers, #heads, L, T) + att_ws = torch.stack(att_ws, dim=0) + + return outs, probs, att_ws + + def _add_first_frame_and_remove_last_frame(self, ys: torch.Tensor) -> torch.Tensor: + ys_in = torch.cat( + [ys.new_zeros((ys.shape[0], 1, ys.shape[2])), ys[:, :-1]], dim=1 + ) + return ys_in + + def _source_mask(self, ilens): + """Make masks for self-attention. + + Args: + ilens (LongTensor): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> ilens = [5, 3] + >>> self._source_mask(ilens) + tensor([[[1, 1, 1, 1, 1], + [[1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) + return x_masks.unsqueeze(-2) + + def _target_mask(self, olens: torch.Tensor) -> torch.Tensor: + """Make masks for masked self-attention. + + Args: + olens (LongTensor): Batch of lengths (B,). + + Returns: + Tensor: Mask tensor for masked self-attention. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + >>> olens = [5, 3] + >>> self._target_mask(olens) + tensor([[[1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 1, 0], + [1, 1, 1, 1, 1]], + [[1, 0, 0, 0, 0], + [1, 1, 0, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0]]], dtype=torch.uint8) + + """ + y_masks = make_non_pad_mask(olens).to(next(self.parameters()).device) + s_masks = subsequent_mask(y_masks.size(-1), device=y_masks.device).unsqueeze(0) + return y_masks.unsqueeze(-2) & s_masks + + def _integrate_with_spk_embed( + self, hs: torch.Tensor, spembs: torch.Tensor + ) -> torch.Tensor: + """Integrate speaker embedding with hidden states. + + Args: + hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). + spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). + + Returns: + Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). + + """ + if self.spk_embed_integration_type == "add": + # apply projection and then add to hidden states + spembs = self.projection(F.normalize(spembs)) + hs = hs + spembs.unsqueeze(1) + elif self.spk_embed_integration_type == "concat": + # concat hidden states with spk embeds and then apply projection + spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) + hs = self.projection(torch.cat([hs, spembs], dim=-1)) + else: + raise NotImplementedError("support only add or concat.") + + return hs diff --git a/espnet2/tts/variance_predictor.py b/espnet2/tts/variance_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..abc8f99cf30b3f3057e39bb61f16ab62867aa1ed --- /dev/null +++ b/espnet2/tts/variance_predictor.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Variance predictor related modules.""" + +import torch + +from typeguard import check_argument_types + +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm + + +class VariancePredictor(torch.nn.Module): + """Variance predictor module. + + This is a module of variacne predictor described in `FastSpeech 2: + Fast and High-Quality End-to-End Text to Speech`_. + + .. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`: + https://arxiv.org/abs/2006.04558 + + """ + + def __init__( + self, + idim: int, + n_layers: int = 2, + n_chans: int = 384, + kernel_size: int = 3, + bias: bool = True, + dropout_rate: float = 0.5, + ): + """Initilize duration predictor module. + + Args: + idim (int): Input dimension. + n_layers (int, optional): Number of convolutional layers. + n_chans (int, optional): Number of channels of convolutional layers. + kernel_size (int, optional): Kernel size of convolutional layers. + dropout_rate (float, optional): Dropout rate. + + """ + assert check_argument_types() + super().__init__() + self.conv = torch.nn.ModuleList() + for idx in range(n_layers): + in_chans = idim if idx == 0 else n_chans + self.conv += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chans, + n_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + bias=bias, + ), + torch.nn.ReLU(), + LayerNorm(n_chans, dim=1), + torch.nn.Dropout(dropout_rate), + ) + ] + self.linear = torch.nn.Linear(n_chans, 1) + + def forward(self, xs: torch.Tensor, x_masks: torch.Tensor = None) -> torch.Tensor: + """Calculate forward propagation. + + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + x_masks (ByteTensor, optional): + Batch of masks indicating padded part (B, Tmax). + + Returns: + Tensor: Batch of predicted sequences (B, Tmax, 1). + + """ + xs = xs.transpose(1, -1) # (B, idim, Tmax) + for f in self.conv: + xs = f(xs) # (B, C, Tmax) + + xs = self.linear(xs.transpose(1, 2)) # (B, Tmax, 1) + + if x_masks is not None: + xs = xs.masked_fill(x_masks, 0.0) + + return xs diff --git a/espnet2/utils/__init__.py b/espnet2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/espnet2/utils/build_dataclass.py b/espnet2/utils/build_dataclass.py new file mode 100644 index 0000000000000000000000000000000000000000..6675c99a014d31ee85d13dc3c4e553979c597b38 --- /dev/null +++ b/espnet2/utils/build_dataclass.py @@ -0,0 +1,17 @@ +import argparse +import dataclasses + +from typeguard import check_type + + +def build_dataclass(dataclass, args: argparse.Namespace): + """Helper function to build dataclass from 'args'.""" + kwargs = {} + for field in dataclasses.fields(dataclass): + if not hasattr(args, field.name): + raise ValueError( + f"args doesn't have {field.name}. You need to set it to ArgumentsParser" + ) + check_type(field.name, getattr(args, field.name), field.type) + kwargs[field.name] = getattr(args, field.name) + return dataclass(**kwargs) diff --git a/espnet2/utils/config_argparse.py b/espnet2/utils/config_argparse.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d7197a74715297d986c9da8597da062d178680 --- /dev/null +++ b/espnet2/utils/config_argparse.py @@ -0,0 +1,47 @@ +import argparse +from pathlib import Path + +import yaml + + +class ArgumentParser(argparse.ArgumentParser): + """Simple implementation of ArgumentParser supporting config file + + This class is originated from https://github.com/bw2/ConfigArgParse, + but this class is lack of some features that it has. + + - Not supporting multiple config files + - Automatically adding "--config" as an option. + - Not supporting any formats other than yaml + - Not checking argument type + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.add_argument("--config", help="Give config file in yaml format") + + def parse_known_args(self, args=None, namespace=None): + # Once parsing for setting from "--config" + _args, _ = super().parse_known_args(args, namespace) + if _args.config is not None: + if not Path(_args.config).exists(): + self.error(f"No such file: {_args.config}") + + with open(_args.config, "r", encoding="utf-8") as f: + d = yaml.safe_load(f) + if not isinstance(d, dict): + self.error("Config file has non dict value: {_args.config}") + + for key in d: + for action in self._actions: + if key == action.dest: + break + else: + self.error(f"unrecognized arguments: {key} (from {_args.config})") + + # NOTE(kamo): Ignore "--config" from a config file + # NOTE(kamo): Unlike "configargparse", this module doesn't check type. + # i.e. We can set any type value regardless of argument type. + self.set_defaults(**d) + return super().parse_known_args(args, namespace) diff --git a/espnet2/utils/get_default_kwargs.py b/espnet2/utils/get_default_kwargs.py new file mode 100644 index 0000000000000000000000000000000000000000..0f11e8af43ef38cad69c530824be702dbfed5981 --- /dev/null +++ b/espnet2/utils/get_default_kwargs.py @@ -0,0 +1,57 @@ +import inspect + + +class Invalid: + """Marker object for not serializable-object""" + + +def get_default_kwargs(func): + """Get the default values of the input function. + + Examples: + >>> def func(a, b=3): pass + >>> get_default_kwargs(func) + {'b': 3} + + """ + + def yaml_serializable(value): + # isinstance(x, tuple) includes namedtuple, so type is used here + if type(value) is tuple: + return yaml_serializable(list(value)) + elif isinstance(value, set): + return yaml_serializable(list(value)) + elif isinstance(value, dict): + if not all(isinstance(k, str) for k in value): + return Invalid + retval = {} + for k, v in value.items(): + v2 = yaml_serializable(v) + # Register only valid object + if v2 not in (Invalid, inspect.Parameter.empty): + retval[k] = v2 + return retval + elif isinstance(value, list): + retval = [] + for v in value: + v2 = yaml_serializable(v) + # If any elements in the list are invalid, + # the list also becomes invalid + if v2 is Invalid: + return Invalid + else: + retval.append(v2) + return retval + elif value in (inspect.Parameter.empty, None): + return value + elif isinstance(value, (float, int, complex, bool, str, bytes)): + return value + else: + return Invalid + + # params: An ordered mapping of inspect.Parameter + params = inspect.signature(func).parameters + data = {p.name: p.default for p in params.values()} + # Remove not yaml-serializable object + data = yaml_serializable(data) + return data diff --git a/espnet2/utils/griffin_lim.py b/espnet2/utils/griffin_lim.py new file mode 100644 index 0000000000000000000000000000000000000000..a56f7fc683d1429a307e437c0f0c8b1be5953c79 --- /dev/null +++ b/espnet2/utils/griffin_lim.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 + +"""Griffin-Lim related modules.""" + +# Copyright 2019 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +import logging + +from distutils.version import LooseVersion +from functools import partial +from typeguard import check_argument_types +from typing import Optional + +import librosa +import numpy as np + +EPS = 1e-10 + + +def logmel2linear( + lmspc: np.ndarray, + fs: int, + n_fft: int, + n_mels: int, + fmin: int = None, + fmax: int = None, +) -> np.ndarray: + """Convert log Mel filterbank to linear spectrogram. + + Args: + lmspc: Log Mel filterbank (T, n_mels). + fs: Sampling frequency. + n_fft: The number of FFT points. + n_mels: The number of mel basis. + f_min: Minimum frequency to analyze. + f_max: Maximum frequency to analyze. + + Returns: + Linear spectrogram (T, n_fft // 2 + 1). + + """ + assert lmspc.shape[1] == n_mels + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + mspc = np.power(10.0, lmspc) + mel_basis = librosa.filters.mel(fs, n_fft, n_mels, fmin, fmax) + inv_mel_basis = np.linalg.pinv(mel_basis) + return np.maximum(EPS, np.dot(inv_mel_basis, mspc.T).T) + + +def griffin_lim( + spc: np.ndarray, + n_fft: int, + n_shift: int, + win_length: int = None, + window: Optional[str] = "hann", + n_iter: Optional[int] = 32, +) -> np.ndarray: + """Convert linear spectrogram into waveform using Griffin-Lim. + + Args: + spc: Linear spectrogram (T, n_fft // 2 + 1). + n_fft: The number of FFT points. + n_shift: Shift size in points. + win_length: Window length in points. + window: Window function type. + n_iter: The number of iterations. + + Returns: + Reconstructed waveform (N,). + + """ + # assert the size of input linear spectrogram + assert spc.shape[1] == n_fft // 2 + 1 + + if LooseVersion(librosa.__version__) >= LooseVersion("0.7.0"): + # use librosa's fast Grriffin-Lim algorithm + spc = np.abs(spc.T) + y = librosa.griffinlim( + S=spc, + n_iter=n_iter, + hop_length=n_shift, + win_length=win_length, + window=window, + center=True if spc.shape[1] > 1 else False, + ) + else: + # use slower version of Grriffin-Lim algorithm + logging.warning( + "librosa version is old. use slow version of Grriffin-Lim algorithm." + "if you want to use fast Griffin-Lim, please update librosa via " + "`source ./path.sh && pip install librosa==0.7.0`." + ) + cspc = np.abs(spc).astype(np.complex).T + angles = np.exp(2j * np.pi * np.random.rand(*cspc.shape)) + y = librosa.istft(cspc * angles, n_shift, win_length, window=window) + for i in range(n_iter): + angles = np.exp( + 1j + * np.angle(librosa.stft(y, n_fft, n_shift, win_length, window=window)) + ) + y = librosa.istft(cspc * angles, n_shift, win_length, window=window) + + return y + + +# TODO(kan-bayashi): write as torch.nn.Module +class Spectrogram2Waveform(object): + """Spectrogram to waveform conversion module.""" + + def __init__( + self, + n_fft: int, + n_shift: int, + fs: int = None, + n_mels: int = None, + win_length: int = None, + window: Optional[str] = "hann", + fmin: int = None, + fmax: int = None, + griffin_lim_iters: Optional[int] = 32, + ): + """Initialize module. + + Args: + fs: Sampling frequency. + n_fft: The number of FFT points. + n_shift: Shift size in points. + n_mels: The number of mel basis. + win_length: Window length in points. + window: Window function type. + f_min: Minimum frequency to analyze. + f_max: Maximum frequency to analyze. + griffin_lim_iters: The number of iterations. + + """ + assert check_argument_types() + self.fs = fs + self.logmel2linear = ( + partial( + logmel2linear, fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax + ) + if n_mels is not None + else None + ) + self.griffin_lim = partial( + griffin_lim, + n_fft=n_fft, + n_shift=n_shift, + win_length=win_length, + window=window, + n_iter=griffin_lim_iters, + ) + self.params = dict( + n_fft=n_fft, + n_shift=n_shift, + win_length=win_length, + window=window, + n_iter=griffin_lim_iters, + ) + if n_mels is not None: + self.params.update(fs=fs, n_mels=n_mels, fmin=fmin, fmax=fmax) + + def __repr__(self): + retval = f"{self.__class__.__name__}(" + for k, v in self.params.items(): + retval += f"{k}={v}, " + retval += ")" + return retval + + def __call__(self, spc): + """Convert spectrogram to waveform. + + Args: + spc: Log Mel filterbank (T, n_mels) + or linear spectrogram (T, n_fft // 2 + 1). + + Returns: + Reconstructed waveform (N,). + + """ + if self.logmel2linear is not None: + spc = self.logmel2linear(spc) + return self.griffin_lim(spc) diff --git a/espnet2/utils/nested_dict_action.py b/espnet2/utils/nested_dict_action.py new file mode 100644 index 0000000000000000000000000000000000000000..38ec57b31d0a6997ccf276c07dc3ba95ee1b7f78 --- /dev/null +++ b/espnet2/utils/nested_dict_action.py @@ -0,0 +1,106 @@ +import argparse +import copy + +import yaml + + +class NestedDictAction(argparse.Action): + """Action class to append items to dict object. + + Examples: + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--conf', action=NestedDictAction, + ... default={'a': 4}) + >>> parser.parse_args(['--conf', 'a=3', '--conf', 'c=4']) + Namespace(conf={'a': 3, 'c': 4}) + >>> parser.parse_args(['--conf', 'c.d=4']) + Namespace(conf={'a': 4, 'c': {'d': 4}}) + >>> parser.parse_args(['--conf', 'c.d=4', '--conf', 'c=2']) + Namespace(conf={'a': 4, 'c': 2}) + >>> parser.parse_args(['--conf', '{d: 5, e: 9}']) + Namespace(conf={'d': 5, 'e': 9}) + + """ + + _syntax = """Syntax: + {op} = + {op} .= + {op} + {op} +e.g. + {op} a=4 + {op} a.b={{c: true}} + {op} {{"c": True}} + {op} {{a: 34.5}} +""" + + def __init__( + self, + option_strings, + dest, + nargs=None, + default=None, + choices=None, + required=False, + help=None, + metavar=None, + ): + super().__init__( + option_strings=option_strings, + dest=dest, + nargs=nargs, + default=copy.deepcopy(default), + type=None, + choices=choices, + required=required, + help=help, + metavar=metavar, + ) + + def __call__(self, parser, namespace, values, option_strings=None): + # --{option} a.b=3 -> {'a': {'b': 3}} + if "=" in values: + indict = copy.deepcopy(getattr(namespace, self.dest, {})) + key, value = values.split("=", maxsplit=1) + if not value.strip() == "": + value = yaml.load(value, Loader=yaml.Loader) + if not isinstance(indict, dict): + indict = {} + + keys = key.split(".") + d = indict + for idx, k in enumerate(keys): + if idx == len(keys) - 1: + d[k] = value + else: + if not isinstance(d.setdefault(k, {}), dict): + # Remove the existing value and recreates as empty dict + d[k] = {} + d = d[k] + + # Update the value + setattr(namespace, self.dest, indict) + else: + try: + # At the first, try eval(), i.e. Python syntax dict. + # e.g. --{option} "{'a': 3}" -> {'a': 3} + # This is workaround for internal behaviour of configargparse. + value = eval(values, {}, {}) + if not isinstance(value, dict): + syntax = self._syntax.format(op=option_strings) + mes = f"must be interpreted as dict: but got {values}\n{syntax}" + raise argparse.ArgumentTypeError(self, mes) + except Exception: + # and the second, try yaml.load + value = yaml.load(values, Loader=yaml.Loader) + if not isinstance(value, dict): + syntax = self._syntax.format(op=option_strings) + mes = f"must be interpreted as dict: but got {values}\n{syntax}" + raise argparse.ArgumentError(self, mes) + + d = getattr(namespace, self.dest, None) + if isinstance(d, dict): + d.update(value) + else: + # Remove existing params, and overwrite + setattr(namespace, self.dest, value) diff --git a/espnet2/utils/sized_dict.py b/espnet2/utils/sized_dict.py new file mode 100644 index 0000000000000000000000000000000000000000..105d8c3985e9ab0ea27f2cf2856b25fe47904408 --- /dev/null +++ b/espnet2/utils/sized_dict.py @@ -0,0 +1,75 @@ +import collections +import sys + +from torch import multiprocessing + + +def get_size(obj, seen=None): + """Recursively finds size of objects + + Taken from https://github.com/bosswissam/pysize + + """ + + size = sys.getsizeof(obj) + if seen is None: + seen = set() + + obj_id = id(obj) + if obj_id in seen: + return 0 + + # Important mark as seen *before* entering recursion to gracefully handle + # self-referential objects + seen.add(obj_id) + + if isinstance(obj, dict): + size += sum([get_size(v, seen) for v in obj.values()]) + size += sum([get_size(k, seen) for k in obj.keys()]) + elif hasattr(obj, "__dict__"): + size += get_size(obj.__dict__, seen) + elif isinstance(obj, (list, set, tuple)): + size += sum([get_size(i, seen) for i in obj]) + + return size + + +class SizedDict(collections.abc.MutableMapping): + def __init__(self, shared: bool = False, data: dict = None): + if data is None: + data = {} + + if shared: + # NOTE(kamo): Don't set manager as a field because Manager, which includes + # weakref object, causes following error with method="spawn", + # "TypeError: can't pickle weakref objects" + self.cache = multiprocessing.Manager().dict(**data) + else: + self.manager = None + self.cache = dict(**data) + self.size = 0 + + def __setitem__(self, key, value): + if key in self.cache: + self.size -= get_size(self.cache[key]) + else: + self.size += sys.getsizeof(key) + self.size += get_size(value) + self.cache[key] = value + + def __getitem__(self, key): + return self.cache[key] + + def __delitem__(self, key): + self.size -= get_size(self.cache[key]) + self.size -= sys.getsizeof(key) + del self.cache[key] + + def __iter__(self): + return iter(self.cache) + + def __contains__(self, key): + return key in self.cache + + def __len__(self): + return len(self.cache) diff --git a/espnet2/utils/types.py b/espnet2/utils/types.py new file mode 100644 index 0000000000000000000000000000000000000000..6b36f9c4b87ed9258a5d1e254ba298ed5dbc01d2 --- /dev/null +++ b/espnet2/utils/types.py @@ -0,0 +1,149 @@ +from distutils.util import strtobool +from typing import Optional +from typing import Tuple +from typing import Union + +import humanfriendly + + +def str2bool(value: str) -> bool: + return bool(strtobool(value)) + + +def remove_parenthesis(value: str): + value = value.strip() + if value.startswith("(") and value.endswith(")"): + value = value[1:-1] + elif value.startswith("[") and value.endswith("]"): + value = value[1:-1] + return value + + +def remove_quotes(value: str): + value = value.strip() + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + elif value.startswith("'") and value.endswith("'"): + value = value[1:-1] + return value + + +def int_or_none(value: str) -> Optional[int]: + """int_or_none. + + Examples: + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--foo', type=int_or_none) + >>> parser.parse_args(['--foo', '456']) + Namespace(foo=456) + >>> parser.parse_args(['--foo', 'none']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'null']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'nil']) + Namespace(foo=None) + + """ + if value.strip().lower() in ("none", "null", "nil"): + return None + return int(value) + + +def float_or_none(value: str) -> Optional[float]: + """float_or_none. + + Examples: + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--foo', type=float_or_none) + >>> parser.parse_args(['--foo', '4.5']) + Namespace(foo=4.5) + >>> parser.parse_args(['--foo', 'none']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'null']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'nil']) + Namespace(foo=None) + + """ + if value.strip().lower() in ("none", "null", "nil"): + return None + return float(value) + + +def humanfriendly_parse_size_or_none(value) -> Optional[float]: + if value.strip().lower() in ("none", "null", "nil"): + return None + return humanfriendly.parse_size(value) + + +def str_or_int(value: str) -> Union[str, int]: + try: + return int(value) + except ValueError: + return value + + +def str_or_none(value: str) -> Optional[str]: + """str_or_none. + + Examples: + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--foo', type=str_or_none) + >>> parser.parse_args(['--foo', 'aaa']) + Namespace(foo='aaa') + >>> parser.parse_args(['--foo', 'none']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'null']) + Namespace(foo=None) + >>> parser.parse_args(['--foo', 'nil']) + Namespace(foo=None) + + """ + if value.strip().lower() in ("none", "null", "nil"): + return None + return value + + +def str2pair_str(value: str) -> Tuple[str, str]: + """str2pair_str. + + Examples: + >>> import argparse + >>> str2pair_str('abc,def ') + ('abc', 'def') + >>> parser = argparse.ArgumentParser() + >>> _ = parser.add_argument('--foo', type=str2pair_str) + >>> parser.parse_args(['--foo', 'abc,def']) + Namespace(foo=('abc', 'def')) + + """ + value = remove_parenthesis(value) + a, b = value.split(",") + + # Workaround for configargparse issues: + # If the list values are given from yaml file, + # the value givent to type() is shaped as python-list, + # e.g. ['a', 'b', 'c'], + # so we need to remove double quotes from it. + return remove_quotes(a), remove_quotes(b) + + +def str2triple_str(value: str) -> Tuple[str, str, str]: + """str2triple_str. + + Examples: + >>> str2triple_str('abc,def ,ghi') + ('abc', 'def', 'ghi') + """ + value = remove_parenthesis(value) + a, b, c = value.split(",") + + # Workaround for configargparse issues: + # If the list values are given from yaml file, + # the value givent to type() is shaped as python-list, + # e.g. ['a', 'b', 'c'], + # so we need to remove quotes from it. + return remove_quotes(a), remove_quotes(b), remove_quotes(c) diff --git a/espnet2/utils/yaml_no_alias_safe_dump.py b/espnet2/utils/yaml_no_alias_safe_dump.py new file mode 100644 index 0000000000000000000000000000000000000000..70a7b0e40be7ecaaaa86a1cae86f146d83116876 --- /dev/null +++ b/espnet2/utils/yaml_no_alias_safe_dump.py @@ -0,0 +1,14 @@ +import yaml + + +class NoAliasSafeDumper(yaml.SafeDumper): + # Disable anchor/alias in yaml because looks ugly + def ignore_aliases(self, data): + return True + + +def yaml_no_alias_safe_dump(data, stream=None, **kwargs): + """Safe-dump in yaml with no anchor/alias""" + return yaml.dump( + data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs + ) diff --git a/feats_stats.npz b/feats_stats.npz new file mode 100644 index 0000000000000000000000000000000000000000..3d3c12a872c91167c791fa08d384ac5d58401e49 Binary files /dev/null and b/feats_stats.npz differ diff --git a/model.pth b/model.pth new file mode 100644 index 0000000000000000000000000000000000000000..164ffe0208bfe5a47230ef713635150bc5e5af99 --- /dev/null +++ b/model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:318a3fd35b3ec41260b511239f824675e649f72e75bd0004e0cbc89e442c2f88 +size 46702723 diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..c00790c0eb2810df141cfb7448c762dd6e59520c --- /dev/null +++ b/packages.txt @@ -0,0 +1,3 @@ +espeak-ng +mbrola +mbrola-us1 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5005151f1c2c814160c375b53ec5bc3dae5d78ba --- /dev/null +++ b/requirements.txt @@ -0,0 +1,26 @@ +setuptools>=38.5.1 +configargparse>=1.2.1 +typeguard>=2.7.0 +humanfriendly +scipy>=1.4.1 +matplotlib==3.1.0 +pillow>=6.1.0 +editdistance==0.5.2 +ctc-segmentation<1.6>=1.4.0 +wandb +filelock +torch==1.8.1 +tensorboard>=1.14 +librosa>=0.8.0 +sentencepiece<0.1.90>=0.1.82 +nltk>=3.4.5 +PyYAML>=5.1.2 +soundfile>=0.10.2 +h5py>=2.10.0 +kaldiio>=2.17.0 +pyworld>=0.2.10 +espnet_tts_frontend +nara_wpe>=0.0.5 +torch_complex +pytorch_wpe +parallel_wavegan==0.5.3 \ No newline at end of file diff --git a/style1.pt b/style1.pt new file mode 100644 index 0000000000000000000000000000000000000000..5c07523c615dae1e68d146a3831b880be7f554c2 --- /dev/null +++ b/style1.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5a26e17f6cbd480ee46c9f567e16d363e97546253acafbc67b8266f85ae90881 +size 469 diff --git a/style2.pt b/style2.pt new file mode 100644 index 0000000000000000000000000000000000000000..4efd1918de87940187829929f5049c39892c8514 --- /dev/null +++ b/style2.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c0efbde0e443c584f738908cdbe502b2a0c9859e4cade410cd215f04b715937 +size 469 diff --git a/style3.pt b/style3.pt new file mode 100644 index 0000000000000000000000000000000000000000..46e2f1e674c99e274caf942c2f4ca6e274745d1b --- /dev/null +++ b/style3.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:93fe47a86cd1d4322761160149368030cbb9430942aa10e14cd8f9850cc3608a +size 469 diff --git a/style4.pt b/style4.pt new file mode 100644 index 0000000000000000000000000000000000000000..f2baa88da6359534051e720b541553ca2aa6abb0 --- /dev/null +++ b/style4.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:226882523b20d73825114f6cda0549dd1edc931c7c9547ba2c72d2501d0fe023 +size 469 diff --git a/style5.pt b/style5.pt new file mode 100644 index 0000000000000000000000000000000000000000..6c99592f067da8a06e863872dd0fcb289db3b231 --- /dev/null +++ b/style5.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:300b2ad466ea00dbf3a295bb73a2d8bbe2e3c340bdc2e6acf224557ee114b9ee +size 469 diff --git a/style6.pt b/style6.pt new file mode 100644 index 0000000000000000000000000000000000000000..75f64bd46b39fe7588ecbed1f08e0e18928bd700 --- /dev/null +++ b/style6.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4b392aa37ac0c7711336f67801c7922885ef7e5e0e85fad10e05ee36e86880e5 +size 469