ArneBinder commited on
Commit
4467900
1 Parent(s): c049f99

Upload 7 files

Browse files
Files changed (5) hide show
  1. __init__.py +6 -0
  2. app.py +47 -310
  3. backend.py +300 -0
  4. rendering_utils.py +13 -4
  5. vector_store.py +20 -4
__init__.py CHANGED
@@ -3,3 +3,9 @@ import sys
3
 
4
  # add current folder to the python path
5
  sys.path.append(os.path.dirname(__file__))
 
 
 
 
 
 
 
3
 
4
  # add current folder to the python path
5
  sys.path.append(os.path.dirname(__file__))
6
+
7
+ # this is required to dynamically load the PIE models
8
+ from pie_modules.models import * # noqa: F403
9
+ from pie_modules.taskmodules import * # noqa: F403
10
+ from pytorch_ie.models import * # noqa: F403
11
+ from pytorch_ie.taskmodules import * # noqa: F403
app.py CHANGED
@@ -1,25 +1,19 @@
1
  import json
2
  import logging
3
  import os.path
4
- from collections import defaultdict
5
  from functools import partial
6
- from typing import Any, Dict, List, Optional, Tuple
7
 
8
  import gradio as gr
9
  import pandas as pd
10
- from pie_modules.document.processing import tokenize_document
11
- from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
12
- from pie_modules.models import * # noqa: F403
13
- from pie_modules.taskmodules import * # noqa: F403
14
  from pytorch_ie import Pipeline
15
- from pytorch_ie.annotations import LabeledSpan
16
  from pytorch_ie.auto import AutoPipeline
17
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
18
- from pytorch_ie.models import * # noqa: F403
19
- from pytorch_ie.taskmodules import * # noqa: F403
20
  from rendering_utils import render_displacy, render_pretty_table
21
  from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
22
- from vector_store import SimpleVectorStore
23
 
24
  logger = logging.getLogger(__name__)
25
 
@@ -34,91 +28,6 @@ DEFAULT_MODEL_REVISION = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
34
  DEFAULT_EMBEDDING_MODEL_NAME = "allenai/scibert_scivocab_uncased"
35
 
36
 
