|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torchaudio |
|
import json |
|
import os |
|
import numpy as np |
|
import librosa |
|
from torch.nn.utils.rnn import pad_sequence |
|
from modules import whisper_extractor as whisper |
|
|
|
|
|
class TorchaudioDataset(torch.utils.data.Dataset): |
|
def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): |
|
""" |
|
Args: |
|
cfg: config |
|
dataset: dataset name |
|
|
|
""" |
|
assert isinstance(dataset, str) |
|
|
|
self.sr = sr |
|
self.cfg = cfg |
|
|
|
if metadata is None: |
|
self.train_metadata_path = os.path.join( |
|
cfg.preprocess.processed_dir, dataset, cfg.preprocess.train_file |
|
) |
|
self.valid_metadata_path = os.path.join( |
|
cfg.preprocess.processed_dir, dataset, cfg.preprocess.valid_file |
|
) |
|
self.metadata = self.get_metadata() |
|
else: |
|
self.metadata = metadata |
|
|
|
if accelerator is not None: |
|
self.device = accelerator.device |
|
elif torch.cuda.is_available(): |
|
self.device = torch.device("cuda") |
|
else: |
|
self.device = torch.device("cpu") |
|
|
|
def get_metadata(self): |
|
metadata = [] |
|
with open(self.train_metadata_path, "r", encoding="utf-8") as t: |
|
metadata.extend(json.load(t)) |
|
with open(self.valid_metadata_path, "r", encoding="utf-8") as v: |
|
metadata.extend(json.load(v)) |
|
return metadata |
|
|
|
def __len__(self): |
|
return len(self.metadata) |
|
|
|
def __getitem__(self, index): |
|
utt_info = self.metadata[index] |
|
wav_path = utt_info["Path"] |
|
|
|
wav, sr = torchaudio.load(wav_path) |
|
|
|
|
|
if sr != self.sr: |
|
wav = torchaudio.functional.resample(wav, sr, self.sr) |
|
|
|
if wav.shape[0] > 1: |
|
wav = torch.mean(wav, dim=0, keepdim=True) |
|
assert wav.shape[0] == 1 |
|
wav = wav.squeeze(0) |
|
|
|
length = wav.shape[0] |
|
|
|
return utt_info, wav, length |
|
|
|
|
|
class LibrosaDataset(TorchaudioDataset): |
|
def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): |
|
super().__init__(cfg, dataset, sr, accelerator, metadata) |
|
|
|
def __getitem__(self, index): |
|
utt_info = self.metadata[index] |
|
wav_path = utt_info["Path"] |
|
|
|
wav, _ = librosa.load(wav_path, sr=self.sr) |
|
|
|
wav = torch.from_numpy(wav) |
|
|
|
|
|
length = wav.shape[0] |
|
return utt_info, wav, length |
|
|
|
|
|
class FFmpegDataset(TorchaudioDataset): |
|
def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): |
|
super().__init__(cfg, dataset, sr, accelerator, metadata) |
|
|
|
def __getitem__(self, index): |
|
utt_info = self.metadata[index] |
|
wav_path = utt_info["Path"] |
|
|
|
|
|
wav = whisper.load_audio(wav_path) |
|
|
|
wav = torch.from_numpy(wav) |
|
|
|
length = wav.shape[0] |
|
|
|
return utt_info, wav, length |
|
|
|
|
|
def collate_batch(batch_list): |
|
""" |
|
Args: |
|
batch_list: list of (metadata, wav, length) |
|
""" |
|
metadata = [item[0] for item in batch_list] |
|
|
|
wavs = pad_sequence([item[1] for item in batch_list], batch_first=True) |
|
lens = [item[2] for item in batch_list] |
|
|
|
return metadata, wavs, lens |
|
|