File size: 2,043 Bytes
bbd199b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import faiss
import numpy as np


class FaissNeighbors:
    def __init__(self):
        self.index = None
        self.y = None

    def fit(self, X, y):
        self.index = faiss.IndexFlatL2(X.shape[1])
        self.index.add(X.astype(np.float32))
        self.y = y

    def get_distances_and_indices(self, X, top_K=1000):
        distances, indices = self.index.search(X.astype(np.float32), k=top_K)
        return np.copy(distances), np.copy(indices), np.copy(self.y[indices])

    def get_nearest_labels(self, X, top_K=1000):
        distances, indices = self.index.search(X.astype(np.float32), k=top_K)
        return np.copy(self.y[indices])


class FaissCosineNeighbors:
    def __init__(self):
        self.cindex = None
        self.y = None

    def fit(self, X, y):
        self.cindex = faiss.index_factory(
            X.shape[1], "Flat", faiss.METRIC_INNER_PRODUCT
        )
        X = np.copy(X)
        X = X.astype(np.float32)
        faiss.normalize_L2(X)
        self.cindex.add(X)
        self.y = y

    def get_distances_and_indices(self, Q, topK):
        Q = np.copy(Q)
        faiss.normalize_L2(Q)
        distances, indices = self.cindex.search(Q.astype(np.float32), k=topK)
        return np.copy(distances), np.copy(indices), np.copy(self.y[indices])

    def get_nearest_labels(self, Q, topK=1000):
        Q = np.copy(Q)
        faiss.normalize_L2(Q)
        distances, indices = self.cindex.search(Q.astype(np.float32), k=topK)
        return np.copy(self.y[indices])


class SearchableTrainingSet:
    def __init__(self, embeddings, labels):
        self.simsearcher = FaissCosineNeighbors()
        self.X_train = embeddings
        self.y_train = labels

    def build_index(self):
        self.simsearcher.fit(self.X_train, self.y_train)

    def search(self, query, k=20):
        nearest_data_points = self.simsearcher.get_distances_and_indices(
            Q=query, topK=100
        )
        # topKs = [x[0] for x in Counter(nearest_data_points[0]).most_common(k)]
        return nearest_data_points