#!/usr/bin/env python3 # coding=utf-8 import torch from data.field.mini_torchtext.field import RawField class AnchorField(RawField): def process(self, batch, device=None): tensors, masks = self.pad(batch, device) return tensors, masks def pad(self, anchors, device): tensor = torch.zeros(anchors[0], anchors[1], dtype=torch.long, device=device) for anchor in anchors[-1]: tensor[anchor[0], anchor[1]] = 1 mask = tensor.sum(-1) == 0 return tensor, mask