File size: 5,206 Bytes
bc6f57a
1681237
bc6f57a
 
 
86277c0
a8df5fb
bc6f57a
bfcba2d
bc6f57a
1681237
 
a8df5fb
 
 
 
 
 
 
 
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5003662
bc6f57a
 
 
5003662
bc6f57a
 
 
a8df5fb
bc6f57a
 
 
a8df5fb
 
 
 
 
 
 
 
 
bc6f57a
 
 
 
a8df5fb
 
 
 
5003662
bc6f57a
 
 
 
 
 
 
 
 
a8df5fb
bc6f57a
 
 
 
 
 
 
 
a8df5fb
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
 
 
a8df5fb
bc6f57a
 
a8df5fb
bc6f57a
 
 
 
 
 
 
a8df5fb
 
 
 
 
1681237
bc6f57a
 
 
4467900
bc6f57a
 
 
 
 
4467900
bc6f57a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Union

from annotation_utils import labeled_span_to_id
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from rendering_utils_displacy import EntityRenderer

logger = logging.getLogger(__name__)

# adjusted from rendering_utils_displacy.TPL_ENT
TPL_ENT_WITH_ID = """
<mark class="entity" id="{id}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
    {text}
    <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
</mark>
"""


def render_pretty_table(
    document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
):
    from prettytable import PrettyTable

    t = PrettyTable()
    t.field_names = ["head", "tail", "relation"]
    t.align = "l"
    for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
        t.add_row([str(relation.head), str(relation.tail), relation.label])

    html = t.get_html_string(format=True)
    html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"

    return html


def render_displacy(
    document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
    inject_relations=True,
    colors_hover=None,
    entity_options={},
    **render_kwargs,
):

    labeled_spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
    spacy_doc = {
        "text": document.text,
        "ents": [
            {
                "start": labeled_span.start,
                "end": labeled_span.end,
                "label": labeled_span.label,
                # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
                # on hover and to inject the relation data.
                "params": {"id": labeled_span_to_id(labeled_span)},
            }
            for labeled_span in labeled_spans
        ],
        "title": None,
    }

    # copy to avoid modifying the original options
    entity_options = entity_options.copy()
    # use the custom template with the entity ID
    entity_options["template"] = TPL_ENT_WITH_ID
    renderer = EntityRenderer(options=entity_options)
    html = renderer.render([spacy_doc], page=True, minify=True).strip()

    html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
    if inject_relations:
        binary_relations = list(document.binary_relations) + list(
            document.binary_relations.predictions
        )
        html = inject_relation_data(
            html,
            labeled_spans=labeled_spans,
            binary_relations=binary_relations,
            additional_colors=colors_hover,
        )
    return html


def inject_relation_data(
    html: str,
    labeled_spans: List[LabeledSpan],
    binary_relations: List[BinaryRelation],
    additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
    from bs4 import BeautifulSoup

    # Parse the HTML using BeautifulSoup
    soup = BeautifulSoup(html, "html.parser")

    entity2tails = defaultdict(list)
    entity2heads = defaultdict(list)
    for relation in binary_relations:
        entity2heads[relation.tail].append((relation.head, relation.label))
        entity2tails[relation.head].append((relation.tail, relation.label))

    ann_id2annotation = {labeled_span_to_id(entity): entity for entity in labeled_spans}
    # Add unique IDs to each entity
    entities = soup.find_all(class_="entity")
    for entity in entities:
        original_color = entity["style"].split("background:")[1].split(";")[0].strip()
        entity["data-color-original"] = original_color
        if additional_colors is not None:
            for key, color in additional_colors.items():
                entity[f"data-color-{key}"] = (
                    json.dumps(color) if isinstance(color, dict) else color
                )
        entity_annotation = ann_id2annotation[entity["id"]]
        # sanity check.
        annotation_text_without_newline = str(entity_annotation).replace("\n", "")
        # Just check the start, because the text has the label attached to the end
        if not entity.text.startswith(annotation_text_without_newline):
            logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
        entity["data-label"] = entity_annotation.label
        entity["data-relation-tails"] = json.dumps(
            [
                {"entity-id": labeled_span_to_id(tail), "label": label}
                for tail, label in entity2tails.get(entity_annotation, [])
            ]
        )
        entity["data-relation-heads"] = json.dumps(
            [
                {"entity-id": labeled_span_to_id(head), "label": label}
                for head, label in entity2heads.get(entity_annotation, [])
            ]
        )

    # Return the modified HTML as a string
    return str(soup)