ArneBinder commited on
Commit
04ce9af
1 Parent(s): b77f1d0

Upload 7 files

Browse files
Files changed (2) hide show
  1. app.py +30 -78
  2. backend.py +160 -140
app.py CHANGED
@@ -3,18 +3,17 @@ import logging
3
  import os.path
4
  import tempfile
5
  from functools import partial
6
- from typing import Dict, List, Optional, Tuple
7
 
8
  import gradio as gr
9
  import pandas as pd
10
- from backend import get_annotation_from_document, get_relevant_adus, get_similar_adus, process_text
11
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
12
  from pytorch_ie import Pipeline
13
  from pytorch_ie.auto import AutoPipeline
14
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
15
  from rendering_utils import render_displacy, render_pretty_table
16
  from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
17
- from vector_store import SimpleVectorStore, VectorStore
18
 
19
  logger = logging.getLogger(__name__)
20
 
@@ -49,18 +48,14 @@ def wrapped_process_text(
49
  text: str,
50
  doc_id: str,
51
  models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
52
- processed_documents: dict[
53
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
54
- ],
55
- vector_store: VectorStore[Tuple[str, str]],
56
  ) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
57
- document = process_text(
58
  text=text,
59
  doc_id=doc_id,
60
  models=models,
61
- processed_documents=processed_documents,
62
- vector_store=vector_store,
63
  )
 
64
  # Return as dict and document to avoid serialization issues
65
  return document.asdict(), document
66
 
