import json import logging from collections import defaultdict from typing import Dict, List, Optional, Union from annotation_utils import labeled_span_to_id from pytorch_ie.annotations import BinaryRelation from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from rendering_utils_displacy import EntityRenderer logger = logging.getLogger(__name__) def render_pretty_table( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs ): from prettytable import PrettyTable t = PrettyTable() t.field_names = ["head", "tail", "relation"] t.align = "l" for relation in list(document.binary_relations) + list(document.binary_relations.predictions): t.add_row([str(relation.head), str(relation.tail), relation.label]) html = t.get_html_string(format=True) html = "
" + html + "
" return html def render_displacy( document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, inject_relations=True, colors_hover=None, entity_options={}, **render_kwargs, ): spans = list(document.labeled_spans) + list(document.labeled_spans.predictions) spacy_doc = { "text": document.text, "ents": [ {"start": entity.start, "end": entity.end, "label": entity.label} for entity in spans ], "title": None, } renderer = EntityRenderer(options=entity_options) html = renderer.render([spacy_doc], page=True, minify=True).strip() html = "
" + html + "
" if inject_relations: binary_relations = list(document.binary_relations) + list( document.binary_relations.predictions ) sorted_entities = sorted(spans, key=lambda x: (x.start, x.end)) html = inject_relation_data( html, sorted_entities=sorted_entities, binary_relations=binary_relations, additional_colors=colors_hover, ) return html def inject_relation_data( html: str, sorted_entities, binary_relations: List[BinaryRelation], additional_colors: Optional[Dict[str, Union[str, dict]]] = None, ) -> str: from bs4 import BeautifulSoup # Parse the HTML using BeautifulSoup soup = BeautifulSoup(html, "html.parser") entity2tails = defaultdict(list) entity2heads = defaultdict(list) for relation in binary_relations: entity2heads[relation.tail].append((relation.head, relation.label)) entity2tails[relation.head].append((relation.tail, relation.label)) # Add unique IDs to each entity entities = soup.find_all(class_="entity") for idx, entity in enumerate(entities): annotation = sorted_entities[idx] entity["id"] = labeled_span_to_id(annotation) original_color = entity["style"].split("background:")[1].split(";")[0].strip() entity["data-color-original"] = original_color if additional_colors is not None: for key, color in additional_colors.items(): entity[f"data-color-{key}"] = ( json.dumps(color) if isinstance(color, dict) else color ) entity_annotation = sorted_entities[idx] # sanity check if str(entity_annotation) != entity.next: logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}") entity["data-label"] = entity_annotation.label entity["data-relation-tails"] = json.dumps( [ {"entity-id": labeled_span_to_id(tail), "label": label} for tail, label in entity2tails.get(entity_annotation, []) ] ) entity["data-relation-heads"] = json.dumps( [ {"entity-id": labeled_span_to_id(head), "label": label} for head, label in entity2heads.get(entity_annotation, []) ] ) # Return the modified HTML as a string return str(soup)