Zevin2023 commited on
Commit
07e1105
1 Parent(s): 82f6e27
app.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from models import monet as MoNet
7
+ import argparse
8
+ from utils.dataset.process import ToTensor, Normalize
9
+ import gradio as gr
10
+
11
+ def load_image(img_path):
12
+ d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR)
13
+ # d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
14
+ d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
15
+ d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
16
+ d_img = np.array(d_img).astype('float32') / 255
17
+ d_img = np.transpose(d_img, (2, 0, 1))
18
+
19
+ return d_img
20
+
21
+ def predict(image):
22
+ parser = argparse.ArgumentParser()
23
+ # model related
24
+ parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224',
25
+ help='The backbone for MoNet.')
26
+ parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
27
+ config = parser.parse_args()
28
+
29
+ model = MoNet.MoNet(config).cuda()
30
+ model.load_state_dict(torch.load('./checkpoints/best_model.pkl'))
31
+ model.eval()
32
+
33
+ trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
34
+
35
+ """Run a single prediction on the model"""
36
+ img = load_image(image)
37
+ img_tensor = trans(img).unsqueeze(0).cuda()
38
+ iq = model(img_tensor).cpu().detach().numpy().tolist()[0]
39
+
40
+ return "The image quality of the image is: {}".format(round(iq, 4))
41
+
42
+ # os.system("wget -O ./checkpoints/best_model.pkl https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")
43
+
44
+ interface = gr.Interface(fn=predict, inputs="image", outputs="text")
45
+ interface.launch(server_name='127.0.0.1',server_port=8088)
models/__pycache__/WAIQT.cpython-38.pyc ADDED
Binary file (6.91 kB). View file
 
models/__pycache__/gc_loss.cpython-37.pyc ADDED
Binary file (5.34 kB). View file
 
models/__pycache__/gc_loss.cpython-38.pyc ADDED
Binary file (3.11 kB). View file
 
models/__pycache__/monet.cpython-37.pyc ADDED
Binary file (8.07 kB). View file
 
models/__pycache__/monet.cpython-38.pyc ADDED
Binary file (8.1 kB). View file
 
