Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import logging | |
from pathlib import Path | |
import shutil | |
from tempfile import NamedTemporaryFile | |
from collections import Counter, defaultdict | |
import pandas as pd | |
import torchaudio | |
from tqdm import tqdm | |
from fairseq.data.audio.audio_utils import convert_waveform | |
from examples.speech_to_text.data_utils import ( | |
create_zip, | |
gen_config_yaml, | |
gen_vocab, | |
get_zip_manifest, | |
load_tsv_to_dicts, | |
save_df_to_tsv | |
) | |
from examples.speech_synthesis.data_utils import ( | |
extract_logmel_spectrogram, extract_pitch, extract_energy, get_global_cmvn, | |
ipa_phonemize, get_mfa_alignment, get_unit_alignment | |
) | |
log = logging.getLogger(__name__) | |
def process(args): | |
assert "train" in args.splits | |
out_root = Path(args.output_root).absolute() | |
out_root.mkdir(exist_ok=True) | |
print("Fetching data...") | |
audio_manifest_root = Path(args.audio_manifest_root).absolute() | |
samples = [] | |
for s in args.splits: | |
for e in load_tsv_to_dicts(audio_manifest_root / f"{s}.audio.tsv"): | |
e["split"] = s | |
samples.append(e) | |
sample_ids = [s["id"] for s in samples] | |
# Get alignment info | |
id_to_alignment = None | |
if args.textgrid_zip is not None: | |
assert args.id_to_units_tsv is None | |
id_to_alignment = get_mfa_alignment( | |
args.textgrid_zip, sample_ids, args.sample_rate, args.hop_length | |
) | |
elif args.id_to_units_tsv is not None: | |
# assume identical hop length on the unit sequence | |
id_to_alignment = get_unit_alignment(args.id_to_units_tsv, sample_ids) | |
# Extract features and pack features into ZIP | |
feature_name = "logmelspec80" | |
zip_path = out_root / f"{feature_name}.zip" | |
pitch_zip_path = out_root / "pitch.zip" | |
energy_zip_path = out_root / "energy.zip" | |
gcmvn_npz_path = out_root / "gcmvn_stats.npz" | |
if zip_path.exists() and gcmvn_npz_path.exists(): | |
print(f"{zip_path} and {gcmvn_npz_path} exist.") | |
else: | |
feature_root = out_root / feature_name | |
feature_root.mkdir(exist_ok=True) | |
pitch_root = out_root / "pitch" | |
energy_root = out_root / "energy" | |
if args.add_fastspeech_targets: | |
pitch_root.mkdir(exist_ok=True) | |
energy_root.mkdir(exist_ok=True) | |
print("Extracting Mel spectrogram features...") | |
for sample in tqdm(samples): | |
waveform, sample_rate = torchaudio.load(sample["audio"]) | |
waveform, sample_rate = convert_waveform( | |
waveform, sample_rate, normalize_volume=args.normalize_volume, | |
to_sample_rate=args.sample_rate | |
) | |
sample_id = sample["id"] | |
target_length = None | |
if id_to_alignment is not None: | |
a = id_to_alignment[sample_id] | |
target_length = sum(a.frame_durations) | |
if a.start_sec is not None and a.end_sec is not None: | |
start_frame = int(a.start_sec * sample_rate) | |
end_frame = int(a.end_sec * sample_rate) | |
waveform = waveform[:, start_frame: end_frame] | |
extract_logmel_spectrogram( | |
waveform, sample_rate, feature_root / f"{sample_id}.npy", | |
win_length=args.win_length, hop_length=args.hop_length, | |
n_fft=args.n_fft, n_mels=args.n_mels, f_min=args.f_min, | |
f_max=args.f_max, target_length=target_length | |
) | |
if args.add_fastspeech_targets: | |
assert id_to_alignment is not None | |
extract_pitch( | |
waveform, sample_rate, pitch_root / f"{sample_id}.npy", | |
hop_length=args.hop_length, log_scale=True, | |
phoneme_durations=id_to_alignment[sample_id].frame_durations | |
) | |
extract_energy( | |
waveform, energy_root / f"{sample_id}.npy", | |
hop_length=args.hop_length, n_fft=args.n_fft, | |
log_scale=True, | |
phoneme_durations=id_to_alignment[sample_id].frame_durations | |
) | |
print("ZIPing features...") | |
create_zip(feature_root, zip_path) | |
get_global_cmvn(feature_root, gcmvn_npz_path) | |
shutil.rmtree(feature_root) | |
if args.add_fastspeech_targets: | |
create_zip(pitch_root, pitch_zip_path) | |
shutil.rmtree(pitch_root) | |
create_zip(energy_root, energy_zip_path) | |
shutil.rmtree(energy_root) | |
print("Fetching ZIP manifest...") | |
audio_paths, audio_lengths = get_zip_manifest(zip_path) | |
pitch_paths, pitch_lengths, energy_paths, energy_lengths = [None] * 4 | |
if args.add_fastspeech_targets: | |
pitch_paths, pitch_lengths = get_zip_manifest(pitch_zip_path) | |
energy_paths, energy_lengths = get_zip_manifest(energy_zip_path) | |
# Generate TSV manifest | |
print("Generating manifest...") | |
manifest_by_split = {split: defaultdict(list) for split in args.splits} | |
for sample in tqdm(samples): | |
sample_id, split = sample["id"], sample["split"] | |
normalized_utt = sample["tgt_text"] | |
if id_to_alignment is not None: | |
normalized_utt = " ".join(id_to_alignment[sample_id].tokens) | |
elif args.ipa_vocab: | |
normalized_utt = ipa_phonemize( | |
normalized_utt, lang=args.lang, use_g2p=args.use_g2p | |
) | |
manifest_by_split[split]["id"].append(sample_id) | |
manifest_by_split[split]["audio"].append(audio_paths[sample_id]) | |
manifest_by_split[split]["n_frames"].append(audio_lengths[sample_id]) | |
manifest_by_split[split]["tgt_text"].append(normalized_utt) | |
manifest_by_split[split]["speaker"].append(sample["speaker"]) | |
manifest_by_split[split]["src_text"].append(sample["src_text"]) | |
if args.add_fastspeech_targets: | |
assert id_to_alignment is not None | |
duration = " ".join( | |
str(d) for d in id_to_alignment[sample_id].frame_durations | |
) | |
manifest_by_split[split]["duration"].append(duration) | |
manifest_by_split[split]["pitch"].append(pitch_paths[sample_id]) | |
manifest_by_split[split]["energy"].append(energy_paths[sample_id]) | |
for split in args.splits: | |
save_df_to_tsv( | |
pd.DataFrame.from_dict(manifest_by_split[split]), | |
out_root / f"{split}.tsv" | |
) | |
# Generate vocab | |
vocab_name, spm_filename = None, None | |
if id_to_alignment is not None or args.ipa_vocab: | |
vocab = Counter() | |
for t in manifest_by_split["train"]["tgt_text"]: | |
vocab.update(t.split(" ")) | |
vocab_name = "vocab.txt" | |
with open(out_root / vocab_name, "w") as f: | |
for s, c in vocab.most_common(): | |
f.write(f"{s} {c}\n") | |
else: | |
spm_filename_prefix = "spm_char" | |
spm_filename = f"{spm_filename_prefix}.model" | |
with NamedTemporaryFile(mode="w") as f: | |
for t in manifest_by_split["train"]["tgt_text"]: | |
f.write(t + "\n") | |
f.flush() # needed to ensure gen_vocab sees dumped text | |
gen_vocab(Path(f.name), out_root / spm_filename_prefix, "char") | |
# Generate speaker list | |
speakers = sorted({sample["speaker"] for sample in samples}) | |
speakers_path = out_root / "speakers.txt" | |
with open(speakers_path, "w") as f: | |
for speaker in speakers: | |
f.write(f"{speaker}\n") | |
# Generate config YAML | |
win_len_t = args.win_length / args.sample_rate | |
hop_len_t = args.hop_length / args.sample_rate | |
extra = { | |
"sample_rate": args.sample_rate, | |
"features": { | |
"type": "spectrogram+melscale+log", | |
"eps": 1e-2, "n_mels": args.n_mels, "n_fft": args.n_fft, | |
"window_fn": "hann", "win_length": args.win_length, | |
"hop_length": args.hop_length, "sample_rate": args.sample_rate, | |
"win_len_t": win_len_t, "hop_len_t": hop_len_t, | |
"f_min": args.f_min, "f_max": args.f_max, | |
"n_stft": args.n_fft // 2 + 1 | |
} | |
} | |
if len(speakers) > 1: | |
extra["speaker_set_filename"] = "speakers.txt" | |
gen_config_yaml( | |
out_root, spm_filename=spm_filename, vocab_name=vocab_name, | |
audio_root=out_root.as_posix(), input_channels=None, | |
input_feat_per_channel=None, specaugment_policy=None, | |
cmvn_type="global", gcmvn_path=gcmvn_npz_path, extra=extra | |
) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--audio-manifest-root", "-m", required=True, type=str) | |
parser.add_argument("--output-root", "-o", required=True, type=str) | |
parser.add_argument("--splits", "-s", type=str, nargs="+", | |
default=["train", "dev", "test"]) | |
parser.add_argument("--ipa-vocab", action="store_true") | |
parser.add_argument("--use-g2p", action="store_true") | |
parser.add_argument("--lang", type=str, default="en-us") | |
parser.add_argument("--win-length", type=int, default=1024) | |
parser.add_argument("--hop-length", type=int, default=256) | |
parser.add_argument("--n-fft", type=int, default=1024) | |
parser.add_argument("--n-mels", type=int, default=80) | |
parser.add_argument("--f-min", type=int, default=20) | |
parser.add_argument("--f-max", type=int, default=8000) | |
parser.add_argument("--sample-rate", type=int, default=22050) | |
parser.add_argument("--normalize-volume", "-n", action="store_true") | |
parser.add_argument("--textgrid-zip", type=str, default=None) | |
parser.add_argument("--id-to-units-tsv", type=str, default=None) | |
parser.add_argument("--add-fastspeech-targets", action="store_true") | |
args = parser.parse_args() | |
process(args) | |
if __name__ == "__main__": | |
main() | |