|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from torch.utils.data import Dataset |
|
|
|
|
|
class TokenizedDataset(Dataset): |
|
""" |
|
Converts a dataset of text samples into a dataset of token sequences, |
|
as converted by a supplied tokenizer. The tokens come along with position |
|
ids and attention masks, they can be supplied direcly to the model. |
|
""" |
|
|
|
def __init__(self, text_dataset, tokenizer=None, maxlen=None, field="text"): |
|
self.text_dataset = text_dataset |
|
self.field = field |
|
self.tokenizer = tokenizer |
|
self.maxlen = maxlen |
|
if hasattr(text_dataset, "info"): |
|
self.info = text_dataset.info |
|
|
|
def __len__(self): |
|
return len(self.text_dataset) |
|
|
|
def __getitem__(self, i): |
|
text = self.text_dataset[i] |
|
if self.field is not None: |
|
text = text[self.field] |
|
token_list = self.tokenizer.encode( |
|
text, truncation=True, max_length=self.maxlen |
|
) |
|
position_ids = list(range(len(token_list))) |
|
attention_mask = [1] * len(token_list) |
|
return dict( |
|
input_ids=torch.tensor(token_list), |
|
position_ids=torch.tensor(position_ids), |
|
attention_mask=torch.tensor(attention_mask), |
|
) |
|
|
|
|
|
def dict_to_(data, device): |
|
""" |
|
Moves a dictionary of tensors to the specified device. |
|
""" |
|
for k in data: |
|
data[k] = data[k].to(device) |
|
return data |
|
|
|
|
|
def length_collation(token_size): |
|
""" |
|
Sorts a batch of sequences and breaks it up into subbatches |
|
of same-sized sequences, padding as needed. Each batch |
|
has no more than token_size total tokens (or a single |
|
sequence, if the sequence happens to be larger). |
|
""" |
|
|
|
def collate_fn(items): |
|
items = sorted(items, key=lambda x: -len(x["input_ids"])) |
|
batches = [] |
|
batch = [] |
|
batch_width = 0 |
|
for item in items: |
|
item_width = len(item["input_ids"]) |
|
if item_width == 0: |
|
break |
|
if batch_width * (len(batch) + 1) > token_size: |
|
batches.append(make_padded_batch(batch)) |
|
batch = [] |
|
batch_width = 0 |
|
if not batch: |
|
batch_width = item_width |
|
batch.append(item) |
|
if len(batch): |
|
batches.append(make_padded_batch(batch)) |
|
return batches |
|
|
|
return collate_fn |
|
|
|
|
|
def make_padded_batch(items): |
|
""" |
|
Pads sequences in a batch, so they are all the same length as the longest. |
|
""" |
|
max_len = max(len(d["input_ids"]) for d in items) |
|
if max_len == 0: |
|
return {k: torch.zeros((0, 0), dtype=torch.long) for k in items[0]} |
|
return { |
|
k: pad_sequence([d[k] for d in items if len(d["input_ids"])], batch_first=True) |
|
for k, v in items[0].items() |
|
} |
|
|
|
|
|
def flatten_masked_batch(data, mask): |
|
""" |
|
Flattens feature data, ignoring items that are masked out of attention. |
|
""" |
|
flat_data = data.view(-1, data.size(-1)) |
|
attended_tokens = mask.view(-1).nonzero()[:, 0] |
|
return flat_data[attended_tokens] |
|
|