Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
from pathlib import Path | |
from typing import Optional, List, Dict | |
import zipfile | |
import tempfile | |
from dataclasses import dataclass | |
from itertools import groupby | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from tqdm import tqdm | |
from examples.speech_to_text.data_utils import load_tsv_to_dicts | |
from fairseq.data.audio.audio_utils import TTSSpectrogram, TTSMelScale | |
def trim_or_pad_to_target_length( | |
data_1d_or_2d: np.ndarray, target_length: int | |
) -> np.ndarray: | |
assert len(data_1d_or_2d.shape) in {1, 2} | |
delta = data_1d_or_2d.shape[0] - target_length | |
if delta >= 0: # trim if being longer | |
data_1d_or_2d = data_1d_or_2d[: target_length] | |
else: # pad if being shorter | |
if len(data_1d_or_2d.shape) == 1: | |
data_1d_or_2d = np.concatenate( | |
[data_1d_or_2d, np.zeros(-delta)], axis=0 | |
) | |
else: | |
data_1d_or_2d = np.concatenate( | |
[data_1d_or_2d, np.zeros((-delta, data_1d_or_2d.shape[1]))], | |
axis=0 | |
) | |
return data_1d_or_2d | |
def extract_logmel_spectrogram( | |
waveform: torch.Tensor, sample_rate: int, | |
output_path: Optional[Path] = None, win_length: int = 1024, | |
hop_length: int = 256, n_fft: int = 1024, | |
win_fn: callable = torch.hann_window, n_mels: int = 80, | |
f_min: float = 0., f_max: float = 8000, eps: float = 1e-5, | |
overwrite: bool = False, target_length: Optional[int] = None | |
): | |
if output_path is not None and output_path.is_file() and not overwrite: | |
return | |
spectrogram_transform = TTSSpectrogram( | |
n_fft=n_fft, win_length=win_length, hop_length=hop_length, | |
window_fn=win_fn | |
) | |
mel_scale_transform = TTSMelScale( | |
n_mels=n_mels, sample_rate=sample_rate, f_min=f_min, f_max=f_max, | |
n_stft=n_fft // 2 + 1 | |
) | |
spectrogram = spectrogram_transform(waveform) | |
mel_spec = mel_scale_transform(spectrogram) | |
logmel_spec = torch.clamp(mel_spec, min=eps).log() | |
assert len(logmel_spec.shape) == 3 and logmel_spec.shape[0] == 1 | |
logmel_spec = logmel_spec.squeeze().t() # D x T -> T x D | |
if target_length is not None: | |
trim_or_pad_to_target_length(logmel_spec, target_length) | |
if output_path is not None: | |
np.save(output_path.as_posix(), logmel_spec) | |
else: | |
return logmel_spec | |
def extract_pitch( | |
waveform: torch.Tensor, sample_rate: int, | |
output_path: Optional[Path] = None, hop_length: int = 256, | |
log_scale: bool = True, phoneme_durations: Optional[List[int]] = None | |
): | |
if output_path is not None and output_path.is_file(): | |
return | |
try: | |
import pyworld | |
except ImportError: | |
raise ImportError("Please install PyWORLD: pip install pyworld") | |
_waveform = waveform.squeeze(0).double().numpy() | |
pitch, t = pyworld.dio( | |
_waveform, sample_rate, frame_period=hop_length / sample_rate * 1000 | |
) | |
pitch = pyworld.stonemask(_waveform, pitch, t, sample_rate) | |
if phoneme_durations is not None: | |
pitch = trim_or_pad_to_target_length(pitch, sum(phoneme_durations)) | |
try: | |
from scipy.interpolate import interp1d | |
except ImportError: | |
raise ImportError("Please install SciPy: pip install scipy") | |
nonzero_ids = np.where(pitch != 0)[0] | |
interp_fn = interp1d( | |
nonzero_ids, | |
pitch[nonzero_ids], | |
fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), | |
bounds_error=False, | |
) | |
pitch = interp_fn(np.arange(0, len(pitch))) | |
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations])) | |
pitch = np.array( | |
[ | |
np.mean(pitch[d_cumsum[i-1]: d_cumsum[i]]) | |
for i in range(1, len(d_cumsum)) | |
] | |
) | |
assert len(pitch) == len(phoneme_durations) | |
if log_scale: | |
pitch = np.log(pitch + 1) | |
if output_path is not None: | |
np.save(output_path.as_posix(), pitch) | |
else: | |
return pitch | |
def extract_energy( | |
waveform: torch.Tensor, output_path: Optional[Path] = None, | |
hop_length: int = 256, n_fft: int = 1024, log_scale: bool = True, | |
phoneme_durations: Optional[List[int]] = None | |
): | |
if output_path is not None and output_path.is_file(): | |
return | |
assert len(waveform.shape) == 2 and waveform.shape[0] == 1 | |
waveform = waveform.view(1, 1, waveform.shape[1]) | |
waveform = F.pad( | |
waveform.unsqueeze(1), [n_fft // 2, n_fft // 2, 0, 0], | |
mode="reflect" | |
) | |
waveform = waveform.squeeze(1) | |
fourier_basis = np.fft.fft(np.eye(n_fft)) | |
cutoff = int((n_fft / 2 + 1)) | |
fourier_basis = np.vstack( | |
[np.real(fourier_basis[:cutoff, :]), | |
np.imag(fourier_basis[:cutoff, :])] | |
) | |
forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) | |
forward_transform = F.conv1d( | |
waveform, forward_basis, stride=hop_length, padding=0 | |
) | |
real_part = forward_transform[:, :cutoff, :] | |
imag_part = forward_transform[:, cutoff:, :] | |
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) | |
energy = torch.norm(magnitude, dim=1).squeeze(0).numpy() | |
if phoneme_durations is not None: | |
energy = trim_or_pad_to_target_length(energy, sum(phoneme_durations)) | |
d_cumsum = np.cumsum(np.concatenate([np.array([0]), phoneme_durations])) | |
energy = np.array( | |
[ | |
np.mean(energy[d_cumsum[i - 1]: d_cumsum[i]]) | |
for i in range(1, len(d_cumsum)) | |
] | |
) | |
assert len(energy) == len(phoneme_durations) | |
if log_scale: | |
energy = np.log(energy + 1) | |
if output_path is not None: | |
np.save(output_path.as_posix(), energy) | |
else: | |
return energy | |
def get_global_cmvn(feature_root: Path, output_path: Optional[Path] = None): | |
mean_x, mean_x2, n_frames = None, None, 0 | |
feature_paths = feature_root.glob("*.npy") | |
for p in tqdm(feature_paths): | |
with open(p, 'rb') as f: | |
frames = np.load(f).squeeze() | |
n_frames += frames.shape[0] | |
cur_mean_x = frames.sum(axis=0) | |
if mean_x is None: | |
mean_x = cur_mean_x | |
else: | |
mean_x += cur_mean_x | |
cur_mean_x2 = (frames ** 2).sum(axis=0) | |
if mean_x2 is None: | |
mean_x2 = cur_mean_x2 | |
else: | |
mean_x2 += cur_mean_x2 | |
mean_x /= n_frames | |
mean_x2 /= n_frames | |
var_x = mean_x2 - mean_x ** 2 | |
std_x = np.sqrt(np.maximum(var_x, 1e-10)) | |
if output_path is not None: | |
with open(output_path, 'wb') as f: | |
np.savez(f, mean=mean_x, std=std_x) | |
else: | |
return {"mean": mean_x, "std": std_x} | |
def ipa_phonemize(text, lang="en-us", use_g2p=False): | |
if use_g2p: | |
assert lang == "en-us", "g2pE phonemizer only works for en-us" | |
try: | |
from g2p_en import G2p | |
g2p = G2p() | |
return " ".join("|" if p == " " else p for p in g2p(text)) | |
except ImportError: | |
raise ImportError( | |
"Please install phonemizer: pip install g2p_en" | |
) | |
else: | |
try: | |
from phonemizer import phonemize | |
from phonemizer.separator import Separator | |
return phonemize( | |
text, backend='espeak', language=lang, | |
separator=Separator(word="| ", phone=" ") | |
) | |
except ImportError: | |
raise ImportError( | |
"Please install phonemizer: pip install phonemizer" | |
) | |
class ForceAlignmentInfo(object): | |
tokens: List[str] | |
frame_durations: List[int] | |
start_sec: Optional[float] | |
end_sec: Optional[float] | |
def get_mfa_alignment_by_sample_id( | |
textgrid_zip_path: str, sample_id: str, sample_rate: int, | |
hop_length: int, silence_phones: List[str] = ("sil", "sp", "spn") | |
) -> ForceAlignmentInfo: | |
try: | |
import tgt | |
except ImportError: | |
raise ImportError("Please install TextGridTools: pip install tgt") | |
filename = f"{sample_id}.TextGrid" | |
out_root = Path(tempfile.gettempdir()) | |
tgt_path = out_root / filename | |
with zipfile.ZipFile(textgrid_zip_path) as f_zip: | |
f_zip.extract(filename, path=out_root) | |
textgrid = tgt.io.read_textgrid(tgt_path.as_posix()) | |
os.remove(tgt_path) | |
phones, frame_durations = [], [] | |
start_sec, end_sec, end_idx = 0, 0, 0 | |
for t in textgrid.get_tier_by_name("phones")._objects: | |
s, e, p = t.start_time, t.end_time, t.text | |
# Trim leading silences | |
if len(phones) == 0: | |
if p in silence_phones: | |
continue | |
else: | |
start_sec = s | |
phones.append(p) | |
if p not in silence_phones: | |
end_sec = e | |
end_idx = len(phones) | |
r = sample_rate / hop_length | |
frame_durations.append(int(np.round(e * r) - np.round(s * r))) | |
# Trim tailing silences | |
phones = phones[:end_idx] | |
frame_durations = frame_durations[:end_idx] | |
return ForceAlignmentInfo( | |
tokens=phones, frame_durations=frame_durations, start_sec=start_sec, | |
end_sec=end_sec | |
) | |
def get_mfa_alignment( | |
textgrid_zip_path: str, sample_ids: List[str], sample_rate: int, | |
hop_length: int | |
) -> Dict[str, ForceAlignmentInfo]: | |
return { | |
i: get_mfa_alignment_by_sample_id( | |
textgrid_zip_path, i, sample_rate, hop_length | |
) for i in tqdm(sample_ids) | |
} | |
def get_unit_alignment( | |
id_to_unit_tsv_path: str, sample_ids: List[str] | |
) -> Dict[str, ForceAlignmentInfo]: | |
id_to_units = { | |
e["id"]: e["units"] for e in load_tsv_to_dicts(id_to_unit_tsv_path) | |
} | |
id_to_units = {i: id_to_units[i].split() for i in sample_ids} | |
id_to_units_collapsed = { | |
i: [uu for uu, _ in groupby(u)] for i, u in id_to_units.items() | |
} | |
id_to_durations = { | |
i: [len(list(g)) for _, g in groupby(u)] for i, u in id_to_units.items() | |
} | |
return { | |
i: ForceAlignmentInfo( | |
tokens=id_to_units_collapsed[i], frame_durations=id_to_durations[i], | |
start_sec=None, end_sec=None | |
) | |
for i in sample_ids | |
} | |