File size: 5,380 Bytes
f7e3261
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

import timm

from PIL import Image

import matplotlib.pyplot as plt

import os

# Thanks to ( ), proxy can be essentail :)
# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:10809'
# os.environ['HTTP_PROXY'] = 'http://127.0.0.1:10809'
# os.environ['ALL_PROXY'] = 'socks5://127.0.0.1:10808'

IMG_FILE_LIST = [
    './testcases/14.jpg',
    './testcases/15.jpg',
    './testcases/16.jpg',
    './testcases/17.jpg',
    './testcases/18.jpg',
    './testcases/19.jpg'
]

TANH_SCALE = 1


class Scorer(nn.Module):
    def __init__(
            self,
            model_name,
            pretrained=False,
            features_only=True,
            embedding_dim=128
    ):
        super(Scorer, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, features_only=features_only)
        pooled_dim = 128 + 256 + 512 + 1024
        self.layer_norms = nn.ModuleList([
            nn.LayerNorm(128),
            nn.LayerNorm(256),
            nn.LayerNorm(512),
            nn.LayerNorm(1024)
        ])
        self.mlp = nn.Sequential(
            nn.Linear(pooled_dim, pooled_dim),
            nn.BatchNorm1d(pooled_dim),
            nn.GELU(),
        )
        # Probably a BYOL-accidental BatchNorm could help ?
        self.mlp_1 = nn.Sequential(
            nn.Linear(pooled_dim, pooled_dim // 4),
            nn.BatchNorm1d(pooled_dim // 4),
            nn.GELU(),
            nn.Linear(pooled_dim // 4, 3),
            nn.Tanh()
        )
        self.mlp_2 = nn.Sequential(
            nn.Linear(pooled_dim, pooled_dim // 4),
            nn.GELU(),
            nn.Linear(pooled_dim // 4, 1),
        )

    def forward(self, x, upload_date=None, freeze_backbone=False):
        if freeze_backbone:
            with torch.no_grad():
                out_features = self.model(x)
        else:
            out_features = self.model(x)
        #  out_features: List [
        #  torch.Size([1, 128, x, x])
        #  torch.Size([1, 256, x, x])
        #  torch.Size([1, 512, x, x])
        #  torch.Size([1, 1024, x, x])
        #  ]
        # Pool the output features from each layer on the channel dimension
        pooled_features = [F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) for x in out_features]
        # Normalize the pooled features
        pooled_features = [self.layer_norms[i](x) for i, x in enumerate(pooled_features)]
        # Embed the upload date
        # date_embedding_features = self.embedding(upload_date)
        # Concatenate the pooled features
        out = torch.cat(pooled_features, dim=-1)
        # Concatenate the date embedding features
        # out = torch.cat([out, date_embedding_features], dim=-1)
        out = self.mlp(out)
        rl_out = self.mlp_1(out) * TANH_SCALE
        ai_out = self.mlp_2(out).squeeze(-1)
        return rl_out[:, 0], rl_out[:, 1], F.sigmoid(ai_out), rl_out[:, 2]


BACKBONE = 'convnextv2_base.fcmae'
RESOLUTION = 640
SHOW_GRAD = False
GRAD_SCALE = 50

MORE_LIKE = False
MORE_COLLECTION = False
LESS_AI = False
MORE_RELATIVE_POP = True

WEIGHT_PATH = './scorer.pt'

DECIVE = 'cuda'


def main():
    model = Scorer(BACKBONE)
    transform = transforms.Compose([
        transforms.Resize((RESOLUTION, RESOLUTION)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    model.load_state_dict(torch.load(WEIGHT_PATH))
    model.eval()
    model.to(DECIVE)

    # Show all the images in pyplot horizontally, and mark the predicted values under each image
    fig = plt.figure(figsize=(20, 20))
    for i, img_file in enumerate(IMG_FILE_LIST):
        img = Image.open(img_file, 'r').convert('RGB')
        transformed_img = transform(img).unsqueeze(0).to(DECIVE)
        transformed_img.requires_grad = True
        liking_pred, collection_pred, ai_pred, relative_pop = model(transformed_img, torch.tensor([1]), False)
        ax = fig.add_subplot(1, len(IMG_FILE_LIST), i + 1)

        backwardee = 0
        if MORE_LIKE:
            backwardee -= liking_pred
        if MORE_COLLECTION:
            backwardee -= collection_pred
        if LESS_AI:
            backwardee += ai_pred
        if MORE_RELATIVE_POP:
            backwardee -= relative_pop
        if SHOW_GRAD:
            model.zero_grad()
            # Figure out which part of the image is the most important to popularity
            backwardee.backward()
            # Get the gradients of the image, and normalize them
            gradients = transformed_img.grad
            # squeeze the batch dimension
            gradients = gradients.squeeze(0).detach()
            # resize the gradients to the same size as the image
            gradients = transforms.Resize((img.height, img.width))(gradients)
            # add the gradients to the image
            img = transforms.ToTensor()(img)
            img = img + gradients.cpu() * GRAD_SCALE
            img = transforms.ToPILImage()(img.cpu())
        ax.imshow(img)
        del img
        ax.set_title(
            f'Liking: {liking_pred.item():.3f}\nCollection: {collection_pred.item():.3f}\nAI: {ai_pred.item() * 100:.3f}%\nPopularity: {relative_pop.item():.3f}')
    plt.show()
    pass


if __name__ == '__main__':
    main()