|
import abc |
|
from typing import Generic, Hashable, List, Optional, Tuple, TypeVar |
|
|
|
T = TypeVar("T", bound=Hashable) |
|
E = TypeVar("E") |
|
|
|
|
|
class VectorStore(Generic[T, E], abc.ABC): |
|
@abc.abstractmethod |
|
def save(self, emb_id: T, embedding: E) -> None: |
|
"""Save an embedding for a given ID.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def retrieve_similar( |
|
self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None |
|
) -> List[Tuple[T, float]]: |
|
"""Retrieve IDs and the respective similarity scores with respect to the reference entry. |
|
Note that this requires the reference entry to be present in the store. |
|
|
|
Args: |
|
ref_id: The ID of the reference entry. |
|
top_k: If provided, only the top-k most similar entries will be returned. |
|
min_similarity: If provided, only entries with a similarity score greater or equal to |
|
this value will be returned. |
|
|
|
Returns: |
|
A list of tuples consisting of the ID and the similarity score, sorted by similarity |
|
score in descending order. |
|
""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def __len__(self): |
|
pass |
|
|
|
|
|
def vector_norm(vector: List[float]) -> float: |
|
return sum(x**2 for x in vector) ** 0.5 |
|
|
|
|
|
def cosine_similarity(a: List[float], b: List[float]) -> float: |
|
return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b)) |
|
|
|
|
|
class SimpleVectorStore(VectorStore[T, List[float]]): |
|
def __init__(self): |
|
self.vectors: dict[T, List[float]] = {} |
|
self._cache = {} |
|
self._sim = cosine_similarity |
|
|
|
def save(self, emb_id: T, embedding: List[float]) -> None: |
|
self.vectors[emb_id] = embedding |
|
|
|
def get(self, emb_id: T) -> Optional[List[float]]: |
|
return self.vectors.get(emb_id) |
|
|
|
def delete(self, emb_id: T) -> None: |
|
if emb_id in self.vectors: |
|
del self.vectors[emb_id] |
|
|
|
self._cache = {k: v for k, v in self._cache.items() if emb_id not in k} |
|
|
|
def clear(self) -> None: |
|
self.vectors.clear() |
|
self._cache.clear() |
|
|
|
def __len__(self): |
|
return len(self.vectors) |
|
|
|
def retrieve_similar( |
|
self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None |
|
) -> List[Tuple[T, float]]: |
|
ref_embedding = self.get(ref_id) |
|
if ref_embedding is None: |
|
raise ValueError(f"Reference embedding '{ref_id}' not found.") |
|
|
|
|
|
similarities = {} |
|
for emb_id, embedding in self.vectors.items(): |
|
if (emb_id, ref_id) not in self._cache: |
|
|
|
self._cache[(emb_id, ref_id)] = self._sim(ref_embedding, embedding) |
|
similarities[emb_id] = self._cache[(emb_id, ref_id)] |
|
|
|
|
|
similar_entries = sorted(similarities.items(), key=lambda x: x[1], reverse=True) |
|
|
|
if min_similarity is not None: |
|
similar_entries = [ |
|
(emb_id, sim) for emb_id, sim in similar_entries if sim >= min_similarity |
|
] |
|
if top_k is not None: |
|
similar_entries = similar_entries[:top_k] |
|
|
|
return similar_entries |
|
|