File size: 4,499 Bytes
6680682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from typing import Any, Dict, List, Optional
import dataclasses
import glob
import os
import sys
import json

import spacy
from spacy.language import Language

from sftp import SpanPredictor


@dataclasses.dataclass
class FrameAnnotation:
    tokens: List[str] = dataclasses.field(default_factory=list)
    pos: List[str] = dataclasses.field(default_factory=list)


@dataclasses.dataclass
class MultiLabelAnnotation(FrameAnnotation):
    frame_list: List[List[str]] = dataclasses.field(default_factory=list)
    lu_list: List[Optional[str]] = dataclasses.field(default_factory=list)

    def to_txt(self):
        for i, tok in enumerate(self.tokens):
            yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}"


def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]:
    labels = [[] for _ in sentence]

    for struct_id, struct in structures.items():
        tgt_span = struct["target"]
        frame = struct["frame"]

        for i in range(tgt_span[0], tgt_span[1] + 1):
            labels[i].append(f"T:{frame}@{struct_id:02}")
        for role in struct["roles"]:
            role_span = role["boundary"]
            role_label = role["label"]
            for i in range(role_span[0], role_span[1] + 1):
                prefix = "B" if i == role_span[0] else "I"
                labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
    return labels


def predict_combined(
    spacy_model: Language,
    sentences: List[str],
    tgt_predictor: SpanPredictor,
    frm_predictor: SpanPredictor,
    bnd_predictor: SpanPredictor,
    arg_predictor: SpanPredictor,
) -> List[MultiLabelAnnotation]:
    
    annotations_out = []

    for sent_idx, sent in enumerate(sentences):

        sent = sent.strip()

        print(f"Processing sent with idx={sent_idx}: {sent}")
        
        doc = spacy_model(sent)
        sent_tokens = [t.text for t in doc]

        tgt_spans, _, _ = tgt_predictor.force_decode(sent_tokens)

        frame_structures = {}

        for i, span in enumerate(tgt_spans):
            span = tuple(span)
            _, fr_labels, _ = frm_predictor.force_decode(sent_tokens, child_spans=[span])
            frame = fr_labels[0]
            if frame == "@@VIRTUAL_ROOT@@@":
                continue

            boundaries, _, _ = bnd_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame)
            _, arg_labels, _ = arg_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame, child_spans=boundaries)

            frame_structures[i] = {
                "target": span,
                "frame": frame,
                "roles": [
                    {"boundary": bnd, "label": label}
                    for bnd, label in zip(boundaries, arg_labels)
                    if label != "Target"
                ]
            }
        annotations_out.append(MultiLabelAnnotation(
            tokens=sent_tokens,
            pos=[t.pos_ for t in doc],
            frame_list=convert_to_seq_labels(sent_tokens, frame_structures),
            lu_list=[None for _ in sent_tokens]
        ))
    return annotations_out


def main(input_folder):

    print("Loading spaCy model ...")
    nlp = spacy.load("it_core_news_md")

    print("Loading predictors ...")
    zs_predictor = SpanPredictor.from_path("/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", cuda_device=0)
    ev_predictor = SpanPredictor.from_path("/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", cuda_device=0)


    print("Reading input files ...")
    for file in glob.glob(os.path.join(input_folder, "*.txt")):
        print(file)
        with open(file, encoding="utf-8") as f:
            sentences = list(f)

        annotations = predict_combined(nlp, sentences, zs_predictor, ev_predictor, ev_predictor, ev_predictor)

        out_name = os.path.splitext(os.path.basename(file))[0]
        with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.txt", "w", encoding="utf-8") as f_out:
            for ann in annotations:
                for line in ann.to_txt():
                    f_out.write(line + os.linesep)
                f_out.write(os.linesep)

        with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.json", "w", encoding="utf-8") as f_out:
            json.dump([dataclasses.asdict(ann) for ann in annotations], f_out)


if __name__ == "__main__":
    main(sys.argv[1])