Spaces:
Runtime error
Runtime error
OFA-OCR-dedao-demo001
/
fairseq
/examples
/wav2vec
/unsupervised
/data
/extracted_features_dataset.py
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
import contextlib | |
import numpy as np | |
import torch | |
from fairseq.data import FairseqDataset, data_utils | |
logger = logging.getLogger(__name__) | |
class ExtractedFeaturesDataset(FairseqDataset): | |
def __init__( | |
self, | |
path, | |
split, | |
min_length=3, | |
max_length=None, | |
labels=None, | |
label_dict=None, | |
shuffle=True, | |
sort_by_length=True, | |
): | |
super().__init__() | |
self.min_length = min_length | |
self.max_length = max_length | |
self.shuffle = shuffle | |
self.sort_by_length = sort_by_length | |
self.label_dict = label_dict | |
if labels is not None: | |
assert label_dict is not None | |
self.sizes = [] | |
self.offsets = [] | |
self.labels = [] | |
path = os.path.join(path, split) | |
data_path = path | |
self.data = np.load(data_path + ".npy", mmap_mode="r") | |
offset = 0 | |
skipped = 0 | |
if not os.path.exists(path + f".{labels}"): | |
labels = None | |
with open(data_path + ".lengths", "r") as len_f, open( | |
path + f".{labels}", "r" | |
) if labels is not None else contextlib.ExitStack() as lbl_f: | |
for line in len_f: | |
length = int(line.rstrip()) | |
lbl = None if labels is None else next(lbl_f).rstrip().split() | |
if length >= min_length and ( | |
max_length is None or length <= max_length | |
): | |
self.sizes.append(length) | |
self.offsets.append(offset) | |
if lbl is not None: | |
self.labels.append(lbl) | |
offset += length | |
self.sizes = np.asarray(self.sizes) | |
self.offsets = np.asarray(self.offsets) | |
logger.info(f"loaded {len(self.offsets)}, skipped {skipped} samples") | |
def __getitem__(self, index): | |
offset = self.offsets[index] | |
end = self.sizes[index] + offset | |
feats = torch.from_numpy(self.data[offset:end].copy()).float() | |
res = {"id": index, "features": feats} | |
if len(self.labels) > 0: | |
res["target"] = self.label_dict.encode_line( | |
self.labels[index], | |
line_tokenizer=lambda x: x, | |
append_eos=False, | |
) | |
return res | |
def __len__(self): | |
return len(self.sizes) | |
def collater(self, samples): | |
if len(samples) == 0: | |
return {} | |
features = [s["features"] for s in samples] | |
sizes = [len(s) for s in features] | |
target_size = max(sizes) | |
collated_features = features[0].new_zeros( | |
len(features), target_size, features[0].size(-1) | |
) | |
padding_mask = torch.BoolTensor(collated_features.shape[:-1]).fill_(False) | |
for i, (f, size) in enumerate(zip(features, sizes)): | |
collated_features[i, :size] = f | |
padding_mask[i, size:] = True | |
res = { | |
"id": torch.LongTensor([s["id"] for s in samples]), | |
"net_input": {"features": collated_features, "padding_mask": padding_mask}, | |
} | |
if len(self.labels) > 0: | |
target = data_utils.collate_tokens( | |
[s["target"] for s in samples], | |
pad_idx=self.label_dict.pad(), | |
left_pad=False, | |
) | |
res["target"] = target | |
return res | |
def num_tokens(self, index): | |
return self.size(index) | |
def size(self, index): | |
return self.sizes[index] | |
def ordered_indices(self): | |
"""Return an ordered list of indices. Batches will be constructed based | |
on this order.""" | |
if self.shuffle: | |
order = [np.random.permutation(len(self))] | |
else: | |
order = [np.arange(len(self))] | |
if self.sort_by_length: | |
order.append(self.sizes) | |
return np.lexsort(order)[::-1] | |
else: | |
return order[0] | |