ArneBinder commited on
Commit
fed112f
1 Parent(s): 38e5624

from https://github.com/ArneBinder/pie-document-level/pull/229

Browse files

except that we still use `SimpleVectorStore` instead of `QdrantVectorStore`

Files changed (3) hide show
  1. app.py +78 -16
  2. document_store.py +85 -1
  3. model_utils.py +42 -11
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
  import logging
3
  import os.path
 
4
  import tempfile
5
  from functools import partial
6
  from typing import List, Optional, Tuple
@@ -32,6 +33,19 @@ DEFAULT_EMBEDDING_MODEL_NAME = "allenai/scibert_scivocab_uncased"
32
  DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
33
  DEFAULT_EMBEDDING_MAX_LENGTH = 512
34
  DEFAULT_EMBEDDING_BATCH_SIZE = 32
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  def render_annotated_document(
@@ -55,9 +69,16 @@ def wrapped_process_text(
55
  doc_id: str,
56
  models: Tuple[Pipeline, Optional[EmbeddingModel]],
57
  document_store: DocumentStore,
 
58
  ) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
59
  try:
60
- document = create_document(text=text, doc_id=doc_id)
 
 
 
 
 
 
61
  annotate_document(
62
  document=document,
63
  annotation_pipeline=models[0],
@@ -77,6 +98,8 @@ def process_uploaded_files(
77
  file_names: List[str],
78
  models: Tuple[Pipeline, Optional[EmbeddingModel]],
79
  document_store: DocumentStore,
 
 
80
  ) -> pd.DataFrame:
81
  try:
82
  new_documents = []
@@ -90,6 +113,9 @@ def process_uploaded_files(
90
  new_document = create_document(
91
  text=text,
92
  doc_id=base_file_name,
 
 
 
93
  )
94
  annotate_document(
95
  document=new_document,
@@ -103,7 +129,7 @@ def process_uploaded_files(
103
  except Exception as e:
104
  raise gr.Error(f"Failed to process uploaded files: {e}")
105
 
106
- return document_store.overview()
107
 
108
 
109
  def open_accordion():
@@ -137,7 +163,7 @@ def set_relation_types(
137
 
138
  return gr.Dropdown(
139
  choices=relation_types,
140
- label="Relation Types",
141
  value=default,
142
  multiselect=True,
143
  )
@@ -204,7 +230,7 @@ def main():
204
  DocumentStore(
205
  span_annotation_caption="adu",
206
  relation_annotation_caption="relation",
207
- vector_store=SimpleVectorStore(),
208
  )
209
  )
210
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
@@ -264,6 +290,11 @@ def main():
264
  ],
265
  outputs=models_state,
266
  )
 
 
 
 
 
267
 
268
  predict_btn = gr.Button("Analyse")
269
 
@@ -289,11 +320,6 @@ def main():
289
 
290
  rendered_output = gr.HTML(label="Rendered Output")
291
 
292
- # add_to_index_btn = gr.Button("Add current result to Index")
293
- upload_btn = gr.UploadButton(
294
- "Upload & Analyse Documents", file_types=["text"], file_count="multiple"
295
- )
296
-
297
  with gr.Column(scale=1):
298
  with gr.Accordion(
299
  "Indexed Documents", open=False
@@ -302,12 +328,22 @@ def main():
302
  headers=["id", "num_adus", "num_relations"],
303
  interactive=False,
304
  )
 
 
 
 
305
  with gr.Row():
306
  download_processed_documents_btn = gr.DownloadButton("Download")
307
  upload_processed_documents_btn = gr.UploadButton(
308
  "Upload", file_types=["json"]
309
  )
310
 
 
 
 
 
 
 
311
  with gr.Accordion("Selected ADU", open=False):
312
  selected_adu_id = gr.Textbox(label="ID", elem_id="selected_adu_id")
313
  selected_adu_text = gr.Textbox(label="Text")
@@ -329,6 +365,14 @@ def main():
329
  )
330
  retrieve_similar_adus_btn = gr.Button("Retrieve similar ADUs")
331
  similar_adus = gr.DataFrame(headers=["doc_id", "adu_id", "score", "text"])
 
 
 
 
 
 
 
 
332
  relation_types = set_relation_types(
333
  models_state.value, default=["supports", "contradicts"]
334
  )
@@ -353,16 +397,19 @@ def main():
353
  outputs=rendered_output,
354
  )
355
 
 
 
 
 
 
 
 
356
  predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
357
  fn=wrapped_process_text,
358
- inputs=[doc_text, doc_id, models_state, document_store_state],
359
  outputs=[document_json, document_state],
360
  api_name="predict",
361
- ).success(
362
- fn=lambda document_store: document_store.overview(),
363
- inputs=[document_store_state],
364
- outputs=[processed_documents_df],
365
- )
366
  render_btn.click(**render_event_kwargs, api_name="render")
