ArneBinder commited on
Commit
efae5be
1 Parent(s): 47cc11e

from https://github.com/ArneBinder/pie-document-level/pull/243

Browse files
Files changed (7) hide show
  1. annotation_utils.py +26 -5
  2. app.py +63 -107
  3. document_store.py +35 -25
  4. embedding.py +46 -13
  5. model_utils.py +29 -5
  6. rendering_utils.py +168 -22
  7. requirements.txt +1 -0
annotation_utils.py CHANGED
@@ -1,10 +1,31 @@
1
- from pytorch_ie.annotations import LabeledSpan
2
 
 
3
 
4
- def labeled_span_to_id(span: LabeledSpan) -> str:
5
- return f"span-{span.start}-{span.end}-{span.label}"
6
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- def labeled_span_from_id(span_id: str) -> LabeledSpan:
 
9
  parts = span_id.split("-")
10
- return LabeledSpan(int(parts[1]), int(parts[2]), parts[3])
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
 
3
+ from pytorch_ie.annotations import LabeledMultiSpan, LabeledSpan
4
 
 
 
5
 
6
+ def labeled_span_to_id(span: Union[LabeledSpan, LabeledMultiSpan]) -> str:
7
+ if isinstance(span, LabeledSpan):
8
+ # {type indicator}-{start}-{end}-{label}
9
+ return f"span-{span.start}-{span.end}-{span.label}"
10
+ elif isinstance(span, LabeledMultiSpan):
11
+ # {type indicator}-({start}-{end})*-{label
12
+ starts_ends = "-".join(f"{start}-{end}" for start, end in span.slices)
13
+ return f"multispan-{starts_ends}-{span.label}"
14
+ else:
15
+ raise ValueError(f"Unsupported span type: {type(span)}")
16
 
17
+
18
+ def labeled_span_from_id(span_id: str) -> Union[LabeledSpan, LabeledMultiSpan]:
19
  parts = span_id.split("-")
20
+ if parts[0] == "span":
21
+ return LabeledSpan(int(parts[1]), int(parts[2]), parts[3])
22
+ elif parts[0] == "multispan":
23
+ label = parts[-1]
24
+ # this contains: start1, end1, start2, end2, ...
25
+ starts_ends = parts[1:-1]
26
+ slices = tuple(
27
+ (int(start), int(end)) for start, end in zip(starts_ends[::2], starts_ends[1::2])
28
+ )
29
+ return LabeledMultiSpan(slices, label)
30
+ else:
31
+ raise ValueError(f"Unsupported span id: {span_id}")
app.py CHANGED
@@ -4,7 +4,7 @@ import os.path
4
  import re
5
  import tempfile
6
  from functools import partial
7
- from typing import List, Optional, Tuple
8
 
9
  import gradio as gr
10
  import pandas as pd
@@ -14,8 +14,11 @@ from embedding import EmbeddingModel
14
  from model_utils import annotate_document, create_document, load_models
15
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
16
  from pytorch_ie import Pipeline
17
- from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
18
- from rendering_utils import render_displacy, render_pretty_table
 
 
 
19
  from transformers import PreTrainedModel, PreTrainedTokenizer
20
  from vector_store import QdrantVectorStore, SimpleVectorStore
21
 
@@ -35,6 +38,10 @@ DEFAULT_EMBEDDING_MAX_LENGTH = 512
35
  DEFAULT_EMBEDDING_BATCH_SIZE = 32
36
  DEFAULT_SPLIT_REGEX = "\n\n\n+"
37
 
 
 
 
 
38
 
39
  def escape_regex(regex: str) -> str:
40
  # "double escape" the backslashes
@@ -49,7 +56,10 @@ def unescape_regex(regex: str) -> str:
49
 
50
 
51
  def render_annotated_document(
52
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
 
 
 
53
  render_with: str,
54
  render_kwargs_json: str,
55
  ) -> str:
@@ -70,7 +80,14 @@ def wrapped_process_text(
70
  models: Tuple[Pipeline, Optional[EmbeddingModel]],
71
  document_store: DocumentStore,
72
  split_regex_escaped: str,
73
- ) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
 
 
 
 
 
 
 
74
  try:
75
  document = create_document(
76
  text=text,
@@ -79,10 +96,11 @@ def wrapped_process_text(
79
  if len(split_regex_escaped) > 0
80
  else None,
81
  )
82
- annotate_document(
83
  document=document,
84
  annotation_pipeline=models[0],
85
  embedding_model=models[1],
 
86
  )
87
  document_store.add_document(document)
88
  except Exception as e:
@@ -100,6 +118,8 @@ def process_uploaded_files(
100
  document_store: DocumentStore,
101
  split_regex_escaped: str,
102
  show_max_cross_doc_sims: bool = False,
 
 
103
  ) -> pd.DataFrame:
104
  try:
105
  new_documents = []
@@ -117,10 +137,11 @@ def process_uploaded_files(
117
  if len(split_regex_escaped) > 0
118
  else None,
119
  )
120
- annotate_document(
121
  document=new_document,
122
  annotation_pipeline=models[0],
123
  embedding_model=models[1],
 
124
  )
125
  new_documents.append(new_document)
126
  else:
@@ -129,7 +150,9 @@ def process_uploaded_files(
129
  except Exception as e:
130
  raise gr.Error(f"Failed to process uploaded files: {e}")
131
 
132
- return document_store.overview(with_max_cross_doc_sims=show_max_cross_doc_sims)
 
 
133
 
134
 
135
  def open_accordion():
@@ -144,9 +167,15 @@ def select_processed_document(
144
  evt: gr.SelectData,
145
  processed_documents_df: pd.DataFrame,
146
  document_store: DocumentStore,
147
- ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
 
 
 
148
  row_idx, col_idx = evt.index
149
- doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
 
 
 
150
  doc = document_store.get_document(doc_id, with_embeddings=False)
151
  return doc
152
 
@@ -231,6 +260,12 @@ def main():
231
  span_annotation_caption="adu",
232
  relation_annotation_caption="relation",
233
  vector_store=QdrantVectorStore(),
 
 
 
 
 
 
234
  )
235
  )
236
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
@@ -399,14 +434,20 @@ def main():
399
 
400
  show_overview_kwargs = dict(
401
  fn=lambda document_store, show_max_sims, min_sim: document_store.overview(
402
- with_max_cross_doc_sims=show_max_sims
403
  ),
404
  inputs=[document_store_state, show_max_cross_docu_sims, min_similarity],
405
  outputs=[processed_documents_df],
406
  )
407
  predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
408
- fn=wrapped_process_text,
409
- inputs=[doc_text, doc_id, models_state, document_store_state, split_regex_escaped],
 
 
 
 
 
 
410
  outputs=[document_json, document_state],
411
  api_name="predict",
412
  ).success(**show_overview_kwargs)
