ArneBinder
commited on
Commit
•
efae5be
1
Parent(s):
47cc11e
from https://github.com/ArneBinder/pie-document-level/pull/243
Browse files- annotation_utils.py +26 -5
- app.py +63 -107
- document_store.py +35 -25
- embedding.py +46 -13
- model_utils.py +29 -5
- rendering_utils.py +168 -22
- requirements.txt +1 -0
annotation_utils.py
CHANGED
@@ -1,10 +1,31 @@
|
|
1 |
-
from
|
2 |
|
|
|
3 |
|
4 |
-
def labeled_span_to_id(span: LabeledSpan) -> str:
|
5 |
-
return f"span-{span.start}-{span.end}-{span.label}"
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
|
|
|
9 |
parts = span_id.split("-")
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
|
3 |
+
from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan
|
4 |
|
|
|
|
|
5 |
|
6 |
+
def labeled_span_to_id(span: Union[LabeledSpan, LabeledMultiSpan]) -> str:
|
7 |
+
if isinstance(span, LabeledSpan):
|
8 |
+
# {type indicator}-{start}-{end}-{label}
|
9 |
+
return f"span-{span.start}-{span.end}-{span.label}"
|
10 |
+
elif isinstance(span, LabeledMultiSpan):
|
11 |
+
# {type indicator}-({start}-{end})*-{label
|
12 |
+
starts_ends = "-".join(f"{start}-{end}" for start, end in span.slices)
|
13 |
+
return f"multispan-{starts_ends}-{span.label}"
|
14 |
+
else:
|
15 |
+
raise ValueError(f"Unsupported span type: {type(span)}")
|
16 |
|
17 |
+
|
18 |
+
def labeled_span_from_id(span_id: str) -> Union[LabeledSpan, LabeledMultiSpan]:
|
19 |
parts = span_id.split("-")
|
20 |
+
if parts[0] == "span":
|
21 |
+
return LabeledSpan(int(parts[1]), int(parts[2]), parts[3])
|
22 |
+
elif parts[0] == "multispan":
|
23 |
+
label = parts[-1]
|
24 |
+
# this contains: start1, end1, start2, end2, ...
|
25 |
+
starts_ends = parts[1:-1]
|
26 |
+
slices = tuple(
|
27 |
+
(int(start), int(end)) for start, end in zip(starts_ends[::2], starts_ends[1::2])
|
28 |
+
)
|
29 |
+
return LabeledMultiSpan(slices, label)
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unsupported span id: {span_id}")
|
app.py
CHANGED
@@ -4,7 +4,7 @@ import os.path
|
|
4 |
import re
|
5 |
import tempfile
|
6 |
from functools import partial
|
7 |
-
from typing import List, Optional, Tuple
|
8 |
|
9 |
import gradio as gr
|
10 |
import pandas as pd
|
@@ -14,8 +14,11 @@ from embedding import EmbeddingModel
|
|
14 |
from model_utils import annotate_document, create_document, load_models
|
15 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
16 |
from pytorch_ie import Pipeline
|
17 |
-
from pytorch_ie.documents import
|
18 |
-
|
|
|
|
|
|
|
19 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
20 |
from vector_store import QdrantVectorStore, SimpleVectorStore
|
21 |
|
@@ -35,6 +38,10 @@ DEFAULT_EMBEDDING_MAX_LENGTH = 512
|
|
35 |
DEFAULT_EMBEDDING_BATCH_SIZE = 32
|
36 |
DEFAULT_SPLIT_REGEX = "\n\n\n+"
|
37 |
|
|
|
|
|
|
|
|
|
38 |
|
39 |
def escape_regex(regex: str) -> str:
|
40 |
# "double escape" the backslashes
|
@@ -49,7 +56,10 @@ def unescape_regex(regex: str) -> str:
|
|
49 |
|
50 |
|
51 |
def render_annotated_document(
|
52 |
-
document:
|
|
|
|
|
|
|
53 |
render_with: str,
|
54 |
render_kwargs_json: str,
|
55 |
) -> str:
|
@@ -70,7 +80,14 @@ def wrapped_process_text(
|
|
70 |
models: Tuple[Pipeline, Optional[EmbeddingModel]],
|
71 |
document_store: DocumentStore,
|
72 |
split_regex_escaped: str,
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
try:
|
75 |
document = create_document(
|
76 |
text=text,
|
@@ -79,10 +96,11 @@ def wrapped_process_text(
|
|
79 |
if len(split_regex_escaped) > 0
|
80 |
else None,
|
81 |
)
|
82 |
-
annotate_document(
|
83 |
document=document,
|
84 |
annotation_pipeline=models[0],
|
85 |
embedding_model=models[1],
|
|
|
86 |
)
|
87 |
document_store.add_document(document)
|
88 |
except Exception as e:
|
@@ -100,6 +118,8 @@ def process_uploaded_files(
|
|
100 |
document_store: DocumentStore,
|
101 |
split_regex_escaped: str,
|
102 |
show_max_cross_doc_sims: bool = False,
|
|
|
|
|
103 |
) -> pd.DataFrame:
|
104 |
try:
|
105 |
new_documents = []
|
@@ -117,10 +137,11 @@ def process_uploaded_files(
|
|
117 |
if len(split_regex_escaped) > 0
|
118 |
else None,
|
119 |
)
|
120 |
-
annotate_document(
|
121 |
document=new_document,
|
122 |
annotation_pipeline=models[0],
|
123 |
embedding_model=models[1],
|
|
|
124 |
)
|
125 |
new_documents.append(new_document)
|
126 |
else:
|
@@ -129,7 +150,9 @@ def process_uploaded_files(
|
|
129 |
except Exception as e:
|
130 |
raise gr.Error(f"Failed to process uploaded files: {e}")
|
131 |
|
132 |
-
return document_store.overview(
|
|
|
|
|
133 |
|
134 |
|
135 |
def open_accordion():
|
@@ -144,9 +167,15 @@ def select_processed_document(
|
|
144 |
evt: gr.SelectData,
|
145 |
processed_documents_df: pd.DataFrame,
|
146 |
document_store: DocumentStore,
|
147 |
-
) ->
|
|
|
|
|
|
|
148 |
row_idx, col_idx = evt.index
|
149 |
-
|
|
|
|
|
|
|
150 |
doc = document_store.get_document(doc_id, with_embeddings=False)
|
151 |
return doc
|
152 |
|
@@ -231,6 +260,12 @@ def main():
|
|
231 |
span_annotation_caption="adu",
|
232 |
relation_annotation_caption="relation",
|
233 |
vector_store=QdrantVectorStore(),
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
)
|
235 |
)
|
236 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
@@ -399,14 +434,20 @@ def main():
|
|
399 |
|
400 |
show_overview_kwargs = dict(
|
401 |
fn=lambda document_store, show_max_sims, min_sim: document_store.overview(
|
402 |
-
with_max_cross_doc_sims=show_max_sims
|
403 |
),
|
404 |
inputs=[document_store_state, show_max_cross_docu_sims, min_similarity],
|
405 |
outputs=[processed_documents_df],
|
406 |
)
|
407 |
predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
|
408 |
-
fn=wrapped_process_text,
|
409 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
410 |
outputs=[document_json, document_state],
|
411 |
api_name="predict",
|
412 |
).success(**show_overview_kwargs)
|
@@ -423,13 +464,14 @@ def main():
|
|
423 |
upload_btn.upload(
|
424 |
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
|
425 |
).then(
|
426 |
-
fn=process_uploaded_files,
|
427 |
inputs=[
|
428 |
upload_btn,
|
429 |
models_state,
|
430 |
document_store_state,
|
431 |
split_regex_escaped,
|
432 |
show_max_cross_docu_sims,
|
|
|
433 |
],
|
434 |
outputs=[processed_documents_df],
|
435 |
)
|
@@ -470,7 +512,9 @@ def main():
|
|
470 |
selected_adu_id.change(
|
471 |
fn=partial(
|
472 |
get_annotation_from_document,
|
473 |
-
annotation_layer="labeled_spans"
|
|
|
|
|
474 |
use_predictions=True,
|
475 |
),
|
476 |
inputs=[document_state, selected_adu_id],
|
@@ -483,7 +527,9 @@ def main():
|
|
483 |
ref_document=document,
|
484 |
min_similarity=min_sim,
|
485 |
top_k=k,
|
486 |
-
annotation_layer="labeled_spans"
|
|
|
|
|
487 |
),
|
488 |
inputs=[
|
489 |
document_store_state,
|
@@ -513,97 +559,7 @@ def main():
|
|
513 |
# **retrieve_relevant_adus_event_kwargs
|
514 |
# )
|
515 |
|
516 |
-
js =
|
517 |
-
() => {
|
518 |
-
function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
|
519 |
-
var color = entity.getAttribute('data-color-' + colorAttributeKey);
|
520 |
-
// if color is a json string, parse it and use the value at colorDictKey
|
521 |
-
try {
|
522 |
-
const colors = JSON.parse(color);
|
523 |
-
color = colors[colorDictKey];
|
524 |
-
} catch (e) {}
|
525 |
-
if (color) {
|
526 |
-
entity.style.backgroundColor = color;
|
527 |
-
entity.style.color = '#000';
|
528 |
-
}
|
529 |
-
}
|
530 |
-
|
531 |
-
function highlightRelationArguments(entityId) {
|
532 |
-
const entities = document.querySelectorAll('.entity');
|
533 |
-
// reset all entities
|
534 |
-
entities.forEach(entity => {
|
535 |
-
const color = entity.getAttribute('data-color-original');
|
536 |
-
entity.style.backgroundColor = color;
|
537 |
-
entity.style.color = '';
|
538 |
-
});
|
539 |
-
|
540 |
-
if (entityId !== null) {
|
541 |
-
var visitedEntities = new Set();
|
542 |
-
// highlight selected entity
|
543 |
-
const selectedEntity = document.getElementById(entityId);
|
544 |
-
if (selectedEntity) {
|
545 |
-
const label = selectedEntity.getAttribute('data-label');
|
546 |
-
maybeSetColor(selectedEntity, 'selected', label);
|
547 |
-
visitedEntities.add(selectedEntity);
|
548 |
-
}
|
549 |
-
// highlight tails
|
550 |
-
const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
|
551 |
-
relationTailsAndLabels.forEach(relationTail => {
|
552 |
-
const tailEntity = document.getElementById(relationTail['entity-id']);
|
553 |
-
if (tailEntity) {
|
554 |
-
const label = relationTail['label'];
|
555 |
-
maybeSetColor(tailEntity, 'tail', label);
|
556 |
-
visitedEntities.add(tailEntity);
|
557 |
-
}
|
558 |
-
});
|
559 |
-
// highlight heads
|
560 |
-
const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
|
561 |
-
relationHeadsAndLabels.forEach(relationHead => {
|
562 |
-
const headEntity = document.getElementById(relationHead['entity-id']);
|
563 |
-
if (headEntity) {
|
564 |
-
const label = relationHead['label'];
|
565 |
-
maybeSetColor(headEntity, 'head', label);
|
566 |
-
visitedEntities.add(headEntity);
|
567 |
-
}
|
568 |
-
});
|
569 |
-
// highlight other entities
|
570 |
-
entities.forEach(entity => {
|
571 |
-
if (!visitedEntities.has(entity)) {
|
572 |
-
const label = entity.getAttribute('data-label');
|
573 |
-
maybeSetColor(entity, 'other', label);
|
574 |
-
}
|
575 |
-
});
|
576 |
-
}
|
577 |
-
}
|
578 |
-
function setReferenceAduId(entityId) {
|
579 |
-
// get the textarea element that holds the reference adu id
|
580 |
-
let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
|
581 |
-
// set the value of the input field
|
582 |
-
referenceAduIdDiv.value = entityId;
|
583 |
-
// trigger an input event to update the state
|
584 |
-
var event = new Event('input');
|
585 |
-
referenceAduIdDiv.dispatchEvent(event);
|
586 |
-
}
|
587 |
-
|
588 |
-
const entities = document.querySelectorAll('.entity');
|
589 |
-
entities.forEach(entity => {
|
590 |
-
const alreadyHasListener = entity.getAttribute('data-has-listener');
|
591 |
-
if (alreadyHasListener) {
|
592 |
-
return;
|
593 |
-
}
|
594 |
-
entity.addEventListener('mouseover', () => {
|
595 |
-
highlightRelationArguments(entity.id);
|
596 |
-
setReferenceAduId(entity.id);
|
597 |
-
});
|
598 |
-
entity.addEventListener('mouseout', () => {
|
599 |
-
highlightRelationArguments(null);
|
600 |
-
});
|
601 |
-
entity.setAttribute('data-has-listener', 'true');
|
602 |
-
});
|
603 |
-
}
|
604 |
-
"""
|
605 |
-
|
606 |
-
rendered_output.change(fn=None, js=js, inputs=[], outputs=[])
|
607 |
|
608 |
demo.launch()
|
609 |
|
|
|
4 |
import re
|
5 |
import tempfile
|
6 |
from functools import partial
|
7 |
+
from typing import List, Optional, Tuple, Union
|
8 |
|
9 |
import gradio as gr
|
10 |
import pandas as pd
|
|
|
14 |
from model_utils import annotate_document, create_document, load_models
|
15 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
16 |
from pytorch_ie import Pipeline
|
17 |
+
from pytorch_ie.documents import (
|
18 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
19 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
20 |
+
)
|
21 |
+
from rendering_utils import HIGHLIGHT_SPANS_JS, render_displacy, render_pretty_table
|
22 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
23 |
from vector_store import QdrantVectorStore, SimpleVectorStore
|
24 |
|
|
|
38 |
DEFAULT_EMBEDDING_BATCH_SIZE = 32
|
39 |
DEFAULT_SPLIT_REGEX = "\n\n\n+"
|
40 |
|
41 |
+
# Whether to handle segmented entities in the document. If True, labeled_spans are converted
|
42 |
+
# to labeled_multi_spans and binary_relations with label "parts_of_same" are used to merge them.
|
43 |
+
HANDLE_PARTS_OF_SAME = True
|
44 |
+
|
45 |
|
46 |
def escape_regex(regex: str) -> str:
|
47 |
# "double escape" the backslashes
|
|
|
56 |
|
57 |
|
58 |
def render_annotated_document(
|
59 |
+
document: Union[
|
60 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
61 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
62 |
+
],
|
63 |
render_with: str,
|
64 |
render_kwargs_json: str,
|
65 |
) -> str:
|
|
|
80 |
models: Tuple[Pipeline, Optional[EmbeddingModel]],
|
81 |
document_store: DocumentStore,
|
82 |
split_regex_escaped: str,
|
83 |
+
handle_parts_of_same: bool = False,
|
84 |
+
) -> Tuple[
|
85 |
+
dict,
|
86 |
+
Union[
|
87 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
88 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
89 |
+
],
|
90 |
+
]:
|
91 |
try:
|
92 |
document = create_document(
|
93 |
text=text,
|
|
|
96 |
if len(split_regex_escaped) > 0
|
97 |
else None,
|
98 |
)
|
99 |
+
document = annotate_document(
|
100 |
document=document,
|
101 |
annotation_pipeline=models[0],
|
102 |
embedding_model=models[1],
|
103 |
+
handle_parts_of_same=handle_parts_of_same,
|
104 |
)
|
105 |
document_store.add_document(document)
|
106 |
except Exception as e:
|
|
|
118 |
document_store: DocumentStore,
|
119 |
split_regex_escaped: str,
|
120 |
show_max_cross_doc_sims: bool = False,
|
121 |
+
min_similarity: float = 0.95,
|
122 |
+
handle_parts_of_same: bool = False,
|
123 |
) -> pd.DataFrame:
|
124 |
try:
|
125 |
new_documents = []
|
|
|
137 |
if len(split_regex_escaped) > 0
|
138 |
else None,
|
139 |
)
|
140 |
+
new_document = annotate_document(
|
141 |
document=new_document,
|
142 |
annotation_pipeline=models[0],
|
143 |
embedding_model=models[1],
|
144 |
+
handle_parts_of_same=handle_parts_of_same,
|
145 |
)
|
146 |
new_documents.append(new_document)
|
147 |
else:
|
|
|
150 |
except Exception as e:
|
151 |
raise gr.Error(f"Failed to process uploaded files: {e}")
|
152 |
|
153 |
+
return document_store.overview(
|
154 |
+
with_max_cross_doc_sims=show_max_cross_doc_sims, min_similarity=min_similarity
|
155 |
+
)
|
156 |
|
157 |
|
158 |
def open_accordion():
|
|
|
167 |
evt: gr.SelectData,
|
168 |
processed_documents_df: pd.DataFrame,
|
169 |
document_store: DocumentStore,
|
170 |
+
) -> Union[
|
171 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
172 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
173 |
+
]:
|
174 |
row_idx, col_idx = evt.index
|
175 |
+
col_name = processed_documents_df.columns[col_idx]
|
176 |
+
if not col_name.endswith("doc_id"):
|
177 |
+
col_name = "doc_id"
|
178 |
+
doc_id = processed_documents_df.iloc[row_idx][col_name]
|
179 |
doc = document_store.get_document(doc_id, with_embeddings=False)
|
180 |
return doc
|
181 |
|
|
|
260 |
span_annotation_caption="adu",
|
261 |
relation_annotation_caption="relation",
|
262 |
vector_store=QdrantVectorStore(),
|
263 |
+
document_type=TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
264 |
+
if not HANDLE_PARTS_OF_SAME
|
265 |
+
else TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
266 |
+
span_layer_name="labeled_spans"
|
267 |
+
if not HANDLE_PARTS_OF_SAME
|
268 |
+
else "labeled_multi_spans",
|
269 |
)
|
270 |
)
|
271 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
|
|
434 |
|
435 |
show_overview_kwargs = dict(
|
436 |
fn=lambda document_store, show_max_sims, min_sim: document_store.overview(
|
437 |
+
with_max_cross_doc_sims=show_max_sims, min_similarity=min_sim
|
438 |
),
|
439 |
inputs=[document_store_state, show_max_cross_docu_sims, min_similarity],
|
440 |
outputs=[processed_documents_df],
|
441 |
)
|
442 |
predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
|
443 |
+
fn=partial(wrapped_process_text, handle_parts_of_same=HANDLE_PARTS_OF_SAME),
|
444 |
+
inputs=[
|
445 |
+
doc_text,
|
446 |
+
doc_id,
|
447 |
+
models_state,
|
448 |
+
document_store_state,
|
449 |
+
split_regex_escaped,
|
450 |
+
],
|
451 |
outputs=[document_json, document_state],
|
452 |
api_name="predict",
|
453 |
).success(**show_overview_kwargs)
|
|
|
464 |
upload_btn.upload(
|
465 |
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
|
466 |
).then(
|
467 |
+
fn=partial(process_uploaded_files, handle_parts_of_same=HANDLE_PARTS_OF_SAME),
|
468 |
inputs=[
|
469 |
upload_btn,
|
470 |
models_state,
|
471 |
document_store_state,
|
472 |
split_regex_escaped,
|
473 |
show_max_cross_docu_sims,
|
474 |
+
min_similarity,
|
475 |
],
|
476 |
outputs=[processed_documents_df],
|
477 |
)
|
|
|
512 |
selected_adu_id.change(
|
513 |
fn=partial(
|
514 |
get_annotation_from_document,
|
515 |
+
annotation_layer="labeled_spans"
|
516 |
+
if not HANDLE_PARTS_OF_SAME
|
517 |
+
else "labeled_multi_spans",
|
518 |
use_predictions=True,
|
519 |
),
|
520 |
inputs=[document_state, selected_adu_id],
|
|
|
527 |
ref_document=document,
|
528 |
min_similarity=min_sim,
|
529 |
top_k=k,
|
530 |
+
annotation_layer="labeled_spans"
|
531 |
+
if not HANDLE_PARTS_OF_SAME
|
532 |
+
else "labeled_multi_spans",
|
533 |
),
|
534 |
inputs=[
|
535 |
document_store_state,
|
|
|
559 |
# **retrieve_relevant_adus_event_kwargs
|
560 |
# )
|
561 |
|
562 |
+
rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
563 |
|
564 |
demo.launch()
|
565 |
|
document_store.py
CHANGED
@@ -14,6 +14,7 @@ from annotation_utils import labeled_span_to_id
|
|
14 |
from pytorch_ie import Annotation
|
15 |
from pytorch_ie.documents import (
|
16 |
TextBasedDocument,
|
|
|
17 |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
18 |
)
|
19 |
from scipy.sparse import csr_matrix
|
@@ -45,7 +46,7 @@ def get_annotation_from_document(
|
|
45 |
if use_predictions:
|
46 |
annotations = annotations.predictions
|
47 |
|
48 |
-
if annotation_layer
|
49 |
annotation_to_id_func = labeled_span_to_id
|
50 |
else:
|
51 |
raise gr.Error(f"Unknown annotation layer '{annotation_layer}'.")
|
@@ -301,6 +302,12 @@ class DocumentStore:
|
|
301 |
|
302 |
def add_document(self, document: TextBasedDocument) -> None:
|
303 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
if document.id in self.documents:
|
305 |
gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
|
306 |
|
@@ -485,6 +492,11 @@ class DocumentStore:
|
|
485 |
|
486 |
max_doc_ids = max_doc2doc_similarities.idxmax(axis="columns")
|
487 |
max_similarities = max_doc2doc_similarities.max(axis="columns")
|
|
|
|
|
|
|
|
|
|
|
488 |
|
489 |
# set the index to the doc_id to correctly join the series
|
490 |
df.set_index("doc_id", inplace=True)
|
@@ -551,7 +563,8 @@ class DocumentStore:
|
|
551 |
# set similarities below min_similarity to 0
|
552 |
similarities[similarities < min_similarity] = 0.0
|
553 |
|
554 |
-
# set triangular part to 0
|
|
|
555 |
similarities = np.triu(similarities, k=1)
|
556 |
# create a sparse matrix
|
557 |
sparse_matrix = csr_matrix(similarities)
|
@@ -564,29 +577,26 @@ class DocumentStore:
|
|
564 |
|
565 |
# construct the DataFrame
|
566 |
records = []
|
567 |
-
for idx1, idx2 in zip(non_zero_idx[0], non_zero_idx[1]):
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
"other_text": annotation_text2,
|
588 |
-
}
|
589 |
-
)
|
590 |
result_df = pd.DataFrame(records)
|
591 |
gr.Info(f"DataFrame shape: {result_df.shape}")
|
592 |
|
|
|
14 |
from pytorch_ie import Annotation
|
15 |
from pytorch_ie.documents import (
|
16 |
TextBasedDocument,
|
17 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
18 |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
19 |
)
|
20 |
from scipy.sparse import csr_matrix
|
|
|
46 |
if use_predictions:
|
47 |
annotations = annotations.predictions
|
48 |
|
49 |
+
if annotation_layer in ["labeled_spans", "labeled_multi_spans"]:
|
50 |
annotation_to_id_func = labeled_span_to_id
|
51 |
else:
|
52 |
raise gr.Error(f"Unknown annotation layer '{annotation_layer}'.")
|
|
|
302 |
|
303 |
def add_document(self, document: TextBasedDocument) -> None:
|
304 |
try:
|
305 |
+
if not isinstance(document, self.document_type):
|
306 |
+
raise gr.Error(
|
307 |
+
f"The document to add must be of type {self.document_type}, but is of type "
|
308 |
+
f"{type(document)}."
|
309 |
+
)
|
310 |
+
|
311 |
if document.id in self.documents:
|
312 |
gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
|
313 |
|
|
|
492 |
|
493 |
max_doc_ids = max_doc2doc_similarities.idxmax(axis="columns")
|
494 |
max_similarities = max_doc2doc_similarities.max(axis="columns")
|
495 |
+
# entries where max_similarities is -inf are documents with no entries in other documents
|
496 |
+
# with similarity > min_similarity
|
497 |
+
mask = max_similarities == -np.inf
|
498 |
+
max_doc_ids[mask] = np.nan
|
499 |
+
max_similarities[mask] = np.nan
|
500 |
|
501 |
# set the index to the doc_id to correctly join the series
|
502 |
df.set_index("doc_id", inplace=True)
|
|
|
563 |
# set similarities below min_similarity to 0
|
564 |
similarities[similarities < min_similarity] = 0.0
|
565 |
|
566 |
+
# set triangular part to 0 because we only want the upper triangular part which
|
567 |
+
# contains entries with idx1 < idx2
|
568 |
similarities = np.triu(similarities, k=1)
|
569 |
# create a sparse matrix
|
570 |
sparse_matrix = csr_matrix(similarities)
|
|
|
577 |
|
578 |
# construct the DataFrame
|
579 |
records = []
|
580 |
+
for sparse_idx, (idx1, idx2) in enumerate(zip(non_zero_idx[0], non_zero_idx[1])):
|
581 |
+
payload1 = all_payloads[idx1]
|
582 |
+
payload2 = all_payloads[idx2]
|
583 |
+
doc_id1 = payload1["doc_id"]
|
584 |
+
doc_id2 = payload2["doc_id"]
|
585 |
+
annotation_id1 = payload1["annotation_id"]
|
586 |
+
annotation_id2 = payload2["annotation_id"]
|
587 |
+
annotation_text1 = doc_id_and_annotation_id2annotation_text[(doc_id1, annotation_id1)]
|
588 |
+
annotation_text2 = doc_id_and_annotation_id2annotation_text[(doc_id2, annotation_id2)]
|
589 |
+
records.append(
|
590 |
+
{
|
591 |
+
"sim_score": scores[sparse_idx],
|
592 |
+
"doc_id": doc_id1,
|
593 |
+
"other_doc_id": doc_id2,
|
594 |
+
"adu_id": annotation_id1,
|
595 |
+
"other_adu_id": annotation_id2,
|
596 |
+
"text": annotation_text1,
|
597 |
+
"other_text": annotation_text2,
|
598 |
+
}
|
599 |
+
)
|
|
|
|
|
|
|
600 |
result_df = pd.DataFrame(records)
|
601 |
gr.Info(f"DataFrame shape: {result_df.shape}")
|
602 |
|
embedding.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1 |
import abc
|
2 |
import logging
|
3 |
-
from typing import Dict
|
4 |
|
5 |
import torch
|
6 |
from datasets import Dataset
|
7 |
from pie_modules.document.processing import tokenize_document
|
8 |
-
from pie_modules.documents import
|
9 |
-
|
|
|
|
|
|
|
10 |
from pytorch_ie.documents import TextBasedDocument
|
11 |
from torch import FloatTensor, Tensor
|
12 |
from torch.utils.data import DataLoader
|
@@ -18,7 +21,7 @@ logger = logging.getLogger(__name__)
|
|
18 |
class EmbeddingModel(abc.ABC):
|
19 |
def __call__(
|
20 |
self, document: TextBasedDocument, span_layer_name: str
|
21 |
-
) -> Dict[Span, FloatTensor]:
|
22 |
"""Embed text annotations from a document.
|
23 |
|
24 |
Args:
|
@@ -51,7 +54,7 @@ class HuggingfaceEmbeddingModel(EmbeddingModel):
|
|
51 |
|
52 |
def __call__(
|
53 |
self, document: TextBasedDocument, span_layer_name: str
|
54 |
-
) -> Dict[Span, FloatTensor]:
|
55 |
# to not modify the original document
|
56 |
document = document.copy()
|
57 |
# tokenize_document does not yet consider predictions, so we need to add them manually
|
@@ -65,10 +68,21 @@ class HuggingfaceEmbeddingModel(EmbeddingModel):
|
|
65 |
"return_overflowing_tokens": True,
|
66 |
}
|
67 |
# tokenize once to get the tokenized documents with mapped annotations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
tokenized_documents = tokenize_document(
|
69 |
document,
|
70 |
tokenizer=self._tokenizer,
|
71 |
-
result_document_type=
|
72 |
partition_layer="labeled_partitions",
|
73 |
added_annotations=added_annotations,
|
74 |
strict_span_conversion=False,
|
@@ -104,14 +118,33 @@ class HuggingfaceEmbeddingModel(EmbeddingModel):
|
|
104 |
for last_hidden_state in model_output.last_hidden_state:
|
105 |
text2tok_ann = added_annotations[example_idx][span_layer_name]
|
106 |
tok2text_ann = {v: k for k, v in text2tok_ann.items()}
|
107 |
-
for tok_ann in tokenized_documents[example_idx]
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
# use the max pooling strategy to get a single embedding for the annotation text
|
112 |
-
embedding = (
|
113 |
-
|
114 |
-
)
|
115 |
text_ann = tok2text_ann[tok_ann]
|
116 |
|
117 |
# if text_ann in embeddings:
|
|
|
1 |
import abc
|
2 |
import logging
|
3 |
+
from typing import Dict, Union
|
4 |
|
5 |
import torch
|
6 |
from datasets import Dataset
|
7 |
from pie_modules.document.processing import tokenize_document
|
8 |
+
from pie_modules.documents import (
|
9 |
+
TokenDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
10 |
+
TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
11 |
+
)
|
12 |
+
from pytorch_ie.annotations import LabeledSpan, MultiSpan, Span
|
13 |
from pytorch_ie.documents import TextBasedDocument
|
14 |
from torch import FloatTensor, Tensor
|
15 |
from torch.utils.data import DataLoader
|
|
|
21 |
class EmbeddingModel(abc.ABC):
|
22 |
def __call__(
|
23 |
self, document: TextBasedDocument, span_layer_name: str
|
24 |
+
) -> Dict[Union[Span, MultiSpan], FloatTensor]:
|
25 |
"""Embed text annotations from a document.
|
26 |
|
27 |
Args:
|
|
|
54 |
|
55 |
def __call__(
|
56 |
self, document: TextBasedDocument, span_layer_name: str
|
57 |
+
) -> Dict[Union[Span, MultiSpan], FloatTensor]:
|
58 |
# to not modify the original document
|
59 |
document = document.copy()
|
60 |
# tokenize_document does not yet consider predictions, so we need to add them manually
|
|
|
68 |
"return_overflowing_tokens": True,
|
69 |
}
|
70 |
# tokenize once to get the tokenized documents with mapped annotations
|
71 |
+
span_annotation_type = document.annotation_types()[span_layer_name]
|
72 |
+
if issubclass(span_annotation_type, Span):
|
73 |
+
result_document_type = TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
74 |
+
tokenized_span_layer_name = "labeled_spans"
|
75 |
+
elif issubclass(span_annotation_type, MultiSpan):
|
76 |
+
result_document_type = (
|
77 |
+
TokenDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
|
78 |
+
)
|
79 |
+
tokenized_span_layer_name = "labeled_multi_spans"
|
80 |
+
else:
|
81 |
+
raise ValueError(f"Unsupported annotation type: {span_annotation_type}")
|
82 |
tokenized_documents = tokenize_document(
|
83 |
document,
|
84 |
tokenizer=self._tokenizer,
|
85 |
+
result_document_type=result_document_type,
|
86 |
partition_layer="labeled_partitions",
|
87 |
added_annotations=added_annotations,
|
88 |
strict_span_conversion=False,
|
|
|
118 |
for last_hidden_state in model_output.last_hidden_state:
|
119 |
text2tok_ann = added_annotations[example_idx][span_layer_name]
|
120 |
tok2text_ann = {v: k for k, v in text2tok_ann.items()}
|
121 |
+
for tok_ann in tokenized_documents[example_idx][tokenized_span_layer_name]:
|
122 |
+
if isinstance(tok_ann, LabeledSpan):
|
123 |
+
# skip "empty" annotations
|
124 |
+
if tok_ann.start == tok_ann.end:
|
125 |
+
continue
|
126 |
+
|
127 |
+
embedded_tokens = last_hidden_state[tok_ann.start : tok_ann.end]
|
128 |
+
|
129 |
+
elif isinstance(tok_ann, MultiSpan):
|
130 |
+
# skip "empty" annotations
|
131 |
+
if all(start == end for start, end in tok_ann.slices):
|
132 |
+
continue
|
133 |
+
|
134 |
+
# concatenate the embeddings of the tokens that make up the multi-span
|
135 |
+
embedded_tokens = torch.concat(
|
136 |
+
[
|
137 |
+
last_hidden_state[start:end]
|
138 |
+
for start, end in tok_ann.slices
|
139 |
+
if start != end
|
140 |
+
],
|
141 |
+
dim=0,
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
raise ValueError(f"Unsupported annotation type: {type(tok_ann)}")
|
145 |
# use the max pooling strategy to get a single embedding for the annotation text
|
146 |
+
embedding = embedded_tokens.max(dim=0)[0].detach().cpu()
|
147 |
+
|
|
|
148 |
text_ann = tok2text_ann[tok_ann]
|
149 |
|
150 |
# if text_ann in embeddings:
|
model_utils.py
CHANGED
@@ -1,15 +1,18 @@
|
|
1 |
import logging
|
2 |
-
from typing import Optional, Tuple
|
3 |
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
from annotation_utils import labeled_span_to_id
|
7 |
from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
|
8 |
-
from pie_modules.document.processing import RegexPartitioner
|
9 |
from pytorch_ie import Pipeline
|
10 |
from pytorch_ie.annotations import LabeledSpan
|
11 |
from pytorch_ie.auto import AutoPipeline
|
12 |
-
from pytorch_ie.documents import
|
|
|
|
|
|
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
@@ -18,7 +21,11 @@ def annotate_document(
|
|
18 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
19 |
annotation_pipeline: Pipeline,
|
20 |
embedding_model: Optional[EmbeddingModel] = None,
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
"""Annotate a document with the provided pipeline. If an embedding model is provided, also
|
23 |
extract embeddings for the labeled spans.
|
24 |
|
@@ -26,15 +33,30 @@ def annotate_document(
|
|
26 |
document: The document to annotate.
|
27 |
annotation_pipeline: The pipeline to use for annotation.
|
28 |
embedding_model: The embedding model to use for extracting text span embeddings.
|
|
|
29 |
"""
|
30 |
|
31 |
# execute prediction pipeline
|
32 |
annotation_pipeline(document)
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
if embedding_model is not None:
|
35 |
text_span_embeddings = embedding_model(
|
36 |
document=document,
|
37 |
-
span_layer_name="labeled_spans",
|
38 |
)
|
39 |
# convert keys to str because JSON keys must be strings
|
40 |
text_span_embeddings_dict = {
|
@@ -47,6 +69,8 @@ def annotate_document(
|
|
47 |
"model in the 'Model Configuration' section."
|
48 |
)
|
49 |
|
|
|
|
|
50 |
|
51 |
def create_document(
|
52 |
text: str, doc_id: str, split_regex: Optional[str] = None
|
|
|
1 |
import logging
|
2 |
+
from typing import Optional, Tuple, Union
|
3 |
|
4 |
import gradio as gr
|
5 |
import torch
|
6 |
from annotation_utils import labeled_span_to_id
|
7 |
from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
|
8 |
+
from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
|
9 |
from pytorch_ie import Pipeline
|
10 |
from pytorch_ie.annotations import LabeledSpan
|
11 |
from pytorch_ie.auto import AutoPipeline
|
12 |
+
from pytorch_ie.documents import (
|
13 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
14 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
15 |
+
)
|
16 |
|
17 |
logger = logging.getLogger(__name__)
|
18 |
|
|
|
21 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
22 |
annotation_pipeline: Pipeline,
|
23 |
embedding_model: Optional[EmbeddingModel] = None,
|
24 |
+
handle_parts_of_same: bool = False,
|
25 |
+
) -> Union[
|
26 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
27 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
28 |
+
]:
|
29 |
"""Annotate a document with the provided pipeline. If an embedding model is provided, also
|
30 |
extract embeddings for the labeled spans.
|
31 |
|
|
|
33 |
document: The document to annotate.
|
34 |
annotation_pipeline: The pipeline to use for annotation.
|
35 |
embedding_model: The embedding model to use for extracting text span embeddings.
|
36 |
+
handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
|
37 |
"""
|
38 |
|
39 |
# execute prediction pipeline
|
40 |
annotation_pipeline(document)
|
41 |
|
42 |
+
if handle_parts_of_same:
|
43 |
+
merger = SpansViaRelationMerger(
|
44 |
+
relation_layer="binary_relations",
|
45 |
+
link_relation_label="parts_of_same",
|
46 |
+
create_multi_spans=True,
|
47 |
+
result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
48 |
+
result_field_mapping={
|
49 |
+
"labeled_spans": "labeled_multi_spans",
|
50 |
+
"binary_relations": "binary_relations",
|
51 |
+
"labeled_partitions": "labeled_partitions",
|
52 |
+
},
|
53 |
+
)
|
54 |
+
document = merger(document)
|
55 |
+
|
56 |
if embedding_model is not None:
|
57 |
text_span_embeddings = embedding_model(
|
58 |
document=document,
|
59 |
+
span_layer_name="labeled_spans" if not handle_parts_of_same else "labeled_multi_spans",
|
60 |
)
|
61 |
# convert keys to str because JSON keys must be strings
|
62 |
text_span_embeddings_dict = {
|
|
|
69 |
"model in the 'Model Configuration' section."
|
70 |
)
|
71 |
|
72 |
+
return document
|
73 |
+
|
74 |
|
75 |
def create_document(
|
76 |
text: str, doc_id: str, split_regex: Optional[str] = None
|
rendering_utils.py
CHANGED
@@ -4,23 +4,130 @@ from collections import defaultdict
|
|
4 |
from typing import Dict, List, Optional, Union
|
5 |
|
6 |
from annotation_utils import labeled_span_to_id
|
7 |
-
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
|
8 |
-
from pytorch_ie.documents import
|
|
|
|
|
|
|
9 |
from rendering_utils_displacy import EntityRenderer
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
# adjusted from rendering_utils_displacy.TPL_ENT
|
14 |
TPL_ENT_WITH_ID = """
|
15 |
-
<mark class="entity" id="{
|
16 |
{text}
|
17 |
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
|
18 |
</mark>
|
19 |
"""
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def render_pretty_table(
|
23 |
-
document:
|
|
|
|
|
|
|
|
|
24 |
):
|
25 |
from prettytable import PrettyTable
|
26 |
|
@@ -37,27 +144,57 @@ def render_pretty_table(
|
|
37 |
|
38 |
|
39 |
def render_displacy(
|
40 |
-
document:
|
|
|
|
|
|
|
41 |
inject_relations=True,
|
42 |
colors_hover=None,
|
43 |
entity_options={},
|
44 |
**render_kwargs,
|
45 |
):
|
46 |
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
spacy_doc = {
|
49 |
"text": document.text,
|
50 |
-
|
51 |
-
|
52 |
-
"start": labeled_span.start,
|
53 |
-
"end": labeled_span.end,
|
54 |
-
"label": labeled_span.label,
|
55 |
-
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
|
56 |
-
# on hover and to inject the relation data.
|
57 |
-
"params": {"id": labeled_span_to_id(labeled_span)},
|
58 |
-
}
|
59 |
-
for labeled_span in labeled_spans
|
60 |
-
],
|
61 |
"title": None,
|
62 |
}
|
63 |
|
@@ -75,7 +212,7 @@ def render_displacy(
|
|
75 |
)
|
76 |
html = inject_relation_data(
|
77 |
html,
|
78 |
-
|
79 |
binary_relations=binary_relations,
|
80 |
additional_colors=colors_hover,
|
81 |
)
|
@@ -84,7 +221,7 @@ def render_displacy(
|
|
84 |
|
85 |
def inject_relation_data(
|
86 |
html: str,
|
87 |
-
|
88 |
binary_relations: List[BinaryRelation],
|
89 |
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
|
90 |
) -> str:
|
@@ -99,7 +236,7 @@ def inject_relation_data(
|
|
99 |
entity2heads[relation.tail].append((relation.head, relation.label))
|
100 |
entity2tails[relation.head].append((relation.tail, relation.label))
|
101 |
|
102 |
-
ann_id2annotation = {labeled_span_to_id(entity): entity for entity in
|
103 |
# Add unique IDs to each entity
|
104 |
entities = soup.find_all(class_="entity")
|
105 |
for entity in entities:
|
@@ -110,12 +247,21 @@ def inject_relation_data(
|
|
110 |
entity[f"data-color-{key}"] = (
|
111 |
json.dumps(color) if isinstance(color, dict) else color
|
112 |
)
|
113 |
-
entity_annotation = ann_id2annotation[entity["id"]]
|
|
|
114 |
# sanity check.
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
# Just check the start, because the text has the label attached to the end
|
117 |
if not entity.text.startswith(annotation_text_without_newline):
|
118 |
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
|
|
|
119 |
entity["data-label"] = entity_annotation.label
|
120 |
entity["data-relation-tails"] = json.dumps(
|
121 |
[
|
|
|
4 |
from typing import Dict, List, Optional, Union
|
5 |
|
6 |
from annotation_utils import labeled_span_to_id
|
7 |
+
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
|
8 |
+
from pytorch_ie.documents import (
|
9 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
10 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
11 |
+
)
|
12 |
from rendering_utils_displacy import EntityRenderer
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
16 |
# adjusted from rendering_utils_displacy.TPL_ENT
|
17 |
TPL_ENT_WITH_ID = """
|
18 |
+
<mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
|
19 |
{text}
|
20 |
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
|
21 |
</mark>
|
22 |
"""
|
23 |
|
24 |
+
HIGHLIGHT_SPANS_JS = """
|
25 |
+
() => {
|
26 |
+
function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
|
27 |
+
var color = entity.getAttribute('data-color-' + colorAttributeKey);
|
28 |
+
// if color is a json string, parse it and use the value at colorDictKey
|
29 |
+
try {
|
30 |
+
const colors = JSON.parse(color);
|
31 |
+
color = colors[colorDictKey];
|
32 |
+
} catch (e) {}
|
33 |
+
if (color) {
|
34 |
+
entity.style.backgroundColor = color;
|
35 |
+
entity.style.color = '#000';
|
36 |
+
}
|
37 |
+
}
|
38 |
+
|
39 |
+
function highlightRelationArguments(entityId) {
|
40 |
+
const entities = document.querySelectorAll('.entity');
|
41 |
+
// reset all entities
|
42 |
+
entities.forEach(entity => {
|
43 |
+
const color = entity.getAttribute('data-color-original');
|
44 |
+
entity.style.backgroundColor = color;
|
45 |
+
entity.style.color = '';
|
46 |
+
});
|
47 |
+
|
48 |
+
if (entityId !== null) {
|
49 |
+
var visitedEntities = new Set();
|
50 |
+
// highlight selected entity
|
51 |
+
// get all elements with attribute data-entity-id==entityId
|
52 |
+
const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`);
|
53 |
+
selectedEntityParts.forEach(selectedEntityPart => {
|
54 |
+
const label = selectedEntityPart.getAttribute('data-label');
|
55 |
+
maybeSetColor(selectedEntityPart, 'selected', label);
|
56 |
+
visitedEntities.add(selectedEntityPart);
|
57 |
+
}); // <-- Corrected closing parenthesis here
|
58 |
+
// if there is at least one part, get the first one and ...
|
59 |
+
if (selectedEntityParts.length > 0) {
|
60 |
+
const selectedEntity = selectedEntityParts[0];
|
61 |
+
|
62 |
+
// ... highlight tails and ...
|
63 |
+
const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
|
64 |
+
relationTailsAndLabels.forEach(relationTail => {
|
65 |
+
const tailEntityId = relationTail['entity-id'];
|
66 |
+
const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`);
|
67 |
+
tailEntityParts.forEach(tailEntity => {
|
68 |
+
const label = relationTail['label'];
|
69 |
+
maybeSetColor(tailEntity, 'tail', label);
|
70 |
+
visitedEntities.add(tailEntity);
|
71 |
+
}); // <-- Corrected closing parenthesis here
|
72 |
+
}); // <-- Corrected closing parenthesis here
|
73 |
+
// .. highlight heads
|
74 |
+
const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
|
75 |
+
relationHeadsAndLabels.forEach(relationHead => {
|
76 |
+
const headEntityId = relationHead['entity-id'];
|
77 |
+
const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`);
|
78 |
+
headEntityParts.forEach(headEntity => {
|
79 |
+
const label = relationHead['label'];
|
80 |
+
maybeSetColor(headEntity, 'head', label);
|
81 |
+
visitedEntities.add(headEntity);
|
82 |
+
}); // <-- Corrected closing parenthesis here
|
83 |
+
}); // <-- Corrected closing parenthesis here
|
84 |
+
}
|
85 |
+
|
86 |
+
// highlight other entities
|
87 |
+
entities.forEach(entity => {
|
88 |
+
if (!visitedEntities.has(entity)) {
|
89 |
+
const label = entity.getAttribute('data-label');
|
90 |
+
maybeSetColor(entity, 'other', label);
|
91 |
+
}
|
92 |
+
});
|
93 |
+
}
|
94 |
+
}
|
95 |
+
function setReferenceAduId(entityId) {
|
96 |
+
// get the textarea element that holds the reference adu id
|
97 |
+
let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
|
98 |
+
// set the value of the input field
|
99 |
+
referenceAduIdDiv.value = entityId;
|
100 |
+
// trigger an input event to update the state
|
101 |
+
var event = new Event('input');
|
102 |
+
referenceAduIdDiv.dispatchEvent(event);
|
103 |
+
}
|
104 |
+
|
105 |
+
const entities = document.querySelectorAll('.entity');
|
106 |
+
entities.forEach(entity => {
|
107 |
+
const alreadyHasListener = entity.getAttribute('data-has-listener');
|
108 |
+
if (alreadyHasListener) {
|
109 |
+
return;
|
110 |
+
}
|
111 |
+
entity.addEventListener('mouseover', () => {
|
112 |
+
const entityId = entity.getAttribute('data-entity-id');
|
113 |
+
highlightRelationArguments(entityId);
|
114 |
+
setReferenceAduId(entityId);
|
115 |
+
});
|
116 |
+
entity.addEventListener('mouseout', () => {
|
117 |
+
highlightRelationArguments(null);
|
118 |
+
});
|
119 |
+
entity.setAttribute('data-has-listener', 'true');
|
120 |
+
});
|
121 |
+
}
|
122 |
+
"""
|
123 |
+
|
124 |
|
125 |
def render_pretty_table(
|
126 |
+
document: Union[
|
127 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
128 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
129 |
+
],
|
130 |
+
**render_kwargs,
|
131 |
):
|
132 |
from prettytable import PrettyTable
|
133 |
|
|
|
144 |
|
145 |
|
146 |
def render_displacy(
|
147 |
+
document: Union[
|
148 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
149 |
+
TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
|
150 |
+
],
|
151 |
inject_relations=True,
|
152 |
colors_hover=None,
|
153 |
entity_options={},
|
154 |
**render_kwargs,
|
155 |
):
|
156 |
|
157 |
+
if isinstance(document, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions):
|
158 |
+
span_layer = document.labeled_spans
|
159 |
+
elif isinstance(
|
160 |
+
document, TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
|
161 |
+
):
|
162 |
+
span_layer = document.labeled_multi_spans
|
163 |
+
else:
|
164 |
+
raise ValueError(f"Unsupported document type: {type(document)}")
|
165 |
+
|
166 |
+
span_annotations = list(span_layer) + list(span_layer.predictions)
|
167 |
+
ents = []
|
168 |
+
for labeled_span in span_annotations:
|
169 |
+
entity_id = labeled_span_to_id(labeled_span)
|
170 |
+
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
|
171 |
+
# on hover and to inject the relation data.
|
172 |
+
if isinstance(labeled_span, LabeledSpan):
|
173 |
+
ents.append(
|
174 |
+
{
|
175 |
+
"start": labeled_span.start,
|
176 |
+
"end": labeled_span.end,
|
177 |
+
"label": labeled_span.label,
|
178 |
+
"params": {"entity_id": entity_id, "slice_idx": 0},
|
179 |
+
}
|
180 |
+
)
|
181 |
+
elif isinstance(labeled_span, LabeledMultiSpan):
|
182 |
+
for i, (start, end) in enumerate(labeled_span.slices):
|
183 |
+
ents.append(
|
184 |
+
{
|
185 |
+
"start": start,
|
186 |
+
"end": end,
|
187 |
+
"label": labeled_span.label,
|
188 |
+
"params": {"entity_id": entity_id, "slice_idx": i},
|
189 |
+
}
|
190 |
+
)
|
191 |
+
else:
|
192 |
+
raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
|
193 |
+
|
194 |
spacy_doc = {
|
195 |
"text": document.text,
|
196 |
+
# the ents MUST be sorted by start and end
|
197 |
+
"ents": sorted(ents, key=lambda x: (x["start"], x["end"])),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
"title": None,
|
199 |
}
|
200 |
|
|
|
212 |
)
|
213 |
html = inject_relation_data(
|
214 |
html,
|
215 |
+
span_annotations=span_annotations,
|
216 |
binary_relations=binary_relations,
|
217 |
additional_colors=colors_hover,
|
218 |
)
|
|
|
221 |
|
222 |
def inject_relation_data(
|
223 |
html: str,
|
224 |
+
span_annotations: Union[List[LabeledSpan], List[LabeledMultiSpan]],
|
225 |
binary_relations: List[BinaryRelation],
|
226 |
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
|
227 |
) -> str:
|
|
|
236 |
entity2heads[relation.tail].append((relation.head, relation.label))
|
237 |
entity2tails[relation.head].append((relation.tail, relation.label))
|
238 |
|
239 |
+
ann_id2annotation = {labeled_span_to_id(entity): entity for entity in span_annotations}
|
240 |
# Add unique IDs to each entity
|
241 |
entities = soup.find_all(class_="entity")
|
242 |
for entity in entities:
|
|
|
247 |
entity[f"data-color-{key}"] = (
|
248 |
json.dumps(color) if isinstance(color, dict) else color
|
249 |
)
|
250 |
+
entity_annotation = ann_id2annotation[entity["data-entity-id"]]
|
251 |
+
|
252 |
# sanity check.
|
253 |
+
if isinstance(entity_annotation, LabeledSpan):
|
254 |
+
annotation_text = entity_annotation.resolve()[1]
|
255 |
+
elif isinstance(entity_annotation, LabeledMultiSpan):
|
256 |
+
slice_idx = int(entity["data-slice-idx"])
|
257 |
+
annotation_text = entity_annotation.resolve()[1][slice_idx]
|
258 |
+
else:
|
259 |
+
raise ValueError(f"Unsupported entity type: {type(entity_annotation)}")
|
260 |
+
annotation_text_without_newline = annotation_text.replace("\n", "")
|
261 |
# Just check the start, because the text has the label attached to the end
|
262 |
if not entity.text.startswith(annotation_text_without_newline):
|
263 |
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
|
264 |
+
|
265 |
entity["data-label"] = entity_annotation.label
|
266 |
entity["data-relation-tails"] = json.dumps(
|
267 |
[
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
gradio==4.36.0
|
2 |
prettytable==3.10.0
|
3 |
pie-modules==0.12.0
|
|
|
1 |
+
pytorch-ie==0.31.1
|
2 |
gradio==4.36.0
|
3 |
prettytable==3.10.0
|
4 |
pie-modules==0.12.0
|