|
|
|
|
|
|
|
import torch |
|
from data.field.mini_torchtext.field import RawField |
|
from data.field.mini_torchtext.vocab import Vocab |
|
from collections import Counter |
|
import types |
|
|
|
|
|
class EdgeLabelField(RawField): |
|
def process(self, edges, device=None): |
|
edges, masks = self.numericalize(edges) |
|
edges, masks = self.pad(edges, masks, device) |
|
|
|
return edges, masks |
|
|
|
def pad(self, edges, masks, device): |
|
n_labels = len(self.vocab) |
|
|
|
tensor = torch.zeros(edges[0], edges[1], n_labels, dtype=torch.long, device=device) |
|
mask_tensor = torch.zeros(edges[0], edges[1], dtype=torch.bool, device=device) |
|
|
|
for edge in edges[-1]: |
|
tensor[edge[0], edge[1], edge[2]] = 1 |
|
|
|
for mask in masks[-1]: |
|
mask_tensor[mask[0], mask[1]] = mask[2] |
|
|
|
return tensor, mask_tensor |
|
|
|
def numericalize(self, arr): |
|
def multi_map(array, function): |
|
if isinstance(array, tuple): |
|
return (array[0], array[1], function(array[2])) |
|
elif isinstance(array, list): |
|
return [multi_map(array[i], function) for i in range(len(array))] |
|
else: |
|
return array |
|
|
|
mask = multi_map(arr, lambda x: x is None) |
|
arr = multi_map(arr, lambda x: self.vocab.stoi[x] if x in self.vocab.stoi else 0) |
|
return arr, mask |
|
|
|
def build_vocab(self, *args): |
|
def generate(l): |
|
if isinstance(l, tuple): |
|
yield l[2] |
|
elif isinstance(l, list) or isinstance(l, types.GeneratorType): |
|
for i in l: |
|
yield from generate(i) |
|
else: |
|
return |
|
|
|
counter = Counter() |
|
sources = [] |
|
for arg in args: |
|
if isinstance(arg, torch.utils.data.Dataset): |
|
sources += [arg.get_examples(name) for name, field in arg.fields.items() if field is self] |
|
else: |
|
sources.append(arg) |
|
|
|
for x in generate(sources): |
|
if x is not None: |
|
counter.update([x]) |
|
|
|
self.vocab = Vocab(counter, specials=[]) |
|
|