ArneBinder commited on
Commit
d7a2972
1 Parent(s): 1681237

from https://github.com/ArneBinder/pie-document-level/pull/225

Browse files
Files changed (4) hide show
  1. app.py +13 -5
  2. document_store.py +120 -27
  3. requirements.txt +1 -0
  4. vector_store.py +165 -17
app.py CHANGED
@@ -16,6 +16,7 @@ from pytorch_ie import Pipeline
16
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
17
  from rendering_utils import render_displacy, render_pretty_table
18
  from transformers import PreTrainedModel, PreTrainedTokenizer
 
19
 
20
  logger = logging.getLogger(__name__)
21
 
@@ -65,6 +66,9 @@ def wrapped_process_text(
65
  document_store.add_document(document)
66
  except Exception as e:
67
  raise gr.Error(f"Failed to process text: {e}")
 
 
 
68
  # Return as dict and document to avoid serialization issues
69
  return document.asdict(), document
70
 
@@ -117,7 +121,7 @@ def select_processed_document(
117
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
118
  row_idx, col_idx = evt.index
119
  doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
120
- doc = document_store.get_document(doc_id)
121
  return doc
122
 
123
 
@@ -144,7 +148,7 @@ def download_processed_documents(
144
  file_name: str = "processed_documents.json",
145
  ) -> str:
146
  file_path = os.path.join(tempfile.gettempdir(), file_name)
147
- document_store.save_to_json(file_path, indent=2)
148
  return file_path
149
 
150
 
@@ -152,7 +156,7 @@ def upload_processed_documents(
152
  file_name: str,
153
  document_store: DocumentStore,
154
  ) -> pd.DataFrame:
155
- document_store.add_documents_from_json(file_name)
156
  return document_store.overview()
157
 
158
 
@@ -197,7 +201,11 @@ def main():
197
 
198
  with gr.Blocks() as demo:
199
  document_store_state = gr.State(
200
- DocumentStore(span_annotation_caption="adu", relation_annotation_caption="relation")
 
 
 
 
201
  )
202
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
203
  models_state = gr.State((argumentation_model, embedding_model))
@@ -379,7 +387,7 @@ def main():
379
  )
380
 
381
  download_processed_documents_btn.click(
382
- fn=download_processed_documents,
383
  inputs=[document_store_state],
384
  outputs=[download_processed_documents_btn],
385
  )
 
16
  from pytorch_ie.documents import TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions
17
  from rendering_utils import render_displacy, render_pretty_table
18
  from transformers import PreTrainedModel, PreTrainedTokenizer
19
+ from vector_store import QdrantVectorStore, SimpleVectorStore
20
 
21
  logger = logging.getLogger(__name__)
22
 
 
66
  document_store.add_document(document)
67
  except Exception as e:
68
  raise gr.Error(f"Failed to process text: {e}")
69
+ # remove the embeddings because they are very large
70
+ if document.metadata.get("embeddings"):
71
+ document.metadata = {k: v for k, v in document.metadata.items() if k != "embeddings"}
72
  # Return as dict and document to avoid serialization issues
73
  return document.asdict(), document
74
 
 
121
  ) -> TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions:
122
  row_idx, col_idx = evt.index
123
  doc_id = processed_documents_df.iloc[row_idx]["doc_id"]
124
+ doc = document_store.get_document(doc_id, with_embeddings=False)
125
  return doc
126
 
127
 
 
148
  file_name: str = "processed_documents.json",
149
  ) -> str:
150
  file_path = os.path.join(tempfile.gettempdir(), file_name)
151
+ document_store.save_to_file(file_path, indent=2)
152
  return file_path
153
 
154
 
 
156
  file_name: str,
157
  document_store: DocumentStore,
158
  ) -> pd.DataFrame:
159
+ document_store.add_documents_from_file(file_name)
160
  return document_store.overview()
161
 
162
 
 
201
 
202
  with gr.Blocks() as demo:
203
  document_store_state = gr.State(
204
+ DocumentStore(
205
+ span_annotation_caption="adu",
206
+ relation_annotation_caption="relation",
207
+ vector_store=QdrantVectorStore(),
208
+ )
209
  )
210
  # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called
211
  models_state = gr.State((argumentation_model, embedding_model))
 
387
  )
388
 
389
  download_processed_documents_btn.click(
390
+ fn=partial(download_processed_documents, file_name="processed_documents.zip"),
391
  inputs=[document_store_state],
392
  outputs=[download_processed_documents_btn],
393
  )
