import json import logging from collections import defaultdict from typing import Dict, List, Optional, Union from annotation_utils import labeled_span_to_id from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from rendering_utils_displacy import EntityRenderer logger = logging.getLogger(__name__) # adjusted from rendering_utils_displacy.TPL_ENT TPL_ENT_WITH_ID = """ {text} {label} """ def render_pretty_table( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs ): from prettytable import PrettyTable t = PrettyTable() t.field_names = ["head", "tail", "relation"] t.align = "l" for relation in list(document.binary_relations) + list(document.binary_relations.predictions): t.add_row([str(relation.head), str(relation.tail), relation.label]) html = t.get_html_string(format=True) html = "
" + html + "
" return html def render_displacy( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, inject_relations=True, colors_hover=None, entity_options={}, **render_kwargs, ): labeled_spans = list(document.labeled_spans) + list(document.labeled_spans.predictions) spacy_doc = { "text": document.text, "ents": [ { "start": labeled_span.start, "end": labeled_span.end, "label": labeled_span.label, # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations # on hover and to inject the relation data. "params": {"id": labeled_span_to_id(labeled_span)}, } for labeled_span in labeled_spans ], "title": None, } # copy to avoid modifying the original options entity_options = entity_options.copy() # use the custom template with the entity ID entity_options["template"] = TPL_ENT_WITH_ID renderer = EntityRenderer(options=entity_options) html = renderer.render([spacy_doc], page=True, minify=True).strip() html = "
" + html + "
" if inject_relations: binary_relations = list(document.binary_relations) + list( document.binary_relations.predictions ) html = inject_relation_data( html, labeled_spans=labeled_spans, binary_relations=binary_relations, additional_colors=colors_hover, ) return html def inject_relation_data( html: str, labeled_spans: List[LabeledSpan], binary_relations: List[BinaryRelation], additional_colors: Optional[Dict[str, Union[str, dict]]] = None, ) -> str: from bs4 import BeautifulSoup # Parse the HTML using BeautifulSoup soup = BeautifulSoup(html, "html.parser") entity2tails = defaultdict(list) entity2heads = defaultdict(list) for relation in binary_relations: entity2heads[relation.tail].append((relation.head, relation.label)) entity2tails[relation.head].append((relation.tail, relation.label)) ann_id2annotation = {labeled_span_to_id(entity): entity for entity in labeled_spans} # Add unique IDs to each entity entities = soup.find_all(class_="entity") for entity in entities: original_color = entity["style"].split("background:")[1].split(";")[0].strip() entity["data-color-original"] = original_color if additional_colors is not None: for key, color in additional_colors.items(): entity[f"data-color-{key}"] = ( json.dumps(color) if isinstance(color, dict) else color ) entity_annotation = ann_id2annotation[entity["id"]] # sanity check. annotation_text_without_newline = str(entity_annotation).replace("\n", "") # Just check the start, because the text has the label attached to the end if not entity.text.startswith(annotation_text_without_newline): logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}") entity["data-label"] = entity_annotation.label entity["data-relation-tails"] = json.dumps( [ {"entity-id": labeled_span_to_id(tail), "label": label} for tail, label in entity2tails.get(entity_annotation, []) ] ) entity["data-relation-heads"] = json.dumps( [ {"entity-id": labeled_span_to_id(head), "label": label} for head, label in entity2heads.get(entity_annotation, []) ] ) # Return the modified HTML as a string return str(soup)