ArneBinder commited on
Commit
86277c0
1 Parent(s): 04ce9af

Upload 9 files

Browse files
Files changed (6) hide show
  1. annotation_utils.py +10 -0
  2. app.py +8 -49
  3. document_store.py +218 -0
  4. model_utils.py +173 -0
  5. rendering_utils.py +2 -10
  6. vector_store.py +18 -3
annotation_utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_ie.annotations import LabeledSpan
2
+
3
+
4
+ def labeled_span_to_id(span: LabeledSpan) -> str:
5
+ return f"span-{span.start}-{span.end}-{span.label}"
6
+
7
+
8
+ def labeled_span_from_id(span_id: str) -> LabeledSpan:
9
+ parts = span_id.split("-")
10
+ return LabeledSpan(int(parts[1]), int(parts[2]), parts[3])
app.py CHANGED
@@ -7,13 +7,13 @@ from typing import List, Optional, Tuple
7
 
8
  import gradio as gr
9
  import pandas as pd
10
- from backend import DocumentStore, create_and_annotate_document, get_annotation_from_document
 
11
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
12
  from pytorch_ie import Pipeline
13
- from pytorch_ie.auto import AutoPipeline
14
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
15
  from rendering_utils import render_displacy, render_pretty_table
16
- from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
17
 
18
  logger = logging.getLogger(__name__)
19
 
@@ -66,6 +66,7 @@ def process_uploaded_files(
66
  document_store: DocumentStore,
67
  ) -> pd.DataFrame:
68
  try:
 
69
  for file_name in file_names:
70
  if file_name.lower().endswith(".txt"):
71
  # read the file content