@@ -68,10 +63,7 @@ def wrapped_process_text(
68
  def process_uploaded_files(
69
  file_names: List[str],
70
  models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
71
- processed_documents: dict[
72
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
73
- ],
74
- vector_store: VectorStore[Tuple[str, str]],
75
  ) -> pd.DataFrame:
76
  try:
77
  for file_name in file_names:
@@ -81,13 +73,14 @@ def process_uploaded_files(
81
  text = f.read()
82
  base_file_name = os.path.basename(file_name)
83
  gr.Info(f"Processing file '{base_file_name}' ...")
84
- process_text(text, base_file_name, models, processed_documents, vector_store)
 
85
  else:
86
  raise gr.Error(f"Unsupported file format: {file_name}")
87
  except Exception as e:
88
  raise gr.Error(f"Failed to process uploaded files: {e}")
89
 
90
- return update_processed_documents_df(processed_documents)
91
 
92
 
93
  def open_accordion():
@@ -135,34 +128,15 @@ def load_models(
135
  return argumentation_model, embedding_model, embedding_tokenizer
136
 
137
 
138
- def update_processed_documents_df(
139
- processed_documents: dict[str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
140
- ) -> pd.DataFrame:
141
- df = pd.DataFrame(
142
- [
143
- (
144
- doc_id,
145
- len(document.labeled_spans.predictions),
146
- len(document.binary_relations.predictions),
147
- )
148
- for doc_id, document in processed_documents.items()
149
- ],
150
- columns=["doc_id", "num_adus", "num_relations"],
151
- )
152
- return df
153
-
154
-
155
  def select_processed_document(
156
  evt: gr.SelectData,
157
  processed_documents_df: pd.DataFrame,
158
- processed_documents: Dict[
159
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
160
- ],
161
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
162
  row_idx, col_idx = evt.index
163
  doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
164
  gr.Info(f"Select document: {doc_id}")
165
- doc = processed_documents[doc_id]
166
  return doc
167
 
168
 
@@ -185,38 +159,24 @@ def set_relation_types(
185
 
186
 
187
  def download_processed_documents(
188
- processed_documents: dict[
189
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
190
- ],
191
  file_name: str = "processed_documents.json",
192
  ) -> str:
193
- processed_documents_json = {
194
- doc_id: document.asdict() for doc_id, document in processed_documents.items()
195
- }
196
  file_path = os.path.join(tempfile.gettempdir(), file_name)
197
  with open(file_path, "w", encoding="utf-8") as f:
198
- json.dump(processed_documents_json, f, indent=2)
199
  return file_path
200
 
201
 
202
  def upload_processed_documents(
203
  file_name: str,
204
- processed_documents: dict[
205
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
206
- ],
207
- ) -> Dict[str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
208
  with open(file_name, "r", encoding="utf-8") as f:
209
  processed_documents_json = json.load(f)
210
- for doc_id, document_json in processed_documents_json.items():
211
- document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions.fromdict(
212
- document_json
213
- )
214
- # metadata is not automatically deserialized, so we need to set it manually
215
- document.metadata["embeddings"] = document_json["metadata"]["embeddings"]
216
- if doc_id in processed_documents:
217
- gr.Warning(f"Document '{doc_id}' already exists. Overwriting.")
218
- processed_documents[doc_id] = document
219
- return processed_documents
220
 
221
 
222
  def main():
@@ -256,8 +216,7 @@ def main():
256
  }
257
 
258
  with gr.Blocks() as demo:
259
- processed_documents_state = gr.State(dict())
260
- vector_store_state = gr.State(SimpleVectorStore())
261
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
262
  models_state = gr.State((argumentation_model, embedding_model, embedding_tokenizer))
263
  with gr.Row():
@@ -381,12 +340,12 @@ def main():
381
 
382
  predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
383
  fn=wrapped_process_text,
384
- inputs=[doc_text, doc_id, models_state, processed_documents_state, vector_store_state],
385
  outputs=[document_json, document_state],
386
  api_name="predict",
387
  ).success(
388
- fn=update_processed_documents_df,
389
- inputs=[processed_documents_state],
390
  outputs=[processed_documents_df],
391
  )
392
  render_btn.click(**render_event_kwargs, api_name="render")
@@ -403,41 +362,35 @@ def main():
403
  fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
404
  ).then(
405
  fn=process_uploaded_files,
406
- inputs=[upload_btn, models_state, processed_documents_state, vector_store_state],
407
  outputs=[processed_documents_df],
408
  )
409
  processed_documents_df.select(
410
  select_processed_document,
411
- inputs=[processed_documents_df, processed_documents_state],
412
  outputs=[document_state],
413
  )
414
 
415
  download_processed_documents_btn.click(
416
  fn=download_processed_documents,
417
- inputs=[processed_documents_state],
418
  outputs=[download_processed_documents_btn],
419
  )
420
  upload_processed_documents_btn.upload(
421
  fn=upload_processed_documents,
422
- inputs=[upload_processed_documents_btn, processed_documents_state],
423
- outputs=[processed_documents_state],
424
- ).success(
425
- fn=update_processed_documents_df,
426
- inputs=[processed_documents_state],
427
  outputs=[processed_documents_df],
428
  )
429
 
430
  retrieve_relevant_adus_event_kwargs = dict(
431
- fn=get_relevant_adus,
432
  inputs=[
 
433
  selected_adu_id,
434
  document_state,
435
- vector_store_state,
436
- processed_documents_state,
437
  min_similarity,
438
  top_k,
439
  relation_types,
440
- relevant_adus,
441
  ],
442
  outputs=[relevant_adus],
443
  )
@@ -449,12 +402,11 @@ def main():
449
  ).success(**retrieve_relevant_adus_event_kwargs)
450
 
451
  retrieve_similar_adus_btn.click(
452
- fn=get_similar_adus,
453
  inputs=[
 
454
  selected_adu_id,
455
  document_state,
456
- vector_store_state,
457
- processed_documents_state,
458
  min_similarity,
459
  top_k,
460
  ],
 
3
  import os.path
4
  import tempfile
5
  from functools import partial
6
+ from typing import List, Optional, Tuple
7
 
8
  import gradio as gr
9
  import pandas as pd
10
+ from backend import DocumentStore, create_and_annotate_document, get_annotation_from_document
11
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
12
  from pytorch_ie import Pipeline
13
  from pytorch_ie.auto import AutoPipeline
14
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
15
  from rendering_utils import render_displacy, render_pretty_table
16
  from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
48
  text: str,
49
  doc_id: str,
50
  models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
51
+ document_store: DocumentStore,
 
 
 
52
  ) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
53
+ document = create_and_annotate_document(
54
  text=text,
55
  doc_id=doc_id,
56
  models=models,
 
 
57
  )
58
+ document_store.add_document(document)
59
  # Return as dict and document to avoid serialization issues
60
  return document.asdict(), document
61
 
 
63
  def process_uploaded_files(
64
  file_names: List[str],
65
  models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
66
+ document_store: DocumentStore,
 
 
 
67
  ) -> pd.DataFrame:
68
  try:
69
  for file_name in file_names:
 
73
  text = f.read()
74
  base_file_name = os.path.basename(file_name)
75
  gr.Info(f"Processing file '{base_file_name}' ...")
76
+ document = create_and_annotate_document(text, base_file_name, models)
77
+ document_store.add_document(document)
78
  else:
79
  raise gr.Error(f"Unsupported file format: {file_name}")
80
  except Exception as e:
81
  raise gr.Error(f"Failed to process uploaded files: {e}")
82
 
83
+ return document_store.overview()
84
 
85
 
86
  def open_accordion():
 
128
  return argumentation_model, embedding_model, embedding_tokenizer
129
 
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def select_processed_document(
132
  evt: gr.SelectData,
133
  processed_documents_df: pd.DataFrame,
134
+ document_store: DocumentStore,
 
 
135
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
136
  row_idx, col_idx = evt.index
137
  doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
138
  gr.Info(f"Select document: {doc_id}")
139
+ doc = document_store.get_document(doc_id)
140
  return doc
141
 
142
 
 
159
 
160
 
161
  def download_processed_documents(
162
+ document_store: DocumentStore,
 
 
163
  file_name: str = "processed_documents.json",
164
  ) -> str:
 
 
 
165
  file_path = os.path.join(tempfile.gettempdir(), file_name)
166
  with open(file_path, "w", encoding="utf-8") as f:
167
+ json.dump(document_store.as_dict(), f, indent=2)
168
  return file_path
169
 
170
 
171
  def upload_processed_documents(
172
  file_name: str,
173
+ document_store: DocumentStore,
174
+ ) -> pd.DataFrame:
 
 
175
  with open(file_name, "r", encoding="utf-8") as f:
176
  processed_documents_json = json.load(f)
177
+ for _, document_json in processed_documents_json.items():
178
+ document_store.add_document_from_dict(document_dict=document_json)
179
+ return document_store.overview()
 
 
 
 
 
 
 
180
 
181
 
182
  def main():
 
216
  }
217
 
218
  with gr.Blocks() as demo:
219
+ document_store_state = gr.State(DocumentStore())
 
220
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
221
  models_state = gr.State((argumentation_model, embedding_model, embedding_tokenizer))
222
  with gr.Row():
 
340
 
341
  predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
342
  fn=wrapped_process_text,
343
+ inputs=[doc_text, doc_id, models_state, document_store_state],
344
  outputs=[document_json, document_state],
345
  api_name="predict",
346
  ).success(
347
+ fn=lambda document_store: document_store.overview(),
348
+ inputs=[document_store_state],
349
  outputs=[processed_documents_df],
350
  )
351
  render_btn.click(**render_event_kwargs, api_name="render")
 
362
  fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
363
  ).then(
364
  fn=process_uploaded_files,
365
+ inputs=[upload_btn, models_state, document_store_state],
366
  outputs=[processed_documents_df],
367
  )
368
  processed_documents_df.select(
369
  select_processed_document,
370
+ inputs=[processed_documents_df, document_store_state],
371
  outputs=[document_state],
372
  )
373
 
374
  download_processed_documents_btn.click(
375
  fn=download_processed_documents,
376
+ inputs=[document_store_state],
377
  outputs=[download_processed_documents_btn],
378
  )
379
  upload_processed_documents_btn.upload(
380
  fn=upload_processed_documents,
381
+ inputs=[upload_processed_documents_btn, document_store_state],
 
 
 
 
382
  outputs=[processed_documents_df],
383
  )
384
 
385
  retrieve_relevant_adus_event_kwargs = dict(
386
+ fn=partial(DocumentStore.get_relevant_adus_df, columns=relevant_adus.headers),
387
  inputs=[
388
+ document_store_state,
389
  selected_adu_id,
390
  document_state,
 
 
391
  min_similarity,
392
  top_k,
393
  relation_types,
 
394
  ],
395
  outputs=[relevant_adus],
396
  )
 
402
  ).success(**retrieve_relevant_adus_event_kwargs)
