|
|
|
|
|
|
|
|
|
|
|
import random |
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from utils.data_utils import * |
|
from tqdm import tqdm |
|
from g2p_en import G2p |
|
import librosa |
|
from torch.utils.data import Dataset |
|
import pandas as pd |
|
import time |
|
import io |
|
|
|
SAMPLE_RATE = 16000 |
|
|
|
from .g2p_processor import G2pProcessor |
|
|
|
phonemizer_g2p = G2pProcessor() |
|
|
|
|
|
class VALLEDataset(Dataset): |
|
def __init__(self, args): |
|
print(f"Initializing VALLEDataset") |
|
self.dataset_list = args.dataset_list |
|
|
|
print(f"using sampling rate {SAMPLE_RATE}") |
|
|
|
|
|
book_col_name = [ |
|
"ID", |
|
"Original_text", |
|
"Normalized_text", |
|
"Aligned_or_not", |
|
"Start_time", |
|
"End_time", |
|
"Signal_to_noise_ratio", |
|
] |
|
trans_col_name = [ |
|
"ID", |
|
"Original_text", |
|
"Normalized_text", |
|
"Dir_path", |
|
"Duration", |
|
] |
|
self.metadata_cache = pd.DataFrame(columns=book_col_name) |
|
self.trans_cache = pd.DataFrame(columns=trans_col_name) |
|
|
|
|
|
|
|
|
|
|
|
self.dataset2dir = { |
|
"dev-clean": f"{args.data_dir}/dev-clean", |
|
"dev-other": f"{args.data_dir}/dev-other", |
|
"test-clean": f"{args.data_dir}/test-clean", |
|
"test-other": f"{args.data_dir}/test-other", |
|
"train-clean-100": f"{args.data_dir}/train-clean-100", |
|
"train-clean-360": f"{args.data_dir}/train-clean-360", |
|
"train-other-500": f"{args.data_dir}/train-other-500", |
|
} |
|
|
|
|
|
for dataset_name in self.dataset_list: |
|
print("Initializing dataset: ", dataset_name) |
|
|
|
self.book_files_list = self.get_metadata_files( |
|
self.dataset2dir[dataset_name] |
|
) |
|
self.trans_files_list = self.get_trans_files(self.dataset2dir[dataset_name]) |
|
|
|
|
|
print("reading paths for dataset...") |
|
for book_path in tqdm(self.book_files_list): |
|
tmp_cache = pd.read_csv( |
|
book_path, sep="\t", names=book_col_name, quoting=3 |
|
) |
|
self.metadata_cache = pd.concat( |
|
[self.metadata_cache, tmp_cache], ignore_index=True |
|
) |
|
self.metadata_cache.set_index("ID", inplace=True) |
|
|
|
|
|
print("creating transcripts for dataset...") |
|
for trans_path in tqdm(self.trans_files_list): |
|
tmp_cache = pd.read_csv( |
|
trans_path, sep="\t", names=trans_col_name, quoting=3 |
|
) |
|
tmp_cache["Dir_path"] = os.path.dirname(trans_path) |
|
self.trans_cache = pd.concat( |
|
[self.trans_cache, tmp_cache], ignore_index=True |
|
) |
|
self.trans_cache.set_index("ID", inplace=True) |
|
|
|
|
|
self.trans_cache["Duration"] = ( |
|
self.metadata_cache.End_time[self.trans_cache.index] |
|
- self.metadata_cache.Start_time[self.trans_cache.index] |
|
) |
|
|
|
|
|
|
|
|
|
print(f"Filtering files with duration between 3.0 and 15.0 seconds") |
|
print(f"Before filtering: {len(self.trans_cache)}") |
|
self.trans_cache = self.trans_cache[ |
|
(self.trans_cache["Duration"] >= 3.0) |
|
& (self.trans_cache["Duration"] <= 15.0) |
|
] |
|
print(f"After filtering: {len(self.trans_cache)}") |
|
|
|
def get_metadata_files(self, directory): |
|
book_files = [] |
|
for root, _, files in os.walk(directory): |
|
for file in files: |
|
if file.endswith(".book.tsv") and file[0] != ".": |
|
rel_path = os.path.join(root, file) |
|
book_files.append(rel_path) |
|
return book_files |
|
|
|
def get_trans_files(self, directory): |
|
trans_files = [] |
|
for root, _, files in os.walk(directory): |
|
for file in files: |
|
if file.endswith(".trans.tsv") and file[0] != ".": |
|
rel_path = os.path.join(root, file) |
|
trans_files.append(rel_path) |
|
return trans_files |
|
|
|
def get_audio_files(self, directory): |
|
audio_files = [] |
|
for root, _, files in os.walk(directory): |
|
for file in files: |
|
if file.endswith((".flac", ".wav", ".opus")): |
|
rel_path = os.path.relpath(os.path.join(root, file), directory) |
|
audio_files.append(rel_path) |
|
return audio_files |
|
|
|
def get_num_frames(self, index): |
|
|
|
duration = self.meta_data_cache["Duration"][index] |
|
|
|
num_frames = int(duration * 75) |
|
|
|
|
|
|
|
|
|
return num_frames |
|
|
|
def __len__(self): |
|
return len(self.trans_cache) |
|
|
|
def __getitem__(self, idx): |
|
|
|
file_dir_path = self.trans_cache["Dir_path"].iloc[idx] |
|
|
|
uid = self.trans_cache.index[idx] |
|
|
|
file_name = uid + ".wav" |
|
|
|
full_file_path = os.path.join(file_dir_path, file_name) |
|
|
|
|
|
phone = self.trans_cache["Normalized_text"][uid] |
|
phone = phonemizer_g2p(phone, "en")[1] |
|
|
|
speech, _ = librosa.load(full_file_path, sr=SAMPLE_RATE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = {} |
|
|
|
|
|
|
|
|
|
inputs["speech"] = speech |
|
inputs["phone"] = phone |
|
return inputs |
|
|
|
|
|
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): |
|
if len(batch) == 0: |
|
return 0 |
|
if len(batch) == max_sentences: |
|
return 1 |
|
if num_tokens > max_tokens: |
|
return 1 |
|
return 0 |
|
|
|
|
|
def batch_by_size( |
|
indices, |
|
num_tokens_fn, |
|
max_tokens=None, |
|
max_sentences=None, |
|
required_batch_size_multiple=1, |
|
): |
|
""" |
|
Yield mini-batches of indices bucketed by size. Batches may contain |
|
sequences of different lengths. |
|
|
|
Args: |
|
indices (List[int]): ordered list of dataset indices |
|
num_tokens_fn (callable): function that returns the number of tokens at |
|
a given index |
|
max_tokens (int, optional): max number of tokens in each batch |
|
(default: None). |
|
max_sentences (int, optional): max number of sentences in each |
|
batch (default: None). |
|
required_batch_size_multiple (int, optional): require batch size to |
|
be a multiple of N (default: 1). |
|
""" |
|
bsz_mult = required_batch_size_multiple |
|
|
|
sample_len = 0 |
|
sample_lens = [] |
|
batch = [] |
|
batches = [] |
|
for i in range(len(indices)): |
|
idx = indices[i] |
|
num_tokens = num_tokens_fn(idx) |
|
sample_lens.append(num_tokens) |
|
sample_len = max(sample_len, num_tokens) |
|
|
|
assert ( |
|
sample_len <= max_tokens |
|
), "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format( |
|
idx, sample_len, max_tokens |
|
) |
|
num_tokens = (len(batch) + 1) * sample_len |
|
|
|
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): |
|
mod_len = max( |
|
bsz_mult * (len(batch) // bsz_mult), |
|
len(batch) % bsz_mult, |
|
) |
|
batches.append(batch[:mod_len]) |
|
batch = batch[mod_len:] |
|
sample_lens = sample_lens[mod_len:] |
|
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 |
|
batch.append(idx) |
|
if len(batch) > 0: |
|
batches.append(batch) |
|
return batches |
|
|
|
|
|
def test(): |
|
from utils.util import load_config |
|
|
|
cfg = load_config("./egs/tts/VALLE_V2/exp_ar_libritts.json") |
|
dataset = VALLEDataset(cfg.dataset) |
|
metadata_cache = dataset.metadata_cache |
|
trans_cache = dataset.trans_cache |
|
print(trans_cache.head(10)) |
|
|
|
breakpoint() |
|
|
|
|
|
if __name__ == "__main__": |
|
test() |
|
|