@@ -73,10 +74,10 @@ def process_uploaded_files(
73
  text = f.read()
74
  base_file_name = os.path.basename(file_name)
75
  gr.Info(f"Processing file '{base_file_name}' ...")
76
- document = create_and_annotate_document(text, base_file_name, models)
77
- document_store.add_document(document)
78
  else:
79
  raise gr.Error(f"Unsupported file format: {file_name}")
 
80
  except Exception as e:
81
  raise gr.Error(f"Failed to process uploaded files: {e}")
82
 
@@ -91,43 +92,6 @@ def close_accordion():
91
  return gr.Accordion(open=False)
92
 
93
 
94
- def load_argumentation_model(model_name: str, revision: Optional[str] = None) -> Pipeline:
95
- try:
96
- model = AutoPipeline.from_pretrained(
97
- model_name,
98
- device=-1,
99
- num_workers=0,
100
- taskmodule_kwargs=dict(revision=revision),
101
- model_kwargs=dict(revision=revision),
102
- )
103
- except Exception as e:
104
- raise gr.Error(f"Failed to load argumentation model: {e}")
105
- gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})")
106
- return model
107
-
108
-
109
- def load_embedding_model(model_name: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
110
- try:
111
- embedding_model = AutoModel.from_pretrained(model_name)
112
- embedding_tokenizer = AutoTokenizer.from_pretrained(model_name)
113
- except Exception as e:
114
- raise gr.Error(f"Failed to load embedding model: {e}")
115
- gr.Info(f"Loaded embedding model: model_name={model_name})")
116
- return embedding_model, embedding_tokenizer
117
-
118
-
119
- def load_models(
120
- model_name: str, revision: Optional[str] = None, embedding_model_name: Optional[str] = None
121
- ) -> Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]]:
122
- argumentation_model = load_argumentation_model(model_name, revision)
123
- embedding_model = None
124
- embedding_tokenizer = None
125
- if embedding_model_name is not None and embedding_model_name.strip():
126
- embedding_model, embedding_tokenizer = load_embedding_model(embedding_model_name)
127
-
128
- return argumentation_model, embedding_model, embedding_tokenizer
129
-
130
-
131
  def select_processed_document(
132
  evt: gr.SelectData,
133
  processed_documents_df: pd.DataFrame,
@@ -135,7 +99,6 @@ def select_processed_document(
135
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
136
  row_idx, col_idx = evt.index
137
  doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
138
- gr.Info(f"Select document: {doc_id}")
139
  doc = document_store.get_document(doc_id)
140
  return doc
141
 
@@ -163,8 +126,7 @@ def download_processed_documents(
163
  file_name: str = "processed_documents.json",
164
  ) -> str:
165
  file_path = os.path.join(tempfile.gettempdir(), file_name)
166
- with open(file_path, "w", encoding="utf-8") as f:
167
- json.dump(document_store.as_dict(), f, indent=2)
168
  return file_path
169
 
170
 
@@ -172,10 +134,7 @@ def upload_processed_documents(
172
  file_name: str,
173
  document_store: DocumentStore,
174
  ) -> pd.DataFrame:
175
- with open(file_name, "r", encoding="utf-8") as f:
176
- processed_documents_json = json.load(f)
177
- for _, document_json in processed_documents_json.items():
178
- document_store.add_document_from_dict(document_dict=document_json)
179
  return document_store.overview()
180
 
181
 
 
7
 
8
  import gradio as gr
9
  import pandas as pd
10
+ from document_store import DocumentStore, get_annotation_from_document
11
+ from model_utils import create_and_annotate_document, load_models
12
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
13
  from pytorch_ie import Pipeline
 
14
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
15
  from rendering_utils import render_displacy, render_pretty_table
16
+ from transformers import PreTrainedModel, PreTrainedTokenizer
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
66
  document_store: DocumentStore,
67
  ) -> pd.DataFrame:
68
  try:
69
+ new_documents = []
70
  for file_name in file_names:
71
  if file_name.lower().endswith(".txt"):
72
  # read the file content
 
74
  text = f.read()
75
  base_file_name = os.path.basename(file_name)
76
  gr.Info(f"Processing file '{base_file_name}' ...")
77
+ new_documents.append(create_and_annotate_document(text, base_file_name, models))
 
78
  else:
79
  raise gr.Error(f"Unsupported file format: {file_name}")
80
+ document_store.add_documents(new_documents)
81
  except Exception as e:
82
  raise gr.Error(f"Failed to process uploaded files: {e}")
83
 
 
92
  return gr.Accordion(open=False)
93
 
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def select_processed_document(
96
  evt: gr.SelectData,
97
  processed_documents_df: pd.DataFrame,
 
99
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
100
  row_idx, col_idx = evt.index
101
  doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
 
102
  doc = document_store.get_document(doc_id)
103
  return doc
104
 
 
126
  file_name: str = "processed_documents.json",
127
  ) -> str:
128
  file_path = os.path.join(tempfile.gettempdir(), file_name)
129
+ document_store.save_to_json(file_path, indent=2)
 
130
  return file_path
131
 
132
 
 
134
  file_name: str,
135
  document_store: DocumentStore,
136
  ) -> pd.DataFrame:
137
+ document_store.add_from_json(file_name)
 
 
 
138
  return document_store.overview()
139
 
140
 