document_store.py CHANGED
@@ -1,7 +1,11 @@
1
  import json
2
  import logging
 
 
 
 
3
  from collections import defaultdict
4
- from typing import Dict, List, Optional, Tuple
5
 
6
  import gradio as gr
7
  import pandas as pd
@@ -11,7 +15,7 @@ from pytorch_ie.documents import (
11
  TextBasedDocument,
12
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
13
  )
14
- from vector_store import SimpleVectorStore, VectorStore
15
 
16
  logger = logging.getLogger(__name__)
17
 
@@ -134,9 +138,11 @@ class DocumentStore:
134
  are used, otherwise the gold annotations are used.
135
  """
136
 
 
 
137
  def __init__(
138
  self,
139
- vector_store: Optional[VectorStore[Tuple[str, str], List[float]]] = None,
140
  document_type: type[
141
  TextBasedDocument
142
  ] = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
@@ -151,9 +157,7 @@ class DocumentStore:
151
  self.documents: Dict[str, TextBasedDocument] = {}
152
  # The vector store to efficiently retrieve similar spans. Can be constructed from the
153
  # documents.
154
- self.vector_store: VectorStore[Tuple[str, str], List[float]] = (
155
- vector_store or SimpleVectorStore()
156
- )
157
  # the document type (to create new documents from dicts)
158
  self.document_type = document_type
159
  self.span_layer_name = span_layer_name
@@ -180,6 +184,10 @@ class DocumentStore:
180
  document, annotation_id, annotation_layer, use_predictions=use_predictions
181
  )
182
 
 
 
 
 
183
  def get_similar_annotations_df(
184
  self,
185
  ref_annotation_id: str,
@@ -203,27 +211,25 @@ class DocumentStore:
203
  """
204
 
205
  similar_entries = self.vector_store.retrieve_similar(
206
- ref_id=(ref_document.id, ref_annotation_id),
207
  **similarity_kwargs,
208
  )
209
 
210
  similar_annotations = [
211
  self.get_annotation(
212
- doc_id=doc_id,
213
- annotation_id=annotation_id,
214
  annotation_layer=annotation_layer,
215
  use_predictions=self.use_predictions,
216
  )
217
- for (doc_id, annotation_id), _ in similar_entries
218
  ]
219
  df = pd.DataFrame(
220
  [
221
  # unpack the tuple (doc_id, annotation_id) to separate columns
222
  # and add the similarity score and the text of the annotation
223
- (doc_id, annotation_id, score, str(annotation))
224
- for ((doc_id, annotation_id), score), annotation in zip(
225
- similar_entries, similar_annotations
226
- )
227
  ],
228
  columns=["doc_id", "annotation_id", "sim_score", "text"],
229
  )
@@ -258,19 +264,20 @@ class DocumentStore:
258
  """
259
 
260
  similar_entries = self.vector_store.retrieve_similar(
261
- ref_id=(ref_document.id, ref_annotation_id),
262
  min_similarity=min_similarity,
263
  top_k=top_k,
264
  )
265
  result = []
266
- for (doc_id, annotation_id), score in similar_entries:
 
267
  # skip entries from the same document
268
  if doc_id == ref_document.id:
269
  continue
270
  document = self.documents[doc_id]
271
  reference_annotation = get_annotation_from_document(
272
  document=document,
273
- annotation_id=annotation_id,
274
  annotation_layer=self.span_layer_name,
275
  use_predictions=self.use_predictions,
276
  )
@@ -295,12 +302,21 @@ class DocumentStore:
295
  if document.id in self.documents:
296
  gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
297
 
 
 
 
298
  # save the processed document to the index
299
  self.documents[document.id] = document
300
 
301
- # save the embeddings to the vector store
302
- for annotation_id, embedding in document.metadata["embeddings"].items():
303
- self.vector_store.save((document.id, annotation_id), embedding)
 
 
 
 
 
 
304
 
305
  except Exception as e:
306
  raise gr.Error(f"Failed to add document {document.id} to index: {e}")
@@ -325,12 +341,81 @@ class DocumentStore:
325
  f"Added {len(documents_json)} documents to the index ({len(self.documents)} documents in total)."
326
  )
327
 
328
- def save_to_json(self, file_path: str, **kwargs) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  with open(file_path, "w", encoding="utf-8") as f:
330
- json.dump(self.as_dict(), f, **kwargs)
331
-
332
- def get_document(self, doc_id: str) -> TextBasedDocument:
333
- return self.documents[doc_id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
  def overview(self) -> pd.DataFrame:
336
  rows = []
@@ -346,5 +431,13 @@ class DocumentStore:
346
  df = pd.DataFrame(rows)
347
  return df
348
 
349
- def as_dict(self) -> dict:
350
- return {doc_id: document.asdict() for doc_id, document in self.documents.items()}
 
 
 
 
 
 
 
 
 
1
  import json
2
  import logging
3
+ import os
4
+ import shutil
5
+ import tempfile
6
+ import zipfile
7
  from collections import defaultdict
8
+ from typing import Any, Dict, List, Optional
9
 
10
  import gradio as gr
11
  import pandas as pd
 
15
  TextBasedDocument,
16
  TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
17
  )
18
+ from vector_store import VectorStore
19
 
20
  logger = logging.getLogger(__name__)
21
 
 
138
  are used, otherwise the gold annotations are used.
139
  """
