image-search-engine / datastore.py
dmedhi's picture
intial commit
40a2cd4
__import__("pysqlite3")
import sys
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
import uuid
from collections import defaultdict
from typing import Any, List
import chromadb
import numpy as np
from chromadb import Collection
from embeddings import Embedding
from PIL.Image import Image
from utils import base64_to_image
class ChromaStore:
def __init__(
self,
collection_name: str,
storage_path: str = "./chroma",
database: str = "database",
metadata: dict = {"hnsw:space": "cosine"},
) -> None:
"""Initiate Chromadb
- collection_name(str): name of the collection
- metadata(dict): available options for 'hnsw:space' are 'l2', 'ip' or 'cosine'.
"""
self.collection_name = collection_name
self.metadata = metadata
self.storage_path = storage_path
self.database = database
self.client = chromadb.PersistentClient(path=self.storage_path)
def _health_check(self) -> bool:
return isinstance(self.client.heartbeat(), int)
def generate_embeddings(
self, images: List[Image], embedding: Embedding
) -> np.ndarray:
return embedding.encode_images(images)
def create(self):
collection = self.client.get_or_create_collection(
name=self.collection_name,
)
return collection
def add(
self,
collection: Collection,
embeddings: List[float],
documents: List[str],
ids: List[str],
):
"""Add embeddings, documents to index or collection.
Args:
- collection: created collection.
- embeddings: list of image embeddings
- documents: list of base64 string of images
- ids: list of ids for images."""
try:
collection.add(
embeddings=embeddings,
ids=ids,
documents=documents,
)
except Exception as e:
raise Exception(f"Failed to add documents to Chroma store. {e}")
def query(
self,
collection: Collection,
query_embedding: List[float],
top_k: int = 3,
) -> list:
"""Retrieve relevant images from chroma database.
Args:
- collection: created collection.
- query_embedding: query image embedding.
- top_k (int): top k images to retrieve.
Returns:
- list of images along with their score.
"""
result = collection.query(query_embeddings=query_embedding, n_results=top_k)
relevant_images = [
base64_to_image(img_str) for img_str in result["documents"][0]
]
scores = [round(score, 3) for score in result["distances"][0]]
return list(zip(relevant_images, scores))
def delete(self, collection_name: str):
try:
self.client.delete_collection(collection_name)
return True
except Exception as e:
raise Exception("Failed to delete collection", e)
@staticmethod
def collection_info(collection: Collection):
info = defaultdict(str)
info["count"] = collection.count()
info["top_10_items"] = collection.peek()
return info