ner_evaluation_metrics / predefined_example.py
wadood's picture
added default predictions in predefined examples
8f56f21
from dataclasses import dataclass, field
from span_dataclass_converters import get_ner_spans_from_annotations
@dataclass
class PredefinedExample:
text: str
gt_labels: dict
default_predictions: list = field(default_factory=list)
# gt_spans: list
# predictions: list
@property
def gt_spans(self):
return sorted(
get_ner_spans_from_annotations(self.gt_labels),
key=lambda span: span["start"],
)
@property
def predictions(self):
return [self.gt_spans] + self.default_predictions
@property
def tags(self):
return list(self.gt_labels.keys())
small_example = PredefinedExample(
text="The patient was diagnosed with bronchitis and was prescribed a mucolytic",
gt_labels={
"Disease": [
{"start": 31, "end": 41, "label": "bronchitis"},
],
"Drug": [
{"start": 63, "end": 72, "label": "mucolytic"},
],
},
default_predictions=[
[
{
"start": 26,
"end": 41,
"label": "Disease",
"span_text": "with bronchitis",
},
{"start": 61, "end": 72, "label": "Drug", "span_text": "a mucolytic"},
],
[
{"start": 31, "end": 41, "label": "Drug", "span_text": "bronchitis"},
{"start": 63, "end": 72, "label": "Drug", "span_text": "mucolytic"},
],
[
{
"start": 31,
"end": 72,
"label": "Disease",
"span_text": "bronchitis and was prescribed a mucolytic",
}
],
],
)
big_example = PredefinedExample(
text=(
"The patient was experiencing stomach pain and flu like symptoms for 3 days. "
"Upon investigation, the chest xray revealed acute bronchitis disease. "
"The patient was asked to take rest for a week and was prescribed a mucolytic along with paracetamol for body pains."
),
gt_labels={
"Disease": [
{"start": 120, "end": 144, "label": "acute bronchitis disease"},
],
"Drug": [
{"start": 213, "end": 222, "label": "mucolytic"},
{"start": 234, "end": 245, "label": "paracetamol"},
],
"Symptoms": [
{"start": 29, "end": 41, "label": "stomach pain"},
{"start": 46, "end": 63, "label": "flu like symptoms"},
],
},
default_predictions=[
[
{"start": 29, "end": 41, "label": "Symptoms", "span_text": "stomach pain"},
{"start": 46, "end": 49, "label": "Symptoms", "span_text": "flu"},
{
"start": 120,
"end": 136,
"label": "Disease",
"span_text": "acute bronchitis",
},
{"start": 213, "end": 222, "label": "Drug", "span_text": "mucolytic"},
{"start": 234, "end": 245, "label": "Drug", "span_text": "paracetamol"},
],
[
{"start": 29, "end": 41, "label": "Symptoms", "span_text": "stomach pain"},
{"start": 46, "end": 49, "label": "Disease", "span_text": "flu"},
{
"start": 120,
"end": 136,
"label": "Disease",
"span_text": "acute bronchitis",
},
{"start": 213, "end": 222, "label": "Drug", "span_text": "mucolytic"},
{"start": 234, "end": 245, "label": "Drug", "span_text": "paracetamol"},
{"start": 250, "end": 260, "label": "Symptoms", "span_text": "body pains"},
],
],
)
EXAMPLES = [small_example, big_example]