mrfakename's picture
Sync from GitHub repo
fededd1 verified
raw
history blame
10.4 kB
import json
import random
from importlib.resources import files
import torch
import torch.nn.functional as F
import torchaudio
from datasets import Dataset as Dataset_
from datasets import load_from_disk
from torch import nn
from torch.utils.data import Dataset, Sampler
from tqdm import tqdm
from f5_tts.model.modules import MelSpec
from f5_tts.model.utils import default
class HFDataset(Dataset):
def __init__(
self,
hf_dataset: Dataset,
target_sample_rate=24_000,
n_mel_channels=100,
hop_length=256,
n_fft=1024,
win_length=1024,
mel_spec_type="vocos",
):
self.data = hf_dataset
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.mel_spectrogram = MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
)
def get_frame_len(self, index):
row = self.data[index]
audio = row["audio"]["array"]
sample_rate = row["audio"]["sampling_rate"]
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
row = self.data[index]
audio = row["audio"]["array"]
# logger.info(f"Audio shape: {audio.shape}")
sample_rate = row["audio"]["sampling_rate"]
duration = audio.shape[-1] / sample_rate
if duration > 30 or duration < 0.3:
return self.__getitem__((index + 1) % len(self.data))
audio_tensor = torch.from_numpy(audio).float()
if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
audio_tensor = resampler(audio_tensor)
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
mel_spec = self.mel_spectrogram(audio_tensor)
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
text = row["text"]
return dict(
mel_spec=mel_spec,
text=text,
)
class CustomDataset(Dataset):
def __init__(
self,
custom_dataset: Dataset,
durations=None,
target_sample_rate=24_000,
hop_length=256,
n_mel_channels=100,
n_fft=1024,
win_length=1024,
mel_spec_type="vocos",
preprocessed_mel=False,
mel_spec_module: nn.Module | None = None,
):
self.data = custom_dataset
self.durations = durations
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.n_fft = n_fft
self.win_length = win_length
self.mel_spec_type = mel_spec_type
self.preprocessed_mel = preprocessed_mel
if not preprocessed_mel:
self.mel_spectrogram = default(
mel_spec_module,
MelSpec(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
)
def get_frame_len(self, index):
if (
self.durations is not None
): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
return self.durations[index] * self.target_sample_rate / self.hop_length
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
row = self.data[index]
audio_path = row["audio_path"]
text = row["text"]
duration = row["duration"]
if self.preprocessed_mel:
mel_spec = torch.tensor(row["mel_spec"])
else:
audio, source_sample_rate = torchaudio.load(audio_path)
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)
if duration > 30 or duration < 0.3:
return self.__getitem__((index + 1) % len(self.data))
if source_sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
audio = resampler(audio)
mel_spec = self.mel_spectrogram(audio)
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
return dict(
mel_spec=mel_spec,
text=text,
)
# Dynamic Batch Sampler
class DynamicBatchSampler(Sampler[list[int]]):
"""Extension of Sampler that will do the following:
1. Change the batch size (essentially number of sequences)
in a batch to ensure that the total number of frames are less
than a certain threshold.
2. Make sure the padding efficiency in the batch is high.
"""
def __init__(
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
):
self.sampler = sampler
self.frames_threshold = frames_threshold
self.max_samples = max_samples
indices, batches = [], []
data_source = self.sampler.data_source
for idx in tqdm(
self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
):
indices.append((idx, data_source.get_frame_len(idx)))
indices.sort(key=lambda elem: elem[1])
batch = []
batch_frames = 0
for idx, frame_len in tqdm(
indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
):
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
batch.append(idx)
batch_frames += frame_len
else:
if len(batch) > 0:
batches.append(batch)
if frame_len <= self.frames_threshold:
batch = [idx]
batch_frames = frame_len
else:
batch = []
batch_frames = 0
if not drop_last and len(batch) > 0:
batches.append(batch)
del indices
# if want to have different batches between epochs, may just set a seed and log it in ckpt
# cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
# e.g. for epoch n, use (random_seed + n)
random.seed(random_seed)
random.shuffle(batches)
self.batches = batches
def __iter__(self):
return iter(self.batches)
def __len__(self):
return len(self.batches)
# Load dataset
def load_dataset(
dataset_name: str,
tokenizer: str = "pinyin",
dataset_type: str = "CustomDataset",
audio_type: str = "raw",
mel_spec_module: nn.Module | None = None,
mel_spec_kwargs: dict = dict(),
) -> CustomDataset | HFDataset:
"""
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
"""
print("Loading dataset ...")
if dataset_type == "CustomDataset":
rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
if audio_type == "raw":
try:
train_dataset = load_from_disk(f"{rel_data_path}/raw")
except: # noqa: E722
train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
preprocessed_mel = False
elif audio_type == "mel":
train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
preprocessed_mel = True
with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(
train_dataset,
durations=durations,
preprocessed_mel=preprocessed_mel,
mel_spec_module=mel_spec_module,
**mel_spec_kwargs,
)
elif dataset_type == "CustomDatasetPath":
try:
train_dataset = load_from_disk(f"{dataset_name}/raw")
except: # noqa: E722
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
data_dict = json.load(f)
durations = data_dict["duration"]
train_dataset = CustomDataset(
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
)
elif dataset_type == "HFDataset":
print(
"Should manually modify the path of huggingface dataset to your need.\n"
+ "May also the corresponding script cuz different dataset may have different format."
)
pre, post = dataset_name.split("_")
train_dataset = HFDataset(
load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
)
return train_dataset
# collation
def collate_fn(batch):
mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
max_mel_length = mel_lengths.amax()
padded_mel_specs = []
for spec in mel_specs: # TODO. maybe records mask for attention here
padding = (0, max_mel_length - spec.size(-1))
padded_spec = F.pad(spec, padding, value=0)
padded_mel_specs.append(padded_spec)
mel_specs = torch.stack(padded_mel_specs)
text = [item["text"] for item in batch]
text_lengths = torch.LongTensor([len(item) for item in text])
return dict(
mel=mel_specs,
mel_lengths=mel_lengths,
text=text,
text_lengths=text_lengths,
)