|
|
|
|
|
|
|
import torch |
|
from data.field.mini_torchtext.field import NestedField as TorchTextNestedField |
|
|
|
|
|
class NestedField(TorchTextNestedField): |
|
def pad(self, example): |
|
self.nesting_field.include_lengths = self.include_lengths |
|
if not self.include_lengths: |
|
return self.nesting_field.pad(example) |
|
|
|
sentence_length = len(example) |
|
example, word_lengths = self.nesting_field.pad(example) |
|
return example, sentence_length, word_lengths |
|
|
|
def numericalize(self, arr, device=None): |
|
numericalized = [] |
|
self.nesting_field.include_lengths = False |
|
if self.include_lengths: |
|
arr, sentence_length, word_lengths = arr |
|
|
|
numericalized = self.nesting_field.numericalize(arr, device=device) |
|
|
|
self.nesting_field.include_lengths = True |
|
if self.include_lengths: |
|
sentence_length = torch.tensor(sentence_length, dtype=self.dtype, device=device) |
|
word_lengths = torch.tensor(word_lengths, dtype=self.dtype, device=device) |
|
return (numericalized, sentence_length, word_lengths) |
|
return numericalized |
|
|
|
def build_vocab(self, *args, **kwargs): |
|
sources = [] |
|
for arg in args: |
|
if isinstance(arg, torch.utils.data.Dataset): |
|
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] |
|
else: |
|
sources.append(arg) |
|
|
|
flattened = [] |
|
for source in sources: |
|
flattened.extend(source) |
|
|
|
|
|
self.nesting_field.build_vocab(*flattened, **kwargs) |
|
super(TorchTextNestedField, self).build_vocab() |
|
self.vocab.extend(self.nesting_field.vocab) |
|
self.vocab.freqs = self.nesting_field.vocab.freqs.copy() |
|
self.nesting_field.vocab = self.vocab |
|
|