Spaces:
Build error
Build error
from sftp import SpanPredictor | |
import spacy | |
import sys | |
import dataclasses | |
from typing import List, Optional, Dict, Any | |
predictor = SpanPredictor.from_path("model.mod.tar.gz") | |
nlp = spacy.load("xx_sent_ud_sm") | |
class FrameAnnotation: | |
tokens: List[str] = dataclasses.field(default_factory=list) | |
pos: List[str] = dataclasses.field(default_factory=list) | |
class MultiLabelAnnotation(FrameAnnotation): | |
frame_list: List[List[str]] = dataclasses.field(default_factory=list) | |
lu_list: List[Optional[str]] = dataclasses.field(default_factory=list) | |
def to_txt(self): | |
for i, tok in enumerate(self.tokens): | |
yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}" | |
# reused from "combine_predictions.py" (cloned/lome/src/spanfinder/sociolome) | |
def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]: | |
labels = [[] for _ in sentence] | |
for struct_id, struct in structures.items(): | |
tgt_span = struct["target"] | |
frame = struct["frame"] | |
for i in range(tgt_span[0], tgt_span[1] + 1): | |
labels[i].append(f"T:{frame}@{struct_id:02}") | |
for role in struct["roles"]: | |
role_span = role["boundary"] | |
role_label = role["label"] | |
for i in range(role_span[0], role_span[1] + 1): | |
prefix = "B" if i == role_span[0] else "I" | |
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}") | |
return labels | |
def make_prediction(sentence, spacy_model, predictor): | |
spacy_doc = spacy_model(sentence) | |
tokens = [t.text for t in spacy_doc] | |
tgt_spans, fr_labels, _ = predictor.force_decode(tokens) | |
frame_structures = {} | |
for i, (tgt, frm) in enumerate(sorted(zip(tgt_spans, fr_labels), key=lambda t: t[0][0])): | |
arg_spans, arg_labels, _ = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm) | |
frame_structures[i] = { | |
"target": tgt, | |
"frame": frm, | |
"roles": [ | |
{"boundary": bnd, "label": label} | |
for bnd, label in zip(arg_spans, arg_labels) | |
if label != "Target" | |
] | |
} | |
return MultiLabelAnnotation( | |
tokens=tokens, | |
pos=[t.pos_ for t in spacy_doc], | |
frame_list=convert_to_seq_labels(tokens, frame_structures), | |
lu_list=[None for _ in tokens] | |
) | |
def analyze(text): | |
analyses = [] | |
for sentence in text.split("\n"): | |
analyses.append(make_prediction(sentence, nlp, predictor)) | |
return { | |
"result": "OK", | |
"analyses": [dataclasses.asdict(an) for an in analyses] | |
} | |