Gosse Minnema commited on
Commit
adbf6a0
·
1 Parent(s): 34cc06e

Allow user to set confidence threshold for frames + roles, default to 0.95

Browse files
Files changed (1) hide show
  1. sociolome/lome_webserver.py +5 -4
sociolome/lome_webserver.py CHANGED
@@ -47,7 +47,7 @@ def convert_to_seq_labels(sentence: List[str], structures: Dict[int, Dict[str, A
47
  labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
48
  return labels
49
 
50
- def make_prediction(sentence, spacy_model, predictor):
51
  spacy_doc = spacy_model(sentence)
52
  tokens = [t.text for t in spacy_doc]
53
  tgt_spans, fr_labels, fr_probas = predictor.force_decode(tokens)
@@ -59,7 +59,7 @@ def make_prediction(sentence, spacy_model, predictor):
59
  continue
60
  if frm.upper() == frm:
61
  continue
62
- if fr_proba.max() != 1.0:
63
  continue
64
 
65
  arg_spans, arg_labels, label_probas = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm)
@@ -70,7 +70,7 @@ def make_prediction(sentence, spacy_model, predictor):
70
  "roles": [
71
  {"boundary": bnd, "label": label}
72
  for bnd, label, probas in zip(arg_spans, arg_labels, label_probas)
73
- if label != "Target" and max(probas) == 1.0
74
  ]
75
  }
76
 
@@ -96,9 +96,10 @@ app = Flask(__name__)
96
  @app.route("/analyze")
97
  def analyze():
98
  text = request.args.get("text")
 
99
  analyses = []
100
  for sentence in text.split("\n"):
101
- analyses.append(make_prediction(sentence, nlp, predictor))
102
 
103
  return jsonify({
104
  "result": "OK",
 
47
  labels[i].append(f"{prefix}:{frame}:{role_label}@{struct_id:02}")
48
  return labels
49
 
50
+ def make_prediction(sentence, spacy_model, predictor, confidence_threshold):
51
  spacy_doc = spacy_model(sentence)
52
  tokens = [t.text for t in spacy_doc]
53
  tgt_spans, fr_labels, fr_probas = predictor.force_decode(tokens)
 
59
  continue
60
  if frm.upper() == frm:
61
  continue
62
+ if fr_proba.max() != confidence_threshold:
63
  continue
64
 
65
  arg_spans, arg_labels, label_probas = predictor.force_decode(tokens, parent_span=tgt, parent_label=frm)
 
70
  "roles": [
71
  {"boundary": bnd, "label": label}
72
  for bnd, label, probas in zip(arg_spans, arg_labels, label_probas)
73
+ if label != "Target" and max(probas) == confidence_threshold
74
  ]
75
  }
76
 
 
96
  @app.route("/analyze")
97
  def analyze():
98
  text = request.args.get("text")
99
+ confidence_threshold = float(request.args.get("confidence_threshold", 0.95))
100
  analyses = []
101
  for sentence in text.split("\n"):
102
+ analyses.append(make_prediction(sentence, nlp, predictor, confidence_threshold))
103
 
104
  return jsonify({
105
  "result": "OK",