File size: 1,582 Bytes
c45d283 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
#!/usr/bin/env python3
# coding=utf-8
import torch
from data.parser.json_parser import example_from_json
class AbstractParser(torch.utils.data.Dataset):
def __init__(self, fields, data, filter_pred=None):
super(AbstractParser, self).__init__()
self.examples = [example_from_json(d, fields) for _, d in sorted(data.items())]
if isinstance(fields, dict):
fields, field_dict = [], fields
for field in field_dict.values():
if isinstance(field, list):
fields.extend(field)
else:
fields.append(field)
if filter_pred is not None:
make_list = isinstance(self.examples, list)
self.examples = filter(filter_pred, self.examples)
if make_list:
self.examples = list(self.examples)
self.fields = dict(fields)
# Unpack field tuples
for n, f in list(self.fields.items()):
if isinstance(n, tuple):
self.fields.update(zip(n, f))
del self.fields[n]
def __getitem__(self, i):
item = self.examples[i]
processed_item = {}
for (name, field) in self.fields.items():
if field is not None:
processed_item[name] = field.process(getattr(item, name), device=None)
return processed_item
def __len__(self):
return len(self.examples)
def get_examples(self, attr):
if attr in self.fields:
for x in self.examples:
yield getattr(x, attr)
|