140
 
141
+ JSON_FILE_NAME = "documents.json"
142
+
143
  def __init__(
144
  self,
145
+ vector_store: VectorStore[Dict[str, Any], List[float]],
146
  document_type: type[
147
  TextBasedDocument
148
  ] = TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
 
157
  self.documents: Dict[str, TextBasedDocument] = {}
158
  # The vector store to efficiently retrieve similar spans. Can be constructed from the
159
  # documents.
160
+ self.vector_store = vector_store
 
 
161
  # the document type (to create new documents from dicts)
162
  self.document_type = document_type
163
  self.span_layer_name = span_layer_name
 
184
  document, annotation_id, annotation_layer, use_predictions=use_predictions
185
  )
186
 
187
+ def construct_embedding_payload(self, document: TextBasedDocument, annotation_id: str) -> dict:
188
+ payload = {"doc_id": document.id, "annotation_id": annotation_id}
189
+ return payload
190
+
191
  def get_similar_annotations_df(
192
  self,
193
  ref_annotation_id: str,
 
211
  """
212
 
213
  similar_entries = self.vector_store.retrieve_similar(
214
+ ref_payload=self.construct_embedding_payload(ref_document, ref_annotation_id),
215
  **similarity_kwargs,
216
  )
217
 
218
  similar_annotations = [
219
  self.get_annotation(
220
+ doc_id=payload["doc_id"],
221
+ annotation_id=payload["annotation_id"],
222
  annotation_layer=annotation_layer,
223
  use_predictions=self.use_predictions,
224
  )
225
+ for _, payload, _ in similar_entries
226
  ]
227
  df = pd.DataFrame(
228
  [
229
  # unpack the tuple (doc_id, annotation_id) to separate columns
230
  # and add the similarity score and the text of the annotation
231
+ (payload["doc_id"], payload["annotation_id"], score, str(annotation))
232
+ for (_, payload, score), annotation in zip(similar_entries, similar_annotations)
 
 
233
  ],
234
  columns=["doc_id", "annotation_id", "sim_score", "text"],
235
  )
 
264
  """
265
 
266
  similar_entries = self.vector_store.retrieve_similar(
267
+ ref_payload=self.construct_embedding_payload(ref_document, ref_annotation_id),
268
  min_similarity=min_similarity,
269
  top_k=top_k,
270
  )
271
  result = []
272
+ for _, payload, score in similar_entries:
273
+ doc_id = payload["doc_id"]
274
  # skip entries from the same document
275
  if doc_id == ref_document.id:
276
  continue
277
  document = self.documents[doc_id]
278
  reference_annotation = get_annotation_from_document(
279
  document=document,
280
+ annotation_id=payload["annotation_id"],
281
  annotation_layer=self.span_layer_name,
282
  use_predictions=self.use_predictions,
283
  )
 
302
  if document.id in self.documents:
303
  gr.Warning(f"Document '{document.id}' already in index. Overwriting.")
304
 
305
+ # copy the document to avoid side effects
306
+ document = document.copy()
307
+
308
  # save the processed document to the index
309
  self.documents[document.id] = document
310
 
311
+ # save the embeddings to the vector store, if available
312
+ if "embeddings" in document.metadata:
313
+ for annotation_id, embedding in document.metadata["embeddings"].items():
314
+ payload = self.construct_embedding_payload(document, annotation_id)
315
+ self.vector_store.add(payload=payload, embedding=embedding)
316
+ # remove the embeddings from the document metadata
317
+ document.metadata = {
318
+ k: v for k, v in document.metadata.items() if k != "embeddings"
319
+ }
320
 
321
  except Exception as e:
322
  raise gr.Error(f"Failed to add document {document.id} to index: {e}")
 
341
  f"Added {len(documents_json)} documents to the index ({len(self.documents)} documents in total)."
342
  )
