File size: 2,847 Bytes
bbd199b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import time
import os
import torch
import numpy as np
import torchvision
import torch.nn.functional as F
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from tqdm import tqdm
import pickle
import argparse


concat = lambda x: np.concatenate(x, axis=0)
to_np = lambda x: x.data.to("cpu").numpy()


class Wrapper(torch.nn.Module):
    def __init__(self, model):
        super(Wrapper, self).__init__()
        self.model = model
        self.avgpool_output = None
        self.query = None
        self.cossim_value = {}

        def fw_hook(module, input, output):
            self.avgpool_output = output.squeeze()

        self.model.avgpool.register_forward_hook(fw_hook)

    def forward(self, input):
        _ = self.model(input)
        return self.avgpool_output

    def __repr__(self):
        return "Wrappper"


def run(training_set_path):
    # Standard ImageNet Transformer to apply imagenet's statistics to input batch
    dataset_transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )

    training_imagefolder = ImageFolder(
        root=training_set_path, transform=dataset_transform
    )
    train_loader = torch.utils.data.DataLoader(
        training_imagefolder,
        batch_size=512,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )
    print(f"# of Training folder samples: {len(training_imagefolder)}")
    ########################################################################################################################
    model = torchvision.models.resnet50(pretrained=True)
    model.eval()
    myw = Wrapper(model)

    training_embeddings = []
    training_labels = []

    with torch.no_grad():
        for _, (data, target) in enumerate(tqdm(train_loader)):
            embeddings = to_np(myw(data))
            labels = to_np(target)

            training_embeddings.append(embeddings)
            training_labels.append(labels)

    training_embeddings_concatted = concat(training_embeddings)
    training_labels_concatted = concat(training_labels)
    
    print(training_embeddings_concatted.shape)
    return training_embeddings_concatted, training_labels_concatted


def main():
    parser = argparse.ArgumentParser(description="Saving Embeddings")
    parser.add_argument("--train", help="Path to the Dataaset", type=str, required=True)
    args = parser.parse_args()

    embeddings, labels = run(args.train)

    # Caluclate Accuracy
    with open(f"embeddings.pickle", "wb") as f:
        pickle.dump(embeddings, f)

    with open(f"labels.pickle", "wb") as f:
        pickle.dump(labels, f)


if __name__ == "__main__":
    main()