Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import numpy as np | |
from fairseq.data import FairseqDataset | |
from . import data_utils | |
from .collaters import Seq2SeqCollater | |
class AsrDataset(FairseqDataset): | |
""" | |
A dataset representing speech and corresponding transcription. | |
Args: | |
aud_paths: (List[str]): A list of str with paths to audio files. | |
aud_durations_ms (List[int]): A list of int containing the durations of | |
audio files. | |
tgt (List[torch.LongTensor]): A list of LongTensors containing the indices | |
of target transcriptions. | |
tgt_dict (~fairseq.data.Dictionary): target vocabulary. | |
ids (List[str]): A list of utterance IDs. | |
speakers (List[str]): A list of speakers corresponding to utterances. | |
num_mel_bins (int): Number of triangular mel-frequency bins (default: 80) | |
frame_length (float): Frame length in milliseconds (default: 25.0) | |
frame_shift (float): Frame shift in milliseconds (default: 10.0) | |
""" | |
def __init__( | |
self, | |
aud_paths, | |
aud_durations_ms, | |
tgt, | |
tgt_dict, | |
ids, | |
speakers, | |
num_mel_bins=80, | |
frame_length=25.0, | |
frame_shift=10.0, | |
): | |
assert frame_length > 0 | |
assert frame_shift > 0 | |
assert all(x > frame_length for x in aud_durations_ms) | |
self.frame_sizes = [ | |
int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms | |
] | |
assert len(aud_paths) > 0 | |
assert len(aud_paths) == len(aud_durations_ms) | |
assert len(aud_paths) == len(tgt) | |
assert len(aud_paths) == len(ids) | |
assert len(aud_paths) == len(speakers) | |
self.aud_paths = aud_paths | |
self.tgt_dict = tgt_dict | |
self.tgt = tgt | |
self.ids = ids | |
self.speakers = speakers | |
self.num_mel_bins = num_mel_bins | |
self.frame_length = frame_length | |
self.frame_shift = frame_shift | |
self.s2s_collater = Seq2SeqCollater( | |
0, | |
1, | |
pad_index=self.tgt_dict.pad(), | |
eos_index=self.tgt_dict.eos(), | |
move_eos_to_beginning=True, | |
) | |
def __getitem__(self, index): | |
import torchaudio | |
import torchaudio.compliance.kaldi as kaldi | |
tgt_item = self.tgt[index] if self.tgt is not None else None | |
path = self.aud_paths[index] | |
if not os.path.exists(path): | |
raise FileNotFoundError("Audio file not found: {}".format(path)) | |
sound, sample_rate = torchaudio.load_wav(path) | |
output = kaldi.fbank( | |
sound, | |
num_mel_bins=self.num_mel_bins, | |
frame_length=self.frame_length, | |
frame_shift=self.frame_shift, | |
) | |
output_cmvn = data_utils.apply_mv_norm(output) | |
return {"id": index, "data": [output_cmvn.detach(), tgt_item]} | |
def __len__(self): | |
return len(self.aud_paths) | |
def collater(self, samples): | |
"""Merge a list of samples to form a mini-batch. | |
Args: | |
samples (List[int]): sample indices to collate | |
Returns: | |
dict: a mini-batch suitable for forwarding with a Model | |
""" | |
return self.s2s_collater.collate(samples) | |
def num_tokens(self, index): | |
return self.frame_sizes[index] | |
def size(self, index): | |
"""Return an example's size as a float or tuple. This value is used when | |
filtering a dataset with ``--max-positions``.""" | |
return ( | |
self.frame_sizes[index], | |
len(self.tgt[index]) if self.tgt is not None else 0, | |
) | |
def ordered_indices(self): | |
"""Return an ordered list of indices. Batches will be constructed based | |
on this order.""" | |
return np.arange(len(self)) | |