import torch from data.field.mini_torchtext.field import Field as TorchTextField from collections import Counter, OrderedDict # small change of vocab building to correspond to our version of Dataset class Field(TorchTextField): def build_vocab(self, *args, **kwargs): counter = Counter() sources = [] for arg in args: if isinstance(arg, torch.utils.data.Dataset): sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] else: sources.append(arg) for data in sources: for x in data: if not self.sequential: x = [x] counter.update(x) specials = list( OrderedDict.fromkeys( tok for tok in [self.unk_token, self.pad_token, self.init_token, self.eos_token] + kwargs.pop("specials", []) if tok is not None ) ) self.vocab = self.vocab_cls(counter, specials=specials, **kwargs) def process(self, example, device=None): if self.include_lengths: example = example, len(example) tensor = self.numericalize(example, device=device) return tensor def numericalize(self, ex, device=None): if self.include_lengths and not isinstance(ex, tuple): raise ValueError("Field has include_lengths set to True, but input data is not a tuple of (data batch, batch lengths).") if isinstance(ex, tuple): ex, lengths = ex lengths = torch.tensor(lengths, dtype=self.dtype, device=device) if self.use_vocab: if self.sequential: ex = [self.vocab.stoi[x] for x in ex] else: ex = self.vocab.stoi[ex] if self.postprocessing is not None: ex = self.postprocessing(ex, self.vocab) else: numericalization_func = self.dtypes[self.dtype] if not self.sequential: ex = numericalization_func(ex) if isinstance(ex, str) else ex if self.postprocessing is not None: ex = self.postprocessing(ex, None) var = torch.tensor(ex, dtype=self.dtype, device=device) if self.sequential and not self.batch_first: var.t_() if self.sequential: var = var.contiguous() if self.include_lengths: return var, lengths return var