@@ -423,13 +464,14 @@ def main():
423
  upload_btn.upload(
424
  fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
425
  ).then(
426
- fn=process_uploaded_files,
427
  inputs=[
428
  upload_btn,
429
  models_state,
430
  document_store_state,
431
  split_regex_escaped,
432
  show_max_cross_docu_sims,
 
433
  ],
434
  outputs=[processed_documents_df],
435
  )
@@ -470,7 +512,9 @@ def main():
470
  selected_adu_id.change(
471
  fn=partial(
472
  get_annotation_from_document,
473
- annotation_layer="labeled_spans",
 
 
474
  use_predictions=True,
475
  ),
476
  inputs=[document_state, selected_adu_id],
@@ -483,7 +527,9 @@ def main():
483
  ref_document=document,
484
  min_similarity=min_sim,
485
  top_k=k,
486
- annotation_layer="labeled_spans",
 
 
487
  ),
488
  inputs=[
489
  document_store_state,
@@ -513,97 +559,7 @@ def main():
513
  # **retrieve_relevant_adus_event_kwargs
514
  # )
515
 
516
- js = """
517
- () => {
518
- function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
519
- var color = entity.getAttribute('data-color-' + colorAttributeKey);
520
- // if color is a json string, parse it and use the value at colorDictKey
521
- try {
522
- const colors = JSON.parse(color);
523
- color = colors[colorDictKey];
524
- } catch (e) {}
525
- if (color) {
526
- entity.style.backgroundColor = color;
527
- entity.style.color = '#000';
528
- }
529
- }
530
-
531
- function highlightRelationArguments(entityId) {
532
- const entities = document.querySelectorAll('.entity');
533
- // reset all entities
534
- entities.forEach(entity => {
535
- const color = entity.getAttribute('data-color-original');
536
- entity.style.backgroundColor = color;
537
- entity.style.color = '';
538
- });
539
-
540
- if (entityId !== null) {
541
- var visitedEntities = new Set();
542
- // highlight selected entity
543
- const selectedEntity = document.getElementById(entityId);
544
- if (selectedEntity) {
545
- const label = selectedEntity.getAttribute('data-label');
546
- maybeSetColor(selectedEntity, 'selected', label);
547
- visitedEntities.add(selectedEntity);
548
- }
549
- // highlight tails
550
- const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
551
- relationTailsAndLabels.forEach(relationTail => {
552
- const tailEntity = document.getElementById(relationTail['entity-id']);
553
- if (tailEntity) {
554
- const label = relationTail['label'];
555
- maybeSetColor(tailEntity, 'tail', label);
556
- visitedEntities.add(tailEntity);
557
- }
558
- });
559
- // highlight heads
560
- const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
561
- relationHeadsAndLabels.forEach(relationHead => {
562
- const headEntity = document.getElementById(relationHead['entity-id']);
563
- if (headEntity) {
564
- const label = relationHead['label'];
565
- maybeSetColor(headEntity, 'head', label);
566
- visitedEntities.add(headEntity);
567
- }
568
- });
569
- // highlight other entities
570
- entities.forEach(entity => {
571
- if (!visitedEntities.has(entity)) {
572
- const label = entity.getAttribute('data-label');
573
- maybeSetColor(entity, 'other', label);
574
- }
575
- });
576
- }
577
- }
578
- function setReferenceAduId(entityId) {
579
- // get the textarea element that holds the reference adu id
580
- let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
581
- // set the value of the input field
582
- referenceAduIdDiv.value = entityId;
583
- // trigger an input event to update the state
584
- var event = new Event('input');
585
- referenceAduIdDiv.dispatchEvent(event);
586
- }
587
-
588
- const entities = document.querySelectorAll('.entity');
589
- entities.forEach(entity => {
590
- const alreadyHasListener = entity.getAttribute('data-has-listener');
591
- if (alreadyHasListener) {
592
- return;
593
- }
594
- entity.addEventListener('mouseover', () => {
595
- highlightRelationArguments(entity.id);
596
- setReferenceAduId(entity.id);
597
- });
598
- entity.addEventListener('mouseout', () => {
599
- highlightRelationArguments(null);
600
- });
601
- entity.setAttribute('data-has-listener', 'true');
602
- });
603
- }
604
- """
605
-
606
- rendered_output.change(fn=None, js=js, inputs=[], outputs=[])
607
 
608
  demo.launch()
609
 
 
4
  import re
5
  import tempfile
6
  from functools import partial
7
+ from typing import List, Optional, Tuple, Union
8
 
9
  import gradio as gr
10
  import pandas as pd
 
14
  from model_utils import annotate_document, create_document, load_models
15
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
16
  from pytorch_ie import Pipeline
17
+ from pytorch_ie.documents import (
18
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
19
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
20
+ )
21
+ from rendering_utils import HIGHLIGHT_SPANS_JS, render_displacy, render_pretty_table
22
  from transformers import PreTrainedModel, PreTrainedTokenizer
23
  from vector_store import QdrantVectorStore, SimpleVectorStore
24
 
 
38
  DEFAULT_EMBEDDING_BATCH_SIZE = 32
39
  DEFAULT_SPLIT_REGEX = "\n\n\n+"
40
 
41
+ # Whether to handle segmented entities in the document. If True, labeled_spans are converted
42
+ # to labeled_multi_spans and binary_relations with label "parts_of_same" are used to merge them.
43
+ HANDLE_PARTS_OF_SAME = True
44
+
45
 
46
  def escape_regex(regex: str) -> str:
47
  # "double escape" the backslashes
 
56
 
57
 
58
  def render_annotated_document(
59
+ document: Union[
60
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
61
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
62
+ ],
63
  render_with: str,
64
  render_kwargs_json: str,
65
  ) -> str:
 
80
  models: Tuple[Pipeline, Optional[EmbeddingModel]],
81
  document_store: DocumentStore,
82
  split_regex_escaped: str,
83
+ handle_parts_of_same: bool = False,
84
+ ) -> Tuple[
85
+ dict,
86
+ Union[
87
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
88
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
89
+ ],
90
+ ]:
91
  try:
92
  document = create_document(
93
  text=text,
 
96
  if len(split_regex_escaped) > 0
97
  else None,
98
  )
99
+ document = annotate_document(
100
  document=document,
101
  annotation_pipeline=models[0],
102
  embedding_model=models[1],
103
+ handle_parts_of_same=handle_parts_of_same,
104
  )
105
  document_store.add_document(document)
106
  except Exception as e:
 
118
  document_store: DocumentStore,
119
  split_regex_escaped: str,
120
  show_max_cross_doc_sims: bool = False,
121
+ min_similarity: float = 0.95,
122
+ handle_parts_of_same: bool = False,
123
  ) -> pd.DataFrame:
124
  try:
125
  new_documents = []
 
137
  if len(split_regex_escaped) > 0
138
  else None,
139
  )
140
+ new_document = annotate_document(
141
  document=new_document,
142
  annotation_pipeline=models[0],
143
  embedding_model=models[1],
144
+ handle_parts_of_same=handle_parts_of_same,
145
  )
146
  new_documents.append(new_document)
147
  else:
 
150
  except Exception as e:
151
  raise gr.Error(f"Failed to process uploaded files: {e}")
152
 
153
+ return document_store.overview(
154
+ with_max_cross_doc_sims=show_max_cross_doc_sims, min_similarity=min_similarity
155
+ )
156
 
157
 
158
  def open_accordion():
 
167
  evt: gr.SelectData,
168
  processed_documents_df: pd.DataFrame,
169
  document_store: DocumentStore,
170
+ ) -> Union[
171
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
172
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
173
+ ]:
174
  row_idx, col_idx = evt.index
175
+ col_name = processed_documents_df.columns[col_idx]
176
+ if not col_name.endswith("doc_id"):
177
+ col_name = "doc_id"
178
+ doc_id = processed_documents_df.iloc[row_idx][col_name]
179
  doc = document_store.get_document(doc_id, with_embeddings=False)
180
  return doc
181
 
 
260
  span_annotation_caption="adu",
261
  relation_annotation_caption="relation",
262
  vector_store=QdrantVectorStore(),
263
+ document_type=TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
264
+ if not HANDLE_PARTS_OF_SAME
265
+ else TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
266
+ span_layer_name="labeled_spans"
267
+ if not HANDLE_PARTS_OF_SAME
268
+ else "labeled_multi_spans",
269
  )
270
  )
271
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
 
434
 
435
  show_overview_kwargs = dict(
436
  fn=lambda document_store, show_max_sims, min_sim: document_store.overview(
437
+ with_max_cross_doc_sims=show_max_sims, min_similarity=min_sim
438
  ),
439
  inputs=[document_store_state, show_max_cross_docu_sims, min_similarity],
440
  outputs=[processed_documents_df],
441
  )
442
  predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
443
+ fn=partial(wrapped_process_text, handle_parts_of_same=HANDLE_PARTS_OF_SAME),
444
+ inputs=[
445
+ doc_text,
446
+ doc_id,
447
+ models_state,
448
+ document_store_state,
449
+ split_regex_escaped,
450
+ ],
451
  outputs=[document_json, document_state],
452
  api_name="predict",
453
  ).success(**show_overview_kwargs)
 
464
  upload_btn.upload(
465
  fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
466
  ).then(
467
+ fn=partial(process_uploaded_files, handle_parts_of_same=HANDLE_PARTS_OF_SAME),
468
  inputs=[
469
  upload_btn,
470
  models_state,
471
  document_store_state,
472
  split_regex_escaped,
473
  show_max_cross_docu_sims,
474
+ min_similarity,
475
  ],
476
  outputs=[processed_documents_df],
477
  )
 
512
  selected_adu_id.change(
513
  fn=partial(
514
  get_annotation_from_document,
515
+ annotation_layer="labeled_spans"
516
+ if not HANDLE_PARTS_OF_SAME
517
+ else "labeled_multi_spans",
518
  use_predictions=True,
519
  ),
520
  inputs=[document_state, selected_adu_id],
 
527
  ref_document=document,
528
  min_similarity=min_sim,
529
  top_k=k,
530
+ annotation_layer="labeled_spans"
531
+ if not HANDLE_PARTS_OF_SAME
532
+ else "labeled_multi_spans",
533
  ),
