sam-pointer-bart-base-v0.3 / vector_store.py
ArneBinder's picture
Upload 9 files
86277c0 verified
raw
history blame
3.29 kB
import abc
from typing import Generic, Hashable, List, Optional, Tuple, TypeVar
T = TypeVar("T", bound=Hashable)
E = TypeVar("E")
class VectorStore(Generic[T, E], abc.ABC):
@abc.abstractmethod
def save(self, emb_id: T, embedding: E) -> None:
"""Save an embedding for a given ID."""
pass
@abc.abstractmethod
def retrieve_similar(
self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
) -> List[Tuple[T, float]]:
"""Retrieve IDs and the respective similarity scores with respect to the reference entry.
Note that this requires the reference entry to be present in the store.
Args:
ref_id: The ID of the reference entry.
top_k: If provided, only the top-k most similar entries will be returned.
min_similarity: If provided, only entries with a similarity score greater or equal to
this value will be returned.
Returns:
A list of tuples consisting of the ID and the similarity score, sorted by similarity
score in descending order.
"""
pass
@abc.abstractmethod
def __len__(self):
pass
def vector_norm(vector: List[float]) -> float:
return sum(x**2 for x in vector) ** 0.5
def cosine_similarity(a: List[float], b: List[float]) -> float:
return sum(a * b for a, b in zip(a, b)) / (vector_norm(a) * vector_norm(b))
class SimpleVectorStore(VectorStore[T, List[float]]):
def __init__(self):
self.vectors: dict[T, List[float]] = {}
self._cache = {}
self._sim = cosine_similarity
def save(self, emb_id: T, embedding: List[float]) -> None:
self.vectors[emb_id] = embedding
def get(self, emb_id: T) -> Optional[List[float]]:
return self.vectors.get(emb_id)
def delete(self, emb_id: T) -> None:
if emb_id in self.vectors:
del self.vectors[emb_id]
# remove from cache
self._cache = {k: v for k, v in self._cache.items() if emb_id not in k}
def clear(self) -> None:
self.vectors.clear()
self._cache.clear()
def __len__(self):
return len(self.vectors)
def retrieve_similar(
self, ref_id: T, top_k: Optional[int] = None, min_similarity: Optional[float] = None
) -> List[Tuple[T, float]]:
ref_embedding = self.get(ref_id)
if ref_embedding is None:
raise ValueError(f"Reference embedding '{ref_id}' not found.")
# calculate similarity to all embeddings
similarities = {}
for emb_id, embedding in self.vectors.items():
if (emb_id, ref_id) not in self._cache:
# use cosine similarity
self._cache[(emb_id, ref_id)] = self._sim(ref_embedding, embedding)
similarities[emb_id] = self._cache[(emb_id, ref_id)]
# sort by similarity
similar_entries = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
if min_similarity is not None:
similar_entries = [
(emb_id, sim) for emb_id, sim in similar_entries if sim >= min_similarity
]
if top_k is not None:
similar_entries = similar_entries[:top_k]
return similar_entries