ArneBinder
commited on
Commit
•
fed112f
1
Parent(s):
38e5624
from https://github.com/ArneBinder/pie-document-level/pull/229
Browse filesexcept that we still use `SimpleVectorStore` instead of `QdrantVectorStore`
- app.py +78 -16
- document_store.py +85 -1
- model_utils.py +42 -11
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import json
|
2 |
import logging
|
3 |
import os.path
|
|
|
4 |
import tempfile
|
5 |
from functools import partial
|
6 |
from typing import List, Optional, Tuple
|
@@ -32,6 +33,19 @@ DEFAULT_EMBEDDING_MODEL_NAME = "allenai/scibert_scivocab_uncased"
|
|
32 |
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
DEFAULT_EMBEDDING_MAX_LENGTH = 512
|
34 |
DEFAULT_EMBEDDING_BATCH_SIZE = 32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
def render_annotated_document(
|
@@ -55,9 +69,16 @@ def wrapped_process_text(
|
|
55 |
doc_id: str,
|
56 |
models: Tuple[Pipeline, Optional[EmbeddingModel]],
|
57 |
document_store: DocumentStore,
|
|
|
58 |
) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
|
59 |
try:
|
60 |
-
document = create_document(
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
annotate_document(
|
62 |
document=document,
|
63 |
annotation_pipeline=models[0],
|
@@ -77,6 +98,8 @@ def process_uploaded_files(
|
|
77 |
file_names: List[str],
|
78 |
models: Tuple[Pipeline, Optional[EmbeddingModel]],
|
79 |
document_store: DocumentStore,
|
|
|
|
|
80 |
) -> pd.DataFrame:
|
81 |
try:
|
82 |
new_documents = []
|
@@ -90,6 +113,9 @@ def process_uploaded_files(
|
|
90 |
new_document = create_document(
|
91 |
text=text,
|
92 |
doc_id=base_file_name,
|
|
|
|
|
|
|
93 |
)
|
94 |
annotate_document(
|
95 |
document=new_document,
|
@@ -103,7 +129,7 @@ def process_uploaded_files(
|
|
103 |
except Exception as e:
|
104 |
raise gr.Error(f"Failed to process uploaded files: {e}")
|
105 |
|
106 |
-
return document_store.overview()
|
107 |
|
108 |
|
109 |
def open_accordion():
|
@@ -137,7 +163,7 @@ def set_relation_types(
|
|
137 |
|
138 |
return gr.Dropdown(
|
139 |
choices=relation_types,
|
140 |
-
label="Relation Types",
|
141 |
value=default,
|
142 |
multiselect=True,
|
143 |
)
|
@@ -204,7 +230,7 @@ def main():
|
|
204 |
DocumentStore(
|
205 |
span_annotation_caption="adu",
|
206 |
relation_annotation_caption="relation",
|
207 |
-
vector_store=
|
208 |
)
|
209 |
)
|
210 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
@@ -264,6 +290,11 @@ def main():
|
|
264 |
],
|
265 |
outputs=models_state,
|
266 |
)
|
|
|
|
|
|
|
|
|
|
|
267 |
|
268 |
predict_btn = gr.Button("Analyse")
|
269 |
|
@@ -289,11 +320,6 @@ def main():
|
|
289 |
|
290 |
rendered_output = gr.HTML(label="Rendered Output")
|
291 |
|
292 |
-
# add_to_index_btn = gr.Button("Add current result to Index")
|
293 |
-
upload_btn = gr.UploadButton(
|
294 |
-
"Upload & Analyse Documents", file_types=["text"], file_count="multiple"
|
295 |
-
)
|
296 |
-
|
297 |
with gr.Column(scale=1):
|
298 |
with gr.Accordion(
|
299 |
"Indexed Documents", open=False
|
@@ -302,12 +328,22 @@ def main():
|
|
302 |
headers=["id", "num_adus", "num_relations"],
|
303 |
interactive=False,
|
304 |
)
|
|
|
|
|
|
|
|
|
305 |
with gr.Row():
|
306 |
download_processed_documents_btn = gr.DownloadButton("Download")
|
307 |
upload_processed_documents_btn = gr.UploadButton(
|
308 |
"Upload", file_types=["json"]
|
309 |
)
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
with gr.Accordion("Selected ADU", open=False):
|
312 |
selected_adu_id = gr.Textbox(label="ID", elem_id="selected_adu_id")
|
313 |
selected_adu_text = gr.Textbox(label="Text")
|
@@ -329,6 +365,14 @@ def main():
|
|
329 |
)
|
330 |
retrieve_similar_adus_btn = gr.Button("Retrieve similar ADUs")
|
331 |
similar_adus = gr.DataFrame(headers=["doc_id", "adu_id", "score", "text"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
relation_types = set_relation_types(
|
333 |
models_state.value, default=["supports", "contradicts"]
|
334 |
)
|
@@ -353,16 +397,19 @@ def main():
|
|
353 |
outputs=rendered_output,
|
354 |
)
|
355 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
356 |
predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
|
357 |
fn=wrapped_process_text,
|
358 |
-
inputs=[doc_text, doc_id, models_state, document_store_state],
|
359 |
outputs=[document_json, document_state],
|
360 |
api_name="predict",
|
361 |
-
).success(
|
362 |
-
fn=lambda document_store: document_store.overview(),
|
363 |
-
inputs=[document_store_state],
|
364 |
-
outputs=[processed_documents_df],
|
365 |
-
)
|
366 |
render_btn.click(**render_event_kwargs, api_name="render")
|
367 |
|
368 |
document_state.change(
|
@@ -377,7 +424,13 @@ def main():
|
|
377 |
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
|
378 |
).then(
|
379 |
fn=process_uploaded_files,
|
380 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
outputs=[processed_documents_df],
|
382 |
)
|
383 |
processed_documents_df.select(
|
@@ -385,6 +438,7 @@ def main():
|
|
385 |
inputs=[processed_documents_df, document_store_state],
|
386 |
outputs=[document_state],
|
387 |
)
|
|
|
388 |
|
389 |
download_processed_documents_btn.click(
|
390 |
fn=partial(download_processed_documents, file_name="processed_documents.zip"),
|
@@ -446,6 +500,14 @@ def main():
|
|
446 |
inputs=[models_state],
|
447 |
outputs=[relation_types],
|
448 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
|
450 |
# retrieve_relevant_adus_btn.click(
|
451 |
# **retrieve_relevant_adus_event_kwargs
|
|
|
1 |
import json
|
2 |
import logging
|
3 |
import os.path
|
4 |
+
import re
|
5 |
import tempfile
|
6 |
from functools import partial
|
7 |
from typing import List, Optional, Tuple
|
|
|
33 |
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
34 |
DEFAULT_EMBEDDING_MAX_LENGTH = 512
|
35 |
DEFAULT_EMBEDDING_BATCH_SIZE = 32
|
36 |
+
DEFAULT_SPLIT_REGEX = "\n\n\n+"
|
37 |
+
|
38 |
+
|
39 |
+
def escape_regex(regex: str) -> str:
|
40 |
+
# "double escape" the backslashes
|
41 |
+
result = regex.encode("unicode_escape").decode("utf-8")
|
42 |
+
return result
|
43 |
+
|
44 |
+
|
45 |
+
def unescape_regex(regex: str) -> str:
|
46 |
+
# reverse of escape_regex
|
47 |
+
result = regex.encode("utf-8").decode("unicode_escape")
|
48 |
+
return result
|
49 |
|
50 |
|
51 |
def render_annotated_document(
|
|
|
69 |
doc_id: str,
|
70 |
models: Tuple[Pipeline, Optional[EmbeddingModel]],
|
71 |
document_store: DocumentStore,
|
72 |
+
split_regex_escaped: str,
|
73 |
) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
|
74 |
try:
|
75 |
+
document = create_document(
|
76 |
+
text=text,
|
77 |
+
doc_id=doc_id,
|
78 |
+
split_regex=unescape_regex(split_regex_escaped)
|
79 |
+
if len(split_regex_escaped) > 0
|
80 |
+
else None,
|
81 |
+
)
|
82 |
annotate_document(
|
83 |
document=document,
|
84 |
annotation_pipeline=models[0],
|
|
|
98 |
file_names: List[str],
|
99 |
models: Tuple[Pipeline, Optional[EmbeddingModel]],
|
100 |
document_store: DocumentStore,
|
101 |
+
split_regex_escaped: str,
|
102 |
+
show_max_cross_doc_sims: bool = False,
|
103 |
) -> pd.DataFrame:
|
104 |
try:
|
105 |
new_documents = []
|
|
|
113 |
new_document = create_document(
|
114 |
text=text,
|
115 |
doc_id=base_file_name,
|
116 |
+
split_regex=unescape_regex(split_regex_escaped)
|
117 |
+
if len(split_regex_escaped) > 0
|
118 |
+
else None,
|
119 |
)
|
120 |
annotate_document(
|
121 |
document=new_document,
|
|
|
129 |
except Exception as e:
|
130 |
raise gr.Error(f"Failed to process uploaded files: {e}")
|
131 |
|
132 |
+
return document_store.overview(with_max_cross_doc_sims=show_max_cross_doc_sims)
|
133 |
|
134 |
|
135 |
def open_accordion():
|
|
|
163 |
|
164 |
return gr.Dropdown(
|
165 |
choices=relation_types,
|
166 |
+
label="Argumentative Relation Types",
|
167 |
value=default,
|
168 |
multiselect=True,
|
169 |
)
|
|
|
230 |
DocumentStore(
|
231 |
span_annotation_caption="adu",
|
232 |
relation_annotation_caption="relation",
|
233 |
+
vector_store=QdrantVectorStore(),
|
234 |
)
|
235 |
)
|
236 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
|
|
290 |
],
|
291 |
outputs=models_state,
|
292 |
)
|
293 |
+
split_regex_escaped = gr.Textbox(
|
294 |
+
label="Regex to partition the text",
|
295 |
+
placeholder="Regular expression pattern to split the text into partitions",
|
296 |
+
value=escape_regex(DEFAULT_SPLIT_REGEX),
|
297 |
+
)
|
298 |
|
299 |
predict_btn = gr.Button("Analyse")
|
300 |
|
|
|
320 |
|
321 |
rendered_output = gr.HTML(label="Rendered Output")
|
322 |
|
|
|
|
|
|
|
|
|
|
|
323 |
with gr.Column(scale=1):
|
324 |
with gr.Accordion(
|
325 |
"Indexed Documents", open=False
|
|
|
328 |
headers=["id", "num_adus", "num_relations"],
|
329 |
interactive=False,
|
330 |
)
|
331 |
+
show_max_cross_docu_sims = gr.Checkbox(
|
332 |
+
label="Show max cross-document similarities", value=False
|
333 |
+
)
|
334 |
+
gr.Markdown("Data Snapshot:")
|
335 |
with gr.Row():
|
336 |
download_processed_documents_btn = gr.DownloadButton("Download")
|
337 |
upload_processed_documents_btn = gr.UploadButton(
|
338 |
"Upload", file_types=["json"]
|
339 |
)
|
340 |
|
341 |
+
upload_btn = gr.UploadButton(
|
342 |
+
"Upload & Analyse Reference Documents",
|
343 |
+
file_types=["text"],
|
344 |
+
file_count="multiple",
|
345 |
+
)
|
346 |
+
|
347 |
with gr.Accordion("Selected ADU", open=False):
|
348 |
selected_adu_id = gr.Textbox(label="ID", elem_id="selected_adu_id")
|
349 |
selected_adu_text = gr.Textbox(label="Text")
|
|
|
365 |
)
|
366 |
retrieve_similar_adus_btn = gr.Button("Retrieve similar ADUs")
|
367 |
similar_adus = gr.DataFrame(headers=["doc_id", "adu_id", "score", "text"])
|
368 |
+
|
369 |
+
all2all_adu_similarities_button = gr.Button(
|
370 |
+
"Compute all ADU-to-ADU similarities"
|
371 |
+
)
|
372 |
+
all2all_adu_similarities = gr.DataFrame(
|
373 |
+
headers=["sim_score", "doc_id", "other_doc_id", "text", "other_text"]
|
374 |
+
)
|
375 |
+
|
376 |
relation_types = set_relation_types(
|
377 |
models_state.value, default=["supports", "contradicts"]
|
378 |
)
|
|
|
397 |
outputs=rendered_output,
|
398 |
)
|
399 |
|
400 |
+
show_overview_kwargs = dict(
|
401 |
+
fn=lambda document_store, show_max_sims: document_store.overview(
|
402 |
+
with_max_cross_doc_sims=show_max_sims
|
403 |
+
),
|
404 |
+
inputs=[document_store_state, show_max_cross_docu_sims],
|
405 |
+
outputs=[processed_documents_df],
|
406 |
+
)
|
407 |
predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
|
408 |
fn=wrapped_process_text,
|
409 |
+
inputs=[doc_text, doc_id, models_state, document_store_state, split_regex_escaped],
|
410 |
outputs=[document_json, document_state],
|
411 |
api_name="predict",
|
412 |
+
).success(**show_overview_kwargs)
|
|
|
|
|
|
|
|
|
413 |
render_btn.click(**render_event_kwargs, api_name="render")
|
414 |
|
415 |
document_state.change(
|
|
|
424 |
fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
|
425 |
).then(
|
426 |
fn=process_uploaded_files,
|
427 |
+
inputs=[
|
428 |
+
upload_btn,
|
429 |
+
models_state,
|
430 |
+
document_store_state,
|
431 |
+
split_regex_escaped,
|
432 |
+
show_max_cross_docu_sims,
|
433 |
+
],
|
434 |
outputs=[processed_documents_df],
|
435 |
)
|
436 |
processed_documents_df.select(
|
|
|
438 |
inputs=[processed_documents_df, document_store_state],
|
439 |
outputs=[document_state],
|
440 |
)
|
441 |
+
show_max_cross_docu_sims.change(**show_overview_kwargs)
|
442 |
|
443 |
download_processed_documents_btn.click(
|
444 |
fn=partial(download_processed_documents, file_name="processed_documents.zip"),
|
|
|
500 |
inputs=[models_state],
|
501 |
outputs=[relation_types],
|
502 |
)
|
503 |
+
all2all_adu_similarities_button.click(
|
504 |
+
fn=partial(
|
505 |
+
DocumentStore.get_all2all_adu_similarities,
|
506 |
+
columns=all2all_adu_similarities.headers,
|
507 |
+
),
|
508 |
+
inputs=[document_store_state],
|
509 |
+
outputs=[all2all_adu_similarities],
|
510 |
+
)
|
511 |
|
512 |
# retrieve_relevant_adus_btn.click(
|
513 |
# **retrieve_relevant_adus_event_kwargs
|
document_store.py
CHANGED
@@ -8,6 +8,7 @@ from collections import defaultdict
|
|
8 |
from typing import Any, Dict, List, Optional
|
9 |
|
10 |
import gradio as gr
|
|
|
11 |
import pandas as pd
|
12 |
from annotation_utils import labeled_span_to_id
|
13 |
from pytorch_ie import Annotation
|
@@ -417,7 +418,7 @@ class DocumentStore:
|
|
417 |
|
418 |
return document
|
419 |
|
420 |
-
def overview(self) -> pd.DataFrame:
|
421 |
rows = []
|
422 |
for doc_id, document in self.documents.items():
|
423 |
layers = {
|
@@ -429,6 +430,38 @@ class DocumentStore:
|
|
429 |
layer_sizes = {f"num_{caption}s": len(layer) for caption, layer in layers.items()}
|
430 |
rows.append({"doc_id": doc_id, **layer_sizes})
|
431 |
df = pd.DataFrame(rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
return df
|
433 |
|
434 |
def as_dict(self, include_embeddings: bool = True) -> dict:
|
@@ -441,3 +474,54 @@ class DocumentStore:
|
|
441 |
}
|
442 |
result[doc_id] = doc_dict
|
443 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from typing import Any, Dict, List, Optional
|
9 |
|
10 |
import gradio as gr
|
11 |
+
import numpy as np
|
12 |
import pandas as pd
|
13 |
from annotation_utils import labeled_span_to_id
|
14 |
from pytorch_ie import Annotation
|
|
|
418 |
|
419 |
return document
|
420 |
|
421 |
+
def overview(self, with_max_cross_doc_sims: bool = False) -> pd.DataFrame:
|
422 |
rows = []
|
423 |
for doc_id, document in self.documents.items():
|
424 |
layers = {
|
|
|
430 |
layer_sizes = {f"num_{caption}s": len(layer) for caption, layer in layers.items()}
|
431 |
rows.append({"doc_id": doc_id, **layer_sizes})
|
432 |
df = pd.DataFrame(rows)
|
433 |
+
|
434 |
+
# add highest cross-document similarity score for each document
|
435 |
+
if with_max_cross_doc_sims and len(self.documents) > 1:
|
436 |
+
# Setting min_similarity to None and top_k to None to get all similarities. Otherwise,
|
437 |
+
# it may happen that this occludes max cross-doc sim for some documents in the
|
438 |
+
# case that there are more than top_k ADUs in the reference document that have a higher
|
439 |
+
# similarity with each other than the highest similarity to any ADU in another document
|
440 |
+
# or if the cross-doc similarity is below the min_similarity threshold.
|
441 |
+
all2all_adu_similarities = self.get_all2all_adu_similarities(
|
442 |
+
min_similarity=None, top_k=None, columns=["doc_id", "other_doc_id", "sim_score"]
|
443 |
+
)
|
444 |
+
max_doc2doc_similarities = all2all_adu_similarities.pivot_table(
|
445 |
+
values="sim_score", index="doc_id", columns="other_doc_id", aggfunc="max"
|
446 |
+
)
|
447 |
+
|
448 |
+
max_doc2doc_similarities.sort_index(axis="index", inplace=True)
|
449 |
+
max_doc2doc_similarities.sort_index(axis="columns", inplace=True)
|
450 |
+
# check that the index and columns are the same
|
451 |
+
if (max_doc2doc_similarities.index != max_doc2doc_similarities.columns).any():
|
452 |
+
raise gr.Error("Index and columns of max_doc2doc_similarities are not the same.")
|
453 |
+
# set diagonal entries to minus infinity to exclude them from the maximum
|
454 |
+
np.fill_diagonal(max_doc2doc_similarities.values, -np.inf)
|
455 |
+
|
456 |
+
max_doc_ids = max_doc2doc_similarities.idxmax(axis="columns")
|
457 |
+
max_similarities = max_doc2doc_similarities.max(axis="columns")
|
458 |
+
|
459 |
+
# set the index to the doc_id to correctly join the series
|
460 |
+
df.set_index("doc_id", inplace=True)
|
461 |
+
df["max_cross_doc_sim_doc_id"] = max_doc_ids
|
462 |
+
df["max_cross_doc_sim_score"] = max_similarities
|
463 |
+
df.reset_index(inplace=True)
|
464 |
+
|
465 |
return df
|
466 |
|
467 |
def as_dict(self, include_embeddings: bool = True) -> dict:
|
|
|
474 |
}
|
475 |
result[doc_id] = doc_dict
|
476 |
return result
|
477 |
+
|
478 |
+
def get_all2all_adu_similarities(
|
479 |
+
self,
|
480 |
+
min_similarity: Optional[float] = 0.5,
|
481 |
+
top_k: Optional[int] = 100,
|
482 |
+
columns: Optional[List[str]] = None,
|
483 |
+
) -> pd.DataFrame:
|
484 |
+
"""Get the similarities between all ADUs in the store.
|
485 |
+
|
486 |
+
Args:
|
487 |
+
min_similarity: The minimum similarity score to consider.
|
488 |
+
top_k: The number of similar ADUs to return.
|
489 |
+
columns: The columns to include in the result DataFrame. If None, all columns are included.
|
490 |
+
|
491 |
+
Returns:
|
492 |
+
A DataFrame with the columns: doc_id, text, other_doc_id, other_text, sim_score.
|
493 |
+
"""
|
494 |
+
result = []
|
495 |
+
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
496 |
+
for doc_id, document in self.documents.items():
|
497 |
+
for adu in document.labeled_spans.predictions:
|
498 |
+
adu_id = labeled_span_to_id(adu)
|
499 |
+
similar_entries = self.vector_store.retrieve_similar(
|
500 |
+
ref_payload=self.construct_embedding_payload(document, adu_id),
|
501 |
+
min_similarity=min_similarity,
|
502 |
+
top_k=top_k,
|
503 |
+
)
|
504 |
+
for _, payload, score in similar_entries:
|
505 |
+
other_doc_id = payload["doc_id"]
|
506 |
+
other_document = self.documents[other_doc_id]
|
507 |
+
other_adu = get_annotation_from_document(
|
508 |
+
other_document,
|
509 |
+
payload["annotation_id"],
|
510 |
+
self.span_layer_name,
|
511 |
+
use_predictions=self.use_predictions,
|
512 |
+
)
|
513 |
+
result.append(
|
514 |
+
{
|
515 |
+
"sim_score": score,
|
516 |
+
"doc_id": doc_id,
|
517 |
+
"other_doc_id": other_doc_id,
|
518 |
+
"adu_id": adu_id,
|
519 |
+
"other_adu_id": payload["annotation_id"],
|
520 |
+
"text": str(adu),
|
521 |
+
"other_text": str(other_adu),
|
522 |
+
}
|
523 |
+
)
|
524 |
+
result_df = pd.DataFrame(result)
|
525 |
+
if columns is not None:
|
526 |
+
result_df = result_df[columns]
|
527 |
+
return result_df
|
model_utils.py
CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
|
|
5 |
import torch
|
6 |
from annotation_utils import labeled_span_to_id
|
7 |
from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
|
|
|
8 |
from pytorch_ie import Pipeline
|
9 |
from pytorch_ie.annotations import LabeledSpan
|
10 |
from pytorch_ie.auto import AutoPipeline
|
@@ -48,7 +49,7 @@ def annotate_document(
|
|
48 |
|
49 |
|
50 |
def create_document(
|
51 |
-
text: str, doc_id: str
|
52 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
53 |
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
|
54 |
text.
|
@@ -56,6 +57,7 @@ def create_document(
|
|
56 |
Parameters:
|
57 |
text: The text to process.
|
58 |
doc_id: The ID of the document.
|
|
|
59 |
|
60 |
Returns:
|
61 |
The processed document.
|
@@ -64,8 +66,14 @@ def create_document(
|
|
64 |
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
|
65 |
id=doc_id, text=text, metadata={}
|
66 |
)
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
return document
|
70 |
|
71 |
|
@@ -92,24 +100,22 @@ def load_argumentation_model(
|
|
92 |
taskmodule_kwargs=dict(revision=revision),
|
93 |
model_kwargs=dict(revision=revision),
|
94 |
)
|
|
|
|
|
|
|
95 |
except Exception as e:
|
96 |
raise gr.Error(f"Failed to load argumentation model: {e}")
|
97 |
-
|
98 |
return model
|
99 |
|
100 |
|
101 |
-
def
|
102 |
-
model_name: str,
|
103 |
-
revision: Optional[str] = None,
|
104 |
embedding_model_name: Optional[str] = None,
|
105 |
# embedding_model_revision: Optional[str] = None,
|
106 |
embedding_max_length: int = 512,
|
107 |
embedding_batch_size: int = 16,
|
108 |
device: str = "cpu",
|
109 |
-
) ->
|
110 |
-
torch.cuda.empty_cache()
|
111 |
-
argumentation_model = load_argumentation_model(model_name, revision=revision, device=device)
|
112 |
-
embedding_model = None
|
113 |
if embedding_model_name is not None and embedding_model_name.strip():
|
114 |
try:
|
115 |
embedding_model = HuggingfaceEmbeddingModel(
|
@@ -119,7 +125,32 @@ def load_models(
|
|
119 |
max_length=embedding_max_length,
|
120 |
batch_size=embedding_batch_size,
|
121 |
)
|
|
|
122 |
except Exception as e:
|
123 |
raise gr.Error(f"Failed to load embedding model: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
return argumentation_model, embedding_model
|
|
|
5 |
import torch
|
6 |
from annotation_utils import labeled_span_to_id
|
7 |
from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
|
8 |
+
from pie_modules.document.processing import RegexPartitioner
|
9 |
from pytorch_ie import Pipeline
|
10 |
from pytorch_ie.annotations import LabeledSpan
|
11 |
from pytorch_ie.auto import AutoPipeline
|
|
|
49 |
|
50 |
|
51 |
def create_document(
|
52 |
+
text: str, doc_id: str, split_regex: Optional[str] = None
|
53 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
54 |
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
|
55 |
text.
|
|
|
57 |
Parameters:
|
58 |
text: The text to process.
|
59 |
doc_id: The ID of the document.
|
60 |
+
split_regex: A regular expression pattern to use for splitting the text into partitions.
|
61 |
|
62 |
Returns:
|
63 |
The processed document.
|
|
|
66 |
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
|
67 |
id=doc_id, text=text, metadata={}
|
68 |
)
|
69 |
+
if split_regex is not None:
|
70 |
+
partitioner = RegexPartitioner(
|
71 |
+
pattern=split_regex, partition_layer_name="labeled_partitions"
|
72 |
+
)
|
73 |
+
document = partitioner(document)
|
74 |
+
else:
|
75 |
+
# add single partition from the whole text (the model only considers text in partitions)
|
76 |
+
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
|
77 |
return document
|
78 |
|
79 |
|
|
|
100 |
taskmodule_kwargs=dict(revision=revision),
|
101 |
model_kwargs=dict(revision=revision),
|
102 |
)
|
103 |
+
gr.Info(
|
104 |
+
f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}"
|
105 |
+
)
|
106 |
except Exception as e:
|
107 |
raise gr.Error(f"Failed to load argumentation model: {e}")
|
108 |
+
|
109 |
return model
|
110 |
|
111 |
|
112 |
+
def load_embedding_model(
|
|
|
|
|
113 |
embedding_model_name: Optional[str] = None,
|
114 |
# embedding_model_revision: Optional[str] = None,
|
115 |
embedding_max_length: int = 512,
|
116 |
embedding_batch_size: int = 16,
|
117 |
device: str = "cpu",
|
118 |
+
) -> Optional[EmbeddingModel]:
|
|
|
|
|
|
|
119 |
if embedding_model_name is not None and embedding_model_name.strip():
|
120 |
try:
|
121 |
embedding_model = HuggingfaceEmbeddingModel(
|
|
|
125 |
max_length=embedding_max_length,
|
126 |
batch_size=embedding_batch_size,
|
127 |
)
|
128 |
+
gr.Info(f"Loaded embedding model: model_name={embedding_model_name}, device={device}")
|
129 |
except Exception as e:
|
130 |
raise gr.Error(f"Failed to load embedding model: {e}")
|
131 |
+
else:
|
132 |
+
embedding_model = None
|
133 |
+
|
134 |
+
return embedding_model
|
135 |
+
|
136 |
+
|
137 |
+
def load_models(
|
138 |
+
model_name: str,
|
139 |
+
revision: Optional[str] = None,
|
140 |
+
embedding_model_name: Optional[str] = None,
|
141 |
+
# embedding_model_revision: Optional[str] = None,
|
142 |
+
embedding_max_length: int = 512,
|
143 |
+
embedding_batch_size: int = 16,
|
144 |
+
device: str = "cpu",
|
145 |
+
) -> Tuple[Pipeline, Optional[EmbeddingModel]]:
|
146 |
+
torch.cuda.empty_cache()
|
147 |
+
argumentation_model = load_argumentation_model(model_name, revision=revision, device=device)
|
148 |
+
embedding_model = load_embedding_model(
|
149 |
+
embedding_model_name=embedding_model_name,
|
150 |
+
# embedding_model_revision=embedding_model_revision,
|
151 |
+
embedding_max_length=embedding_max_length,
|
152 |
+
embedding_batch_size=embedding_batch_size,
|
153 |
+
device=device,
|
154 |
+
)
|
155 |
|
156 |
return argumentation_model, embedding_model
|