ltg
/

File size: 1,657 Bytes
c45d283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python3
# coding=utf-8

from data.parser.to_mrp.abstract_parser import AbstractParser


class SequentialParser(AbstractParser):
    def parse(self, prediction):
        output = {}

        output["id"] = self.dataset.id_field.vocab.itos[prediction["id"].item()]
        output["nodes"] = self.create_nodes(prediction)
        output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=True, mode="anchors")
        output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=False, mode="source anchors")
        output["nodes"] = self.create_anchors(prediction, output["nodes"], join_contiguous=True, at_least_one=False, mode="target anchors")
        output["edges"], output["nodes"] = self.create_targets_sources(output["nodes"])

        return output

    def create_targets_sources(self, nodes):
        edges, new_nodes = [], []
        for i, node in enumerate(nodes):
            new_node_id = len(nodes) + len(new_nodes)
            if len(node["source anchors"]) > 0:
                new_nodes.append({"id": new_node_id, "label": "Source", "anchors": node["source anchors"]})
                edges.append({"source": i, "target": new_node_id, "label": ""})
                new_node_id += 1
            del node["source anchors"]

            if len(node["target anchors"]) > 0:
                new_nodes.append({"id": new_node_id, "label": "Target", "anchors": node["target anchors"]})
                edges.append({"source": i, "target": new_node_id, "label": ""})
            del node["target anchors"]

        return edges, nodes + new_nodes