Gosse Minnema
commited on
Commit
·
adbf6a0
1
Parent(s):
34cc06e
Allow user to set confidence threshold for frames + roles, default to 0.95
Browse files
sociolome/lome_webserver.py
CHANGED
@@ -47,7 +47,7 @@ def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, A
|
|
47 |
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
|
48 |
return labels
|
49 |
|
50 |
-
def make_prediction(sentence, spacy_model, predictor):
|
51 |
spacy_doc = spacy_model(sentence)
|
52 |
tokens = [t.text for t in spacy_doc]
|
53 |
tgt_spans, fr_labels, fr_probas = predictor.force_decode(tokens)
|
@@ -59,7 +59,7 @@ def make_prediction(sentence, spacy_model, predictor):
|
|
59 |
continue
|
60 |
if frm.upper() == frm:
|
61 |
continue
|
62 |
-
if fr_proba.max() !=
|
63 |
continue
|
64 |
|
65 |
arg_spans, arg_labels, label_probas = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm)
|
@@ -70,7 +70,7 @@ def make_prediction(sentence, spacy_model, predictor):
|
|
70 |
"roles": [
|
71 |
{"boundary": bnd, "label": label}
|
72 |
for bnd, label, probas in zip(arg_spans, arg_labels, label_probas)
|
73 |
-
if label != "Target" and max(probas) ==
|
74 |
]
|
75 |
}
|
76 |
|
@@ -96,9 +96,10 @@ app = Flask(__name__)
|
|
96 |
@app.route("/analyze")
|
97 |
def analyze():
|
98 |
text = request.args.get("text")
|
|
|
99 |
analyses = []
|
100 |
for sentence in text.split("\n"):
|
101 |
-
analyses.append(make_prediction(sentence, nlp, predictor))
|
102 |
|
103 |
return jsonify({
|
104 |
"result": "OK",
|
|
|
47 |
labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
|
48 |
return labels
|
49 |
|
50 |
+
def make_prediction(sentence, spacy_model, predictor, confidence_threshold):
|
51 |
spacy_doc = spacy_model(sentence)
|
52 |
tokens = [t.text for t in spacy_doc]
|
53 |
tgt_spans, fr_labels, fr_probas = predictor.force_decode(tokens)
|
|
|
59 |
continue
|
60 |
if frm.upper() == frm:
|
61 |
continue
|
62 |
+
if fr_proba.max() != confidence_threshold:
|
63 |
continue
|
64 |
|
65 |
arg_spans, arg_labels, label_probas = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm)
|
|
|
70 |
"roles": [
|
71 |
{"boundary": bnd, "label": label}
|
72 |
for bnd, label, probas in zip(arg_spans, arg_labels, label_probas)
|
73 |
+
if label != "Target" and max(probas) == confidence_threshold
|
74 |
]
|
75 |
}
|
76 |
|
|
|
96 |
@app.route("/analyze")
|
97 |
def analyze():
|
98 |
text = request.args.get("text")
|
99 |
+
confidence_threshold = float(request.args.get("confidence_threshold", 0.95))
|
100 |
analyses = []
|
101 |
for sentence in text.split("\n"):
|
102 |
+
analyses.append(make_prediction(sentence, nlp, predictor, confidence_threshold))
|
103 |
|
104 |
return jsonify({
|
105 |
"result": "OK",
|