Spaces:
Runtime error
Runtime error
File size: 7,052 Bytes
abca9bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import faiss
import faiss.contrib.torch_utils
import time
import logging
import torch
import numpy as np
code_size = 64
class DatastoreBatch():
def __init__(self, dim, batch_size, flat_index=False, gpu_index=False, verbose=False, index_device=None) -> None:
self.indices = []
self.batch_size = batch_size
self.device = index_device if index_device is not None else torch.device('cuda' if gpu_index else 'cpu')
for i in range(batch_size):
self.indices.append(Datastore(dim, use_flat_index=flat_index, gpu_index=gpu_index, verbose=verbose, device=self.device))
def move_to_gpu(self):
for i in range(self.batch_size):
self.indices[i].move_to_gpu()
def add_keys(self, keys, num_keys_to_add_at_a_time=100000):
for i in range(self.batch_size):
self.indices[i].add_keys(keys[i], num_keys_to_add_at_a_time)
def train_index(self, keys):
for index, example_keys in zip(self.indices, keys):
index.train_index(example_keys)
def search(self, queries, k):
found_scores, found_values = [], []
for i in range(self.batch_size):
scores, values = self.indices[i].search(queries[i], k)
found_scores.append(scores)
found_values.append(values)
return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0)
def search_and_reconstruct(self, queries, k):
found_scores, found_values = [], []
found_vectors = []
for i in range(self.batch_size):
scores, values, vectors = self.indices[i].search_and_reconstruct(queries[i], k)
found_scores.append(scores)
found_values.append(values)
found_vectors.append(vectors)
return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0), torch.stack(found_vectors, dim=0)
class Datastore():
def __init__(self, dim, use_flat_index=False, gpu_index=False, verbose=False, device=None) -> None:
self.dimension = dim
self.device = device if device is not None else torch.device('cuda' if gpu_index else 'cpu')
self.logger = logging.getLogger('index_building')
self.logger.setLevel(20)
self.use_flat_index = use_flat_index
self.gpu_index = gpu_index
# Initialize faiss index
# TODO: is preprocessing efficient enough to spend time on?
if not use_flat_index:
self.index = faiss.IndexFlatIP(self.dimension) # inner product index because we use IP attention
# need to wrap in index ID map to enable add_with_ids
# self.index = faiss.IndexIDMap(self.index)
self.index_size = 0
# if self.gpu_index:
# self.move_to_gpu()
def move_to_gpu(self):
if self.use_flat_index:
# self.keys = self.keys.to(self.device)
return
else:
co = faiss.GpuClonerOptions()
co.useFloat16 = True
self.index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), self.device.index, self.index, co)
def train_index(self, keys):
if self.use_flat_index:
self.add_keys(keys=keys, index_is_trained=True)
else:
keys = keys.cpu().float()
ncentroids = int(keys.shape[0] / 128)
self.index = faiss.IndexIVFPQ(self.index, self.dimension,
ncentroids, code_size, 8)
self.index.nprobe = min(32, ncentroids)
# if not self.gpu_index:
# keys = keys.cpu()
self.logger.info('Training index')
start_time = time.time()
self.index.train(keys)
self.logger.info(f'Training took {time.time() - start_time} s')
self.add_keys(keys=keys, index_is_trained=True)
# self.keys = None
if self.gpu_index:
self.move_to_gpu()
def add_keys(self, keys, num_keys_to_add_at_a_time=1000000, index_is_trained=False):
self.keys = keys
if not self.use_flat_index and index_is_trained:
start = 0
while start < keys.shape[0]:
end = min(len(keys), start + num_keys_to_add_at_a_time)
to_add = keys[start:end]
# if not self.gpu_index:
# to_add = to_add.cpu()
# self.index.add_with_ids(to_add, torch.arange(start+self.index_size, end+self.index_size))
self.index.add(to_add)
self.index_size += end - start
start += end
if (start % 1000000) == 0:
self.logger.info(f'Added {start} tokens so far')
# else:
# self.keys.append(keys)
# self.logger.info(f'Adding total {start} keys')
# self.logger.info(f'Adding took {time.time() - start_time} s')
def search_and_reconstruct(self, queries, k):
if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim
self.logger.info("Searching for a single vector; unsqueezing")
queries = queries.unsqueeze(0)
# self.logger.info("Searching with reconstruct")
assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors
scores, values, vectors = self.index.index.search_and_reconstruct(queries.cpu().detach(), k)
# self.logger.info("Searching done")
return scores, values, vectors
def search(self, queries, k):
# model_device = queries.device
# model_dtype = queries.dtype
if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim
self.logger.info("Searching for a single vector; unsqueezing")
queries = queries.unsqueeze(0)
assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors
# if not self.gpu_index:
# queries = queries.cpu()
# else:
# queries = queries.to(self.device)
if self.use_flat_index:
if self.gpu_index:
scores, values = faiss.knn_gpu(faiss.StandardGpuResources(), queries, self.keys, k,
metric=faiss.METRIC_INNER_PRODUCT, device=self.device.index)
else:
scores, values = faiss.knn(queries, self.keys, k, metric=faiss.METRIC_INNER_PRODUCT)
scores = torch.from_numpy(scores).to(queries.dtype)
values = torch.from_numpy(values) #.to(model_dtype)
else:
scores, values = self.index.search(queries.float(), k)
# avoid returning -1 as a value
# TODO: get a handle on the attention mask and mask the values that were -1
values = torch.where(torch.logical_or(values < 0, values >= self.keys.shape[0]), torch.zeros_like(values), values)
# self.logger.info("Searching done")
# return scores.to(model_dtype).to(model_device), values.to(model_device)
return scores, values
|