ArneBinder commited on
Commit
1f79774
1 Parent(s): ee9934e

save processed documents and model loading

Browse files

same as https://github.com/ArneBinder/pie-document-level/pull/213/commits/34327b71b6b1a50341d003311888402f6705b3bc

Files changed (1) hide show
  1. app.py +123 -28
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import json
2
  import logging
3
  from functools import partial
4
- from typing import Optional, Tuple
5
 
6
  import gradio as gr
7
  from pie_modules.document.processing import tokenize_document
@@ -22,6 +22,13 @@ logger = logging.getLogger(__name__)
22
  RENDER_WITH_DISPLACY = "displaCy + highlighted arguments"
23
  RENDER_WITH_PRETTY_TABLE = "Pretty Table"
24
 
 
 
 
 
 
 
 
25
 
26
  def embed_text_annotations(
27
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
@@ -99,6 +106,10 @@ def predict(
99
  # convert keys to str because JSON keys must be strings
100
  adu_embeddings_dict = {str(k._id): v.detach().tolist() for k, v in adu_embeddings.items()}
101
  document_dict["embeddings"] = adu_embeddings_dict
 
 
 
 
102
 
103
  # Return as dict and JSON string. The latter is required because the JSON component converts floats
104
  # to ints which destroys de-serialization of the document (the scores of the annotations need to be floats)
@@ -123,6 +134,29 @@ def render(document_txt: str, render_with: str, render_kwargs_json: str) -> str:
123
  return html
124
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  def open_accordion():
127
  return gr.Accordion(open=True)
128
 
@@ -131,30 +165,52 @@ def close_accordion():
131
  return gr.Accordion(open=False)
132
 
133
 
134
- def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- model_name_or_path = "ArneBinder/sam-pointer-bart-base-v0.3"
137
- # W&B run: https://wandb.ai/arne/dataset-sciarg-task-ner_re-v0.3-training/runs/prik91di
138
- revision = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
139
- # local path
140
- # model_name_or_path = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
141
- # revision = None
142
 
143
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
144
 
145
- print("loading argumentation mining model ...")
146
- pipeline = AutoPipeline.from_pretrained(
147
- model_name_or_path,
148
- device=-1,
149
- num_workers=0,
150
- taskmodule_kwargs=dict(revision=revision),
151
- model_kwargs=dict(revision=revision),
152
  )
153
 
154
- print("loading SciBERT embedding model ...")
155
- embedding_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")
156
- embedding_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
157
-
158
  default_render_kwargs = {
159
  "entity_options": {
160
  # we need to convert the keys to uppercase because the spacy rendering function expects them in uppercase
@@ -180,18 +236,49 @@ def main():
180
  },
181
  }
182
 
 
 
 
183
  with gr.Blocks() as demo:
 
 
 
 
184
  with gr.Row():
185
  with gr.Column(scale=1):
 
 
 
 
186
  text = gr.Textbox(
187
- label="Input Text",
188
  lines=20,
189
  value=example_text,
190
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- predict_btn = gr.Button("Predict")
193
 
194
  output_txt = gr.Textbox(visible=False)
 
195
 
196
  with gr.Column(scale=1):
197
 
@@ -216,14 +303,16 @@ def main():
216
  render_button_kwargs = dict(
217
  fn=render, inputs=[output_txt, render_as, render_kwargs], outputs=rendered_output
218
  )
 
 
 
 
 
 
 
219
  predict_btn.click(open_accordion, inputs=[], outputs=[output_accordion]).then(
220
- fn=partial(
221
- predict,
222
- pipeline=pipeline,
223
- embedding_model=embedding_model,
224
- embedding_tokenizer=embedding_tokenizer,
225
- ),
226
- inputs=text,
227
  outputs=[output_json, output_txt],
228
  api_name="predict",
229
  ).success(**render_button_kwargs).success(
@@ -231,6 +320,12 @@ def main():
231
  )
232
  render_btn.click(**render_button_kwargs, api_name="render")
233
 
 
 
 
 
 
 
234
  js = """
235
  () => {
236
  function maybeSetColor(entity, colorAttributeKey, colorDictKey) {
 
1
  import json
2
  import logging
3
  from functools import partial
4
+ from typing import Any, Optional, Tuple
5
 
6
  import gradio as gr
7
  from pie_modules.document.processing import tokenize_document
 
22
  RENDER_WITH_DISPLACY = "displaCy + highlighted arguments"
23
  RENDER_WITH_PRETTY_TABLE = "Pretty Table"
24
 
25
+ DEFAULT_MODEL_NAME = "ArneBinder/sam-pointer-bart-base-v0.3"
26
+ DEFAULT_MODEL_REVISION = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
27
+ # local path
28
+ # DEFAULT_MODEL_NAME = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
29
+ # DEFAULT_MODEL_REVISION = None
30
+ DEFAULT_EMBEDDING_MODEL_NAME = "allenai/scibert_scivocab_uncased"
31
+
32
 
33
  def embed_text_annotations(
34
  document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
 
106
  # convert keys to str because JSON keys must be strings
107
  adu_embeddings_dict = {str(k._id): v.detach().tolist() for k, v in adu_embeddings.items()}
108
  document_dict["embeddings"] = adu_embeddings_dict
109
+ else:
110
+ gr.Warning(
111
+ "No embedding model provided. Skipping embedding extraction. You can load an embedding model in the 'Model Configuration' section."
112
+ )
113
 
114
  # Return as dict and JSON string. The latter is required because the JSON component converts floats
115
  # to ints which destroys de-serialization of the document (the scores of the annotations need to be floats)
 
134
  return html
135
 
136
 
137
+ def add_to_index(
138
+ output_txt: str, doc_id: str, processed_documents: dict, vector_store: Any
139
+ ) -> None:
140
+ try:
141
+ if doc_id in processed_documents:
142
+ gr.Warning(f"Document {doc_id} already in index. Overwriting.")
143
+ output = json.loads(output_txt)
144
+ # get the embeddings from the output and remove them from the output
145
+ embeddings = output.pop("embeddings")
146
+ # save the processed document to the index
147
+ processed_documents[doc_id] = output
148
+ # save the embeddings to the vector store
149
+ for adu_id, embedding in embeddings.items():
150
+ emb_id = f"{doc_id}:{adu_id}"
151
+ # TODO: save embedding to vector store at emb_id (embedding is a list of 768 floats)
152
+
153
+ gr.Info(
154
+ f"Added document {doc_id} to index (index contains {len(processed_documents)} entries). (NOT YET IMPLEMENTED)"
155
+ )
156
+ except Exception as e:
157
+ raise gr.Error(f"Failed to add document {doc_id} to index: {e}")
158
+
159
+
160
  def open_accordion():
161
  return gr.Accordion(open=True)
162
 
 
165
  return gr.Accordion(open=False)
166
 
167
 
168
+ def load_argumentation_model(model_name: str, revision: Optional[str] = None) -> Pipeline:
169
+ try:
170
+ model = AutoPipeline.from_pretrained(
171
+ model_name,
172
+ device=-1,
173
+ num_workers=0,
174
+ taskmodule_kwargs=dict(revision=revision),
175
+ model_kwargs=dict(revision=revision),
176
+ )
177
+ except Exception as e:
178
+ raise gr.Error(f"Failed to load argumentation model: {e}")
179
+ gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})")
180
+ return model
181
+
182
+
183
+ def load_embedding_model(model_name: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
184
+ try:
185
+ embedding_model = AutoModel.from_pretrained(model_name)
186
+ embedding_tokenizer = AutoTokenizer.from_pretrained(model_name)
187
+ except Exception as e:
188
+ raise gr.Error(f"Failed to load embedding model: {e}")
189
+ gr.Info(f"Loaded embedding model: model_name={model_name})")
190
+ return embedding_model, embedding_tokenizer
191
+
192
+
193
+ def load_models(
194
+ model_name: str, revision: Optional[str] = None, embedding_model_name: Optional[str] = None
195
+ ) -> Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]]:
196
+ argumentation_model = load_argumentation_model(model_name, revision)
197
+ embedding_model = None
198
+ embedding_tokenizer = None
199
+ if embedding_model_name is not None and embedding_model_name.strip():
200
+ embedding_model, embedding_tokenizer = load_embedding_model(embedding_model_name)
201
+
202
+ return argumentation_model, embedding_model, embedding_tokenizer
203
 
204
+
205
+ def main():
 
 
 
 
206
 
207
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
208
 
209
+ print("Loading argumentation model ...")
210
+ argumentation_model = load_argumentation_model(
211
+ model_name=DEFAULT_MODEL_NAME, revision=DEFAULT_MODEL_REVISION
 
 
 
 
212
  )
213
 
 
 
 
 
214
  default_render_kwargs = {
215
  "entity_options": {
216
  # we need to convert the keys to uppercase because the spacy rendering function expects them in uppercase
 
236
  },
237
  }
238
 
239
+ # TODO: setup the vector store
240
+ vector_store = None
241
+
242
  with gr.Blocks() as demo:
243
+ processed_documents_state = gr.State(dict())
244
+ vector_store_state = gr.State(vector_store)
245
+ # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
246
+ models_state = gr.State((argumentation_model, None, None))
247
  with gr.Row():
248
  with gr.Column(scale=1):
249
+ doc_id = gr.Textbox(
250
+ label="Document ID",
251
+ value="user_input",
252
+ )
253
  text = gr.Textbox(
254
+ label="Text",
255
  lines=20,
256
  value=example_text,
257
  )
258
+ with gr.Accordion("Model Configuration", open=False):
259
+ model_name = gr.Textbox(
260
+ label="Model Name",
261
+ value=DEFAULT_MODEL_NAME,
262
+ )
263
+ model_revision = gr.Textbox(
264
+ label="Model Revision",
265
+ value=DEFAULT_MODEL_REVISION,
266
+ )
267
+ embedding_model_name = gr.Textbox(
268
+ label=f"Embedding Model Name (e.g. {DEFAULT_EMBEDDING_MODEL_NAME})",
269
+ value="",
270
+ )
271
+ load_models_btn = gr.Button("Load Models")
272
+ load_models_btn.click(
273
+ fn=load_models,
274
+ inputs=[model_name, model_revision, embedding_model_name],
275
+ outputs=models_state,
276
+ )
277
 
278
+ predict_btn = gr.Button("Analyse")
279
 
280
  output_txt = gr.Textbox(visible=False)
281
+ add_to_index_btn = gr.Button("Add current result to Index")
282
 
283
  with gr.Column(scale=1):
284
 
 
303
  render_button_kwargs = dict(
304
  fn=render, inputs=[output_txt, render_as, render_kwargs], outputs=rendered_output
305
  )
306
+
307
+ def _predict(
308
+ text: str,
309
+ models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
310
+ ) -> Tuple[dict, str]:
311
+ return predict(text, *models)
312
+
313
  predict_btn.click(open_accordion, inputs=[], outputs=[output_accordion]).then(
314
+ fn=_predict,
315
+ inputs=[text, models_state],
 
 
 
 
 
316
  outputs=[output_json, output_txt],
317
  api_name="predict",
318
  ).success(**render_button_kwargs).success(
 
320
  )
321
  render_btn.click(**render_button_kwargs, api_name="render")
322
 
323
+ add_to_index_btn.click(
324
+ fn=add_to_index,
325
+ inputs=[output_txt, doc_id, processed_documents_state, vector_store_state],
326
+ outputs=[],
327
+ )
328
+
329
  js = """
330
  () => {
331
  function maybeSetColor(entity, colorAttributeKey, colorDictKey) {