Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from pathlib import Path | |
import soundfile as sf | |
import sys | |
import torch | |
import torchaudio | |
from fairseq import checkpoint_utils, options, tasks, utils | |
from fairseq.logging import progress_bar | |
from fairseq.tasks.text_to_speech import plot_tts_output | |
from fairseq.data.audio.text_to_speech_dataset import TextToSpeechDataset | |
logging.basicConfig() | |
logging.root.setLevel(logging.INFO) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def make_parser(): | |
parser = options.get_speech_generation_parser() | |
parser.add_argument("--dump-features", action="store_true") | |
parser.add_argument("--dump-waveforms", action="store_true") | |
parser.add_argument("--dump-attentions", action="store_true") | |
parser.add_argument("--dump-eos-probs", action="store_true") | |
parser.add_argument("--dump-plots", action="store_true") | |
parser.add_argument("--dump-target", action="store_true") | |
parser.add_argument("--output-sample-rate", default=22050, type=int) | |
parser.add_argument("--teacher-forcing", action="store_true") | |
parser.add_argument( | |
"--audio-format", type=str, default="wav", choices=["wav", "flac"] | |
) | |
return parser | |
def postprocess_results( | |
dataset: TextToSpeechDataset, sample, hypos, resample_fn, dump_target | |
): | |
def to_np(x): | |
return None if x is None else x.detach().cpu().numpy() | |
sample_ids = [dataset.ids[i] for i in sample["id"].tolist()] | |
texts = sample["src_texts"] | |
attns = [to_np(hypo["attn"]) for hypo in hypos] | |
eos_probs = [to_np(hypo.get("eos_prob", None)) for hypo in hypos] | |
feat_preds = [to_np(hypo["feature"]) for hypo in hypos] | |
wave_preds = [to_np(resample_fn(h["waveform"])) for h in hypos] | |
if dump_target: | |
feat_targs = [to_np(hypo["targ_feature"]) for hypo in hypos] | |
wave_targs = [to_np(resample_fn(h["targ_waveform"])) for h in hypos] | |
else: | |
feat_targs = [None for _ in hypos] | |
wave_targs = [None for _ in hypos] | |
return zip(sample_ids, texts, attns, eos_probs, feat_preds, wave_preds, | |
feat_targs, wave_targs) | |
def dump_result( | |
is_na_model, | |
args, | |
vocoder, | |
sample_id, | |
text, | |
attn, | |
eos_prob, | |
feat_pred, | |
wave_pred, | |
feat_targ, | |
wave_targ, | |
): | |
sample_rate = args.output_sample_rate | |
out_root = Path(args.results_path) | |
if args.dump_features: | |
feat_dir = out_root / "feat" | |
feat_dir.mkdir(exist_ok=True, parents=True) | |
np.save(feat_dir / f"{sample_id}.npy", feat_pred) | |
if args.dump_target: | |
feat_tgt_dir = out_root / "feat_tgt" | |
feat_tgt_dir.mkdir(exist_ok=True, parents=True) | |
np.save(feat_tgt_dir / f"{sample_id}.npy", feat_targ) | |
if args.dump_attentions: | |
attn_dir = out_root / "attn" | |
attn_dir.mkdir(exist_ok=True, parents=True) | |
np.save(attn_dir / f"{sample_id}.npy", attn.numpy()) | |
if args.dump_eos_probs and not is_na_model: | |
eos_dir = out_root / "eos" | |
eos_dir.mkdir(exist_ok=True, parents=True) | |
np.save(eos_dir / f"{sample_id}.npy", eos_prob) | |
if args.dump_plots: | |
images = [feat_pred.T] if is_na_model else [feat_pred.T, attn] | |
names = ["output"] if is_na_model else ["output", "alignment"] | |
if feat_targ is not None: | |
images = [feat_targ.T] + images | |
names = [f"target (idx={sample_id})"] + names | |
if is_na_model: | |
plot_tts_output(images, names, attn, "alignment", suptitle=text) | |
else: | |
plot_tts_output(images, names, eos_prob, "eos prob", suptitle=text) | |
plot_dir = out_root / "plot" | |
plot_dir.mkdir(exist_ok=True, parents=True) | |
plt.savefig(plot_dir / f"{sample_id}.png") | |
plt.close() | |
if args.dump_waveforms: | |
ext = args.audio_format | |
if wave_pred is not None: | |
wav_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}" | |
wav_dir.mkdir(exist_ok=True, parents=True) | |
sf.write(wav_dir / f"{sample_id}.{ext}", wave_pred, sample_rate) | |
if args.dump_target and wave_targ is not None: | |
wav_tgt_dir = out_root / f"{ext}_{sample_rate}hz_{vocoder}_tgt" | |
wav_tgt_dir.mkdir(exist_ok=True, parents=True) | |
sf.write(wav_tgt_dir / f"{sample_id}.{ext}", wave_targ, sample_rate) | |
def main(args): | |
assert(args.dump_features or args.dump_waveforms or args.dump_attentions | |
or args.dump_eos_probs or args.dump_plots) | |
if args.max_tokens is None and args.batch_size is None: | |
args.max_tokens = 8000 | |
logger.info(args) | |
use_cuda = torch.cuda.is_available() and not args.cpu | |
task = tasks.setup_task(args) | |
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |
[args.path], | |
task=task, | |
) | |
model = models[0].cuda() if use_cuda else models[0] | |
# use the original n_frames_per_step | |
task.args.n_frames_per_step = saved_cfg.task.n_frames_per_step | |
task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) | |
data_cfg = task.data_cfg | |
sample_rate = data_cfg.config.get("features", {}).get("sample_rate", 22050) | |
resample_fn = { | |
False: lambda x: x, | |
True: lambda x: torchaudio.sox_effects.apply_effects_tensor( | |
x.detach().cpu().unsqueeze(0), sample_rate, | |
[['rate', str(args.output_sample_rate)]] | |
)[0].squeeze(0) | |
}.get(args.output_sample_rate != sample_rate) | |
if args.output_sample_rate != sample_rate: | |
logger.info(f"resampling to {args.output_sample_rate}Hz") | |
generator = task.build_generator([model], args) | |
itr = task.get_batch_iterator( | |
dataset=task.dataset(args.gen_subset), | |
max_tokens=args.max_tokens, | |
max_sentences=args.batch_size, | |
max_positions=(sys.maxsize, sys.maxsize), | |
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, | |
required_batch_size_multiple=args.required_batch_size_multiple, | |
num_shards=args.num_shards, | |
shard_id=args.shard_id, | |
num_workers=args.num_workers, | |
data_buffer_size=args.data_buffer_size, | |
).next_epoch_itr(shuffle=False) | |
Path(args.results_path).mkdir(exist_ok=True, parents=True) | |
is_na_model = getattr(model, "NON_AUTOREGRESSIVE", False) | |
dataset = task.dataset(args.gen_subset) | |
vocoder = task.args.vocoder | |
with progress_bar.build_progress_bar(args, itr) as t: | |
for sample in t: | |
sample = utils.move_to_cuda(sample) if use_cuda else sample | |
hypos = generator.generate(model, sample, has_targ=args.dump_target) | |
for result in postprocess_results( | |
dataset, sample, hypos, resample_fn, args.dump_target | |
): | |
dump_result(is_na_model, args, vocoder, *result) | |
def cli_main(): | |
parser = make_parser() | |
args = options.parse_args_and_arch(parser) | |
main(args) | |
if __name__ == "__main__": | |
cli_main() | |