RAR / evaluator /inception.py
yucornetto's picture
upload
51ce47d verified
raw
history blame
8.26 kB
"""This file is for Inception model borrowed from torch metrics / fidelity.
This file may have been modified by Bytedance Ltd. and/or its affiliates (“Bytedance's Modifications”).
All Bytedance's Modifications are Copyright (year) Bytedance Ltd. and/or its affiliates.
Reference:
https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py
"""
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from torch_fidelity.feature_extractor_base import FeatureExtractorBase
from torch_fidelity.helpers import vassert
from torch_fidelity.feature_extractor_inceptionv3 import BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE_1, InceptionE_2
from torch_fidelity.interpolate_compat_tensorflow import interpolate_bilinear_2d_like_tensorflow1x
try:
from torchvision.models.utils import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
# Note: Compared shasum and models should be the same.
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
class FeatureExtractorInceptionV3(FeatureExtractorBase):
INPUT_IMAGE_SIZE = 299
def __init__(
self,
name,
features_list,
**kwargs,
):
"""
InceptionV3 feature extractor for 2D RGB 24bit images.
Args:
name (str): Unique name of the feature extractor, must be the same as used in
:func:`register_feature_extractor`.
features_list (list): A list of the requested feature names, which will be produced for each input. This
feature extractor provides the following features:
- '64'
- '192'
- '768'
- '2048'
- 'logits_unbiased'
- 'logits'
"""
super(FeatureExtractorInceptionV3, self).__init__(name, features_list)
self.feature_extractor_internal_dtype = torch.float64
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2)
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3)
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
self.MaxPool_1 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3)
self.MaxPool_2 = torch.nn.MaxPool2d(kernel_size=3, stride=2)
self.Mixed_5b = InceptionA(192, pool_features=32)
self.Mixed_5c = InceptionA(256, pool_features=64)
self.Mixed_5d = InceptionA(288, pool_features=64)
self.Mixed_6a = InceptionB(288)
self.Mixed_6b = InceptionC(768, channels_7x7=128)
self.Mixed_6c = InceptionC(768, channels_7x7=160)
self.Mixed_6d = InceptionC(768, channels_7x7=160)
self.Mixed_6e = InceptionC(768, channels_7x7=192)
self.Mixed_7a = InceptionD(768)
self.Mixed_7b = InceptionE_1(1280)
self.Mixed_7c = InceptionE_2(2048)
self.AvgPool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
self.fc = torch.nn.Linear(2048, 1008)
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
#state_dict = torch.load(FID_WEIGHTS_URL, map_location='cpu')
self.load_state_dict(state_dict)
self.to(self.feature_extractor_internal_dtype)
self.requires_grad_(False)
self.eval()
def forward(self, x):
vassert(torch.is_tensor(x) and x.dtype == torch.uint8, 'Expecting image as torch.Tensor with dtype=torch.uint8')
vassert(x.dim() == 4 and x.shape[1] == 3, f'Input is not Bx3xHxW: {x.shape}')
features = {}
remaining_features = self.features_list.copy()
x = x.to(self.feature_extractor_internal_dtype)
# N x 3 x ? x ?
x = interpolate_bilinear_2d_like_tensorflow1x(
x,
size=(self.INPUT_IMAGE_SIZE, self.INPUT_IMAGE_SIZE),
align_corners=False,
)
# N x 3 x 299 x 299
# x = (x - 128) * torch.tensor(0.0078125, dtype=torch.float32, device=x.device) # really happening in graph
x = (x - 128) / 128 # but this gives bit-exact output _of this step_ too
# N x 3 x 299 x 299
x = self.Conv2d_1a_3x3(x)
# N x 32 x 149 x 149
x = self.Conv2d_2a_3x3(x)
# N x 32 x 147 x 147
x = self.Conv2d_2b_3x3(x)
# N x 64 x 147 x 147
x = self.MaxPool_1(x)
# N x 64 x 73 x 73
if '64' in remaining_features:
features['64'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
remaining_features.remove('64')
if len(remaining_features) == 0:
return features
x = self.Conv2d_3b_1x1(x)
# N x 80 x 73 x 73
x = self.Conv2d_4a_3x3(x)
# N x 192 x 71 x 71
x = self.MaxPool_2(x)
# N x 192 x 35 x 35
if '192' in remaining_features:
features['192'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1)
remaining_features.remove('192')
if len(remaining_features) == 0:
return features
x = self.Mixed_5b(x)
# N x 256 x 35 x 35
x = self.Mixed_5c(x)
# N x 288 x 35 x 35
x = self.Mixed_5d(x)
# N x 288 x 35 x 35
x = self.Mixed_6a(x)
# N x 768 x 17 x 17
x = self.Mixed_6b(x)
# N x 768 x 17 x 17
x = self.Mixed_6c(x)
# N x 768 x 17 x 17
x = self.Mixed_6d(x)
# N x 768 x 17 x 17
x = self.Mixed_6e(x)
# N x 768 x 17 x 17
if '768' in remaining_features:
features['768'] = F.adaptive_avg_pool2d(x, output_size=(1, 1)).squeeze(-1).squeeze(-1).to(torch.float32)
remaining_features.remove('768')
if len(remaining_features) == 0:
return features
x = self.Mixed_7a(x)
# N x 1280 x 8 x 8
x = self.Mixed_7b(x)
# N x 2048 x 8 x 8
x = self.Mixed_7c(x)
# N x 2048 x 8 x 8
x = self.AvgPool(x)
# N x 2048 x 1 x 1
x = torch.flatten(x, 1)
# N x 2048
if '2048' in remaining_features:
features['2048'] = x
remaining_features.remove('2048')
if len(remaining_features) == 0:
return features
if 'logits_unbiased' in remaining_features:
x = x.mm(self.fc.weight.T)
# N x 1008 (num_classes)
features['logits_unbiased'] = x
remaining_features.remove('logits_unbiased')
if len(remaining_features) == 0:
return features
x = x + self.fc.bias.unsqueeze(0)
else:
x = self.fc(x)
# N x 1008 (num_classes)
features['logits'] = x
return features
@staticmethod
def get_provided_features_list():
return '64', '192', '768', '2048', 'logits_unbiased', 'logits'
@staticmethod
def get_default_feature_layer_for_metric(metric):
return {
'isc': 'logits_unbiased',
'fid': '2048',
'kid': '2048',
'prc': '2048',
}[metric]
@staticmethod
def can_be_compiled():
return True
@staticmethod
def get_dummy_input_for_compile():
return (torch.rand([1, 3, 4, 4]) * 255).to(torch.uint8)
def get_inception_model():
model = FeatureExtractorInceptionV3("inception_model", ["2048", "logits_unbiased"])
return model