artst-demo-asr / artst /data /speech_to_class_dataset.py
amupd's picture
initial commit
8b33290
raw
history blame
8.8 kB
# --------------------------------------------------------
# ArTST: Arabic Text and Speech Transformer (https://arxiv.org/abs/2310.16621)
# Github source: https://github.com/mbzuai-nlp/ArTST
# Based on speecht5, fairseq and espnet code bases
# https://github.com/microsoft/SpeechT5/tree/main/SpeechT5; https://github.com/pytorch/fairseq; https://github.com/espnet/espnet
# --------------------------------------------------------
import logging
import os
from typing import Any, List, Optional
import numpy as np
import torch
import torch.nn.functional as F
from fairseq.data import data_utils, Dictionary
from fairseq.data.fairseq_dataset import FairseqDataset
logger = logging.getLogger(__name__)
def load_audio(manifest_path, max_keep, min_keep):
"""manifest tsv: wav_path, wav_nframe, wav_class
Args
manifest_path: str
max_keep: int
min_keep: int
Return
root, names, inds, tot, sizes, classes
"""
n_long, n_short = 0, 0
names, inds, sizes, classes = [], [], [], []
with open(manifest_path) as f:
root = f.readline().strip()
for ind, line in enumerate(f):
items = line.strip().split("\t")
assert len(items) >= 2, line
sz = int(items[1])
if min_keep is not None and sz < min_keep:
n_short += 1
elif max_keep is not None and sz > max_keep:
n_long += 1
else:
names.append(items[0])
if len(items) > 2:
classes.append(items[2])
inds.append(ind)
sizes.append(sz)
tot = ind + 1
logger.info(
(
f"max_keep={max_keep}, min_keep={min_keep}, "
f"loaded {len(names)}, skipped {n_short} short and {n_long} long, "
f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}"
)
)
if len(classes) == 0:
logger.warn("no classes loaded only if inference")
return root, names, inds, tot, sizes, classes
def sample_from_feature(x: np.ndarray, max_segment_length: int = 300):
"""Load a segment within 300-400/51200-76800 frames or the corresponding samples from a utterance.
Args:
x (np.ndarray): feature or waveform (frames[, features]), e.g., log mel filter bank or waveform
max_segment_length (int, optional): maximum segment length. Defaults to 400.
Returns:
np.ndarray: segmented features
"""
if len(x) <= max_segment_length:
return x
start = np.random.randint(0, x.shape[0] - max_segment_length)
return x[start: start + max_segment_length]
class SpeechToClassDataset(FairseqDataset):
def __init__(
self,
manifest_path: str,
sample_rate: float,
label_processors: Optional[List[Any]] = None,
max_keep_sample_size: Optional[int] = None,
min_keep_sample_size: Optional[int] = None,
shuffle: bool = True,
normalize: bool = False,
tgt_dict: Optional[Dictionary] = None,
max_length: Optional[int] = None
):
self.audio_root, self.audio_names, inds, tot, self.wav_sizes, self.wav_classes = load_audio(
manifest_path, max_keep_sample_size, min_keep_sample_size
)
self.sample_rate = sample_rate
self.shuffle = shuffle
self.label_processors = label_processors
self.normalize = normalize
self.tgt_dict = tgt_dict
self.max_length = max_length
logger.info(
f"max_length={max_length}, normalize={normalize}"
)
def get_audio(self, index):
import soundfile as sf
wav_path = os.path.join(self.audio_root, self.audio_names[index])
wav, cur_sample_rate = sf.read(wav_path)
if self.max_length is not None:
wav = sample_from_feature(wav, self.max_length)
wav = torch.from_numpy(wav).float()
wav = self.postprocess(wav, cur_sample_rate)
return wav
def get_label(self, index):
label = self.wav_classes[index]
if self.label_processors is not None:
label = self.label_processors(label)
return label
def __getitem__(self, index):
wav = self.get_audio(index)
label = None
if len(self.wav_classes) == len(self.audio_names):
label = self.get_label(index)
return {"id": index, "source": wav, "label": label}
def __len__(self):
return len(self.wav_sizes)
def collater(self, samples):
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
audios = [s["source"] for s in samples]
audio_sizes = [len(s) for s in audios]
audio_size = max(audio_sizes)
collated_audios, padding_mask = self.collater_audio(
audios, audio_size
)
decoder_label = None
decoder_target = None
decoder_target_lengths = None
if samples[0]["label"] is not None:
targets_by_label = [
[s["label"] for s in samples]
]
targets_list, lengths_list, ntokens_list = self.collater_label(targets_by_label)
decoder_label = [
(targets_list[0][i, :lengths_list[0][i]]).long()
for i in range(targets_list[0].size(0))
]
decoder_target = data_utils.collate_tokens(
decoder_label,
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
decoder_target_lengths = torch.tensor(
[x.size(0) for x in decoder_label], dtype=torch.long
)
prev_output_tokens = data_utils.collate_tokens(
[torch.LongTensor([-1]) for _ in samples],
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=True,
)
net_input = {
"source": collated_audios,
"padding_mask": padding_mask,
"prev_output_tokens": prev_output_tokens,
"task_name": "s2c",
}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
"target": decoder_target,
"target_lengths": decoder_target_lengths,
"task_name": "s2c",
"ntokens": len(samples),
}
return batch
def collater_audio(self, audios, audio_size):
collated_audios = audios[0].new_zeros(len(audios), audio_size)
padding_mask = (
torch.BoolTensor(collated_audios.shape).fill_(False)
)
for i, audio in enumerate(audios):
diff = len(audio) - audio_size
if diff == 0:
collated_audios[i] = audio
elif diff < 0:
collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)])
padding_mask[i, diff:] = True
else:
raise Exception("Diff should not be larger than 0")
return collated_audios, padding_mask
def collater_seq_label(self, targets, pad):
lengths = torch.LongTensor([len(t) for t in targets])
ntokens = lengths.sum().item()
targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False)
return targets, lengths, ntokens
def collater_label(self, targets_by_label):
targets_list, lengths_list, ntokens_list = [], [], []
itr = zip(targets_by_label, [self.tgt_dict.pad()])
for targets, pad in itr:
targets, lengths, ntokens = self.collater_seq_label(targets, pad)
targets_list.append(targets)
lengths_list.append(lengths)
ntokens_list.append(ntokens)
return targets_list, lengths_list, ntokens_list
def num_tokens(self, index):
return self.size(index)
def size(self, index):
return self.wav_sizes[index]
@property
def sizes(self):
return np.array(self.wav_sizes)
def ordered_indices(self):
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.wav_sizes)
return np.lexsort(order)[::-1]
def postprocess(self, wav, cur_sample_rate):
if wav.dim() == 2:
wav = wav.mean(-1)
assert wav.dim() == 1, wav.dim()
if cur_sample_rate != self.sample_rate:
raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}")
if self.normalize:
with torch.no_grad():
wav = F.layer_norm(wav, wav.shape)
return wav