37
- def embed_text_annotations(
38
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
39
- model: PreTrainedModel,
40
- tokenizer: PreTrainedTokenizer,
41
- text_layer_name: str,
42
- ) -> dict:
43
- # to not modify the original document
44
- document = document.copy()
45
- # tokenize_document does not yet consider predictions, so we need to add them manually
46
- document[text_layer_name].extend(document[text_layer_name].predictions.clear())
47
- added_annotations = []
48
- tokenizer_kwargs = {
49
- "max_length": 512,
50
- "stride": 64,
51
- "truncation": True,
52
- "return_overflowing_tokens": True,
53
- }
54
- tokenized_documents = tokenize_document(
55
- document,
56
- tokenizer=tokenizer,
57
- result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
58
- partition_layer="labeled_partitions",
59
- added_annotations=added_annotations,
60
- strict_span_conversion=False,
61
- **tokenizer_kwargs,
62
- )
63
- # just tokenize again to get tensors in the correct format for the model
64
- # TODO: fix for A34.txt from sciarg corpus
65
- model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
66
- # this is added when using return_overflowing_tokens=True, but the model does not accept it
67
- model_inputs.pop("overflow_to_sample_mapping", None)
68
- assert len(model_inputs.encodings) == len(tokenized_documents)
69
- model_output = model(**model_inputs)
70
-
71
- # get embeddings for all text annotations
72
- embeddings = {}
73
- for batch_idx in range(len(model_output.last_hidden_state)):
74
- text2tok_ann = added_annotations[batch_idx][text_layer_name]
75
- tok2text_ann = {v: k for k, v in text2tok_ann.items()}
76
- for tok_ann in tokenized_documents[batch_idx].labeled_spans:
77
- # skip "empty" annotations
78
- if tok_ann.start == tok_ann.end:
79
- continue
80
- # use the max pooling strategy to get a single embedding for the annotation text
81
- embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max(
82
- dim=0
83
- )[0]
84
- text_ann = tok2text_ann[tok_ann]
85
-
86
- if text_ann in embeddings:
87
- logger.warning(
88
- f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
89
- )
90
- embeddings[text_ann] = embedding
91
-
92
- return embeddings
93
-
94
-
95
- def annotate(
96
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
97
- pipeline: Pipeline,
98
- embedding_model: Optional[PreTrainedModel] = None,
99
- embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
100
- ) -> None:
101
-
102
- # execute prediction pipeline
103
- pipeline(document)
104
-
105
- if embedding_model is not None and embedding_tokenizer is not None:
106
- adu_embeddings = embed_text_annotations(
107
- document=document,
108
- model=embedding_model,
109
- tokenizer=embedding_tokenizer,
110
- text_layer_name="labeled_spans",
111
- )
112
- # convert keys to str because JSON keys must be strings
113
- adu_embeddings_dict = {str(k._id): v.detach().tolist() for k, v in adu_embeddings.items()}
114
- document.metadata["embeddings"] = adu_embeddings_dict
115
- else:
116
- gr.Warning(
117
- "No embedding model provided. Skipping embedding extraction. You can load an embedding "
118
- "model in the 'Model Configuration' section."
119
- )
120
-
121
-
122
  def render_annotated_document(
123
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
124
  render_with: str,
@@ -135,57 +44,6 @@ def render_annotated_document(
135
  return html
136
 
137
 
138
- def add_to_index(
139
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
140
- processed_documents: dict,
141
- vector_store: SimpleVectorStore,
142
- ) -> None:
143
- try:
144
- if document.id in processed_documents:
145
- gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
146
- # save the processed document to the index
147
- processed_documents[document.id] = document
148
- # save the embeddings to the vector store
149
- for adu_id, embedding in document.metadata["embeddings"].items():
150
- vector_store.save((document.id, adu_id), embedding)
151
- gr.Info(
152
- f"Added document {document.id} to index (index contains {len(processed_documents)} "
153
- f"documents and {len(vector_store)} embeddings)."
154
- )
155
- except Exception as e:
156
- raise gr.Error(f"Failed to add document {document.id} to index: {e}")
157
-
158
-
159
- def process_text(
160
- text: str,
161
- doc_id: str,
162
- models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
163
- processed_documents: dict[
164
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
165
- ],
166
- vector_store: SimpleVectorStore,
167
- ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
168
- try:
169
- document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
170
- id=doc_id, text=text, metadata={}
171
- )
172
- # add single partition from the whole text (the model only considers text in partitions)
173
- document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
174
- # annotate the document
175
- annotate(
176
- document=document,
177
- pipeline=models[0],
178
- embedding_model=models[1],
179
- embedding_tokenizer=models[2],
180
- )
181
- # add the document to the index
182
- add_to_index(document, processed_documents, vector_store)
183
-
184
- return document
185
- except Exception as e:
186
- raise gr.Error(f"Failed to process text: {e}")
187
-
188
-
189
  def wrapped_process_text(
190
  text: str,
191
  doc_id: str,
@@ -193,7 +51,7 @@ def wrapped_process_text(
193
  processed_documents: dict[
194
  str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
195
  ],
196
- vector_store: SimpleVectorStore,
197
  ) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
198
  document = process_text(
199
  text=text,
@@ -206,13 +64,13 @@ def wrapped_process_text(
206
  return document.asdict(), document
207
 
208
 
209
- def process_uploaded_file(
210
  file_names: List[str],
211
  models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
212
  processed_documents: dict[
213
  str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
214
  ],
215
- vector_store: SimpleVectorStore,
216
  ) -> None:
217
  try:
218
  for file_name in file_names:
@@ -229,164 +87,6 @@ def process_uploaded_file(
229
  raise gr.Error(f"Failed to process uploaded files: {e}")
230
 
231
 
232
- def _get_annotation_from_document(
233
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
234
- annotation_id: str,
235
- annotation_layer: str,
236
- ) -> LabeledSpan:
237
- # use predictions
238
- annotations = document[annotation_layer].predictions
239
- id2annotation = {str(annotation._id): annotation for annotation in annotations}
240
- annotation = id2annotation.get(annotation_id)
241
- if annotation is None:
242
- raise gr.Error(
243
- f"annotation '{annotation_id}' not found in document '{document.id}'. Available "
244
- f"annotations: {id2annotation}"
245
- )
246
- return annotation
247
-
248
-
249
- def _get_annotation(
250
- doc_id: str,
251
- annotation_id: str,
252
- annotation_layer: str,
253
- processed_documents: dict[
254
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
255
- ],
256
- ) -> LabeledSpan:
257
- document = processed_documents.get(doc_id)
258
- if document is None:
259
- raise gr.Error(
260
- f"Document '{doc_id}' not found in index. Available documents: {list(processed_documents)}"
261
- )
262
- return _get_annotation_from_document(document, annotation_id, annotation_layer)
263
-
264
-
265
- def _get_similar_entries_from_vector_store(
266
- ref_annotation_id: str,
267
- ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
268
- vector_store: SimpleVectorStore[Tuple[str, str]],
269
- **retrieval_kwargs,
270
- ) -> List[Tuple[Tuple[str, str], float]]:
271
- embeddings = ref_document.metadata["embeddings"]
272
- ref_embedding = embeddings.get(ref_annotation_id)
273
- if ref_embedding is None:
274
- raise gr.Error(
275
- f"Embedding for annotation '{ref_annotation_id}' not found in metadata of "
276
- f"document '{ref_document.id}'. Annotations with embeddings: {list(embeddings)}"
277
- )
278
-
279
- try:
280
- similar_entries = vector_store.retrieve_similar(
281
- ref_id=(ref_document.id, ref_annotation_id), **retrieval_kwargs
282
- )
283
- except Exception as e:
284
- raise gr.Error(f"Failed to retrieve similar ADUs: {e}")
285
-
286
- return similar_entries
287
-
288
-
289
- def get_similar_adus(
290
- ref_annotation_id: str,
291
- ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
292
- vector_store: SimpleVectorStore,
293
- processed_documents: dict[
294
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
295
- ],
296
- min_similarity: float,
297
- ) -> pd.DataFrame:
298
- similar_entries = _get_similar_entries_from_vector_store(
299
- ref_annotation_id=ref_annotation_id,
300
- ref_document=ref_document,
301
- vector_store=vector_store,
302
- min_similarity=min_similarity,
303
- )
304
-
305
- similar_annotations = [
306
- _get_annotation(
307
- doc_id=doc_id,
308
- annotation_id=annotation_id,
309
- annotation_layer="labeled_spans",
310
- processed_documents=processed_documents,
311
- )
312
- for (doc_id, annotation_id), _ in similar_entries
313
- ]
314
- df = pd.DataFrame(
315
- [
316
- # unpack the tuple (doc_id, annotation_id) to separate columns
317
- # and add the similarity score and the text of the annotation
318
- (doc_id, annotation_id, score, str(annotation))
319
- for ((doc_id, annotation_id), score), annotation in zip(
320
- similar_entries, similar_annotations
321
- )
322
- ],
323
- columns=["doc_id", "adu_id", "sim_score", "text"],
324
- )
325
-
326
- return df
327
-
328
-
329
- def get_relevant_adus(
330
- ref_annotation_id: str,
331
- ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
332
- vector_store: SimpleVectorStore,
333
- processed_documents: dict[
334
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
335
- ],
336
- min_similarity: float,
337
- ) -> pd.DataFrame:
338
- similar_entries = _get_similar_entries_from_vector_store(
339
- ref_annotation_id=ref_annotation_id,
340
- ref_document=ref_document,
341
- vector_store=vector_store,
342
- min_similarity=min_similarity,
343
- )
344
- ref_annotation = _get_annotation(
345
- doc_id=ref_document.id,
346
- annotation_id=ref_annotation_id,
347
- annotation_layer="labeled_spans",
348
- processed_documents=processed_documents,
349
- )
350
- result = []
351
- for (doc_id, annotation_id), score in similar_entries:
352
- # skip entries from the same document
353
- if doc_id == ref_document.id:
354
- continue
355
- document = processed_documents[doc_id]
356
- tail2rels = defaultdict(list)
357
- head2rels = defaultdict(list)
358
- for rel in document.binary_relations.predictions:
359
- # skip non-argumentative relations
360
- if rel.label in ["parts_of_same", "semantically_same"]:
361
- continue
362
- head2rels[rel.head].append(rel)
363
- tail2rels[rel.tail].append(rel)
364
-
365
- id2annotation = {
366
- str(annotation._id): annotation for annotation in document.labeled_spans.predictions
367
- }
368
- annotation = id2annotation.get(annotation_id)
369
- # note: we do not need to check if the annotation is different from the reference annotation,
370
- # because they com from different documents and we already skip entries from the same document
371
- for rel in head2rels.get(annotation, []):
372
- result.append(
373
- {
374
- "doc_id": doc_id,
375
- "reference_adu": str(annotation),
376
- "sim_score": score,
377
- "rel_score": rel.score,
378
- "relation": rel.label,
379
- "text": str(rel.tail),
380
- }
381
- )
382
-
383
- # define column order
384
- df = pd.DataFrame(
385
- result, columns=["text", "relation", "doc_id", "reference_adu", "sim_score", "rel_score"]
386
- )
387
- return df
388
-
389
-
390
  def open_accordion():
391
  return gr.Accordion(open=True)
392
 
@@ -463,6 +163,24 @@ def select_processed_document(
463
  return doc
464
 
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  def main():
467
 
468
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
@@ -526,7 +244,7 @@ def main():
526
  )
527
  embedding_model_name = gr.Textbox(
528
  label=f"Embedding Model Name (e.g. {DEFAULT_EMBEDDING_MODEL_NAME})",
529
- value="",
530
  )
531
  load_models_btn = gr.Button("Load Models")
532
  load_models_btn.click(
@@ -583,8 +301,18 @@ def main():
583
  step=0.01,
584
  value=0.8,
585
  )
 
 
 
 
 
 
 
586
  retrieve_similar_adus_btn = gr.Button("Retrieve similar ADUs")
587
  similar_adus = gr.DataFrame(headers=["doc_id", "adu_id", "score", "text"])
 
 
 
588
 
589
  # retrieve_relevant_adus_btn = gr.Button("Retrieve relevant ADUs")
590
  relevant_adus = gr.DataFrame(
@@ -626,7 +354,7 @@ def main():
626
  )
627
 
628
  upload_btn.upload(
629
- fn=process_uploaded_file,
630
  inputs=[upload_btn, models_state, processed_documents_state, vector_store_state],
631
  outputs=[],
632
  ).success(
@@ -648,12 +376,14 @@ def main():
648
  vector_store_state,
649
  processed_documents_state,
650
  min_similarity,
 
 
651
  ],
652
  outputs=[relevant_adus],
653
  )
654
 
655
  reference_adu_id.change(
656
- fn=partial(_get_annotation_from_document, annotation_layer="labeled_spans"),
657
  inputs=[document_state, reference_adu_id],
658
  outputs=[reference_adu_text],
659
  ).success(**retrieve_relevant_adus_event_kwargs)
@@ -666,10 +396,17 @@ def main():
666
  vector_store_state,
667
  processed_documents_state,
668
  min_similarity,
 
669
  ],
670
  outputs=[similar_adus],
671
  )
672
 
 
 
 
 
 
 
673
  # retrieve_relevant_adus_btn.click(
674
  # **retrieve_relevant_adus_event_kwargs
675
  # )
 
1
  import json
2
  import logging
3
  import os.path
 
4
  from functools import partial
5
+ from typing import Dict, List, Optional, Tuple
6
 
7
  import gradio as gr
8
  import pandas as pd
9
+ from backend import get_annotation_from_document, get_relevant_adus, get_similar_adus, process_text
10
+ from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
 
 
11
  from pytorch_ie import Pipeline
 
12
  from pytorch_ie.auto import AutoPipeline
13
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
 
 
14
  from rendering_utils import render_displacy, render_pretty_table
15
  from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
16
+ from vector_store import SimpleVectorStore, VectorStore
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
28
  DEFAULT_EMBEDDING_MODEL_NAME = "allenai/scibert_scivocab_uncased"
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def render_annotated_document(
32
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
33
  render_with: str,
 
44
  return html
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def wrapped_process_text(
48
  text: str,
49
  doc_id: str,
 
51
  processed_documents: dict[
52
  str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
53
  ],
54
+ vector_store: VectorStore[Tuple[str, str]],
55
  ) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
56
  document = process_text(
57
  text=text,
 
64
  return document.asdict(), document
65
 
66
 
67
+ def process_uploaded_files(
68
  file_names: List[str],
69
  models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
70
  processed_documents: dict[
71
  str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
72
  ],
73
+ vector_store: VectorStore[Tuple[str, str]],
74
  ) -> None:
75
  try:
76
  for file_name in file_names:
 
87
  raise gr.Error(f"Failed to process uploaded files: {e}")
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def open_accordion():
91
  return gr.Accordion(open=True)
92
 
 
163
  return doc
164
 
165
 
166
+ def set_relation_types(
167
+ models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
168
+ default: Optional[List[str]] = None,
169
+ ) -> gr.Dropdown:
170
+ arg_pipeline = models[0]
171
+ if isinstance(arg_pipeline.taskmodule, PointerNetworkTaskModuleForEnd2EndRE):
172
+ relation_types = arg_pipeline.taskmodule.labels_per_layer["binary_relations"]
173
+ else:
174
+ raise gr.Error("Unsupported taskmodule for relation types")
175
+
176
+ return gr.Dropdown(
177
+ choices=relation_types,
178
+ label="Relation Types",
179
+ value=default,
180
+ multiselect=True,
181
+ )
182
+
183
+
184
  def main():
185
 
186
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
 
244
  )
245
  embedding_model_name = gr.Textbox(
246
  label=f"Embedding Model Name (e.g. {DEFAULT_EMBEDDING_MODEL_NAME})",
247
+ value=DEFAULT_EMBEDDING_MODEL_NAME,
248
  )
249
  load_models_btn = gr.Button("Load Models")
250
  load_models_btn.click(
 
301
  step=0.01,
302
  value=0.8,
303
  )
304
+ top_k = gr.Slider(
305
+ label="Top K",
306
+ minimum=2,
307
+ maximum=50,
308
+ step=1,
309
+ value=20,
310
+ )
311
  retrieve_similar_adus_btn = gr.Button("Retrieve similar ADUs")
312
  similar_adus = gr.DataFrame(headers=["doc_id", "adu_id", "score", "text"])
313
+ relation_types = set_relation_types(
314
+ models_state.value, default=["supports", "contradicts"]
315
+ )
316
 
317
  # retrieve_relevant_adus_btn = gr.Button("Retrieve relevant ADUs")
318
  relevant_adus = gr.DataFrame(
 
354
  )
355
 
356
  upload_btn.upload(
357
+ fn=process_uploaded_files,
358
  inputs=[upload_btn, models_state, processed_documents_state, vector_store_state],
359
  outputs=[],
360
  ).success(
 
376
  vector_store_state,
377
  processed_documents_state,
378
  min_similarity,
379
+ top_k,
380
+ relation_types,
381
  ],
382
  outputs=[relevant_adus],
383
  )
384
 
385
  reference_adu_id.change(
386
+ fn=partial(get_annotation_from_document, annotation_layer="labeled_spans"),
387
  inputs=[document_state, reference_adu_id],
388
  outputs=[reference_adu_text],
389
  ).success(**retrieve_relevant_adus_event_kwargs)
 
396
  vector_store_state,
397
  processed_documents_state,
398
  min_similarity,
399
+ top_k,
400
  ],
401
  outputs=[similar_adus],
402
  )
403
 
404
+ models_state.change(
405
+ fn=set_relation_types,
406
+ inputs=[models_state],
407
+ outputs=[relation_types],
408
+ )
409
+
410
  # retrieve_relevant_adus_btn.click(
411
  # **retrieve_relevant_adus_event_kwargs
412
  # )
backend.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from collections import defaultdict
3
+ from typing import Dict, List, Optional, Tuple
4
+
5
+ import gradio as gr
6
+ import pandas as pd
7
+ from pie_modules.document.processing import tokenize_document
8
+ from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
9
+ from pytorch_ie import Pipeline
10
+ from pytorch_ie.annotations import LabeledSpan, Span
11
+ from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
12
+ from rendering_utils import labeled_span_to_id
13
+ from transformers import PreTrainedModel, PreTrainedTokenizer
14
+ from vector_store import VectorStore
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def embed_text_annotations(
20
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
21
+ model: PreTrainedModel,
22
+ tokenizer: PreTrainedTokenizer,
23
+ text_layer_name: str,
24
+ ) -> Dict[Span, List[float]]:
25
+ # to not modify the original document
26
+ document = document.copy()
27
+ # tokenize_document does not yet consider predictions, so we need to add them manually
28
+ document[text_layer_name].extend(document[text_layer_name].predictions.clear())
29
+ added_annotations = []
30
+ tokenizer_kwargs = {
31
+ "max_length": 512,
32
+ "stride": 64,
33
+ "truncation": True,
34
+ "return_overflowing_tokens": True,
35
+ }
36
+ tokenized_documents = tokenize_document(
37
+ document,
38
+ tokenizer=tokenizer,
39
+ result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
40
+ partition_layer="labeled_partitions",
41
+ added_annotations=added_annotations,
42
+ strict_span_conversion=False,
43
+ **tokenizer_kwargs,
44
+ )
45
+ # just tokenize again to get tensors in the correct format for the model
46
+ model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
47
+ # this is added when using return_overflowing_tokens=True, but the model does not accept it
48
+ model_inputs.pop("overflow_to_sample_mapping", None)
49
+ assert len(model_inputs.encodings) == len(tokenized_documents)
50
+ model_output = model(**model_inputs)
51
+
52
+ # get embeddings for all text annotations
53
+ embeddings = {}
54
+ for batch_idx in range(len(model_output.last_hidden_state)):
55
+ text2tok_ann = added_annotations[batch_idx][text_layer_name]
56
+ tok2text_ann = {v: k for k, v in text2tok_ann.items()}
57
+ for tok_ann in tokenized_documents[batch_idx].labeled_spans:
58
+ # skip "empty" annotations
59
+ if tok_ann.start == tok_ann.end:
60
+ continue
61
+ # use the max pooling strategy to get a single embedding for the annotation text
62
+ embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max(
63
+ dim=0
64
+ )[0]
65
+ text_ann = tok2text_ann[tok_ann]
66
+
67
+ if text_ann in embeddings:
68
+ logger.warning(
69
+ f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
70
+ )
71
+ embeddings[text_ann] = embedding
72
+
73
+ return embeddings
74
+
75
+
76
+ def annotate(
77
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
78
+ pipeline: Pipeline,
79
+ embedding_model: Optional[PreTrainedModel] = None,
80
+ embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
81
+ ) -> None:
82
+
83
+ # execute prediction pipeline
84
+ pipeline(document)
85
+
86
+ if embedding_model is not None and embedding_tokenizer is not None:
87
+ adu_embeddings = embed_text_annotations(
88
+ document=document,
89
+ model=embedding_model,
90
+ tokenizer=embedding_tokenizer,
91
+ text_layer_name="labeled_spans",
92
+ )
93
+ # convert keys to str because JSON keys must be strings
94
+ adu_embeddings_dict = {
95
+ labeled_span_to_id(k): v.detach().tolist() for k, v in adu_embeddings.items()
96
+ }
97
+ document.metadata["embeddings"] = adu_embeddings_dict
98
+ else:
99
+ gr.Warning(
100
+ "No embedding model provided. Skipping embedding extraction. You can load an embedding "
101
+ "model in the 'Model Configuration' section."
102
+ )
103
+
104
+
105
+ def add_to_index(
106
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
107
+ processed_documents: dict,
108
+ vector_store: VectorStore[Tuple[str, str]],
109
+ ) -> None:
110
+ try:
111
+ if document.id in processed_documents:
112
+ gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
113
+
114
+ # save the processed document to the index
115
+ processed_documents[document.id] = document
116
+
117
+ # save the embeddings to the vector store
118
+ for adu_id, embedding in document.metadata["embeddings"].items():
119
+ vector_store.save((document.id, adu_id), embedding)
120
+
121
+ gr.Info(
122
+ f"Added document {document.id} to index (index contains {len(processed_documents)} "
123
+ f"documents and {len(vector_store)} embeddings)."
124
+ )
125
+ except Exception as e:
126
+ raise gr.Error(f"Failed to add document {document.id} to index: {e}")
127
+
128
+
129
+ def process_text(
130
+ text: str,
131
+ doc_id: str,
132
+ models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
133
+ processed_documents: dict[
134
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
135
+ ],
136
+ vector_store: VectorStore[Tuple[str, str]],
137
+ ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
138
+ """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
139
+ text, annotate it, and add it to the index.
140
+
141
+ Parameters:
142
+ text: The text to process.
143
+ doc_id: The ID of the document.
144
+ models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
145
+ processed_documents: The index of processed documents.
146
+ vector_store: The vector store to save the embeddings.
147
+
148
+ Returns:
149
+ The processed document.
150
+ """
151
+
152
+ try:
153
+ document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
154
+ id=doc_id, text=text, metadata={}
155
+ )
156
+ # add single partition from the whole text (the model only considers text in partitions)
157
+ document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
158
+ # annotate the document
159
+ annotate(
160
+ document=document,
161
+ pipeline=models[0],
162
+ embedding_model=models[1],
163
+ embedding_tokenizer=models[2],
164
+ )
165
+ # add the document to the index
166
+ add_to_index(document, processed_documents, vector_store)
167
+
168
+ return document
169
+ except Exception as e:
170
+ raise gr.Error(f"Failed to process text: {e}")
171
+
172
+
173
+ def get_annotation_from_document(
174
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
175
+ annotation_id: str,
176
+ annotation_layer: str,
177
+ ) -> LabeledSpan:
178
+ # use predictions
179
+ annotations = document[annotation_layer].predictions
180
+ id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations}
181
+ annotation = id2annotation.get(annotation_id)
182
+ if annotation is None:
183
+ raise gr.Error(
184
+ f"annotation '{annotation_id}' not found in document '{document.id}'. Available "
185
+ f"annotations: {id2annotation}"
186
+ )
187
+ return annotation
188
+
189
+
190
+ def get_annotation_from_processed_documents(
191
+ doc_id: str,
192
+ annotation_id: str,
193
+ annotation_layer: str,
194
+ processed_documents: dict[
195
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
196
+ ],
197
+ ) -> LabeledSpan:
198
+ document = processed_documents.get(doc_id)
199
+ if document is None:
200
+ raise gr.Error(
201
+ f"Document '{doc_id}' not found in index. Available documents: {list(processed_documents)}"
202
+ )
203
+ return get_annotation_from_document(document, annotation_id, annotation_layer)
204
+
205
+
206
+ def get_similar_adus(
207
+ ref_annotation_id: str,
208
+ ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
209
+ vector_store: VectorStore[Tuple[str, str]],
210
+ processed_documents: dict[
211
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
212
+ ],
213
+ min_similarity: float,
214
+ top_k: int,
215
+ ) -> pd.DataFrame:
216
+ similar_entries = vector_store.retrieve_similar(
217
+ ref_id=(ref_document.id, ref_annotation_id),
218
+ min_similarity=min_similarity,
219
+ top_k=top_k,
220
+ )
221
+
222
+ similar_annotations = [
223
+ get_annotation_from_processed_documents(
224
+ doc_id=doc_id,
225
+ annotation_id=annotation_id,
226
+ annotation_layer="labeled_spans",
227
+ processed_documents=processed_documents,
228
+ )
229
+ for (doc_id, annotation_id), _ in similar_entries
230
+ ]
231
+ df = pd.DataFrame(
232
+ [
233
+ # unpack the tuple (doc_id, annotation_id) to separate columns
234
+ # and add the similarity score and the text of the annotation
235
+ (doc_id, annotation_id, score, str(annotation))
236
+ for ((doc_id, annotation_id), score), annotation in zip(
237
+ similar_entries, similar_annotations
238
+ )
239
+ ],
240
+ columns=["doc_id", "adu_id", "sim_score", "text"],
241
+ )
242
+
243
+ return df
244
+
245
+
246
+ def get_relevant_adus(
247
+ ref_annotation_id: str,
248
+ ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
249
+ vector_store: VectorStore[Tuple[str, str]],
250
+ processed_documents: dict[
251
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
252
+ ],
253
+ min_similarity: float,
254
+ top_k: int,
255
+ relation_types: List[str],
256
+ ) -> pd.DataFrame:
257
+ similar_entries = vector_store.retrieve_similar(
258
+ ref_id=(ref_document.id, ref_annotation_id),
259
+ min_similarity=min_similarity,
260
+ top_k=top_k,
261
+ )
262
+ result = []
263
+ for (doc_id, annotation_id), score in similar_entries:
264
+ # skip entries from the same document
265
+ if doc_id == ref_document.id:
266
+ continue
267
+ document = processed_documents[doc_id]
268
+ tail2rels = defaultdict(list)
269
+ head2rels = defaultdict(list)
270
+ for rel in document.binary_relations.predictions:
271
+ # skip non-argumentative relations
272
+ if rel.label not in relation_types:
273
+ continue
274
+ head2rels[rel.head].append(rel)
275
+ tail2rels[rel.tail].append(rel)
276
+
277
+ id2annotation = {
278
+ labeled_span_to_id(annotation): annotation
279
+ for annotation in document.labeled_spans.predictions
280
+ }
281
+ annotation = id2annotation.get(annotation_id)
282
+ # note: we do not need to check if the annotation is different from the reference annotation,
283
+ # because they come from different documents and we already skip entries from the same document
284
+ for rel in head2rels.get(annotation, []):
285
+ result.append(
286
+ {
287
+ "doc_id": doc_id,
288
+ "reference_adu": str(annotation),
289
+ "sim_score": score,
290
+ "rel_score": rel.score,
291
+ "relation": rel.label,
292
+ "text": str(rel.tail),
293
+ }
294
+ )
295
+
296
+ # define column order
297
+ df = pd.DataFrame(
298
+ result, columns=["text", "relation", "doc_id", "reference_adu", "sim_score", "rel_score"]
299
+ )
300
+ return df
rendering_utils.py CHANGED
@@ -2,7 +2,7 @@ import json
2
  from collections import defaultdict
