ArneBinder commited on
Commit
1681237
1 Parent(s): 148e0d6

Upload 10 files

Browse files

from https://github.com/ArneBinder/pie-document-level/pull/221

Files changed (6) hide show
  1. app.py +59 -13
  2. document_store.py +0 -2
  3. embedding.py +124 -0
  4. model_utils.py +68 -116
  5. rendering_utils.py +4 -1
  6. requirements.txt +3 -0
app.py CHANGED
@@ -7,8 +7,10 @@ from typing import List, Optional, Tuple
7
 
8
  import gradio as gr
9
  import pandas as pd
 
10
  from document_store import DocumentStore, get_annotation_from_document
11
- from model_utils import create_and_annotate_document, load_models
 
12
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
13
  from pytorch_ie import Pipeline
14
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
@@ -26,6 +28,9 @@ DEFAULT_MODEL_REVISION = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
26
  # DEFAULT_MODEL_NAME = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
27
  # DEFAULT_MODEL_REVISION = None
28
  DEFAULT_EMBEDDING_MODEL_NAME = "allenai/scibert_scivocab_uncased"
 
 
 
29
 
30
 
31
  def render_annotated_document(
@@ -47,22 +52,26 @@ def render_annotated_document(
47
  def wrapped_process_text(
48
  text: str,
49
  doc_id: str,
50
- models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
51
  document_store: DocumentStore,
52
  ) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
53
- document = create_and_annotate_document(
54
- text=text,
55
- doc_id=doc_id,
56
- models=models,
57
- )
58
- document_store.add_document(document)
 
 
 
 
59
  # Return as dict and document to avoid serialization issues
60
  return document.asdict(), document
61
 
62
 
63
  def process_uploaded_files(
64
  file_names: List[str],
65
- models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
66
  document_store: DocumentStore,
67
  ) -> pd.DataFrame:
68
  try:
@@ -74,7 +83,16 @@ def process_uploaded_files(
74
  text = f.read()
75
  base_file_name = os.path.basename(file_name)
76
  gr.Info(f"Processing file '{base_file_name}' ...")
77
- new_documents.append(create_and_annotate_document(text, base_file_name, models))
 
 
 
 
 
 
 
 
 
78
  else:
79
  raise gr.Error(f"Unsupported file format: {file_name}")
80
  document_store.add_documents(new_documents)
@@ -143,10 +161,13 @@ def main():
143
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
144
 
145
  print("Loading models ...")
146
- argumentation_model, embedding_model, embedding_tokenizer = load_models(
147
  model_name=DEFAULT_MODEL_NAME,
148
  revision=DEFAULT_MODEL_REVISION,
149
  embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
 
 
 
150
  )
151
 
152
  default_render_kwargs = {
@@ -179,7 +200,7 @@ def main():
179
  DocumentStore(span_annotation_caption="adu", relation_annotation_caption="relation")
180
  )
181
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
182
- models_state = gr.State((argumentation_model, embedding_model, embedding_tokenizer))
183
  with gr.Row():
184
  with gr.Column(scale=1):
185
  doc_id = gr.Textbox(
@@ -204,10 +225,35 @@ def main():
204
  label=f"Embedding Model Name (e.g. {DEFAULT_EMBEDDING_MODEL_NAME})",
205
  value=DEFAULT_EMBEDDING_MODEL_NAME,
206
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  load_models_btn = gr.Button("Load Models")
208
  load_models_btn.click(
209
  fn=load_models,
210
- inputs=[model_name, model_revision, embedding_model_name],
 
 
 
 
 
 
 
211
  outputs=models_state,
212
  )
213
 
 
7
 
8
  import gradio as gr
9
  import pandas as pd
10
+ import torch
11
  from document_store import DocumentStore, get_annotation_from_document
12
+ from embedding import EmbeddingModel
13
+ from model_utils import annotate_document, create_document, load_models
14
  from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
15
  from pytorch_ie import Pipeline
16
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
 
28
  # DEFAULT_MODEL_NAME = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
29
  # DEFAULT_MODEL_REVISION = None
30
  DEFAULT_EMBEDDING_MODEL_NAME = "allenai/scibert_scivocab_uncased"
31
+ DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
+ DEFAULT_EMBEDDING_MAX_LENGTH = 512
33
+ DEFAULT_EMBEDDING_BATCH_SIZE = 32
34
 
35
 
36
  def render_annotated_document(
 
52
  def wrapped_process_text(
53
  text: str,
54
  doc_id: str,
55
+ models: Tuple[Pipeline, Optional[EmbeddingModel]],
56
  document_store: DocumentStore,
57
  ) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
58
+ try:
59
+ document = create_document(text=text, doc_id=doc_id)
60
+ annotate_document(
61
+ document=document,
62
+ annotation_pipeline=models[0],
63
+ embedding_model=models[1],
64
+ )
65
+ document_store.add_document(document)
66
+ except Exception as e:
67
+ raise gr.Error(f"Failed to process text: {e}")
68
  # Return as dict and document to avoid serialization issues
69
  return document.asdict(), document
70
 
71
 
72
  def process_uploaded_files(
73
  file_names: List[str],
74
+ models: Tuple[Pipeline, Optional[EmbeddingModel]],
75
  document_store: DocumentStore,
76
  ) -> pd.DataFrame:
77
  try:
 
83
  text = f.read()
84
  base_file_name = os.path.basename(file_name)
85
  gr.Info(f"Processing file '{base_file_name}' ...")
86
+ new_document = create_document(
87
+ text=text,
88
+ doc_id=base_file_name,
89
+ )
90
+ annotate_document(
91
+ document=new_document,
92
+ annotation_pipeline=models[0],
93
+ embedding_model=models[1],
94
+ )
95
+ new_documents.append(new_document)
96
  else:
97
  raise gr.Error(f"Unsupported file format: {file_name}")
98
  document_store.add_documents(new_documents)
 
161
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
162
 
163
  print("Loading models ...")
164
+ argumentation_model, embedding_model = load_models(
165
  model_name=DEFAULT_MODEL_NAME,
166
  revision=DEFAULT_MODEL_REVISION,
167
  embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
168
+ embedding_max_length=DEFAULT_EMBEDDING_MAX_LENGTH,
169
+ embedding_batch_size=DEFAULT_EMBEDDING_BATCH_SIZE,
170
+ device=DEFAULT_DEVICE,
171
  )
172
 
173
  default_render_kwargs = {
 
200
  DocumentStore(span_annotation_caption="adu", relation_annotation_caption="relation")
201
  )
202
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
203
+ models_state = gr.State((argumentation_model, embedding_model))
204
  with gr.Row():
205
  with gr.Column(scale=1):
206
  doc_id = gr.Textbox(
 
225
  label=f"Embedding Model Name (e.g. {DEFAULT_EMBEDDING_MODEL_NAME})",
226
  value=DEFAULT_EMBEDDING_MODEL_NAME,
227
  )
228
+ embedding_max_length = gr.Slider(
229
+ label="Embedding Model Max Length",
230
+ minimum=16,
231
+ maximum=2048,
232
+ step=8,
233
+ value=DEFAULT_EMBEDDING_MAX_LENGTH,
234
+ )
235
+ embedding_batch_size = gr.Slider(
236
+ label="Embedding Model Batch Size",
237
+ minimum=1,
238
+ maximum=128,
239
+ step=1,
240
+ value=DEFAULT_EMBEDDING_BATCH_SIZE,
241
+ )
242
+ device = gr.Textbox(
243
+ label="Device (e.g. 'cuda' or 'cpu')",
244
+ value=DEFAULT_DEVICE,
245
+ )
246
  load_models_btn = gr.Button("Load Models")
247
  load_models_btn.click(
248
  fn=load_models,
249
+ inputs=[
250
+ model_name,
251
+ model_revision,
252
+ embedding_model_name,
253
+ embedding_max_length,
254
+ embedding_batch_size,
255
+ device,
256
+ ],
257
  outputs=models_state,
258
  )
259
 
document_store.py CHANGED
@@ -307,8 +307,6 @@ class DocumentStore:
307
 
308
  def add_document_from_dict(self, document_dict: dict) -> None:
309
  document = self.document_type.fromdict(document_dict)
310
- # metadata is not automatically deserialized, so we need to set it manually
311
- document.metadata = document_dict["metadata"]
312
  self.add_document(document)
313
 
314
  def add_documents(self, documents: List[TextBasedDocument]) -> None:
 
307
 
308
  def add_document_from_dict(self, document_dict: dict) -> None:
309
  document = self.document_type.fromdict(document_dict)
 
 
310
  self.add_document(document)
311
 
312
  def add_documents(self, documents: List[TextBasedDocument]) -> None:
embedding.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import logging
3
+ from typing import Dict
4
+
5
+ import torch
6
+ from datasets import Dataset
7
+ from pie_modules.document.processing import tokenize_document
8
+ from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
9
+ from pytorch_ie.annotations import Span
10
+ from pytorch_ie.documents import TextBasedDocument
11
+ from torch import FloatTensor, Tensor
12
+ from torch.utils.data import DataLoader
13
+ from transformers import AutoModel, AutoTokenizer
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class EmbeddingModel(abc.ABC):
19
+ def __call__(
20
+ self, document: TextBasedDocument, span_layer_name: str
21
+ ) -> Dict[Span, FloatTensor]:
22
+ """Embed text annotations from a document.
23
+
24
+ Args:
25
+ document: The document to embed.
26
+ span_layer_name: The name of the annotation layer in the document that contains the
27
+ text span annotations to embed.
28
+
29
+ Returns:
30
+ A dictionary mapping text annotations to their embeddings.
31
+ """
32
+ pass
33
+
34
+
35
+ class HuggingfaceEmbeddingModel(EmbeddingModel):
36
+ def __init__(
37
+ self,
38
+ model_name_or_path: str,
39
+ revision: str = None,
40
+ device: str = "cpu",
41
+ max_length: int = 512,
42
+ batch_size: int = 16,
43
+ ):
44
+ self.load(model_name_or_path, revision, device)
45
+ self.max_length = max_length
46
+ self.batch_size = batch_size
47
+
48
+ def load(self, model_name_or_path: str, revision: str = None, device: str = "cpu") -> None:
49
+ self._model = AutoModel.from_pretrained(model_name_or_path, revision=revision).to(device)
50
+ self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, revision=revision)
51
+
52
+ def __call__(
53
+ self, document: TextBasedDocument, span_layer_name: str
54
+ ) -> Dict[Span, FloatTensor]:
55
+ # to not modify the original document
56
+ document = document.copy()
57
+ # tokenize_document does not yet consider predictions, so we need to add them manually
58
+ document[span_layer_name].extend(document[span_layer_name].predictions.clear())
59
+ added_annotations = []
60
+ tokenizer_kwargs = {
61
+ "max_length": self.max_length,
62
+ "stride": self.max_length // 8,
63
+ "truncation": True,
64
+ "padding": True,
65
+ "return_overflowing_tokens": True,
66
+ }
67
+ # tokenize once to get the tokenized documents with mapped annotations
68
+ tokenized_documents = tokenize_document(
69
+ document,
70
+ tokenizer=self._tokenizer,
71
+ result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
72
+ partition_layer="labeled_partitions",
73
+ added_annotations=added_annotations,
74
+ strict_span_conversion=False,
75
+ **tokenizer_kwargs,
76
+ )
77
+
78
+ # just tokenize again to get tensors in the correct format for the model
79
+ dataset = Dataset.from_dict({"text": [document.text]})
80
+
81
+ def tokenize_function(examples):
82
+ return self._tokenizer(examples["text"], **tokenizer_kwargs)
83
+
84
+ # Tokenize the texts. Note that we remove the text column directly in the map call,
85
+ # otherwise the map would fail because we produce we amy produce multipel new rows
86
+ # (tokenization result) for each input row (text).
87
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
88
+ # remove the overflow_to_sample_mapping column
89
+ tokenized_dataset = tokenized_dataset.remove_columns(["overflow_to_sample_mapping"])
90
+ tokenized_dataset.set_format(type="torch")
91
+
92
+ dataloader = DataLoader(tokenized_dataset, batch_size=self.batch_size)
93
+
94
+ embeddings = {}
95
+ example_idx = 0
96
+ for batch in dataloader:
97
+ batch_at_device = {
98
+ k: v.to(self._model.device) if isinstance(v, Tensor) else v
99
+ for k, v in batch.items()
100
+ }
101
+ with torch.no_grad():
102
+ model_output = self._model(**batch_at_device)
103
+
104
+ for last_hidden_state in model_output.last_hidden_state:
105
+ text2tok_ann = added_annotations[example_idx][span_layer_name]
106
+ tok2text_ann = {v: k for k, v in text2tok_ann.items()}
107
+ for tok_ann in tokenized_documents[example_idx].labeled_spans:
108
+ # skip "empty" annotations
109
+ if tok_ann.start == tok_ann.end:
110
+ continue
111
+ # use the max pooling strategy to get a single embedding for the annotation text
112
+ embedding = (
113
+ last_hidden_state[tok_ann.start : tok_ann.end].max(dim=0)[0].detach().cpu()
114
+ )
115
+ text_ann = tok2text_ann[tok_ann]
116
+
117
+ if text_ann in embeddings:
118
+ logger.warning(
119
+ f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
120
+ )
121
+ embeddings[text_ann] = embedding
122
+ example_idx += 1
123
+
124
+ return embeddings
model_utils.py CHANGED
@@ -1,98 +1,45 @@
1
  import logging
2
- from typing import Dict, List, Optional, Tuple
3
 
4
  import gradio as gr
 
5
  from annotation_utils import labeled_span_to_id
6
- from pie_modules.document.processing import tokenize_document
7
- from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
8
  from pytorch_ie import Pipeline
9
  from pytorch_ie.annotations import LabeledSpan
10
  from pytorch_ie.auto import AutoPipeline
11
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
12
- from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
13
 
14
  logger = logging.getLogger(__name__)
15
 
16
 
17
- def _embed_text_annotations(
18
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
19
- model: PreTrainedModel,
20
- tokenizer: PreTrainedTokenizer,
21
- text_layer_name: str,
22
- ) -> Dict[LabeledSpan, List[float]]:
23
- # to not modify the original document
24
- document = document.copy()
25
- # tokenize_document does not yet consider predictions, so we need to add them manually
26
- document[text_layer_name].extend(document[text_layer_name].predictions.clear())
27
- added_annotations = []
28
- tokenizer_kwargs = {
29
- "max_length": 512,
30
- "stride": 64,
31
- "truncation": True,
32
- "return_overflowing_tokens": True,
33
- }
34
- tokenized_documents = tokenize_document(
35
- document,
36
- tokenizer=tokenizer,
37
- result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
38
- partition_layer="labeled_partitions",
39
- added_annotations=added_annotations,
40
- strict_span_conversion=False,
41
- **tokenizer_kwargs,
42
- )
43
- # just tokenize again to get tensors in the correct format for the model
44
- model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
45
- # this is added when using return_overflowing_tokens=True, but the model does not accept it
46
- model_inputs.pop("overflow_to_sample_mapping", None)
47
- assert len(model_inputs.encodings) == len(tokenized_documents)
48
- model_output = model(**model_inputs)
49
-
50
- # get embeddings for all text annotations
51
- embeddings = {}
52
- for batch_idx in range(len(model_output.last_hidden_state)):
53
- text2tok_ann = added_annotations[batch_idx][text_layer_name]
54
- tok2text_ann = {v: k for k, v in text2tok_ann.items()}
55
- for tok_ann in tokenized_documents[batch_idx].labeled_spans:
56
- # skip "empty" annotations
57
- if tok_ann.start == tok_ann.end:
58
- continue
59
- # use the max pooling strategy to get a single embedding for the annotation text
60
- embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max(
61
- dim=0
62
- )[0]
63
- text_ann = tok2text_ann[tok_ann]
64
-
65
- if text_ann in embeddings:
66
- logger.warning(
67
- f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
68
- )
69
- embeddings[text_ann] = embedding
70
-
71
- return embeddings
72
-
73
-
74
- def _annotate(
75
- document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
76
- pipeline: Pipeline,
77
- embedding_model: Optional[PreTrainedModel] = None,
78
- embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
79
  ) -> None:
 
 
 
 
 
 
 
 
80
 
81
  # execute prediction pipeline
82
- pipeline(document)
83
 
84
- if embedding_model is not None and embedding_tokenizer is not None:
85
- adu_embeddings = _embed_text_annotations(
86
  document=document,
87
- model=embedding_model,
88
- tokenizer=embedding_tokenizer,
89
- text_layer_name="labeled_spans",
90
  )
91
  # convert keys to str because JSON keys must be strings
92
- adu_embeddings_dict = {
93
- labeled_span_to_id(k): v.detach().tolist() for k, v in adu_embeddings.items()
94
  }
95
- document.metadata["embeddings"] = adu_embeddings_dict
96
  else:
97
  gr.Warning(
98
  "No embedding model provided. Skipping embedding extraction. You can load an embedding "
@@ -100,47 +47,47 @@ def _annotate(
100
  )
101
 
102
 
103
- def create_and_annotate_document(
104
- text: str,
105
- doc_id: str,
106
- models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
107
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
108
  """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
109
- text, annotate it, and add it to the index.
110
 
111
  Parameters:
112
  text: The text to process.
113
  doc_id: The ID of the document.
114
- models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
115
 
116
  Returns:
117
  The processed document.
118
  """
119
 
120
- try:
121
- document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
122
- id=doc_id, text=text, metadata={}
123
- )
124
- # add single partition from the whole text (the model only considers text in partitions)
125
- document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
126
- # annotate the document
127
- _annotate(
128
- document=document,
129
- pipeline=models[0],
130
- embedding_model=models[1],
131
- embedding_tokenizer=models[2],
132
- )
133
-
134
- return document
135
- except Exception as e:
136
- raise gr.Error(f"Failed to process text: {e}")
137
 
138
 
139
- def load_argumentation_model(model_name: str, revision: Optional[str] = None) -> Pipeline:
 
 
 
 
140
  try:
 
 
 
 
 
 
 
 
 
 
141
  model = AutoPipeline.from_pretrained(
142
  model_name,
143
- device=-1,
144
  num_workers=0,
145
  taskmodule_kwargs=dict(revision=revision),
146
  model_kwargs=dict(revision=revision),
@@ -151,23 +98,28 @@ def load_argumentation_model(model_name: str, revision: Optional[str] = None) ->
151
  return model
152
 
153
 
154
- def load_embedding_model(model_name: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
155
- try:
156
- embedding_model = AutoModel.from_pretrained(model_name)
157
- embedding_tokenizer = AutoTokenizer.from_pretrained(model_name)
158
- except Exception as e:
159
- raise gr.Error(f"Failed to load embedding model: {e}")
160
- gr.Info(f"Loaded embedding model: model_name={model_name})")
161
- return embedding_model, embedding_tokenizer
162
-
163
-
164
  def load_models(
165
- model_name: str, revision: Optional[str] = None, embedding_model_name: Optional[str] = None
166
- ) -> Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]]:
167
- argumentation_model = load_argumentation_model(model_name, revision)
 
 
 
 
 
 
 
168
  embedding_model = None
169
- embedding_tokenizer = None
170
  if embedding_model_name is not None and embedding_model_name.strip():
171
- embedding_model, embedding_tokenizer = load_embedding_model(embedding_model_name)
172
-
173
- return argumentation_model, embedding_model, embedding_tokenizer
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from typing import Optional, Tuple
3
 
4
  import gradio as gr
5
+ import torch
6
  from annotation_utils import labeled_span_to_id
7
+ from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
 
8
  from pytorch_ie import Pipeline
9
  from pytorch_ie.annotations import LabeledSpan
10
  from pytorch_ie.auto import AutoPipeline
11
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
+ def annotate_document(
17
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
18
+ annotation_pipeline: Pipeline,
19
+ embedding_model: Optional[EmbeddingModel] = None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ) -> None:
21
+ """Annotate a document with the provided pipeline. If an embedding model is provided, also
22
+ extract embeddings for the labeled spans.
23
+
24
+ Args:
25
+ document: The document to annotate.
26
+ annotation_pipeline: The pipeline to use for annotation.
27
+ embedding_model: The embedding model to use for extracting text span embeddings.
28
+ """
29
 
30
  # execute prediction pipeline
31
+ annotation_pipeline(document)
32
 
33
+ if embedding_model is not None:
34
+ text_span_embeddings = embedding_model(
35
  document=document,
36
+ span_layer_name="labeled_spans",
 
 
37
  )
38
  # convert keys to str because JSON keys must be strings
39
+ text_span_embeddings_dict = {
40
+ labeled_span_to_id(k): v.tolist() for k, v in text_span_embeddings.items()
41
  }
42
+ document.metadata["embeddings"] = text_span_embeddings_dict
43
  else:
44
  gr.Warning(
45
  "No embedding model provided. Skipping embedding extraction. You can load an embedding "
 
47
  )
48
 
49
 
50
+ def create_document(
51
+ text: str, doc_id: str
 
 
52
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
53
  """Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
54
+ text.
55
 
56
  Parameters:
57
  text: The text to process.
58
  doc_id: The ID of the document.
 
59
 
60
  Returns:
61
  The processed document.
62
  """
63
 
64
+ document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
65
+ id=doc_id, text=text, metadata={}
66
+ )
67
+ # add single partition from the whole text (the model only considers text in partitions)
68
+ document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
69
+ return document
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
+ def load_argumentation_model(
73
+ model_name: str,
74
+ revision: Optional[str] = None,
75
+ device: str = "cpu",
76
+ ) -> Pipeline:
77
  try:
78
+ # the Pipeline class expects an integer for the device
79
+ if device == "cuda":
80
+ pipeline_device = 0
81
+ elif device.startswith("cuda:"):
82
+ pipeline_device = int(device.split(":")[1])
83
+ elif device == "cpu":
84
+ pipeline_device = -1
85
+ else:
86
+ raise gr.Error(f"Invalid device: {device}")
87
+
88
  model = AutoPipeline.from_pretrained(
89
  model_name,
90
+ device=pipeline_device,
91
  num_workers=0,
92
  taskmodule_kwargs=dict(revision=revision),
93
  model_kwargs=dict(revision=revision),
 
98
  return model
99
 
100
 
 
 
 
 
 
 
 
 
 
 
101
  def load_models(
102
+ model_name: str,
103
+ revision: Optional[str] = None,
104
+ embedding_model_name: Optional[str] = None,
105
+ # embedding_model_revision: Optional[str] = None,
106
+ embedding_max_length: int = 512,
107
+ embedding_batch_size: int = 16,
108
+ device: str = "cpu",
109
+ ) -> Tuple[Pipeline, Optional[EmbeddingModel]]:
110
+ torch.cuda.empty_cache()
111
+ argumentation_model = load_argumentation_model(model_name, revision=revision, device=device)
112
  embedding_model = None
 
113
  if embedding_model_name is not None and embedding_model_name.strip():
114
+ try:
115
+ embedding_model = HuggingfaceEmbeddingModel(
116
+ embedding_model_name.strip(),
117
+ # revision=embedding_model_revision,
118
+ device=device,
119
+ max_length=embedding_max_length,
120
+ batch_size=embedding_batch_size,
121
+ )
122
+ except Exception as e:
123
+ raise gr.Error(f"Failed to load embedding model: {e}")
124
+
125
+ return argumentation_model, embedding_model
rendering_utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  from collections import defaultdict
3
  from typing import Dict, List, Optional, Union
4
 
@@ -7,6 +8,8 @@ from pytorch_ie.annotations import BinaryRelation
7
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
8
  from rendering_utils_displacy import EntityRenderer
9
 
 
 
10
 
11
  def render_pretty_table(
12
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
@@ -92,7 +95,7 @@ def inject_relation_data(
92
  entity_annotation = sorted_entities[idx]
93
  # sanity check
94
  if str(entity_annotation) != entity.next:
95
- raise ValueError(f"Entity text mismatch: {entity_annotation} != {entity.text}")
96
  entity["data-label"] = entity_annotation.label
97
  entity["data-relation-tails"] = json.dumps(
98
  [
 
1
  import json
2
+ import logging
3
  from collections import defaultdict
4
  from typing import Dict, List, Optional, Union
5
 
 
8
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
9
  from rendering_utils_displacy import EntityRenderer
10
 
11
+ logger = logging.getLogger(__name__)
12
+
13
 
14
  def render_pretty_table(
15
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
 
95
  entity_annotation = sorted_entities[idx]
96
  # sanity check
97
  if str(entity_annotation) != entity.next:
98
+ logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
99
  entity["data-label"] = entity_annotation.label
100
  entity["data-relation-tails"] = json.dumps(
101
  [
requirements.txt CHANGED
@@ -2,3 +2,6 @@ gradio==4.36.0
2
  prettytable==3.10.0
3
  pie-modules==0.12.0
4
  beautifulsoup4==4.12.3
 
 
 
 
2
  prettytable==3.10.0
3
  pie-modules==0.12.0
4
  beautifulsoup4==4.12.3
5
+ datasets==2.14.4
6
+ # numpy 2.0.0 breaks the code
7
+ numpy==1.25.2