ssa-perin / data /field /field.py
larkkin's picture
Add supporting code from perin
7daaa6b
import torch
from data.field.mini_torchtext.field import Field as TorchTextField
from collections import Counter, OrderedDict
# small change of vocab building to correspond to our version of Dataset
class Field(TorchTextField):
def build_vocab(self, *args, **kwargs):
counter = Counter()
sources = []
for arg in args:
if isinstance(arg, torch.utils.data.Dataset):
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self]
else:
sources.append(arg)
for data in sources:
for x in data:
if not self.sequential:
x = [x]
counter.update(x)
specials = list(
OrderedDict.fromkeys(
tok
for tok in [self.unk_token, self.pad_token, self.init_token, self.eos_token] + kwargs.pop("specials", [])
if tok is not None
)
)
self.vocab = self.vocab_cls(counter, specials=specials, **kwargs)
def process(self, example, device=None):
if self.include_lengths:
example = example, len(example)
tensor = self.numericalize(example, device=device)
return tensor
def numericalize(self, ex, device=None):
if self.include_lengths and not isinstance(ex, tuple):
raise ValueError("Field has include_lengths set to True, but input data is not a tuple of (data batch, batch lengths).")
if isinstance(ex, tuple):
ex, lengths = ex
lengths = torch.tensor(lengths, dtype=self.dtype, device=device)
if self.use_vocab:
if self.sequential:
ex = [self.vocab.stoi[x] for x in ex]
else:
ex = self.vocab.stoi[ex]
if self.postprocessing is not None:
ex = self.postprocessing(ex, self.vocab)
else:
numericalization_func = self.dtypes[self.dtype]
if not self.sequential:
ex = numericalization_func(ex) if isinstance(ex, str) else ex
if self.postprocessing is not None:
ex = self.postprocessing(ex, None)
var = torch.tensor(ex, dtype=self.dtype, device=device)
if self.sequential and not self.batch_first:
var.t_()
if self.sequential:
var = var.contiguous()
if self.include_lengths:
return var, lengths
return var