vision_pipeline / voyager_index.py
Antoine Chaffin
Initial commit
349b5c2
raw
history blame
6.62 kB
import os
import numpy as np
import pypdfium2 as pdfium
import torch
import tqdm
from model import encode_images, encode_queries
from PIL import Image
from sqlitedict import SqliteDict
from voyager import Index, Space
def iter_batch(
X: list[str], batch_size: int, tqdm_bar: bool = True, desc: str = ""
) -> list:
"""Iterate over a list of elements by batch."""
batchs = [X[pos : pos + batch_size] for pos in range(0, len(X), batch_size)]
if tqdm_bar:
for batch in tqdm.tqdm(
iterable=batchs,
position=0,
total=1 + len(X) // batch_size,
desc=desc,
):
yield batch
else:
yield from batchs
class Voyager:
"""Voyager index. The Voyager index is a fast and efficient index for approximate nearest neighbor search.
Parameters
----------
name
The name of the collection.
override
Whether to override the collection if it already exists.
embedding_size
The number of dimensions of the embeddings.
M
The number of subquantizers.
ef_construction
The number of candidates to evaluate during the construction of the index.
ef_search
The number of candidates to evaluate during the search.
"""
def __init__(
self,
index_folder: str = "indexes",
index_name: str = "base_collection",
override: bool = False,
embedding_size: int = 128,
M: int = 64,
ef_construction: int = 200,
ef_search: int = 200,
) -> None:
self.ef_search = ef_search
if not os.path.exists(path=index_folder):
os.makedirs(name=index_folder)
self.index_path = os.path.join(index_folder, f"{index_name}.voyager")
self.page_ids_to_data_path = os.path.join(
index_folder, f"{index_name}_page_ids_to_data.sqlite"
)
self.index = self._create_collection(
index_path=self.index_path,
embedding_size=embedding_size,
M=M,
ef_constructions=ef_construction,
override=override,
)
def _load_page_ids_to_data(self) -> SqliteDict:
"""Load the SQLite database that maps document IDs to images."""
return SqliteDict(self.page_ids_to_data_path, outer_stack=False)
def _create_collection(
self,
index_path: str,
embedding_size: int,
M: int,
ef_constructions: int,
override: bool,
) -> None:
"""Create a new Voyager collection.
Parameters
----------
index_path
The path to the index.
embedding_size
The size of the embeddings.
M
The number of subquantizers.
ef_constructions
The number of candidates to evaluate during the construction of the index.
override
Whether to override the collection if it already exists.
"""
if os.path.exists(path=index_path) and not override:
return Index.load(index_path)
if os.path.exists(path=index_path):
os.remove(index_path)
# Create the Voyager index
index = Index(
Space.Cosine,
num_dimensions=embedding_size,
M=M,
ef_construction=ef_constructions,
)
index.save(index_path)
if override and os.path.exists(path=self.page_ids_to_data_path):
os.remove(path=self.page_ids_to_data_path)
# Create the SQLite databases
page_ids_to_data = self._load_page_ids_to_data()
page_ids_to_data.close()
return index
def add_documents(
self,
paths: str | list[str],
batch_size: int = 1,
) -> None:
"""Add documents to the index. Note that batch_size means the number of pages to encode at once, not documents."""
if isinstance(paths, str):
paths = [paths]
page_ids_to_data = self._load_page_ids_to_data()
images = []
num_pages = []
for path in paths:
if path.lower().endswith(".pdf"):
pdf = pdfium.PdfDocument(path)
n_pages = len(pdf)
num_pages.append(n_pages)
for page_number in range(n_pages):
page = pdf.get_page(page_number)
pil_image = page.render(
scale=1,
rotation=0,
)
pil_image = pil_image.to_pil()
images.append(pil_image)
pdf.close()
else:
pil_image = Image.open(path)
images.append(pil_image)
num_pages.append(1)
embeddings = []
for batch in iter_batch(
X=images, batch_size=batch_size, desc=f"Encoding pages (bs={batch_size})"
):
embeddings.extend(encode_images(batch))
embeddings_ids = self.index.add_items(embeddings)
current_index = 0
for i, path in enumerate(paths):
for page_number in range(num_pages[i]):
page_ids_to_data[embeddings_ids[current_index]] = {
"path": path,
"image": images[current_index],
"page_number": page_number,
}
current_index += 1
page_ids_to_data.commit()
self.index.save(self.index_path)
return self
def __call__(
self,
queries: np.ndarray | torch.Tensor,
k: int = 10,
) -> dict:
"""Query the index for the nearest neighbors of the queries embeddings.
Parameters
----------
queries_embeddings
The queries embeddings.
k
The number of nearest neighbors to return.
"""
queries_embeddings = encode_queries(queries)
page_ids_to_data = self._load_page_ids_to_data()
k = min(k, len(page_ids_to_data))
n_queries = len(queries_embeddings)
indices, distances = self.index.query(
queries_embeddings, k, query_ef=self.ef_search
)
if len(indices) == 0:
raise ValueError("Index is empty, add documents before querying.")
documents = [
[page_ids_to_data[str(indice)] for indice in query_indices]
for query_indices in indices
]
page_ids_to_data.close()
return {
"documents": documents,
"distances": distances.reshape(n_queries, -1, k),
}