ArneBinder commited on
Commit
ff28cb9
1 Parent(s): 3667108

fix relation highligthing

Browse files
Files changed (1) hide show
  1. app.py +72 -73
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  from typing import Tuple
3
 
4
  import gradio as gr
@@ -10,6 +11,7 @@ from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndL
10
  from pytorch_ie.models import * # noqa: F403
11
  from pytorch_ie.taskmodules import * # noqa: F403
12
 
 
13
  def render_pretty_table(
14
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
15
  ):
@@ -68,9 +70,16 @@ def inject_relation_data(html: str, sorted_entities, binary_relations) -> str:
68
  # Parse the HTML using BeautifulSoup
69
  soup = BeautifulSoup(html, "html.parser")
70
 
 
 
 
 
 
 
 
 
71
  # Add unique IDs to each entity
72
  entities = soup.find_all(class_="entity")
73
- entity2id = {}
74
  for idx, entity in enumerate(entities):
75
  entity["id"] = f"entity-{idx}"
76
  entity["data-original-color"] = (
@@ -80,77 +89,18 @@ def inject_relation_data(html: str, sorted_entities, binary_relations) -> str:
80
  # sanity check
81
  if str(entity_annotation) != entity.next:
82
  raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
83
- entity2id[sorted_entities[idx]] = f"entity-{idx}"
84
-
85
- # Generate prefixed relations
86
- prefixed_relations = [
87
- {
88
- "head": entity2id[relation.head],
89
- "tail": entity2id[relation.tail],
90
- "label": relation.label,
91
- }
92
- for relation in binary_relations
93
- ]
94
-
95
- # Create the JavaScript function to handle mouse over and mouse out events
96
- script = (
97
- """
98
- <script>
99
- function highlightRelations(entityId, relations) {
100
- // Reset all entities' styles
101
- const entities = document.querySelectorAll('.entity');
102
- entities.forEach(entity => {
103
- entity.style.backgroundColor = entity.getAttribute('data-original-color');
104
- entity.style.color = '';
105
- });
106
-
107
- // If an entity is hovered, highlight it and its related entities with different colors
108
- if (entityId !== null) {
109
- const selectedEntity = document.getElementById(entityId);
110
- if (selectedEntity) {
111
- selectedEntity.style.backgroundColor = '#ffa';
112
- selectedEntity.style.color = '#000';
113
- }
114
-
115
- relations.forEach(relation => {
116
- if (relation.head === entityId) {
117
- const tailEntity = document.getElementById(relation.tail);
118
- if (tailEntity) {
119
- tailEntity.style.backgroundColor = '#aff';
120
- tailEntity.style.color = '#000';
121
- }
122
- }
123
- if (relation.tail === entityId) {
124
- const headEntity = document.getElementById(relation.head);
125
- if (headEntity) {
126
- headEntity.style.backgroundColor = '#faf';
127
- headEntity.style.color = '#000';
128
- }
129
- }
130
- });
131
- }
132
- }
133
-
134
- // Event listeners for mouse over and mouse out on each entity
135
- document.addEventListener('DOMContentLoaded', (event) => {
136
- const relations = %s;
137
- const entities = document.querySelectorAll('.entity');
138
- entities.forEach(entity => {
139
- entity.addEventListener('mouseover', () => {
140
- highlightRelations(entity.id, relations);
141
- });
142
- entity.addEventListener('mouseout', () => {
143
- highlightRelations(null, relations);
144
- });
145
- });
146
- });
147
- </script>
148
- """
149
- % prefixed_relations
150
- )
151
-
152
- # Inject the script into the HTML
153
- soup.body.append(BeautifulSoup(script, "html.parser"))
154
 
155
  # Return the modified HTML as a string
156
  return str(soup)
@@ -264,5 +214,54 @@ if __name__ == "__main__":
264
  )
265
  render_btn.click(**render_button_kwargs, api_name="render")
266
 
267
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
 
 
 
 
1
  import json
2
+ from collections import defaultdict
3
  from typing import Tuple
4
 
5
  import gradio as gr
 
11
  from pytorch_ie.models import * # noqa: F403
12
  from pytorch_ie.taskmodules import * # noqa: F403
13
 
14
+
15
  def render_pretty_table(
16
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
17
  ):
 
70
  # Parse the HTML using BeautifulSoup
71
  soup = BeautifulSoup(html, "html.parser")
72
 
73
+ entity2tails = defaultdict(list)
74
+ entity2heads = defaultdict(list)
75
+ for relation in binary_relations:
76
+ entity2heads[relation.tail].append((relation.head, relation.label))
77
+ entity2tails[relation.head].append((relation.tail, relation.label))
78
+
79
+ entity2id = {entity: f"entity-{idx}" for idx, entity in enumerate(sorted_entities)}
80
+
81
  # Add unique IDs to each entity
82
  entities = soup.find_all(class_="entity")
 
83
  for idx, entity in enumerate(entities):
84
  entity["id"] = f"entity-{idx}"
85
  entity["data-original-color"] = (
 
89
  # sanity check
90
  if str(entity_annotation) != entity.next:
91
  raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
92
+ entity["data-relation-tails"] = json.dumps(
93
+ [
94
+ {"entity-id": entity2id[tail], "label": label}
95
+ for tail, label in entity2tails.get(entity_annotation, [])
96
+ ]
97
+ )
98
+ entity["data-relation-heads"] = json.dumps(
99
+ [
100
+ {"entity-id": entity2id[head], "label": label}
101
+ for head, label in entity2heads.get(entity_annotation, [])
102
+ ]
103
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Return the modified HTML as a string
106
  return str(soup)
 
214
  )
215
  render_btn.click(**render_button_kwargs, api_name="render")
216
 
217
+ js = """
218
+ () => {
219
+ function highlightRelations(entityId) {
220
+ const entities = document.querySelectorAll('.entity');
221
+ entities.forEach(entity => {
222
+ entity.style.backgroundColor = entity.getAttribute('data-original-color');
223
+ entity.style.color = '';
224
+ });
225
+
226
+ if (entityId !== null) {
227
+ const selectedEntity = document.getElementById(entityId);
228
+ if (selectedEntity) {
229
+ selectedEntity.style.backgroundColor = '#ffa';
230
+ selectedEntity.style.color = '#000';
231
+ }
232
+ // highlight tails
233
+ const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
234
+ relationTailsAndLabels.forEach(relationTail => {
235
+ const tailEntity = document.getElementById(relationTail['entity-id']);
236
+ if (tailEntity) {
237
+ tailEntity.style.backgroundColor = '#aff';
238
+ tailEntity.style.color = '#000';
239
+ }
240
+ });
241
+ // highlight heads
242
+ const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
243
+ relationHeadsAndLabels.forEach(relationHead => {
244
+ const headEntity = document.getElementById(relationHead['entity-id']);
245
+ if (headEntity) {
246
+ headEntity.style.backgroundColor = '#faf';
247
+ headEntity.style.color = '#000';
248
+ }
249
+ });
250
+ }
251
+ }
252
+
253
+ const entities = document.querySelectorAll('.entity');
254
+ entities.forEach(entity => {
255
+ entity.addEventListener('mouseover', () => {
256
+ highlightRelations(entity.id);
257
+ });
258
+ entity.addEventListener('mouseout', () => {
259
+ highlightRelations(null);
260
+ });
261
+ });
262
+ }
263
+ """
264
 
265
+ rendered_output.change(fn=None, js=js, inputs=[], outputs=[])
266
+
267
+ demo.launch()