models/gc_loss.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+
5
+ # class GC_Loss(nn.Module):
6
+ # def __init__(self, queue_len=800):
7
+ # super(GC_Loss, self).__init__()
8
+ # self.pred_queue = list()
9
+ # self.gt_queue = list()
10
+ # self.queue_len = 0
11
+
12
+ # self.queue_max_len = queue_len
13
+
14
+ # print('CCWD Length: ', queue_len)
15
+
16
+ # self.l1_loss = torch.nn.L1Loss().cuda()
17
+ # self.l2_loss = torch.nn.MSELoss().cuda()
18
+
19
+ # def enqueue(self, pred, gt):
20
+ # bs = pred.shape[0]
21
+ # self.queue_len = self.queue_len + bs
22
+
23
+ # self.pred_queue = self.pred_queue + pred.cpu().detach().numpy().tolist()
24
+ # self.gt_queue = self.gt_queue + gt.cpu().detach().numpy().tolist()
25
+
26
+ # if self.queue_len > self.queue_max_len:
27
+ # self.dequeue(self.queue_len - self.queue_max_len)
28
+ # self.queue_len = self.queue_max_len
29
+
30
+ # def dequeue(self, n):
31
+ # for index in range(n):
32
+ # self.pred_queue.pop(0)
33
+ # self.gt_queue.pop(0)
34
+
35
+ # def clear(self):
36
+ # self.pred_queue.clear()
37
+ # self.gt_queue.clear()
38
+
39
+ # def forward(self, x, y):
40
+ # x_queue = self.pred_queue.copy()
41
+ # y_queue = self.gt_queue.copy()
42
+
43
+ # # 获取队列中的所有值
44
+ # x_all = torch.cat((x, torch.tensor(x_queue).cuda()), dim=0)
45
+ # y_all = torch.cat((y, torch.tensor(y_queue).cuda()), dim=0)
46
+
47
+ # # 估计均值和方差
48
+ # x_bar = torch.mean(x_all, dim=0)
49
+ # x_std = torch.std(x_all, dim=0)
50
+
51
+ # y_bar = torch.mean(y_all, dim=0)
52
+ # y_std = torch.std(y_all, dim=0)
53
+
54
+ # # 估计预测值在整体值中的PLCC
55
+ # diff_x_plcc = (x - x_bar) # [bs, 1]
56
+ # diff_y_plcc = (y - y_bar) # [bs, 1]
57
+
58
+ # x1 = torch.sum(torch.mul(diff_x_plcc, diff_y_plcc))
59
+ # x2_1 = torch.sqrt(torch.sum(torch.mul(diff_x_plcc, diff_x_plcc)))
60
+ # x2_2 = torch.sqrt(torch.sum(torch.mul(diff_y_plcc, diff_y_plcc)))
61
+
62
+ # # 对所有值标准化
63
+ # diff_x = (x_all - x_bar) / x_std # [bs, 1]
64
+ # diff_y = (y_all - y_bar) / y_std # [bs, 1]
65
+
66
+ # rank_x = diff_x.reshape(-1, 1)
67
+ # rank_y = diff_y.reshape(-1, 1)
68
+
69
+ # rank_x = rank_x - rank_x.transpose(1, 0)
70
+ # rank_y = rank_y - rank_y.transpose(1, 0)
71
+
72
+ # # 对所有值估计排序
73
+ # rank_x = torch.sum(1 / 2 * (1 + torch.erf(rank_x)), dim=1)
74
+ # rank_y = torch.sum(1 / 2 * (1 + torch.erf(rank_y)), dim=1)
75
+
76
+ # # 计算排序后的均值和方差
77
+ # rank_x_bar = torch.mean(rank_x, dim=0)
78
+ # rank_x_std = torch.std(rank_x, dim=0)
79
+ # rank_y_bar = torch.mean(rank_y, dim=0)
80
+ # rank_y_std = torch.std(rank_y, dim=0)
81
+
82
+ # # 估计预测值在整体值中的SROCC
83
+ # rank_x_ = (x - rank_x_bar) / rank_x_std # [bs, 1]
84
+ # rank_y_ = (y - rank_y_bar) / rank_y_std # [bs, 1]
85
+
86
+ # x1_rank = torch.sum(torch.mul(rank_x_, rank_y_))
87
+ # x2_1_rank = torch.sqrt(torch.sum(torch.mul(rank_x_, rank_x_)))
88
+ # x2_2_rank = torch.sqrt(torch.sum(torch.mul(rank_y_, rank_y_)))
89
+
90
+ # self.enqueue(x, y)
91
+
92
+ # return (0.5 * ((1 - x1 / (x2_1 * x2_2)) + (1 - (x1_rank / (x2_1_rank * x2_2_rank)))) + 1) * self.l2_loss(x, y)
93
+
94
+ class GC_Loss(nn.Module):
95
+ def __init__(self, queue_len=800, alpha=0.5, beta=0.5, gamma=1):
96
+ super(GC_Loss, self).__init__()
97
+ self.pred_queue = list()
98
+ self.gt_queue = list()
99
+ self.queue_len = 0
100
+
101
+ self.queue_max_len = queue_len
102
+ print('The queue length is: ', self.queue_max_len)
103
+ self.mse = torch.nn.MSELoss().cuda()
104
+
105
+ self.alpha, self.beta, self.gamma = alpha, beta, gamma
106
+
107
+ def consistency(self, pred_data, gt_data):
108
+ pred_one_batch, pred_queue = pred_data
109
+ gt_one_batch, gt_queue = gt_data
110
+
111
+ pred_mean = torch.mean(pred_queue)
112
+ gt_mean = torch.mean(gt_queue)
113
+
114
+ diff_pred = pred_one_batch - pred_mean
115
+ diff_gt = gt_one_batch - gt_mean
116
+
117
+ x1 = torch.sum(torch.mul(diff_pred, diff_gt))
118
+ x2_1 = torch.sqrt(torch.sum(torch.mul(diff_pred, diff_pred)))
119
+ x2_2 = torch.sqrt(torch.sum(torch.mul(diff_gt, diff_gt)))
120
+
121
+ return x1 / (x2_1 * x2_2)
122
+
123
+ def ppra(self, x):
124
+ """
125
+ Pairwise Preference-based Rank Approximation
126
+ """
127
+
128
+ x_bar, x_std = torch.mean(x), torch.std(x)
129
+ x_n = (x - x_bar) / x_std
130
+ x_n_T = x_n.reshape(-1, 1)
131
+
132
+ rank_x = x_n_T - x_n_T.transpose(1, 0)
133
+ rank_x = torch.sum(1 / 2 * (1 + torch.erf(rank_x / torch.sqrt(torch.tensor(2, dtype=torch.float)))), dim=1)
134
+
135
+ return rank_x
136
+
137
+ @torch.no_grad()
138
+ def enqueue(self, pred, gt):
139
+ bs = pred.shape[0]
140
+ self.queue_len = self.queue_len + bs
141
+
142
+ self.pred_queue = self.pred_queue + pred.tolist()
143
+ self.gt_queue = self.gt_queue + gt.cpu().detach().numpy().tolist()
144
+
145
+ if self.queue_len > self.queue_max_len:
146
+ self.dequeue(self.queue_len - self.queue_max_len)
147
+ self.queue_len = self.queue_max_len
148
+
149
+ @torch.no_grad()
150
+ def dequeue(self, n):
151
+ for _ in range(n):
152
+ self.pred_queue.pop(0)
153
+ self.gt_queue.pop(0)
154
+
155
+ def clear(self):
156
+ self.pred_queue.clear()
157
+ self.gt_queue.clear()
158
+
159
+ def forward(self, x, y):
160
+ x_queue = self.pred_queue.copy()
161
+ y_queue = self.gt_queue.copy()
162
+
163
+ x_all = torch.cat((x, torch.tensor(x_queue).cuda()), dim=0)
164
+ y_all = torch.cat((y, torch.tensor(y_queue).cuda()), dim=0)
165
+
166
+ PLCC = self.consistency((x, x_all), (y, y_all))
167
+ PGC = 1 - PLCC
168
+
169
+ rank_x = self.ppra(x_all)
170
+ rank_y = self.ppra(y_all)
171
+ SROCC = self.consistency((rank_x[:x.shape[0]], rank_x), (rank_y[:y.shape[0]], rank_y))
172
+ SGC = 1 - SROCC
173
+
174
+ GC = (self.alpha * PGC + self.beta * SGC + self.gamma) * self.mse(x, y)
175
+ self.enqueue(x, y)
176
+
177
+ return GC
178
+
179
+
180
+ if __name__ == '__main__':
181
+ gc = GC_Loss().cuda()
182
+ x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float).cuda()
183
+ y = torch.tensor([6, 7, 8, 9, 15], dtype=torch.float).cuda()
184
+
185
+ res = gc(x, y)
186
+
187
+ print(res)
models/monet.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The completion for Mean-opinion Network(MoNet)
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import timm
7
+
8
+ from timm.models.vision_transformer import Block
9
+ from einops import rearrange
10
+
11
+
12
+ class Attention_Block(nn.Module):
13
+ def __init__(self, dim, drop=0.1):
14
+ super().__init__()
15
+ self.c_q = nn.Linear(dim, dim)
16
+ self.c_k = nn.Linear(dim, dim)
17
+ self.c_v = nn.Linear(dim, dim)
18
+ self.norm_fact = dim ** -0.5
19
+ self.softmax = nn.Softmax(dim=-1)
20
+ self.proj_drop = nn.Dropout(drop)
21
+
22
+ def forward(self, x):
23
+ _x = x
24
+ B, C, N = x.shape
25
+ q = self.c_q(x)
26
+ k = self.c_k(x)
27
+ v = self.c_v(x)
28
+
29
+ attn = q @ k.transpose(-2, -1) * self.norm_fact
30
+ attn = self.softmax(attn)
31
+ x = (attn @ v).transpose(1, 2).reshape(B, C, N)
32
+ x = self.proj_drop(x)
33
+ x = x + _x
34
+ return x
35
+
36
+
37
+ class Self_Attention(nn.Module):
38
+ """ Self attention Layer"""
39
+
40
+ def __init__(self, in_dim):
41
+ super(Self_Attention, self).__init__()
42
+
43
+ self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
44
+ self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
45
+ self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
46
+ self.gamma = nn.Parameter(torch.zeros(1))
47
+
48
+ self.softmax = nn.Softmax(dim=-1)
49
+
50
+ def forward(self, inFeature):
51
+ bs, C, w, h = inFeature.size()
52
+
53
+ proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1)
54
+ proj_key = self.kConv(inFeature).view(bs, -1, w * h)
55
+ energy = torch.bmm(proj_query, proj_key)
56
+ attention = self.softmax(energy)
57
+ proj_value = self.vConv(inFeature).view(bs, -1, w * h)
58
+
59
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
60
+ out = out.view(bs, C, w, h)
61
+
62
+ out = self.gamma * out + inFeature
63
+
64
+ return out
65
+
66
+
67
+ class MAL(nn.Module):
68
+ """
69
+ Multi-view Attention Learning (MAL) module
70
+ """
71
+
72
+ def __init__(self, in_dim=768, feature_num=4, feature_size=28):
73
+ super().__init__()
74
+
75
+ self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
76
+ self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
77
+
78
+ # Self attention module for each input feature
79
+ self.attention_module = nn.ModuleList()
80
+ for _ in range(feature_num):
81
+ self.attention_module.append(Self_Attention(in_dim))
82
+
83
+ self.feature_num = feature_num
84
+ self.in_dim = in_dim
85
+
86
+ def forward(self, features):
87
+ feature = torch.tensor([]).cuda()
88
+ for index, _ in enumerate(features):
89
+ feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
90
+ features = feature
91
+
92
+ input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
93
+ bs, _, _ = input_tensor.shape # [2, 3072, 784]
94
+
95
+ in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim, c=self.feature_num) # bs, 768, 28 * 28 * feature_num
96
+ feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
97
+
98
+ in_channel = input_tensor.permute(0, 2, 1) # bs, 28 * 28, 768 * feature_num
99
+ channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
100
+
101
+ weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
102
+ c=self.feature_num) + channel_weight_sum.permute(0, 2, 1)) / 2 # [2, 3072, 784]
103
+
104
+ weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
105
+
106
+ return weight_sum_res # bs, 768, 28 * 28
107
+
108
+
109
+ class SaveOutput:
110
+ def __init__(self):
111
+ self.outputs = []
112
+
113
+ def __call__(self, module, module_in, module_out):
114
+ self.outputs.append(module_out)
115
+
116
+ def clear(self):
117
+ self.outputs = []
118
+
119
+
120
+ class MoNet(nn.Module):
121
+ def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
122
+ super().__init__()
123
+ self.img_size = img_size
124
+ self.input_size = img_size // patch_size
125
+ self.dim_mlp = dim_mlp
126
+
127
+ self.vit = timm.create_model(config.backbone, pretrained=False)
128
+ self.save_output = SaveOutput()
129
+
130
+ # Register Hooks
131
+ hook_handles = []
132
+ for layer in self.vit.modules():
133
+ if isinstance(layer, Block):
134
+ handle = layer.register_forward_hook(self.save_output)
135
+ hook_handles.append(handle)
136
+
137
+ self.MALs = nn.ModuleList()
138
+ for _ in range(config.mal_num):
139
+ self.MALs.append(MAL())
140
+
141
+ # Image Quality Score Regression
142
+ self.fusion_wam = MAL(feature_num=config.mal_num)
143
+ self.block = Block(dim_mlp, 12)
144
+ self.cnn = nn.Sequential(
145
+ nn.Conv2d(dim_mlp, 256, 5),
146
+ nn.BatchNorm2d(256),
147
+ nn.ReLU(inplace=True),
148
+ nn.AvgPool2d((2, 2)),
149
+ nn.Conv2d(256, 128, 3),
150
+ nn.BatchNorm2d(128),
151
+ nn.ReLU(inplace=True),
152
+ nn.AvgPool2d((2, 2)),
153
+ nn.Conv2d(128, 128, 3),
154
+ nn.BatchNorm2d(128),
155
+ nn.ReLU(inplace=True),
156
+ nn.AvgPool2d((3, 3)),
157
+ )
158
+ self.fc_score = nn.Sequential(
159
+ nn.Linear(128, 128 // 2),
160
+ nn.ReLU(),
161
+ nn.Dropout(drop),
162
+ nn.Linear(128 // 2, 1),
163
+ nn.Sigmoid()
164
+ )
165
+
166
+ def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
167
+ x1 = save_output.outputs[block_index[0]][:, 1:]
168
+ x2 = save_output.outputs[block_index[1]][:, 1:]
169
+ x3 = save_output.outputs[block_index[2]][:, 1:]
170
+ x4 = save_output.outputs[block_index[3]][:, 1:]
171
+ x = torch.cat((x1, x2, x3, x4), dim=2)
172
+ return x
173
+
174
+ def forward(self, x):
175
+ # Multi-level Feature From Different Transformer Blocks
176
+ _x = self.vit(x)
177
+ x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
178
+ self.save_output.outputs.clear()
179
+
180
+ x = x.permute(0, 2, 1) # bs, 768 * 4, 28 * 28
181
+ x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
182
+ x = x.permute(1, 0, 2, 3, 4) # bs, 4, 768, 28 * 28
183
+
184
+ # Different Opinion Features (DOF)
185
+ DOF = torch.tensor([]).cuda()
186
+ for index, _ in enumerate(self.MALs):
187
+ DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
188
+ DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # 3, bs, 768, 28, 28
189
+
190
+ # Image Quality Score Regression
191
+ wam = self.fusion_wam(DOF).permute(0, 2, 1) # bs, 28 * 28 768
192
+ wam = self.block(wam).permute(0, 2, 1)
193
+ wam = rearrange(wam, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size)
194
+ score = self.cnn(wam).squeeze(-1).squeeze(-1)
195
+ score = self.fc_score(score).view(-1)
196
+
197
+ return score
198
+
199
+
200
+ if __name__ == '__main__':
201
+ import argparse
202
+
203
+ parser = argparse.ArgumentParser()
204
+ parser.add_argument('--seed', dest='seed', type=int, default=3407)
205
+ parser.add_argument('--gpu_id', dest='gpu_id', type=str, default='0')
206
+
207
+ # model related
208
+ parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224',
209
+ help='The backbone for MoNet.')
210
+ parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
211
+
212
+ # data related
213
+ parser.add_argument('--dataset', dest='dataset', type=str, default='livec',
214
+ help='Support datasets: livec|koniq10k|bid|spaq')
215
+ parser.add_argument('--train_patch_num', dest='train_patch_num', type=int, default=5,
216
+ help='Number of sample patches from training image')
217
+ parser.add_argument('--test_patch_num', dest='test_patch_num', type=int, default=25,
218
+ help='Number of sample patches from testing image')
219
+ parser.add_argument('--patch_size', dest='patch_size', type=int, default=224,
220
+ help='Crop size for training & testing image patches')
221
+
222
+ # training related
223
+ parser.add_argument('--lr', dest='lr', type=float, default=1e-5, help='Learning rate')
224
+ parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=1e-5, help='Weight decay')
225
+ parser.add_argument('--batch_size', dest='batch_size', type=int, default=11, help='Batch size')
226
+ parser.add_argument('--epochs', dest='epochs', type=int, default=50, help='Epochs for training')
227
+ parser.add_argument('--T_max', dest='T_max', type=int, default=50, help='Hyper-parameter for CosineAnnealingLR')
228
+ parser.add_argument('--eta_min', dest='eta_min', type=int, default=0, help='Hyper-parameter for CosineAnnealingLR')
229
+
230
+ parser.add_argument('--save_path', dest='save_path', type=str, default='./training_for_IQA',
231
+ help='The path where the model and logs will be saved.')
232
+
233
+ config = parser.parse_args()
234
+
235
+ # torch.autograd.set_detect_anomaly(True)
236
+ # with torch.autograd.detect_anomaly():
237
+ in_tensor = torch.zeros((2, 3, 224, 224), dtype=torch.float).cuda()
238
+ model = MoNet(config).cuda()
239
+ res = model(in_tensor)
240
+
241
+ print('{} : {} [M]'.format('#Params', sum(map(lambda x: x.numel(), model.parameters())) / 10 ** 6))
242
+
243
+ # label = torch.tensor([1, 2], dtype=torch.float).cuda()
244
+ # loss = torch.nn.L1Loss().cuda()
245
+ #
246
+ # res = model(in_tensor)
247
+ # # loss = loss_func()
248
+ # l = loss(label, res)
249
+ # print(l)
250
+ # l.backward()
utils/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (258 Bytes). View file
 