343
 
344
+ def add_documents_from_zip(self, file_path: str) -> None:
345
+ temp_dir = os.path.join(tempfile.gettempdir(), "document_store")
346
+ # remove the temporary directory if it already exists
347
+ if os.path.exists(temp_dir):
348
+ shutil.rmtree(temp_dir)
349
+ with zipfile.ZipFile(file_path, "r") as zipf:
350
+ # extract all files to the temporary directory
351
+ zipf.extractall(temp_dir)
352
+ json_file_path = os.path.join(temp_dir, self.JSON_FILE_NAME)
353
+ self.add_documents_from_json(json_file_path)
354
+ # load the vector store from the temporary directory
355
+ self.vector_store.load_from_directory(temp_dir)
356
+ # delete the temporary directory
357
+ shutil.rmtree(temp_dir)
358
+
359
+ def add_documents_from_file(self, file_path: str) -> None:
360
+ if file_path.endswith(".json"):
361
+ self.add_documents_from_json(file_path)
362
+ elif file_path.endswith(".zip"):
363
+ self.add_documents_from_zip(file_path)
364
+ else:
365
+ raise gr.Error(f"Unsupported file format: {file_path}")
366
+
367
+ def save_to_json(self, file_path: str, include_embeddings: bool = True, **kwargs) -> None:
368
  with open(file_path, "w", encoding="utf-8") as f:
369
+ json.dump(self.as_dict(include_embeddings=include_embeddings), f, **kwargs)
370
+
371
+ def save_to_zip(self, file_path: str, **kwargs) -> None:
372
+ # first create a new temporary directory and save the documents as json file in it
373
+ temp_dir = os.path.join(tempfile.gettempdir(), "document_store")
374
+ # remove the temporary directory if it already exists
375
+ if os.path.exists(temp_dir):
376
+ shutil.rmtree(temp_dir)
377
+ os.makedirs(temp_dir)
378
+ temp_file_path = os.path.join(temp_dir, self.JSON_FILE_NAME)
379
+ self.save_to_json(temp_file_path, include_embeddings=False, **kwargs)
380
+ self.vector_store.save_to_directory(temp_dir)
381
+ # then zip all files in the temporary directory and write them to the target file
382
+ with zipfile.ZipFile(file_path, "w") as zipf:
383
+ for root, _, files in os.walk(temp_dir):
384
+ for file in files:
385
+ zipf.write(
386
+ os.path.join(root, file),
387
+ os.path.relpath(os.path.join(root, file), temp_dir),
388
+ )
389
+ # delete the temporary directory
390
+ shutil.rmtree(temp_dir)
391
+
392
+ def save_to_file(self, file_path: str, **kwargs) -> None:
393
+ if file_path.endswith(".json"):
394
+ self.save_to_json(file_path, **kwargs)
395
+ elif file_path.endswith(".zip"):
396
+ self.save_to_zip(file_path, **kwargs)
397
+ else:
398
+ raise gr.Error(f"Unsupported file format: {file_path}")
399
+
400
+ def get_document(self, doc_id: str, with_embeddings: bool = False) -> TextBasedDocument:
401
+ document = self.documents[doc_id]
402
+ if not with_embeddings:
403
+ return document
404
+
405
+ # TODO: is this really required?
406
+ # copy because we add the embeddings to the metadata
407
+ document = document.copy()
408
+ # get the embeddings from the vector store
409
+ embeddings = {}
410
+ for annotation in document[self.span_layer_name].predictions:
411
+ annotation_id = labeled_span_to_id(annotation)
412
+ payload = self.construct_embedding_payload(document, annotation_id)
413
+ embedding = self.vector_store.get(payload=payload)
414
+ if embedding is not None:
415
+ embeddings[annotation_id] = embedding
416
+ document.metadata["embeddings"] = embeddings
417
+
418
+ return document
419
 
420
  def overview(self) -> pd.DataFrame:
