Karlo Pintaric
Upload 25 files
fdc1efd
from pathlib import Path
from typing import List, Optional, Tuple, Type, Union
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose
import modeling.transforms as transform_module
from modeling.transforms import (
LabelsFromTxt,
OneHotEncode,
ParentMultilabel,
Preprocess,
Transform,
)
from modeling.utils import CLASSES, get_wav_files, init_obj, init_transforms
class IRMASDataset(Dataset):
"""Dataset class for IRMAS dataset.
:param audio_dir: Directory containing the audio files
:type audio_dir: Union[str, Path]
:param preprocess: Preprocessing method to apply to the audio files
:type preprocess: Type[Preprocess]
:param signal_augments: Signal augmentation method to apply to the audio files, defaults to None
:type signal_augments: Optional[Union[Type[Compose], Type[Transform]]], optional
:param transforms: Transform method to apply to the audio files, defaults to None
:type transforms: Optional[Union[Type[Compose], Type[Transform]]], optional
:param spec_augments: Spectrogram augmentation method to apply to the audio files, defaults to None
:type spec_augments: Optional[Union[Type[Compose], Type[Transform]]], optional
:param subset: Subset of the data to load (train, valid, or test), defaults to "train"
:type subset: str, optional
:raises AssertionError: Raises an assertion error if subset is not train, valid or test
:raises OSError: Raises an OS error if test_songs.txt is not found in the data folder
:return: A tuple of the preprocessed audio signal and the corresponding one-hot encoded label
:rtype: Tuple[Tensor, Tensor]
"""
def __init__(
self,
audio_dir: Union[str, Path],
preprocess: Type[Preprocess],
signal_augments: Optional[Union[Type[Compose], Type[Transform]]] = None,
transforms: Optional[Union[Type[Compose], Type[Transform]]] = None,
spec_augments: Optional[Union[Type[Compose], Type[Transform]]] = None,
subset: str = "train",
):
self.files = get_wav_files(audio_dir)
assert subset in ["train", "valid", "test"], "Subset can only be train, valid or test"
self.subset = subset
if self.subset != "train":
try:
test_songs = np.genfromtxt("../data/test_songs.txt", dtype=str, ndmin=1, delimiter="\n")
except OSError as e:
print("Error: {e}")
print("test_songs.txt not found in data/. Please generate a split before training")
raise e
if self.subset == "valid":
self.files = [file for file in self.files if Path(file).stem not in test_songs]
if self.subset == "test":
self.files = [file for file in self.files if Path(file).stem in test_songs]
self.preprocess = preprocess
self.transforms = transforms
self.signal_augments = signal_augments
self.spec_augments = spec_augments
def __len__(self):
"""Return the length of the dataset.
:return: The length of the dataset
:rtype: int
"""
return len(self.files)
def __getitem__(self, index):
"""Get an item from the dataset.
:param index: The index of the item to get
:type index: int
:return: A tuple of the preprocessed audio signal and the corresponding one-hot encoded label
:rtype: Tuple[Tensor, Tensor]
"""
sample_path = self.files[index]
signal = self.preprocess(sample_path)
if self.subset == "train":
target_transforms = Compose([ParentMultilabel(sep="-"), OneHotEncode(CLASSES)])
else:
target_transforms = Compose([LabelsFromTxt(), OneHotEncode(CLASSES)])
label = target_transforms(sample_path)
if self.signal_augments is not None and self.subset == "train":
signal = self.signal_augments(signal)
if self.transforms is not None:
signal = self.transforms(signal)
if self.spec_augments is not None and self.subset == "train":
signal = self.spec_augments(signal)
return signal, label.float()
def collate_fn(data: List[Tuple[torch.Tensor, torch.Tensor]]):
"""
Function to collate a batch of audio signals and their corresponding labels.
:param data: A list of tuples containing the audio signals and their corresponding labels.
:type data: List[Tuple[torch.Tensor, torch.Tensor]]
:return: A tuple containing the batch of audio signals and their corresponding labels.
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
features, labels = zip(*data)
features = [item.squeeze().T for item in features]
# Pads items to same length if they're not
features = pad_sequence(features, batch_first=True)
labels = torch.stack(labels)
return features, labels
def get_loader(config: dict, subset: str):
"""
Function to create a PyTorch DataLoader for a given subset of the IRMAS dataset.
:param config: A configuration object.
:type config: Any
:param subset: The subset of the dataset to use. Can be "train" or "valid".
:type subset: str
:return: A PyTorch DataLoader for the specified subset of the dataset.
:rtype: torch.utils.data.DataLoader
"""
dst = IRMASDataset(
config.train_dir if subset == "train" else config.valid_dir,
preprocess=init_obj(config.preprocess, transform_module),
transforms=init_obj(config.transforms, transform_module),
signal_augments=init_transforms(config.signal_augments, transform_module),
spec_augments=init_transforms(config.spec_augments, transform_module),
subset=subset,
)
return DataLoader(
dst,
batch_size=config.batch_size,
shuffle=True if subset == "train" else False,
pin_memory=True if torch.cuda.is_available() else False,
num_workers=torch.get_num_threads() - 1,
collate_fn=collate_fn,
)