utils/__pycache__/iqa_solver.cpython-36.pyc ADDED
Binary file (3.73 kB). View file
 
utils/__pycache__/iqa_solver.cpython-37.pyc ADDED
Binary file (3.81 kB). View file
 
utils/__pycache__/iqa_solver.cpython-38.pyc ADDED
Binary file (3.64 kB). View file
 
utils/__pycache__/log_writer.cpython-36.pyc ADDED
Binary file (799 Bytes). View file
 
utils/__pycache__/log_writer.cpython-37.pyc ADDED
Binary file (809 Bytes). View file
 
utils/__pycache__/log_writer.cpython-38.pyc ADDED
Binary file (825 Bytes). View file
 
utils/__pycache__/process.cpython-38.pyc ADDED
Binary file (1.78 kB). View file
 
utils/dataset/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (266 Bytes). View file
 
utils/dataset/__pycache__/data_loader.cpython-37.pyc ADDED
Binary file (1.8 kB). View file
 
utils/dataset/__pycache__/data_loader.cpython-38.pyc ADDED
Binary file (1.44 kB). View file
 
utils/dataset/__pycache__/folders.cpython-37.pyc ADDED
Binary file (6.02 kB). View file
 
utils/dataset/__pycache__/folders.cpython-38.pyc ADDED
Binary file (5.75 kB). View file
 
