File size: 6,646 Bytes
2f85de4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# python3.7
"""Unit test for loading pre-trained models.
Basically, this file tests whether the perceptual model (VGG16) and the
inception model (InceptionV3), which are commonly used for loss computation and
evaluation, have the expected behavior after loading pre-trained weights. In
particular, we compare with the models from repo
https://github.com/NVlabs/stylegan2-ada-pytorch
"""
import torch
from models import build_model
from utils.misc import download_url
__all__ = ['test_model']
_BATCH_SIZE = 4
# pylint: disable=line-too-long
_PERCEPTUAL_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
_INCEPTION_URL = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt'
# pylint: enable=line-too-long
def test_model():
"""Collects all model tests."""
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print('========== Start Model Test ==========')
test_perceptual()
test_inception()
print('========== Finish Model Test ==========')
def test_perceptual():
"""Test the perceptual model."""
print('===== Testing Perceptual Model =====')
print('Build test model.')
model = build_model('PerceptualModel',
use_torchvision=False,
no_top=False,
enable_lpips=True)
print('Build reference model.')
ref_model_path, _, = download_url(_PERCEPTUAL_URL)
with open(ref_model_path, 'rb') as f:
ref_model = torch.jit.load(f).eval().cuda()
print('Test performance: ')
for size in [224, 128, 256, 512, 1024]:
raw_img = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size))
raw_img_comp = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size))
# The test model requires input images to have range [-1, 1].
img = raw_img.to(torch.float32).cuda() / 127.5 - 1
img_comp = raw_img_comp.to(torch.float32).cuda() / 127.5 - 1
feat = model(img, resize_input=True, return_tensor='feature')
pred = model(img, resize_input=True, return_tensor='prediction')
lpips = model(img, img_comp, resize_input=False, return_tensor='lpips')
assert feat.shape == (_BATCH_SIZE, 4096)
assert pred.shape == (_BATCH_SIZE, 1000)
assert lpips.shape == (_BATCH_SIZE,)
# The reference model requires input images to have range [0, 255].
img = raw_img.to(torch.float32).cuda()
img_comp = raw_img_comp.to(torch.float32).cuda()
ref_feat = ref_model(img, resize_images=True, return_features=True)
ref_pred = ref_model(img, resize_images=True, return_features=False)
temp = ref_model(torch.cat([img, img_comp], dim=0),
resize_images=False, return_lpips=True).chunk(2)
ref_lpips = (temp[0] - temp[1]).square().sum(dim=1, keepdim=False)
assert ref_feat.shape == (_BATCH_SIZE, 4096)
assert ref_pred.shape == (_BATCH_SIZE, 1000)
assert ref_lpips.shape == (_BATCH_SIZE,)
print(f' Size {size}x{size}, feature (with resize):\n '
f'mean: {(feat - ref_feat).abs().mean().item():.3e}, '
f'max: {(feat - ref_feat).abs().max().item():.3e}, '
f'ref_mean: {ref_feat.abs().mean().item():.3e}, '
f'ref_max: {ref_feat.abs().max().item():.3e}.')
print(f' Size {size}x{size}, prediction (with resize):\n '
f'mean: {(pred - ref_pred).abs().mean().item():.3e}, '
f'max: {(pred - ref_pred).abs().max().item():.3e}, '
f'ref_mean: {ref_pred.abs().mean().item():.3e}, '
f'ref_max: {ref_pred.abs().max().item():.3e}.')
print(f' Size {size}x{size}, LPIPS (without resize):\n '
f'mean: {(lpips - ref_lpips).abs().mean().item():.3e}, '
f'max: {(lpips - ref_lpips).abs().max().item():.3e}, '
f'ref_mean: {ref_lpips.abs().mean().item():.3e}, '
f'ref_max: {ref_lpips.abs().max().item():.3e}.')
def test_inception():
"""Test the inception model."""
print('===== Testing Inception Model =====')
print('Build test model.')
model = build_model('InceptionModel', align_tf=True)
print('Build reference model.')
ref_model_path, _, = download_url(_INCEPTION_URL)
with open(ref_model_path, 'rb') as f:
ref_model = torch.jit.load(f).eval().cuda()
print('Test performance: ')
for size in [299, 128, 256, 512, 1024]:
raw_img = torch.randint(0, 256, size=(_BATCH_SIZE, 3, size, size))
# The test model requires input images to have range [-1, 1].
img = raw_img.to(torch.float32).cuda() / 127.5 - 1
feat = model(img)
pred = model(img, output_predictions=True)
pred_nb = model(img, output_predictions=True, remove_logits_bias=True)
assert feat.shape == (_BATCH_SIZE, 2048)
assert pred.shape == (_BATCH_SIZE, 1008)
assert pred_nb.shape == (_BATCH_SIZE, 1008)
# The reference model requires input images to have range [0, 255].
img = raw_img.to(torch.float32).cuda()
ref_feat = ref_model(img, return_features=True)
ref_pred = ref_model(img)
ref_pred_nb = ref_model(img, no_output_bias=True)
assert ref_feat.shape == (_BATCH_SIZE, 2048)
assert ref_pred.shape == (_BATCH_SIZE, 1008)
assert ref_pred_nb.shape == (_BATCH_SIZE, 1008)
print(f' Size {size}x{size}, feature:\n '
f'mean: {(feat - ref_feat).abs().mean().item():.3e}, '
f'max: {(feat - ref_feat).abs().max().item():.3e}, '
f'ref_mean: {ref_feat.abs().mean().item():.3e}, '
f'ref_max: {ref_feat.abs().max().item():.3e}.')
print(f' Size {size}x{size}, prediction:\n '
f'mean: {(pred - ref_pred).abs().mean().item():.3e}, '
f'max: {(pred - ref_pred).abs().max().item():.3e}, '
f'ref_mean: {ref_pred.abs().mean().item():.3e}, '
f'ref_max: {ref_pred.abs().max().item():.3e}.')
print(f' Size {size}x{size}, prediction (without bias):\n '
f'mean: {(pred_nb - ref_pred_nb).abs().mean().item():.3e}, '
f'max: {(pred_nb - ref_pred_nb).abs().max().item():.3e}, '
f'ref_mean: {ref_pred_nb.abs().mean().item():.3e}, '
f'ref_max: {ref_pred_nb.abs().max().item():.3e}.')
|