367
 
368
  document_state.change(
@@ -377,7 +424,13 @@ def main():
377
  fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
378
  ).then(
379
  fn=process_uploaded_files,
380
- inputs=[upload_btn, models_state, document_store_state],
 
 
 
 
 
 
381
  outputs=[processed_documents_df],
382
  )
383
  processed_documents_df.select(
@@ -385,6 +438,7 @@ def main():
385
  inputs=[processed_documents_df, document_store_state],
386
  outputs=[document_state],
387
  )
 
388
 
389
  download_processed_documents_btn.click(
390
  fn=partial(download_processed_documents, file_name="processed_documents.zip"),
@@ -446,6 +500,14 @@ def main():
446
  inputs=[models_state],
447
  outputs=[relation_types],
448
  )
 
 
 
 
 
 
 
 
449
 
450
  # retrieve_relevant_adus_btn.click(
451
  # **retrieve_relevant_adus_event_kwargs
 
1
  import json
2
  import logging
3
  import os.path
4
+ import re
5
  import tempfile
6
  from functools import partial
7
  from typing import List, Optional, Tuple
 
33
  DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
34
  DEFAULT_EMBEDDING_MAX_LENGTH = 512
35
  DEFAULT_EMBEDDING_BATCH_SIZE = 32
36
+ DEFAULT_SPLIT_REGEX = "\n\n\n+"
37
+
38
+
39
+ def escape_regex(regex: str) -> str:
40
+ # "double escape" the backslashes
41
+ result = regex.encode("unicode_escape").decode("utf-8")
42
+ return result
43
+
44
+
45
+ def unescape_regex(regex: str) -> str:
46
+ # reverse of escape_regex
47
+ result = regex.encode("utf-8").decode("unicode_escape")
48
+ return result
49
 
50
 
51
  def render_annotated_document(
 
69
  doc_id: str,
70
  models: Tuple[Pipeline, Optional[EmbeddingModel]],
71
  document_store: DocumentStore,
72
+ split_regex_escaped: str,
73
  ) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
74
  try:
75
+ document = create_document(
76
+ text=text,
77
+ doc_id=doc_id,
78
+ split_regex=unescape_regex(split_regex_escaped)
79
+ if len(split_regex_escaped) > 0
80
+ else None,
81
+ )
82
  annotate_document(
83
  document=document,
84
  annotation_pipeline=models[0],
 
98
  file_names: List[str],
99
  models: Tuple[Pipeline, Optional[EmbeddingModel]],
100
  document_store: DocumentStore,
101
+ split_regex_escaped: str,
102
+ show_max_cross_doc_sims: bool = False,
103
  ) -> pd.DataFrame:
104
  try:
105
  new_documents = []
 
113
  new_document = create_document(
114
  text=text,
115
  doc_id=base_file_name,
116
+ split_regex=unescape_regex(split_regex_escaped)
117
+ if len(split_regex_escaped) > 0
118
+ else None,
119
  )
