|
""" |
|
The completion for Mean-opinion Network(MoNet) |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
import timm |
|
|
|
from timm.models.vision_transformer import Block |
|
from einops import rearrange |
|
|
|
|
|
class Attention_Block(nn.Module): |
|
def __init__(self, dim, drop=0.1): |
|
super().__init__() |
|
self.c_q = nn.Linear(dim, dim) |
|
self.c_k = nn.Linear(dim, dim) |
|
self.c_v = nn.Linear(dim, dim) |
|
self.norm_fact = dim ** -0.5 |
|
self.softmax = nn.Softmax(dim=-1) |
|
self.proj_drop = nn.Dropout(drop) |
|
|
|
def forward(self, x): |
|
_x = x |
|
B, C, N = x.shape |
|
q = self.c_q(x) |
|
k = self.c_k(x) |
|
v = self.c_v(x) |
|
|
|
attn = q @ k.transpose(-2, -1) * self.norm_fact |
|
attn = self.softmax(attn) |
|
x = (attn @ v).transpose(1, 2).reshape(B, C, N) |
|
x = self.proj_drop(x) |
|
x = x + _x |
|
return x |
|
|
|
|
|
class Self_Attention(nn.Module): |
|
""" Self attention Layer""" |
|
|
|
def __init__(self, in_dim): |
|
super(Self_Attention, self).__init__() |
|
|
|
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) |
|
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1) |
|
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) |
|
self.gamma = nn.Parameter(torch.zeros(1)) |
|
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
def forward(self, inFeature): |
|
bs, C, w, h = inFeature.size() |
|
|
|
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1) |
|
proj_key = self.kConv(inFeature).view(bs, -1, w * h) |
|
energy = torch.bmm(proj_query, proj_key) |
|
attention = self.softmax(energy) |
|
proj_value = self.vConv(inFeature).view(bs, -1, w * h) |
|
|
|
out = torch.bmm(proj_value, attention.permute(0, 2, 1)) |
|
out = out.view(bs, C, w, h) |
|
|
|
out = self.gamma * out + inFeature |
|
|
|
return out |
|
|
|
|
|
class MAL(nn.Module): |
|
""" |
|
Multi-view Attention Learning (MAL) module |
|
""" |
|
|
|
def __init__(self, in_dim=768, feature_num=4, feature_size=28): |
|
super().__init__() |
|
|
|
self.channel_attention = Attention_Block(in_dim * feature_num) |
|
self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) |
|
|
|
|
|
self.attention_module = nn.ModuleList() |
|
for _ in range(feature_num): |
|
self.attention_module.append(Self_Attention(in_dim)) |
|
|
|
self.feature_num = feature_num |
|
self.in_dim = in_dim |
|
|
|
def forward(self, features): |
|
feature = torch.tensor([]).cuda() |
|
for index, _ in enumerate(features): |
|
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0) |
|
features = feature |
|
|
|
input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') |
|
bs, _, _ = input_tensor.shape |
|
|
|
in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim, c=self.feature_num) |
|
feature_weight_sum = self.feature_attention(in_feature) |
|
|
|
in_channel = input_tensor.permute(0, 2, 1) |
|
channel_weight_sum = self.channel_attention(in_channel) |
|
|
|
weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim, |
|
c=self.feature_num) + channel_weight_sum.permute(0, 2, 1)) / 2 |
|
|
|
weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1) |
|
|
|
return weight_sum_res |
|
|
|
|
|
class SaveOutput: |
|
def __init__(self): |
|
self.outputs = [] |
|
|
|
def __call__(self, module, module_in, module_out): |
|
self.outputs.append(module_out) |
|
|
|
def clear(self): |
|
self.outputs = [] |
|
|
|
|
|
class MoNet(nn.Module): |
|
def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224): |
|
super().__init__() |
|
self.img_size = img_size |
|
self.input_size = img_size // patch_size |
|
self.dim_mlp = dim_mlp |
|
|
|
self.vit = timm.create_model(config.backbone, pretrained=False) |
|
self.save_output = SaveOutput() |
|
|
|
|
|
hook_handles = [] |
|
for layer in self.vit.modules(): |
|
if isinstance(layer, Block): |
|
handle = layer.register_forward_hook(self.save_output) |
|
hook_handles.append(handle) |
|
|
|
self.MALs = nn.ModuleList() |
|
for _ in range(config.mal_num): |
|
self.MALs.append(MAL()) |
|
|
|
|
|
self.fusion_wam = MAL(feature_num=config.mal_num) |
|
self.block = Block(dim_mlp, 12) |
|
self.cnn = nn.Sequential( |
|
nn.Conv2d(dim_mlp, 256, 5), |
|
nn.BatchNorm2d(256), |
|
nn.ReLU(inplace=True), |
|
nn.AvgPool2d((2, 2)), |
|
nn.Conv2d(256, 128, 3), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(inplace=True), |
|
nn.AvgPool2d((2, 2)), |
|
nn.Conv2d(128, 128, 3), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(inplace=True), |
|
nn.AvgPool2d((3, 3)), |
|
) |
|
self.fc_score = nn.Sequential( |
|
nn.Linear(128, 128 // 2), |
|
nn.ReLU(), |
|
nn.Dropout(drop), |
|
nn.Linear(128 // 2, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]): |
|
x1 = save_output.outputs[block_index[0]][:, 1:] |
|
x2 = save_output.outputs[block_index[1]][:, 1:] |
|
x3 = save_output.outputs[block_index[2]][:, 1:] |
|
x4 = save_output.outputs[block_index[3]][:, 1:] |
|
x = torch.cat((x1, x2, x3, x4), dim=2) |
|
return x |
|
|
|
def forward(self, x): |
|
|
|
_x = self.vit(x) |
|
x = self.extract_feature(self.save_output) |
|
self.save_output.outputs.clear() |
|
|
|
x = x.permute(0, 2, 1) |
|
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) |
|
x = x.permute(1, 0, 2, 3, 4) |
|
|
|
|
|
DOF = torch.tensor([]).cuda() |
|
for index, _ in enumerate(self.MALs): |
|
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0) |
|
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) |
|
|
|
|
|
wam = self.fusion_wam(DOF).permute(0, 2, 1) |
|
wam = self.block(wam).permute(0, 2, 1) |
|
wam = rearrange(wam, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size) |
|
score = self.cnn(wam).squeeze(-1).squeeze(-1) |
|
score = self.fc_score(score).view(-1) |
|
|
|
return score |
|
|
|
|
|
if __name__ == '__main__': |
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--seed', dest='seed', type=int, default=3407) |
|
parser.add_argument('--gpu_id', dest='gpu_id', type=str, default='0') |
|
|
|
|
|
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224', |
|
help='The backbone for MoNet.') |
|
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.') |
|
|
|
|
|
parser.add_argument('--dataset', dest='dataset', type=str, default='livec', |
|
help='Support datasets: livec|koniq10k|bid|spaq') |
|
parser.add_argument('--train_patch_num', dest='train_patch_num', type=int, default=5, |
|
help='Number of sample patches from training image') |
|
parser.add_argument('--test_patch_num', dest='test_patch_num', type=int, default=25, |
|
help='Number of sample patches from testing image') |
|
parser.add_argument('--patch_size', dest='patch_size', type=int, default=224, |
|
help='Crop size for training & testing image patches') |
|
|
|
|
|
parser.add_argument('--lr', dest='lr', type=float, default=1e-5, help='Learning rate') |
|
parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=1e-5, help='Weight decay') |
|
parser.add_argument('--batch_size', dest='batch_size', type=int, default=11, help='Batch size') |
|
parser.add_argument('--epochs', dest='epochs', type=int, default=50, help='Epochs for training') |
|
parser.add_argument('--T_max', dest='T_max', type=int, default=50, help='Hyper-parameter for CosineAnnealingLR') |
|
parser.add_argument('--eta_min', dest='eta_min', type=int, default=0, help='Hyper-parameter for CosineAnnealingLR') |
|
|
|
parser.add_argument('--save_path', dest='save_path', type=str, default='./training_for_IQA', |
|
help='The path where the model and logs will be saved.') |
|
|
|
config = parser.parse_args() |
|
|
|
|
|
|
|
in_tensor = torch.zeros((2, 3, 224, 224), dtype=torch.float).cuda() |
|
model = MoNet(config).cuda() |
|
res = model(in_tensor) |
|
|
|
print('{} : {} [M]'.format('#Params', sum(map(lambda x: x.numel(), model.parameters())) / 10 ** 6)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|