ltg
/

File size: 1,582 Bytes
c45d283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python3
# coding=utf-8

import torch
from data.parser.json_parser import example_from_json


class AbstractParser(torch.utils.data.Dataset):
    def __init__(self, fields, data, filter_pred=None):
        super(AbstractParser, self).__init__()

        self.examples = [example_from_json(d, fields) for _, d in sorted(data.items())]

        if isinstance(fields, dict):
            fields, field_dict = [], fields
            for field in field_dict.values():
                if isinstance(field, list):
                    fields.extend(field)
                else:
                    fields.append(field)

        if filter_pred is not None:
            make_list = isinstance(self.examples, list)
            self.examples = filter(filter_pred, self.examples)
            if make_list:
                self.examples = list(self.examples)

        self.fields = dict(fields)

        # Unpack field tuples
        for n, f in list(self.fields.items()):
            if isinstance(n, tuple):
                self.fields.update(zip(n, f))
                del self.fields[n]

    def __getitem__(self, i):
        item = self.examples[i]
        processed_item = {}
        for (name, field) in self.fields.items():
            if field is not None:
                processed_item[name] = field.process(getattr(item, name), device=None)
        return processed_item

    def __len__(self):
        return len(self.examples)

    def get_examples(self, attr):
        if attr in self.fields:
            for x in self.examples:
                yield getattr(x, attr)