GMC-IQA / models /monet.py
Zevin2023's picture
MoC-IQA
07e1105
raw
history blame
9.68 kB
"""
The completion for Mean-opinion Network(MoNet)
"""
import torch
import torch.nn as nn
import timm
from timm.models.vision_transformer import Block
from einops import rearrange
class Attention_Block(nn.Module):
def __init__(self, dim, drop=0.1):
super().__init__()
self.c_q = nn.Linear(dim, dim)
self.c_k = nn.Linear(dim, dim)
self.c_v = nn.Linear(dim, dim)
self.norm_fact = dim ** -0.5
self.softmax = nn.Softmax(dim=-1)
self.proj_drop = nn.Dropout(drop)
def forward(self, x):
_x = x
B, C, N = x.shape
q = self.c_q(x)
k = self.c_k(x)
v = self.c_v(x)
attn = q @ k.transpose(-2, -1) * self.norm_fact
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
x = self.proj_drop(x)
x = x + _x
return x
class Self_Attention(nn.Module):
""" Self attention Layer"""
def __init__(self, in_dim):
super(Self_Attention, self).__init__()
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, inFeature):
bs, C, w, h = inFeature.size()
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1)
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
energy = torch.bmm(proj_query, proj_key)
attention = self.softmax(energy)
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(bs, C, w, h)
out = self.gamma * out + inFeature
return out
class MAL(nn.Module):
"""
Multi-view Attention Learning (MAL) module
"""
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
super().__init__()
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
# Self attention module for each input feature
self.attention_module = nn.ModuleList()
for _ in range(feature_num):
self.attention_module.append(Self_Attention(in_dim))
self.feature_num = feature_num
self.in_dim = in_dim
def forward(self, features):
feature = torch.tensor([]).cuda()
for index, _ in enumerate(features):
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
features = feature
input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
bs, _, _ = input_tensor.shape # [2, 3072, 784]
in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim, c=self.feature_num) # bs, 768, 28 * 28 * feature_num
feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
in_channel = input_tensor.permute(0, 2, 1) # bs, 28 * 28, 768 * feature_num
channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
c=self.feature_num) + channel_weight_sum.permute(0, 2, 1)) / 2 # [2, 3072, 784]
weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
return weight_sum_res # bs, 768, 28 * 28
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out)
def clear(self):
self.outputs = []
class MoNet(nn.Module):
def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
super().__init__()
self.img_size = img_size
self.input_size = img_size // patch_size
self.dim_mlp = dim_mlp
self.vit = timm.create_model(config.backbone, pretrained=False)
self.save_output = SaveOutput()
# Register Hooks
hook_handles = []
for layer in self.vit.modules():
if isinstance(layer, Block):
handle = layer.register_forward_hook(self.save_output)
hook_handles.append(handle)
self.MALs = nn.ModuleList()
for _ in range(config.mal_num):
self.MALs.append(MAL())
# Image Quality Score Regression
self.fusion_wam = MAL(feature_num=config.mal_num)
self.block = Block(dim_mlp, 12)
self.cnn = nn.Sequential(
nn.Conv2d(dim_mlp, 256, 5),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.AvgPool2d((2, 2)),
nn.Conv2d(256, 128, 3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.AvgPool2d((2, 2)),
nn.Conv2d(128, 128, 3),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.AvgPool2d((3, 3)),
)
self.fc_score = nn.Sequential(
nn.Linear(128, 128 // 2),
nn.ReLU(),
nn.Dropout(drop),
nn.Linear(128 // 2, 1),
nn.Sigmoid()
)
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
x1 = save_output.outputs[block_index[0]][:, 1:]
x2 = save_output.outputs[block_index[1]][:, 1:]
x3 = save_output.outputs[block_index[2]][:, 1:]
x4 = save_output.outputs[block_index[3]][:, 1:]
x = torch.cat((x1, x2, x3, x4), dim=2)
return x
def forward(self, x):
# Multi-level Feature From Different Transformer Blocks
_x = self.vit(x)
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
self.save_output.outputs.clear()
x = x.permute(0, 2, 1) # bs, 768 * 4, 28 * 28
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
x = x.permute(1, 0, 2, 3, 4) # bs, 4, 768, 28 * 28
# Different Opinion Features (DOF)
DOF = torch.tensor([]).cuda()
for index, _ in enumerate(self.MALs):
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # 3, bs, 768, 28, 28
# Image Quality Score Regression
wam = self.fusion_wam(DOF).permute(0, 2, 1) # bs, 28 * 28 768
wam = self.block(wam).permute(0, 2, 1)
wam = rearrange(wam, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size)
score = self.cnn(wam).squeeze(-1).squeeze(-1)
score = self.fc_score(score).view(-1)
return score
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--seed', dest='seed', type=int, default=3407)
parser.add_argument('--gpu_id', dest='gpu_id', type=str, default='0')
# model related
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224',
help='The backbone for MoNet.')
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
# data related
parser.add_argument('--dataset', dest='dataset', type=str, default='livec',
help='Support datasets: livec|koniq10k|bid|spaq')
parser.add_argument('--train_patch_num', dest='train_patch_num', type=int, default=5,
help='Number of sample patches from training image')
parser.add_argument('--test_patch_num', dest='test_patch_num', type=int, default=25,
help='Number of sample patches from testing image')
parser.add_argument('--patch_size', dest='patch_size', type=int, default=224,
help='Crop size for training & testing image patches')
# training related
parser.add_argument('--lr', dest='lr', type=float, default=1e-5, help='Learning rate')
parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=1e-5, help='Weight decay')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=11, help='Batch size')
parser.add_argument('--epochs', dest='epochs', type=int, default=50, help='Epochs for training')
parser.add_argument('--T_max', dest='T_max', type=int, default=50, help='Hyper-parameter for CosineAnnealingLR')
parser.add_argument('--eta_min', dest='eta_min', type=int, default=0, help='Hyper-parameter for CosineAnnealingLR')
parser.add_argument('--save_path', dest='save_path', type=str, default='./training_for_IQA',
help='The path where the model and logs will be saved.')
config = parser.parse_args()
# torch.autograd.set_detect_anomaly(True)
# with torch.autograd.detect_anomaly():
in_tensor = torch.zeros((2, 3, 224, 224), dtype=torch.float).cuda()
model = MoNet(config).cuda()
res = model(in_tensor)
print('{} : {} [M]'.format('#Params', sum(map(lambda x: x.numel(), model.parameters())) / 10 ** 6))
# label = torch.tensor([1, 2], dtype=torch.float).cuda()
# loss = torch.nn.L1Loss().cuda()
#
# res = model(in_tensor)
# # loss = loss_func()
# l = loss(label, res)
# print(l)
# l.backward()