File size: 4,295 Bytes
bc6f57a
 
 
 
4467900
bc6f57a
bfcba2d
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5003662
bc6f57a
 
 
5003662
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
5003662
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4467900
 
 
 
 
 
 
 
 
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25fcabc
4467900
bc6f57a
 
 
 
 
 
 
 
 
 
 
 
 
 
4467900
bc6f57a
 
 
 
 
4467900
bc6f57a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import json
from collections import defaultdict
from typing import Dict, List, Optional, Union

from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from rendering_utils_displacy import EntityRenderer


def render_pretty_table(
    document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
):
    from prettytable import PrettyTable

    t = PrettyTable()
    t.field_names = ["head", "tail", "relation"]
    t.align = "l"
    for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
        t.add_row([str(relation.head), str(relation.tail), relation.label])

    html = t.get_html_string(format=True)
    html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"

    return html


def render_displacy(
    document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
    inject_relations=True,
    colors_hover=None,
    entity_options={},
    **render_kwargs,
):

    spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
    spacy_doc = {
        "text": document.text,
        "ents": [
            {"start": entity.start, "end": entity.end, "label": entity.label} for entity in spans
        ],
        "title": None,
    }

    renderer = EntityRenderer(options=entity_options)
    html = renderer.render([spacy_doc], page=True, minify=True).strip()

    html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
    if inject_relations:
        binary_relations = list(document.binary_relations) + list(
            document.binary_relations.predictions
        )
        sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
        html = inject_relation_data(
            html,
            sorted_entities=sorted_entities,
            binary_relations=binary_relations,
            additional_colors=colors_hover,
        )
    return html


def labeled_span_to_id(span: LabeledSpan) -> str:
    return f"span-{span.start}-{span.end}-{span.label}"


def labeled_span_from_id(span_id: str) -> LabeledSpan:
    parts = span_id.split("-")
    return LabeledSpan(int(parts[1]), int(parts[2]), parts[3])


def inject_relation_data(
    html: str,
    sorted_entities,
    binary_relations: List[BinaryRelation],
    additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
    from bs4 import BeautifulSoup

    # Parse the HTML using BeautifulSoup
    soup = BeautifulSoup(html, "html.parser")

    entity2tails = defaultdict(list)
    entity2heads = defaultdict(list)
    for relation in binary_relations:
        entity2heads[relation.tail].append((relation.head, relation.label))
        entity2tails[relation.head].append((relation.tail, relation.label))

    # Add unique IDs to each entity
    entities = soup.find_all(class_="entity")
    for idx, entity in enumerate(entities):
        annotation = sorted_entities[idx]
        entity["id"] = labeled_span_to_id(annotation)
        original_color = entity["style"].split("background:")[1].split(";")[0].strip()
        entity["data-color-original"] = original_color
        if additional_colors is not None:
            for key, color in additional_colors.items():
                entity[f"data-color-{key}"] = (
                    json.dumps(color) if isinstance(color, dict) else color
                )
        entity_annotation = sorted_entities[idx]
        # sanity check
        if str(entity_annotation) != entity.next:
            raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
        entity["data-label"] = entity_annotation.label
        entity["data-relation-tails"] = json.dumps(
            [
                {"entity-id": labeled_span_to_id(tail), "label": label}
                for tail, label in entity2tails.get(entity_annotation, [])
            ]
        )
        entity["data-relation-heads"] = json.dumps(
            [
                {"entity-id": labeled_span_to_id(head), "label": label}
                for head, label in entity2heads.get(entity_annotation, [])
            ]
        )

    # Return the modified HTML as a string
    return str(soup)