421
  rows = []
 
431
  df = pd.DataFrame(rows)
432
  return df
433
 
434
+ def as_dict(self, include_embeddings: bool = True) -> dict:
435
+ result = {}
436
+ for doc_id, document in self.documents.items():
437
+ doc_dict = document.asdict()
438
+ if not include_embeddings and "embeddings" in (doc_dict.get("metadata") or {}):
439
+ doc_dict["metadata"] = {
440
+ k: v for k, v in doc_dict["metadata"].items() if k != "embeddings"
441
+ }
442
+ result[doc_id] = doc_dict
443
+ return result
requirements.txt CHANGED
@@ -5,3 +5,4 @@ beautifulsoup4==4.12.3
5
  datasets==2.14.4
6
  # numpy 2.0.0 breaks the code
7
  numpy==1.25.2
 
 
5
  datasets==2.14.4
6
  # numpy 2.0.0 breaks the code
7
  numpy==1.25.2
8
+ qdrant-client==1.9.1
vector_store.py CHANGED
@@ -1,22 +1,48 @@
1
  import abc
2
- from typing import Generic, Hashable, List, Optional, Tuple, TypeVar
 
 
3
 
4
- T = TypeVar("T", bound=Hashable)
 
 
 
 
5
  E = TypeVar("E")
6
 
7
 
8
  class VectorStore(Generic[T, E], abc.ABC):
9
  @abc.abstractmethod
10
- def save(self, emb_id: T, embedding: E) -> None:
11
- """Save an embedding for a given ID."""
12
  pass
13
 
14
  @abc.abstractmethod
15
- def retrieve_similar(
16
- self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  ) -> List[Tuple[T, float]]:
18
- """Retrieve IDs and the respective similarity scores with respect to the reference entry.
19
- Note that this requires the reference entry to be present in the store.
20
 
21
  Args:
22
  ref_id: The ID of the reference entry.
@@ -30,10 +56,28 @@ class VectorStore(Generic[T, E], abc.ABC):
30
  """
31
  pass
32
 
 
 
 
 
 
 
 
33
  @abc.abstractmethod
34
  def __len__(self):
35
  pass
36
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def vector_norm(vector: List[float]) -> float:
39
  return sum(x**2 for x in vector) ** 0.5
@@ -44,34 +88,43 @@ def cosine_similarity(a: List[float], b: List[float]) -> float:
44
 
45
 
46
  class SimpleVectorStore(VectorStore[T, List[float]]):
 
 
 
 
 
47
  def __init__(self):
48
- self.vectors: dict[T, List[float]] = {}
 
49
  self._cache = {}
50
  self._sim = cosine_similarity
51
 
52
- def save(self, emb_id: T, embedding: List[float]) -> None:
53
  self.vectors[emb_id] = embedding
 
54
 
55
- def get(self, emb_id: T) -> Optional[List[float]]:
56
  return self.vectors.get(emb_id)
57
 
58
- def delete(self, emb_id: T) -> None:
59
  if emb_id in self.vectors:
60
  del self.vectors[emb_id]
 
61
  # remove from cache
62
  self._cache = {k: v for k, v in self._cache.items() if emb_id not in k}
63
 
64
  def clear(self) -> None:
65
  self.vectors.clear()
66
  self._cache.clear()
 
67
 
68
  def __len__(self):
69
  return len(self.vectors)
70
 
71
- def retrieve_similar(
72
- self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
73
- ) -> List[Tuple[T, float]]:
74
- ref_embedding = self.get(ref_id)
75
  if ref_embedding is None:
76
  raise ValueError(f"Reference embedding '{ref_id}' not found.")
77
 
@@ -93,4 +146,99 @@ class SimpleVectorStore(VectorStore[T, List[float]]):
93
  if top_k is not None:
94
  similar_entries = similar_entries[:top_k]
95
 
96
- return similar_entries
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import abc
2
+ import json
3
+ import os
4
+ from typing import Any, Generic, List, Optional, Tuple, TypeVar
5
 
6
+ import numpy as np
7
+ from qdrant_client import QdrantClient
8
+ from qdrant_client.models import Distance, PointStruct, VectorParams
9
+
10
+ T = TypeVar("T", bound=dict[str, Any])
11
  E = TypeVar("E")
12
 
13
 
14
  class VectorStore(Generic[T, E], abc.ABC):
15
  @abc.abstractmethod