120
  annotate_document(
121
  document=new_document,
 
129
  except Exception as e:
130
  raise gr.Error(f"Failed to process uploaded files: {e}")
131
 
132
+ return document_store.overview(with_max_cross_doc_sims=show_max_cross_doc_sims)
133
 
134
 
135
  def open_accordion():
 
163
 
164
  return gr.Dropdown(
165
  choices=relation_types,
166
+ label="Argumentative Relation Types",
167
  value=default,
168
  multiselect=True,
169
  )
 
230
  DocumentStore(
231
  span_annotation_caption="adu",
232
  relation_annotation_caption="relation",
233
+ vector_store=QdrantVectorStore(),
234
  )
235
  )
236
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
 
290
  ],
291
  outputs=models_state,
292
  )
293
+ split_regex_escaped = gr.Textbox(
294
+ label="Regex to partition the text",
295
+ placeholder="Regular expression pattern to split the text into partitions",
296
+ value=escape_regex(DEFAULT_SPLIT_REGEX),
297
+ )
298
 
299
  predict_btn = gr.Button("Analyse")
300
 
 
320
 
321
  rendered_output = gr.HTML(label="Rendered Output")
322
 
 
 
 
 
 
323
  with gr.Column(scale=1):
324
  with gr.Accordion(
325
  "Indexed Documents", open=False
 
328
  headers=["id", "num_adus", "num_relations"],
329
  interactive=False,
330
  )
331
+ show_max_cross_docu_sims = gr.Checkbox(
332
+ label="Show max cross-document similarities", value=False
333
+ )
334
+ gr.Markdown("Data Snapshot:")
335
  with gr.Row():
336
  download_processed_documents_btn = gr.DownloadButton("Download")
337
  upload_processed_documents_btn = gr.UploadButton(
338
  "Upload", file_types=["json"]
339
  )
340
 
341
+ upload_btn = gr.UploadButton(
342
+ "Upload & Analyse Reference Documents",
343
+ file_types=["text"],
344
+ file_count="multiple",
345
+ )
346
+
347
  with gr.Accordion("Selected ADU", open=False):
348
  selected_adu_id = gr.Textbox(label="ID", elem_id="selected_adu_id")
349
  selected_adu_text = gr.Textbox(label="Text")
 
365
  )
366
  retrieve_similar_adus_btn = gr.Button("Retrieve similar ADUs")
367
  similar_adus = gr.DataFrame(headers=["doc_id", "adu_id", "score", "text"])
368
+
369
+ all2all_adu_similarities_button = gr.Button(
370
+ "Compute all ADU-to-ADU similarities"
371
+ )
372
+ all2all_adu_similarities = gr.DataFrame(
373
+ headers=["sim_score", "doc_id", "other_doc_id", "text", "other_text"]
374
+ )
375
+
376
  relation_types = set_relation_types(
377
  models_state.value, default=["supports", "contradicts"]
378
  )
 
397
  outputs=rendered_output,
398
  )
399
 
400
+ show_overview_kwargs = dict(
401
+ fn=lambda document_store, show_max_sims: document_store.overview(
402
+ with_max_cross_doc_sims=show_max_sims
403
+ ),
404
+ inputs=[document_store_state, show_max_cross_docu_sims],
405
+ outputs=[processed_documents_df],
406
+ )
407
  predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
408
  fn=wrapped_process_text,
409
+ inputs=[doc_text, doc_id, models_state, document_store_state, split_regex_escaped],
410
  outputs=[document_json, document_state],
411
  api_name="predict",
412
+ ).success(**show_overview_kwargs)
 
 
 
 
413
  render_btn.click(**render_event_kwargs, api_name="render")
414
 
415
  document_state.change(
 
424
  fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]
425
  ).then(
426
  fn=process_uploaded_files,
427
+ inputs=[
428
+ upload_btn,
429
+ models_state,
430
+ document_store_state,
431
+ split_regex_escaped,
432
+ show_max_cross_docu_sims,
433
+ ],
434
  outputs=[processed_documents_df],
435
  )
436
  processed_documents_df.select(
 
438
  inputs=[processed_documents_df, document_store_state],
439
  outputs=[document_state],
440
  )
441
+ show_max_cross_docu_sims.change(**show_overview_kwargs)
442
 