utils/dataset/__pycache__/process.cpython-38.pyc ADDED
Binary file (1.93 kB). View file
 
utils/dataset/data_loader.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+
4
+ from utils.dataset import folders
5
+ from utils.dataset.process import ToTensor, Normalize, RandHorizontalFlip
6
+
7
+ class Data_Loader(object):
8
+ """Dataset class for IQA databases"""
9
+
10
+ def __init__(self, config, path, img_indx, istrain=True):
11
+
12
+ self.batch_size = config.batch_size
13
+ self.istrain = istrain
14
+ dataset = config.dataset
15
+ patch_size = config.patch_size
16
+
17
+ # Train transforms
18
+ if istrain:
19
+ transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), RandHorizontalFlip(prob_aug=0.5), ToTensor()])
20
+ else:
21
+ transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
22
+
23
+ if dataset == 'livec':
24
+ self.data = folders.LIVEC(root=path, index=img_indx, transform=transforms)
25
+ elif dataset == 'koniq10k':
26
+ self.data = folders.Koniq10k(root=path, index=img_indx, transform=transforms)
27
+ elif dataset == 'bid':
28
+ self.data = folders.BID(root=path, index=img_indx, transform=transforms)
29
+ elif dataset == 'spaq':
30
+ self.data = folders.SPAQ(root=path, index=img_indx, transform=transforms)
31
+ else:
32
+ raise Exception("Only support livec, koniq10k, bid, spaq.")
33
+
34
+ def get_data(self):
35
+ dataloader = torch.utils.data.DataLoader(self.data, batch_size=self.batch_size, shuffle=self.istrain, num_workers=8)
36
+ return dataloader
utils/dataset/dataset_info.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "livec": ["/disk1/chenzewen/sciResLife/ALotOfDataset/IQA/ChallengeDB_release", 1162],
3
+ "koniq10k": ["/disk1/chenzewen/sciResLife/ALotOfDataset/IQA/koniq-10k", 10073],
4
+ "bid": ["/disk1/chenzewen/sciResLife/ALotOfDataset/IQA/BID/ImageDatabase", 586],
5
+ "spaq": ["/home/ssl/Database/ChallengeDB_release/ChallengeDB_release/", 11125]
6
+ }
utils/dataset/folders.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.utils.data as data
2
+ import torch
3
+
4
+ from PIL import Image
5
+ import os
6
+ import scipy.io
7
+ import numpy as np
8
+ import csv
9
+ from openpyxl import load_workbook
10
+ import cv2
11
+
12
+ class LIVEC(data.Dataset):
13
+ def __init__(self, root, index, transform):
14
+ imgpath = scipy.io.loadmat(os.path.join(root, 'Data', 'AllImages_release.mat'))
15
+ imgpath = imgpath['AllImages_release']
16
+ imgpath = imgpath[7:1169]
17
+ mos = scipy.io.loadmat(os.path.join(root, 'Data', 'AllMOS_release.mat'))
18
+ labels = mos['AllMOS_release'].astype(np.float32)
19
+ labels = labels[0][7:1169]
20
+
21
+ sample, gt = [], []
22
+ for i, item in enumerate(index):
23
+ sample.append(os.path.join(root, 'Images', imgpath[item][0][0]))
24
+ gt.append(labels[item])
25
+ gt = normalization(gt)
26
+
27
+ self.samples, self.gt = sample, gt
28
+ self.transform = transform
29
+
30
+ def __getitem__(self, index):
31
+ """
32
+ Args:
33
+ index (int): Index
34
+
35
+ Returns:
36
+ tuple: (sample, target) where target is class_index of the target class.
37
+ """
38
+ img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform)
39
+
40
+ return img_tensor, gt_tensor
41
+
42
+ def __len__(self):
43
+ length = len(self.samples)
44
+ return length
45
+
46
+
47
+ class Koniq10k(data.Dataset):
48
+ def __init__(self, root, index, transform):
49
+ imgname = []
50
+ mos_all = []
51
+ csv_file = os.path.join(root, 'koniq10k_distributions_sets.csv')
52
+ with open(csv_file) as f:
53
+ reader = csv.DictReader(f)
54
+ for row in reader:
55
+ imgname.append(row['image_name'])
56
+ mos = np.array(float(row['MOS'])).astype(np.float32)
57
+ mos_all.append(mos)
58
+
59
+ sample, gt = [], []
60
+ for i, item in enumerate(index):
61
+ sample.append(os.path.join(root, '1024x768', imgname[item]))
62
+ gt.append(mos_all[item])
63
+ gt = normalization(gt)
64
+
65
+ self.samples, self.gt = sample, gt
66
+ self.transform = transform
67
+
68
+ def __getitem__(self, index):
69
+ """
70
+ Args:
71
+ index (int): Index
72
+
73
+ Returns:
74
+ tuple: (sample, target) where target is class_index of the target class.
75
+ """
76
+ img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform)
77
+
78
+ return img_tensor, gt_tensor
79
+
80
+ def __len__(self):
81
+ length = len(self.samples)
82
+ return length
83
+
84
+
85
+ class SPAQ(data.Dataset):
86
+ def __init__(self, root, index, transform):
87
+ imgname = []
88
+ mos_all = []
89
+ csv_file = os.path.join(root, 'koniq10k_scores_and_distributions.csv')
90
+ with open(csv_file) as f:
91
+ reader = csv.DictReader(f)
92
+ for row in reader:
93
+ imgname.append(row['image_name'])
94
+ mos = np.array(float(row['MOS_zscore'])).astype(np.float32)
95
+ mos_all.append(mos)
96
+
97
+ sample, gt = [], []
98
+ for i, item in enumerate(index):
99
+ sample.append(os.path.join(root, '1024x768', imgname[item]))
100
+ gt.append(labels[item])
101
+ gt = norm_target(gt)
102
+
103
+ self.samples, self.gt = sample, gt
104
+
105
+ self.samples, self.gt = sample, gt
106
+ self.transform = transform
107
+
108
+ def __getitem__(self, index):
109
+ """
110
+ Args:
111
+ index (int): Index
112
+
113
+ Returns:
114
+ tuple: (sample, target) where target is class_index of the target class.
115
+ """
116
+ path, target = self.samples[index], self.gt[index]
117
+ sample = pil_loader(path)
118
+ sample = self.transform(sample)
119
+ return sample, target
120
+
121
+ def __len__(self):
122
+ length = len(self.samples)
123
+ return length
124
+
125
+
126
+ class BID(data.Dataset):
127
+ def __init__(self, root, index, transform):
128
+
129
+ imgname = []
130
+ mos_all = []
131
+
132
+ xls_file = os.path.join(root, 'DatabaseGrades.xlsx')
133
+ workbook = load_workbook(xls_file)
134
+ booksheet = workbook.active
135
+ rows = booksheet.rows
136
+ count = 1
137
+ for row in rows:
138
+ count += 1
139
+ img_num = booksheet.cell(row=count, column=1).value
140
+ img_name = "DatabaseImage%04d.JPG" % (img_num)
141
+ imgname.append(img_name)
142
+ mos = booksheet.cell(row=count, column=2).value
143
+ mos = np.array(mos)
144
+ mos = mos.astype(np.float32)
145
+ mos_all.append(mos)
146
+ if count == 587:
147
+ break
148
+
149
+ sample, gt = [], []
150
+ for i, item in enumerate(index):
151
+ sample.append(os.path.join(root, imgname[item]))
152
+ gt.append(mos_all[item])
153
+ gt = normalization(gt)
154
+
155
+ self.samples, self.gt = sample, gt
156
+ self.transform = transform
157
+
158
+ def __getitem__(self, index):
159
+ """
160
+ Args:
161
+ index (int): Index
162
+
163
+ Returns:
164
+ tuple: (sample, target) where target is class_index of the target class.
165
+ """
166
+ img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform)
167
+
168
+ return img_tensor, gt_tensor
169
+
170
+ def __len__(self):
171
+ length = len(self.samples)
172
+ return length
173
+
174
+ def get_item(samples, gt, index, transform):
175
+ path, target = samples[index], gt[index]
176
+ sample = load_image(path)
177
+ samples = {'img': sample, 'gt': target }
178
+ samples = transform(samples)
179
+
180
+ return samples['img'], samples['gt'].type(torch.FloatTensor)
181
+
182
+
183
+ def getFileName(path, suffix):
184
+ filename = []
185
+ f_list = os.listdir(path)
186
+ for i in f_list:
187
+ if os.path.splitext(i)[1] == suffix:
188
+ filename.append(i)
189
+ return filename
190
+
191
+
192
+ def load_image(img_path):
193
+ d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
194
+ d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
195
+ d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
196
+ d_img = np.array(d_img).astype('float32') / 255
197
+ d_img = np.transpose(d_img, (2, 0, 1))
198
+
199
+ return d_img
200
+
201
+ def normalization(data):
202
+ data = np.array(data)
203
+ range = np.max(data) - np.min(data)
204
+ data = (data - np.min(data)) / range
205
+ data = list(data.astype('float').reshape(-1, 1))
206
+
207
+ return data
utils/dataset/process.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class Normalize(object):
6
+ def __init__(self, mean, var):
7
+ self.mean = mean
8
+ self.var = var
9
+
10
+ def __call__(self, sample):
11
+ if isinstance(sample, dict):
12
+ img = sample['img']
13
+ gt = sample['gt']
14
+ img = (img - self.mean) / self.var
15
+ sample = {'img': img, 'gt': gt}
16
+ else:
17
+ sample = (sample - self.mean) / self.var
18
+
19
+ return sample
20
+
21
+
22
+
23
+ class RandHorizontalFlip(object):
24
+ def __init__(self, prob_aug):
25
+ self.prob_aug = prob_aug
26
+
27
+ def __call__(self, sample):
28
+ p_aug = np.array([self.prob_aug, 1 - self.prob_aug])
29
+ prob_lr = np.random.choice([1, 0], p=p_aug.ravel())
30
+
31
+ if isinstance(sample, dict):
32
+ img = sample['img']
33
+ gt = sample['gt']
34
+
35
+ if prob_lr > 0.5:
36
+ img = np.fliplr(img).copy()
37
+ sample = {'img': img, 'gt': gt}
38
+ else:
39
+ if prob_lr > 0.5:
40
+ sample = np.fliplr(sample).copy()
41
+ return sample
42
+
43
+
44
+ class ToTensor(object):
45
+ def __init__(self):
46
+ pass
47
+
48
+ def __call__(self, sample):
49
+ if isinstance(sample, dict):
50
+ img = sample['img']
51
+ gt = sample['gt']
52
+ img = torch.from_numpy(img).type(torch.FloatTensor)
53
+ gt = torch.from_numpy(gt).type(torch.FloatTensor)
54
+ sample = {'img': img, 'gt': gt}
55
+ else:
56
+ sample = torch.from_numpy(sample).type(torch.FloatTensor)
57
+ return sample
utils/iqa_solver.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scipy import stats
3
+ import numpy as np
4
+ from models import monet as MoNet
5
+ from models import gc_loss as GC_Loss
6
+ from utils.dataset import data_loader
7
+ import json
8
+ import random
9
+ import os
10
+ from tqdm import tqdm
11
+
12
+
13
+ def get_data(dataset, data_path='./utils/dataset/dataset_info.json'):
14
+ with open(data_path, 'r') as data_info:
15
+ data_info = json.load(data_info)
16
+ path, img_num = data_info[dataset]
17
+ img_num = list(range(img_num))
18
+
19
+ random.shuffle(img_num)
20
+ train_index = img_num[0:int(round(0.8 * len(img_num)))]
21
+ test_index = img_num[int(round(0.8 * len(img_num))):len(img_num)]
22
+
23
+ return path, train_index, test_index
24
+
25
+
26
+ def cal_srocc_plcc(pred_score, gt_score):
27
+ srocc, _ = stats.spearmanr(pred_score, gt_score)
28
+ plcc, _ = stats.pearsonr(pred_score, gt_score)
29
+
30
+ return srocc, plcc
31
+
32
+
33
+ class Solver:
34
+ def __init__(self, config):
35
+
36
+ path, train_index, test_index = get_data(dataset=config.dataset)
37
+
38
+ train_loader = data_loader.Data_Loader(config, path, train_index, istrain=True)
39
+ test_loader = data_loader.Data_Loader(config, path, test_index, istrain=False)
40
+ self.train_data = train_loader.get_data()
41
+ self.test_data = test_loader.get_data()
42
+
43
+ print('Traning data number: ', len(train_index))
44
+ print('Testing data number: ', len(test_index))
45
+
46
+ if config.loss == 'MAE':
47
+ self.loss = torch.nn.L1Loss().cuda()
48
+ elif config.loss == 'MSE':
49
+ self.loss = torch.nn.MSELoss().cuda()
50
+ elif config.loss == 'GC':
51
+ self.loss = GC_Loss.GC_Loss(queue_len=int(len(train_index) * config.queue_ratio))
52
+ else:
53
+ raise 'Only Support MAE, MSE and GC loss.'
54
+
55
+ print('Loading MoNet...')
56
+ self.MoNet = MoNet.MoNet(config).cuda()
57
+ self.MoNet.train(True)
58
+
59
+ self.epochs = config.epochs
60
+ self.optimizer = torch.optim.Adam(self.MoNet.parameters(), lr=config.lr, weight_decay=config.weight_decay)
61
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=config.T_max, eta_min=config.eta_min)
62
+
63
+ self.model_save_path = os.path.join(config.save_path, 'best_model.pkl')
64
+
65
+ def train(self):
66
+ """Training"""
67
+ best_srocc = 0.0
68
+ best_plcc = 0.0
69
+ print('----------------------------------')
70
+ print('Epoch\tTrain_Loss\tTrain_SROCC\tTrain_PLCC\tTest_SROCC\tTest_PLCC')
71
+ for t in range(self.epochs):
72
+ epoch_loss = []
73
+ pred_scores = []
74
+ gt_scores = []
75
+
76
+ for img, label in tqdm(self.train_data):
77
+ img = img.cuda()
78
+ label = label.view(-1).cuda()
79
+
80
+ self.optimizer.zero_grad()
81
+
82
+ pred = self.MoNet(img) # 'paras' contains the network weights conveyed to target network
83
+
84
+ pred_scores = pred_scores + pred.cpu().tolist()
85
+ gt_scores = gt_scores + label.cpu().tolist()
86
+
87
+ loss = self.loss(pred.squeeze(), label.float().detach())
88
+ epoch_loss.append(loss.item())
89
+
90
+ loss.backward()
91
+ self.optimizer.step()
92
+ self.scheduler.step()
93
+
94
+ train_srocc, train_plcc = cal_srocc_plcc(pred_scores, gt_scores)
95
+
96
+ test_srocc, test_plcc = self.test()
97
+ if test_srocc + test_plcc > best_srocc + best_plcc:
98
+ best_srocc = test_srocc
99
+ best_plcc = test_plcc
100
+ torch.save(self.MoNet.state_dict(), self.model_save_path)
101
+ print('Model saved in: ', self.model_save_path)
102
+
103
+ print('{}\t{}\t{}\t{}\t{}\t{}'.format(t + 1, round(np.mean(epoch_loss), 4), round(train_srocc, 4),
104
+ round(train_plcc, 4), round(test_srocc, 4), round(test_plcc, 4)))
105
+
106
+ print('Best test SROCC {}, PLCC {}'.format(round(best_srocc, 4), round(best_plcc, 4)))
107
+
108
+ return best_srocc, best_plcc
109
+
110
+ def test(self):
111
+ """Testing"""
112
+ self.MoNet.train(False)
113
+ pred_scores = []
114
+ gt_scores = []
115
+
116
+ with torch.no_grad():
117
+ for img, label in tqdm(self.test_data):
118
+ # Data.
119
+ img = img.cuda()
120
+ label = label.view(-1).cuda()
121
+
122
+ pred = self.MoNet(img)
123
+
124
+ pred_scores = pred_scores + pred.cpu().tolist()
125
+ gt_scores = gt_scores + label.cpu().tolist()
126
+
127
+ test_srocc, test_plcc = cal_srocc_plcc(pred_scores, gt_scores)
128
+
129
+ self.MoNet.train(True)
130
+ return test_srocc, test_plcc
utils/log_writer.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ class Logger(object):
4
+ def __init__(self, filename="Default.log"):
5
+ self.terminal = sys.stdout
6
+ self.log = open(filename, "w")
7
+
8
+ def write(self, message):
9
+ self.terminal.write(message)
10
+ self.log.write(message)
11
+ self.flush()
12
+
13
+ def flush(self):
14
+ self.log.flush()