403
 
404
  retrieve_similar_adus_btn.click(
405
+ fn=DocumentStore.get_similar_adus_df,
406
  inputs=[
407
+ document_store_state,
408
  selected_adu_id,
409
  document_state,
 
 
410
  min_similarity,
411
  top_k,
412
  ],
backend.py CHANGED
@@ -11,12 +11,12 @@ from pytorch_ie.annotations import LabeledSpan, Span
11
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
12
  from rendering_utils import labeled_span_to_id
13
  from transformers import PreTrainedModel, PreTrainedTokenizer
14
- from vector_store import VectorStore
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
 
19
- def embed_text_annotations(
20
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
21
  model: PreTrainedModel,
22
  tokenizer: PreTrainedTokenizer,
@@ -73,7 +73,7 @@ def embed_text_annotations(
73
  return embeddings
74
 
75
 
76
- def annotate(
77
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
78
  pipeline: Pipeline,
79
  embedding_model: Optional[PreTrainedModel] = None,
@@ -84,7 +84,7 @@ def annotate(
84
  pipeline(document)
85
 
86
  if embedding_model is not None and embedding_tokenizer is not None:
87
- adu_embeddings = embed_text_annotations(
88
  document=document,
89
  model=embedding_model,
90
  tokenizer=embedding_tokenizer,
@@ -102,38 +102,10 @@ def annotate(
102
  )
103
 
104
 
105
- def add_to_index(
106
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
107
- processed_documents: dict,
108
- vector_store: VectorStore[Tuple[str, str]],
109
- ) -> None:
110
- try:
111
- if document.id in processed_documents:
112
- gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
113
-
114
- # save the processed document to the index
115
- processed_documents[document.id] = document
116
-
117
- # save the embeddings to the vector store
118
- for adu_id, embedding in document.metadata["embeddings"].items():
119
- vector_store.save((document.id, adu_id), embedding)
120
-
121
- gr.Info(
122
- f"Added document {document.id} to index (index contains {len(processed_documents)} "
123
- f"documents and {len(vector_store)} embeddings)."
124
- )
125
- except Exception as e:
126
- raise gr.Error(f"Failed to add document {document.id} to index: {e}")
127
-
128
-
129
- def process_text(
130
  text: str,
131
  doc_id: str,
132
  models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
133
- processed_documents: dict[
134
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
135
- ],
136
- vector_store: VectorStore[Tuple[str, str]],
137
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
138
  """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
139
  text, annotate it, and add it to the index.
@@ -142,8 +114,6 @@ def process_text(
142
  text: The text to process.
143
  doc_id: The ID of the document.
144
  models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
145
- processed_documents: The index of processed documents.
146
- vector_store: The vector store to save the embeddings.
147
 
148
  Returns:
149
  The processed document.
@@ -156,14 +126,12 @@ def process_text(
156
  # add single partition from the whole text (the model only considers text in partitions)
157
  document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
158
  # annotate the document
159
- annotate(
160
  document=document,
161
  pipeline=models[0],
162
  embedding_model=models[1],
163
  embedding_tokenizer=models[2],
164
  )
165
- # add the document to the index
166
- add_to_index(document, processed_documents, vector_store)
167
 
168
  return document
169
  except Exception as e:
@@ -187,113 +155,165 @@ def get_annotation_from_document(
187
  return annotation
188
 
189
 
190
- def get_annotation_from_processed_documents(
191
- doc_id: str,
192
- annotation_id: str,
193
- annotation_layer: str,
194
- processed_documents: dict[
195
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
196
- ],
197
- ) -> LabeledSpan:
198
- document = processed_documents.get(doc_id)
199
- if document is None:
200
- raise gr.Error(
201
- f"Document '{doc_id}' not found in index. Available documents: {list(processed_documents)}"
202
- )
203
- return get_annotation_from_document(document, annotation_id, annotation_layer)
204
-
205
-
206
- def get_similar_adus(
207
- ref_annotation_id: str,
208
- ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
209
- vector_store: VectorStore[Tuple[str, str]],
210
- processed_documents: dict[
211
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
212
- ],
213
- min_similarity: float,
214
- top_k: int,
215
- ) -> pd.DataFrame:
216
- similar_entries = vector_store.retrieve_similar(
217
- ref_id=(ref_document.id, ref_annotation_id),
218
- min_similarity=min_similarity,
219
- top_k=top_k,
220
- )
221
 
222
- similar_annotations = [
223
- get_annotation_from_processed_documents(
224
- doc_id=doc_id,
225
- annotation_id=annotation_id,
226
- annotation_layer="labeled_spans",
227
- processed_documents=processed_documents,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  )
229
- for (doc_id, annotation_id), _ in similar_entries
230
- ]
231
- df = pd.DataFrame(
232
- [
233
- # unpack the tuple (doc_id, annotation_id) to separate columns
234
- # and add the similarity score and the text of the annotation
235
- (doc_id, annotation_id, score, str(annotation))
236
- for ((doc_id, annotation_id), score), annotation in zip(
237
- similar_entries, similar_annotations
238
  )
239
- ],
240
- columns=["doc_id", "adu_id", "sim_score", "text"],
241
- )
 
 
 
 
 
 
 
 
 
 
242
 
243
- return df
244
-
245
-
246
- def get_relevant_adus(
247
- ref_annotation_id: str,
248
- ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
249
- vector_store: VectorStore[Tuple[str, str]],
250
- processed_documents: dict[
251
- str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
252
- ],
253
- min_similarity: float,
254
- top_k: int,
255
- relation_types: List[str],
256
- previous_result: pd.DataFrame,
257
- ) -> pd.DataFrame:
258
- similar_entries = vector_store.retrieve_similar(
259
- ref_id=(ref_document.id, ref_annotation_id),
260
- min_similarity=min_similarity,
261
- top_k=top_k,
262
- )
263
- result = []
264
- for (doc_id, annotation_id), score in similar_entries:
265
- # skip entries from the same document
266
- if doc_id == ref_document.id:
267
- continue
268
- document = processed_documents[doc_id]
269
- tail2rels = defaultdict(list)
270
- head2rels = defaultdict(list)
271
- for rel in document.binary_relations.predictions:
272
- # skip non-argumentative relations
273
- if rel.label not in relation_types:
274
  continue
275
- head2rels[rel.head].append(rel)
276
- tail2rels[rel.tail].append(rel)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- id2annotation = {
279
- labeled_span_to_id(annotation): annotation
280
- for annotation in document.labeled_spans.predictions
281
- }
282
- annotation = id2annotation.get(annotation_id)
283
- # note: we do not need to check if the annotation is different from the reference annotation,
284
- # because they come from different documents and we already skip entries from the same document
285
- for rel in head2rels.get(annotation, []):
286
- result.append(
287
- {
288
- "doc_id": doc_id,
289
- "reference_adu": str(annotation),
290
- "sim_score": score,
291
- "rel_score": rel.score,
292
- "relation": rel.label,
293
- "adu": str(rel.tail),
294
- }
 
 
 
 
295
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- # define column order
298
- df = pd.DataFrame(result, columns=previous_result.columns)
299
- return df
 
11
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
12
  from rendering_utils import labeled_span_to_id
13
  from transformers import PreTrainedModel, PreTrainedTokenizer
14
+ from vector_store import SimpleVectorStore, VectorStore
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
 
19
+ def _embed_text_annotations(
20
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
21
  model: PreTrainedModel,
22
  tokenizer: PreTrainedTokenizer,
 
73
  return embeddings
74
 
75
 
76
+ def _annotate(
77
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
78
  pipeline: Pipeline,
79
  embedding_model: Optional[PreTrainedModel] = None,
 
84
  pipeline(document)
85
 
86
  if embedding_model is not None and embedding_tokenizer is not None:
87
+ adu_embeddings = _embed_text_annotations(
88
  document=document,
89
  model=embedding_model,
90
  tokenizer=embedding_tokenizer,
 
102
  )
103
 
104
 
105
+ def create_and_annotate_document(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  text: str,
107
  doc_id: str,
108
  models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
 
 
 
 
109
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
110
  """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
111
  text, annotate it, and add it to the index.
 
114
  text: The text to process.
115
  doc_id: The ID of the document.
116
  models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
 
 
117
 
118
  Returns:
119
  The processed document.
 
126
  # add single partition from the whole text (the model only considers text in partitions)
127
  document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
128
  # annotate the document
129
+ _annotate(
130
  document=document,
131
  pipeline=models[0],
132
  embedding_model=models[1],
133
  embedding_tokenizer=models[2],
134
  )
 
 
135
 
136
  return document
137
  except Exception as e:
 
155
  return annotation
156
 
157
 
158
+ class DocumentStore:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
161
+
162
+ def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str]]] = None):
163
+ self.documents = {}
164
+ self.vector_store = vector_store or SimpleVectorStore()
165
+
166
+ def get_annotation(
167
+ self,
168
+ doc_id: str,
169
+ annotation_id: str,
170
+ annotation_layer: str,
171
+ ) -> LabeledSpan:
172
+ document = self.documents.get(doc_id)
173
+ if document is None:
174
+ raise gr.Error(
175
+ f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}"
176
+ )
177
+ return get_annotation_from_document(document, annotation_id, annotation_layer)
178
+
179
+ def get_similar_adus_df(
180
+ self,
181
+ ref_annotation_id: str,
182
+ ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
183
+ min_similarity: float,
184
+ top_k: int,
185
+ ) -> pd.DataFrame:
186
+ similar_entries = self.vector_store.retrieve_similar(
187
+ ref_id=(ref_document.id, ref_annotation_id),
188
+ min_similarity=min_similarity,
189
+ top_k=top_k,
190
  )
191
+
192
+ similar_annotations = [
193
+ self.get_annotation(
194
+ doc_id=doc_id,
195
+ annotation_id=annotation_id,
196
+ annotation_layer="labeled_spans",
 
 
 
197
  )
198
+ for (doc_id, annotation_id), _ in similar_entries
199
+ ]
200
+ df = pd.DataFrame(
201
+ [
202
+ # unpack the tuple (doc_id, annotation_id) to separate columns
203
+ # and add the similarity score and the text of the annotation
204
+ (doc_id, annotation_id, score, str(annotation))
205
+ for ((doc_id, annotation_id), score), annotation in zip(
206
+ similar_entries, similar_annotations
207
+ )
208
+ ],
209
+ columns=["doc_id", "adu_id", "sim_score", "text"],
210
+ )
211
 
212
+ return df
213
+
214
+ def get_relevant_adus_df(
215
+ self,
216
+ ref_annotation_id: str,
217
+ ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
218
+ min_similarity: float,
219
+ top_k: int,
220
+ relation_types: List[str],
221
+ columns: List[str],
222
+ ) -> pd.DataFrame:
223
+ similar_entries = self.vector_store.retrieve_similar(
224
+ ref_id=(ref_document.id, ref_annotation_id),
225
+ min_similarity=min_similarity,
226
+ top_k=top_k,
227
+ )
228
+ result = []
229
+ for (doc_id, annotation_id), score in similar_entries:
230
+ # skip entries from the same document
231
+ if doc_id == ref_document.id:
 
 
 
 
 
 
 
 
 
 
 
232
  continue
233
+ document = self.documents[doc_id]
234
+ tail2rels = defaultdict(list)
235
+ head2rels = defaultdict(list)
236
+ for rel in document.binary_relations.predictions:
237
+ # skip non-argumentative relations
238
+ if rel.label not in relation_types:
239
+ continue
240
+ head2rels[rel.head].append(rel)
241
+ tail2rels[rel.tail].append(rel)
242
+
243
+ id2annotation = {
244
+ labeled_span_to_id(annotation): annotation
245
+ for annotation in document.labeled_spans.predictions
246
+ }
247
+ annotation = id2annotation.get(annotation_id)
248
+ # note: we do not need to check if the annotation is different from the reference annotation,
249
+ # because they come from different documents and we already skip entries from the same document
250
+ for rel in head2rels.get(annotation, []):
251
+ result.append(
252
+ {
253
+ "doc_id": doc_id,
254
+ "reference_adu": str(annotation),
255
+ "sim_score": score,
256
+ "rel_score": rel.score,
257
+ "relation": rel.label,
258
+ "adu": str(rel.tail),
259
+ }
260
+ )
261
 
262
+ # define column order
263
+ df = pd.DataFrame(result, columns=columns)
264
+ return df
265
+
266
+ def add_document(
267
+ self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
268
+ ) -> None:
269
+ try:
270
+ if document.id in self.documents:
271
+ gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
272
+
273
+ # save the processed document to the index
274
+ self.documents[document.id] = document
275
+
276
+ # save the embeddings to the vector store
277
+ for adu_id, embedding in document.metadata["embeddings"].items():
278
+ self.vector_store.save((document.id, adu_id), embedding)
279
+
280
+ gr.Info(
281
+ f"Added document {document.id} to index (index contains {len(self.documents)} "
282
+ f"documents and {len(self.vector_store)} embeddings)."
283
  )
284
+ except Exception as e:
285
+ raise gr.Error(f"Failed to add document {document.id} to index: {e}")
286
+
287
+ def add_document_from_dict(self, document_dict: dict) -> None:
288
+ document = self.DOCUMENT_TYPE.fromdict(document_dict)
289
+ # metadata is not automatically deserialized, so we need to set it manually
290
+ document.metadata = document_dict["metadata"]
291
+ self.add_document(document)
292
+
293
+ def add_documents(
294
+ self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
295
+ ) -> None:
296
+ for document in documents:
297
+ self.add_document(document)
298
+
299
+ def get_document(
300
+ self, doc_id: str
301
+ ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
302
+ return self.documents[doc_id]
303
+
304
+ def overview(self) -> pd.DataFrame:
305
+ df = pd.DataFrame(
306
+ [
307
+ (
308
+ doc_id,
309
+ len(document.labeled_spans.predictions),
310
+ len(document.binary_relations.predictions),
311
+ )
312
+ for doc_id, document in self.documents.items()
313
+ ],
314
+ columns=["doc_id", "num_adus", "num_relations"],
315
+ )
316
+ return df
317
 
318
+ def as_dict(self) -> dict:
319
+ return {doc_id: document.asdict() for doc_id, document in self.documents.items()}