document_store.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from collections import defaultdict
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import gradio as gr
7
+ import pandas as pd
8
+ from annotation_utils import labeled_span_to_id
9
+ from pytorch_ie.annotations import LabeledSpan
10
+ from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
11
+ from vector_store import SimpleVectorStore, VectorStore
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def get_annotation_from_document(
17
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
18
+ annotation_id: str,
19
+ annotation_layer: str,
20
+ ) -> LabeledSpan:
21
+ # use predictions
22
+ annotations = document[annotation_layer].predictions
23
+ id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations}
24
+ annotation = id2annotation.get(annotation_id)
25
+ if annotation is None:
26
+ raise gr.Error(
27
+ f"annotation '{annotation_id}' not found in document '{document.id}'. Available "
28
+ f"annotations: {id2annotation}"
29
+ )
30
+ return annotation
31
+
32
+
33
+ class DocumentStore:
34
+
35
+ DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
36
+
37
+ def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str], List[float]]] = None):
38
+ # The annotated documents. As key, we use the document id. All documents keep the embeddings
39
+ # of the ADUs in the metadata.
40
+ self.documents: Dict[
41
+ str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
42
+ ] = {}
43
+ # The vector store to efficiently retrieve similar ADUs. Can be constructed from the
44
+ # documents.
45
+ self.vector_store: VectorStore[Tuple[str, str], List[float]] = (
46
+ vector_store or SimpleVectorStore()
47
+ )
48
+
49
+ def get_annotation(
50
+ self,
51
+ doc_id: str,
52
+ annotation_id: str,
53
+ annotation_layer: str,
54
+ ) -> LabeledSpan:
55
+ document = self.documents.get(doc_id)
56
+ if document is None:
57
+ raise gr.Error(
58
+ f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}"
59
+ )
60
+ return get_annotation_from_document(document, annotation_id, annotation_layer)
61
+
62
+ def get_similar_adus_df(
63
+ self,
64
+ ref_annotation_id: str,
65
+ ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
66
+ min_similarity: float,
67
+ top_k: int,
68
+ ) -> pd.DataFrame:
69
+ similar_entries = self.vector_store.retrieve_similar(
70
+ ref_id=(ref_document.id, ref_annotation_id),
71
+ min_similarity=min_similarity,
72
+ top_k=top_k,
73
+ )
74
+
75
+ similar_annotations = [
76
+ self.get_annotation(
77
+ doc_id=doc_id,
78
+ annotation_id=annotation_id,
79
+ annotation_layer="labeled_spans",
80
+ )
81
+ for (doc_id, annotation_id), _ in similar_entries
82
+ ]
83
+ df = pd.DataFrame(
84
+ [
85
+ # unpack the tuple (doc_id, annotation_id) to separate columns
86
+ # and add the similarity score and the text of the annotation
87
+ (doc_id, annotation_id, score, str(annotation))
88
+ for ((doc_id, annotation_id), score), annotation in zip(
89
+ similar_entries, similar_annotations
90
+ )
91
+ ],
92
+ columns=["doc_id", "adu_id", "sim_score", "text"],
93
+ )
94
+
95
+ return df
96
+
97
+ def get_relevant_adus_df(
98
+ self,
99
+ ref_annotation_id: str,
100
+ ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
101
+ min_similarity: float,
102
+ top_k: int,
103
+ relation_types: List[str],
104
+ columns: List[str],
105
+ ) -> pd.DataFrame:
106
+ similar_entries = self.vector_store.retrieve_similar(
107
+ ref_id=(ref_document.id, ref_annotation_id),
108
+ min_similarity=min_similarity,
109
+ top_k=top_k,
110
+ )
111
+ result = []
112
+ for (doc_id, annotation_id), score in similar_entries:
113
+ # skip entries from the same document
114
+ if doc_id == ref_document.id:
115
+ continue
116
+ document = self.documents[doc_id]
117
+ tail2rels = defaultdict(list)
118
+ head2rels = defaultdict(list)
119
+ for rel in document.binary_relations.predictions:
120
+ # skip non-argumentative relations
121
+ if rel.label not in relation_types:
122
+ continue
123
+ head2rels[rel.head].append(rel)
124
+ tail2rels[rel.tail].append(rel)
125
+
126
+ id2annotation = {
127
+ labeled_span_to_id(annotation): annotation
128
+ for annotation in document.labeled_spans.predictions
129
+ }
130
+ annotation = id2annotation.get(annotation_id)
131
+ # note: we do not need to check if the annotation is different from the reference annotation,
132
+ # because they come from different documents and we already skip entries from the same document
133
+ for rel in head2rels.get(annotation, []):
134
+ result.append(
135
+ {
136
+ "doc_id": doc_id,
137
+ "reference_adu": str(annotation),
138
+ "sim_score": score,
139
+ "rel_score": rel.score,
140
+ "relation": rel.label,
141
+ "adu": str(rel.tail),
142
+ }
143
+ )
144
+
145
+ # define column order
146
+ df = pd.DataFrame(result, columns=columns)
147
+ return df
148
+
149
+ def add_document(
150
+ self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
151
+ ) -> None:
152
+ try:
153
+ if document.id in self.documents:
154
+ gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
155
+
156
+ # save the processed document to the index
157
+ self.documents[document.id] = document
158
+
159
+ # save the embeddings to the vector store
160
+ for adu_id, embedding in document.metadata["embeddings"].items():
161
+ self.vector_store.save((document.id, adu_id), embedding)
162
+
163
+ except Exception as e:
164
+ raise gr.Error(f"Failed to add document {document.id} to index: {e}")
165
+
166
+ def add_document_from_dict(self, document_dict: dict) -> None:
167
+ document = self.DOCUMENT_TYPE.fromdict(document_dict)
168
+ # metadata is not automatically deserialized, so we need to set it manually
169
+ document.metadata = document_dict["metadata"]
170
+ self.add_document(document)
171
+
172
+ def add_documents(
173
+ self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
174
+ ) -> None:
175
+ size_before = len(self.documents)
176
+ for document in documents:
177
+ self.add_document(document)
178
+ size_after = len(self.documents)
179
+ gr.Info(
180
+ f"Added {size_after - size_before} documents to the index ({size_after} documents in total)."
181
+ )
182
+
183
+ def add_from_json(self, file_path: str) -> None:
184
+ size_before = len(self.documents)
185
+ with open(file_path, "r", encoding="utf-8") as f:
186
+ processed_documents_json = json.load(f)
187
+ for _, document_json in processed_documents_json.items():
188
+ self.add_document_from_dict(document_dict=document_json)
189
+ size_after = len(self.documents)
190
+ gr.Info(
191
+ f"Added {size_after - size_before} documents to the index ({size_after} documents in total)."
192
+ )
193
+
194
+ def save_to_json(self, file_path: str, **kwargs) -> None:
195
+ with open(file_path, "w", encoding="utf-8") as f:
196
+ json.dump(self.as_dict(), f, **kwargs)
197
+
198
+ def get_document(
199
+ self, doc_id: str
200
+ ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
201
+ return self.documents[doc_id]
202
+
203
+ def overview(self) -> pd.DataFrame:
204
+ df = pd.DataFrame(
205
+ [
206
+ (
207
+ doc_id,
208
+ len(document.labeled_spans.predictions),
209
+ len(document.binary_relations.predictions),
210
+ )
211
+ for doc_id, document in self.documents.items()
212
+ ],
213
+ columns=["doc_id", "num_adus", "num_relations"],
214
+ )
215
+ return df
216
+
217
+ def as_dict(self) -> dict:
218
+ return {doc_id: document.asdict() for doc_id, document in self.documents.items()}
model_utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, List, Optional, Tuple
3
+
4
+ import gradio as gr
5
+ from annotation_utils import labeled_span_to_id
6
+ from pie_modules.document.processing import tokenize_document
7
+ from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
8
+ from pytorch_ie import Pipeline
9
+ from pytorch_ie.annotations import LabeledSpan
10
+ from pytorch_ie.auto import AutoPipeline
11
+ from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
12
+ from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def _embed_text_annotations(
18
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
19
+ model: PreTrainedModel,
20
+ tokenizer: PreTrainedTokenizer,
21
+ text_layer_name: str,
22
+ ) -> Dict[LabeledSpan, List[float]]:
23
+ # to not modify the original document
24
+ document = document.copy()
25
+ # tokenize_document does not yet consider predictions, so we need to add them manually
26
+ document[text_layer_name].extend(document[text_layer_name].predictions.clear())
27
+ added_annotations = []
28
+ tokenizer_kwargs = {
29
+ "max_length": 512,
30
+ "stride": 64,
31
+ "truncation": True,
32
+ "return_overflowing_tokens": True,
33
+ }
34
+ tokenized_documents = tokenize_document(
35
+ document,
36
+ tokenizer=tokenizer,
37
+ result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
38
+ partition_layer="labeled_partitions",
39
+ added_annotations=added_annotations,
40
+ strict_span_conversion=False,
41
+ **tokenizer_kwargs,
42
+ )
43
+ # just tokenize again to get tensors in the correct format for the model
44
+ model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
45
+ # this is added when using return_overflowing_tokens=True, but the model does not accept it
46
+ model_inputs.pop("overflow_to_sample_mapping", None)
47
+ assert len(model_inputs.encodings) == len(tokenized_documents)
48
+ model_output = model(**model_inputs)
49
+
50
+ # get embeddings for all text annotations
51
+ embeddings = {}
52
+ for batch_idx in range(len(model_output.last_hidden_state)):
53
+ text2tok_ann = added_annotations[batch_idx][text_layer_name]
54
+ tok2text_ann = {v: k for k, v in text2tok_ann.items()}
55
+ for tok_ann in tokenized_documents[batch_idx].labeled_spans:
56
+ # skip "empty" annotations
57
+ if tok_ann.start == tok_ann.end:
58
+ continue
59
+ # use the max pooling strategy to get a single embedding for the annotation text
60
+ embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max(
61
+ dim=0
62
+ )[0]
63
+ text_ann = tok2text_ann[tok_ann]
64
+
65
+ if text_ann in embeddings:
66
+ logger.warning(
67
+ f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
68
+ )
69
+ embeddings[text_ann] = embedding
70
+
71
+ return embeddings
72
+
73
+
74
+ def _annotate(
75
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
76
+ pipeline: Pipeline,
77
+ embedding_model: Optional[PreTrainedModel] = None,
78
+ embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
79
+ ) -> None:
80
+
81
+ # execute prediction pipeline
82
+ pipeline(document)
83
+
84
+ if embedding_model is not None and embedding_tokenizer is not None:
85
+ adu_embeddings = _embed_text_annotations(
86
+ document=document,
87
+ model=embedding_model,
88
+ tokenizer=embedding_tokenizer,
89
+ text_layer_name="labeled_spans",
90
+ )
91
+ # convert keys to str because JSON keys must be strings
92
+ adu_embeddings_dict = {
93
+ labeled_span_to_id(k): v.detach().tolist() for k, v in adu_embeddings.items()
94
+ }
95
+ document.metadata["embeddings"] = adu_embeddings_dict
96
+ else:
97
+ gr.Warning(
98
+ "No embedding model provided. Skipping embedding extraction. You can load an embedding "
99
+ "model in the 'Model Configuration' section."
100
+ )
101
+
102
+
103
+ def create_and_annotate_document(
104
+ text: str,
105
+ doc_id: str,
106
+ models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
107
+ ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
108
+ """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
109
+ text, annotate it, and add it to the index.
110
+
111
+ Parameters:
112
+ text: The text to process.
113
+ doc_id: The ID of the document.
114
+ models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
115
+
116
+ Returns:
117
+ The processed document.
118
+ """
119
+
120
+ try:
121
+ document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
122
+ id=doc_id, text=text, metadata={}
123
+ )
124
+ # add single partition from the whole text (the model only considers text in partitions)
125
+ document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
126
+ # annotate the document
127
+ _annotate(
128
+ document=document,
129
+ pipeline=models[0],
130
+ embedding_model=models[1],
131
+ embedding_tokenizer=models[2],
132
+ )
133
+
134
+ return document
135
+ except Exception as e:
136
+ raise gr.Error(f"Failed to process text: {e}")
137
+
138
+
139
+ def load_argumentation_model(model_name: str, revision: Optional[str] = None) -> Pipeline:
140
+ try:
141
+ model = AutoPipeline.from_pretrained(
142
+ model_name,
143
+ device=-1,
144
+ num_workers=0,
145
+ taskmodule_kwargs=dict(revision=revision),
146
+ model_kwargs=dict(revision=revision),
147
+ )
148
+ except Exception as e:
149
+ raise gr.Error(f"Failed to load argumentation model: {e}")
150
+ gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})")
151
+ return model
152
+
153
+
154
+ def load_embedding_model(model_name: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
155
+ try:
156
+ embedding_model = AutoModel.from_pretrained(model_name)
157
+ embedding_tokenizer = AutoTokenizer.from_pretrained(model_name)
158
+ except Exception as e:
159
+ raise gr.Error(f"Failed to load embedding model: {e}")
160
+ gr.Info(f"Loaded embedding model: model_name={model_name})")
161
+ return embedding_model, embedding_tokenizer
162
+
163
+
164
+ def load_models(
165
+ model_name: str, revision: Optional[str] = None, embedding_model_name: Optional[str] = None
166
+ ) -> Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]]:
167
+ argumentation_model = load_argumentation_model(model_name, revision)
168
+ embedding_model = None
169
+ embedding_tokenizer = None
170
+ if embedding_model_name is not None and embedding_model_name.strip():
171
+ embedding_model, embedding_tokenizer = load_embedding_model(embedding_model_name)
172
+
173
+ return argumentation_model, embedding_model, embedding_tokenizer
rendering_utils.py CHANGED
@@ -2,7 +2,8 @@ import json
2
  from collections import defaultdict