443
  download_processed_documents_btn.click(
444
  fn=partial(download_processed_documents, file_name="processed_documents.zip"),
 
500
  inputs=[models_state],
501
  outputs=[relation_types],
502
  )
503
+ all2all_adu_similarities_button.click(
504
+ fn=partial(
505
+ DocumentStore.get_all2all_adu_similarities,
506
+ columns=all2all_adu_similarities.headers,
507
+ ),
508
+ inputs=[document_store_state],
509
+ outputs=[all2all_adu_similarities],
510
+ )
511
 
512
  # retrieve_relevant_adus_btn.click(
513
  # **retrieve_relevant_adus_event_kwargs
document_store.py CHANGED
@@ -8,6 +8,7 @@ from collections import defaultdict
8
  from typing import Any, Dict, List, Optional
9
 
10
  import gradio as gr
 
11
  import pandas as pd
12
  from annotation_utils import labeled_span_to_id
13
  from pytorch_ie import Annotation
@@ -417,7 +418,7 @@ class DocumentStore:
417
 
418
  return document
419
 
420
- def overview(self) -> pd.DataFrame:
421
  rows = []
422
  for doc_id, document in self.documents.items():
423
  layers = {
@@ -429,6 +430,38 @@ class DocumentStore:
429
  layer_sizes = {f"num_{caption}s": len(layer) for caption, layer in layers.items()}
430
  rows.append({"doc_id": doc_id, **layer_sizes})
431
  df = pd.DataFrame(rows)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  return df
433
 
434
  def as_dict(self, include_embeddings: bool = True) -> dict:
@@ -441,3 +474,54 @@ class DocumentStore:
441
  }
442
  result[doc_id] = doc_dict
443
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from typing import Any, Dict, List, Optional
9
 
10
  import gradio as gr
11
+ import numpy as np
12
  import pandas as pd
13
  from annotation_utils import labeled_span_to_id
14
  from pytorch_ie import Annotation
 
418
 
419
  return document
420
 
421
+ def overview(self, with_max_cross_doc_sims: bool = False) -> pd.DataFrame:
422
  rows = []
423
  for doc_id, document in self.documents.items():
424
  layers = {
 
430
  layer_sizes = {f"num_{caption}s": len(layer) for caption, layer in layers.items()}
431
  rows.append({"doc_id": doc_id, **layer_sizes})
432
  df = pd.DataFrame(rows)
433
+
434
+ # add highest cross-document similarity score for each document
435
+ if with_max_cross_doc_sims and len(self.documents) > 1:
436
+ # Setting min_similarity to None and top_k to None to get all similarities. Otherwise,
437
+ # it may happen that this occludes max cross-doc sim for some documents in the
438
+ # case that there are more than top_k ADUs in the reference document that have a higher
439
+ # similarity with each other than the highest similarity to any ADU in another document
440
+ # or if the cross-doc similarity is below the min_similarity threshold.
441
+ all2all_adu_similarities = self.get_all2all_adu_similarities(
442
+ min_similarity=None, top_k=None, columns=["doc_id", "other_doc_id", "sim_score"]
443
+ )
444
+ max_doc2doc_similarities = all2all_adu_similarities.pivot_table(
445
+ values="sim_score", index="doc_id", columns="other_doc_id", aggfunc="max"
446
+ )
447
+
448
+ max_doc2doc_similarities.sort_index(axis="index", inplace=True)
449
+ max_doc2doc_similarities.sort_index(axis="columns", inplace=True)
450
+ # check that the index and columns are the same
451
+ if (max_doc2doc_similarities.index != max_doc2doc_similarities.columns).any():
452
+ raise gr.Error("Index and columns of max_doc2doc_similarities are not the same.")
453
+ # set diagonal entries to minus infinity to exclude them from the maximum
454
+ np.fill_diagonal(max_doc2doc_similarities.values, -np.inf)
455
+
456
+ max_doc_ids = max_doc2doc_similarities.idxmax(axis="columns")
457
+ max_similarities = max_doc2doc_similarities.max(axis="columns")
458
+
459
+ # set the index to the doc_id to correctly join the series
460
+ df.set_index("doc_id", inplace=True)
461
+ df["max_cross_doc_sim_doc_id"] = max_doc_ids
462
+ df["max_cross_doc_sim_score"] = max_similarities
463
+ df.reset_index(inplace=True)
464
+
465
  return df
466
 
467
  def as_dict(self, include_embeddings: bool = True) -> dict:
 
474
  }
