Commit
•
86277c0
1
Parent(s):
04ce9af
Upload 9 files
Browse files- annotation_utils.py +10 -0
- app.py +8 -49
- document_store.py +218 -0
- model_utils.py +173 -0
- rendering_utils.py +2 -10
- vector_store.py +18 -3
annotation_utils.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pytorch_ie.annotations import LabeledSpan
|
2 |
+
|
3 |
+
|
4 |
+
def labeled_span_to_id(span: LabeledSpan) -> str:
|
5 |
+
return f"span-{span.start}-{span.end}-{span.label}"
|
6 |
+
|
7 |
+
|
8 |
+
def labeled_span_from_id(span_id: str) -> LabeledSpan:
|
9 |
+
parts = span_id.split("-")
|
10 |
+
return LabeledSpan(int(parts[1]), int(parts[2]), parts[3])
|
app.py
CHANGED
@@ -7,13 +7,13 @@ from typing import List, Optional, Tuple
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import pandas as pd
|
10 |
-
from
|
|
|
11 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
12 |
from pytorch_ie import Pipeline
|
13 |
-
from pytorch_ie.auto import AutoPipeline
|
14 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
15 |
from rendering_utils import render_displacy, render_pretty_table
|
16 |
-
from transformers import
|
17 |
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
@@ -66,6 +66,7 @@ def process_uploaded_files(
|
|
66 |
document_store: DocumentStore,
|
67 |
) -> pd.DataFrame:
|
68 |
try:
|
|
|
69 |
for file_name in file_names:
|
70 |
if file_name.lower().endswith(".txt"):
|
71 |
# read the file content
|
@@ -73,10 +74,10 @@ def process_uploaded_files(
|
|
73 |
text = f.read()
|
74 |
base_file_name = os.path.basename(file_name)
|
75 |
gr.Info(f"Processing file '{base_file_name}' ...")
|
76 |
-
|
77 |
-
document_store.add_document(document)
|
78 |
else:
|
79 |
raise gr.Error(f"Unsupported file format: {file_name}")
|
|
|
80 |
except Exception as e:
|
81 |
raise gr.Error(f"Failed to process uploaded files: {e}")
|
82 |
|
@@ -91,43 +92,6 @@ def close_accordion():
|
|
91 |
return gr.Accordion(open=False)
|
92 |
|
93 |
|
94 |
-
def load_argumentation_model(model_name: str, revision: Optional[str] = None) -> Pipeline:
|
95 |
-
try:
|
96 |
-
model = AutoPipeline.from_pretrained(
|
97 |
-
model_name,
|
98 |
-
device=-1,
|
99 |
-
num_workers=0,
|
100 |
-
taskmodule_kwargs=dict(revision=revision),
|
101 |
-
model_kwargs=dict(revision=revision),
|
102 |
-
)
|
103 |
-
except Exception as e:
|
104 |
-
raise gr.Error(f"Failed to load argumentation model: {e}")
|
105 |
-
gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})")
|
106 |
-
return model
|
107 |
-
|
108 |
-
|
109 |
-
def load_embedding_model(model_name: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
110 |
-
try:
|
111 |
-
embedding_model = AutoModel.from_pretrained(model_name)
|
112 |
-
embedding_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
113 |
-
except Exception as e:
|
114 |
-
raise gr.Error(f"Failed to load embedding model: {e}")
|
115 |
-
gr.Info(f"Loaded embedding model: model_name={model_name})")
|
116 |
-
return embedding_model, embedding_tokenizer
|
117 |
-
|
118 |
-
|
119 |
-
def load_models(
|
120 |
-
model_name: str, revision: Optional[str] = None, embedding_model_name: Optional[str] = None
|
121 |
-
) -> Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]]:
|
122 |
-
argumentation_model = load_argumentation_model(model_name, revision)
|
123 |
-
embedding_model = None
|
124 |
-
embedding_tokenizer = None
|
125 |
-
if embedding_model_name is not None and embedding_model_name.strip():
|
126 |
-
embedding_model, embedding_tokenizer = load_embedding_model(embedding_model_name)
|
127 |
-
|
128 |
-
return argumentation_model, embedding_model, embedding_tokenizer
|
129 |
-
|
130 |
-
|
131 |
def select_processed_document(
|
132 |
evt: gr.SelectData,
|
133 |
processed_documents_df: pd.DataFrame,
|
@@ -135,7 +99,6 @@ def select_processed_document(
|
|
135 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
136 |
row_idx, col_idx = evt.index
|
137 |
doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
|
138 |
-
gr.Info(f"Select document: {doc_id}")
|
139 |
doc = document_store.get_document(doc_id)
|
140 |
return doc
|
141 |
|
@@ -163,8 +126,7 @@ def download_processed_documents(
|
|
163 |
file_name: str = "processed_documents.json",
|
164 |
) -> str:
|
165 |
file_path = os.path.join(tempfile.gettempdir(), file_name)
|
166 |
-
|
167 |
-
json.dump(document_store.as_dict(), f, indent=2)
|
168 |
return file_path
|
169 |
|
170 |
|
@@ -172,10 +134,7 @@ def upload_processed_documents(
|
|
172 |
file_name: str,
|
173 |
document_store: DocumentStore,
|
174 |
) -> pd.DataFrame:
|
175 |
-
|
176 |
-
processed_documents_json = json.load(f)
|
177 |
-
for _, document_json in processed_documents_json.items():
|
178 |
-
document_store.add_document_from_dict(document_dict=document_json)
|
179 |
return document_store.overview()
|
180 |
|
181 |
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import pandas as pd
|
10 |
+
from document_store import DocumentStore, get_annotation_from_document
|
11 |
+
from model_utils import create_and_annotate_document, load_models
|
12 |
from pie_modules.taskmodules import PointerNetworkTaskModuleForEnd2EndRE
|
13 |
from pytorch_ie import Pipeline
|
|
|
14 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
15 |
from rendering_utils import render_displacy, render_pretty_table
|
16 |
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
17 |
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
|
|
66 |
document_store: DocumentStore,
|
67 |
) -> pd.DataFrame:
|
68 |
try:
|
69 |
+
new_documents = []
|
70 |
for file_name in file_names:
|
71 |
if file_name.lower().endswith(".txt"):
|
72 |
# read the file content
|
|
|
74 |
text = f.read()
|
75 |
base_file_name = os.path.basename(file_name)
|
76 |
gr.Info(f"Processing file '{base_file_name}' ...")
|
77 |
+
new_documents.append(create_and_annotate_document(text, base_file_name, models))
|
|
|
78 |
else:
|
79 |
raise gr.Error(f"Unsupported file format: {file_name}")
|
80 |
+
document_store.add_documents(new_documents)
|
81 |
except Exception as e:
|
82 |
raise gr.Error(f"Failed to process uploaded files: {e}")
|
83 |
|
|
|
92 |
return gr.Accordion(open=False)
|
93 |
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
def select_processed_document(
|
96 |
evt: gr.SelectData,
|
97 |
processed_documents_df: pd.DataFrame,
|
|
|
99 |
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
100 |
row_idx, col_idx = evt.index
|
101 |
doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
|
|
|
102 |
doc = document_store.get_document(doc_id)
|
103 |
return doc
|
104 |
|
|
|
126 |
file_name: str = "processed_documents.json",
|
127 |
) -> str:
|
128 |
file_path = os.path.join(tempfile.gettempdir(), file_name)
|
129 |
+
document_store.save_to_json(file_path, indent=2)
|
|
|
130 |
return file_path
|
131 |
|
132 |
|
|
|
134 |
file_name: str,
|
135 |
document_store: DocumentStore,
|
136 |
) -> pd.DataFrame:
|
137 |
+
document_store.add_from_json(file_name)
|
|
|
|
|
|
|
138 |
return document_store.overview()
|
139 |
|
140 |
|
document_store.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from collections import defaultdict
|
4 |
+
from typing import Dict, List, Optional, Tuple
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
import pandas as pd
|
8 |
+
from annotation_utils import labeled_span_to_id
|
9 |
+
from pytorch_ie.annotations import LabeledSpan
|
10 |
+
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
11 |
+
from vector_store import SimpleVectorStore, VectorStore
|
12 |
+
|
13 |
+
logger = logging.getLogger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
def get_annotation_from_document(
|
17 |
+
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
18 |
+
annotation_id: str,
|
19 |
+
annotation_layer: str,
|
20 |
+
) -> LabeledSpan:
|
21 |
+
# use predictions
|
22 |
+
annotations = document[annotation_layer].predictions
|
23 |
+
id2annotation = {labeled_span_to_id(annotation): annotation for annotation in annotations}
|
24 |
+
annotation = id2annotation.get(annotation_id)
|
25 |
+
if annotation is None:
|
26 |
+
raise gr.Error(
|
27 |
+
f"annotation '{annotation_id}' not found in document '{document.id}'. Available "
|
28 |
+
f"annotations: {id2annotation}"
|
29 |
+
)
|
30 |
+
return annotation
|
31 |
+
|
32 |
+
|
33 |
+
class DocumentStore:
|
34 |
+
|
35 |
+
DOCUMENT_TYPE = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
36 |
+
|
37 |
+
def __init__(self, vector_store: Optional[VectorStore[Tuple[str, str], List[float]]] = None):
|
38 |
+
# The annotated documents. As key, we use the document id. All documents keep the embeddings
|
39 |
+
# of the ADUs in the metadata.
|
40 |
+
self.documents: Dict[
|
41 |
+
str, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
42 |
+
] = {}
|
43 |
+
# The vector store to efficiently retrieve similar ADUs. Can be constructed from the
|
44 |
+
# documents.
|
45 |
+
self.vector_store: VectorStore[Tuple[str, str], List[float]] = (
|
46 |
+
vector_store or SimpleVectorStore()
|
47 |
+
)
|
48 |
+
|
49 |
+
def get_annotation(
|
50 |
+
self,
|
51 |
+
doc_id: str,
|
52 |
+
annotation_id: str,
|
53 |
+
annotation_layer: str,
|
54 |
+
) -> LabeledSpan:
|
55 |
+
document = self.documents.get(doc_id)
|
56 |
+
if document is None:
|
57 |
+
raise gr.Error(
|
58 |
+
f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}"
|
59 |
+
)
|
60 |
+
return get_annotation_from_document(document, annotation_id, annotation_layer)
|
61 |
+
|
62 |
+
def get_similar_adus_df(
|
63 |
+
self,
|
64 |
+
ref_annotation_id: str,
|
65 |
+
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
66 |
+
min_similarity: float,
|
67 |
+
top_k: int,
|
68 |
+
) -> pd.DataFrame:
|
69 |
+
similar_entries = self.vector_store.retrieve_similar(
|
70 |
+
ref_id=(ref_document.id, ref_annotation_id),
|
71 |
+
min_similarity=min_similarity,
|
72 |
+
top_k=top_k,
|
73 |
+
)
|
74 |
+
|
75 |
+
similar_annotations = [
|
76 |
+
self.get_annotation(
|
77 |
+
doc_id=doc_id,
|
78 |
+
annotation_id=annotation_id,
|
79 |
+
annotation_layer="labeled_spans",
|
80 |
+
)
|
81 |
+
for (doc_id, annotation_id), _ in similar_entries
|
82 |
+
]
|
83 |
+
df = pd.DataFrame(
|
84 |
+
[
|
85 |
+
# unpack the tuple (doc_id, annotation_id) to separate columns
|
86 |
+
# and add the similarity score and the text of the annotation
|
87 |
+
(doc_id, annotation_id, score, str(annotation))
|
88 |
+
for ((doc_id, annotation_id), score), annotation in zip(
|
89 |
+
similar_entries, similar_annotations
|
90 |
+
)
|
91 |
+
],
|
92 |
+
columns=["doc_id", "adu_id", "sim_score", "text"],
|
93 |
+
)
|
94 |
+
|
95 |
+
return df
|
96 |
+
|
97 |
+
def get_relevant_adus_df(
|
98 |
+
self,
|
99 |
+
ref_annotation_id: str,
|
100 |
+
ref_document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
101 |
+
min_similarity: float,
|
102 |
+
top_k: int,
|
103 |
+
relation_types: List[str],
|
104 |
+
columns: List[str],
|
105 |
+
) -> pd.DataFrame:
|
106 |
+
similar_entries = self.vector_store.retrieve_similar(
|
107 |
+
ref_id=(ref_document.id, ref_annotation_id),
|
108 |
+
min_similarity=min_similarity,
|
109 |
+
top_k=top_k,
|
110 |
+
)
|
111 |
+
result = []
|
112 |
+
for (doc_id, annotation_id), score in similar_entries:
|
113 |
+
# skip entries from the same document
|
114 |
+
if doc_id == ref_document.id:
|
115 |
+
continue
|
116 |
+
document = self.documents[doc_id]
|
117 |
+
tail2rels = defaultdict(list)
|
118 |
+
head2rels = defaultdict(list)
|
119 |
+
for rel in document.binary_relations.predictions:
|
120 |
+
# skip non-argumentative relations
|
121 |
+
if rel.label not in relation_types:
|
122 |
+
continue
|
123 |
+
head2rels[rel.head].append(rel)
|
124 |
+
tail2rels[rel.tail].append(rel)
|
125 |
+
|
126 |
+
id2annotation = {
|
127 |
+
labeled_span_to_id(annotation): annotation
|
128 |
+
for annotation in document.labeled_spans.predictions
|
129 |
+
}
|
130 |
+
annotation = id2annotation.get(annotation_id)
|
131 |
+
# note: we do not need to check if the annotation is different from the reference annotation,
|
132 |
+
# because they come from different documents and we already skip entries from the same document
|
133 |
+
for rel in head2rels.get(annotation, []):
|
134 |
+
result.append(
|
135 |
+
{
|
136 |
+
"doc_id": doc_id,
|
137 |
+
"reference_adu": str(annotation),
|
138 |
+
"sim_score": score,
|
139 |
+
"rel_score": rel.score,
|
140 |
+
"relation": rel.label,
|
141 |
+
"adu": str(rel.tail),
|
142 |
+
}
|
143 |
+
)
|
144 |
+
|
145 |
+
# define column order
|
146 |
+
df = pd.DataFrame(result, columns=columns)
|
147 |
+
return df
|
148 |
+
|
149 |
+
def add_document(
|
150 |
+
self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
151 |
+
) -> None:
|
152 |
+
try:
|
153 |
+
if document.id in self.documents:
|
154 |
+
gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
|
155 |
+
|
156 |
+
# save the processed document to the index
|
157 |
+
self.documents[document.id] = document
|
158 |
+
|
159 |
+
# save the embeddings to the vector store
|
160 |
+
for adu_id, embedding in document.metadata["embeddings"].items():
|
161 |
+
self.vector_store.save((document.id, adu_id), embedding)
|
162 |
+
|
163 |
+
except Exception as e:
|
164 |
+
raise gr.Error(f"Failed to add document {document.id} to index: {e}")
|
165 |
+
|
166 |
+
def add_document_from_dict(self, document_dict: dict) -> None:
|
167 |
+
document = self.DOCUMENT_TYPE.fromdict(document_dict)
|
168 |
+
# metadata is not automatically deserialized, so we need to set it manually
|
169 |
+
document.metadata = document_dict["metadata"]
|
170 |
+
self.add_document(document)
|
171 |
+
|
172 |
+
def add_documents(
|
173 |
+
self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
|
174 |
+
) -> None:
|
175 |
+
size_before = len(self.documents)
|
176 |
+
for document in documents:
|
177 |
+
self.add_document(document)
|
178 |
+
size_after = len(self.documents)
|
179 |
+
gr.Info(
|
180 |
+
f"Added {size_after - size_before} documents to the index ({size_after} documents in total)."
|
181 |
+
)
|
182 |
+
|
183 |
+
def add_from_json(self, file_path: str) -> None:
|
184 |
+
size_before = len(self.documents)
|
185 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
186 |
+
processed_documents_json = json.load(f)
|
187 |
+
for _, document_json in processed_documents_json.items():
|
188 |
+
self.add_document_from_dict(document_dict=document_json)
|
189 |
+
size_after = len(self.documents)
|
190 |
+
gr.Info(
|
191 |
+
f"Added {size_after - size_before} documents to the index ({size_after} documents in total)."
|
192 |
+
)
|
193 |
+
|
194 |
+
def save_to_json(self, file_path: str, **kwargs) -> None:
|
195 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
196 |
+
json.dump(self.as_dict(), f, **kwargs)
|
197 |
+
|
198 |
+
def get_document(
|
199 |
+
self, doc_id: str
|
200 |
+
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
201 |
+
return self.documents[doc_id]
|
202 |
+
|
203 |
+
def overview(self) -> pd.DataFrame:
|
204 |
+
df = pd.DataFrame(
|
205 |
+
[
|
206 |
+
(
|
207 |
+
doc_id,
|
208 |
+
len(document.labeled_spans.predictions),
|
209 |
+
len(document.binary_relations.predictions),
|
210 |
+
)
|
211 |
+
for doc_id, document in self.documents.items()
|
212 |
+
],
|
213 |
+
columns=["doc_id", "num_adus", "num_relations"],
|
214 |
+
)
|
215 |
+
return df
|
216 |
+
|
217 |
+
def as_dict(self) -> dict:
|
218 |
+
return {doc_id: document.asdict() for doc_id, document in self.documents.items()}
|
model_utils.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Dict, List, Optional, Tuple
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from annotation_utils import labeled_span_to_id
|
6 |
+
from pie_modules.document.processing import tokenize_document
|
7 |
+
from pie_modules.documents import TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
8 |
+
from pytorch_ie import Pipeline
|
9 |
+
from pytorch_ie.annotations import LabeledSpan
|
10 |
+
from pytorch_ie.auto import AutoPipeline
|
11 |
+
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
12 |
+
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
def _embed_text_annotations(
|
18 |
+
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
19 |
+
model: PreTrainedModel,
|
20 |
+
tokenizer: PreTrainedTokenizer,
|
21 |
+
text_layer_name: str,
|
22 |
+
) -> Dict[LabeledSpan, List[float]]:
|
23 |
+
# to not modify the original document
|
24 |
+
document = document.copy()
|
25 |
+
# tokenize_document does not yet consider predictions, so we need to add them manually
|
26 |
+
document[text_layer_name].extend(document[text_layer_name].predictions.clear())
|
27 |
+
added_annotations = []
|
28 |
+
tokenizer_kwargs = {
|
29 |
+
"max_length": 512,
|
30 |
+
"stride": 64,
|
31 |
+
"truncation": True,
|
32 |
+
"return_overflowing_tokens": True,
|
33 |
+
}
|
34 |
+
tokenized_documents = tokenize_document(
|
35 |
+
document,
|
36 |
+
tokenizer=tokenizer,
|
37 |
+
result_document_type=TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
38 |
+
partition_layer="labeled_partitions",
|
39 |
+
added_annotations=added_annotations,
|
40 |
+
strict_span_conversion=False,
|
41 |
+
**tokenizer_kwargs,
|
42 |
+
)
|
43 |
+
# just tokenize again to get tensors in the correct format for the model
|
44 |
+
model_inputs = tokenizer(document.text, return_tensors="pt", **tokenizer_kwargs)
|
45 |
+
# this is added when using return_overflowing_tokens=True, but the model does not accept it
|
46 |
+
model_inputs.pop("overflow_to_sample_mapping", None)
|
47 |
+
assert len(model_inputs.encodings) == len(tokenized_documents)
|
48 |
+
model_output = model(**model_inputs)
|
49 |
+
|
50 |
+
# get embeddings for all text annotations
|
51 |
+
embeddings = {}
|
52 |
+
for batch_idx in range(len(model_output.last_hidden_state)):
|
53 |
+
text2tok_ann = added_annotations[batch_idx][text_layer_name]
|
54 |
+
tok2text_ann = {v: k for k, v in text2tok_ann.items()}
|
55 |
+
for tok_ann in tokenized_documents[batch_idx].labeled_spans:
|
56 |
+
# skip "empty" annotations
|
57 |
+
if tok_ann.start == tok_ann.end:
|
58 |
+
continue
|
59 |
+
# use the max pooling strategy to get a single embedding for the annotation text
|
60 |
+
embedding = model_output.last_hidden_state[batch_idx, tok_ann.start : tok_ann.end].max(
|
61 |
+
dim=0
|
62 |
+
)[0]
|
63 |
+
text_ann = tok2text_ann[tok_ann]
|
64 |
+
|
65 |
+
if text_ann in embeddings:
|
66 |
+
logger.warning(
|
67 |
+
f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
|
68 |
+
)
|
69 |
+
embeddings[text_ann] = embedding
|
70 |
+
|
71 |
+
return embeddings
|
72 |
+
|
73 |
+
|
74 |
+
def _annotate(
|
75 |
+
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
76 |
+
pipeline: Pipeline,
|
77 |
+
embedding_model: Optional[PreTrainedModel] = None,
|
78 |
+
embedding_tokenizer: Optional[PreTrainedTokenizer] = None,
|
79 |
+
) -> None:
|
80 |
+
|
81 |
+
# execute prediction pipeline
|
82 |
+
pipeline(document)
|
83 |
+
|
84 |
+
if embedding_model is not None and embedding_tokenizer is not None:
|
85 |
+
adu_embeddings = _embed_text_annotations(
|
86 |
+
document=document,
|
87 |
+
model=embedding_model,
|
88 |
+
tokenizer=embedding_tokenizer,
|
89 |
+
text_layer_name="labeled_spans",
|
90 |
+
)
|
91 |
+
# convert keys to str because JSON keys must be strings
|
92 |
+
adu_embeddings_dict = {
|
93 |
+
labeled_span_to_id(k): v.detach().tolist() for k, v in adu_embeddings.items()
|
94 |
+
}
|
95 |
+
document.metadata["embeddings"] = adu_embeddings_dict
|
96 |
+
else:
|
97 |
+
gr.Warning(
|
98 |
+
"No embedding model provided. Skipping embedding extraction. You can load an embedding "
|
99 |
+
"model in the 'Model Configuration' section."
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
def create_and_annotate_document(
|
104 |
+
text: str,
|
105 |
+
doc_id: str,
|
106 |
+
models: Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]],
|
107 |
+
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
108 |
+
"""Create a TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions from the provided
|
109 |
+
text, annotate it, and add it to the index.
|
110 |
+
|
111 |
+
Parameters:
|
112 |
+
text: The text to process.
|
113 |
+
doc_id: The ID of the document.
|
114 |
+
models: A tuple containing the prediction pipeline and the embedding model and tokenizer.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
The processed document.
|
118 |
+
"""
|
119 |
+
|
120 |
+
try:
|
121 |
+
document = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions(
|
122 |
+
id=doc_id, text=text, metadata={}
|
123 |
+
)
|
124 |
+
# add single partition from the whole text (the model only considers text in partitions)
|
125 |
+
document.labeled_partitions.append(LabeledSpan(start=0, end=len(text), label="text"))
|
126 |
+
# annotate the document
|
127 |
+
_annotate(
|
128 |
+
document=document,
|
129 |
+
pipeline=models[0],
|
130 |
+
embedding_model=models[1],
|
131 |
+
embedding_tokenizer=models[2],
|
132 |
+
)
|
133 |
+
|
134 |
+
return document
|
135 |
+
except Exception as e:
|
136 |
+
raise gr.Error(f"Failed to process text: {e}")
|
137 |
+
|
138 |
+
|
139 |
+
def load_argumentation_model(model_name: str, revision: Optional[str] = None) -> Pipeline:
|
140 |
+
try:
|
141 |
+
model = AutoPipeline.from_pretrained(
|
142 |
+
model_name,
|
143 |
+
device=-1,
|
144 |
+
num_workers=0,
|
145 |
+
taskmodule_kwargs=dict(revision=revision),
|
146 |
+
model_kwargs=dict(revision=revision),
|
147 |
+
)
|
148 |
+
except Exception as e:
|
149 |
+
raise gr.Error(f"Failed to load argumentation model: {e}")
|
150 |
+
gr.Info(f"Loaded argumentation model: model_name={model_name}, revision={revision})")
|
151 |
+
return model
|
152 |
+
|
153 |
+
|
154 |
+
def load_embedding_model(model_name: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
155 |
+
try:
|
156 |
+
embedding_model = AutoModel.from_pretrained(model_name)
|
157 |
+
embedding_tokenizer = AutoTokenizer.from_pretrained(model_name)
|
158 |
+
except Exception as e:
|
159 |
+
raise gr.Error(f"Failed to load embedding model: {e}")
|
160 |
+
gr.Info(f"Loaded embedding model: model_name={model_name})")
|
161 |
+
return embedding_model, embedding_tokenizer
|
162 |
+
|
163 |
+
|
164 |
+
def load_models(
|
165 |
+
model_name: str, revision: Optional[str] = None, embedding_model_name: Optional[str] = None
|
166 |
+
) -> Tuple[Pipeline, Optional[PreTrainedModel], Optional[PreTrainedTokenizer]]:
|
167 |
+
argumentation_model = load_argumentation_model(model_name, revision)
|
168 |
+
embedding_model = None
|
169 |
+
embedding_tokenizer = None
|
170 |
+
if embedding_model_name is not None and embedding_model_name.strip():
|
171 |
+
embedding_model, embedding_tokenizer = load_embedding_model(embedding_model_name)
|
172 |
+
|
173 |
+
return argumentation_model, embedding_model, embedding_tokenizer
|
rendering_utils.py
CHANGED
@@ -2,7 +2,8 @@ import json
|
|
2 |
from collections import defaultdict
|
3 |
from typing import Dict, List, Optional, Union
|
4 |
|
5 |
-
from
|
|
|
6 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
7 |
from rendering_utils_displacy import EntityRenderer
|
8 |
|
@@ -59,15 +60,6 @@ def render_displacy(
|
|
59 |
return html
|
60 |
|
61 |
|
62 |
-
def labeled_span_to_id(span: LabeledSpan) -> str:
|
63 |
-
return f"span-{span.start}-{span.end}-{span.label}"
|
64 |
-
|
65 |
-
|
66 |
-
def labeled_span_from_id(span_id: str) -> LabeledSpan:
|
67 |
-
parts = span_id.split("-")
|
68 |
-
return LabeledSpan(int(parts[1]), int(parts[2]), parts[3])
|
69 |
-
|
70 |
-
|
71 |
def inject_relation_data(
|
72 |
html: str,
|
73 |
sorted_entities,
|
|
|
2 |
from collections import defaultdict
|
3 |
from typing import Dict, List, Optional, Union
|
4 |
|
5 |
+
from annotation_utils import labeled_span_to_id
|
6 |
+
from pytorch_ie.annotations import BinaryRelation
|
7 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
8 |
from rendering_utils_displacy import EntityRenderer
|
9 |
|
|
|
60 |
return html
|
61 |
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def inject_relation_data(
|
64 |
html: str,
|
65 |
sorted_entities,
|
vector_store.py
CHANGED
@@ -2,17 +2,32 @@ import abc
|
|
2 |
from typing import Generic, Hashable, List, Optional, Tuple, TypeVar
|
3 |
|
4 |
T = TypeVar("T", bound=Hashable)
|
|
|
5 |
|
6 |
|
7 |
-
class VectorStore(Generic[T], abc.ABC):
|
8 |
@abc.abstractmethod
|
9 |
-
def save(self, emb_id: T, embedding:
|
|
|
10 |
pass
|
11 |
|
12 |
@abc.abstractmethod
|
13 |
def retrieve_similar(
|
14 |
self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
|
15 |
) -> List[Tuple[T, float]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
pass
|
17 |
|
18 |
@abc.abstractmethod
|
@@ -28,7 +43,7 @@ def cosine_similarity(a: List[float], b: List[float]) -> float:
|
|
28 |
return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))
|
29 |
|
30 |
|
31 |
-
class SimpleVectorStore(VectorStore[T]):
|
32 |
def __init__(self):
|
33 |
self.vectors: dict[T, List[float]] = {}
|
34 |
self._cache = {}
|
|
|
2 |
from typing import Generic, Hashable, List, Optional, Tuple, TypeVar
|
3 |
|
4 |
T = TypeVar("T", bound=Hashable)
|
5 |
+
E = TypeVar("E")
|
6 |
|
7 |
|
8 |
+
class VectorStore(Generic[T, E], abc.ABC):
|
9 |
@abc.abstractmethod
|
10 |
+
def save(self, emb_id: T, embedding: E) -> None:
|
11 |
+
"""Save an embedding for a given ID."""
|
12 |
pass
|
13 |
|
14 |
@abc.abstractmethod
|
15 |
def retrieve_similar(
|
16 |
self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
|
17 |
) -> List[Tuple[T, float]]:
|
18 |
+
"""Retrieve IDs and the respective similarity scores with respect to the reference entry.
|
19 |
+
Note that this requires the reference entry to be present in the store.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
ref_id: The ID of the reference entry.
|
23 |
+
top_k: If provided, only the top-k most similar entries will be returned.
|
24 |
+
min_similarity: If provided, only entries with a similarity score greater or equal to
|
25 |
+
this value will be returned.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
A list of tuples consisting of the ID and the similarity score, sorted by similarity
|
29 |
+
score in descending order.
|
30 |
+
"""
|
31 |
pass
|
32 |
|
33 |
@abc.abstractmethod
|
|
|
43 |
return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))
|
44 |
|
45 |
|
46 |
+
class SimpleVectorStore(VectorStore[T, List[float]]):
|
47 |
def __init__(self):
|
48 |
self.vectors: dict[T, List[float]] = {}
|
49 |
self._cache = {}
|