ArneBinder commited on
Commit
bc6f57a
1 Parent(s): 7d208a6

add rendering utils

Browse files
Files changed (2) hide show
  1. rendering_utils.py +113 -0
  2. rendering_utils_displacy.py +217 -0
rendering_utils.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections import defaultdict
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from pytorch_ie.annotations import BinaryRelation
6
+ from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
7
+
8
+ from .rendering_utils_displacy import EntityRenderer
9
+
10
+
11
+ def render_pretty_table(
12
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
13
+ ):
14
+ from prettytable import PrettyTable
15
+
16
+ t = PrettyTable()
17
+ t.field_names = ["head", "tail", "relation"]
18
+ t.align = "l"
19
+ for relation in list(document.binary_relations) + list(document.binary_relations.predictions):
20
+ t.add_row([str(relation.head), str(relation.tail), relation.label])
21
+
22
+ html = t.get_html_string(format=True)
23
+ html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
24
+
25
+ return html
26
+
27
+
28
+ def render_spacy(
29
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
30
+ style="ent",
31
+ inject_relations=True,
32
+ colors_hover=None,
33
+ options={},
34
+ **render_kwargs,
35
+ ):
36
+
37
+ spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
38
+ spacy_doc = {
39
+ "text": document.text,
40
+ "ents": [
41
+ {"start": entity.start, "end": entity.end, "label": entity.label} for entity in spans
42
+ ],
43
+ "title": None,
44
+ }
45
+
46
+ renderer = EntityRenderer(options=options)
47
+ html = renderer.render([spacy_doc], page=True, minify=True).strip()
48
+
49
+ html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
50
+ if inject_relations:
51
+ binary_relations = list(document.binary_relations) + list(
52
+ document.binary_relations.predictions
53
+ )
54
+ sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
55
+ html = inject_relation_data(
56
+ html,
57
+ sorted_entities=sorted_entities,
58
+ binary_relations=binary_relations,
59
+ additional_colors=colors_hover,
60
+ )
61
+ return html
62
+
63
+
64
+ def inject_relation_data(
65
+ html: str,
66
+ sorted_entities,
67
+ binary_relations: List[BinaryRelation],
68
+ additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
69
+ ) -> str:
70
+ from bs4 import BeautifulSoup
71
+
72
+ # Parse the HTML using BeautifulSoup
73
+ soup = BeautifulSoup(html, "html.parser")
74
+
75
+ entity2tails = defaultdict(list)
76
+ entity2heads = defaultdict(list)
77
+ for relation in binary_relations:
78
+ entity2heads[relation.tail].append((relation.head, relation.label))
79
+ entity2tails[relation.head].append((relation.tail, relation.label))
80
+
81
+ entity2id = {entity: f"entity-{idx}" for idx, entity in enumerate(sorted_entities)}
82
+
83
+ # Add unique IDs to each entity
84
+ entities = soup.find_all(class_="entity")
85
+ for idx, entity in enumerate(entities):
86
+ entity["id"] = f"entity-{idx}"
87
+ original_color = entity["style"].split("background:")[1].split(";")[0].strip()
88
+ entity["data-color-original"] = original_color
89
+ if additional_colors is not None:
90
+ for key, color in additional_colors.items():
91
+ entity[f"data-color-{key}"] = (
92
+ json.dumps(color) if isinstance(color, dict) else color
93
+ )
94
+ entity_annotation = sorted_entities[idx]
95
+ # sanity check
96
+ if str(entity_annotation) != entity.next:
97
+ raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
98
+ entity["data-label"] = entity_annotation.label
99
+ entity["data-relation-tails"] = json.dumps(
100
+ [
101
+ {"entity-id": entity2id[tail], "label": label}
102
+ for tail, label in entity2tails.get(entity_annotation, [])
103
+ ]
104
+ )
105
+ entity["data-relation-heads"] = json.dumps(
106
+ [
107
+ {"entity-id": entity2id[head], "label": label}
108
+ for head, label in entity2heads.get(entity_annotation, [])
109
+ ]
110
+ )
111
+
112
+ # Return the modified HTML as a string
113
+ return str(soup)
rendering_utils_displacy.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is mainly taken from
2
+ # https://github.com/explosion/spaCy/blob/master/spacy/displacy/templates.py, and from
3
+ # https://github.com/explosion/spaCy/blob/master/spacy/displacy/render.py.
4
+
5
+ # Setting explicit height and max-width: none on the SVG is required for
6
+ # Jupyter to render it properly in a cell
7
+
8
+ TPL_DEP_SVG = """
9
+ <svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:lang="{lang}" id="{id}" class="displacy" width="{width}" height="{height}" direction="{dir}" style="max-width: none; height: {height}px; color: {color}; background: {bg}; font-family: {font}; direction: {dir}">{content}</svg>
10
+ """
11
+
12
+
13
+ TPL_DEP_WORDS = """
14
+ <text class="displacy-token" fill="currentColor" text-anchor="middle" y="{y}">
15
+ <tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan>
16
+ <tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan>
17
+ </text>
18
+ """
19
+
20
+
21
+ TPL_DEP_WORDS_LEMMA = """
22
+ <text class="displacy-token" fill="currentColor" text-anchor="middle" y="{y}">
23
+ <tspan class="displacy-word" fill="currentColor" x="{x}">{text}</tspan>
24
+ <tspan class="displacy-lemma" dy="2em" fill="currentColor" x="{x}">{lemma}</tspan>
25
+ <tspan class="displacy-tag" dy="2em" fill="currentColor" x="{x}">{tag}</tspan>
26
+ </text>
27
+ """
28
+
29
+
30
+ TPL_DEP_ARCS = """
31
+ <g class="displacy-arrow">
32
+ <path class="displacy-arc" id="arrow-{id}-{i}" stroke-width="{stroke}px" d="{arc}" fill="none" stroke="currentColor"/>
33
+ <text dy="1.25em" style="font-size: 0.8em; letter-spacing: 1px">
34
+ <textPath xlink:href="#arrow-{id}-{i}" class="displacy-label" startOffset="50%" side="{label_side}" fill="currentColor" text-anchor="middle">{label}</textPath>
35
+ </text>
36
+ <path class="displacy-arrowhead" d="{head}" fill="currentColor"/>
37
+ </g>
38
+ """
39
+
40
+
41
+ TPL_FIGURE = """
42
+ <figure style="margin-bottom: 6rem">{content}</figure>
43
+ """
44
+
45
+ TPL_TITLE = """
46
+ <h2 style="margin: 0">{title}</h2>
47
+ """
48
+
49
+
50
+ TPL_ENTS = """
51
+ <div class="entities" style="line-height: 2.5; direction: {dir}">{content}</div>
52
+ """
53
+
54
+
55
+ TPL_ENT = """
56
+ <mark class="entity" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
57
+ {text}
58
+ <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
59
+ </mark>
60
+ """
61
+
62
+ TPL_ENT_RTL = """
63
+ <mark class="entity" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em">
64
+ {text}
65
+ <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-right: 0.5rem">{label}</span>
66
+ </mark>
67
+ """
68
+
69
+
70
+ TPL_PAGE = """
71
+ <!DOCTYPE html>
72
+ <html lang="{lang}">
73
+ <head>
74
+ <title>displaCy</title>
75
+ </head>
76
+
77
+ <body style="font-size: 16px; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif, 'Apple Color Emoji', 'Segoe UI Emoji', 'Segoe UI Symbol'; padding: 4rem 2rem; direction: {dir}">{content}</body>
78
+ </html>
79
+ """
80
+
81
+
82
+ DEFAULT_LANG = "en"
83
+ DEFAULT_DIR = "ltr"
84
+
85
+
86
+ def minify_html(html):
87
+ """Perform a template-specific, rudimentary HTML minification for displaCy.
88
+ Disclaimer: NOT a general-purpose solution, only removes indentation and
89
+ newlines.
90
+
91
+ html (unicode): Markup to minify.
92
+ RETURNS (unicode): "Minified" HTML.
93
+ """
94
+ return html.strip().replace(" ", "").replace("\n", "")
95
+
96
+
97
+ def escape_html(text):
98
+ """Replace <, >, &, " with their HTML encoded representation. Intended to prevent HTML errors
99
+ in rendered displaCy markup.
100
+
101
+ text (unicode): The original text. RETURNS (unicode): Equivalent text to be safely used within
102
+ HTML.
103
+ """
104
+ text = text.replace("&", "&amp;")
105
+ text = text.replace("<", "&lt;")
106
+ text = text.replace(">", "&gt;")
107
+ text = text.replace('"', "&quot;")
108
+ return text
109
+
110
+
111
+ class EntityRenderer(object):
112
+ """Render named entities as HTML."""
113
+
114
+ style = "ent"
115
+
116
+ def __init__(self, options={}):
117
+ """Initialise dependency renderer.
118
+
119
+ options (dict): Visualiser-specific options (colors, ents)
120
+ """
121
+ colors = {
122
+ "ORG": "#7aecec",
123
+ "PRODUCT": "#bfeeb7",
124
+ "GPE": "#feca74",
125
+ "LOC": "#ff9561",
126
+ "PERSON": "#aa9cfc",
127
+ "NORP": "#c887fb",
128
+ "FACILITY": "#9cc9cc",
129
+ "EVENT": "#ffeb80",
130
+ "LAW": "#ff8197",
131
+ "LANGUAGE": "#ff8197",
132
+ "WORK_OF_ART": "#f0d0ff",
133
+ "DATE": "#bfe1d9",
134
+ "TIME": "#bfe1d9",
135
+ "MONEY": "#e4e7d2",
136
+ "QUANTITY": "#e4e7d2",
137
+ "ORDINAL": "#e4e7d2",
138
+ "CARDINAL": "#e4e7d2",
139
+ "PERCENT": "#e4e7d2",
140
+ }
141
+ # user_colors = registry.displacy_colors.get_all()
142
+ # for user_color in user_colors.values():
143
+ # colors.update(user_color)
144
+ colors.update(options.get("colors", {}))
145
+ self.default_color = "#ddd"
146
+ self.colors = colors
147
+ self.ents = options.get("ents", None)
148
+ self.direction = DEFAULT_DIR
149
+ self.lang = DEFAULT_LANG
150
+
151
+ template = options.get("template")
152
+ if template:
153
+ self.ent_template = template
154
+ else:
155
+ if self.direction == "rtl":
156
+ self.ent_template = TPL_ENT_RTL
157
+ else:
158
+ self.ent_template = TPL_ENT
159
+
160
+ def render(self, parsed, page=False, minify=False):
161
+ """Render complete markup.
162
+
163
+ parsed (list): Dependency parses to render. page (bool): Render parses wrapped as full HTML
164
+ page. minify (bool): Minify HTML markup. RETURNS (unicode): Rendered HTML markup.
165
+ """
166
+ rendered = []
167
+ for i, p in enumerate(parsed):
168
+ if i == 0:
169
+ settings = p.get("settings", {})
170
+ self.direction = settings.get("direction", DEFAULT_DIR)
171
+ self.lang = settings.get("lang", DEFAULT_LANG)
172
+ rendered.append(self.render_ents(p["text"], p["ents"], p.get("title")))
173
+ if page:
174
+ docs = "".join([TPL_FIGURE.format(content=doc) for doc in rendered])
175
+ markup = TPL_PAGE.format(content=docs, lang=self.lang, dir=self.direction)
176
+ else:
177
+ markup = "".join(rendered)
178
+ if minify:
179
+ return minify_html(markup)
180
+ return markup
181
+
182
+ def render_ents(self, text, spans, title):
183
+ """Render entities in text.
184
+
185
+ text (unicode): Original text. spans (list): Individual entity spans and their start, end
186
+ and label. title (unicode or None): Document title set in Doc.user_data['title'].
187
+ """
188
+ markup = ""
189
+ offset = 0
190
+ for span in spans:
191
+ label = span["label"]
192
+ start = span["start"]
193
+ end = span["end"]
194
+ additional_params = span.get("params", {})
195
+ entity = escape_html(text[start:end])
196
+ fragments = text[offset:start].split("\n")
197
+ for i, fragment in enumerate(fragments):
198
+ markup += escape_html(fragment)
199
+ if len(fragments) > 1 and i != len(fragments) - 1:
200
+ markup += "</br>"
201
+ if self.ents is None or label.upper() in self.ents:
202
+ color = self.colors.get(label.upper(), self.default_color)
203
+ ent_settings = {"label": label, "text": entity, "bg": color}
204
+ ent_settings.update(additional_params)
205
+ markup += self.ent_template.format(**ent_settings)
206
+ else:
207
+ markup += entity
208
+ offset = end
209
+ fragments = text[offset:].split("\n")
210
+ for i, fragment in enumerate(fragments):
211
+ markup += escape_html(fragment)
212
+ if len(fragments) > 1 and i != len(fragments) - 1:
213
+ markup += "</br>"
214
+ markup = TPL_ENTS.format(content=markup, dir=self.direction)
215
+ if title:
216
+ markup = TPL_TITLE.format(title=title) + markup
217
+ return markup