Commit
•
04ce9af
1
Parent(s):
b77f1d0
Upload 7 files
Browse files- app.py +30 -78
- backend.py +160 -140
app.py
CHANGED
@@ -3,18 +3,17 @@ import logging
|
|
3 |
import os.path
|
4 |
import tempfile
|
5 |
from functools import partial
|
6 |
-
from typing import
|
7 |
|
8 |
import gradio as gr
|
9 |
import pandas as pd
|
10 |
-
from backend import
|
11 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
12 |
from pytorch_ie import Pipeline
|
13 |
from pytorch_ie.auto import AutoPipeline
|
14 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
15 |
from rendering_utils import render_displacy, render_pretty_table
|
16 |
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
|
17 |
-
from vector_store import SimpleVectorStore, VectorStore
|
18 |
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
@@ -49,18 +48,14 @@ def wrapped_process_text(
|
|
49 |
text: str,
|
50 |
doc_id: str,
|
51 |
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
|
52 |
-
|
53 |
-
str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
54 |
-
],
|
55 |
-
vector_store: VectorStore[Tuple[str, str]],
|
56 |
) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
|
57 |
-
document =
|
58 |
text=text,
|
59 |
doc_id=doc_id,
|
60 |
models=models,
|
61 |
-
processed_documents=processed_documents,
|
62 |
-
vector_store=vector_store,
|
63 |
)
|
|
|
64 |
# Return as dict and document to avoid serialization issues
|
65 |
return document.asdict(), document
|
66 |
|
@@ -68,10 +63,7 @@ def wrapped_process_text(
|
|
68 |
def process_uploaded_files(
|
69 |
file_names: List[str],
|
70 |
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
|
71 |
-
|
72 |
-
str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
73 |
-
],
|
74 |
-
vector_store: VectorStore[Tuple[str, str]],
|
75 |
) -> pd.DataFrame:
|
76 |
try:
|
77 |
for file_name in file_names:
|
@@ -81,13 +73,14 @@ def process_uploaded_files(
|
|
81 |
text = f.read()
|
82 |
base_file_name = os.path.basename(file_name)
|
83 |
gr.Info(f"Processing file '{base_file_name}' ...")
|
84 |
-
|
|
|
85 |
else:
|
86 |
raise gr.Error(f"Unsupported file format: {file_name}")
|
87 |
except Exception as e:
|
88 |
raise gr.Error(f"Failed to process uploaded files: {e}")
|
89 |
|
90 |
-
return
|
91 |
|
92 |
|
93 |
def open_accordion():
|
@@ -135,34 +128,15 @@ def load_models(
|
|
135 |
return argumentation_model, embedding_model, embedding_tokenizer
|
136 |
|
137 |
|
138 |
-
def update_processed_documents_df(
|
139 |
-
processed_documents: dict[str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
|
140 |
-
) -> pd.DataFrame:
|
141 |
-
df = pd.DataFrame(
|
142 |
-
[
|
143 |
-
(
|
144 |
-
doc_id,
|
145 |
-
len(document.labeled_spans.predictions),
|
146 |
-
len(document.binary_relations.predictions),
|
147 |
-
)
|
148 |
-
for doc_id, document in processed_documents.items()
|
149 |
-
],
|
150 |
-
columns=["doc_id", "num_adus", "num_relations"],
|
151 |
-
)
|
152 |
-
return df
|
153 |
-
|
154 |
-
|
155 |
def select_processed_document(
|
156 |
evt: gr.SelectData,
|
157 |
processed_documents_df: pd.DataFrame,
|
158 |
-
|
159 |
-
str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
160 |
-
],
|
161 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
162 |
row_idx, col_idx = evt.index
|
163 |
doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
|
164 |
gr.Info(f"Select document: {doc_id}")
|
165 |
-
doc =
|
166 |
return doc
|
167 |
|
168 |
|
@@ -185,38 +159,24 @@ def set_relation_types(
|
|
185 |
|
186 |
|
187 |
def download_processed_documents(
|
188 |
-
|
189 |
-
str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
190 |
-
],
|
191 |
file_name: str = "processed_documents.json",
|
192 |
) -> str:
|
193 |
-
processed_documents_json = {
|
194 |
-
doc_id: document.asdict() for doc_id, document in processed_documents.items()
|
195 |
-
}
|
196 |
file_path = os.path.join(tempfile.gettempdir(), file_name)
|
197 |
with open(file_path, "w", encoding="utf-8") as f:
|
198 |
-
json.dump(
|
199 |
return file_path
|
200 |
|
201 |
|
202 |
def upload_processed_documents(
|
203 |
file_name: str,
|
204 |
-
|
205 |
-
|
206 |
-
],
|
207 |
-
) -> Dict[str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
|
208 |
with open(file_name, "r", encoding="utf-8") as f:
|
209 |
processed_documents_json = json.load(f)
|
210 |
-
for
|
211 |
-
|
212 |
-
|
213 |
-
)
|
214 |
-
# metadata is not automatically deserialized, so we need to set it manually
|
215 |
-
document.metadata["embeddings"] = document_json["metadata"]["embeddings"]
|
216 |
-
if doc_id in processed_documents:
|
217 |
-
gr.Warning(f"Document '{doc_id}' already exists. Overwriting.")
|
218 |
-
processed_documents[doc_id] = document
|
219 |
-
return processed_documents
|
220 |
|
221 |
|
222 |
def main():
|
@@ -256,8 +216,7 @@ def main():
|
|
256 |
}
|
257 |
|
258 |
with gr.Blocks() as demo:
|
259 |
-
|
260 |
-
vector_store_state = gr.State(SimpleVectorStore())
|
261 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
262 |
models_state = gr.State((argumentation_model, embedding_model, embedding_tokenizer))
|
263 |
with gr.Row():
|
@@ -381,12 +340,12 @@ def main():
|
|
381 |
|
382 |
predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
|
383 |
fn=wrapped_process_text,
|
384 |
-
inputs=[doc_text, doc_id, models_state,
|
385 |
outputs=[document_json, document_state],
|
386 |
api_name="predict",
|
387 |
).success(
|
388 |
-
fn=
|
389 |
-
inputs=[
|
390 |
outputs=[processed_documents_df],
|
391 |
)
|
392 |
render_btn.click(**render_event_kwargs, api_name="render")
|
@@ -403,41 +362,35 @@ def main():
|
|
403 |
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
|
404 |
).then(
|
405 |
fn=process_uploaded_files,
|
406 |
-
inputs=[upload_btn, models_state,
|
407 |
outputs=[processed_documents_df],
|
408 |
)
|
409 |
processed_documents_df.select(
|
410 |
select_processed_document,
|
411 |
-
inputs=[processed_documents_df,
|
412 |
outputs=[document_state],
|
413 |
)
|
414 |
|
415 |
download_processed_documents_btn.click(
|
416 |
fn=download_processed_documents,
|
417 |
-
inputs=[
|
418 |
outputs=[download_processed_documents_btn],
|
419 |
)
|
420 |
upload_processed_documents_btn.upload(
|
421 |
fn=upload_processed_documents,
|
422 |
-
inputs=[upload_processed_documents_btn,
|
423 |
-
outputs=[processed_documents_state],
|
424 |
-
).success(
|
425 |
-
fn=update_processed_documents_df,
|
426 |
-
inputs=[processed_documents_state],
|
427 |
outputs=[processed_documents_df],
|
428 |
)
|
429 |
|
430 |
retrieve_relevant_adus_event_kwargs = dict(
|
431 |
-
fn=
|
432 |
inputs=[
|
|
|
433 |
selected_adu_id,
|
434 |
document_state,
|
435 |
-
vector_store_state,
|
436 |
-
processed_documents_state,
|
437 |
min_similarity,
|
438 |
top_k,
|
439 |
relation_types,
|
440 |
-
relevant_adus,
|
441 |
],
|
442 |
outputs=[relevant_adus],
|
443 |
)
|
@@ -449,12 +402,11 @@ def main():
|
|
449 |
).success(**retrieve_relevant_adus_event_kwargs)
|
450 |
|
451 |
retrieve_similar_adus_btn.click(
|
452 |
-
fn=
|
453 |
inputs=[
|
|
|
454 |
selected_adu_id,
|
455 |
document_state,
|
456 |
-
vector_store_state,
|
457 |
-
processed_documents_state,
|
458 |
min_similarity,
|
459 |
top_k,
|
460 |
],
|
|
|
3 |
import os.path
|
4 |
import tempfile
|
5 |
from functools import partial
|
6 |
+
from typing import List, Optional, Tuple
|
7 |
|
8 |
import gradio as gr
|
9 |
import pandas as pd
|
10 |
+
from backend import DocumentStore, create_and_annotate_document, get_annotation_from_document
|
11 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
12 |
from pytorch_ie import Pipeline
|
13 |
from pytorch_ie.auto import AutoPipeline
|
14 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
15 |
from rendering_utils import render_displacy, render_pretty_table
|
16 |
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
|
|
|
17 |
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
|
|
48 |
text: str,
|
49 |
doc_id: str,
|
50 |
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
|
51 |
+
document_store: DocumentStore,
|
|
|
|
|
|
|
52 |
) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
|
53 |
+
document = create_and_annotate_document(
|
54 |
text=text,
|
55 |
doc_id=doc_id,
|
56 |
models=models,
|
|
|
|
|
57 |
)
|
58 |
+
document_store.add_document(document)
|
59 |
# Return as dict and document to avoid serialization issues
|
60 |
return document.asdict(), document
|
61 |
|
|
|
63 |
def process_uploaded_files(
|
64 |
file_names: List[str],
|
65 |
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
|
66 |
+
document_store: DocumentStore,
|
|
|
|
|
|
|
67 |
) -> pd.DataFrame:
|
68 |
try:
|
69 |
for file_name in file_names:
|
|
|
73 |
text = f.read()
|
74 |
base_file_name = os.path.basename(file_name)
|
75 |
gr.Info(f"Processing file '{base_file_name}' ...")
|
76 |
+
document = create_and_annotate_document(text, base_file_name, models)
|
77 |
+
document_store.add_document(document)
|
78 |
else:
|
79 |
raise gr.Error(f"Unsupported file format: {file_name}")
|
80 |
except Exception as e:
|
81 |
raise gr.Error(f"Failed to process uploaded files: {e}")
|
82 |
|
83 |
+
return document_store.overview()
|
84 |
|
85 |
|
86 |
def open_accordion():
|
|
|
128 |
return argumentation_model, embedding_model, embedding_tokenizer
|
129 |
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
def select_processed_document(
|
132 |
evt: gr.SelectData,
|
133 |
processed_documents_df: pd.DataFrame,
|
134 |
+
document_store: DocumentStore,
|
|
|
|
|
135 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
136 |
row_idx, col_idx = evt.index
|
137 |
doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
|
138 |
gr.Info(f"Select document: {doc_id}")
|
139 |
+
doc = document_store.get_document(doc_id)
|
140 |
return doc
|
141 |
|
142 |
|
|
|
159 |
|
160 |
|
161 |
def download_processed_documents(
|
162 |
+
document_store: DocumentStore,
|
|
|
|
|
163 |
file_name: str = "processed_documents.json",
|
164 |
) -> str:
|
|
|
|
|
|
|
165 |
file_path = os.path.join(tempfile.gettempdir(), file_name)
|
166 |
with open(file_path, "w", encoding="utf-8") as f:
|
167 |
+
json.dump(document_store.as_dict(), f, indent=2)
|
168 |
return file_path
|
169 |
|
170 |
|
171 |
def upload_processed_documents(
|
172 |
file_name: str,
|
173 |
+
document_store: DocumentStore,
|
174 |
+
) -> pd.DataFrame:
|
|
|
|
|
175 |
with open(file_name, "r", encoding="utf-8") as f:
|
176 |
processed_documents_json = json.load(f)
|
177 |
+
for _, document_json in processed_documents_json.items():
|
178 |
+
document_store.add_document_from_dict(document_dict=document_json)
|
179 |
+
return document_store.overview()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
|
182 |
def main():
|
|
|
216 |
}
|
217 |
|
218 |
with gr.Blocks() as demo:
|
219 |
+
document_store_state = gr.State(DocumentStore())
|
|
|
220 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
221 |
models_state = gr.State((argumentation_model, embedding_model, embedding_tokenizer))
|
222 |
with gr.Row():
|
|
|
340 |
|
341 |
predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
|
342 |
fn=wrapped_process_text,
|
343 |
+
inputs=[doc_text, doc_id, models_state, document_store_state],
|
344 |
outputs=[document_json, document_state],
|
345 |
api_name="predict",
|
346 |
).success(
|
347 |
+
fn=lambda document_store: document_store.overview(),
|
348 |
+
inputs=[document_store_state],
|
349 |
outputs=[processed_documents_df],
|
350 |
)
|
351 |
render_btn.click(**render_event_kwargs, api_name="render")
|
|
|
362 |
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
|
363 |
).then(
|
364 |
fn=process_uploaded_files,
|
365 |
+
inputs=[upload_btn, models_state, document_store_state],
|
366 |
outputs=[processed_documents_df],
|
367 |
)
|
368 |
processed_documents_df.select(
|
369 |
select_processed_document,
|
370 |
+
inputs=[processed_documents_df, document_store_state],
|
371 |
outputs=[document_state],
|
372 |
)
|
373 |
|
374 |
download_processed_documents_btn.click(
|
375 |
fn=download_processed_documents,
|
376 |
+
inputs=[document_store_state],
|
377 |
outputs=[download_processed_documents_btn],
|
378 |
)
|
379 |
upload_processed_documents_btn.upload(
|
380 |
fn=upload_processed_documents,
|
381 |
+
inputs=[upload_processed_documents_btn, document_store_state],
|
|
|
|
|
|
|
|
|
382 |
outputs=[processed_documents_df],
|
383 |
)
|
384 |
|
385 |
retrieve_relevant_adus_event_kwargs = dict(
|
386 |
+
fn=partial(DocumentStore.get_relevant_adus_df, columns=relevant_adus.headers),
|
387 |
inputs=[
|
388 |
+
document_store_state,
|
389 |
selected_adu_id,
|
390 |
document_state,
|
|
|
|
|
391 |
min_similarity,
|
392 |
top_k,
|
393 |
relation_types,
|
|
|
394 |
],
|
395 |
outputs=[relevant_adus],
|
396 |
)
|
|
|
402 |
).success(**retrieve_relevant_adus_event_kwargs)
|
403 |
|
404 |
retrieve_similar_adus_btn.click(
|
405 |
+
fn=DocumentStore.get_similar_adus_df,
|
406 |
inputs=[
|
407 |
+
document_store_state,
|
408 |
selected_adu_id,
|
409 |
document_state,
|
|
|
|
|
410 |
min_similarity,
|
411 |
top_k,
|
412 |
],
|
backend.py
CHANGED
@@ -11,12 +11,12 @@ from pytorch_ie.annotations import LabeledSpan, Span
|
|
11 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
12 |
from rendering_utils import labeled_span_to_id
|
13 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
14 |
-
from vector_store import VectorStore
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
|
19 |
-
def
|
20 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
21 |
model: PreTrainedModel,
|
22 |
tokenizer: PreTrainedTokenizer,
|
@@ -73,7 +73,7 @@ def embed_text_annotations(
|
|
73 |
return embeddings
|
74 |
|
75 |
|
76 |
-
def
|
77 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
78 |
pipeline: Pipeline,
|
79 |
embedding_model: Optional[PreTrainedModel] = None,
|
@@ -84,7 +84,7 @@ def annotate(
|
|
84 |
pipeline(document)
|
85 |
|
86 |
if embedding_model is not None and embedding_tokenizer is not None:
|
87 |
-
adu_embeddings =
|
88 |
document=document,
|
89 |
model=embedding_model,
|
90 |
tokenizer=embedding_tokenizer,
|
@@ -102,38 +102,10 @@ def annotate(
|
|
102 |
)
|
103 |
|
104 |
|
105 |
-
def
|
106 |
-
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
107 |
-
processed_documents: dict,
|
108 |
-
vector_store: VectorStore[Tuple[str, str]],
|
109 |
-
) -> None:
|
110 |
-
try:
|
111 |
-
if document.id in processed_documents:
|
112 |
-
gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
|
113 |
-
|
114 |
-
# save the processed document to the index
|
115 |
-
processed_documents[document.id] = document
|
116 |
-
|
117 |
-
# save the embeddings to the vector store
|
118 |
-
for adu_id, embedding in document.metadata["embeddings"].items():
|
119 |
-
vector_store.save((document.id, adu_id), embedding)
|
120 |
-
|
121 |
-
gr.Info(
|
122 |
-
f"Added document {document.id} to index (index contains {len(processed_documents)} "
|
123 |
-
f"documents and {len(vector_store)} embeddings)."
|
124 |
-
)
|
125 |
-
except Exception as e:
|
126 |
-
raise gr.Error(f"Failed to add document {document.id} to index: {e}")
|
127 |
-
|
128 |
-
|
129 |
-
def process_text(
|
130 |
text: str,
|
131 |
doc_id: str,
|
132 |
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
|
133 |
-
processed_documents: dict[
|
134 |
-
str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
135 |
-
],
|
136 |
-
vector_store: VectorStore[Tuple[str, str]],
|
137 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
138 |
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
|
139 |
text, annotate it, and add it to the index.
|
@@ -142,8 +114,6 @@ def process_text(
|
|
142 |
text: The text to process.
|
143 |
doc_id: The ID of the document.
|
144 |
models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
|
145 |
-
processed_documents: The index of processed documents.
|
146 |
-
vector_store: The vector store to save the embeddings.
|
147 |
|
148 |
Returns:
|
149 |
The processed document.
|
@@ -156,14 +126,12 @@ def process_text(
|
|
156 |
# add single partition from the whole text (the model only considers text in partitions)
|
157 |
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
|
158 |
# annotate the document
|
159 |
-
|
160 |
document=document,
|
161 |
pipeline=models[0],
|
162 |
embedding_model=models[1],
|
163 |
embedding_tokenizer=models[2],
|
164 |
)
|
165 |
-
# add the document to the index
|
166 |
-
add_to_index(document, processed_documents, vector_store)
|
167 |
|
168 |
return document
|
169 |
except Exception as e:
|
@@ -187,113 +155,165 @@ def get_annotation_from_document(
|
|
187 |
return annotation
|
188 |
|
189 |
|
190 |
-
|
191 |
-
doc_id: str,
|
192 |
-
annotation_id: str,
|
193 |
-
annotation_layer: str,
|
194 |
-
processed_documents: dict[
|
195 |
-
str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
196 |
-
],
|
197 |
-
) -> LabeledSpan:
|
198 |
-
document = processed_documents.get(doc_id)
|
199 |
-
if document is None:
|
200 |
-
raise gr.Error(
|
201 |
-
f"Document '{doc_id}' not found in index. Available documents: {list(processed_documents)}"
|
202 |
-
)
|
203 |
-
return get_annotation_from_document(document, annotation_id, annotation_layer)
|
204 |
-
|
205 |
-
|
206 |
-
def get_similar_adus(
|
207 |
-
ref_annotation_id: str,
|
208 |
-
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
209 |
-
vector_store: VectorStore[Tuple[str, str]],
|
210 |
-
processed_documents: dict[
|
211 |
-
str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
212 |
-
],
|
213 |
-
min_similarity: float,
|
214 |
-
top_k: int,
|
215 |
-
) -> pd.DataFrame:
|
216 |
-
similar_entries = vector_store.retrieve_similar(
|
217 |
-
ref_id=(ref_document.id, ref_annotation_id),
|
218 |
-
min_similarity=min_similarity,
|
219 |
-
top_k=top_k,
|
220 |
-
)
|
221 |
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
)
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
(doc_id, annotation_id, score, str(annotation))
|
236 |
-
for ((doc_id, annotation_id), score), annotation in zip(
|
237 |
-
similar_entries, similar_annotations
|
238 |
)
|
239 |
-
|
240 |
-
|
241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
str,
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
result = []
|
264 |
-
for (doc_id, annotation_id), score in similar_entries:
|
265 |
-
# skip entries from the same document
|
266 |
-
if doc_id == ref_document.id:
|
267 |
-
continue
|
268 |
-
document = processed_documents[doc_id]
|
269 |
-
tail2rels = defaultdict(list)
|
270 |
-
head2rels = defaultdict(list)
|
271 |
-
for rel in document.binary_relations.predictions:
|
272 |
-
# skip non-argumentative relations
|
273 |
-
if rel.label not in relation_types:
|
274 |
continue
|
275 |
-
|
276 |
-
tail2rels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
{
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
295 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
|
297 |
-
|
298 |
-
|
299 |
-
return df
|
|
|
11 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
12 |
from rendering_utils import labeled_span_to_id
|
13 |
from transformers import PreTrainedModel, PreTrainedTokenizer
|
14 |
+
from vector_store import SimpleVectorStore, VectorStore
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
|
19 |
+
def _embed_text_annotations(
|
20 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
21 |
model: PreTrainedModel,
|
22 |
tokenizer: PreTrainedTokenizer,
|
|
|
73 |
return embeddings
|
74 |
|
75 |
|
76 |
+
def _annotate(
|
77 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
78 |
pipeline: Pipeline,
|
79 |
embedding_model: Optional[PreTrainedModel] = None,
|
|
|
84 |
pipeline(document)
|
85 |
|
86 |
if embedding_model is not None and embedding_tokenizer is not None:
|
87 |
+
adu_embeddings = _embed_text_annotations(
|
88 |
document=document,
|
89 |
model=embedding_model,
|
90 |
tokenizer=embedding_tokenizer,
|
|
|
102 |
)
|
103 |
|
104 |
|
105 |
+
def create_and_annotate_document(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
text: str,
|
107 |
doc_id: str,
|
108 |
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
|
|
|
|
|
|
|
|
|
109 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
110 |
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
|
111 |
text, annotate it, and add it to the index.
|
|
|
114 |
text: The text to process.
|
115 |
doc_id: The ID of the document.
|
116 |
models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
|
|
|
|
|
117 |
|
118 |
Returns:
|
119 |
The processed document.
|
|
|
126 |
# add single partition from the whole text (the model only considers text in partitions)
|
127 |
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
|
128 |
# annotate the document
|
129 |
+
_annotate(
|
130 |
document=document,
|
131 |
pipeline=models[0],
|
132 |
embedding_model=models[1],
|
133 |
embedding_tokenizer=models[2],
|
134 |
)
|
|
|
|
|
135 |
|
136 |
return document
|
137 |
except Exception as e:
|
|
|
155 |
return annotation
|
156 |
|
157 |
|
158 |
+
class DocumentStore:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
161 |
+
|
162 |
+
def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str]]] = None):
|
163 |
+
self.documents = {}
|
164 |
+
self.vector_store = vector_store or SimpleVectorStore()
|
165 |
+
|
166 |
+
def get_annotation(
|
167 |
+
self,
|
168 |
+
doc_id: str,
|
169 |
+
annotation_id: str,
|
170 |
+
annotation_layer: str,
|
171 |
+
) -> LabeledSpan:
|
172 |
+
document = self.documents.get(doc_id)
|
173 |
+
if document is None:
|
174 |
+
raise gr.Error(
|
175 |
+
f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}"
|
176 |
+
)
|
177 |
+
return get_annotation_from_document(document, annotation_id, annotation_layer)
|
178 |
+
|
179 |
+
def get_similar_adus_df(
|
180 |
+
self,
|
181 |
+
ref_annotation_id: str,
|
182 |
+
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
183 |
+
min_similarity: float,
|
184 |
+
top_k: int,
|
185 |
+
) -> pd.DataFrame:
|
186 |
+
similar_entries = self.vector_store.retrieve_similar(
|
187 |
+
ref_id=(ref_document.id, ref_annotation_id),
|
188 |
+
min_similarity=min_similarity,
|
189 |
+
top_k=top_k,
|
190 |
)
|
191 |
+
|
192 |
+
similar_annotations = [
|
193 |
+
self.get_annotation(
|
194 |
+
doc_id=doc_id,
|
195 |
+
annotation_id=annotation_id,
|
196 |
+
annotation_layer="labeled_spans",
|
|
|
|
|
|
|
197 |
)
|
198 |
+
for (doc_id, annotation_id), _ in similar_entries
|
199 |
+
]
|
200 |
+
df = pd.DataFrame(
|
201 |
+
[
|
202 |
+
# unpack the tuple (doc_id, annotation_id) to separate columns
|
203 |
+
# and add the similarity score and the text of the annotation
|
204 |
+
(doc_id, annotation_id, score, str(annotation))
|
205 |
+
for ((doc_id, annotation_id), score), annotation in zip(
|
206 |
+
similar_entries, similar_annotations
|
207 |
+
)
|
208 |
+
],
|
209 |
+
columns=["doc_id", "adu_id", "sim_score", "text"],
|
210 |
+
)
|
211 |
|
212 |
+
return df
|
213 |
+
|
214 |
+
def get_relevant_adus_df(
|
215 |
+
self,
|
216 |
+
ref_annotation_id: str,
|
217 |
+
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
218 |
+
min_similarity: float,
|
219 |
+
top_k: int,
|
220 |
+
relation_types: List[str],
|
221 |
+
columns: List[str],
|
222 |
+
) -> pd.DataFrame:
|
223 |
+
similar_entries = self.vector_store.retrieve_similar(
|
224 |
+
ref_id=(ref_document.id, ref_annotation_id),
|
225 |
+
min_similarity=min_similarity,
|
226 |
+
top_k=top_k,
|
227 |
+
)
|
228 |
+
result = []
|
229 |
+
for (doc_id, annotation_id), score in similar_entries:
|
230 |
+
# skip entries from the same document
|
231 |
+
if doc_id == ref_document.id:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
continue
|
233 |
+
document = self.documents[doc_id]
|
234 |
+
tail2rels = defaultdict(list)
|
235 |
+
head2rels = defaultdict(list)
|
236 |
+
for rel in document.binary_relations.predictions:
|
237 |
+
# skip non-argumentative relations
|
238 |
+
if rel.label not in relation_types:
|
239 |
+
continue
|
240 |
+
head2rels[rel.head].append(rel)
|
241 |
+
tail2rels[rel.tail].append(rel)
|
242 |
+
|
243 |
+
id2annotation = {
|
244 |
+
labeled_span_to_id(annotation): annotation
|
245 |
+
for annotation in document.labeled_spans.predictions
|
246 |
+
}
|
247 |
+
annotation = id2annotation.get(annotation_id)
|
248 |
+
# note: we do not need to check if the annotation is different from the reference annotation,
|
249 |
+
# because they come from different documents and we already skip entries from the same document
|
250 |
+
for rel in head2rels.get(annotation, []):
|
251 |
+
result.append(
|
252 |
+
{
|
253 |
+
"doc_id": doc_id,
|
254 |
+
"reference_adu": str(annotation),
|
255 |
+
"sim_score": score,
|
256 |
+
"rel_score": rel.score,
|
257 |
+
"relation": rel.label,
|
258 |
+
"adu": str(rel.tail),
|
259 |
+
}
|
260 |
+
)
|
261 |
|
262 |
+
# define column order
|
263 |
+
df = pd.DataFrame(result, columns=columns)
|
264 |
+
return df
|
265 |
+
|
266 |
+
def add_document(
|
267 |
+
self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
268 |
+
) -> None:
|
269 |
+
try:
|
270 |
+
if document.id in self.documents:
|
271 |
+
gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
|
272 |
+
|
273 |
+
# save the processed document to the index
|
274 |
+
self.documents[document.id] = document
|
275 |
+
|
276 |
+
# save the embeddings to the vector store
|
277 |
+
for adu_id, embedding in document.metadata["embeddings"].items():
|
278 |
+
self.vector_store.save((document.id, adu_id), embedding)
|
279 |
+
|
280 |
+
gr.Info(
|
281 |
+
f"Added document {document.id} to index (index contains {len(self.documents)} "
|
282 |
+
f"documents and {len(self.vector_store)} embeddings)."
|
283 |
)
|
284 |
+
except Exception as e:
|
285 |
+
raise gr.Error(f"Failed to add document {document.id} to index: {e}")
|
286 |
+
|
287 |
+
def add_document_from_dict(self, document_dict: dict) -> None:
|
288 |
+
document = self.DOCUMENT_TYPE.fromdict(document_dict)
|
289 |
+
# metadata is not automatically deserialized, so we need to set it manually
|
290 |
+
document.metadata = document_dict["metadata"]
|
291 |
+
self.add_document(document)
|
292 |
+
|
293 |
+
def add_documents(
|
294 |
+
self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
|
295 |
+
) -> None:
|
296 |
+
for document in documents:
|
297 |
+
self.add_document(document)
|
298 |
+
|
299 |
+
def get_document(
|
300 |
+
self, doc_id: str
|
301 |
+
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
302 |
+
return self.documents[doc_id]
|
303 |
+
|
304 |
+
def overview(self) -> pd.DataFrame:
|
305 |
+
df = pd.DataFrame(
|
306 |
+
[
|
307 |
+
(
|
308 |
+
doc_id,
|
309 |
+
len(document.labeled_spans.predictions),
|
310 |
+
len(document.binary_relations.predictions),
|
311 |
+
)
|
312 |
+
for doc_id, document in self.documents.items()
|
313 |
+
],
|
314 |
+
columns=["doc_id", "num_adus", "num_relations"],
|
315 |
+
)
|
316 |
+
return df
|
317 |
|
318 |
+
def as_dict(self) -> dict:
|
319 |
+
return {doc_id: document.asdict() for doc_id, document in self.documents.items()}
|
|