3
  from typing import Dict, List, Optional, Union
4
 
5
- from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span
 
6
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
7
  from rendering_utils_displacy import EntityRenderer
8
 
@@ -59,15 +60,6 @@ def render_displacy(
59
  return html
60
 
61
 
62
- def labeled_span_to_id(span: LabeledSpan) -> str:
63
- return f"span-{span.start}-{span.end}-{span.label}"
64
-
65
-
66
- def labeled_span_from_id(span_id: str) -> LabeledSpan:
67
- parts = span_id.split("-")
68
- return LabeledSpan(int(parts[1]), int(parts[2]), parts[3])
69
-
70
-
71
  def inject_relation_data(
72
  html: str,
73
  sorted_entities,
 
2
  from collections import defaultdict
3
  from typing import Dict, List, Optional, Union
4
 
5
+ from annotation_utils import labeled_span_to_id
6
+ from pytorch_ie.annotations import BinaryRelation
7
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
8
  from rendering_utils_displacy import EntityRenderer
9
 
 
60
  return html
61
 
62
 
 
 
 
 
 
 
 
 
 
63
  def inject_relation_data(
64
  html: str,
65
  sorted_entities,
vector_store.py CHANGED
@@ -2,17 +2,32 @@ import abc
2
  from typing import Generic, Hashable, List, Optional, Tuple, TypeVar
3
 
4
  T = TypeVar("T", bound=Hashable)
 
5
 
6
 
7
- class VectorStore(Generic[T], abc.ABC):
8
  @abc.abstractmethod
9
- def save(self, emb_id: T, embedding: List[float]) -> None:
 
10
  pass
11
 
12
  @abc.abstractmethod
13
  def retrieve_similar(
14
  self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
15
  ) -> List[Tuple[T, float]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  pass
17
 
18
  @abc.abstractmethod
@@ -28,7 +43,7 @@ def cosine_similarity(a: List[float], b: List[float]) -> float:
28
  return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))
29
 
30
 
31
- class SimpleVectorStore(VectorStore[T]):
32
  def __init__(self):
33
  self.vectors: dict[T, List[float]] = {}
34
  self._cache = {}
 
2
  from typing import Generic, Hashable, List, Optional, Tuple, TypeVar
3
 
4
  T = TypeVar("T", bound=Hashable)
5
+ E = TypeVar("E")
6
 
7
 
8
+ class VectorStore(Generic[T, E], abc.ABC):
9
  @abc.abstractmethod
10
+ def save(self, emb_id: T, embedding: E) -> None:
11
+ """Save an embedding for a given ID."""
12
  pass
13
 
14
  @abc.abstractmethod
15
  def retrieve_similar(
16
  self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
17
  ) -> List[Tuple[T, float]]:
18
+ """Retrieve IDs and the respective similarity scores with respect to the reference entry.
19
+ Note that this requires the reference entry to be present in the store.
20
+
21
+ Args:
22
+ ref_id: The ID of the reference entry.
23
+ top_k: If provided, only the top-k most similar entries will be returned.
24
+ min_similarity: If provided, only entries with a similarity score greater or equal to
25
+ this value will be returned.
26
+
27
+ Returns:
28
+ A list of tuples consisting of the ID and the similarity score, sorted by similarity
29
+ score in descending order.
30
+ """
31
  pass
32
 
33
  @abc.abstractmethod
 
43
  return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))
44
 
45
 
46
+ class SimpleVectorStore(VectorStore[T, List[float]]):
47
  def __init__(self):
48
  self.vectors: dict[T, List[float]] = {}
49
  self._cache = {}