Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import random | |
from typing import List | |
from fairseq.data import BaseWrapperDataset, data_utils | |
class RandomInputDataset(BaseWrapperDataset): | |
def __init__( | |
self, | |
dataset, | |
random_input_dataset, | |
input_key_path: List[str], | |
add_to_input, | |
pad_idx, | |
): | |
super().__init__(dataset) | |
self.random_input_dataset = random_input_dataset | |
if isinstance(input_key_path, str): | |
input_key_path = [input_key_path] | |
assert len(input_key_path) > 0 | |
self.input_key_path = input_key_path | |
self.add_to_input = add_to_input | |
self.pad_idx = pad_idx | |
def get_target(self, item): | |
target_loc = item | |
for p in self.input_key_path[:-1]: | |
target_loc = target_loc[p] | |
return self.input_key_path[-1], target_loc | |
def get_target_value(self, item): | |
k, target_loc = self.get_target(item) | |
return target_loc[k] | |
def __getitem__(self, index): | |
item = self.dataset[index] | |
k, target_loc = self.get_target(item) | |
target_loc[k] = random.choice(self.random_input_dataset) | |
return item | |
def collater(self, samples): | |
collated = self.dataset.collater(samples) | |
if len(collated) == 0: | |
return collated | |
indices = set(collated["id"].tolist()) | |
random_inputs = data_utils.collate_tokens( | |
[self.get_target_value(s) for s in samples if s["id"] in indices], | |
pad_idx=self.pad_idx, | |
left_pad=False, | |
) | |
k, target_loc = self.get_target( | |
collated if not self.add_to_input else collated["net_input"] | |
) | |
target_loc[k] = random_inputs | |
return collated | |