OFA-OCR / data /nlu_data /cola_dataset.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
4.49 kB
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
import logging
import warnings
import torch
import numpy as np
from data import data_utils
from data.ofa_dataset import OFADataset
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
ref_dict = None
if samples[0].get("ref_dict", None) is not None:
ref_dict = np.array([s['ref_dict'] for s in samples])
constraint_masks = None
if samples[0].get("constraint_mask", None) is not None:
constraint_masks = merge("constraint_mask")
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor(
[s["target"].ne(pad_idx).long().sum() for s in samples]
)
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"prev_output_tokens": prev_output_tokens
},
"ref_dict": ref_dict,
"constraint_masks": constraint_masks,
"target": target,
}
return batch
class COLADataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=512,
max_tgt_length=30,
constraint_trie=None,
prompt_type="none"
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.constraint_trie = constraint_trie
self.prompt_type = prompt_type
def __getitem__(self, index):
sentence, label = self.dataset[index]
if label == '0':
label = 'no'
elif label == '1':
label = 'yes'
else:
raise NotImplementedError
sentence = ' '.join(sentence.lower().strip().split()[:self.max_src_length])
src_item = self.encode_text(' is the text " {} " grammatically correct?'.format(sentence))
tgt_item = self.encode_text(" {}".format(label))
assert tgt_item.size(0) == 1
ref_dict = {label: 1.0}
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
if self.prompt_type == 'none':
prev_output_item = self.bos_item
target_item = tgt_item
elif self.prompt_type == 'src':
prev_output_item = src_item.clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
elif self.prompt_type == 'prev_output':
prev_output_item = src_item[:-1].clone()
target_item = torch.cat([prev_output_item[1:], tgt_item])
else:
raise NotImplementedError
target_item[:-1] = self.tgt_dict.pad()
example = {
"source": src_item,
"target": target_item,
"prev_output_tokens": prev_output_item,
"ref_dict": ref_dict,
}
if self.constraint_trie is not None:
constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
constraint_mask[-1][constraint_nodes] = True
example["constraint_mask"] = constraint_mask
return example
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data of the task
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)