|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
|
|
import timm |
|
|
|
from PIL import Image |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
IMG_FILE_LIST = [ |
|
'./testcases/14.jpg', |
|
'./testcases/15.jpg', |
|
'./testcases/16.jpg', |
|
'./testcases/17.jpg', |
|
'./testcases/18.jpg', |
|
'./testcases/19.jpg' |
|
] |
|
|
|
TANH_SCALE = 1 |
|
|
|
|
|
class Scorer(nn.Module): |
|
def __init__( |
|
self, |
|
model_name, |
|
pretrained=False, |
|
features_only=True, |
|
embedding_dim=128 |
|
): |
|
super(Scorer, self).__init__() |
|
self.model = timm.create_model(model_name, pretrained=pretrained, features_only=features_only) |
|
pooled_dim = 128 + 256 + 512 + 1024 |
|
self.layer_norms = nn.ModuleList([ |
|
nn.LayerNorm(128), |
|
nn.LayerNorm(256), |
|
nn.LayerNorm(512), |
|
nn.LayerNorm(1024) |
|
]) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(pooled_dim, pooled_dim), |
|
nn.BatchNorm1d(pooled_dim), |
|
nn.GELU(), |
|
) |
|
|
|
self.mlp_1 = nn.Sequential( |
|
nn.Linear(pooled_dim, pooled_dim // 4), |
|
nn.BatchNorm1d(pooled_dim // 4), |
|
nn.GELU(), |
|
nn.Linear(pooled_dim // 4, 3), |
|
nn.Tanh() |
|
) |
|
self.mlp_2 = nn.Sequential( |
|
nn.Linear(pooled_dim, pooled_dim // 4), |
|
nn.GELU(), |
|
nn.Linear(pooled_dim // 4, 1), |
|
) |
|
|
|
def forward(self, x, upload_date=None, freeze_backbone=False): |
|
if freeze_backbone: |
|
with torch.no_grad(): |
|
out_features = self.model(x) |
|
else: |
|
out_features = self.model(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pooled_features = [F.adaptive_avg_pool2d(x, 1).squeeze(-1).squeeze(-1) for x in out_features] |
|
|
|
pooled_features = [self.layer_norms[i](x) for i, x in enumerate(pooled_features)] |
|
|
|
|
|
|
|
out = torch.cat(pooled_features, dim=-1) |
|
|
|
|
|
out = self.mlp(out) |
|
rl_out = self.mlp_1(out) * TANH_SCALE |
|
ai_out = self.mlp_2(out).squeeze(-1) |
|
return rl_out[:, 0], rl_out[:, 1], F.sigmoid(ai_out), rl_out[:, 2] |
|
|
|
|
|
BACKBONE = 'convnextv2_base.fcmae' |
|
RESOLUTION = 640 |
|
SHOW_GRAD = False |
|
GRAD_SCALE = 50 |
|
|
|
MORE_LIKE = False |
|
MORE_COLLECTION = False |
|
LESS_AI = False |
|
MORE_RELATIVE_POP = True |
|
|
|
WEIGHT_PATH = './scorer.pt' |
|
|
|
DECIVE = 'cuda' |
|
|
|
|
|
def main(): |
|
model = Scorer(BACKBONE) |
|
transform = transforms.Compose([ |
|
transforms.Resize((RESOLUTION, RESOLUTION)), |
|
transforms.ToTensor(), |
|
transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225] |
|
) |
|
]) |
|
model.load_state_dict(torch.load(WEIGHT_PATH)) |
|
model.eval() |
|
model.to(DECIVE) |
|
|
|
|
|
fig = plt.figure(figsize=(20, 20)) |
|
for i, img_file in enumerate(IMG_FILE_LIST): |
|
img = Image.open(img_file, 'r').convert('RGB') |
|
transformed_img = transform(img).unsqueeze(0).to(DECIVE) |
|
transformed_img.requires_grad = True |
|
liking_pred, collection_pred, ai_pred, relative_pop = model(transformed_img, torch.tensor([1]), False) |
|
ax = fig.add_subplot(1, len(IMG_FILE_LIST), i + 1) |
|
|
|
backwardee = 0 |
|
if MORE_LIKE: |
|
backwardee -= liking_pred |
|
if MORE_COLLECTION: |
|
backwardee -= collection_pred |
|
if LESS_AI: |
|
backwardee += ai_pred |
|
if MORE_RELATIVE_POP: |
|
backwardee -= relative_pop |
|
if SHOW_GRAD: |
|
model.zero_grad() |
|
|
|
backwardee.backward() |
|
|
|
gradients = transformed_img.grad |
|
|
|
gradients = gradients.squeeze(0).detach() |
|
|
|
gradients = transforms.Resize((img.height, img.width))(gradients) |
|
|
|
img = transforms.ToTensor()(img) |
|
img = img + gradients.cpu() * GRAD_SCALE |
|
img = transforms.ToPILImage()(img.cpu()) |
|
ax.imshow(img) |
|
del img |
|
ax.set_title( |
|
f'Liking: {liking_pred.item():.3f}\nCollection: {collection_pred.item():.3f}\nAI: {ai_pred.item() * 100:.3f}%\nPopularity: {relative_pop.item():.3f}') |
|
plt.show() |
|
pass |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |