ArneBinder commited on
Commit
70fea2e
1 Parent(s): ff28cb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -21
app.py CHANGED
@@ -1,11 +1,11 @@
1
  import json
2
  from collections import defaultdict
3
- from typing import Tuple
4
 
5
  import gradio as gr
6
  from pie_modules.models import * # noqa: F403
7
  from pie_modules.taskmodules import * # noqa: F403
8
- from pytorch_ie.annotations import LabeledSpan
9
  from pytorch_ie.auto import AutoPipeline
10
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
11
  from pytorch_ie.models import * # noqa: F403
@@ -33,6 +33,7 @@ def render_spacy(
33
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
34
  style="ent",
35
  inject_relations=True,
 
36
  **render_kwargs,
37
  ):
38
  from spacy import displacy
@@ -51,20 +52,25 @@ def render_spacy(
51
  )
52
  html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
53
  if inject_relations:
54
- print("Injecting relation data")
55
  binary_relations = list(document.binary_relations) + list(
56
  document.binary_relations.predictions
57
  )
58
  sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
59
  html = inject_relation_data(
60
- html, sorted_entities=sorted_entities, binary_relations=binary_relations
 
 
 
61
  )
62
- else:
63
- print("Not injecting relation data")
64
  return html
65
 
66
 
67
- def inject_relation_data(html: str, sorted_entities, binary_relations) -> str:
 
 
 
 
 
68
  from bs4 import BeautifulSoup
69
 
70
  # Parse the HTML using BeautifulSoup
@@ -82,13 +88,18 @@ def inject_relation_data(html: str, sorted_entities, binary_relations) -> str:
82
  entities = soup.find_all(class_="entity")
83
  for idx, entity in enumerate(entities):
84
  entity["id"] = f"entity-{idx}"
85
- entity["data-original-color"] = (
86
- entity["style"].split("background:")[1].split(";")[0].strip()
87
- )
 
 
 
 
88
  entity_annotation = sorted_entities[idx]
89
  # sanity check
90
  if str(entity_annotation) != entity.next:
91
  raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
 
92
  entity["data-relation-tails"] = json.dumps(
93
  [
94
  {"entity-id": entity2id[tail], "label": label}
@@ -165,10 +176,24 @@ if __name__ == "__main__":
165
  # we need to convert the keys to uppercase because the spacy rendering function expects them in uppercase
166
  "colors": {
167
  "own_claim".upper(): "#009933",
168
- "background_claim".upper(): "#0033cc",
169
  "data".upper(): "#993399",
170
  }
171
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  }
173
 
174
  with gr.Blocks() as demo:
@@ -216,26 +241,47 @@ if __name__ == "__main__":
216
 
217
  js = """
218
  () => {
219
- function highlightRelations(entityId) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  const entities = document.querySelectorAll('.entity');
 
221
  entities.forEach(entity => {
222
- entity.style.backgroundColor = entity.getAttribute('data-original-color');
 
223
  entity.style.color = '';
224
  });
225
 
226
  if (entityId !== null) {
 
 
227
  const selectedEntity = document.getElementById(entityId);
228
  if (selectedEntity) {
229
- selectedEntity.style.backgroundColor = '#ffa';
230
- selectedEntity.style.color = '#000';
 
231
  }
232
  // highlight tails
233
  const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
234
  relationTailsAndLabels.forEach(relationTail => {
235
  const tailEntity = document.getElementById(relationTail['entity-id']);
236
  if (tailEntity) {
237
- tailEntity.style.backgroundColor = '#aff';
238
- tailEntity.style.color = '#000';
 
239
  }
240
  });
241
  // highlight heads
@@ -243,8 +289,16 @@ if __name__ == "__main__":
243
  relationHeadsAndLabels.forEach(relationHead => {
244
  const headEntity = document.getElementById(relationHead['entity-id']);
245
  if (headEntity) {
246
- headEntity.style.backgroundColor = '#faf';
247
- headEntity.style.color = '#000';
 
 
 
 
 
 
 
 
248
  }
249
  });
250
  }
@@ -252,12 +306,17 @@ if __name__ == "__main__":
252
 
253
  const entities = document.querySelectorAll('.entity');
254
  entities.forEach(entity => {
 
 
 
 
255
  entity.addEventListener('mouseover', () => {
256
- highlightRelations(entity.id);
257
  });
258
  entity.addEventListener('mouseout', () => {
259
- highlightRelations(null);
260
  });
 
261
  });
262
  }
263
  """
 
1
  import json
2
  from collections import defaultdict
3
+ from typing import Dict, List, Optional, Tuple, Union
4
 
5
  import gradio as gr
6
  from pie_modules.models import * # noqa: F403
7
  from pie_modules.taskmodules import * # noqa: F403
8
+ from pytorch_ie.annotations import BinaryRelation, LabeledSpan
9
  from pytorch_ie.auto import AutoPipeline
10
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
11
  from pytorch_ie.models import * # noqa: F403
 
33
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
34
  style="ent",
35
  inject_relations=True,
36
+ colors_hover=None,
37
  **render_kwargs,
38
  ):
39
  from spacy import displacy
 
52
  )
53
  html = "<div style='max-width:100%; max-height:360px; overflow:auto'>" + html + "</div>"
54
  if inject_relations:
 
55
  binary_relations = list(document.binary_relations) + list(
56
  document.binary_relations.predictions
57
  )
58
  sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
59
  html = inject_relation_data(
60
+ html,
61
+ sorted_entities=sorted_entities,
62
+ binary_relations=binary_relations,
63
+ additional_colors=colors_hover,
64
  )
 
 
65
  return html
66
 
67
 
68
+ def inject_relation_data(
69
+ html: str,
70
+ sorted_entities,
71
+ binary_relations: List[BinaryRelation],
72
+ additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
73
+ ) -> str:
74
  from bs4 import BeautifulSoup
75
 
76
  # Parse the HTML using BeautifulSoup
 
88
  entities = soup.find_all(class_="entity")
89
  for idx, entity in enumerate(entities):
90
  entity["id"] = f"entity-{idx}"
91
+ original_color = entity["style"].split("background:")[1].split(";")[0].strip()
92
+ entity["data-color-original"] = original_color
93
+ if additional_colors is not None:
94
+ for key, color in additional_colors.items():
95
+ entity[f"data-color-{key}"] = (
96
+ json.dumps(color) if isinstance(color, dict) else color
97
+ )
98
  entity_annotation = sorted_entities[idx]
99
  # sanity check
100
  if str(entity_annotation) != entity.next:
101
  raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
102
+ entity["data-label"] = entity_annotation.label
103
  entity["data-relation-tails"] = json.dumps(
104
  [
105
  {"entity-id": entity2id[tail], "label": label}
 
176
  # we need to convert the keys to uppercase because the spacy rendering function expects them in uppercase
177
  "colors": {
178
  "own_claim".upper(): "#009933",
179
+ "background_claim".upper(): "#99ccff",
180
  "data".upper(): "#993399",
181
  }
182
  },
183
+ "colors_hover": {
184
+ "selected": "#ffa",
185
+ # "tail": "#aff",
186
+ "tail": {
187
+ # green
188
+ "supports": "#9f9",
189
+ # red
190
+ "contradicts": "#f99",
191
+ # do not highlight
192
+ "parts_of_same": None,
193
+ },
194
+ "head": None, # "#faf",
195
+ "other": None,
196
+ },
197
  }
198
 
199
  with gr.Blocks() as demo:
 
241
 
242
  js = """
243
  () => {
244
+ function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
245
+ var color = entity.getAttribute('data-color-' + colorAttributeKey);
246
+ // if color is a json string, parse it and use the value at colorDictKey
247
+ try {
248
+ const colors = JSON.parse(color);
249
+ color = colors[colorDictKey];
250
+ } catch (e) {}
251
+ if (color) {
252
+ console.log('setting color', color);
253
+ console.log('entity', entity);
254
+ entity.style.backgroundColor = color;
255
+ entity.style.color = '#000';
256
+ }
257
+ }
258
+
259
+ function highlightRelationArguments(entityId) {
260
  const entities = document.querySelectorAll('.entity');
261
+ // reset all entities
262
  entities.forEach(entity => {
263
+ const color = entity.getAttribute('data-color-original');
264
+ entity.style.backgroundColor = color;
265
  entity.style.color = '';
266
  });
267
 
268
  if (entityId !== null) {
269
+ var visitedEntities = new Set();
270
+ // highlight selected entity
271
  const selectedEntity = document.getElementById(entityId);
272
  if (selectedEntity) {
273
+ const label = selectedEntity.getAttribute('data-label');
274
+ maybeSetColor(selectedEntity, 'selected', label);
275
+ visitedEntities.add(selectedEntity);
276
  }
277
  // highlight tails
278
  const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
279
  relationTailsAndLabels.forEach(relationTail => {
280
  const tailEntity = document.getElementById(relationTail['entity-id']);
281
  if (tailEntity) {
282
+ const label = relationTail['label'];
283
+ maybeSetColor(tailEntity, 'tail', label);
284
+ visitedEntities.add(tailEntity);
285
  }
286
  });
287
  // highlight heads
 
289
  relationHeadsAndLabels.forEach(relationHead => {
290
  const headEntity = document.getElementById(relationHead['entity-id']);
291
  if (headEntity) {
292
+ const label = relationHead['label'];
293
+ maybeSetColor(headEntity, 'head', label);
294
+ visitedEntities.add(headEntity);
295
+ }
296
+ });
297
+ // highlight other entities
298
+ entities.forEach(entity => {
299
+ if (!visitedEntities.has(entity)) {
300
+ const label = entity.getAttribute('data-label');
301
+ maybeSetColor(entity, 'other', label);
302
  }
303
  });
304
  }
 
306
 
307
  const entities = document.querySelectorAll('.entity');
308
  entities.forEach(entity => {
309
+ const alreadyHasListener = entity.getAttribute('data-has-listener');
310
+ if (alreadyHasListener) {
311
+ return;
312
+ }
313
  entity.addEventListener('mouseover', () => {
314
+ highlightRelationArguments(entity.id);
315
  });
316
  entity.addEventListener('mouseout', () => {
317
+ highlightRelationArguments(null);
318
  });
319
+ entity.setAttribute('data-has-listener', 'true');
320
  });
321
  }
322
  """