add rendering utils
Browse files- rendering_utils.py +113 -0
- rendering_utils_displacy.py +217 -0
rendering_utils.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import Dict, List, Optional, Union
|
4 |
+
|
5 |
+
from pytorch_ie.annotations import BinaryRelation
|
6 |
+
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
7 |
+
|
8 |
+
from .rendering_utils_displacy import EntityRenderer
|
9 |
+
|
10 |
+
|
11 |
+
def render_pretty_table(
|
12 |
+
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
|
13 |
+
):
|
14 |
+
from prettytable import PrettyTable
|
15 |
+
|
16 |
+
t = PrettyTable()
|
17 |
+
t.field_names = ["head", "tail", "relation"]
|
18 |
+
t.align = "l"
|
19 |
+
for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
|
20 |
+
t.add_row([str(relation.head), str(relation.tail), relation.label])
|
21 |
+
|
22 |
+
html = t.get_html_string(format=True)
|
23 |
+
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
|
24 |
+
|
25 |
+
return html
|
26 |
+
|
27 |
+
|
28 |
+
def render_spacy(
|
29 |
+
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
30 |
+
style="ent",
|
31 |
+
inject_relations=True,
|
32 |
+
colors_hover=None,
|
33 |
+
options={},
|
34 |
+
**render_kwargs,
|
35 |
+
):
|
36 |
+
|
37 |
+
spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
|
38 |
+
spacy_doc = {
|
39 |
+
"text": document.text,
|
40 |
+
"ents": [
|
41 |
+
{"start": entity.start, "end": entity.end, "label": entity.label} for entity in spans
|
42 |
+
],
|
43 |
+
"title": None,
|
44 |
+
}
|
45 |
+
|
46 |
+
renderer = EntityRenderer(options=options)
|
47 |
+
html = renderer.render([spacy_doc], page=True, minify=True).strip()
|
48 |
+
|
49 |
+
html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
|
50 |
+
if inject_relations:
|
51 |
+
binary_relations = list(document.binary_relations) + list(
|
52 |
+
document.binary_relations.predictions
|
53 |
+
)
|
54 |
+
sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
|
55 |
+
html = inject_relation_data(
|
56 |
+
html,
|
57 |
+
sorted_entities=sorted_entities,
|
58 |
+
binary_relations=binary_relations,
|
59 |
+
additional_colors=colors_hover,
|
60 |
+
)
|
61 |
+
return html
|
62 |
+
|
63 |
+
|
64 |
+
def inject_relation_data(
|
65 |
+
html: str,
|
66 |
+
sorted_entities,
|
67 |
+
binary_relations: List[BinaryRelation],
|
68 |
+
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
|
69 |
+
) -> str:
|
70 |
+
from bs4 import BeautifulSoup
|
71 |
+
|
72 |
+
# Parse the HTML using BeautifulSoup
|
73 |
+
soup = BeautifulSoup(html, "html.parser")
|
74 |
+
|
75 |
+
entity2tails = defaultdict(list)
|
76 |
+
entity2heads = defaultdict(list)
|
77 |
+
for relation in binary_relations:
|
78 |
+
entity2heads[relation.tail].append((relation.head, relation.label))
|
79 |
+
entity2tails[relation.head].append((relation.tail, relation.label))
|
80 |
+
|
81 |
+
entity2id = {entity: f"entity-{idx}" for idx, entity in enumerate(sorted_entities)}
|
82 |
+
|
83 |
+
# Add unique IDs to each entity
|
84 |
+
entities = soup.find_all(class_="entity")
|
85 |
+
for idx, entity in enumerate(entities):
|
86 |
+
entity["id"] = f"entity-{idx}"
|
87 |
+
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
|
88 |
+
entity["data-color-original"] = original_color
|
89 |
+
if additional_colors is not None:
|
90 |
+
for key, color in additional_colors.items():
|
91 |
+
entity[f"data-color-{key}"] = (
|
92 |
+
json.dumps(color) if isinstance(color, dict) else color
|
93 |
+
)
|
94 |
+
entity_annotation = sorted_entities[idx]
|
95 |
+
# sanity check
|
96 |
+
if str(entity_annotation) != entity.next:
|
97 |
+
raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
|
98 |
+
entity["data-label"] = entity_annotation.label
|
99 |
+
entity["data-relation-tails"] = json.dumps(
|
100 |
+
[
|
101 |
+
{"entity-id": entity2id[tail], "label": label}
|
102 |
+
for tail, label in entity2tails.get(entity_annotation, [])
|
103 |
+
]
|
104 |
+
)
|
105 |
+
entity["data-relation-heads"] = json.dumps(
|
106 |
+
[
|
107 |
+
{"entity-id": entity2id[head], "label": label}
|
108 |
+
for head, label in entity2heads.get(entity_annotation, [])
|
109 |
+
]
|
110 |
+
)
|
111 |
+
|
112 |
+
# Return the modified HTML as a string
|
113 |
+
return str(soup)
|
rendering_utils_displacy.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This code is mainly taken from
|
2 |
+
# https://github.com/explosion/spaCy/blob/master/spacy/displacy/templates.py, and from
|
3 |
+
# https://github.com/explosion/spaCy/blob/master/spacy/displacy/render.py.
|
4 |
+
|
5 |
+
# Setting explicit height and max-width: none on the SVG is required for
|
6 |
+
# Jupyter to render it properly in a cell
|
7 |
+
|
8 |
+
TPL_DEP_SVG = """
|
9 |
+
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:lang="{lang}" id="{id}" class="displacy" width="{width}" height="{height}" direction="{dir}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}; direction: {dir}">{content}</svg>
|
10 |
+
"""
|
11 |
+
|
12 |
+
|
13 |
+
TPL_DEP_WORDS = """
|
14 |
+
<text class="displacy-token" fill="currentColor" text-anchor="middle" y="{y}">
|
15 |
+
<tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan>
|
16 |
+
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan>
|
17 |
+
</text>
|
18 |
+
"""
|
19 |
+
|
20 |
+
|
21 |
+
TPL_DEP_WORDS_LEMMA = """
|
22 |
+
<text class="displacy-token" fill="currentColor" text-anchor="middle" y="{y}">
|
23 |
+
<tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan>
|
24 |
+
<tspan class="displacy-lemma" dy="2em" fill="currentColor" x="{x}">{lemma}</tspan>
|
25 |
+
<tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan>
|
26 |
+
</text>
|
27 |
+
"""
|
28 |
+
|
29 |
+
|
30 |
+
TPL_DEP_ARCS = """
|
31 |
+
<g class="displacy-arrow">
|
32 |
+
<path class="displacy-arc" id="arrow-{id}-{i}" stroke-width="{stroke}px" d="{arc}" fill="none" stroke="currentColor"/>
|
33 |
+
<text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
|
34 |
+
<textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" side="{label_side}" fill="currentColor" text-anchor="middle">{label}</textPath>
|
35 |
+
</text>
|
36 |
+
<path class="displacy-arrowhead" d="{head}" fill="currentColor"/>
|
37 |
+
</g>
|
38 |
+
"""
|
39 |
+
|
40 |
+
|
41 |
+
TPL_FIGURE = """
|
42 |
+
<figure style="margin-bottom: 6rem">{content}</figure>
|
43 |
+
"""
|
44 |
+
|
45 |
+
TPL_TITLE = """
|
46 |
+
<h2 style="margin: 0">{title}</h2>
|
47 |
+
"""
|
48 |
+
|
49 |
+
|
50 |
+
TPL_ENTS = """
|
51 |
+
<div class="entities" style="line-height: 2.5; direction: {dir}">{content}</div>
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
TPL_ENT = """
|
56 |
+
<mark class="entity" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
|
57 |
+
{text}
|
58 |
+
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
|
59 |
+
</mark>
|
60 |
+
"""
|
61 |
+
|
62 |
+
TPL_ENT_RTL = """
|
63 |
+
<mark class="entity" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em">
|
64 |
+
{text}
|
65 |
+
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-right: 0.5rem">{label}</span>
|
66 |
+
</mark>
|
67 |
+
"""
|
68 |
+
|
69 |
+
|
70 |
+
TPL_PAGE = """
|
71 |
+
<!DOCTYPE html>
|
72 |
+
<html lang="{lang}">
|
73 |
+
<head>
|
74 |
+
<title>displaCy</title>
|
75 |
+
</head>
|
76 |
+
|
77 |
+
<body style="font-size: 16px; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol'; padding: 4rem 2rem; direction: {dir}">{content}</body>
|
78 |
+
</html>
|
79 |
+
"""
|
80 |
+
|
81 |
+
|
82 |
+
DEFAULT_LANG = "en"
|
83 |
+
DEFAULT_DIR = "ltr"
|
84 |
+
|
85 |
+
|
86 |
+
def minify_html(html):
|
87 |
+
"""Perform a template-specific, rudimentary HTML minification for displaCy.
|
88 |
+
Disclaimer: NOT a general-purpose solution, only removes indentation and
|
89 |
+
newlines.
|
90 |
+
|
91 |
+
html (unicode): Markup to minify.
|
92 |
+
RETURNS (unicode): "Minified" HTML.
|
93 |
+
"""
|
94 |
+
return html.strip().replace(" ", "").replace("\n", "")
|
95 |
+
|
96 |
+
|
97 |
+
def escape_html(text):
|
98 |
+
"""Replace <, >, &, " with their HTML encoded representation. Intended to prevent HTML errors
|
99 |
+
in rendered displaCy markup.
|
100 |
+
|
101 |
+
text (unicode): The original text. RETURNS (unicode): Equivalent text to be safely used within
|
102 |
+
HTML.
|
103 |
+
"""
|
104 |
+
text = text.replace("&", "&")
|
105 |
+
text = text.replace("<", "<")
|
106 |
+
text = text.replace(">", ">")
|
107 |
+
text = text.replace('"', """)
|
108 |
+
return text
|
109 |
+
|
110 |
+
|
111 |
+
class EntityRenderer(object):
|
112 |
+
"""Render named entities as HTML."""
|
113 |
+
|
114 |
+
style = "ent"
|
115 |
+
|
116 |
+
def __init__(self, options={}):
|
117 |
+
"""Initialise dependency renderer.
|
118 |
+
|
119 |
+
options (dict): Visualiser-specific options (colors, ents)
|
120 |
+
"""
|
121 |
+
colors = {
|
122 |
+
"ORG": "#7aecec",
|
123 |
+
"PRODUCT": "#bfeeb7",
|
124 |
+
"GPE": "#feca74",
|
125 |
+
"LOC": "#ff9561",
|
126 |
+
"PERSON": "#aa9cfc",
|
127 |
+
"NORP": "#c887fb",
|
128 |
+
"FACILITY": "#9cc9cc",
|
129 |
+
"EVENT": "#ffeb80",
|
130 |
+
"LAW": "#ff8197",
|
131 |
+
"LANGUAGE": "#ff8197",
|
132 |
+
"WORK_OF_ART": "#f0d0ff",
|
133 |
+
"DATE": "#bfe1d9",
|
134 |
+
"TIME": "#bfe1d9",
|
135 |
+
"MONEY": "#e4e7d2",
|
136 |
+
"QUANTITY": "#e4e7d2",
|
137 |
+
"ORDINAL": "#e4e7d2",
|
138 |
+
"CARDINAL": "#e4e7d2",
|
139 |
+
"PERCENT": "#e4e7d2",
|
140 |
+
}
|
141 |
+
# user_colors = registry.displacy_colors.get_all()
|
142 |
+
# for user_color in user_colors.values():
|
143 |
+
# colors.update(user_color)
|
144 |
+
colors.update(options.get("colors", {}))
|
145 |
+
self.default_color = "#ddd"
|
146 |
+
self.colors = colors
|
147 |
+
self.ents = options.get("ents", None)
|
148 |
+
self.direction = DEFAULT_DIR
|
149 |
+
self.lang = DEFAULT_LANG
|
150 |
+
|
151 |
+
template = options.get("template")
|
152 |
+
if template:
|
153 |
+
self.ent_template = template
|
154 |
+
else:
|
155 |
+
if self.direction == "rtl":
|
156 |
+
self.ent_template = TPL_ENT_RTL
|
157 |
+
else:
|
158 |
+
self.ent_template = TPL_ENT
|
159 |
+
|
160 |
+
def render(self, parsed, page=False, minify=False):
|
161 |
+
"""Render complete markup.
|
162 |
+
|
163 |
+
parsed (list): Dependency parses to render. page (bool): Render parses wrapped as full HTML
|
164 |
+
page. minify (bool): Minify HTML markup. RETURNS (unicode): Rendered HTML markup.
|
165 |
+
"""
|
166 |
+
rendered = []
|
167 |
+
for i, p in enumerate(parsed):
|
168 |
+
if i == 0:
|
169 |
+
settings = p.get("settings", {})
|
170 |
+
self.direction = settings.get("direction", DEFAULT_DIR)
|
171 |
+
self.lang = settings.get("lang", DEFAULT_LANG)
|
172 |
+
rendered.append(self.render_ents(p["text"], p["ents"], p.get("title")))
|
173 |
+
if page:
|
174 |
+
docs = "".join([TPL_FIGURE.format(content=doc) for doc in rendered])
|
175 |
+
markup = TPL_PAGE.format(content=docs, lang=self.lang, dir=self.direction)
|
176 |
+
else:
|
177 |
+
markup = "".join(rendered)
|
178 |
+
if minify:
|
179 |
+
return minify_html(markup)
|
180 |
+
return markup
|
181 |
+
|
182 |
+
def render_ents(self, text, spans, title):
|
183 |
+
"""Render entities in text.
|
184 |
+
|
185 |
+
text (unicode): Original text. spans (list): Individual entity spans and their start, end
|
186 |
+
and label. title (unicode or None): Document title set in Doc.user_data['title'].
|
187 |
+
"""
|
188 |
+
markup = ""
|
189 |
+
offset = 0
|
190 |
+
for span in spans:
|
191 |
+
label = span["label"]
|
192 |
+
start = span["start"]
|
193 |
+
end = span["end"]
|
194 |
+
additional_params = span.get("params", {})
|
195 |
+
entity = escape_html(text[start:end])
|
196 |
+
fragments = text[offset:start].split("\n")
|
197 |
+
for i, fragment in enumerate(fragments):
|
198 |
+
markup += escape_html(fragment)
|
199 |
+
if len(fragments) > 1 and i != len(fragments) - 1:
|
200 |
+
markup += "</br>"
|
201 |
+
if self.ents is None or label.upper() in self.ents:
|
202 |
+
color = self.colors.get(label.upper(), self.default_color)
|
203 |
+
ent_settings = {"label": label, "text": entity, "bg": color}
|
204 |
+
ent_settings.update(additional_params)
|
205 |
+
markup += self.ent_template.format(**ent_settings)
|
206 |
+
else:
|
207 |
+
markup += entity
|
208 |
+
offset = end
|
209 |
+
fragments = text[offset:].split("\n")
|
210 |
+
for i, fragment in enumerate(fragments):
|
211 |
+
markup += escape_html(fragment)
|
212 |
+
if len(fragments) > 1 and i != len(fragments) - 1:
|
213 |
+
markup += "</br>"
|
214 |
+
markup = TPL_ENTS.format(content=markup, dir=self.direction)
|
215 |
+
if title:
|
216 |
+
markup = TPL_TITLE.format(title=title) + markup
|
217 |
+
return markup
|