File size: 5,206 Bytes
bc6f57a 1681237 bc6f57a 86277c0 a8df5fb bc6f57a bfcba2d bc6f57a 1681237 a8df5fb bc6f57a 5003662 bc6f57a 5003662 bc6f57a a8df5fb bc6f57a a8df5fb bc6f57a a8df5fb 5003662 bc6f57a a8df5fb bc6f57a a8df5fb bc6f57a a8df5fb bc6f57a a8df5fb bc6f57a a8df5fb 1681237 bc6f57a 4467900 bc6f57a 4467900 bc6f57a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import json
import logging
from collections import defaultdict
from typing import Dict, List, Optional, Union
from annotation_utils import labeled_span_to_id
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
from rendering_utils_displacy import EntityRenderer
logger = logging.getLogger(__name__)
# adjusted from rendering_utils_displacy.TPL_ENT
TPL_ENT_WITH_ID = """
<mark class="entity" id="{id}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
{text}
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
</mark>
"""
def render_pretty_table(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
):
from prettytable import PrettyTable
t = PrettyTable()
t.field_names = ["head", "tail", "relation"]
t.align = "l"
for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
t.add_row([str(relation.head), str(relation.tail), relation.label])
html = t.get_html_string(format=True)
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
return html
def render_displacy(
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
inject_relations=True,
colors_hover=None,
entity_options={},
**render_kwargs,
):
labeled_spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
spacy_doc = {
"text": document.text,
"ents": [
{
"start": labeled_span.start,
"end": labeled_span.end,
"label": labeled_span.label,
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
# on hover and to inject the relation data.
"params": {"id": labeled_span_to_id(labeled_span)},
}
for labeled_span in labeled_spans
],
"title": None,
}
# copy to avoid modifying the original options
entity_options = entity_options.copy()
# use the custom template with the entity ID
entity_options["template"] = TPL_ENT_WITH_ID
renderer = EntityRenderer(options=entity_options)
html = renderer.render([spacy_doc], page=True, minify=True).strip()
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
if inject_relations:
binary_relations = list(document.binary_relations) + list(
document.binary_relations.predictions
)
html = inject_relation_data(
html,
labeled_spans=labeled_spans,
binary_relations=binary_relations,
additional_colors=colors_hover,
)
return html
def inject_relation_data(
html: str,
labeled_spans: List[LabeledSpan],
binary_relations: List[BinaryRelation],
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
) -> str:
from bs4 import BeautifulSoup
# Parse the HTML using BeautifulSoup
soup = BeautifulSoup(html, "html.parser")
entity2tails = defaultdict(list)
entity2heads = defaultdict(list)
for relation in binary_relations:
entity2heads[relation.tail].append((relation.head, relation.label))
entity2tails[relation.head].append((relation.tail, relation.label))
ann_id2annotation = {labeled_span_to_id(entity): entity for entity in labeled_spans}
# Add unique IDs to each entity
entities = soup.find_all(class_="entity")
for entity in entities:
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
entity["data-color-original"] = original_color
if additional_colors is not None:
for key, color in additional_colors.items():
entity[f"data-color-{key}"] = (
json.dumps(color) if isinstance(color, dict) else color
)
entity_annotation = ann_id2annotation[entity["id"]]
# sanity check.
annotation_text_without_newline = str(entity_annotation).replace("\n", "")
# Just check the start, because the text has the label attached to the end
if not entity.text.startswith(annotation_text_without_newline):
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
entity["data-label"] = entity_annotation.label
entity["data-relation-tails"] = json.dumps(
[
{"entity-id": labeled_span_to_id(tail), "label": label}
for tail, label in entity2tails.get(entity_annotation, [])
]
)
entity["data-relation-heads"] = json.dumps(
[
{"entity-id": labeled_span_to_id(head), "label": label}
for head, label in entity2heads.get(entity_annotation, [])
]
)
# Return the modified HTML as a string
return str(soup)
|