ArneBinder
commited on
Commit
•
1681237
1
Parent(s):
148e0d6
Upload 10 files
Browse filesfrom https://github.com/ArneBinder/pie-document-level/pull/221
- app.py +59 -13
- document_store.py +0 -2
- embedding.py +124 -0
- model_utils.py +68 -116
- rendering_utils.py +4 -1
- requirements.txt +3 -0
app.py
CHANGED
@@ -7,8 +7,10 @@ from typing import List, Optional, Tuple
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import pandas as pd
|
|
|
10 |
from document_store import DocumentStore, get_annotation_from_document
|
11 |
-
from
|
|
|
12 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
13 |
from pytorch_ie import Pipeline
|
14 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
@@ -26,6 +28,9 @@ DEFAULT_MODEL_REVISION = "76300f8e534e2fcf695f00cb49bba166739b8d8a"
|
|
26 |
# DEFAULT_MODEL_NAME = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
|
27 |
# DEFAULT_MODEL_REVISION = None
|
28 |
DEFAULT_EMBEDDING_MODEL_NAME = "allenai/scibert_scivocab_uncased"
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
def render_annotated_document(
|
@@ -47,22 +52,26 @@ def render_annotated_document(
|
|
47 |
def wrapped_process_text(
|
48 |
text: str,
|
49 |
doc_id: str,
|
50 |
-
models: Tuple[Pipeline, Optional[
|
51 |
document_store: DocumentStore,
|
52 |
) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
|
53 |
-
|
54 |
-
text=text,
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
59 |
# Return as dict and document to avoid serialization issues
|
60 |
return document.asdict(), document
|
61 |
|
62 |
|
63 |
def process_uploaded_files(
|
64 |
file_names: List[str],
|
65 |
-
models: Tuple[Pipeline, Optional[
|
66 |
document_store: DocumentStore,
|
67 |
) -> pd.DataFrame:
|
68 |
try:
|
@@ -74,7 +83,16 @@ def process_uploaded_files(
|
|
74 |
text = f.read()
|
75 |
base_file_name = os.path.basename(file_name)
|
76 |
gr.Info(f"Processing file '{base_file_name}' ...")
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
else:
|
79 |
raise gr.Error(f"Unsupported file format: {file_name}")
|
80 |
document_store.add_documents(new_documents)
|
@@ -143,10 +161,13 @@ def main():
|
|
143 |
example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
|
144 |
|
145 |
print("Loading models ...")
|
146 |
-
argumentation_model, embedding_model
|
147 |
model_name=DEFAULT_MODEL_NAME,
|
148 |
revision=DEFAULT_MODEL_REVISION,
|
149 |
embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
|
|
|
|
|
|
|
150 |
)
|
151 |
|
152 |
default_render_kwargs = {
|
@@ -179,7 +200,7 @@ def main():
|
|
179 |
DocumentStore(span_annotation_caption="adu", relation_annotation_caption="relation")
|
180 |
)
|
181 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
182 |
-
models_state = gr.State((argumentation_model, embedding_model
|
183 |
with gr.Row():
|
184 |
with gr.Column(scale=1):
|
185 |
doc_id = gr.Textbox(
|
@@ -204,10 +225,35 @@ def main():
|
|
204 |
label=f"Embedding Model Name (e.g. {DEFAULT_EMBEDDING_MODEL_NAME})",
|
205 |
value=DEFAULT_EMBEDDING_MODEL_NAME,
|
206 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
load_models_btn = gr.Button("Load Models")
|
208 |
load_models_btn.click(
|
209 |
fn=load_models,
|
210 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
outputs=models_state,
|
212 |
)
|
213 |
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import pandas as pd
|
10 |
+
import torch
|
11 |
from document_store import DocumentStore, get_annotation_from_document
|
12 |
+
from embedding import EmbeddingModel
|
13 |
+
from model_utils import annotate_document, create_document, load_models
|
14 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
15 |
from pytorch_ie import Pipeline
|
16 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
|
|
28 |
# DEFAULT_MODEL_NAME = "models/dataset-sciarg/task-ner_re/v0.3/2024-05-28_23-33-46"
|
29 |
# DEFAULT_MODEL_REVISION = None
|
30 |
DEFAULT_EMBEDDING_MODEL_NAME = "allenai/scibert_scivocab_uncased"
|
31 |
+
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
32 |
+
DEFAULT_EMBEDDING_MAX_LENGTH = 512
|
33 |
+
DEFAULT_EMBEDDING_BATCH_SIZE = 32
|
34 |
|
35 |
|
36 |
def render_annotated_document(
|
|
|
52 |
def wrapped_process_text(
|
53 |
text: str,
|
54 |
doc_id: str,
|
55 |
+
models: Tuple[Pipeline, Optional[EmbeddingModel]],
|
56 |
document_store: DocumentStore,
|
57 |
) -> Tuple[dict, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]:
|
58 |
+
try:
|
59 |
+
document = create_document(text=text, doc_id=doc_id)
|
60 |
+
annotate_document(
|
61 |
+
document=document,
|
62 |
+
annotation_pipeline=models[0],
|
63 |
+
embedding_model=models[1],
|
64 |
+
)
|
65 |
+
document_store.add_document(document)
|
66 |
+
except Exception as e:
|
67 |
+
raise gr.Error(f"Failed to process text: {e}")
|
68 |
# Return as dict and document to avoid serialization issues
|
69 |
return document.asdict(), document
|
70 |
|
71 |
|
72 |
def process_uploaded_files(
|
73 |
file_names: List[str],
|
74 |
+
models: Tuple[Pipeline, Optional[EmbeddingModel]],
|
75 |
document_store: DocumentStore,
|
76 |
) -> pd.DataFrame:
|
77 |
try:
|
|
|
83 |
text = f.read()
|
84 |
base_file_name = os.path.basename(file_name)
|
85 |
gr.Info(f"Processing file '{base_file_name}' ...")
|
86 |
+
new_document = create_document(
|
87 |
+
text=text,
|
88 |
+
doc_id=base_file_name,
|
89 |
+
)
|
90 |
+
annotate_document(
|
91 |
+
document=new_document,
|
92 |
+
annotation_pipeline=models[0],
|
93 |
+
embedding_model=models[1],
|
94 |
+
)
|
95 |
+
new_documents.append(new_document)
|
96 |
else:
|
97 |
raise gr.Error(f"Unsupported file format: {file_name}")
|
98 |
document_store.add_documents(new_documents)
|
|
|
161 |
example_text = "Scholarly Argumentation Mining (SAM) has recently gained attention due to its potential to help scholars with the rapid growth of published scientific literature. It comprises two subtasks: argumentative discourse unit recognition (ADUR) and argumentative relation extraction (ARE), both of which are challenging since they require e.g. the integration of domain knowledge, the detection of implicit statements, and the disambiguation of argument structure. While previous work focused on dataset construction and baseline methods for specific document sections, such as abstract or results, full-text scholarly argumentation mining has seen little progress. In this work, we introduce a sequential pipeline model combining ADUR and ARE for full-text SAM, and provide a first analysis of the performance of pretrained language models (PLMs) on both subtasks. We establish a new SotA for ADUR on the Sci-Arg corpus, outperforming the previous best reported result by a large margin (+7% F1). We also present the first results for ARE, and thus for the full AM pipeline, on this benchmark dataset. Our detailed error analysis reveals that non-contiguous ADUs as well as the interpretation of discourse connectors pose major challenges and that data annotation needs to be more consistent."
|
162 |
|
163 |
print("Loading models ...")
|
164 |
+
argumentation_model, embedding_model = load_models(
|
165 |
model_name=DEFAULT_MODEL_NAME,
|
166 |
revision=DEFAULT_MODEL_REVISION,
|
167 |
embedding_model_name=DEFAULT_EMBEDDING_MODEL_NAME,
|
168 |
+
embedding_max_length=DEFAULT_EMBEDDING_MAX_LENGTH,
|
169 |
+
embedding_batch_size=DEFAULT_EMBEDDING_BATCH_SIZE,
|
170 |
+
device=DEFAULT_DEVICE,
|
171 |
)
|
172 |
|
173 |
default_render_kwargs = {
|
|
|
200 |
DocumentStore(span_annotation_caption="adu", relation_annotation_caption="relation")
|
201 |
)
|
202 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
203 |
+
models_state = gr.State((argumentation_model, embedding_model))
|
204 |
with gr.Row():
|
205 |
with gr.Column(scale=1):
|
206 |
doc_id = gr.Textbox(
|
|
|
225 |
label=f"Embedding Model Name (e.g. {DEFAULT_EMBEDDING_MODEL_NAME})",
|
226 |
value=DEFAULT_EMBEDDING_MODEL_NAME,
|
227 |
)
|
228 |
+
embedding_max_length = gr.Slider(
|
229 |
+
label="Embedding Model Max Length",
|
230 |
+
minimum=16,
|
231 |
+
maximum=2048,
|
232 |
+
step=8,
|
233 |
+
value=DEFAULT_EMBEDDING_MAX_LENGTH,
|
234 |
+
)
|
235 |
+
embedding_batch_size = gr.Slider(
|
236 |
+
label="Embedding Model Batch Size",
|
237 |
+
minimum=1,
|
238 |
+
maximum=128,
|
239 |
+
step=1,
|
240 |
+
value=DEFAULT_EMBEDDING_BATCH_SIZE,
|
241 |
+
)
|
242 |
+
device = gr.Textbox(
|
243 |
+
label="Device (e.g. 'cuda' or 'cpu')",
|
244 |
+
value=DEFAULT_DEVICE,
|
245 |
+
)
|
246 |
load_models_btn = gr.Button("Load Models")
|
247 |
load_models_btn.click(
|
248 |
fn=load_models,
|
249 |
+
inputs=[
|
250 |
+
model_name,
|
251 |
+
model_revision,
|
252 |
+
embedding_model_name,
|
253 |
+
embedding_max_length,
|
254 |
+
embedding_batch_size,
|
255 |
+
device,
|
256 |
+
],
|
257 |
outputs=models_state,
|
258 |
)
|
259 |
|
document_store.py
CHANGED
@@ -307,8 +307,6 @@ class DocumentStore:
|
|
307 |
|
308 |
def add_document_from_dict(self, document_dict: dict) -> None:
|
309 |
document = self.document_type.fromdict(document_dict)
|
310 |
-
# metadata is not automatically deserialized, so we need to set it manually
|
311 |
-
document.metadata = document_dict["metadata"]
|
312 |
self.add_document(document)
|
313 |
|
314 |
def add_documents(self, documents: List[TextBasedDocument]) -> None:
|
|
|
307 |
|
308 |
def add_document_from_dict(self, document_dict: dict) -> None:
|
309 |
document = self.document_type.fromdict(document_dict)
|
|
|
|
|
310 |
self.add_document(document)
|
311 |
|
312 |
def add_documents(self, documents: List[TextBasedDocument]) -> None:
|
embedding.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import logging
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from datasets import Dataset
|
7 |
+
from pie_modules.document.processing import tokenize_document
|
8 |
+
from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
9 |
+
from pytorch_ie.annotations import Span
|
10 |
+
from pytorch_ie.documents import TextBasedDocument
|
11 |
+
from torch import FloatTensor, Tensor
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
from transformers import AutoModel, AutoTokenizer
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class EmbeddingModel(abc.ABC):
|
19 |
+
def __call__(
|
20 |
+
self, document: TextBasedDocument, span_layer_name: str
|
21 |
+
) -> Dict[Span, FloatTensor]:
|
22 |
+
"""Embed text annotations from a document.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
document: The document to embed.
|
26 |
+
span_layer_name: The name of the annotation layer in the document that contains the
|
27 |
+
text span annotations to embed.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
A dictionary mapping text annotations to their embeddings.
|
31 |
+
"""
|
32 |
+
pass
|
33 |
+
|
34 |
+
|
35 |
+
class HuggingfaceEmbeddingModel(EmbeddingModel):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
model_name_or_path: str,
|
39 |
+
revision: str = None,
|
40 |
+
device: str = "cpu",
|
41 |
+
max_length: int = 512,
|
42 |
+
batch_size: int = 16,
|
43 |
+
):
|
44 |
+
self.load(model_name_or_path, revision, device)
|
45 |
+
self.max_length = max_length
|
46 |
+
self.batch_size = batch_size
|
47 |
+
|
48 |
+
def load(self, model_name_or_path: str, revision: str = None, device: str = "cpu") -> None:
|
49 |
+
self._model = AutoModel.from_pretrained(model_name_or_path, revision=revision).to(device)
|
50 |
+
self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, revision=revision)
|
51 |
+
|
52 |
+
def __call__(
|
53 |
+
self, document: TextBasedDocument, span_layer_name: str
|
54 |
+
) -> Dict[Span, FloatTensor]:
|
55 |
+
# to not modify the original document
|
56 |
+
document = document.copy()
|
57 |
+
# tokenize_document does not yet consider predictions, so we need to add them manually
|
58 |
+
document[span_layer_name].extend(document[span_layer_name].predictions.clear())
|
59 |
+
added_annotations = []
|
60 |
+
tokenizer_kwargs = {
|
61 |
+
"max_length": self.max_length,
|
62 |
+
"stride": self.max_length // 8,
|
63 |
+
"truncation": True,
|
64 |
+
"padding": True,
|
65 |
+
"return_overflowing_tokens": True,
|
66 |
+
}
|
67 |
+
# tokenize once to get the tokenized documents with mapped annotations
|
68 |
+
tokenized_documents = tokenize_document(
|
69 |
+
document,
|
70 |
+
tokenizer=self._tokenizer,
|
71 |
+
result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
72 |
+
partition_layer="labeled_partitions",
|
73 |
+
added_annotations=added_annotations,
|
74 |
+
strict_span_conversion=False,
|
75 |
+
**tokenizer_kwargs,
|
76 |
+
)
|
77 |
+
|
78 |
+
# just tokenize again to get tensors in the correct format for the model
|
79 |
+
dataset = Dataset.from_dict({"text": [document.text]})
|
80 |
+
|
81 |
+
def tokenize_function(examples):
|
82 |
+
return self._tokenizer(examples["text"], **tokenizer_kwargs)
|
83 |
+
|
84 |
+
# Tokenize the texts. Note that we remove the text column directly in the map call,
|
85 |
+
# otherwise the map would fail because we produce we amy produce multipel new rows
|
86 |
+
# (tokenization result) for each input row (text).
|
87 |
+
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
|
88 |
+
# remove the overflow_to_sample_mapping column
|
89 |
+
tokenized_dataset = tokenized_dataset.remove_columns(["overflow_to_sample_mapping"])
|
90 |
+
tokenized_dataset.set_format(type="torch")
|
91 |
+
|
92 |
+
dataloader = DataLoader(tokenized_dataset, batch_size=self.batch_size)
|
93 |
+
|
94 |
+
embeddings = {}
|
95 |
+
example_idx = 0
|
96 |
+
for batch in dataloader:
|
97 |
+
batch_at_device = {
|
98 |
+
k: v.to(self._model.device) if isinstance(v, Tensor) else v
|
99 |
+
for k, v in batch.items()
|
100 |
+
}
|
101 |
+
with torch.no_grad():
|
102 |
+
model_output = self._model(**batch_at_device)
|
103 |
+
|
104 |
+
for last_hidden_state in model_output.last_hidden_state:
|
105 |
+
text2tok_ann = added_annotations[example_idx][span_layer_name]
|
106 |
+
tok2text_ann = {v: k for k, v in text2tok_ann.items()}
|
107 |
+
for tok_ann in tokenized_documents[example_idx].labeled_spans:
|
108 |
+
# skip "empty" annotations
|
109 |
+
if tok_ann.start == tok_ann.end:
|
110 |
+
continue
|
111 |
+
# use the max pooling strategy to get a single embedding for the annotation text
|
112 |
+
embedding = (
|
113 |
+
last_hidden_state[tok_ann.start : tok_ann.end].max(dim=0)[0].detach().cpu()
|
114 |
+
)
|
115 |
+
text_ann = tok2text_ann[tok_ann]
|
116 |
+
|
117 |
+
if text_ann in embeddings:
|
118 |
+
logger.warning(
|
119 |
+
f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
|
120 |
+
)
|
121 |
+
embeddings[text_ann] = embedding
|
122 |
+
example_idx += 1
|
123 |
+
|
124 |
+
return embeddings
|
model_utils.py
CHANGED
@@ -1,98 +1,45 @@
|
|
1 |
import logging
|
2 |
-
from typing import
|
3 |
|
4 |
import gradio as gr
|
|
|
5 |
from annotation_utils import labeled_span_to_id
|
6 |
-
from
|
7 |
-
from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
8 |
from pytorch_ie import Pipeline
|
9 |
from pytorch_ie.annotations import LabeledSpan
|
10 |
from pytorch_ie.auto import AutoPipeline
|
11 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
12 |
-
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
|
13 |
|
14 |
logger = logging.getLogger(__name__)
|
15 |
|
16 |
|
17 |
-
def
|
18 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
19 |
-
|
20 |
-
|
21 |
-
text_layer_name: str,
|
22 |
-
) -> Dict[LabeledSpan, List[float]]:
|
23 |
-
# to not modify the original document
|
24 |
-
document = document.copy()
|
25 |
-
# tokenize_document does not yet consider predictions, so we need to add them manually
|
26 |
-
document[text_layer_name].extend(document[text_layer_name].predictions.clear())
|
27 |
-
added_annotations = []
|
28 |
-
tokenizer_kwargs = {
|
29 |
-
"max_length": 512,
|
30 |
-
"stride": 64,
|
31 |
-
"truncation": True,
|
32 |
-
"return_overflowing_tokens": True,
|
33 |
-
}
|
34 |
-
tokenized_documents = tokenize_document(
|
35 |
-
document,
|
36 |
-
tokenizer=tokenizer,
|
37 |
-
result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
38 |
-
partition_layer="labeled_partitions",
|
39 |
-
added_annotations=added_annotations,
|
40 |
-
strict_span_conversion=False,
|
41 |
-
**tokenizer_kwargs,
|
42 |
-
)
|
43 |
-
# just tokenize again to get tensors in the correct format for the model
|
44 |
-
model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
|
45 |
-
# this is added when using return_overflowing_tokens=True, but the model does not accept it
|
46 |
-
model_inputs.pop("overflow_to_sample_mapping", None)
|
47 |
-
assert len(model_inputs.encodings) == len(tokenized_documents)
|
48 |
-
model_output = model(**model_inputs)
|
49 |
-
|
50 |
-
# get embeddings for all text annotations
|
51 |
-
embeddings = {}
|
52 |
-
for batch_idx in range(len(model_output.last_hidden_state)):
|
53 |
-
text2tok_ann = added_annotations[batch_idx][text_layer_name]
|
54 |
-
tok2text_ann = {v: k for k, v in text2tok_ann.items()}
|
55 |
-
for tok_ann in tokenized_documents[batch_idx].labeled_spans:
|
56 |
-
# skip "empty" annotations
|
57 |
-
if tok_ann.start == tok_ann.end:
|
58 |
-
continue
|
59 |
-
# use the max pooling strategy to get a single embedding for the annotation text
|
60 |
-
embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max(
|
61 |
-
dim=0
|
62 |
-
)[0]
|
63 |
-
text_ann = tok2text_ann[tok_ann]
|
64 |
-
|
65 |
-
if text_ann in embeddings:
|
66 |
-
logger.warning(
|
67 |
-
f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
|
68 |
-
)
|
69 |
-
embeddings[text_ann] = embedding
|
70 |
-
|
71 |
-
return embeddings
|
72 |
-
|
73 |
-
|
74 |
-
def _annotate(
|
75 |
-
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
76 |
-
pipeline: Pipeline,
|
77 |
-
embedding_model: Optional[PreTrainedModel] = None,
|
78 |
-
embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
|
79 |
) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
# execute prediction pipeline
|
82 |
-
|
83 |
|
84 |
-
if embedding_model is not None
|
85 |
-
|
86 |
document=document,
|
87 |
-
|
88 |
-
tokenizer=embedding_tokenizer,
|
89 |
-
text_layer_name="labeled_spans",
|
90 |
)
|
91 |
# convert keys to str because JSON keys must be strings
|
92 |
-
|
93 |
-
labeled_span_to_id(k): v.
|
94 |
}
|
95 |
-
document.metadata["embeddings"] =
|
96 |
else:
|
97 |
gr.Warning(
|
98 |
"No embedding model provided. Skipping embedding extraction. You can load an embedding "
|
@@ -100,47 +47,47 @@ def _annotate(
|
|
100 |
)
|
101 |
|
102 |
|
103 |
-
def
|
104 |
-
text: str,
|
105 |
-
doc_id: str,
|
106 |
-
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
|
107 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
108 |
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
|
109 |
-
text
|
110 |
|
111 |
Parameters:
|
112 |
text: The text to process.
|
113 |
doc_id: The ID of the document.
|
114 |
-
models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
|
115 |
|
116 |
Returns:
|
117 |
The processed document.
|
118 |
"""
|
119 |
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
# annotate the document
|
127 |
-
_annotate(
|
128 |
-
document=document,
|
129 |
-
pipeline=models[0],
|
130 |
-
embedding_model=models[1],
|
131 |
-
embedding_tokenizer=models[2],
|
132 |
-
)
|
133 |
-
|
134 |
-
return document
|
135 |
-
except Exception as e:
|
136 |
-
raise gr.Error(f"Failed to process text: {e}")
|
137 |
|
138 |
|
139 |
-
def load_argumentation_model(
|
|
|
|
|
|
|
|
|
140 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
model = AutoPipeline.from_pretrained(
|
142 |
model_name,
|
143 |
-
device
|
144 |
num_workers=0,
|
145 |
taskmodule_kwargs=dict(revision=revision),
|
146 |
model_kwargs=dict(revision=revision),
|
@@ -151,23 +98,28 @@ def load_argumentation_model(model_name: str, revision: Optional[str] = None) ->
|
|
151 |
return model
|
152 |
|
153 |
|
154 |
-
def load_embedding_model(model_name: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
155 |
-
try:
|
156 |
-
embedding_model = AutoModel.from_pretrained(model_name)
|
157 |
-
embedding_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
158 |
-
except Exception as e:
|
159 |
-
raise gr.Error(f"Failed to load embedding model: {e}")
|
160 |
-
gr.Info(f"Loaded embedding model: model_name={model_name})")
|
161 |
-
return embedding_model, embedding_tokenizer
|
162 |
-
|
163 |
-
|
164 |
def load_models(
|
165 |
-
model_name: str,
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
embedding_model = None
|
169 |
-
embedding_tokenizer = None
|
170 |
if embedding_model_name is not None and embedding_model_name.strip():
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
from typing import Optional, Tuple
|
3 |
|
4 |
import gradio as gr
|
5 |
+
import torch
|
6 |
from annotation_utils import labeled_span_to_id
|
7 |
+
from embedding import EmbeddingModel, HuggingfaceEmbeddingModel
|
|
|
8 |
from pytorch_ie import Pipeline
|
9 |
from pytorch_ie.annotations import LabeledSpan
|
10 |
from pytorch_ie.auto import AutoPipeline
|
11 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
|
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
15 |
|
16 |
+
def annotate_document(
|
17 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
18 |
+
annotation_pipeline: Pipeline,
|
19 |
+
embedding_model: Optional[EmbeddingModel] = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
) -> None:
|
21 |
+
"""Annotate a document with the provided pipeline. If an embedding model is provided, also
|
22 |
+
extract embeddings for the labeled spans.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
document: The document to annotate.
|
26 |
+
annotation_pipeline: The pipeline to use for annotation.
|
27 |
+
embedding_model: The embedding model to use for extracting text span embeddings.
|
28 |
+
"""
|
29 |
|
30 |
# execute prediction pipeline
|
31 |
+
annotation_pipeline(document)
|
32 |
|
33 |
+
if embedding_model is not None:
|
34 |
+
text_span_embeddings = embedding_model(
|
35 |
document=document,
|
36 |
+
span_layer_name="labeled_spans",
|
|
|
|
|
37 |
)
|
38 |
# convert keys to str because JSON keys must be strings
|
39 |
+
text_span_embeddings_dict = {
|
40 |
+
labeled_span_to_id(k): v.tolist() for k, v in text_span_embeddings.items()
|
41 |
}
|
42 |
+
document.metadata["embeddings"] = text_span_embeddings_dict
|
43 |
else:
|
44 |
gr.Warning(
|
45 |
"No embedding model provided. Skipping embedding extraction. You can load an embedding "
|
|
|
47 |
)
|
48 |
|
49 |
|
50 |
+
def create_document(
|
51 |
+
text: str, doc_id: str
|
|
|
|
|
52 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
53 |
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
|
54 |
+
text.
|
55 |
|
56 |
Parameters:
|
57 |
text: The text to process.
|
58 |
doc_id: The ID of the document.
|
|
|
59 |
|
60 |
Returns:
|
61 |
The processed document.
|
62 |
"""
|
63 |
|
64 |
+
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
|
65 |
+
id=doc_id, text=text, metadata={}
|
66 |
+
)
|
67 |
+
# add single partition from the whole text (the model only considers text in partitions)
|
68 |
+
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
|
69 |
+
return document
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
|
72 |
+
def load_argumentation_model(
|
73 |
+
model_name: str,
|
74 |
+
revision: Optional[str] = None,
|
75 |
+
device: str = "cpu",
|
76 |
+
) -> Pipeline:
|
77 |
try:
|
78 |
+
# the Pipeline class expects an integer for the device
|
79 |
+
if device == "cuda":
|
80 |
+
pipeline_device = 0
|
81 |
+
elif device.startswith("cuda:"):
|
82 |
+
pipeline_device = int(device.split(":")[1])
|
83 |
+
elif device == "cpu":
|
84 |
+
pipeline_device = -1
|
85 |
+
else:
|
86 |
+
raise gr.Error(f"Invalid device: {device}")
|
87 |
+
|
88 |
model = AutoPipeline.from_pretrained(
|
89 |
model_name,
|
90 |
+
device=pipeline_device,
|
91 |
num_workers=0,
|
92 |
taskmodule_kwargs=dict(revision=revision),
|
93 |
model_kwargs=dict(revision=revision),
|
|
|
98 |
return model
|
99 |
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
def load_models(
|
102 |
+
model_name: str,
|
103 |
+
revision: Optional[str] = None,
|
104 |
+
embedding_model_name: Optional[str] = None,
|
105 |
+
# embedding_model_revision: Optional[str] = None,
|
106 |
+
embedding_max_length: int = 512,
|
107 |
+
embedding_batch_size: int = 16,
|
108 |
+
device: str = "cpu",
|
109 |
+
) -> Tuple[Pipeline, Optional[EmbeddingModel]]:
|
110 |
+
torch.cuda.empty_cache()
|
111 |
+
argumentation_model = load_argumentation_model(model_name, revision=revision, device=device)
|
112 |
embedding_model = None
|
|
|
113 |
if embedding_model_name is not None and embedding_model_name.strip():
|
114 |
+
try:
|
115 |
+
embedding_model = HuggingfaceEmbeddingModel(
|
116 |
+
embedding_model_name.strip(),
|
117 |
+
# revision=embedding_model_revision,
|
118 |
+
device=device,
|
119 |
+
max_length=embedding_max_length,
|
120 |
+
batch_size=embedding_batch_size,
|
121 |
+
)
|
122 |
+
except Exception as e:
|
123 |
+
raise gr.Error(f"Failed to load embedding model: {e}")
|
124 |
+
|
125 |
+
return argumentation_model, embedding_model
|
rendering_utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import json
|
|
|
2 |
from collections import defaultdict
|
3 |
from typing import Dict, List, Optional, Union
|
4 |
|
@@ -7,6 +8,8 @@ from pytorch_ie.annotations import BinaryRelation
|
|
7 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
8 |
from rendering_utils_displacy import EntityRenderer
|
9 |
|
|
|
|
|
10 |
|
11 |
def render_pretty_table(
|
12 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
|
@@ -92,7 +95,7 @@ def inject_relation_data(
|
|
92 |
entity_annotation = sorted_entities[idx]
|
93 |
# sanity check
|
94 |
if str(entity_annotation) != entity.next:
|
95 |
-
|
96 |
entity["data-label"] = entity_annotation.label
|
97 |
entity["data-relation-tails"] = json.dumps(
|
98 |
[
|
|
|
1 |
import json
|
2 |
+
import logging
|
3 |
from collections import defaultdict
|
4 |
from typing import Dict, List, Optional, Union
|
5 |
|
|
|
8 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
9 |
from rendering_utils_displacy import EntityRenderer
|
10 |
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
|
14 |
def render_pretty_table(
|
15 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
|
|
|
95 |
entity_annotation = sorted_entities[idx]
|
96 |
# sanity check
|
97 |
if str(entity_annotation) != entity.next:
|
98 |
+
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
|
99 |
entity["data-label"] = entity_annotation.label
|
100 |
entity["data-relation-tails"] = json.dumps(
|
101 |
[
|
requirements.txt
CHANGED
@@ -2,3 +2,6 @@ gradio==4.36.0
|
|
2 |
prettytable==3.10.0
|
3 |
pie-modules==0.12.0
|
4 |
beautifulsoup4==4.12.3
|
|
|
|
|
|
|
|
2 |
prettytable==3.10.0
|
3 |
pie-modules==0.12.0
|
4 |
beautifulsoup4==4.12.3
|
5 |
+
datasets==2.14.4
|
6 |
+
# numpy 2.0.0 breaks the code
|
7 |
+
numpy==1.25.2
|