475
  result[doc_id] = doc_dict
476
  return result
477
+
478
+ def get_all2all_adu_similarities(
479
+ self,
480
+ min_similarity: Optional[float] = 0.5,
481
+ top_k: Optional[int] = 100,
482
+ columns: Optional[List[str]] = None,
483
+ ) -> pd.DataFrame:
484
+ """Get the similarities between all ADUs in the store.
485
+
486
+ Args:
487
+ min_similarity: The minimum similarity score to consider.
488
+ top_k: The number of similar ADUs to return.
489
+ columns: The columns to include in the result DataFrame. If None, all columns are included.
490
+
491
+ Returns:
492
+ A DataFrame with the columns: doc_id, text, other_doc_id, other_text, sim_score.
493
+ """
494
+ result = []
495
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
496
+ for doc_id, document in self.documents.items():
497
+ for adu in document.labeled_spans.predictions:
498
+ adu_id = labeled_span_to_id(adu)
499
+ similar_entries = self.vector_store.retrieve_similar(
500
+ ref_payload=self.construct_embedding_payload(document, adu_id),
501
+ min_similarity=min_similarity,
502
+ top_k=top_k,
503
+ )
504
+ for _, payload, score in similar_entries:
505
+ other_doc_id = payload["doc_id"]
506
+ other_document = self.documents[other_doc_id]
507
+ other_adu = get_annotation_from_document(
508
+ other_document,
509
+ payload["annotation_id"],
510
+ self.span_layer_name,
511
+ use_predictions=self.use_predictions,
512
+ )
513
+ result.append(
514
+ {
515
+ "sim_score": score,
516
+ "doc_id": doc_id,
517
+ "other_doc_id": other_doc_id,
518
+ "adu_id": adu_id,
519
+ "other_adu_id": payload["annotation_id"],
520
+ "text": str(adu),
521
+ "other_text": str(other_adu),
522
+ }
523
+ )
524
+ result_df = pd.DataFrame(result)
525
+ if columns is not None:
526
+ result_df = result_df[columns]
527
+ return result_df
model_utils.py CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
5
  import torch
6
  from annotation_utils import labeled_span_to_id
7
  from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
 
8
  from pytorch_ie import Pipeline
9
  from pytorch_ie.annotations import LabeledSpan
10
  from pytorch_ie.auto import AutoPipeline
