fix relation highligthing
Browse files
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import json
|
|
|
2 |
from typing import Tuple
|
3 |
|
4 |
import gradio as gr
|
@@ -10,6 +11,7 @@ from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndL
|
|
10 |
from pytorch_ie.models import * # noqa: F403
|
11 |
from pytorch_ie.taskmodules import * # noqa: F403
|
12 |
|
|
|
13 |
def render_pretty_table(
|
14 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
|
15 |
):
|
@@ -68,9 +70,16 @@ def inject_relation_data(html: str, sorted_entities, binary_relations) -> str:
|
|
68 |
# Parse the HTML using BeautifulSoup
|
69 |
soup = BeautifulSoup(html, "html.parser")
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
# Add unique IDs to each entity
|
72 |
entities = soup.find_all(class_="entity")
|
73 |
-
entity2id = {}
|
74 |
for idx, entity in enumerate(entities):
|
75 |
entity["id"] = f"entity-{idx}"
|
76 |
entity["data-original-color"] = (
|
@@ -80,77 +89,18 @@ def inject_relation_data(html: str, sorted_entities, binary_relations) -> str:
|
|
80 |
# sanity check
|
81 |
if str(entity_annotation) != entity.next:
|
82 |
raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
# Create the JavaScript function to handle mouse over and mouse out events
|
96 |
-
script = (
|
97 |
-
"""
|
98 |
-
<script>
|
99 |
-
function highlightRelations(entityId, relations) {
|
100 |
-
// Reset all entities' styles
|
101 |
-
const entities = document.querySelectorAll('.entity');
|
102 |
-
entities.forEach(entity => {
|
103 |
-
entity.style.backgroundColor = entity.getAttribute('data-original-color');
|
104 |
-
entity.style.color = '';
|
105 |
-
});
|
106 |
-
|
107 |
-
// If an entity is hovered, highlight it and its related entities with different colors
|
108 |
-
if (entityId !== null) {
|
109 |
-
const selectedEntity = document.getElementById(entityId);
|
110 |
-
if (selectedEntity) {
|
111 |
-
selectedEntity.style.backgroundColor = '#ffa';
|
112 |
-
selectedEntity.style.color = '#000';
|
113 |
-
}
|
114 |
-
|
115 |
-
relations.forEach(relation => {
|
116 |
-
if (relation.head === entityId) {
|
117 |
-
const tailEntity = document.getElementById(relation.tail);
|
118 |
-
if (tailEntity) {
|
119 |
-
tailEntity.style.backgroundColor = '#aff';
|
120 |
-
tailEntity.style.color = '#000';
|
121 |
-
}
|
122 |
-
}
|
123 |
-
if (relation.tail === entityId) {
|
124 |
-
const headEntity = document.getElementById(relation.head);
|
125 |
-
if (headEntity) {
|
126 |
-
headEntity.style.backgroundColor = '#faf';
|
127 |
-
headEntity.style.color = '#000';
|
128 |
-
}
|
129 |
-
}
|
130 |
-
});
|
131 |
-
}
|
132 |
-
}
|
133 |
-
|
134 |
-
// Event listeners for mouse over and mouse out on each entity
|
135 |
-
document.addEventListener('DOMContentLoaded', (event) => {
|
136 |
-
const relations = %s;
|
137 |
-
const entities = document.querySelectorAll('.entity');
|
138 |
-
entities.forEach(entity => {
|
139 |
-
entity.addEventListener('mouseover', () => {
|
140 |
-
highlightRelations(entity.id, relations);
|
141 |
-
});
|
142 |
-
entity.addEventListener('mouseout', () => {
|
143 |
-
highlightRelations(null, relations);
|
144 |
-
});
|
145 |
-
});
|
146 |
-
});
|
147 |
-
</script>
|
148 |
-
"""
|
149 |
-
% prefixed_relations
|
150 |
-
)
|
151 |
-
|
152 |
-
# Inject the script into the HTML
|
153 |
-
soup.body.append(BeautifulSoup(script, "html.parser"))
|
154 |
|
155 |
# Return the modified HTML as a string
|
156 |
return str(soup)
|
@@ -264,5 +214,54 @@ if __name__ == "__main__":
|
|
264 |
)
|
265 |
render_btn.click(**render_button_kwargs, api_name="render")
|
266 |
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
268 |
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
from collections import defaultdict
|
3 |
from typing import Tuple
|
4 |
|
5 |
import gradio as gr
|
|
|
11 |
from pytorch_ie.models import * # noqa: F403
|
12 |
from pytorch_ie.taskmodules import * # noqa: F403
|
13 |
|
14 |
+
|
15 |
def render_pretty_table(
|
16 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
|
17 |
):
|
|
|
70 |
# Parse the HTML using BeautifulSoup
|
71 |
soup = BeautifulSoup(html, "html.parser")
|
72 |
|
73 |
+
entity2tails = defaultdict(list)
|
74 |
+
entity2heads = defaultdict(list)
|
75 |
+
for relation in binary_relations:
|
76 |
+
entity2heads[relation.tail].append((relation.head, relation.label))
|
77 |
+
entity2tails[relation.head].append((relation.tail, relation.label))
|
78 |
+
|
79 |
+
entity2id = {entity: f"entity-{idx}" for idx, entity in enumerate(sorted_entities)}
|
80 |
+
|
81 |
# Add unique IDs to each entity
|
82 |
entities = soup.find_all(class_="entity")
|
|
|
83 |
for idx, entity in enumerate(entities):
|
84 |
entity["id"] = f"entity-{idx}"
|
85 |
entity["data-original-color"] = (
|
|
|
89 |
# sanity check
|
90 |
if str(entity_annotation) != entity.next:
|
91 |
raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
|
92 |
+
entity["data-relation-tails"] = json.dumps(
|
93 |
+
[
|
94 |
+
{"entity-id": entity2id[tail], "label": label}
|
95 |
+
for tail, label in entity2tails.get(entity_annotation, [])
|
96 |
+
]
|
97 |
+
)
|
98 |
+
entity["data-relation-heads"] = json.dumps(
|
99 |
+
[
|
100 |
+
{"entity-id": entity2id[head], "label": label}
|
101 |
+
for head, label in entity2heads.get(entity_annotation, [])
|
102 |
+
]
|
103 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
# Return the modified HTML as a string
|
106 |
return str(soup)
|
|
|
214 |
)
|
215 |
render_btn.click(**render_button_kwargs, api_name="render")
|
216 |
|
217 |
+
js = """
|
218 |
+
() => {
|
219 |
+
function highlightRelations(entityId) {
|
220 |
+
const entities = document.querySelectorAll('.entity');
|
221 |
+
entities.forEach(entity => {
|
222 |
+
entity.style.backgroundColor = entity.getAttribute('data-original-color');
|
223 |
+
entity.style.color = '';
|
224 |
+
});
|
225 |
+
|
226 |
+
if (entityId !== null) {
|
227 |
+
const selectedEntity = document.getElementById(entityId);
|
228 |
+
if (selectedEntity) {
|
229 |
+
selectedEntity.style.backgroundColor = '#ffa';
|
230 |
+
selectedEntity.style.color = '#000';
|
231 |
+
}
|
232 |
+
// highlight tails
|
233 |
+
const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
|
234 |
+
relationTailsAndLabels.forEach(relationTail => {
|
235 |
+
const tailEntity = document.getElementById(relationTail['entity-id']);
|
236 |
+
if (tailEntity) {
|
237 |
+
tailEntity.style.backgroundColor = '#aff';
|
238 |
+
tailEntity.style.color = '#000';
|
239 |
+
}
|
240 |
+
});
|
241 |
+
// highlight heads
|
242 |
+
const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
|
243 |
+
relationHeadsAndLabels.forEach(relationHead => {
|
244 |
+
const headEntity = document.getElementById(relationHead['entity-id']);
|
245 |
+
if (headEntity) {
|
246 |
+
headEntity.style.backgroundColor = '#faf';
|
247 |
+
headEntity.style.color = '#000';
|
248 |
+
}
|
249 |
+
});
|
250 |
+
}
|
251 |
+
}
|
252 |
+
|
253 |
+
const entities = document.querySelectorAll('.entity');
|
254 |
+
entities.forEach(entity => {
|
255 |
+
entity.addEventListener('mouseover', () => {
|
256 |
+
highlightRelations(entity.id);
|
257 |
+
});
|
258 |
+
entity.addEventListener('mouseout', () => {
|
259 |
+
highlightRelations(null);
|
260 |
+
});
|
261 |
+
});
|
262 |
+
}
|
263 |
+
"""
|
264 |
|
265 |
+
rendered_output.change(fn=None, js=js, inputs=[], outputs=[])
|
266 |
+
|
267 |
+
demo.launch()
|