ArneBinder
commited on
Commit
•
ee9934e
1
Parent(s):
b5bd682
simple emebdding calculation
Browse files- app.py +105 -5
- requirements.txt +1 -2
app.py
CHANGED
@@ -1,21 +1,84 @@
|
|
1 |
import json
|
2 |
-
|
|
|
|
|
3 |
|
4 |
import gradio as gr
|
|
|
|
|
5 |
from pie_modules.models import * # noqa: F403
|
6 |
from pie_modules.taskmodules import * # noqa: F403
|
|
|
7 |
from pytorch_ie.annotations import LabeledSpan
|
8 |
from pytorch_ie.auto import AutoPipeline
|
9 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
10 |
from pytorch_ie.models import * # noqa: F403
|
11 |
from pytorch_ie.taskmodules import * # noqa: F403
|
12 |
from rendering_utils import render_displacy, render_pretty_table
|
|
|
|
|
|
|
13 |
|
14 |
RENDER_WITH_DISPLACY = "displaCy + highlighted arguments"
|
15 |
RENDER_WITH_PRETTY_TABLE = "Pretty Table"
|
16 |
|
17 |
|
18 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(text=text)
|
20 |
|
21 |
# add single partition from the whole text (the model only considers text in partitions)
|
@@ -25,11 +88,27 @@ def predict(text: str) -> Tuple[dict, str]:
|
|
25 |
pipeline(document)
|
26 |
|
27 |
document_dict = document.asdict()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
return document_dict, json.dumps(document_dict)
|
29 |
|
30 |
|
31 |
def render(document_txt: str, render_with: str, render_kwargs_json: str) -> str:
|
32 |
document_dict = json.loads(document_txt)
|
|
|
|
|
33 |
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions.fromdict(
|
34 |
document_dict
|
35 |
)
|
@@ -52,16 +131,18 @@ def close_accordion():
|
|
52 |
return gr.Accordion(open=False)
|
53 |
|
54 |
|
55 |
-
|
56 |
|
57 |
model_name_or_path = "ArneBinder/sam-pointer-bart-base-v0.3"
|
|
|
58 |
revision = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
|
59 |
# local path
|
60 |
-
# model_name_or_path = "models/dataset-sciarg/task-ner_re/v0.3/2024-
|
61 |
# revision = None
|
62 |
|
63 |
example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
|
64 |
|
|
|
65 |
pipeline = AutoPipeline.from_pretrained(
|
66 |
model_name_or_path,
|
67 |
device=-1,
|
@@ -70,6 +151,10 @@ if __name__ == "__main__":
|
|
70 |
model_kwargs=dict(revision=revision),
|
71 |
)
|
72 |
|
|
|
|
|
|
|
|
|
73 |
default_render_kwargs = {
|
74 |
"entity_options": {
|
75 |
# we need to convert the keys to uppercase because the spacy rendering function expects them in uppercase
|
@@ -132,7 +217,15 @@ if __name__ == "__main__":
|
|
132 |
fn=render, inputs=[output_txt, render_as, render_kwargs], outputs=rendered_output
|
133 |
)
|
134 |
predict_btn.click(open_accordion, inputs=[], outputs=[output_accordion]).then(
|
135 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
).success(**render_button_kwargs).success(
|
137 |
close_accordion, inputs=[], outputs=[output_accordion]
|
138 |
)
|
@@ -223,3 +316,10 @@ if __name__ == "__main__":
|
|
223 |
rendered_output.change(fn=None, js=js, inputs=[], outputs=[])
|
224 |
|
225 |
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
2 |
+
import logging
|
3 |
+
from functools import partial
|
4 |
+
from typing import Optional, Tuple
|
5 |
|
6 |
import gradio as gr
|
7 |
+
from pie_modules.document.processing import tokenize_document
|
8 |
+
from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
9 |
from pie_modules.models import * # noqa: F403
|
10 |
from pie_modules.taskmodules import * # noqa: F403
|
11 |
+
from pytorch_ie import Pipeline
|
12 |
from pytorch_ie.annotations import LabeledSpan
|
13 |
from pytorch_ie.auto import AutoPipeline
|
14 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
15 |
from pytorch_ie.models import * # noqa: F403
|
16 |
from pytorch_ie.taskmodules import * # noqa: F403
|
17 |
from rendering_utils import render_displacy, render_pretty_table
|
18 |
+
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
|
22 |
RENDER_WITH_DISPLACY = "displaCy + highlighted arguments"
|
23 |
RENDER_WITH_PRETTY_TABLE = "Pretty Table"
|
24 |
|
25 |
|
26 |
+
def embed_text_annotations(
|
27 |
+
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
28 |
+
model: PreTrainedModel,
|
29 |
+
tokenizer: PreTrainedTokenizer,
|
30 |
+
text_layer_name: str,
|
31 |
+
) -> dict:
|
32 |
+
# to not modify the original document
|
33 |
+
document = document.copy()
|
34 |
+
# tokenize_document does not yet consider predictions, so we need to add them manually
|
35 |
+
document[text_layer_name].extend(document[text_layer_name].predictions.clear())
|
36 |
+
added_annotations = []
|
37 |
+
# TODO: set return_overflowing_tokens=True and max_length=...?
|
38 |
+
tokenizer_kwargs = {}
|
39 |
+
tokenized_documents = tokenize_document(
|
40 |
+
document,
|
41 |
+
tokenizer=tokenizer,
|
42 |
+
result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
43 |
+
partition_layer="labeled_partitions",
|
44 |
+
added_annotations=added_annotations,
|
45 |
+
**tokenizer_kwargs,
|
46 |
+
)
|
47 |
+
# just tokenize again to get tensors in the correct format for the model
|
48 |
+
model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
|
49 |
+
assert len(model_inputs.encodings) == len(tokenized_documents)
|
50 |
+
model_output = model(**model_inputs)
|
51 |
+
|
52 |
+
# get embeddings for all text annotations
|
53 |
+
embeddings = {}
|
54 |
+
for batch_idx in range(len(model_output.last_hidden_state)):
|
55 |
+
text2tok_ann = added_annotations[batch_idx][text_layer_name]
|
56 |
+
tok2text_ann = {v: k for k, v in text2tok_ann.items()}
|
57 |
+
for tok_ann in tokenized_documents[batch_idx].labeled_spans:
|
58 |
+
# skip "empty" annotations
|
59 |
+
if tok_ann.start == tok_ann.end:
|
60 |
+
continue
|
61 |
+
# use the max pooling strategy to get a single embedding for the annotation text
|
62 |
+
embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max(
|
63 |
+
dim=0
|
64 |
+
)[0]
|
65 |
+
text_ann = tok2text_ann[tok_ann]
|
66 |
+
|
67 |
+
if text_ann in embeddings:
|
68 |
+
logger.warning(
|
69 |
+
f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
|
70 |
+
)
|
71 |
+
embeddings[text_ann] = embedding
|
72 |
+
|
73 |
+
return embeddings
|
74 |
+
|
75 |
+
|
76 |
+
def predict(
|
77 |
+
text: str,
|
78 |
+
pipeline: Pipeline,
|
79 |
+
embedding_model: Optional[PreTrainedModel] = None,
|
80 |
+
embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
|
81 |
+
) -> Tuple[dict, str]:
|
82 |
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(text=text)
|
83 |
|
84 |
# add single partition from the whole text (the model only considers text in partitions)
|
|
|
88 |
pipeline(document)
|
89 |
|
90 |
document_dict = document.asdict()
|
91 |
+
|
92 |
+
if embedding_model is not None and embedding_tokenizer is not None:
|
93 |
+
adu_embeddings = embed_text_annotations(
|
94 |
+
document=document,
|
95 |
+
model=embedding_model,
|
96 |
+
tokenizer=embedding_tokenizer,
|
97 |
+
text_layer_name="labeled_spans",
|
98 |
+
)
|
99 |
+
# convert keys to str because JSON keys must be strings
|
100 |
+
adu_embeddings_dict = {str(k._id): v.detach().tolist() for k, v in adu_embeddings.items()}
|
101 |
+
document_dict["embeddings"] = adu_embeddings_dict
|
102 |
+
|
103 |
+
# Return as dict and JSON string. The latter is required because the JSON component converts floats
|
104 |
+
# to ints which destroys de-serialization of the document (the scores of the annotations need to be floats)
|
105 |
return document_dict, json.dumps(document_dict)
|
106 |
|
107 |
|
108 |
def render(document_txt: str, render_with: str, render_kwargs_json: str) -> str:
|
109 |
document_dict = json.loads(document_txt)
|
110 |
+
# remove embeddings from document_dict to make it de-serializable
|
111 |
+
document_dict.pop("embeddings", None)
|
112 |
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions.fromdict(
|
113 |
document_dict
|
114 |
)
|
|
|
131 |
return gr.Accordion(open=False)
|
132 |
|
133 |
|
134 |
+
def main():
|
135 |
|
136 |
model_name_or_path = "ArneBinder/sam-pointer-bart-base-v0.3"
|
137 |
+
# W&B run: https://wandb.ai/arne/dataset-sciarg-task-ner_re-v0.3-training/runs/prik91di
|
138 |
revision = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
|
139 |
# local path
|
140 |
+
# model_name_or_path = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
|
141 |
# revision = None
|
142 |
|
143 |
example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
|
144 |
|
145 |
+
print("loading argumentation mining model ...")
|
146 |
pipeline = AutoPipeline.from_pretrained(
|
147 |
model_name_or_path,
|
148 |
device=-1,
|
|
|
151 |
model_kwargs=dict(revision=revision),
|
152 |
)
|
153 |
|
154 |
+
print("loading SciBERT embedding model ...")
|
155 |
+
embedding_model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")
|
156 |
+
embedding_tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
|
157 |
+
|
158 |
default_render_kwargs = {
|
159 |
"entity_options": {
|
160 |
# we need to convert the keys to uppercase because the spacy rendering function expects them in uppercase
|
|
|
217 |
fn=render, inputs=[output_txt, render_as, render_kwargs], outputs=rendered_output
|
218 |
)
|
219 |
predict_btn.click(open_accordion, inputs=[], outputs=[output_accordion]).then(
|
220 |
+
fn=partial(
|
221 |
+
predict,
|
222 |
+
pipeline=pipeline,
|
223 |
+
embedding_model=embedding_model,
|
224 |
+
embedding_tokenizer=embedding_tokenizer,
|
225 |
+
),
|
226 |
+
inputs=text,
|
227 |
+
outputs=[output_json, output_txt],
|
228 |
+
api_name="predict",
|
229 |
).success(**render_button_kwargs).success(
|
230 |
close_accordion, inputs=[], outputs=[output_accordion]
|
231 |
)
|
|
|
316 |
rendered_output.change(fn=None, js=js, inputs=[], outputs=[])
|
317 |
|
318 |
demo.launch()
|
319 |
+
|
320 |
+
|
321 |
+
if __name__ == "__main__":
|
322 |
+
# configure logging
|
323 |
+
logging.basicConfig()
|
324 |
+
|
325 |
+
main()
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
-
transformers==4.39.3
|
2 |
gradio==4.31.4
|
3 |
prettytable==3.10.0
|
4 |
-
pie-modules
|
5 |
beautifulsoup4==4.12.3
|
|
|
|
|
1 |
gradio==4.31.4
|
2 |
prettytable==3.10.0
|
3 |
+
pie-modules==0.12.0
|
4 |
beautifulsoup4==4.12.3
|