16
+ def _add(self, embedding: E, payload: T, emb_id: str) -> None:
17
+ """Save an embedding with payload for a given ID."""
18
  pass
19
 
20
  @abc.abstractmethod
21
+ def _get(self, emb_id: str) -> Optional[E]:
22
+ """Get the embedding for a given ID."""
23
+ pass
24
+
25
+ def _get_emb_id(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> str:
26
+ if emb_id is None:
27
+ if payload is None:
28
+ raise ValueError("Either emb_id or payload must be provided.")
29
+ emb_id = json.dumps(payload, sort_keys=True)
30
+ return emb_id
31
+
32
+ def add(self, embedding: E, payload: T, emb_id: Optional[str] = None) -> None:
33
+ if emb_id is None:
34
+ emb_id = json.dumps(payload, sort_keys=True)
35
+ self._add(embedding=embedding, payload=payload, emb_id=emb_id)
36
+
37
+ def get(self, emb_id: Optional[str] = None, payload: Optional[T] = None) -> Optional[E]:
38
+ return self._get(emb_id=self._get_emb_id(emb_id=emb_id, payload=payload))
39
+
40
+ @abc.abstractmethod
41
+ def _retrieve_similar(
42
+ self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None
43
  ) -> List[Tuple[T, float]]:
44
+ """Retrieve IDs, payloads and the respective similarity scores with respect to the
45
+ reference entry. Note that this requires the reference entry to be present in the store.
46
 
47
  Args:
48
  ref_id: The ID of the reference entry.
 
56
  """
57
  pass
58
 
59
+ def retrieve_similar(
60
+ self, ref_id: Optional[str] = None, ref_payload: Optional[T] = None, **kwargs
61
+ ) -> List[Tuple[T, float]]:
62
+ return self._retrieve_similar(
63
+ ref_id=self._get_emb_id(emb_id=ref_id, payload=ref_payload), **kwargs
64
+ )
65
+
66
  @abc.abstractmethod
67
  def __len__(self):
68
  pass
69
 
70
+ def save_to_directory(self, directory: str) -> None:
71
+ """Save the vector store to a directory."""
72
+ raise NotImplementedError
73
+
74
+ def load_from_directory(self, directory: str, replace: bool = False) -> None:
75
+ """Load the vector store from a directory.
76
+
77
+ If `replace` is True, the current content of the store will be replaced.
78
+ """
79
+ raise NotImplementedError
80
+
81
 
82
  def vector_norm(vector: List[float]) -> float:
83
  return sum(x**2 for x in vector) ** 0.5
 
88
 
89
 
90
  class SimpleVectorStore(VectorStore[T, List[float]]):
91
+
92
+ INDEX_FILE = "vectors_index.json"
93
+ EMBEDDINGS_FILE = "vectors_data.npy"
94
+ PAYLOADS_FILE = "vectors_payloads.json"
95
+
96
  def __init__(self):
97
+ self.vectors: dict[str, List[float]] = {}
98
+ self.payloads: dict[str, T] = {}
99
  self._cache = {}
100
  self._sim = cosine_similarity
101
 
102
+ def _add(self, embedding: E, payload: T, emb_id: str) -> None:
103
  self.vectors[emb_id] = embedding
104
+ self.payloads[emb_id] = payload
105
 
106
+ def _get(self, emb_id: str) -> Optional[E]:
107
  return self.vectors.get(emb_id)
108
 
109
+ def delete(self, emb_id: str) -> None:
110
  if emb_id in self.vectors:
111
  del self.vectors[emb_id]
112
+ del self.payloads[emb_id]
113
  # remove from cache
114
  self._cache = {k: v for k, v in self._cache.items() if emb_id not in k}
115
 
116
  def clear(self) -> None:
117
  self.vectors.clear()
118
  self._cache.clear()
119
+ self.payloads.clear()
120
 
121
  def __len__(self):
122
  return len(self.vectors)
123
 
124
+ def _retrieve_similar(
125
+ self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None
126
+ ) -> List[Tuple[str, T, float]]:
127
+ ref_embedding = self.get(emb_id=ref_id)
128
  if ref_embedding is None:
129
  raise ValueError(f"Reference embedding '{ref_id}' not found.")
130
 
 
146
  if top_k is not None:
147
  similar_entries = similar_entries[:top_k]
148
 
149
+ return [(emb_id, self.payloads[emb_id], sim) for emb_id, sim in similar_entries]
150
+
151
+ def save_to_directory(self, directory: str) -> None:
152
+ os.makedirs(directory, exist_ok=True)
153
+ indices = list(self.vectors.keys())
154
+ with open(os.path.join(directory, self.INDEX_FILE), "w") as f:
155
+ json.dump(indices, f)
156
+ embeddings_np = np.array(list(self.vectors.values()))
157
+ np.save(os.path.join(directory, self.EMBEDDINGS_FILE), embeddings_np)
158
+ payloads = [self.payloads[idx] for idx in indices]
159
+ with open(os.path.join(directory, self.PAYLOADS_FILE), "w") as f:
160
+ json.dump(payloads, f)
161
+
162
+ def load_from_directory(self, directory: str, replace: bool = False) -> None:
163
+ if replace:
164
+ self.clear()
165
+ with open(os.path.join(directory, self.INDEX_FILE), "r") as f:
166
+ index = json.load(f)
167
+ embeddings_np = np.load(os.path.join(directory, self.EMBEDDINGS_FILE))
168
+ with open(os.path.join(directory, self.PAYLOADS_FILE), "r") as f:
169
+ payloads = json.load(f)
170
+ for emb_id, emb, payload in zip(index, embeddings_np, payloads):
171
+ self.vectors[emb_id] = emb.tolist()
172
+ self.payloads[emb_id] = payload
173
+
174
+
175
+ class QdrantVectorStore(VectorStore[T, List[float]]):
176
+
177
+ COLLECTION_NAME = "ADUs"
178
+ MAX_LIMIT = 100
179
+
180
+ def __init__(
181
+ self,
182
+ location: str = ":memory:",
183
+ vector_size: int = 768,
184
+ distance: Distance = Distance.COSINE,
185
+ ):
186
+ self.client = QdrantClient(location=location)
187
+ self.id2idx = {}
188
+ self.idx2id = {}
189
+ self.client.create_collection(
190
+ collection_name=self.COLLECTION_NAME,
191
+ vectors_config=VectorParams(size=vector_size, distance=distance),
192
+ )
193
+
194
+ def __len__(self):
195
+ return self.client.get_collection(collection_name=self.COLLECTION_NAME).points_count
196
+
197
+ def _add(self, emb_id: str, payload: T, embedding: List[float]) -> None:
198
+
199
+ # we use the length of the id2idx dict as the index,
200
+ # because we assume that, even when we delete an entry from
201
+ # the store, we do not delete it from the index
202
+ _id = len(self.id2idx)
203
+ self.client.upsert(
204
+ collection_name=self.COLLECTION_NAME,
205
+ points=[PointStruct(id=_id, vector=embedding, payload=payload)],
206
+ )
207
+ self.id2idx[emb_id] = _id
208
+ self.idx2id[_id] = emb_id
209
+
210
+ def _get(self, emb_id: str) -> Optional[List[float]]:
211
+ points = self.client.retrieve(
212
+ collection_name=self.COLLECTION_NAME,
213
+ ids=[self.id2idx[emb_id]],
214
+ with_vectors=True,
215
+ )
216
+ if len(points) == 0:
217
+ return None
218
+ elif len(points) == 1:
219
+ return points[0].vector
220
+ else:
221
+ raise ValueError(f"Multiple points found for ID '{emb_id}'.")
222
+
223
+ def _retrieve_similar(
224
+ self, ref_id: str, top_k: Optional[int] = None, min_similarity: Optional[float] = None
225
+ ) -> List[Tuple[str, T, float]]:
226
+ similar_entries = self.client.recommend(
227
+ collection_name=self.COLLECTION_NAME,
228
+ positive=[self.id2idx[ref_id]],
229
+ limit=top_k or self.MAX_LIMIT,
230
+ score_threshold=min_similarity,
231
+ )
232
+ return [(self.idx2id[entry.id], entry.payload, entry.score) for entry in similar_entries]
233
+
234
+ def clear(self) -> None:
235
+ vectors_config = self.client.get_collection(
236
+ collection_name=self.COLLECTION_NAME
237
+ ).vectors_config
238
+ self.client.delete_collection(collection_name=self.COLLECTION_NAME)
239
+ self.client.create_collection(
240
+ collection_name=self.COLLECTION_NAME,
241
+ vectors_config=vectors_config,
242
+ )
243
+ self.id2idx.clear()
244
+ self.idx2id.clear()