JustinLin610's picture
first commit
ee21b96
raw
history blame
4.17 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
import contextlib
import numpy as np
import torch
from fairseq.data import FairseqDataset, data_utils
logger = logging.getLogger(__name__)
class ExtractedFeaturesDataset(FairseqDataset):
def __init__(
self,
path,
split,
min_length=3,
max_length=None,
labels=None,
label_dict=None,
shuffle=True,
sort_by_length=True,
):
super().__init__()
self.min_length = min_length
self.max_length = max_length
self.shuffle = shuffle
self.sort_by_length = sort_by_length
self.label_dict = label_dict
if labels is not None:
assert label_dict is not None
self.sizes = []
self.offsets = []
self.labels = []
path = os.path.join(path, split)
data_path = path
self.data = np.load(data_path + ".npy", mmap_mode="r")
offset = 0
skipped = 0
if not os.path.exists(path + f".{labels}"):
labels = None
with open(data_path + ".lengths", "r") as len_f, open(
path + f".{labels}", "r"
) if labels is not None else contextlib.ExitStack() as lbl_f:
for line in len_f:
length = int(line.rstrip())
lbl = None if labels is None else next(lbl_f).rstrip().split()
if length >= min_length and (
max_length is None or length <= max_length
):
self.sizes.append(length)
self.offsets.append(offset)
if lbl is not None:
self.labels.append(lbl)
offset += length
self.sizes = np.asarray(self.sizes)
self.offsets = np.asarray(self.offsets)
logger.info(f"loaded {len(self.offsets)}, skipped {skipped} samples")
def __getitem__(self, index):
offset = self.offsets[index]
end = self.sizes[index] + offset
feats = torch.from_numpy(self.data[offset:end].copy()).float()
res = {"id": index, "features": feats}
if len(self.labels) > 0:
res["target"] = self.label_dict.encode_line(
self.labels[index],
line_tokenizer=lambda x: x,
append_eos=False,
)
return res
def __len__(self):
return len(self.sizes)
def collater(self, samples):
if len(samples) == 0:
return {}
features = [s["features"] for s in samples]
sizes = [len(s) for s in features]
target_size = max(sizes)
collated_features = features[0].new_zeros(
len(features), target_size, features[0].size(-1)
)
padding_mask = torch.BoolTensor(collated_features.shape[:-1]).fill_(False)
for i, (f, size) in enumerate(zip(features, sizes)):
collated_features[i, :size] = f
padding_mask[i, size:] = True
res = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": {"features": collated_features, "padding_mask": padding_mask},
}
if len(self.labels) > 0:
target = data_utils.collate_tokens(
[s["target"] for s in samples],
pad_idx=self.label_dict.pad(),
left_pad=False,
)
res["target"] = target
return res
def num_tokens(self, index):
return self.size(index)
def size(self, index):
return self.sizes[index]
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
if self.sort_by_length:
order.append(self.sizes)
return np.lexsort(order)[::-1]
else:
return order[0]