Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
from collections import defaultdict | |
from itertools import chain | |
from pathlib import Path | |
import numpy as np | |
import torchaudio | |
import torchaudio.sox_effects as ta_sox | |
import yaml | |
from tqdm import tqdm | |
from examples.speech_to_text.data_utils import load_tsv_to_dicts | |
from examples.speech_synthesis.preprocessing.speaker_embedder import SpkrEmbedder | |
def extract_embedding(audio_path, embedder): | |
wav, sr = torchaudio.load(audio_path) # 2D | |
if sr != embedder.RATE: | |
wav, sr = ta_sox.apply_effects_tensor( | |
wav, sr, [["rate", str(embedder.RATE)]] | |
) | |
try: | |
emb = embedder([wav[0].cuda().float()]).cpu().numpy() | |
except RuntimeError: | |
emb = None | |
return emb | |
def process(args): | |
print("Fetching data...") | |
raw_manifest_root = Path(args.raw_manifest_root).absolute() | |
samples = [load_tsv_to_dicts(raw_manifest_root / (s + ".tsv")) | |
for s in args.splits] | |
samples = list(chain(*samples)) | |
with open(args.config, "r") as f: | |
config = yaml.load(f, Loader=yaml.FullLoader) | |
with open(f"{config['audio_root']}/{config['speaker_set_filename']}") as f: | |
speaker_to_id = {r.strip(): i for i, r in enumerate(f)} | |
embedder = SpkrEmbedder(args.ckpt).cuda() | |
speaker_to_cnt = defaultdict(float) | |
speaker_to_emb = defaultdict(float) | |
for sample in tqdm(samples, desc="extract emb"): | |
emb = extract_embedding(sample["audio"], embedder) | |
if emb is not None: | |
speaker_to_cnt[sample["speaker"]] += 1 | |
speaker_to_emb[sample["speaker"]] += emb | |
if len(speaker_to_emb) != len(speaker_to_id): | |
missed = set(speaker_to_id) - set(speaker_to_emb.keys()) | |
print( | |
f"WARNING: missing embeddings for {len(missed)} speaker:\n{missed}" | |
) | |
speaker_emb_mat = np.zeros((len(speaker_to_id), len(emb)), float) | |
for speaker in speaker_to_emb: | |
idx = speaker_to_id[speaker] | |
emb = speaker_to_emb[speaker] | |
cnt = speaker_to_cnt[speaker] | |
speaker_emb_mat[idx, :] = emb / cnt | |
speaker_emb_name = "speaker_emb.npy" | |
speaker_emb_path = f"{config['audio_root']}/{speaker_emb_name}" | |
np.save(speaker_emb_path, speaker_emb_mat) | |
config["speaker_emb_filename"] = speaker_emb_name | |
with open(args.new_config, "w") as f: | |
yaml.dump(config, f) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--raw-manifest-root", "-m", required=True, type=str) | |
parser.add_argument("--splits", "-s", type=str, nargs="+", | |
default=["train"]) | |
parser.add_argument("--config", "-c", required=True, type=str) | |
parser.add_argument("--new-config", "-n", required=True, type=str) | |
parser.add_argument("--ckpt", required=True, type=str, | |
help="speaker embedder checkpoint") | |
args = parser.parse_args() | |
process(args) | |
if __name__ == "__main__": | |
main() | |