Spaces:
Paused
Paused
import argparse | |
import importlib | |
import os | |
from argparse import RawTextHelpFormatter | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from TTS.config import load_config | |
from TTS.tts.datasets.TTSDataset import TTSDataset | |
from TTS.tts.models import setup_model | |
from TTS.tts.utils.text.characters import make_symbols, phonemes, symbols | |
from TTS.utils.audio import AudioProcessor | |
from TTS.utils.io import load_checkpoint | |
if __name__ == "__main__": | |
# pylint: disable=bad-option-value | |
parser = argparse.ArgumentParser( | |
description="""Extract attention masks from trained Tacotron/Tacotron2 models. | |
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n""" | |
"""Each attention mask is written to the same path as the input wav file with ".npy" file extension. | |
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n""" | |
""" | |
Example run: | |
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py | |
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth | |
--config_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json | |
--dataset_metafile metadata.csv | |
--data_path /root/LJSpeech-1.1/ | |
--batch_size 32 | |
--dataset ljspeech | |
--use_cuda True | |
""", | |
formatter_class=RawTextHelpFormatter, | |
) | |
parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ") | |
parser.add_argument( | |
"--config_path", | |
type=str, | |
required=True, | |
help="Path to Tacotron/Tacotron2 config file.", | |
) | |
parser.add_argument( | |
"--dataset", | |
type=str, | |
default="", | |
required=True, | |
help="Target dataset processor name from TTS.tts.dataset.preprocess.", | |
) | |
parser.add_argument( | |
"--dataset_metafile", | |
type=str, | |
default="", | |
required=True, | |
help="Dataset metafile inclusing file paths with transcripts.", | |
) | |
parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") | |
parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.") | |
parser.add_argument( | |
"--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." | |
) | |
args = parser.parse_args() | |
C = load_config(args.config_path) | |
ap = AudioProcessor(**C.audio) | |
# if the vocabulary was passed, replace the default | |
if "characters" in C.keys(): | |
symbols, phonemes = make_symbols(**C.characters) | |
# load the model | |
num_chars = len(phonemes) if C.use_phonemes else len(symbols) | |
# TODO: handle multi-speaker | |
model = setup_model(C) | |
model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) | |
# data loader | |
preprocessor = importlib.import_module("TTS.tts.datasets.formatters") | |
preprocessor = getattr(preprocessor, args.dataset) | |
meta_data = preprocessor(args.data_path, args.dataset_metafile) | |
dataset = TTSDataset( | |
model.decoder.r, | |
C.text_cleaner, | |
compute_linear_spec=False, | |
ap=ap, | |
meta_data=meta_data, | |
characters=C.characters if "characters" in C.keys() else None, | |
add_blank=C["add_blank"] if "add_blank" in C.keys() else False, | |
use_phonemes=C.use_phonemes, | |
phoneme_cache_path=C.phoneme_cache_path, | |
phoneme_language=C.phoneme_language, | |
enable_eos_bos=C.enable_eos_bos_chars, | |
) | |
dataset.sort_and_filter_items(C.get("sort_by_audio_len", default=False)) | |
loader = DataLoader( | |
dataset, | |
batch_size=args.batch_size, | |
num_workers=4, | |
collate_fn=dataset.collate_fn, | |
shuffle=False, | |
drop_last=False, | |
) | |
# compute attentions | |
file_paths = [] | |
with torch.no_grad(): | |
for data in tqdm(loader): | |
# setup input data | |
text_input = data[0] | |
text_lengths = data[1] | |
linear_input = data[3] | |
mel_input = data[4] | |
mel_lengths = data[5] | |
stop_targets = data[6] | |
item_idxs = data[7] | |
# dispatch data to GPU | |
if args.use_cuda: | |
text_input = text_input.cuda() | |
text_lengths = text_lengths.cuda() | |
mel_input = mel_input.cuda() | |
mel_lengths = mel_lengths.cuda() | |
model_outputs = model.forward(text_input, text_lengths, mel_input) | |
alignments = model_outputs["alignments"].detach() | |
for idx, alignment in enumerate(alignments): | |
item_idx = item_idxs[idx] | |
# interpolate if r > 1 | |
alignment = ( | |
torch.nn.functional.interpolate( | |
alignment.transpose(0, 1).unsqueeze(0), | |
size=None, | |
scale_factor=model.decoder.r, | |
mode="nearest", | |
align_corners=None, | |
recompute_scale_factor=None, | |
) | |
.squeeze(0) | |
.transpose(0, 1) | |
) | |
# remove paddings | |
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() | |
# set file paths | |
wav_file_name = os.path.basename(item_idx) | |
align_file_name = os.path.splitext(wav_file_name)[0] + "_attn.npy" | |
file_path = item_idx.replace(wav_file_name, align_file_name) | |
# save output | |
wav_file_abs_path = os.path.abspath(item_idx) | |
file_abs_path = os.path.abspath(file_path) | |
file_paths.append([wav_file_abs_path, file_abs_path]) | |
np.save(file_path, alignment) | |
# ourput metafile | |
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") | |
with open(metafile, "w", encoding="utf-8") as f: | |
for p in file_paths: | |
f.write(f"{p[0]}|{p[1]}\n") | |
print(f" >> Metafile created: {metafile}") | |