ArneBinder
commited on
Commit
•
148e0d6
1
Parent(s):
86277c0
Upload 9 files
Browse files- app.py +20 -5
- document_store.py +219 -85
app.py
CHANGED
@@ -134,7 +134,7 @@ def upload_processed_documents(
|
|
134 |
file_name: str,
|
135 |
document_store: DocumentStore,
|
136 |
) -> pd.DataFrame:
|
137 |
-
document_store.
|
138 |
return document_store.overview()
|
139 |
|
140 |
|
@@ -175,7 +175,9 @@ def main():
|
|
175 |
}
|
176 |
|
177 |
with gr.Blocks() as demo:
|
178 |
-
document_store_state = gr.State(
|
|
|
|
|
179 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
180 |
models_state = gr.State((argumentation_model, embedding_model, embedding_tokenizer))
|
181 |
with gr.Row():
|
@@ -342,7 +344,10 @@ def main():
|
|
342 |
)
|
343 |
|
344 |
retrieve_relevant_adus_event_kwargs = dict(
|
345 |
-
fn=partial(
|
|
|
|
|
|
|
346 |
inputs=[
|
347 |
document_store_state,
|
348 |
selected_adu_id,
|
@@ -355,13 +360,23 @@ def main():
|
|
355 |
)
|
356 |
|
357 |
selected_adu_id.change(
|
358 |
-
fn=partial(
|
|
|
|
|
|
|
|
|
359 |
inputs=[document_state, selected_adu_id],
|
360 |
outputs=[selected_adu_text],
|
361 |
).success(**retrieve_relevant_adus_event_kwargs)
|
362 |
|
363 |
retrieve_similar_adus_btn.click(
|
364 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
inputs=[
|
366 |
document_store_state,
|
367 |
selected_adu_id,
|
|
|
134 |
file_name: str,
|
135 |
document_store: DocumentStore,
|
136 |
) -> pd.DataFrame:
|
137 |
+
document_store.add_documents_from_json(file_name)
|
138 |
return document_store.overview()
|
139 |
|
140 |
|
|
|
175 |
}
|
176 |
|
177 |
with gr.Blocks() as demo:
|
178 |
+
document_store_state = gr.State(
|
179 |
+
DocumentStore(span_annotation_caption="adu", relation_annotation_caption="relation")
|
180 |
+
)
|
181 |
# wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
|
182 |
models_state = gr.State((argumentation_model, embedding_model, embedding_tokenizer))
|
183 |
with gr.Row():
|
|
|
344 |
)
|
345 |
|
346 |
retrieve_relevant_adus_event_kwargs = dict(
|
347 |
+
fn=partial(
|
348 |
+
DocumentStore.get_related_annotations_from_other_documents_df,
|
349 |
+
columns=relevant_adus.headers,
|
350 |
+
),
|
351 |
inputs=[
|
352 |
document_store_state,
|
353 |
selected_adu_id,
|
|
|
360 |
)
|
361 |
|
362 |
selected_adu_id.change(
|
363 |
+
fn=partial(
|
364 |
+
get_annotation_from_document,
|
365 |
+
annotation_layer="labeled_spans",
|
366 |
+
use_predictions=True,
|
367 |
+
),
|
368 |
inputs=[document_state, selected_adu_id],
|
369 |
outputs=[selected_adu_text],
|
370 |
).success(**retrieve_relevant_adus_event_kwargs)
|
371 |
|
372 |
retrieve_similar_adus_btn.click(
|
373 |
+
fn=lambda document_store, ann_id, document, min_sim, k: document_store.get_similar_annotations_df(
|
374 |
+
ref_annotation_id=ann_id,
|
375 |
+
ref_document=document,
|
376 |
+
min_similarity=min_sim,
|
377 |
+
top_k=k,
|
378 |
+
annotation_layer="labeled_spans",
|
379 |
+
),
|
380 |
inputs=[
|
381 |
document_store_state,
|
382 |
selected_adu_id,
|
document_store.py
CHANGED
@@ -6,21 +6,45 @@ from typing import Dict, List, Optional, Tuple
|
|
6 |
import gradio as gr
|
7 |
import pandas as pd
|
8 |
from annotation_utils import labeled_span_to_id
|
9 |
-
from pytorch_ie
|
10 |
-
from pytorch_ie.documents import
|
|
|
|
|
|
|
11 |
from vector_store import SimpleVectorStore, VectorStore
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
15 |
|
16 |
def get_annotation_from_document(
|
17 |
-
document:
|
18 |
annotation_id: str,
|
19 |
annotation_layer: str,
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
annotation = id2annotation.get(annotation_id)
|
25 |
if annotation is None:
|
26 |
raise gr.Error(
|
@@ -30,53 +54,165 @@ def get_annotation_from_document(
|
|
30 |
return annotation
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
class DocumentStore:
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
# The annotated documents. As key, we use the document id. All documents keep the embeddings
|
39 |
-
# of the
|
40 |
-
self.documents: Dict[
|
41 |
-
|
42 |
-
] = {}
|
43 |
-
# The vector store to efficiently retrieve similar ADUs. Can be constructed from the
|
44 |
# documents.
|
45 |
self.vector_store: VectorStore[Tuple[str, str], List[float]] = (
|
46 |
vector_store or SimpleVectorStore()
|
47 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
def get_annotation(
|
50 |
self,
|
51 |
doc_id: str,
|
52 |
annotation_id: str,
|
53 |
annotation_layer: str,
|
54 |
-
|
|
|
55 |
document = self.documents.get(doc_id)
|
56 |
if document is None:
|
57 |
raise gr.Error(
|
58 |
f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}"
|
59 |
)
|
60 |
-
return get_annotation_from_document(
|
|
|
|
|
61 |
|
62 |
-
def
|
63 |
self,
|
64 |
ref_annotation_id: str,
|
65 |
-
ref_document:
|
66 |
-
|
67 |
-
|
68 |
) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
similar_entries = self.vector_store.retrieve_similar(
|
70 |
ref_id=(ref_document.id, ref_annotation_id),
|
71 |
-
|
72 |
-
top_k=top_k,
|
73 |
)
|
74 |
|
75 |
similar_annotations = [
|
76 |
self.get_annotation(
|
77 |
doc_id=doc_id,
|
78 |
annotation_id=annotation_id,
|
79 |
-
annotation_layer=
|
|
|
80 |
)
|
81 |
for (doc_id, annotation_id), _ in similar_entries
|
82 |
]
|
@@ -89,20 +225,38 @@ class DocumentStore:
|
|
89 |
similar_entries, similar_annotations
|
90 |
)
|
91 |
],
|
92 |
-
columns=["doc_id", "
|
93 |
)
|
94 |
|
95 |
return df
|
96 |
|
97 |
-
def
|
98 |
self,
|
99 |
ref_annotation_id: str,
|
100 |
-
ref_document:
|
101 |
min_similarity: float,
|
102 |
top_k: int,
|
103 |
relation_types: List[str],
|
104 |
columns: List[str],
|
105 |
) -> pd.DataFrame:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
similar_entries = self.vector_store.retrieve_similar(
|
107 |
ref_id=(ref_document.id, ref_annotation_id),
|
108 |
min_similarity=min_similarity,
|
@@ -114,41 +268,29 @@ class DocumentStore:
|
|
114 |
if doc_id == ref_document.id:
|
115 |
continue
|
116 |
document = self.documents[doc_id]
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
result.append(
|
135 |
-
{
|
136 |
-
"doc_id": doc_id,
|
137 |
-
"reference_adu": str(annotation),
|
138 |
-
"sim_score": score,
|
139 |
-
"rel_score": rel.score,
|
140 |
-
"relation": rel.label,
|
141 |
-
"adu": str(rel.tail),
|
142 |
-
}
|
143 |
-
)
|
144 |
|
145 |
# define column order
|
146 |
df = pd.DataFrame(result, columns=columns)
|
147 |
return df
|
148 |
|
149 |
-
def add_document(
|
150 |
-
self, document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
151 |
-
) -> None:
|
152 |
try:
|
153 |
if document.id in self.documents:
|
154 |
gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
|
@@ -157,61 +299,53 @@ class DocumentStore:
|
|
157 |
self.documents[document.id] = document
|
158 |
|
159 |
# save the embeddings to the vector store
|
160 |
-
for
|
161 |
-
self.vector_store.save((document.id,
|
162 |
|
163 |
except Exception as e:
|
164 |
raise gr.Error(f"Failed to add document {document.id} to index: {e}")
|
165 |
|
166 |
def add_document_from_dict(self, document_dict: dict) -> None:
|
167 |
-
document = self.
|
168 |
# metadata is not automatically deserialized, so we need to set it manually
|
169 |
document.metadata = document_dict["metadata"]
|
170 |
self.add_document(document)
|
171 |
|
172 |
-
def add_documents(
|
173 |
-
self, documents: List[TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions]
|
174 |
-
) -> None:
|
175 |
-
size_before = len(self.documents)
|
176 |
for document in documents:
|
177 |
self.add_document(document)
|
178 |
-
size_after = len(self.documents)
|
179 |
gr.Info(
|
180 |
-
f"Added {
|
181 |
)
|
182 |
|
183 |
-
def
|
184 |
-
size_before = len(self.documents)
|
185 |
with open(file_path, "r", encoding="utf-8") as f:
|
186 |
-
|
187 |
-
for _, document_json in
|
188 |
self.add_document_from_dict(document_dict=document_json)
|
189 |
-
size_after = len(self.documents)
|
190 |
gr.Info(
|
191 |
-
f"Added {
|
192 |
)
|
193 |
|
194 |
def save_to_json(self, file_path: str, **kwargs) -> None:
|
195 |
with open(file_path, "w", encoding="utf-8") as f:
|
196 |
json.dump(self.as_dict(), f, **kwargs)
|
197 |
|
198 |
-
def get_document(
|
199 |
-
self, doc_id: str
|
200 |
-
) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
|
201 |
return self.documents[doc_id]
|
202 |
|
203 |
def overview(self) -> pd.DataFrame:
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
for
|
212 |
-
|
213 |
-
|
214 |
-
)
|
215 |
return df
|
216 |
|
217 |
def as_dict(self) -> dict:
|
|
|
6 |
import gradio as gr
|
7 |
import pandas as pd
|
8 |
from annotation_utils import labeled_span_to_id
|
9 |
+
from pytorch_ie import Annotation
|
10 |
+
from pytorch_ie.documents import (
|
11 |
+
TextBasedDocument,
|
12 |
+
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
13 |
+
)
|
14 |
from vector_store import SimpleVectorStore, VectorStore
|
15 |
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
|
19 |
def get_annotation_from_document(
|
20 |
+
document: TextBasedDocument,
|
21 |
annotation_id: str,
|
22 |
annotation_layer: str,
|
23 |
+
use_predictions: bool,
|
24 |
+
) -> Annotation:
|
25 |
+
"""Get an annotation from a document by its id. Note that the annotation id is constructed from
|
26 |
+
the annotation itself, so it is unique within the document.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
document: The document to get the annotation from.
|
30 |
+
annotation_id: The id of the annotation.
|
31 |
+
annotation_layer: The name of the annotation layer.
|
32 |
+
use_predictions: Whether to use the predictions of the annotation layer.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
The annotation with the given id.
|
36 |
+
"""
|
37 |
+
|
38 |
+
annotations = document[annotation_layer]
|
39 |
+
if use_predictions:
|
40 |
+
annotations = annotations.predictions
|
41 |
+
|
42 |
+
if annotation_layer == "labeled_spans":
|
43 |
+
annotation_to_id_func = labeled_span_to_id
|
44 |
+
else:
|
45 |
+
raise gr.Error(f"Unknown annotation layer '{annotation_layer}'.")
|
46 |
+
|
47 |
+
id2annotation = {annotation_to_id_func(annotation): annotation for annotation in annotations}
|
48 |
annotation = id2annotation.get(annotation_id)
|
49 |
if annotation is None:
|
50 |
raise gr.Error(
|
|
|
54 |
return annotation
|
55 |
|
56 |
|
57 |
+
def get_related_annotation_records_from_document(
|
58 |
+
document: TextBasedDocument,
|
59 |
+
reference_annotation: Annotation,
|
60 |
+
relation_layer_name: str,
|
61 |
+
use_predictions: bool,
|
62 |
+
annotation_caption: str,
|
63 |
+
relation_types: Optional[List[str]] = None,
|
64 |
+
additional_static_columns: Optional[Dict[str, str]] = None,
|
65 |
+
) -> List[Dict[str, str]]:
|
66 |
+
"""Get related annotations from a document for a given reference annotation. The related
|
67 |
+
annotations are all annotations that are targets (tails) of relations with the reference
|
68 |
+
annotation as source (head).
|
69 |
+
|
70 |
+
Args:
|
71 |
+
document: The document to get the related annotations from.
|
72 |
+
reference_annotation: The reference annotation. Should be an annotation from the document.
|
73 |
+
relation_layer_name: The name of the relation layer.
|
74 |
+
use_predictions: Whether to use the predictions of the relation layer.
|
75 |
+
annotation_caption: The caption for the related annotations in the result.
|
76 |
+
relation_types: The types of relations to consider. If None, all relation types are considered.
|
77 |
+
additional_static_columns: Additional static columns to add to the result.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
A list of dictionaries with the related annotations and additional columns.
|
81 |
+
"""
|
82 |
+
|
83 |
+
result = []
|
84 |
+
|
85 |
+
# get the relation layer
|
86 |
+
relation_layer = document[relation_layer_name]
|
87 |
+
if use_predictions:
|
88 |
+
relation_layer = relation_layer.predictions
|
89 |
+
|
90 |
+
# create helper dictionaries to quickly find related annotations
|
91 |
+
tail2rels = defaultdict(list)
|
92 |
+
head2rels = defaultdict(list)
|
93 |
+
for rel in relation_layer:
|
94 |
+
# skip non-argumentative relations
|
95 |
+
if relation_types is not None and rel.label not in relation_types:
|
96 |
+
continue
|
97 |
+
head2rels[rel.head].append(rel)
|
98 |
+
tail2rels[rel.tail].append(rel)
|
99 |
+
|
100 |
+
# get the related annotations: all annotations that are targets (tails) of relations with the reference
|
101 |
+
# annotation as source (head)
|
102 |
+
for rel in head2rels.get(reference_annotation, []):
|
103 |
+
result.append(
|
104 |
+
{
|
105 |
+
"doc_id": document.id,
|
106 |
+
f"reference_{annotation_caption}": str(reference_annotation),
|
107 |
+
"rel_score": rel.score,
|
108 |
+
"relation": rel.label,
|
109 |
+
annotation_caption: str(rel.tail),
|
110 |
+
**(additional_static_columns or {}),
|
111 |
+
}
|
112 |
+
)
|
113 |
+
return result
|
114 |
+
|
115 |
+
|
116 |
class DocumentStore:
|
117 |
+
"""A document store that allows to add, retrieve, and search for documents and annotations.
|
118 |
+
|
119 |
+
The store keeps the documents in memory and stores the embeddings of the labeled spans in a vector
|
120 |
+
store to efficiently retrieve similar or related spans.
|
121 |
|
122 |
+
Args:
|
123 |
+
vector_store: The vector store to use. If None, a new SimpleVectorStore is created.
|
124 |
+
document_type: The type of the documents to store. Should be a subclass of TextBasedDocument with
|
125 |
+
a span and a relation layer (see below).
|
126 |
+
span_layer_name: The name of the span annotation layer. This should be a valid annotation layer
|
127 |
+
of type LabelSpan in the document type.
|
128 |
+
relation_layer_name: The name of the argumentative relation annotation layer. This should be a
|
129 |
+
valid annotation layer of type BinaryRelation in the document type.
|
130 |
+
span_annotation_caption: The caption for the span annotations (e.g. in the statistical overview)
|
131 |
+
relation_annotation_caption: The caption for the relation annotations (e.g. in the statistical
|
132 |
+
overview)
|
133 |
+
use_predictions: Whether to use the predictions of the annotation layers. If True, the predictions
|
134 |
+
are used, otherwise the gold annotations are used.
|
135 |
+
"""
|
136 |
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
vector_store: Optional[VectorStore[Tuple[str, str], List[float]]] = None,
|
140 |
+
document_type: type[
|
141 |
+
TextBasedDocument
|
142 |
+
] = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
143 |
+
span_layer_name: str = "labeled_spans",
|
144 |
+
relation_layer_name: str = "binary_relations",
|
145 |
+
span_annotation_caption: str = "span",
|
146 |
+
relation_annotation_caption: str = "relation",
|
147 |
+
use_predictions: bool = True,
|
148 |
+
):
|
149 |
# The annotated documents. As key, we use the document id. All documents keep the embeddings
|
150 |
+
# of the spans in the metadata.
|
151 |
+
self.documents: Dict[str, TextBasedDocument] = {}
|
152 |
+
# The vector store to efficiently retrieve similar spans. Can be constructed from the
|
|
|
|
|
153 |
# documents.
|
154 |
self.vector_store: VectorStore[Tuple[str, str], List[float]] = (
|
155 |
vector_store or SimpleVectorStore()
|
156 |
)
|
157 |
+
# the document type (to create new documents from dicts)
|
158 |
+
self.document_type = document_type
|
159 |
+
self.span_layer_name = span_layer_name
|
160 |
+
self.relation_layer_name = relation_layer_name
|
161 |
+
self.use_predictions = use_predictions
|
162 |
+
self.layer_captions = {
|
163 |
+
self.span_layer_name: span_annotation_caption,
|
164 |
+
self.relation_layer_name: relation_annotation_caption,
|
165 |
+
}
|
166 |
|
167 |
def get_annotation(
|
168 |
self,
|
169 |
doc_id: str,
|
170 |
annotation_id: str,
|
171 |
annotation_layer: str,
|
172 |
+
use_predictions: bool,
|
173 |
+
) -> Annotation:
|
174 |
document = self.documents.get(doc_id)
|
175 |
if document is None:
|
176 |
raise gr.Error(
|
177 |
f"Document '{doc_id}' not found in index. Available documents: {list(self.documents)}"
|
178 |
)
|
179 |
+
return get_annotation_from_document(
|
180 |
+
document, annotation_id, annotation_layer, use_predictions=use_predictions
|
181 |
+
)
|
182 |
|
183 |
+
def get_similar_annotations_df(
|
184 |
self,
|
185 |
ref_annotation_id: str,
|
186 |
+
ref_document: TextBasedDocument,
|
187 |
+
annotation_layer: str,
|
188 |
+
**similarity_kwargs,
|
189 |
) -> pd.DataFrame:
|
190 |
+
"""Get similar annotations from documents in the store sorted by similarity. Usually, the
|
191 |
+
reference annotation is returned as the most similar annotation.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
ref_annotation_id: The id of the reference annotation.
|
195 |
+
ref_document: The document of the reference annotation.
|
196 |
+
annotation_layer: The name of the annotation layer to consider.
|
197 |
+
**similarity_kwargs: Additional keyword arguments that will be passed to the vector
|
198 |
+
store to retrieve similar entries (see VectorStore.retrieve_similar()).
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
A DataFrame with the similar annotations with columns: doc_id, annotation_id, sim_score,
|
202 |
+
and text.
|
203 |
+
"""
|
204 |
+
|
205 |
similar_entries = self.vector_store.retrieve_similar(
|
206 |
ref_id=(ref_document.id, ref_annotation_id),
|
207 |
+
**similarity_kwargs,
|
|
|
208 |
)
|
209 |
|
210 |
similar_annotations = [
|
211 |
self.get_annotation(
|
212 |
doc_id=doc_id,
|
213 |
annotation_id=annotation_id,
|
214 |
+
annotation_layer=annotation_layer,
|
215 |
+
use_predictions=self.use_predictions,
|
216 |
)
|
217 |
for (doc_id, annotation_id), _ in similar_entries
|
218 |
]
|
|
|
225 |
similar_entries, similar_annotations
|
226 |
)
|
227 |
],
|
228 |
+
columns=["doc_id", "annotation_id", "sim_score", "text"],
|
229 |
)
|
230 |
|
231 |
return df
|
232 |
|
233 |
+
def get_related_annotations_from_other_documents_df(
|
234 |
self,
|
235 |
ref_annotation_id: str,
|
236 |
+
ref_document: TextBasedDocument,
|
237 |
min_similarity: float,
|
238 |
top_k: int,
|
239 |
relation_types: List[str],
|
240 |
columns: List[str],
|
241 |
) -> pd.DataFrame:
|
242 |
+
"""Get related annotations from documents in the store for a given reference annotation.
|
243 |
+
First, similar annotations are retrieved from the vector store. Then, annotations that are
|
244 |
+
linked to them via relations are returned. Only annotations from other documents are
|
245 |
+
considered.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
ref_annotation_id: The id of the reference annotation.
|
249 |
+
ref_document: The document of the reference annotation.
|
250 |
+
min_similarity: The minimum similarity score to consider.
|
251 |
+
top_k: The number of related annotations to return.
|
252 |
+
relation_types: The types of relations to consider.
|
253 |
+
columns: The columns to include in the result DataFrame.
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
A DataFrame with the columns that contain: the related annotation, the relation type,
|
257 |
+
the similar annotation, the similarity score, and the relation score.
|
258 |
+
"""
|
259 |
+
|
260 |
similar_entries = self.vector_store.retrieve_similar(
|
261 |
ref_id=(ref_document.id, ref_annotation_id),
|
262 |
min_similarity=min_similarity,
|
|
|
268 |
if doc_id == ref_document.id:
|
269 |
continue
|
270 |
document = self.documents[doc_id]
|
271 |
+
reference_annotation = get_annotation_from_document(
|
272 |
+
document=document,
|
273 |
+
annotation_id=annotation_id,
|
274 |
+
annotation_layer=self.span_layer_name,
|
275 |
+
use_predictions=self.use_predictions,
|
276 |
+
)
|
277 |
+
|
278 |
+
new_entries = get_related_annotation_records_from_document(
|
279 |
+
document=document,
|
280 |
+
reference_annotation=reference_annotation,
|
281 |
+
relation_types=relation_types,
|
282 |
+
relation_layer_name=self.relation_layer_name,
|
283 |
+
use_predictions=self.use_predictions,
|
284 |
+
annotation_caption=self.layer_captions[self.span_layer_name],
|
285 |
+
additional_static_columns={"sim_score": str(score)},
|
286 |
+
)
|
287 |
+
result.extend(new_entries)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
# define column order
|
290 |
df = pd.DataFrame(result, columns=columns)
|
291 |
return df
|
292 |
|
293 |
+
def add_document(self, document: TextBasedDocument) -> None:
|
|
|
|
|
294 |
try:
|
295 |
if document.id in self.documents:
|
296 |
gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
|
|
|
299 |
self.documents[document.id] = document
|
300 |
|
301 |
# save the embeddings to the vector store
|
302 |
+
for annotation_id, embedding in document.metadata["embeddings"].items():
|
303 |
+
self.vector_store.save((document.id, annotation_id), embedding)
|
304 |
|
305 |
except Exception as e:
|
306 |
raise gr.Error(f"Failed to add document {document.id} to index: {e}")
|
307 |
|
308 |
def add_document_from_dict(self, document_dict: dict) -> None:
|
309 |
+
document = self.document_type.fromdict(document_dict)
|
310 |
# metadata is not automatically deserialized, so we need to set it manually
|
311 |
document.metadata = document_dict["metadata"]
|
312 |
self.add_document(document)
|
313 |
|
314 |
+
def add_documents(self, documents: List[TextBasedDocument]) -> None:
|
|
|
|
|
|
|
315 |
for document in documents:
|
316 |
self.add_document(document)
|
|
|
317 |
gr.Info(
|
318 |
+
f"Added {len(documents)} documents to the index ({len(self.documents)} documents in total)."
|
319 |
)
|
320 |
|
321 |
+
def add_documents_from_json(self, file_path: str) -> None:
|
|
|
322 |
with open(file_path, "r", encoding="utf-8") as f:
|
323 |
+
documents_json = json.load(f)
|
324 |
+
for _, document_json in documents_json.items():
|
325 |
self.add_document_from_dict(document_dict=document_json)
|
|
|
326 |
gr.Info(
|
327 |
+
f"Added {len(documents_json)} documents to the index ({len(self.documents)} documents in total)."
|
328 |
)
|
329 |
|
330 |
def save_to_json(self, file_path: str, **kwargs) -> None:
|
331 |
with open(file_path, "w", encoding="utf-8") as f:
|
332 |
json.dump(self.as_dict(), f, **kwargs)
|
333 |
|
334 |
+
def get_document(self, doc_id: str) -> TextBasedDocument:
|
|
|
|
|
335 |
return self.documents[doc_id]
|
336 |
|
337 |
def overview(self) -> pd.DataFrame:
|
338 |
+
rows = []
|
339 |
+
for doc_id, document in self.documents.items():
|
340 |
+
layers = {
|
341 |
+
caption: document[layer_name]
|
342 |
+
for layer_name, caption in self.layer_captions.items()
|
343 |
+
}
|
344 |
+
if self.use_predictions:
|
345 |
+
layers = {caption: layer.predictions for caption, layer in layers.items()}
|
346 |
+
layer_sizes = {f"num_{caption}s": len(layer) for caption, layer in layers.items()}
|
347 |
+
rows.append({"doc_id": doc_id, **layer_sizes})
|
348 |
+
df = pd.DataFrame(rows)
|
349 |
return df
|
350 |
|
351 |
def as_dict(self) -> dict:
|