fillmorle-app / sociolome /lome_wrapper.py
gossminn's picture
First version
6680682
raw
history blame
2.76 kB
from sftp import SpanPredictor
import spacy
import sys
import dataclasses
from typing import List, Optional, Dict, Any
predictor = SpanPredictor.from_path("model.mod.tar.gz")
nlp = spacy.load("xx_sent_ud_sm")
@dataclasses.dataclass
class FrameAnnotation:
tokens: List[str] = dataclasses.field(default_factory=list)
pos: List[str] = dataclasses.field(default_factory=list)
@dataclasses.dataclass
class MultiLabelAnnotation(FrameAnnotation):
frame_list: List[List[str]] = dataclasses.field(default_factory=list)
lu_list: List[Optional[str]] = dataclasses.field(default_factory=list)
def to_txt(self):
for i, tok in enumerate(self.tokens):
yield f"{tok} {self.pos[i]} {'|'.join(self.frame_list[i]) or '_'} {self.lu_list[i] or '_'}"
# reused from "combine_predictions.py" (cloned/lome/src/spanfinder/sociolome)
def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, Any]]) -> List[List[str]]:
labels = [[] for _ in sentence]
for struct_id, struct in structures.items():
tgt_span = struct["target"]
frame = struct["frame"]
for i in range(tgt_span[0], tgt_span[1] + 1):
labels[i].append(f"T:{frame}@{struct_id:02}")
for role in struct["roles"]:
role_span = role["boundary"]
role_label = role["label"]
for i in range(role_span[0], role_span[1] + 1):
prefix = "B" if i == role_span[0] else "I"
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
return labels
def make_prediction(sentence, spacy_model, predictor):
spacy_doc = spacy_model(sentence)
tokens = [t.text for t in spacy_doc]
tgt_spans, fr_labels, _ = predictor.force_decode(tokens)
frame_structures = {}
for i, (tgt, frm) in enumerate(sorted(zip(tgt_spans, fr_labels), key=lambda t: t[0][0])):
arg_spans, arg_labels, _ = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm)
frame_structures[i] = {
"target": tgt,
"frame": frm,
"roles": [
{"boundary": bnd, "label": label}
for bnd, label in zip(arg_spans, arg_labels)
if label != "Target"
]
}
return MultiLabelAnnotation(
tokens=tokens,
pos=[t.pos_ for t in spacy_doc],
frame_list=convert_to_seq_labels(tokens, frame_structures),
lu_list=[None for _ in tokens]
)
def analyze(text):
analyses = []
for sentence in text.split("\n"):
analyses.append(make_prediction(sentence, nlp, predictor))
return {
"result": "OK",
"analyses": [dataclasses.asdict(an) for an in analyses]
}