@@ -48,7 +49,7 @@ def annotate_document(
48
 
49
 
50
  def create_document(
51
- text: str, doc_id: str
52
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
53
  """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
54
  text.
@@ -56,6 +57,7 @@ def create_document(
56
  Parameters:
57
  text: The text to process.
58
  doc_id: The ID of the document.
 
59
 
60
  Returns:
61
  The processed document.
@@ -64,8 +66,14 @@ def create_document(
64
  document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
65
  id=doc_id, text=text, metadata={}
66
  )
67
- # add single partition from the whole text (the model only considers text in partitions)
68
- document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
 
 
 
 
 
 
69
  return document
70
 
71
 
@@ -92,24 +100,22 @@ def load_argumentation_model(
92
  taskmodule_kwargs=dict(revision=revision),
93
  model_kwargs=dict(revision=revision),
94
  )
 
 
 
95
  except Exception as e:
96
  raise gr.Error(f"Failed to load argumentation model: {e}")
97
- gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})")
98
  return model
99
 
100
 
101
- def load_models(
102
- model_name: str,
103
- revision: Optional[str] = None,
104
  embedding_model_name: Optional[str] = None,
105
  # embedding_model_revision: Optional[str] = None,
106
  embedding_max_length: int = 512,
107
  embedding_batch_size: int = 16,
108
  device: str = "cpu",
109
- ) -> Tuple[Pipeline, Optional[EmbeddingModel]]:
110
- torch.cuda.empty_cache()
111
- argumentation_model = load_argumentation_model(model_name, revision=revision, device=device)
112
- embedding_model = None
113
  if embedding_model_name is not None and embedding_model_name.strip():
114
  try:
115
  embedding_model = HuggingfaceEmbeddingModel(
@@ -119,7 +125,32 @@ def load_models(
119
  max_length=embedding_max_length,
120
  batch_size=embedding_batch_size,
121
  )
 
122
  except Exception as e:
123
  raise gr.Error(f"Failed to load embedding model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  return argumentation_model, embedding_model
 
5
  import torch
6
  from annotation_utils import labeled_span_to_id
7
  from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
8
+ from pie_modules.document.processing import RegexPartitioner
9
  from pytorch_ie import Pipeline
10
  from pytorch_ie.annotations import LabeledSpan
11
  from pytorch_ie.auto import AutoPipeline
 
49
 
50
 
51
  def create_document(
52
+ text: str, doc_id: str, split_regex: Optional[str] = None
53
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
54
  """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
55
  text.
 
57
  Parameters:
58
  text: The text to process.
59
  doc_id: The ID of the document.
60
+ split_regex: A regular expression pattern to use for splitting the text into partitions.
61
 
62
  Returns:
63
  The processed document.
 
66
  document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
67
  id=doc_id, text=text, metadata={}
68
  )
69
+ if split_regex is not None:
70
+ partitioner = RegexPartitioner(
71
+ pattern=split_regex, partition_layer_name="labeled_partitions"
72
+ )
73
+ document = partitioner(document)
74
+ else:
75
+ # add single partition from the whole text (the model only considers text in partitions)
76
+ document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
77
  return document
78
 
79
 
 
100
  taskmodule_kwargs=dict(revision=revision),
101
  model_kwargs=dict(revision=revision),
102
  )
103
+ gr.Info(
104
+ f"Loaded argumentation model: model_name={model_name}, revision={revision}, device={device}"
105
+ )
106
  except Exception as e:
107
  raise gr.Error(f"Failed to load argumentation model: {e}")
108
+
109
  return model
110
 
111
 
112
+ def load_embedding_model(
 
 
113
  embedding_model_name: Optional[str] = None,
114
  # embedding_model_revision: Optional[str] = None,
115
  embedding_max_length: int = 512,
116
  embedding_batch_size: int = 16,
117
  device: str = "cpu",
118
+ ) -> Optional[EmbeddingModel]:
 
 
 
119
  if embedding_model_name is not None and embedding_model_name.strip():
120
  try:
121
  embedding_model = HuggingfaceEmbeddingModel(
 
125
  max_length=embedding_max_length,
126
  batch_size=embedding_batch_size,
127
  )
128
+ gr.Info(f"Loaded embedding model: model_name={embedding_model_name}, device={device}")
129
  except Exception as e:
130
  raise gr.Error(f"Failed to load embedding model: {e}")
131
+ else:
132
+ embedding_model = None
133
+
134
+ return embedding_model
135
+
136
+
137
+ def load_models(
138
+ model_name: str,
139
+ revision: Optional[str] = None,
140
+ embedding_model_name: Optional[str] = None,
141
+ # embedding_model_revision: Optional[str] = None,
142
+ embedding_max_length: int = 512,
143
+ embedding_batch_size: int = 16,
144
+ device: str = "cpu",
145
+ ) -> Tuple[Pipeline, Optional[EmbeddingModel]]:
146
+ torch.cuda.empty_cache()
147
+ argumentation_model = load_argumentation_model(model_name, revision=revision, device=device)
148
+ embedding_model = load_embedding_model(
149
+ embedding_model_name=embedding_model_name,
150
+ # embedding_model_revision=embedding_model_revision,
151
+ embedding_max_length=embedding_max_length,
152
+ embedding_batch_size=embedding_batch_size,
153
+ device=device,
154
+ )
155
 
156
  return argumentation_model, embedding_model