ArneBinder
commited on
Commit
•
a8df5fb
1
Parent(s):
b0174fa
from https://github.com/ArneBinder/pie-document-level/pull/238
Browse files- app.py +4 -4
- document_store.py +108 -40
- embedding.py +4 -4
- rendering_utils.py +32 -12
- vector_store.py +9 -1
app.py
CHANGED
@@ -354,7 +354,7 @@ def main():
|
|
354 |
minimum=0.0,
|
355 |
maximum=1.0,
|
356 |
step=0.01,
|
357 |
-
value=0.
|
358 |
)
|
359 |
top_k = gr.Slider(
|
360 |
label="Top K",
|
@@ -398,10 +398,10 @@ def main():
|
|
398 |
)
|
399 |
|
400 |
show_overview_kwargs = dict(
|
401 |
-
fn=lambda document_store, show_max_sims: document_store.overview(
|
402 |
with_max_cross_doc_sims=show_max_sims
|
403 |
),
|
404 |
-
inputs=[document_store_state, show_max_cross_docu_sims],
|
405 |
outputs=[processed_documents_df],
|
406 |
)
|
407 |
predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
|
@@ -505,7 +505,7 @@ def main():
|
|
505 |
DocumentStore.get_all2all_adu_similarities,
|
506 |
columns=all2all_adu_similarities.headers,
|
507 |
),
|
508 |
-
inputs=[document_store_state],
|
509 |
outputs=[all2all_adu_similarities],
|
510 |
)
|
511 |
|
|
|
354 |
minimum=0.0,
|
355 |
maximum=1.0,
|
356 |
step=0.01,
|
357 |
+
value=0.95,
|
358 |
)
|
359 |
top_k = gr.Slider(
|
360 |
label="Top K",
|
|
|
398 |
)
|
399 |
|
400 |
show_overview_kwargs = dict(
|
401 |
+
fn=lambda document_store, show_max_sims, min_sim: document_store.overview(
|
402 |
with_max_cross_doc_sims=show_max_sims
|
403 |
),
|
404 |
+
inputs=[document_store_state, show_max_cross_docu_sims, min_similarity],
|
405 |
outputs=[processed_documents_df],
|
406 |
)
|
407 |
predict_btn.click(fn=open_accordion, inputs=[], outputs=[output_accordion]).then(
|
|
|
505 |
DocumentStore.get_all2all_adu_similarities,
|
506 |
columns=all2all_adu_similarities.headers,
|
507 |
),
|
508 |
+
inputs=[document_store_state, min_similarity],
|
509 |
outputs=[all2all_adu_similarities],
|
510 |
)
|
511 |
|
document_store.py
CHANGED
@@ -16,6 +16,7 @@ from pytorch_ie.documents import (
|
|
16 |
TextBasedDocument,
|
17 |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
18 |
)
|
|
|
19 |
from vector_store import VectorStore
|
20 |
|
21 |
logger = logging.getLogger(__name__)
|
@@ -342,6 +343,38 @@ class DocumentStore:
|
|
342 |
f"Added {len(documents_json)} documents to the index ({len(self.documents)} documents in total)."
|
343 |
)
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
def add_documents_from_zip(self, file_path: str) -> None:
|
346 |
temp_dir = os.path.join(tempfile.gettempdir(), "document_store")
|
347 |
# remove the temporary directory if it already exists
|
@@ -418,7 +451,9 @@ class DocumentStore:
|
|
418 |
|
419 |
return document
|
420 |
|
421 |
-
def overview(
|
|
|
|
|
422 |
rows = []
|
423 |
for doc_id, document in self.documents.items():
|
424 |
layers = {
|
@@ -433,13 +468,8 @@ class DocumentStore:
|
|
433 |
|
434 |
# add highest cross-document similarity score for each document
|
435 |
if with_max_cross_doc_sims and len(self.documents) > 1:
|
436 |
-
# Setting min_similarity to None and top_k to None to get all similarities. Otherwise,
|
437 |
-
# it may happen that this occludes max cross-doc sim for some documents in the
|
438 |
-
# case that there are more than top_k ADUs in the reference document that have a higher
|
439 |
-
# similarity with each other than the highest similarity to any ADU in another document
|
440 |
-
# or if the cross-doc similarity is below the min_similarity threshold.
|
441 |
all2all_adu_similarities = self.get_all2all_adu_similarities(
|
442 |
-
min_similarity=
|
443 |
)
|
444 |
max_doc2doc_similarities = all2all_adu_similarities.pivot_table(
|
445 |
values="sim_score", index="doc_id", columns="other_doc_id", aggfunc="max"
|
@@ -478,50 +508,88 @@ class DocumentStore:
|
|
478 |
def get_all2all_adu_similarities(
|
479 |
self,
|
480 |
min_similarity: Optional[float] = 0.5,
|
481 |
-
top_k: Optional[int] = 100,
|
482 |
columns: Optional[List[str]] = None,
|
483 |
) -> pd.DataFrame:
|
484 |
"""Get the similarities between all ADUs in the store.
|
485 |
|
486 |
Args:
|
487 |
-
min_similarity: The minimum similarity score to consider.
|
488 |
-
top_k: The number of similar ADUs to return.
|
489 |
columns: The columns to include in the result DataFrame. If None, all columns are included.
|
490 |
|
491 |
Returns:
|
492 |
A DataFrame with the columns: doc_id, text, other_doc_id, other_text, sim_score.
|
493 |
"""
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
503 |
)
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
other_adu = get_annotation_from_document(
|
508 |
-
other_document,
|
509 |
-
payload["annotation_id"],
|
510 |
-
self.span_layer_name,
|
511 |
-
use_predictions=self.use_predictions,
|
512 |
-
)
|
513 |
-
result.append(
|
514 |
-
{
|
515 |
-
"sim_score": score,
|
516 |
-
"doc_id": doc_id,
|
517 |
-
"other_doc_id": other_doc_id,
|
518 |
-
"adu_id": adu_id,
|
519 |
-
"other_adu_id": payload["annotation_id"],
|
520 |
-
"text": str(adu),
|
521 |
-
"other_text": str(other_adu),
|
522 |
-
}
|
523 |
-
)
|
524 |
-
result_df = pd.DataFrame(result)
|
525 |
if columns is not None:
|
526 |
result_df = result_df[columns]
|
527 |
return result_df
|
|
|
16 |
TextBasedDocument,
|
17 |
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
|
18 |
)
|
19 |
+
from scipy.sparse import csr_matrix
|
20 |
from vector_store import VectorStore
|
21 |
|
22 |
logger = logging.getLogger(__name__)
|
|
|
343 |
f"Added {len(documents_json)} documents to the index ({len(self.documents)} documents in total)."
|
344 |
)
|
345 |
|
346 |
+
def get_payloads_for_missing_and_unexpected_embeddings(self) -> dict[str, dict[str, Any]]:
|
347 |
+
"""Get the payloads for missing and unexpected embeddings in the vector store. An embedding
|
348 |
+
is missing if its annotation is in the documents but the embedding is not in the vector
|
349 |
+
store. An embedding is unexpected if it is in the vector store but the annotation is not in
|
350 |
+
the documents.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
A dictionary with the missing and unexpected payloads.
|
354 |
+
"""
|
355 |
+
expected_payloads = []
|
356 |
+
for document in self.documents.values():
|
357 |
+
for annotation in document[self.span_layer_name].predictions:
|
358 |
+
annotation_id = labeled_span_to_id(annotation)
|
359 |
+
payload = self.construct_embedding_payload(document, annotation_id)
|
360 |
+
expected_payloads.append(payload)
|
361 |
+
vector_sore_payloads = self.vector_store.as_indices_vectors_payloads()[2]
|
362 |
+
# construct mappings from ids to payloads to compare the expected and actual payloads
|
363 |
+
expected_mapping = {
|
364 |
+
json.dumps(payload, sort_keys=True): payload for payload in expected_payloads
|
365 |
+
}
|
366 |
+
vector_store_mapping = {
|
367 |
+
json.dumps(payload, sort_keys=True): payload for payload in vector_sore_payloads
|
368 |
+
}
|
369 |
+
missing = set(expected_mapping) - set(vector_store_mapping)
|
370 |
+
unexpected = set(vector_store_mapping) - set(expected_mapping)
|
371 |
+
|
372 |
+
# return the missing and unexpected payloads
|
373 |
+
return {
|
374 |
+
"missing": {payload: expected_mapping[payload] for payload in missing},
|
375 |
+
"unexpected": {payload: vector_store_mapping[payload] for payload in unexpected},
|
376 |
+
}
|
377 |
+
|
378 |
def add_documents_from_zip(self, file_path: str) -> None:
|
379 |
temp_dir = os.path.join(tempfile.gettempdir(), "document_store")
|
380 |
# remove the temporary directory if it already exists
|
|
|
451 |
|
452 |
return document
|
453 |
|
454 |
+
def overview(
|
455 |
+
self, with_max_cross_doc_sims: bool = False, min_similarity: float = 0.9
|
456 |
+
) -> pd.DataFrame:
|
457 |
rows = []
|
458 |
for doc_id, document in self.documents.items():
|
459 |
layers = {
|
|
|
468 |
|
469 |
# add highest cross-document similarity score for each document
|
470 |
if with_max_cross_doc_sims and len(self.documents) > 1:
|
|
|
|
|
|
|
|
|
|
|
471 |
all2all_adu_similarities = self.get_all2all_adu_similarities(
|
472 |
+
min_similarity=min_similarity, columns=["doc_id", "other_doc_id", "sim_score"]
|
473 |
)
|
474 |
max_doc2doc_similarities = all2all_adu_similarities.pivot_table(
|
475 |
values="sim_score", index="doc_id", columns="other_doc_id", aggfunc="max"
|
|
|
508 |
def get_all2all_adu_similarities(
|
509 |
self,
|
510 |
min_similarity: Optional[float] = 0.5,
|
|
|
511 |
columns: Optional[List[str]] = None,
|
512 |
) -> pd.DataFrame:
|
513 |
"""Get the similarities between all ADUs in the store.
|
514 |
|
515 |
Args:
|
516 |
+
min_similarity: The minimum similarity score to consider. If None, all similarities are included.
|
|
|
517 |
columns: The columns to include in the result DataFrame. If None, all columns are included.
|
518 |
|
519 |
Returns:
|
520 |
A DataFrame with the columns: doc_id, text, other_doc_id, other_text, sim_score.
|
521 |
"""
|
522 |
+
|
523 |
+
# shape of all_embeddings: (num_embeddings, embedding_dim)
|
524 |
+
(
|
525 |
+
all_embed_ids,
|
526 |
+
all_embeddings,
|
527 |
+
all_payloads,
|
528 |
+
) = self.vector_store.as_indices_vectors_payloads()
|
529 |
+
|
530 |
+
doc_id_and_annotation_id2annotation_text = {}
|
531 |
+
for doc in self.documents.values():
|
532 |
+
for annotation in doc[self.span_layer_name]:
|
533 |
+
doc_id_and_annotation_id2annotation_text[
|
534 |
+
(doc.id, labeled_span_to_id(annotation))
|
535 |
+
] = str(annotation)
|
536 |
+
for annotation in doc[self.span_layer_name].predictions:
|
537 |
+
doc_id_and_annotation_id2annotation_text[
|
538 |
+
(doc.id, labeled_span_to_id(annotation))
|
539 |
+
] = str(annotation)
|
540 |
+
|
541 |
+
# calculate cosine similarities between all embeddings
|
542 |
+
dot_prod = np.dot(all_embeddings, all_embeddings.T)
|
543 |
+
norm = np.linalg.norm(all_embeddings, axis=1)
|
544 |
+
norm_prod = np.outer(norm, norm)
|
545 |
+
similarities = dot_prod / norm_prod
|
546 |
+
|
547 |
+
gr.Info(f"Similarities shape: {similarities.shape}")
|
548 |
+
|
549 |
+
if min_similarity is not None:
|
550 |
+
gr.Info(f"Filtering similarities below {min_similarity}.")
|
551 |
+
# set similarities below min_similarity to 0
|
552 |
+
similarities[similarities < min_similarity] = 0.0
|
553 |
+
|
554 |
+
# set triangular part to 0
|
555 |
+
similarities = np.triu(similarities, k=1)
|
556 |
+
# create a sparse matrix
|
557 |
+
sparse_matrix = csr_matrix(similarities)
|
558 |
+
sparse_matrix.eliminate_zeros()
|
559 |
+
# Get the non-zero values and their indices
|
560 |
+
non_zero_idx = sparse_matrix.nonzero()
|
561 |
+
scores = sparse_matrix.data
|
562 |
+
|
563 |
+
gr.Info(f"Number of similarities above {min_similarity}: {len(scores)}")
|
564 |
+
|
565 |
+
# construct the DataFrame
|
566 |
+
records = []
|
567 |
+
for idx1, idx2 in zip(non_zero_idx[0], non_zero_idx[1]):
|
568 |
+
if idx1 < idx2:
|
569 |
+
doc_id1 = all_payloads[idx1]["doc_id"]
|
570 |
+
doc_id2 = all_payloads[idx2]["doc_id"]
|
571 |
+
annotation_id1 = all_payloads[idx1]["annotation_id"]
|
572 |
+
annotation_id2 = all_payloads[idx2]["annotation_id"]
|
573 |
+
annotation_text1 = doc_id_and_annotation_id2annotation_text[
|
574 |
+
(doc_id1, annotation_id1)
|
575 |
+
]
|
576 |
+
annotation_text2 = doc_id_and_annotation_id2annotation_text[
|
577 |
+
(doc_id2, annotation_id2)
|
578 |
+
]
|
579 |
+
records.append(
|
580 |
+
{
|
581 |
+
"sim_score": scores[idx1],
|
582 |
+
"doc_id": doc_id1,
|
583 |
+
"other_doc_id": doc_id2,
|
584 |
+
"adu_id": annotation_id1,
|
585 |
+
"other_adu_id": annotation_id2,
|
586 |
+
"text": annotation_text1,
|
587 |
+
"other_text": annotation_text2,
|
588 |
+
}
|
589 |
)
|
590 |
+
result_df = pd.DataFrame(records)
|
591 |
+
gr.Info(f"DataFrame shape: {result_df.shape}")
|
592 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
593 |
if columns is not None:
|
594 |
result_df = result_df[columns]
|
595 |
return result_df
|
embedding.py
CHANGED
@@ -114,10 +114,10 @@ class HuggingfaceEmbeddingModel(EmbeddingModel):
|
|
114 |
)
|
115 |
text_ann = tok2text_ann[tok_ann]
|
116 |
|
117 |
-
if text_ann in embeddings:
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
embeddings[text_ann] = embedding
|
122 |
example_idx += 1
|
123 |
|
|
|
114 |
)
|
115 |
text_ann = tok2text_ann[tok_ann]
|
116 |
|
117 |
+
# if text_ann in embeddings:
|
118 |
+
# logger.warning(
|
119 |
+
# f"Overwriting embedding for annotation '{text_ann}' (do you use striding?)"
|
120 |
+
# )
|
121 |
embeddings[text_ann] = embedding
|
122 |
example_idx += 1
|
123 |
|
rendering_utils.py
CHANGED
@@ -4,12 +4,20 @@ from collections import defaultdict
|
|
4 |
from typing import Dict, List, Optional, Union
|
5 |
|
6 |
from annotation_utils import labeled_span_to_id
|
7 |
-
from pytorch_ie.annotations import BinaryRelation
|
8 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
9 |
from rendering_utils_displacy import EntityRenderer
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
def render_pretty_table(
|
15 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
|
@@ -36,15 +44,27 @@ def render_displacy(
|
|
36 |
**render_kwargs,
|
37 |
):
|
38 |
|
39 |
-
|
40 |
spacy_doc = {
|
41 |
"text": document.text,
|
42 |
"ents": [
|
43 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
],
|
45 |
"title": None,
|
46 |
}
|
47 |
|
|
|
|
|
|
|
|
|
48 |
renderer = EntityRenderer(options=entity_options)
|
49 |
html = renderer.render([spacy_doc], page=True, minify=True).strip()
|
50 |
|
@@ -53,10 +73,9 @@ def render_displacy(
|
|
53 |
binary_relations = list(document.binary_relations) + list(
|
54 |
document.binary_relations.predictions
|
55 |
)
|
56 |
-
sorted_entities = sorted(spans, key=lambda x: (x.start, x.end))
|
57 |
html = inject_relation_data(
|
58 |
html,
|
59 |
-
|
60 |
binary_relations=binary_relations,
|
61 |
additional_colors=colors_hover,
|
62 |
)
|
@@ -65,7 +84,7 @@ def render_displacy(
|
|
65 |
|
66 |
def inject_relation_data(
|
67 |
html: str,
|
68 |
-
|
69 |
binary_relations: List[BinaryRelation],
|
70 |
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
|
71 |
) -> str:
|
@@ -80,11 +99,10 @@ def inject_relation_data(
|
|
80 |
entity2heads[relation.tail].append((relation.head, relation.label))
|
81 |
entity2tails[relation.head].append((relation.tail, relation.label))
|
82 |
|
|
|
83 |
# Add unique IDs to each entity
|
84 |
entities = soup.find_all(class_="entity")
|
85 |
-
for
|
86 |
-
annotation = sorted_entities[idx]
|
87 |
-
entity["id"] = labeled_span_to_id(annotation)
|
88 |
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
|
89 |
entity["data-color-original"] = original_color
|
90 |
if additional_colors is not None:
|
@@ -92,9 +110,11 @@ def inject_relation_data(
|
|
92 |
entity[f"data-color-{key}"] = (
|
93 |
json.dumps(color) if isinstance(color, dict) else color
|
94 |
)
|
95 |
-
entity_annotation =
|
96 |
-
# sanity check
|
97 |
-
|
|
|
|
|
98 |
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
|
99 |
entity["data-label"] = entity_annotation.label
|
100 |
entity["data-relation-tails"] = json.dumps(
|
|
|
4 |
from typing import Dict, List, Optional, Union
|
5 |
|
6 |
from annotation_utils import labeled_span_to_id
|
7 |
+
from pytorch_ie.annotations import BinaryRelation, LabeledSpan
|
8 |
from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
|
9 |
from rendering_utils_displacy import EntityRenderer
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
+
# adjusted from rendering_utils_displacy.TPL_ENT
|
14 |
+
TPL_ENT_WITH_ID = """
|
15 |
+
<mark class="entity" id="{id}" style="background: {bg}; padding: 0.45em 0.6em; margin: 0 0.25em; line-height: 1; border-radius: 0.35em;">
|
16 |
+
{text}
|
17 |
+
<span style="font-size: 0.8em; font-weight: bold; line-height: 1; border-radius: 0.35em; text-transform: uppercase; vertical-align: middle; margin-left: 0.5rem">{label}</span>
|
18 |
+
</mark>
|
19 |
+
"""
|
20 |
+
|
21 |
|
22 |
def render_pretty_table(
|
23 |
document: TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, **render_kwargs
|
|
|
44 |
**render_kwargs,
|
45 |
):
|
46 |
|
47 |
+
labeled_spans = list(document.labeled_spans) + list(document.labeled_spans.predictions)
|
48 |
spacy_doc = {
|
49 |
"text": document.text,
|
50 |
"ents": [
|
51 |
+
{
|
52 |
+
"start": labeled_span.start,
|
53 |
+
"end": labeled_span.end,
|
54 |
+
"label": labeled_span.label,
|
55 |
+
# pass the ID as a parameter to the entity. The id is required to fetch the entity annotations
|
56 |
+
# on hover and to inject the relation data.
|
57 |
+
"params": {"id": labeled_span_to_id(labeled_span)},
|
58 |
+
}
|
59 |
+
for labeled_span in labeled_spans
|
60 |
],
|
61 |
"title": None,
|
62 |
}
|
63 |
|
64 |
+
# copy to avoid modifying the original options
|
65 |
+
entity_options = entity_options.copy()
|
66 |
+
# use the custom template with the entity ID
|
67 |
+
entity_options["template"] = TPL_ENT_WITH_ID
|
68 |
renderer = EntityRenderer(options=entity_options)
|
69 |
html = renderer.render([spacy_doc], page=True, minify=True).strip()
|
70 |
|
|
|
73 |
binary_relations = list(document.binary_relations) + list(
|
74 |
document.binary_relations.predictions
|
75 |
)
|
|
|
76 |
html = inject_relation_data(
|
77 |
html,
|
78 |
+
labeled_spans=labeled_spans,
|
79 |
binary_relations=binary_relations,
|
80 |
additional_colors=colors_hover,
|
81 |
)
|
|
|
84 |
|
85 |
def inject_relation_data(
|
86 |
html: str,
|
87 |
+
labeled_spans: List[LabeledSpan],
|
88 |
binary_relations: List[BinaryRelation],
|
89 |
additional_colors: Optional[Dict[str, Union[str, dict]]] = None,
|
90 |
) -> str:
|
|
|
99 |
entity2heads[relation.tail].append((relation.head, relation.label))
|
100 |
entity2tails[relation.head].append((relation.tail, relation.label))
|
101 |
|
102 |
+
ann_id2annotation = {labeled_span_to_id(entity): entity for entity in labeled_spans}
|
103 |
# Add unique IDs to each entity
|
104 |
entities = soup.find_all(class_="entity")
|
105 |
+
for entity in entities:
|
|
|
|
|
106 |
original_color = entity["style"].split("background:")[1].split(";")[0].strip()
|
107 |
entity["data-color-original"] = original_color
|
108 |
if additional_colors is not None:
|
|
|
110 |
entity[f"data-color-{key}"] = (
|
111 |
json.dumps(color) if isinstance(color, dict) else color
|
112 |
)
|
113 |
+
entity_annotation = ann_id2annotation[entity["id"]]
|
114 |
+
# sanity check.
|
115 |
+
annotation_text_without_newline = str(entity_annotation).replace("\n", "")
|
116 |
+
# Just check the start, because the text has the label attached to the end
|
117 |
+
if not entity.text.startswith(annotation_text_without_newline):
|
118 |
logger.warning(f"Entity text mismatch: {entity_annotation} != {entity.text}")
|
119 |
entity["data-label"] = entity_annotation.label
|
120 |
entity["data-relation-tails"] = json.dumps(
|
vector_store.py
CHANGED
@@ -52,12 +52,16 @@ class VectorStore(Generic[T, E], abc.ABC):
|
|
52 |
def get(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> Optional[E]:
|
53 |
return self._get(emb_id=self._get_emb_id(emb_id=emb_id, payload=payload))
|
54 |
|
|
|
|
|
|
|
55 |
@abc.abstractmethod
|
56 |
def _retrieve_similar(
|
57 |
self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None
|
58 |
) -> List[Tuple[T, float]]:
|
59 |
"""Retrieve IDs, payloads and the respective similarity scores with respect to the
|
60 |
-
reference entry.
|
|
|
61 |
|
62 |
Args:
|
63 |
ref_id: The ID of the reference entry.
|
@@ -74,6 +78,8 @@ class VectorStore(Generic[T, E], abc.ABC):
|
|
74 |
def retrieve_similar(
|
75 |
self, ref_id: Optional[str] = None, ref_payload: Optional[T] = None, **kwargs
|
76 |
) -> List[Tuple[T, float]]:
|
|
|
|
|
77 |
return self._retrieve_similar(
|
78 |
ref_id=self._get_emb_id(emb_id=ref_id, payload=ref_payload), **kwargs
|
79 |
)
|
@@ -244,6 +250,8 @@ class QdrantVectorStore(VectorStore[T, List[float]]):
|
|
244 |
)
|
245 |
|
246 |
def _get(self, emb_id: str) -> Optional[List[float]]:
|
|
|
|
|
247 |
points = self.client.retrieve(
|
248 |
collection_name=self.COLLECTION_NAME,
|
249 |
ids=[self.emb_id2point_id[emb_id]],
|
|
|
52 |
def get(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> Optional[E]:
|
53 |
return self._get(emb_id=self._get_emb_id(emb_id=emb_id, payload=payload))
|
54 |
|
55 |
+
def has(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> bool:
|
56 |
+
return self.get(emb_id=emb_id, payload=payload) is not None
|
57 |
+
|
58 |
@abc.abstractmethod
|
59 |
def _retrieve_similar(
|
60 |
self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None
|
61 |
) -> List[Tuple[T, float]]:
|
62 |
"""Retrieve IDs, payloads and the respective similarity scores with respect to the
|
63 |
+
reference entry. In the case that the reference entry is not in the store itself, an empty
|
64 |
+
list will be returned.
|
65 |
|
66 |
Args:
|
67 |
ref_id: The ID of the reference entry.
|
|
|
78 |
def retrieve_similar(
|
79 |
self, ref_id: Optional[str] = None, ref_payload: Optional[T] = None, **kwargs
|
80 |
) -> List[Tuple[T, float]]:
|
81 |
+
if not self.has(emb_id=ref_id, payload=ref_payload):
|
82 |
+
return []
|
83 |
return self._retrieve_similar(
|
84 |
ref_id=self._get_emb_id(emb_id=ref_id, payload=ref_payload), **kwargs
|
85 |
)
|
|
|
250 |
)
|
251 |
|
252 |
def _get(self, emb_id: str) -> Optional[List[float]]:
|
253 |
+
if emb_id not in self.emb_id2point_id:
|
254 |
+
return None
|
255 |
points = self.client.retrieve(
|
256 |
collection_name=self.COLLECTION_NAME,
|
257 |
ids=[self.emb_id2point_id[emb_id]],
|