|
import json |
|
from collections import defaultdict |
|
from typing import Dict, List, Optional, Union |
|
|
|
from annotation_utils import labeled_span_to_id |
|
from pytorch_ie.annotations import BinaryRelation |
|
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions |
|
from rendering_utils_displacy import EntityRenderer |
|
|
|
|
|
def render_pretty_table( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs |
|
): |
|
from prettytable import PrettyTable |
|
|
|
t = PrettyTable() |
|
t.field_names = ["head", "tail", "relation"] |
|
t.align = "l" |
|
for relation in list(document.binary_relations) + list(document.binary_relations.predictions): |
|
t.add_row([str(relation.head), str(relation.tail), relation.label]) |
|
|
|
html = t.get_html_string(format=True) |
|
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>" |
|
|
|
return html |
|
|
|
|
|
def render_displacy( |
|
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, |
|
inject_relations=True, |
|
colors_hover=None, |
|
entity_options={}, |
|
**render_kwargs, |
|
): |
|
|
|
spans = list(document.labeled_spans) + list(document.labeled_spans.predictions) |
|
spacy_doc = { |
|
"text": document.text, |
|
"ents": [ |
|
{"start": entity.start, "end": entity.end, "label": entity.label} for entity in spans |
|
], |
|
"title": None, |
|
} |
|
|
|
renderer = EntityRenderer(options=entity_options) |
|
html = renderer.render([spacy_doc], page=True, minify=True).strip() |
|
|
|
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>" |
|
if inject_relations: |
|
binary_relations = list(document.binary_relations) + list( |
|
document.binary_relations.predictions |
|
) |
|
sorted_entities = sorted(spans, key=lambda x: (x.start, x.end)) |
|
html = inject_relation_data( |
|
html, |
|
sorted_entities=sorted_entities, |
|
binary_relations=binary_relations, |
|
additional_colors=colors_hover, |
|
) |
|
return html |
|
|
|
|
|
def inject_relation_data( |
|
html: str, |
|
sorted_entities, |
|
binary_relations: List[BinaryRelation], |
|
additional_colors: Optional[Dict[str, Union[str, dict]]] = None, |
|
) -> str: |
|
from bs4 import BeautifulSoup |
|
|
|
|
|
soup = BeautifulSoup(html, "html.parser") |
|
|
|
entity2tails = defaultdict(list) |
|
entity2heads = defaultdict(list) |
|
for relation in binary_relations: |
|
entity2heads[relation.tail].append((relation.head, relation.label)) |
|
entity2tails[relation.head].append((relation.tail, relation.label)) |
|
|
|
|
|
entities = soup.find_all(class_="entity") |
|
for idx, entity in enumerate(entities): |
|
annotation = sorted_entities[idx] |
|
entity["id"] = labeled_span_to_id(annotation) |
|
original_color = entity["style"].split("background:")[1].split(";")[0].strip() |
|
entity["data-color-original"] = original_color |
|
if additional_colors is not None: |
|
for key, color in additional_colors.items(): |
|
entity[f"data-color-{key}"] = ( |
|
json.dumps(color) if isinstance(color, dict) else color |
|
) |
|
entity_annotation = sorted_entities[idx] |
|
|
|
if str(entity_annotation) != entity.next: |
|
raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}") |
|
entity["data-label"] = entity_annotation.label |
|
entity["data-relation-tails"] = json.dumps( |
|
[ |
|
{"entity-id": labeled_span_to_id(tail), "label": label} |
|
for tail, label in entity2tails.get(entity_annotation, []) |
|
] |
|
) |
|
entity["data-relation-heads"] = json.dumps( |
|
[ |
|
{"entity-id": labeled_span_to_id(head), "label": label} |
|
for head, label in entity2heads.get(entity_annotation, []) |
|
] |
|
) |
|
|
|
|
|
return str(soup) |
|
|