534
  inputs=[
535
  document_store_state,
 
559
  # **retrieve_relevant_adus_event_kwargs
560
  # )
561
 
562
+ rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
  demo.launch()
565
 
document_store.py CHANGED
@@ -14,6 +14,7 @@ from annotation_utils import labeled_span_to_id
14
  from pytorch_ie import Annotation
15
  from pytorch_ie.documents import (
16
  TextBasedDocument,
 
17
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
18
  )
19
  from scipy.sparse import csr_matrix
@@ -45,7 +46,7 @@ def get_annotation_from_document(
45
  if use_predictions:
46
  annotations = annotations.predictions
47
 
48
- if annotation_layer == "labeled_spans":
49
  annotation_to_id_func = labeled_span_to_id
50
  else:
51
  raise gr.Error(f"Unknown annotation layer '{annotation_layer}'.")
@@ -301,6 +302,12 @@ class DocumentStore:
301
 
302
  def add_document(self, document: TextBasedDocument) -> None:
303
  try:
 
 
 
 
 
 
304
  if document.id in self.documents:
305
  gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
306
 
@@ -485,6 +492,11 @@ class DocumentStore:
485
 
486
  max_doc_ids = max_doc2doc_similarities.idxmax(axis="columns")
487
  max_similarities = max_doc2doc_similarities.max(axis="columns")
 
 
 
 
 
488
 
489
  # set the index to the doc_id to correctly join the series
490
  df.set_index("doc_id", inplace=True)
@@ -551,7 +563,8 @@ class DocumentStore:
551
  # set similarities below min_similarity to 0
552
  similarities[similarities < min_similarity] = 0.0
553
 
554
- # set triangular part to 0
 
555
  similarities = np.triu(similarities, k=1)
556
  # create a sparse matrix
557
  sparse_matrix = csr_matrix(similarities)
@@ -564,29 +577,26 @@ class DocumentStore:
564
 
565
  # construct the DataFrame
566
  records = []
567
- for idx1, idx2 in zip(non_zero_idx[0], non_zero_idx[1]):
568
- if idx1 < idx2:
569
- doc_id1 = all_payloads[idx1]["doc_id"]
570
- doc_id2 = all_payloads[idx2]["doc_id"]
571
- annotation_id1 = all_payloads[idx1]["annotation_id"]
572
- annotation_id2 = all_payloads[idx2]["annotation_id"]
573
- annotation_text1 = doc_id_and_annotation_id2annotation_text[
574
- (doc_id1, annotation_id1)
575
- ]
576
- annotation_text2 = doc_id_and_annotation_id2annotation_text[
577
- (doc_id2, annotation_id2)
578
- ]
579
- records.append(
580
- {
581
- "sim_score": scores[idx1],
582
- "doc_id": doc_id1,
583
- "other_doc_id": doc_id2,
584
- "adu_id": annotation_id1,
585
- "other_adu_id": annotation_id2,
586
- "text": annotation_text1,
587
- "other_text": annotation_text2,
588
- }
589
- )
590
  result_df = pd.DataFrame(records)
591
  gr.Info(f"DataFrame shape: {result_df.shape}")
592
 
 
14
  from pytorch_ie import Annotation
15
  from pytorch_ie.documents import (
16
  TextBasedDocument,
17
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
18
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
19
  )
20
  from scipy.sparse import csr_matrix
 
46
  if use_predictions:
47
  annotations = annotations.predictions
48
 
49
+ if annotation_layer in ["labeled_spans", "labeled_multi_spans"]:
50
  annotation_to_id_func = labeled_span_to_id
51
  else:
52
  raise gr.Error(f"Unknown annotation layer '{annotation_layer}'.")
 
302
 
303
  def add_document(self, document: TextBasedDocument) -> None:
304
  try:
305
+ if not isinstance(document, self.document_type):
306
+ raise gr.Error(
307
+ f"The document to add must be of type {self.document_type}, but is of type "
308
+ f"{type(document)}."
309
+ )
310
+
311
  if document.id in self.documents:
312
  gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
313
 
 
492
 
493
  max_doc_ids = max_doc2doc_similarities.idxmax(axis="columns")
494
  max_similarities = max_doc2doc_similarities.max(axis="columns")
495
+ # entries where max_similarities is -inf are documents with no entries in other documents
496
+ # with similarity > min_similarity
497
+ mask = max_similarities == -np.inf
498
+ max_doc_ids[mask] = np.nan
499
+ max_similarities[mask] = np.nan
500
 
501
  # set the index to the doc_id to correctly join the series
502
  df.set_index("doc_id", inplace=True)
 
563
  # set similarities below min_similarity to 0
564
  similarities[similarities < min_similarity] = 0.0
565
 
566
+ # set triangular part to 0 because we only want the upper triangular part which
567
+ # contains entries with idx1 < idx2
568
  similarities = np.triu(similarities, k=1)
569
  # create a sparse matrix
570
  sparse_matrix = csr_matrix(similarities)
 
577
 
578
  # construct the DataFrame
579
  records = []
580
+ for sparse_idx, (idx1, idx2) in enumerate(zip(non_zero_idx[0], non_zero_idx[1])):
581
+ payload1 = all_payloads[idx1]
582
+ payload2 = all_payloads[idx2]
583
+ doc_id1 = payload1["doc_id"]
584
+ doc_id2 = payload2["doc_id"]
585
+ annotation_id1 = payload1["annotation_id"]
586
+ annotation_id2 = payload2["annotation_id"]
587
+ annotation_text1 = doc_id_and_annotation_id2annotation_text[(doc_id1, annotation_id1)]
588
+ annotation_text2 = doc_id_and_annotation_id2annotation_text[(doc_id2, annotation_id2)]
589
+ records.append(
590
+ {
591
+ "sim_score": scores[sparse_idx],
592
+ "doc_id": doc_id1,
593
+ "other_doc_id": doc_id2,
594
+ "adu_id": annotation_id1,
595
+ "other_adu_id": annotation_id2,
596
+ "text": annotation_text1,
597
+ "other_text": annotation_text2,
598
+ }
599
+ )
 
 
 
600
  result_df = pd.DataFrame(records)
601
  gr.Info(f"DataFrame shape: {result_df.shape}")
602
 
embedding.py CHANGED
@@ -1,12 +1,15 @@
1
  import abc
2
  import logging
3
- from typing import Dict
4
 
5
  import torch
6
  from datasets import Dataset
7
  from pie_modules.document.processing import tokenize_document
8
- from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
9
- from pytorch_ie.annotations import Span
 
 
 
10
  from pytorch_ie.documents import TextBasedDocument
11
  from torch import FloatTensor, Tensor
12
  from torch.utils.data import DataLoader
@@ -18,7 +21,7 @@ logger = logging.getLogger(__name__)
18
  class EmbeddingModel(abc.ABC):
19
  def __call__(
20
  self, document: TextBasedDocument, span_layer_name: str
21
- ) -> Dict[Span, FloatTensor]:
22
  """Embed text annotations from a document.
23
 
24
  Args:
@@ -51,7 +54,7 @@ class HuggingfaceEmbeddingModel(EmbeddingModel):
51
 
52
  def __call__(
53
  self, document: TextBasedDocument, span_layer_name: str
54
- ) -> Dict[Span, FloatTensor]:
55
  # to not modify the original document
56
  document = document.copy()
57
  # tokenize_document does not yet consider predictions, so we need to add them manually
@@ -65,10 +68,21 @@ class HuggingfaceEmbeddingModel(EmbeddingModel):
65
  "return_overflowing_tokens": True,
66
  }
67
  # tokenize once to get the tokenized documents with mapped annotations
 
 
 
 
 
 
 
 
 
 
 
68
  tokenized_documents = tokenize_document(
69
  document,
70
  tokenizer=self._tokenizer,
71
- result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
72
  partition_layer="labeled_partitions",
73
  added_annotations=added_annotations,
74
  strict_span_conversion=False,
@@ -104,14 +118,33 @@ class HuggingfaceEmbeddingModel(EmbeddingModel):
104
  for last_hidden_state in model_output.last_hidden_state:
105
  text2tok_ann = added_annotations[example_idx][span_layer_name]
106
  tok2text_ann = {v: k for k, v in text2tok_ann.items()}
107
- for tok_ann in tokenized_documents[example_idx].labeled_spans:
108
- # skip "empty" annotations
109
- if tok_ann.start == tok_ann.end:
110
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # use the max pooling strategy to get a single embedding for the annotation text
112
- embedding = (
113
- last_hidden_state[tok_ann.start : tok_ann.end].max(dim=0)[0].detach().cpu()
114
- )
115
  text_ann = tok2text_ann[tok_ann]
116
 
117
  # if text_ann in embeddings:
 
1
  import abc
2
  import logging
3
+ from typing import Dict, Union
4
 
5
  import torch
6
  from datasets import Dataset
7
  from pie_modules.document.processing import tokenize_document
8
+ from pie_modules.documents import (
9
+ TokenDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
10
+ TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
11
+ )
12
+ from pytorch_ie.annotations import LabeledSpan, MultiSpan, Span
13
  from pytorch_ie.documents import TextBasedDocument
14
  from torch import FloatTensor, Tensor
15
  from torch.utils.data import DataLoader
 
21
  class EmbeddingModel(abc.ABC):
22
  def __call__(
23
  self, document: TextBasedDocument, span_layer_name: str
24
+ ) -> Dict[Union[Span, MultiSpan], FloatTensor]:
25
  """Embed text annotations from a document.
26
 
27
  Args:
 
54
 
55
  def __call__(
56
  self, document: TextBasedDocument, span_layer_name: str
57
+ ) -> Dict[Union[Span, MultiSpan], FloatTensor]:
58
  # to not modify the original document
59
  document = document.copy()
60
  # tokenize_document does not yet consider predictions, so we need to add them manually
 
68
  "return_overflowing_tokens": True,
69
  }
70
  # tokenize once to get the tokenized documents with mapped annotations
71
+ span_annotation_type = document.annotation_types()[span_layer_name]
72
+ if issubclass(span_annotation_type, Span):
73
+ result_document_type = TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
74
+ tokenized_span_layer_name = "labeled_spans"
75
+ elif issubclass(span_annotation_type, MultiSpan):
76
+ result_document_type = (
77
+ TokenDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
78
+ )
79
+ tokenized_span_layer_name = "labeled_multi_spans"
80
+ else:
81
+ raise ValueError(f"Unsupported annotation type: {span_annotation_type}")
82
  tokenized_documents = tokenize_document(
83
  document,
84
  tokenizer=self._tokenizer,
85
+ result_document_type=result_document_type,
86
  partition_layer="labeled_partitions",
87
  added_annotations=added_annotations,
88
  strict_span_conversion=False,
 
118
  for last_hidden_state in model_output.last_hidden_state:
119
  text2tok_ann = added_annotations[example_idx][span_layer_name]
120
  tok2text_ann = {v: k for k, v in text2tok_ann.items()}
121
+ for tok_ann in tokenized_documents[example_idx][tokenized_span_layer_name]:
122
+ if isinstance(tok_ann, LabeledSpan):
123
+ # skip "empty" annotations
124
+ if tok_ann.start == tok_ann.end:
125
+ continue
126
+
127
+ embedded_tokens = last_hidden_state[tok_ann.start : tok_ann.end]
128
+
129
+ elif isinstance(tok_ann, MultiSpan):
130
+ # skip "empty" annotations
131
+ if all(start == end for start, end in tok_ann.slices):
132
+ continue
133
+
134
+ # concatenate the embeddings of the tokens that make up the multi-span
135
+ embedded_tokens = torch.concat(
136
+ [
137
+ last_hidden_state[start:end]
138
+ for start, end in tok_ann.slices
139
+ if start != end
140
+ ],
141
+ dim=0,
142
+ )
143
+ else:
144
+ raise ValueError(f"Unsupported annotation type: {type(tok_ann)}")
145
  # use the max pooling strategy to get a single embedding for the annotation text
146
+ embedding = embedded_tokens.max(dim=0)[0].detach().cpu()
147
+
 
148
  text_ann = tok2text_ann[tok_ann]
149
 
150
  # if text_ann in embeddings:
model_utils.py CHANGED
@@ -1,15 +1,18 @@
1
  import logging
2
- from typing import Optional, Tuple
3
 
4
  import gradio as gr
5
  import torch
6
  from annotation_utils import labeled_span_to_id
7
  from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
8
- from pie_modules.document.processing import RegexPartitioner
9
  from pytorch_ie import Pipeline
10
  from pytorch_ie.annotations import LabeledSpan
11
  from pytorch_ie.auto import AutoPipeline
12
- from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
 
 
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
@@ -18,7 +21,11 @@ def annotate_document(
18
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
19
  annotation_pipeline: Pipeline,
20
  embedding_model: Optional[EmbeddingModel] = None,
21
- ) -> None:
 
 
 
 
22
  """Annotate a document with the provided pipeline. If an embedding model is provided, also
23
  extract embeddings for the labeled spans.
24
 
@@ -26,15 +33,30 @@ def annotate_document(
26
  document: The document to annotate.
27
  annotation_pipeline: The pipeline to use for annotation.
28
  embedding_model: The embedding model to use for extracting text span embeddings.
 
29
  """
30
 
31
  # execute prediction pipeline
32
  annotation_pipeline(document)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  if embedding_model is not None:
35
  text_span_embeddings = embedding_model(
36
  document=document,
37
- span_layer_name="labeled_spans",
38
  )
39
  # convert keys to str because JSON keys must be strings
40
  text_span_embeddings_dict = {
@@ -47,6 +69,8 @@ def annotate_document(
47
  "model in the 'Model Configuration' section."
48
  )
49
 
 
 
50
 
51
  def create_document(
52
  text: str, doc_id: str, split_regex: Optional[str] = None
 
1
  import logging
2
+ from typing import Optional, Tuple, Union
3
 
4
  import gradio as gr
5
  import torch
6
  from annotation_utils import labeled_span_to_id
7
  from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
8
+ from pie_modules.document.processing import RegexPartitioner, SpansViaRelationMerger
9
  from pytorch_ie import Pipeline
10
  from pytorch_ie.annotations import LabeledSpan
11
  from pytorch_ie.auto import AutoPipeline
12
+ from pytorch_ie.documents import (
13
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
14
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
15
+ )
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
21
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
22
  annotation_pipeline: Pipeline,
23
  embedding_model: Optional[EmbeddingModel] = None,
24
+ handle_parts_of_same: bool = False,
25
+ ) -> Union[
26
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
27
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
28
+ ]:
29
  """Annotate a document with the provided pipeline. If an embedding model is provided, also
30
  extract embeddings for the labeled spans.
31
 
 
33
  document: The document to annotate.
34
  annotation_pipeline: The pipeline to use for annotation.
35
  embedding_model: The embedding model to use for extracting text span embeddings.
36
+ handle_parts_of_same: Whether to merge spans that are part of the same entity into a single multi span.
37
  """
38
 
39
  # execute prediction pipeline
40
  annotation_pipeline(document)
41
 
42
+ if handle_parts_of_same:
43
+ merger = SpansViaRelationMerger(
44
+ relation_layer="binary_relations",
45
+ link_relation_label="parts_of_same",
46
+ create_multi_spans=True,
47
+ result_document_type=TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
48
+ result_field_mapping={
49
+ "labeled_spans": "labeled_multi_spans",
50
+ "binary_relations": "binary_relations",
51
+ "labeled_partitions": "labeled_partitions",
52
+ },
53
+ )
54
+ document = merger(document)
55
+
56
  if embedding_model is not None:
57
  text_span_embeddings = embedding_model(
58
  document=document,
59
+ span_layer_name="labeled_spans" if not handle_parts_of_same else "labeled_multi_spans",
60
  )
61
  # convert keys to str because JSON keys must be strings
62
  text_span_embeddings_dict = {
 
69
  "model in the 'Model Configuration' section."
70
  )
71
 
72
+ return document
73
+
74
 
75
  def create_document(
76
  text: str, doc_id: str, split_regex: Optional[str] = None
rendering_utils.py CHANGED
@@ -4,23 +4,130 @@ from collections import defaultdict
4
  from typing import Dict, List, Optional, Union
5
 
6
  from annotation_utils import labeled_span_to_id
7
- from pytorch_ie.annotations import BinaryRelation, LabeledSpan
8
- from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
 
 
 
9
  from rendering_utils_displacy import EntityRenderer
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
  # adjusted from rendering_utils_displacy.TPL_ENT
14
  TPL_ENT_WITH_ID = """
15
- <mark class="entity" id="{id}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
16
  {text}
17
  <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
18
  </mark>
19
  """
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def render_pretty_table(
23
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
 
 
 
 
24
  ):
25
  from prettytable import PrettyTable
26
 
@@ -37,27 +144,57 @@ def render_pretty_table(
37
 
38
 
39
  def render_displacy(
40
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
 
 
 
41
  inject_relations=True,
42
  colors_hover=None,
43
  entity_options={},
44
  **render_kwargs,
45
  ):
46
 
47
- labeled_spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  spacy_doc = {
49
  "text": document.text,
50
- "ents": [
51
- {
52
- "start": labeled_span.start,
53
- "end": labeled_span.end,
54
- "label": labeled_span.label,
55
- # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
56
- # on hover and to inject the relation data.
57
- "params": {"id": labeled_span_to_id(labeled_span)},
58
- }
59
- for labeled_span in labeled_spans
60
- ],
61
  "title": None,
62
  }
63
 
@@ -75,7 +212,7 @@ def render_displacy(
75
  )
76
  html = inject_relation_data(
77
  html,
78
- labeled_spans=labeled_spans,
79
  binary_relations=binary_relations,
80
  additional_colors=colors_hover,
81
  )
@@ -84,7 +221,7 @@ def render_displacy(
84
 
85
  def inject_relation_data(
86
  html: str,
87
- labeled_spans: List[LabeledSpan],
88
  binary_relations: List[BinaryRelation],
89
  additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
90
  ) -> str:
@@ -99,7 +236,7 @@ def inject_relation_data(
99
  entity2heads[relation.tail].append((relation.head, relation.label))
100
  entity2tails[relation.head].append((relation.tail, relation.label))
101
 
102
- ann_id2annotation = {labeled_span_to_id(entity): entity for entity in labeled_spans}
103
  # Add unique IDs to each entity
104
  entities = soup.find_all(class_="entity")
105
  for entity in entities:
@@ -110,12 +247,21 @@ def inject_relation_data(
110
  entity[f"data-color-{key}"] = (
111
  json.dumps(color) if isinstance(color, dict) else color
112
  )
113
- entity_annotation = ann_id2annotation[entity["id"]]
 
114
  # sanity check.
115
- annotation_text_without_newline = str(entity_annotation).replace("\n", "")
 
 
 
 
 
 
 
116
  # Just check the start, because the text has the label attached to the end
117
  if not entity.text.startswith(annotation_text_without_newline):
118
  logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
 
119
  entity["data-label"] = entity_annotation.label
120
  entity["data-relation-tails"] = json.dumps(
121
  [
 
4
  from typing import Dict, List, Optional, Union
5
 
6
  from annotation_utils import labeled_span_to_id
7
+ from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
8
+ from pytorch_ie.documents import (
9
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
10
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
11
+ )
12
  from rendering_utils_displacy import EntityRenderer
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
  # adjusted from rendering_utils_displacy.TPL_ENT
17
  TPL_ENT_WITH_ID = """
18
+ <mark class="entity" data-entity-id="{entity_id}" data-slice-idx="{slice_idx}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
19
  {text}
20
  <span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
21
  </mark>
22
  """
23
 
24
+ HIGHLIGHT_SPANS_JS = """
25
+ () => {
26
+ function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
27
+ var color = entity.getAttribute('data-color-' + colorAttributeKey);
28
+ // if color is a json string, parse it and use the value at colorDictKey
29
+ try {
30
+ const colors = JSON.parse(color);
31
+ color = colors[colorDictKey];
32
+ } catch (e) {}
33
+ if (color) {
34
+ entity.style.backgroundColor = color;
35
+ entity.style.color = '#000';
36
+ }
37
+ }
38
+
39
+ function highlightRelationArguments(entityId) {
40
+ const entities = document.querySelectorAll('.entity');
41
+ // reset all entities
42
+ entities.forEach(entity => {
43
+ const color = entity.getAttribute('data-color-original');
44
+ entity.style.backgroundColor = color;
45
+ entity.style.color = '';
46
+ });
47
+
48
+ if (entityId !== null) {
49
+ var visitedEntities = new Set();
50
+ // highlight selected entity
51
+ // get all elements with attribute data-entity-id==entityId
52
+ const selectedEntityParts = document.querySelectorAll(`[data-entity-id="${entityId}"]`);
53
+ selectedEntityParts.forEach(selectedEntityPart => {
54
+ const label = selectedEntityPart.getAttribute('data-label');
55
+ maybeSetColor(selectedEntityPart, 'selected', label);
56
+ visitedEntities.add(selectedEntityPart);
57
+ }); // <-- Corrected closing parenthesis here
58
+ // if there is at least one part, get the first one and ...
59
+ if (selectedEntityParts.length > 0) {
60
+ const selectedEntity = selectedEntityParts[0];
61
+
62
+ // ... highlight tails and ...
63
+ const relationTailsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-tails'));
64
+ relationTailsAndLabels.forEach(relationTail => {
65
+ const tailEntityId = relationTail['entity-id'];
66
+ const tailEntityParts = document.querySelectorAll(`[data-entity-id="${tailEntityId}"]`);
67
+ tailEntityParts.forEach(tailEntity => {
68
+ const label = relationTail['label'];
69
+ maybeSetColor(tailEntity, 'tail', label);
70
+ visitedEntities.add(tailEntity);
71
+ }); // <-- Corrected closing parenthesis here
72
+ }); // <-- Corrected closing parenthesis here
73
+ // .. highlight heads
74
+ const relationHeadsAndLabels = JSON.parse(selectedEntity.getAttribute('data-relation-heads'));
75
+ relationHeadsAndLabels.forEach(relationHead => {
76
+ const headEntityId = relationHead['entity-id'];
77
+ const headEntityParts = document.querySelectorAll(`[data-entity-id="${headEntityId}"]`);
78
+ headEntityParts.forEach(headEntity => {
79
+ const label = relationHead['label'];
80
+ maybeSetColor(headEntity, 'head', label);
81
+ visitedEntities.add(headEntity);
82
+ }); // <-- Corrected closing parenthesis here
83
+ }); // <-- Corrected closing parenthesis here
84
+ }
85
+
86
+ // highlight other entities
87
+ entities.forEach(entity => {
88
+ if (!visitedEntities.has(entity)) {
89
+ const label = entity.getAttribute('data-label');
90
+ maybeSetColor(entity, 'other', label);
91
+ }
92
+ });
93
+ }
94
+ }
95
+ function setReferenceAduId(entityId) {
96
+ // get the textarea element that holds the reference adu id
97
+ let referenceAduIdDiv = document.querySelector('#selected_adu_id textarea');
98
+ // set the value of the input field
99
+ referenceAduIdDiv.value = entityId;
100
+ // trigger an input event to update the state
101
+ var event = new Event('input');
102
+ referenceAduIdDiv.dispatchEvent(event);
103
+ }
104
+
105
+ const entities = document.querySelectorAll('.entity');
106
+ entities.forEach(entity => {
107
+ const alreadyHasListener = entity.getAttribute('data-has-listener');
108
+ if (alreadyHasListener) {
109
+ return;
110
+ }
111
+ entity.addEventListener('mouseover', () => {
112
+ const entityId = entity.getAttribute('data-entity-id');
113
+ highlightRelationArguments(entityId);
114
+ setReferenceAduId(entityId);
115
+ });
116
+ entity.addEventListener('mouseout', () => {
117
+ highlightRelationArguments(null);
118
+ });
119
+ entity.setAttribute('data-has-listener', 'true');
120
+ });
121
+ }
122
+ """
123
+
124
 
125
  def render_pretty_table(
126
+ document: Union[
127
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
128
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
129
+ ],
130
+ **render_kwargs,
131
  ):
132
  from prettytable import PrettyTable
133
 
 
144
 
145
 
146
  def render_displacy(
147
+ document: Union[
148
+ TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
149
+ TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions,
150
+ ],
151
  inject_relations=True,
152
  colors_hover=None,
153
  entity_options={},
154
  **render_kwargs,
155
  ):
156
 
157
+ if isinstance(document, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions):
158
+ span_layer = document.labeled_spans
159
+ elif isinstance(
160
+ document, TextDocumentWithLabeledMultiSpansBinaryRelationsAndLabeledPartitions
161
+ ):
162
+ span_layer = document.labeled_multi_spans
163
+ else:
164
+ raise ValueError(f"Unsupported document type: {type(document)}")
165
+
166
+ span_annotations = list(span_layer) + list(span_layer.predictions)
167
+ ents = []
168
+ for labeled_span in span_annotations:
169
+ entity_id = labeled_span_to_id(labeled_span)
170
+ # pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
171
+ # on hover and to inject the relation data.
172
+ if isinstance(labeled_span, LabeledSpan):
173
+ ents.append(
174
+ {
175
+ "start": labeled_span.start,
176
+ "end": labeled_span.end,
177
+ "label": labeled_span.label,
178
+ "params": {"entity_id": entity_id, "slice_idx": 0},
179
+ }
180
+ )
181
+ elif isinstance(labeled_span, LabeledMultiSpan):
182
+ for i, (start, end) in enumerate(labeled_span.slices):
183
+ ents.append(
184
+ {
185
+ "start": start,
186
+ "end": end,
187
+ "label": labeled_span.label,
188
+ "params": {"entity_id": entity_id, "slice_idx": i},
189
+ }
190
+ )
191
+ else:
192
+ raise ValueError(f"Unsupported labeled span type: {type(labeled_span)}")
193
+
194
  spacy_doc = {
195
  "text": document.text,
196
+ # the ents MUST be sorted by start and end
197
+ "ents": sorted(ents, key=lambda x: (x["start"], x["end"])),
 
 
 
 
 
 
 
 
 
198
  "title": None,
199
  }
200
 
 
212
  )
213
  html = inject_relation_data(
214
  html,
215
+ span_annotations=span_annotations,
216
  binary_relations=binary_relations,
217
  additional_colors=colors_hover,
218
  )
 
221
 
222
  def inject_relation_data(
223
  html: str,
224
+ span_annotations: Union[List[LabeledSpan], List[LabeledMultiSpan]],
225
  binary_relations: List[BinaryRelation],
226
  additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
227
  ) -> str:
 
236
  entity2heads[relation.tail].append((relation.head, relation.label))
237
  entity2tails[relation.head].append((relation.tail, relation.label))
238
 
239
+ ann_id2annotation = {labeled_span_to_id(entity): entity for entity in span_annotations}
240
  # Add unique IDs to each entity
241
  entities = soup.find_all(class_="entity")
242
  for entity in entities:
 
247
  entity[f"data-color-{key}"] = (
248
  json.dumps(color) if isinstance(color, dict) else color
249
  )
250
+ entity_annotation = ann_id2annotation[entity["data-entity-id"]]
251
+
252
  # sanity check.
253
+ if isinstance(entity_annotation, LabeledSpan):
254
+ annotation_text = entity_annotation.resolve()[1]
255
+ elif isinstance(entity_annotation, LabeledMultiSpan):
256
+ slice_idx = int(entity["data-slice-idx"])
257
+ annotation_text = entity_annotation.resolve()[1][slice_idx]
258
+ else:
259
+ raise ValueError(f"Unsupported entity type: {type(entity_annotation)}")
260
+ annotation_text_without_newline = annotation_text.replace("\n", "")
261
  # Just check the start, because the text has the label attached to the end
262
  if not entity.text.startswith(annotation_text_without_newline):
263
  logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
264
+
265
  entity["data-label"] = entity_annotation.label
266
  entity["data-relation-tails"] = json.dumps(
267
  [
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  gradio==4.36.0
2
  prettytable==3.10.0
3
  pie-modules==0.12.0
 
1
+ pytorch-ie==0.31.1
2
  gradio==4.36.0
3
  prettytable==3.10.0
4
  pie-modules==0.12.0