|
from typing import Callable, Optional, Union |
|
from tqdm import tqdm |
|
import os |
|
import torch |
|
import torchaudio |
|
import torchaudio.functional as F |
|
from torch.utils.data import Dataset, DataLoader, IterableDataset, random_split |
|
from pytorch_lightning import LightningDataModule |
|
import webdataset |
|
|
|
|
|
class VLSP2020Dataset(Dataset): |
|
def __init__(self, root: str, sample_rate: int = 16000): |
|
super().__init__() |
|
|
|
self.sample_rate = sample_rate |
|
self.memory = self._prepare_data(root) |
|
self._memory = tuple( |
|
(v["transcript"], v["audio"]) for v in self.memory.values() |
|
) |
|
|
|
@staticmethod |
|
def _prepare_data(root: str): |
|
memory = {} |
|
|
|
for f in os.scandir(root): |
|
file_name, file_ext = os.path.splitext(f.name) |
|
|
|
if file_ext == ".txt": |
|
if file_name not in memory: |
|
memory[file_name] = {"transcript": f.path} |
|
elif "transcript" not in memory[file_name]: |
|
memory[file_name]["transcript"] = f.path |
|
else: |
|
raise ValueError(f"Duplicate transcript for {f.path}") |
|
else: |
|
if file_name not in memory: |
|
memory[file_name] = {"audio": f.path} |
|
elif "audio" not in memory[file_name]: |
|
memory[file_name]["audio"] = f.path |
|
else: |
|
raise ValueError(f"Duplicate audio for {f.path}") |
|
|
|
for key, value in memory.items(): |
|
if "audio" not in value: |
|
raise ValueError(f"Missing audio for {key}") |
|
elif "transcript" not in value: |
|
raise ValueError(f"Missing transcript for {key}") |
|
|
|
return memory |
|
|
|
def __len__(self): |
|
return len(self.memory) |
|
|
|
def __getitem__(self, index: int): |
|
transcript, audio = self._memory[index] |
|
|
|
with open(transcript, "r") as f: |
|
transcript = f.read() |
|
|
|
audio, sample_rate = torchaudio.load(audio) |
|
audio = F.resample(audio, sample_rate, self.sample_rate) |
|
|
|
return transcript, audio |
|
|
|
|
|
class VLSP2020TarDataset: |
|
def __init__(self, outpath: str): |
|
self.outpath = outpath |
|
|
|
def convert(self, dataset: VLSP2020Dataset): |
|
writer = webdataset.TarWriter(self.outpath) |
|
|
|
for idx, (transcript, audio) in enumerate(tqdm(dataset, colour="green")): |
|
writer.write( |
|
{ |
|
"__key__": f"{idx:08d}", |
|
"txt": transcript, |
|
"pth": audio, |
|
} |
|
) |
|
|
|
writer.close() |
|
|
|
def load(self) -> webdataset.WebDataset: |
|
self.data = ( |
|
webdataset.WebDataset(self.outpath) |
|
.decode( |
|
webdataset.handle_extension("txt", lambda x: x.decode("utf-8")), |
|
webdataset.torch_audio, |
|
) |
|
.to_tuple("txt", "pth") |
|
) |
|
|
|
return self.data |
|
|
|
|
|
def get_dataloader( |
|
dataset: Union[VLSP2020Dataset, webdataset.WebDataset], |
|
return_transcript: bool = False, |
|
target_transform: Optional[Callable] = None, |
|
batch_size: int = 32, |
|
num_workers: int = 2, |
|
): |
|
def collate_fn(batch): |
|
def get_audio(item): |
|
audio = item[1] |
|
|
|
assert ( |
|
isinstance(audio, torch.Tensor) |
|
and audio.ndim == 2 |
|
and audio.size(0) == 1 |
|
) |
|
|
|
return audio.squeeze(0) |
|
|
|
audio = tuple(get_audio(item) for item in batch) |
|
|
|
if return_transcript: |
|
if target_transform is not None: |
|
transcript = tuple(target_transform(item[0]) for item in batch) |
|
else: |
|
transcript = tuple(item[0] for item in batch) |
|
|
|
return transcript, audio |
|
else: |
|
return audio |
|
|
|
return DataLoader( |
|
dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn |
|
) |
|
|