|
|
|
|
|
|
|
import torch |
|
from data.parser.json_parser import example_from_json |
|
|
|
|
|
class AbstractParser(torch.utils.data.Dataset): |
|
def __init__(self, fields, data, filter_pred=None): |
|
super(AbstractParser, self).__init__() |
|
|
|
self.examples = [example_from_json(d, fields) for _, d in sorted(data.items())] |
|
|
|
if isinstance(fields, dict): |
|
fields, field_dict = [], fields |
|
for field in field_dict.values(): |
|
if isinstance(field, list): |
|
fields.extend(field) |
|
else: |
|
fields.append(field) |
|
|
|
if filter_pred is not None: |
|
make_list = isinstance(self.examples, list) |
|
self.examples = filter(filter_pred, self.examples) |
|
if make_list: |
|
self.examples = list(self.examples) |
|
|
|
self.fields = dict(fields) |
|
|
|
|
|
for n, f in list(self.fields.items()): |
|
if isinstance(n, tuple): |
|
self.fields.update(zip(n, f)) |
|
del self.fields[n] |
|
|
|
def __getitem__(self, i): |
|
item = self.examples[i] |
|
processed_item = {} |
|
for (name, field) in self.fields.items(): |
|
if field is not None: |
|
processed_item[name] = field.process(getattr(item, name), device=None) |
|
return processed_item |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def get_examples(self, attr): |
|
if attr in self.fields: |
|
for x in self.examples: |
|
yield getattr(x, attr) |
|
|