3
  from typing import Dict, List, Optional, Union
4
 
5
- from pytorch_ie.annotations import BinaryRelation
6
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
7
  from rendering_utils_displacy import EntityRenderer
8
 
@@ -59,6 +59,15 @@ def render_displacy(
59
  return html
60
 
61
 
 
 
 
 
 
 
 
 
 
62
  def inject_relation_data(
63
  html: str,
64
  sorted_entities,
@@ -80,7 +89,7 @@ def inject_relation_data(
80
  entities = soup.find_all(class_="entity")
81
  for idx, entity in enumerate(entities):
82
  annotation = sorted_entities[idx]
83
- entity["id"] = str(annotation._id)
84
  original_color = entity["style"].split("background:")[1].split(";")[0].strip()
85
  entity["data-color-original"] = original_color
86
  if additional_colors is not None:
@@ -95,13 +104,13 @@ def inject_relation_data(
95
  entity["data-label"] = entity_annotation.label
96
  entity["data-relation-tails"] = json.dumps(
97
  [
98
- {"entity-id": str(tail._id), "label": label}
99
  for tail, label in entity2tails.get(entity_annotation, [])
100
  ]
101
  )
102
  entity["data-relation-heads"] = json.dumps(
103
  [
104
- {"entity-id": str(head._id), "label": label}
105
  for head, label in entity2heads.get(entity_annotation, [])
106
  ]
107
  )
 
2
  from collections import defaultdict
3
  from typing import Dict, List, Optional, Union
4
 
5
+ from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span
6
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
7
  from rendering_utils_displacy import EntityRenderer
8
 
 
59
  return html
60
 
61
 
62
+ def labeled_span_to_id(span: LabeledSpan) -> str:
63
+ return f"span-{span.start}-{span.end}-{span.label}"
64
+
65
+
66
+ def labeled_span_from_id(span_id: str) -> LabeledSpan:
67
+ parts = span_id.split("-")
68
+ return LabeledSpan(int(parts[1]), int(parts[2]), parts[3])
69
+
70
+
71
  def inject_relation_data(
72
  html: str,
73
  sorted_entities,
 
89
  entities = soup.find_all(class_="entity")
90
  for idx, entity in enumerate(entities):
91
  annotation = sorted_entities[idx]
92
+ entity["id"] = labeled_span_to_id(annotation)
93
  original_color = entity["style"].split("background:")[1].split(";")[0].strip()
94
  entity["data-color-original"] = original_color
95
  if additional_colors is not None:
 
104
  entity["data-label"] = entity_annotation.label
105
  entity["data-relation-tails"] = json.dumps(
106
  [
107
+ {"entity-id": labeled_span_to_id(tail), "label": label}
108
  for tail, label in entity2tails.get(entity_annotation, [])
109
  ]
110
  )
111
  entity["data-relation-heads"] = json.dumps(
112
  [
113
+ {"entity-id": labeled_span_to_id(head), "label": label}
114
  for head, label in entity2heads.get(entity_annotation, [])
115
  ]
116
  )
vector_store.py CHANGED
@@ -1,5 +1,24 @@
 
1
  from typing import Generic, Hashable, List, Optional, Tuple, TypeVar
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  def vector_norm(vector: List[float]) -> float:
5
  return sum(x**2 for x in vector) ** 0.5
@@ -9,10 +28,7 @@ def cosine_similarity(a: List[float], b: List[float]) -> float:
9
  return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))
10
 
11
 
12
- T = TypeVar("T", bound=Hashable)
13
-
14
-
15
- class SimpleVectorStore(Generic[T]):
16
  def __init__(self):
17
  self.vectors: dict[T, List[float]] = {}
18
  self._cache = {}
 
1
+ import abc
2
  from typing import Generic, Hashable, List, Optional, Tuple, TypeVar
3
 
4
+ T = TypeVar("T", bound=Hashable)
5
+
6
+
7
+ class VectorStore(Generic[T], abc.ABC):
8
+ @abc.abstractmethod
9
+ def save(self, emb_id: T, embedding: List[float]) -> None:
10
+ pass
11
+
12
+ @abc.abstractmethod
13
+ def retrieve_similar(
14
+ self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
15
+ ) -> List[Tuple[T, float]]:
16
+ pass
17
+
18
+ @abc.abstractmethod
19
+ def __len__(self):
20
+ pass
21
+
22
 
23
  def vector_norm(vector: List[float]) -> float:
24
  return sum(x**2 for x in vector) ** 0.5
 
28
  return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))
29
 
30
 
31
+ class SimpleVectorStore(VectorStore[T]):
 
 
 
32
  def __init__(self):
33
  self.vectors: dict[T, List[float]] = {}
34
  self._cache = {}