File size: 7,052 Bytes
abca9bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import faiss
import faiss.contrib.torch_utils
import time
import logging

import torch
import numpy as np

code_size = 64

class DatastoreBatch():
    def __init__(self, dim, batch_size, flat_index=False, gpu_index=False, verbose=False, index_device=None) -> None:
        self.indices = []
        self.batch_size = batch_size
        self.device = index_device if index_device is not None else torch.device('cuda' if gpu_index else 'cpu')
        for i in range(batch_size):
            self.indices.append(Datastore(dim, use_flat_index=flat_index, gpu_index=gpu_index, verbose=verbose, device=self.device))
    
    def move_to_gpu(self):
        for i in range(self.batch_size):
            self.indices[i].move_to_gpu()

    def add_keys(self, keys, num_keys_to_add_at_a_time=100000):
        for i in range(self.batch_size):
            self.indices[i].add_keys(keys[i], num_keys_to_add_at_a_time)
        
    def train_index(self, keys):
        for index, example_keys in zip(self.indices, keys):
            index.train_index(example_keys)
    
    def search(self, queries, k):
        found_scores, found_values = [], []
        for i in range(self.batch_size):
            scores, values = self.indices[i].search(queries[i], k)
            found_scores.append(scores)
            found_values.append(values)
        return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0)

    def search_and_reconstruct(self, queries, k):
        found_scores, found_values = [], []
        found_vectors = []
        for i in range(self.batch_size):
            scores, values, vectors = self.indices[i].search_and_reconstruct(queries[i], k)
            found_scores.append(scores)
            found_values.append(values)
            found_vectors.append(vectors)     
        return torch.stack(found_scores, dim=0), torch.stack(found_values, dim=0), torch.stack(found_vectors, dim=0)

class Datastore():
    def __init__(self, dim, use_flat_index=False, gpu_index=False, verbose=False, device=None) -> None:
        self.dimension = dim
        self.device = device if device is not None else torch.device('cuda' if gpu_index else 'cpu')
        self.logger = logging.getLogger('index_building')
        self.logger.setLevel(20)
        self.use_flat_index = use_flat_index
        self.gpu_index = gpu_index

        # Initialize faiss index
        # TODO: is preprocessing efficient enough to spend time on?
        if not use_flat_index:
            self.index = faiss.IndexFlatIP(self.dimension) # inner product index because we use IP attention
        
        # need to wrap in index ID map to enable add_with_ids 
        # self.index = faiss.IndexIDMap(self.index) 

        self.index_size = 0
        # if self.gpu_index:
        #     self.move_to_gpu()
        
    def move_to_gpu(self):
        if self.use_flat_index:
            # self.keys = self.keys.to(self.device)
            return
        else:
            co = faiss.GpuClonerOptions()
            co.useFloat16 = True
            self.index = faiss.index_cpu_to_gpu(faiss.StandardGpuResources(), self.device.index, self.index, co)
    
    def train_index(self, keys):
        if self.use_flat_index:
            self.add_keys(keys=keys, index_is_trained=True)
        else:
            keys = keys.cpu().float()
            ncentroids = int(keys.shape[0] / 128)
            self.index = faiss.IndexIVFPQ(self.index, self.dimension,
                ncentroids, code_size, 8)
            self.index.nprobe = min(32, ncentroids)
            # if not self.gpu_index:
            #     keys = keys.cpu()

            self.logger.info('Training index')
            start_time = time.time()
            self.index.train(keys)
            self.logger.info(f'Training took {time.time() - start_time} s')
            self.add_keys(keys=keys, index_is_trained=True)
            # self.keys = None
            if self.gpu_index:
                self.move_to_gpu()

    def add_keys(self, keys, num_keys_to_add_at_a_time=1000000, index_is_trained=False):
        self.keys = keys
        if not self.use_flat_index and index_is_trained:
            start = 0
            while start < keys.shape[0]:
                end = min(len(keys), start + num_keys_to_add_at_a_time)
                to_add = keys[start:end]
                # if not self.gpu_index:
                #     to_add = to_add.cpu()
                # self.index.add_with_ids(to_add, torch.arange(start+self.index_size, end+self.index_size))
                self.index.add(to_add)
                self.index_size += end - start
                start += end
                if (start % 1000000) == 0:
                    self.logger.info(f'Added {start} tokens so far')
        # else:
        #     self.keys.append(keys)

        # self.logger.info(f'Adding total {start} keys')
        # self.logger.info(f'Adding took {time.time() - start_time} s')

    def search_and_reconstruct(self, queries, k):
        if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim
            self.logger.info("Searching for a single vector; unsqueezing")
            queries = queries.unsqueeze(0)
        # self.logger.info("Searching with reconstruct")
        assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors
        scores, values, vectors = self.index.index.search_and_reconstruct(queries.cpu().detach(), k)
        # self.logger.info("Searching done")
        return scores, values, vectors
    
    def search(self, queries, k):
        # model_device = queries.device
        # model_dtype = queries.dtype
        if len(queries.shape) == 1: # searching for only 1 vector, add one extra dim
            self.logger.info("Searching for a single vector; unsqueezing")
            queries = queries.unsqueeze(0)
        assert queries.shape[-1] == self.dimension # query vectors are same shape as "key" vectors
        # if not self.gpu_index:
        #     queries = queries.cpu()
        # else:
        #     queries = queries.to(self.device)
        if self.use_flat_index:
            if self.gpu_index:
                scores, values = faiss.knn_gpu(faiss.StandardGpuResources(), queries, self.keys, k, 
                    metric=faiss.METRIC_INNER_PRODUCT, device=self.device.index)
            else:
                scores, values = faiss.knn(queries, self.keys, k, metric=faiss.METRIC_INNER_PRODUCT)
                scores = torch.from_numpy(scores).to(queries.dtype)
                values = torch.from_numpy(values) #.to(model_dtype)
        else:
            scores, values = self.index.search(queries.float(), k)
        
        # avoid returning -1 as a value
        # TODO: get a handle on the attention mask and mask the values that were -1
        values = torch.where(torch.logical_or(values < 0, values >= self.keys.shape[0]), torch.zeros_like(values), values)
        # self.logger.info("Searching done")
        # return scores.to(model_dtype).to(model_device), values.to(model_device)
        return scores, values