ArneBinder commited on
Commit
ee9934e
1 Parent(s): b5bd682

simple emebdding calculation

Browse files
Files changed (2) hide show
  1. app.py +105 -5
  2. requirements.txt +1 -2
app.py CHANGED
@@ -1,21 +1,84 @@
1
  import json
2
- from typing import Tuple
 
 
3
 
4
  import gradio as gr
 
 
5
  from pie_modules.models import * # noqa: F403
6
  from pie_modules.taskmodules import * # noqa: F403
 
7
  from pytorch_ie.annotations import LabeledSpan
8
  from pytorch_ie.auto import AutoPipeline
9
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
10
  from pytorch_ie.models import * # noqa: F403
11
  from pytorch_ie.taskmodules import * # noqa: F403
12
  from rendering_utils import render_displacy, render_pretty_table
 
 
 
13
 
14
  RENDER_WITH_DISPLACY = "displaCy + highlighted arguments"
15
  RENDER_WITH_PRETTY_TABLE = "Pretty Table"
16
 
17
 
18
- def predict(text: str) -> Tuple[dict, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(text=text)
20
 
21
  # add single partition from the whole text (the model only considers text in partitions)
@@ -25,11 +88,27 @@ def predict(text: str) -> Tuple[dict, str]:
25
  pipeline(document)
26
 
27
  document_dict = document.asdict()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  return document_dict, json.dumps(document_dict)
29
 
30
 
31
  def render(document_txt: str, render_with: str, render_kwargs_json: str) -> str:
32
  document_dict = json.loads(document_txt)
 
 
33
  document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions.fromdict(
34
  document_dict
35
  )
@@ -52,16 +131,18 @@ def close_accordion():
52
  return gr.Accordion(open=False)
53
 
54
 
55
- if __name__ == "__main__":
56
 
57
  model_name_or_path = "ArneBinder/sam-pointer-bart-base-v0.3"
 
58
  revision = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
59
  # local path
60
- # model_name_or_path = "models/dataset-sciarg/task-ner_re/v0.3/2024-03-01_18-25-32"
61
  # revision = None
62
 
63
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
64
 
 
65
  pipeline = AutoPipeline.from_pretrained(
66
  model_name_or_path,
67
  device=-1,
@@ -70,6 +151,10 @@ if __name__ == "__main__":
70
  model_kwargs=dict(revision=revision),
71
  )
72
 
 
 
 
 
73
  default_render_kwargs = {
74
  "entity_options": {
75
  # we need to convert the keys to uppercase because the spacy rendering function expects them in uppercase
@@ -132,7 +217,15 @@ if __name__ == "__main__":
132
  fn=render, inputs=[output_txt, render_as, render_kwargs], outputs=rendered_output
133
  )
134
  predict_btn.click(open_accordion, inputs=[], outputs=[output_accordion]).then(
135
- fn=predict, inputs=text, outputs=[output_json, output_txt], api_name="predict"
 
 
 
 
 
 
 
 
136
  ).success(**render_button_kwargs).success(
137
  close_accordion, inputs=[], outputs=[output_accordion]
138
  )
@@ -223,3 +316,10 @@ if __name__ == "__main__":
223
  rendered_output.change(fn=None, js=js, inputs=[], outputs=[])
224
 
225
  demo.launch()
 
 
 
 
 
 
 
 
1
  import json
2
+ import logging
3
+ from functools import partial
4
+ from typing import Optional, Tuple
5
 
6
  import gradio as gr
7
+ from pie_modules.document.processing import tokenize_document
8
+ from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
9
  from pie_modules.models import * # noqa: F403
10
  from pie_modules.taskmodules import * # noqa: F403
11
+ from pytorch_ie import Pipeline
12
  from pytorch_ie.annotations import LabeledSpan
13
  from pytorch_ie.auto import AutoPipeline
14
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
15
  from pytorch_ie.models import * # noqa: F403
16
  from pytorch_ie.taskmodules import * # noqa: F403
17
  from rendering_utils import render_displacy, render_pretty_table
18
+ from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
19
+
20
+ logger = logging.getLogger(__name__)
21
 
22
  RENDER_WITH_DISPLACY = "displaCy + highlighted arguments"
23
  RENDER_WITH_PRETTY_TABLE = "Pretty Table"
24
 
25
 
26
+ def embed_text_annotations(
27
+ document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
28
+ model: PreTrainedModel,
29
+ tokenizer: PreTrainedTokenizer,
30
+ text_layer_name: str,
31
+ ) -> dict:
32
+ # to not modify the original document
33
+ document = document.copy()
34
+ # tokenize_document does not yet consider predictions, so we need to add them manually
35
+ document[text_layer_name].extend(document[text_layer_name].predictions.clear())
36
+ added_annotations = []
37
+ # TODO: set return_overflowing_tokens=True and max_length=...?
38
+ tokenizer_kwargs = {}
39
+ tokenized_documents = tokenize_document(
40
+ document,
41
+ tokenizer=tokenizer,
42
+ result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
43
+ partition_layer="labeled_partitions",
44
+ added_annotations=added_annotations,
45
+ **tokenizer_kwargs,
46
+ )
47
+ # just tokenize again to get tensors in the correct format for the model
48
+ model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
49
+ assert len(model_inputs.encodings) == len(tokenized_documents)
50
+ model_output = model(**model_inputs)
51
+
52
+ # get embeddings for all text annotations
53
+ embeddings = {}
54
+ for batch_idx in range(len(model_output.last_hidden_state)):
55
+ text2tok_ann = added_annotations[batch_idx][text_layer_name]
56
+ tok2text_ann = {v: k for k, v in text2tok_ann.items()}
57
+ for tok_ann in tokenized_documents[batch_idx].labeled_spans:
58
+ # skip "empty" annotations
59
+ if tok_ann.start == tok_ann.end:
60
+ continue
61
+ # use the max pooling strategy to get a single embedding for the annotation text
62
+ embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max(
63
+ dim=0
64
+ )[0]
65
+ text_ann = tok2text_ann[tok_ann]
66
+
67
+ if text_ann in embeddings:
68
+ logger.warning(
69
+ f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
70
+ )
71
+ embeddings[text_ann] = embedding
72
+
73
+ return embeddings
74
+
75
+
76
+ def predict(
77
+ text: str,
78
+ pipeline: Pipeline,
79
+ embedding_model: Optional[PreTrainedModel] = None,
80
+ embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
81
+ ) -> Tuple[dict, str]:
82
  document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(text=text)
83
 
84
  # add single partition from the whole text (the model only considers text in partitions)
 
88
  pipeline(document)
89
 
90
  document_dict = document.asdict()
91
+
92
+ if embedding_model is not None and embedding_tokenizer is not None:
93
+ adu_embeddings = embed_text_annotations(
94
+ document=document,
95
+ model=embedding_model,
96
+ tokenizer=embedding_tokenizer,
97
+ text_layer_name="labeled_spans",
98
+ )
99
+ # convert keys to str because JSON keys must be strings
100
+ adu_embeddings_dict = {str(k._id): v.detach().tolist() for k, v in adu_embeddings.items()}
101
+ document_dict["embeddings"] = adu_embeddings_dict
102
+
103
+ # Return as dict and JSON string. The latter is required because the JSON component converts floats
104
+ # to ints which destroys de-serialization of the document (the scores of the annotations need to be floats)
105
  return document_dict, json.dumps(document_dict)
106
 
107
 
108
  def render(document_txt: str, render_with: str, render_kwargs_json: str) -> str:
109
  document_dict = json.loads(document_txt)
110
+ # remove embeddings from document_dict to make it de-serializable
111
+ document_dict.pop("embeddings", None)
112
  document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions.fromdict(
113
  document_dict
114
  )
 
131
  return gr.Accordion(open=False)
132
 
133
 
134
+ def main():
135
 
136
  model_name_or_path = "ArneBinder/sam-pointer-bart-base-v0.3"
137
+ # W&B run: https://wandb.ai/arne/dataset-sciarg-task-ner_re-v0.3-training/runs/prik91di
138
  revision = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
139
  # local path
140
+ # model_name_or_path = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
141
  # revision = None
142
 
143
  example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
144
 
145
+ print("loading argumentation mining model ...")
146
  pipeline = AutoPipeline.from_pretrained(
147
  model_name_or_path,
148
  device=-1,
 
151
  model_kwargs=dict(revision=revision),
152
  )
153
 
154
+ print("loading SciBERT embedding model ...")
155
+ embedding_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")
156
+ embedding_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
157
+
158
  default_render_kwargs = {
159
  "entity_options": {
160
  # we need to convert the keys to uppercase because the spacy rendering function expects them in uppercase
 
217
  fn=render, inputs=[output_txt, render_as, render_kwargs], outputs=rendered_output
218
  )
219
  predict_btn.click(open_accordion, inputs=[], outputs=[output_accordion]).then(
220
+ fn=partial(
221
+ predict,
222
+ pipeline=pipeline,
223
+ embedding_model=embedding_model,
224
+ embedding_tokenizer=embedding_tokenizer,
225
+ ),
226
+ inputs=text,
227
+ outputs=[output_json, output_txt],
228
+ api_name="predict",
229
  ).success(**render_button_kwargs).success(
230
  close_accordion, inputs=[], outputs=[output_accordion]
231
  )
 
316
  rendered_output.change(fn=None, js=js, inputs=[], outputs=[])
317
 
318
  demo.launch()
319
+
320
+
321
+ if __name__ == "__main__":
322
+ # configure logging
323
+ logging.basicConfig()
324
+
325
+ main()
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
- transformers==4.39.3
2
  gradio==4.31.4
3
  prettytable==3.10.0
4
- pie-modules>=0.11.0,<0.12.0
5
  beautifulsoup4==4.12.3
 
 
1
  gradio==4.31.4
2
  prettytable==3.10.0
3
+ pie-modules==0.12.0
4
  beautifulsoup4==4.12.3