Spaces:
Build error
Build error
from typing import Any, Dict, List, Optional | |
import dataclasses | |
import glob | |
import os | |
import sys | |
import json | |
import spacy | |
from spacy.language import Language | |
from sftp import SpanPredictor | |
class FrameAnnotation: | |
tokens: List[str] = dataclasses.field(default_factory=list) | |
pos: List[str] = dataclasses.field(default_factory=list) | |
class MultiLabelAnnotation(FrameAnnotation): | |
frame_list: List[List[str]] = dataclasses.field(default_factory=list) | |
lu_list: List[Optional[str]] = dataclasses.field(default_factory=list) | |
def to_txt(self): | |
for i, tok in enumerate(self.tokens): | |
yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}" | |
def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]: | |
labels = [[] for _ in sentence] | |
for struct_id, struct in structures.items(): | |
tgt_span = struct["target"] | |
frame = struct["frame"] | |
for i in range(tgt_span[0], tgt_span[1] + 1): | |
labels[i].append(f"T:{frame}@{struct_id:02}") | |
for role in struct["roles"]: | |
role_span = role["boundary"] | |
role_label = role["label"] | |
for i in range(role_span[0], role_span[1] + 1): | |
prefix = "B" if i == role_span[0] else "I" | |
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}") | |
return labels | |
def predict_combined( | |
spacy_model: Language, | |
sentences: List[str], | |
tgt_predictor: SpanPredictor, | |
frm_predictor: SpanPredictor, | |
bnd_predictor: SpanPredictor, | |
arg_predictor: SpanPredictor, | |
) -> List[MultiLabelAnnotation]: | |
annotations_out = [] | |
for sent_idx, sent in enumerate(sentences): | |
sent = sent.strip() | |
print(f"Processing sent with idx={sent_idx}: {sent}") | |
doc = spacy_model(sent) | |
sent_tokens = [t.text for t in doc] | |
tgt_spans, _, _ = tgt_predictor.force_decode(sent_tokens) | |
frame_structures = {} | |
for i, span in enumerate(tgt_spans): | |
span = tuple(span) | |
_, fr_labels, _ = frm_predictor.force_decode(sent_tokens, child_spans=[span]) | |
frame = fr_labels[0] | |
if frame == "@@VIRTUAL_ROOT@@@": | |
continue | |
boundaries, _, _ = bnd_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame) | |
_, arg_labels, _ = arg_predictor.force_decode(sent_tokens, parent_span=span, parent_label=frame, child_spans=boundaries) | |
frame_structures[i] = { | |
"target": span, | |
"frame": frame, | |
"roles": [ | |
{"boundary": bnd, "label": label} | |
for bnd, label in zip(boundaries, arg_labels) | |
if label != "Target" | |
] | |
} | |
annotations_out.append(MultiLabelAnnotation( | |
tokens=sent_tokens, | |
pos=[t.pos_ for t in doc], | |
frame_list=convert_to_seq_labels(sent_tokens, frame_structures), | |
lu_list=[None for _ in sent_tokens] | |
)) | |
return annotations_out | |
def main(input_folder): | |
print("Loading spaCy model ...") | |
nlp = spacy.load("it_core_news_md") | |
print("Loading predictors ...") | |
zs_predictor = SpanPredictor.from_path("/data/p289731/cloned/lome-models/models/spanfinder/model.mod.tar.gz", cuda_device=0) | |
ev_predictor = SpanPredictor.from_path("/scratch/p289731/lome-training-files/train-evalita-plus-fn-vanilla/model.tar.gz", cuda_device=0) | |
print("Reading input files ...") | |
for file in glob.glob(os.path.join(input_folder, "*.txt")): | |
print(file) | |
with open(file, encoding="utf-8") as f: | |
sentences = list(f) | |
annotations = predict_combined(nlp, sentences, zs_predictor, ev_predictor, ev_predictor, ev_predictor) | |
out_name = os.path.splitext(os.path.basename(file))[0] | |
with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.txt", "w", encoding="utf-8") as f_out: | |
for ann in annotations: | |
for line in ann.to_txt(): | |
f_out.write(line + os.linesep) | |
f_out.write(os.linesep) | |
with open(f"../../data-out/{out_name}.combined_zs_ev.tc_bilstm.json", "w", encoding="utf-8") as f_out: | |
json.dump([dataclasses.asdict(ann) for ann in annotations], f_out) | |
if __name__ == "__main__": | |
main(sys.argv[1]) | |