import faiss import faiss.contrib.torch_utils import time import logging import torch import numpy as np code_size = 64 class DatastoreBatch(): def __init__(self, dim, batch_size, flat_index=False, gpu_index=False, verbose=False, index_device=None) -> None: self.indices = [] self.batch_size = batch_size self.device = index_device if index_device is not None else torch.device('cuda' if gpu_index else 'cpu') for i in range(batch_size): self.indices.append(Datastore(dim, use_flat_index=flat_index, gpu_index=gpu_index, verbose=verbose, device=self.device)) def move_to_gpu(self): for i in range(self.batch_size): self.indices[i].move_to_gpu() def add_keys(self, keys, num_keys_to_add_at_a_time=100000): for i in range(self.batch_size): self.indices[i].add_keys(keys[i], num_keys_to_add_at_a_time) def train_index(self, keys): for index, example_keys in zip(self.indices, keys): index.train_index(example_keys) def search(self, queries, k): found_scores, found_values = [], [] for i in range(self.batch_size): scores, values = self.indices[i].search(queries[i], k) found_scores.append(scores) found_values.append(values) return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0) def search_and_reconstruct(self, queries, k): found_scores, found_values = [], [] found_vectors = [] for i in range(self.batch_size): scores, values, vectors = self.indices[i].search_and_reconstruct(queries[i], k) found_scores.append(scores) found_values.append(values) found_vectors.append(vectors) return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0), torch.stack(found_vectors, dim=0) class Datastore(): def __init__(self, dim, use_flat_index=False, gpu_index=False, verbose=False, device=None) -> None: self.dimension = dim self.device = device if device is not None else torch.device('cuda' if gpu_index else 'cpu') self.logger = logging.getLogger('index_building') self.logger.setLevel(20) self.use_flat_index = use_flat_index self.gpu_index = gpu_index # Initialize faiss index # TODO: is preprocessing efficient enough to spend time on? if not use_flat_index: self.index = faiss.IndexFlatIP(self.dimension) # inner product index because we use IP attention # need to wrap in index ID map to enable add_with_ids # self.index = faiss.IndexIDMap(self.index) self.index_size = 0 # if self.gpu_index: # self.move_to_gpu() def move_to_gpu(self): if self.use_flat_index: # self.keys = self.keys.to(self.device) return else: co = faiss.GpuClonerOptions() co.useFloat16 = True self.index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), self.device.index, self.index, co) def train_index(self, keys): if self.use_flat_index: self.add_keys(keys=keys, index_is_trained=True) else: keys = keys.cpu().float() ncentroids = int(keys.shape[0] / 128) self.index = faiss.IndexIVFPQ(self.index, self.dimension, ncentroids, code_size, 8) self.index.nprobe = min(32, ncentroids) # if not self.gpu_index: # keys = keys.cpu() self.logger.info('Training index') start_time = time.time() self.index.train(keys) self.logger.info(f'Training took {time.time() - start_time} s') self.add_keys(keys=keys, index_is_trained=True) # self.keys = None if self.gpu_index: self.move_to_gpu() def add_keys(self, keys, num_keys_to_add_at_a_time=1000000, index_is_trained=False): self.keys = keys if not self.use_flat_index and index_is_trained: start = 0 while start < keys.shape[0]: end = min(len(keys), start + num_keys_to_add_at_a_time) to_add = keys[start:end] # if not self.gpu_index: # to_add = to_add.cpu() # self.index.add_with_ids(to_add, torch.arange(start+self.index_size, end+self.index_size)) self.index.add(to_add) self.index_size += end - start start += end if (start % 1000000) == 0: self.logger.info(f'Added {start} tokens so far') # else: # self.keys.append(keys) # self.logger.info(f'Adding total {start} keys') # self.logger.info(f'Adding took {time.time() - start_time} s') def search_and_reconstruct(self, queries, k): if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim self.logger.info("Searching for a single vector; unsqueezing") queries = queries.unsqueeze(0) # self.logger.info("Searching with reconstruct") assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors scores, values, vectors = self.index.index.search_and_reconstruct(queries.cpu().detach(), k) # self.logger.info("Searching done") return scores, values, vectors def search(self, queries, k): # model_device = queries.device # model_dtype = queries.dtype if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim self.logger.info("Searching for a single vector; unsqueezing") queries = queries.unsqueeze(0) assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors # if not self.gpu_index: # queries = queries.cpu() # else: # queries = queries.to(self.device) if self.use_flat_index: if self.gpu_index: scores, values = faiss.knn_gpu(faiss.StandardGpuResources(), queries, self.keys, k, metric=faiss.METRIC_INNER_PRODUCT, device=self.device.index) else: scores, values = faiss.knn(queries, self.keys, k, metric=faiss.METRIC_INNER_PRODUCT) scores = torch.from_numpy(scores).to(queries.dtype) values = torch.from_numpy(values) #.to(model_dtype) else: scores, values = self.index.search(queries.float(), k) # avoid returning -1 as a value # TODO: get a handle on the attention mask and mask the values that were -1 values = torch.where(torch.logical_or(values < 0, values >= self.keys.shape[0]), torch.zeros_like(values), values) # self.logger.info("Searching done") # return scores.to(model_dtype).to(model_device), values.to(model_device) return scores, values