MoC-IQA
Browse files- app.py +45 -0
- models/__pycache__/WAIQT.cpython-38.pyc +0 -0
- models/__pycache__/gc_loss.cpython-37.pyc +0 -0
- models/__pycache__/gc_loss.cpython-38.pyc +0 -0
- models/__pycache__/monet.cpython-37.pyc +0 -0
- models/__pycache__/monet.cpython-38.pyc +0 -0
- models/gc_loss.py +187 -0
- models/monet.py +250 -0
- utils/__pycache__/__init__.cpython-37.pyc +0 -0
- utils/__pycache__/iqa_solver.cpython-36.pyc +0 -0
- utils/__pycache__/iqa_solver.cpython-37.pyc +0 -0
- utils/__pycache__/iqa_solver.cpython-38.pyc +0 -0
- utils/__pycache__/log_writer.cpython-36.pyc +0 -0
- utils/__pycache__/log_writer.cpython-37.pyc +0 -0
- utils/__pycache__/log_writer.cpython-38.pyc +0 -0
- utils/__pycache__/process.cpython-38.pyc +0 -0
- utils/dataset/__pycache__/__init__.cpython-37.pyc +0 -0
- utils/dataset/__pycache__/data_loader.cpython-37.pyc +0 -0
- utils/dataset/__pycache__/data_loader.cpython-38.pyc +0 -0
- utils/dataset/__pycache__/folders.cpython-37.pyc +0 -0
- utils/dataset/__pycache__/folders.cpython-38.pyc +0 -0
- utils/dataset/__pycache__/process.cpython-38.pyc +0 -0
- utils/dataset/data_loader.py +36 -0
- utils/dataset/dataset_info.json +6 -0
- utils/dataset/folders.py +207 -0
- utils/dataset/process.py +57 -0
- utils/iqa_solver.py +130 -0
- utils/log_writer.py +14 -0
app.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from models import monet as MoNet
|
7 |
+
import argparse
|
8 |
+
from utils.dataset.process import ToTensor, Normalize
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
def load_image(img_path):
|
12 |
+
d_img = cv2.cvtColor(np.asarray(img_path),cv2.COLOR_RGB2BGR)
|
13 |
+
# d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
14 |
+
d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
|
15 |
+
d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
|
16 |
+
d_img = np.array(d_img).astype('float32') / 255
|
17 |
+
d_img = np.transpose(d_img, (2, 0, 1))
|
18 |
+
|
19 |
+
return d_img
|
20 |
+
|
21 |
+
def predict(image):
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
# model related
|
24 |
+
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224',
|
25 |
+
help='The backbone for MoNet.')
|
26 |
+
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
|
27 |
+
config = parser.parse_args()
|
28 |
+
|
29 |
+
model = MoNet.MoNet(config).cuda()
|
30 |
+
model.load_state_dict(torch.load('./checkpoints/best_model.pkl'))
|
31 |
+
model.eval()
|
32 |
+
|
33 |
+
trans = torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
|
34 |
+
|
35 |
+
"""Run a single prediction on the model"""
|
36 |
+
img = load_image(image)
|
37 |
+
img_tensor = trans(img).unsqueeze(0).cuda()
|
38 |
+
iq = model(img_tensor).cpu().detach().numpy().tolist()[0]
|
39 |
+
|
40 |
+
return "The image quality of the image is: {}".format(round(iq, 4))
|
41 |
+
|
42 |
+
# os.system("wget -O ./checkpoints/best_model.pkl https://huggingface.co/Zevin2023/MoC-IQA/resolve/main/Koniq10K_570908.pkl")
|
43 |
+
|
44 |
+
interface = gr.Interface(fn=predict, inputs="image", outputs="text")
|
45 |
+
interface.launch(server_name='127.0.0.1',server_port=8088)
|
models/__pycache__/WAIQT.cpython-38.pyc
ADDED
Binary file (6.91 kB). View file
|
|
models/__pycache__/gc_loss.cpython-37.pyc
ADDED
Binary file (5.34 kB). View file
|
|
models/__pycache__/gc_loss.cpython-38.pyc
ADDED
Binary file (3.11 kB). View file
|
|
models/__pycache__/monet.cpython-37.pyc
ADDED
Binary file (8.07 kB). View file
|
|
models/__pycache__/monet.cpython-38.pyc
ADDED
Binary file (8.1 kB). View file
|
|
models/gc_loss.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
# class GC_Loss(nn.Module):
|
6 |
+
# def __init__(self, queue_len=800):
|
7 |
+
# super(GC_Loss, self).__init__()
|
8 |
+
# self.pred_queue = list()
|
9 |
+
# self.gt_queue = list()
|
10 |
+
# self.queue_len = 0
|
11 |
+
|
12 |
+
# self.queue_max_len = queue_len
|
13 |
+
|
14 |
+
# print('CCWD Length: ', queue_len)
|
15 |
+
|
16 |
+
# self.l1_loss = torch.nn.L1Loss().cuda()
|
17 |
+
# self.l2_loss = torch.nn.MSELoss().cuda()
|
18 |
+
|
19 |
+
# def enqueue(self, pred, gt):
|
20 |
+
# bs = pred.shape[0]
|
21 |
+
# self.queue_len = self.queue_len + bs
|
22 |
+
|
23 |
+
# self.pred_queue = self.pred_queue + pred.cpu().detach().numpy().tolist()
|
24 |
+
# self.gt_queue = self.gt_queue + gt.cpu().detach().numpy().tolist()
|
25 |
+
|
26 |
+
# if self.queue_len > self.queue_max_len:
|
27 |
+
# self.dequeue(self.queue_len - self.queue_max_len)
|
28 |
+
# self.queue_len = self.queue_max_len
|
29 |
+
|
30 |
+
# def dequeue(self, n):
|
31 |
+
# for index in range(n):
|
32 |
+
# self.pred_queue.pop(0)
|
33 |
+
# self.gt_queue.pop(0)
|
34 |
+
|
35 |
+
# def clear(self):
|
36 |
+
# self.pred_queue.clear()
|
37 |
+
# self.gt_queue.clear()
|
38 |
+
|
39 |
+
# def forward(self, x, y):
|
40 |
+
# x_queue = self.pred_queue.copy()
|
41 |
+
# y_queue = self.gt_queue.copy()
|
42 |
+
|
43 |
+
# # 获取队列中的所有值
|
44 |
+
# x_all = torch.cat((x, torch.tensor(x_queue).cuda()), dim=0)
|
45 |
+
# y_all = torch.cat((y, torch.tensor(y_queue).cuda()), dim=0)
|
46 |
+
|
47 |
+
# # 估计均值和方差
|
48 |
+
# x_bar = torch.mean(x_all, dim=0)
|
49 |
+
# x_std = torch.std(x_all, dim=0)
|
50 |
+
|
51 |
+
# y_bar = torch.mean(y_all, dim=0)
|
52 |
+
# y_std = torch.std(y_all, dim=0)
|
53 |
+
|
54 |
+
# # 估计预测值在整体值中的PLCC
|
55 |
+
# diff_x_plcc = (x - x_bar) # [bs, 1]
|
56 |
+
# diff_y_plcc = (y - y_bar) # [bs, 1]
|
57 |
+
|
58 |
+
# x1 = torch.sum(torch.mul(diff_x_plcc, diff_y_plcc))
|
59 |
+
# x2_1 = torch.sqrt(torch.sum(torch.mul(diff_x_plcc, diff_x_plcc)))
|
60 |
+
# x2_2 = torch.sqrt(torch.sum(torch.mul(diff_y_plcc, diff_y_plcc)))
|
61 |
+
|
62 |
+
# # 对所有值标准化
|
63 |
+
# diff_x = (x_all - x_bar) / x_std # [bs, 1]
|
64 |
+
# diff_y = (y_all - y_bar) / y_std # [bs, 1]
|
65 |
+
|
66 |
+
# rank_x = diff_x.reshape(-1, 1)
|
67 |
+
# rank_y = diff_y.reshape(-1, 1)
|
68 |
+
|
69 |
+
# rank_x = rank_x - rank_x.transpose(1, 0)
|
70 |
+
# rank_y = rank_y - rank_y.transpose(1, 0)
|
71 |
+
|
72 |
+
# # 对所有值估计排序
|
73 |
+
# rank_x = torch.sum(1 / 2 * (1 + torch.erf(rank_x)), dim=1)
|
74 |
+
# rank_y = torch.sum(1 / 2 * (1 + torch.erf(rank_y)), dim=1)
|
75 |
+
|
76 |
+
# # 计算排序后的均值和方差
|
77 |
+
# rank_x_bar = torch.mean(rank_x, dim=0)
|
78 |
+
# rank_x_std = torch.std(rank_x, dim=0)
|
79 |
+
# rank_y_bar = torch.mean(rank_y, dim=0)
|
80 |
+
# rank_y_std = torch.std(rank_y, dim=0)
|
81 |
+
|
82 |
+
# # 估计预测值在整体值中的SROCC
|
83 |
+
# rank_x_ = (x - rank_x_bar) / rank_x_std # [bs, 1]
|
84 |
+
# rank_y_ = (y - rank_y_bar) / rank_y_std # [bs, 1]
|
85 |
+
|
86 |
+
# x1_rank = torch.sum(torch.mul(rank_x_, rank_y_))
|
87 |
+
# x2_1_rank = torch.sqrt(torch.sum(torch.mul(rank_x_, rank_x_)))
|
88 |
+
# x2_2_rank = torch.sqrt(torch.sum(torch.mul(rank_y_, rank_y_)))
|
89 |
+
|
90 |
+
# self.enqueue(x, y)
|
91 |
+
|
92 |
+
# return (0.5 * ((1 - x1 / (x2_1 * x2_2)) + (1 - (x1_rank / (x2_1_rank * x2_2_rank)))) + 1) * self.l2_loss(x, y)
|
93 |
+
|
94 |
+
class GC_Loss(nn.Module):
|
95 |
+
def __init__(self, queue_len=800, alpha=0.5, beta=0.5, gamma=1):
|
96 |
+
super(GC_Loss, self).__init__()
|
97 |
+
self.pred_queue = list()
|
98 |
+
self.gt_queue = list()
|
99 |
+
self.queue_len = 0
|
100 |
+
|
101 |
+
self.queue_max_len = queue_len
|
102 |
+
print('The queue length is: ', self.queue_max_len)
|
103 |
+
self.mse = torch.nn.MSELoss().cuda()
|
104 |
+
|
105 |
+
self.alpha, self.beta, self.gamma = alpha, beta, gamma
|
106 |
+
|
107 |
+
def consistency(self, pred_data, gt_data):
|
108 |
+
pred_one_batch, pred_queue = pred_data
|
109 |
+
gt_one_batch, gt_queue = gt_data
|
110 |
+
|
111 |
+
pred_mean = torch.mean(pred_queue)
|
112 |
+
gt_mean = torch.mean(gt_queue)
|
113 |
+
|
114 |
+
diff_pred = pred_one_batch - pred_mean
|
115 |
+
diff_gt = gt_one_batch - gt_mean
|
116 |
+
|
117 |
+
x1 = torch.sum(torch.mul(diff_pred, diff_gt))
|
118 |
+
x2_1 = torch.sqrt(torch.sum(torch.mul(diff_pred, diff_pred)))
|
119 |
+
x2_2 = torch.sqrt(torch.sum(torch.mul(diff_gt, diff_gt)))
|
120 |
+
|
121 |
+
return x1 / (x2_1 * x2_2)
|
122 |
+
|
123 |
+
def ppra(self, x):
|
124 |
+
"""
|
125 |
+
Pairwise Preference-based Rank Approximation
|
126 |
+
"""
|
127 |
+
|
128 |
+
x_bar, x_std = torch.mean(x), torch.std(x)
|
129 |
+
x_n = (x - x_bar) / x_std
|
130 |
+
x_n_T = x_n.reshape(-1, 1)
|
131 |
+
|
132 |
+
rank_x = x_n_T - x_n_T.transpose(1, 0)
|
133 |
+
rank_x = torch.sum(1 / 2 * (1 + torch.erf(rank_x / torch.sqrt(torch.tensor(2, dtype=torch.float)))), dim=1)
|
134 |
+
|
135 |
+
return rank_x
|
136 |
+
|
137 |
+
@torch.no_grad()
|
138 |
+
def enqueue(self, pred, gt):
|
139 |
+
bs = pred.shape[0]
|
140 |
+
self.queue_len = self.queue_len + bs
|
141 |
+
|
142 |
+
self.pred_queue = self.pred_queue + pred.tolist()
|
143 |
+
self.gt_queue = self.gt_queue + gt.cpu().detach().numpy().tolist()
|
144 |
+
|
145 |
+
if self.queue_len > self.queue_max_len:
|
146 |
+
self.dequeue(self.queue_len - self.queue_max_len)
|
147 |
+
self.queue_len = self.queue_max_len
|
148 |
+
|
149 |
+
@torch.no_grad()
|
150 |
+
def dequeue(self, n):
|
151 |
+
for _ in range(n):
|
152 |
+
self.pred_queue.pop(0)
|
153 |
+
self.gt_queue.pop(0)
|
154 |
+
|
155 |
+
def clear(self):
|
156 |
+
self.pred_queue.clear()
|
157 |
+
self.gt_queue.clear()
|
158 |
+
|
159 |
+
def forward(self, x, y):
|
160 |
+
x_queue = self.pred_queue.copy()
|
161 |
+
y_queue = self.gt_queue.copy()
|
162 |
+
|
163 |
+
x_all = torch.cat((x, torch.tensor(x_queue).cuda()), dim=0)
|
164 |
+
y_all = torch.cat((y, torch.tensor(y_queue).cuda()), dim=0)
|
165 |
+
|
166 |
+
PLCC = self.consistency((x, x_all), (y, y_all))
|
167 |
+
PGC = 1 - PLCC
|
168 |
+
|
169 |
+
rank_x = self.ppra(x_all)
|
170 |
+
rank_y = self.ppra(y_all)
|
171 |
+
SROCC = self.consistency((rank_x[:x.shape[0]], rank_x), (rank_y[:y.shape[0]], rank_y))
|
172 |
+
SGC = 1 - SROCC
|
173 |
+
|
174 |
+
GC = (self.alpha * PGC + self.beta * SGC + self.gamma) * self.mse(x, y)
|
175 |
+
self.enqueue(x, y)
|
176 |
+
|
177 |
+
return GC
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == '__main__':
|
181 |
+
gc = GC_Loss().cuda()
|
182 |
+
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float).cuda()
|
183 |
+
y = torch.tensor([6, 7, 8, 9, 15], dtype=torch.float).cuda()
|
184 |
+
|
185 |
+
res = gc(x, y)
|
186 |
+
|
187 |
+
print(res)
|
models/monet.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The completion for Mean-opinion Network(MoNet)
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import timm
|
7 |
+
|
8 |
+
from timm.models.vision_transformer import Block
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
|
12 |
+
class Attention_Block(nn.Module):
|
13 |
+
def __init__(self, dim, drop=0.1):
|
14 |
+
super().__init__()
|
15 |
+
self.c_q = nn.Linear(dim, dim)
|
16 |
+
self.c_k = nn.Linear(dim, dim)
|
17 |
+
self.c_v = nn.Linear(dim, dim)
|
18 |
+
self.norm_fact = dim ** -0.5
|
19 |
+
self.softmax = nn.Softmax(dim=-1)
|
20 |
+
self.proj_drop = nn.Dropout(drop)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
_x = x
|
24 |
+
B, C, N = x.shape
|
25 |
+
q = self.c_q(x)
|
26 |
+
k = self.c_k(x)
|
27 |
+
v = self.c_v(x)
|
28 |
+
|
29 |
+
attn = q @ k.transpose(-2, -1) * self.norm_fact
|
30 |
+
attn = self.softmax(attn)
|
31 |
+
x = (attn @ v).transpose(1, 2).reshape(B, C, N)
|
32 |
+
x = self.proj_drop(x)
|
33 |
+
x = x + _x
|
34 |
+
return x
|
35 |
+
|
36 |
+
|
37 |
+
class Self_Attention(nn.Module):
|
38 |
+
""" Self attention Layer"""
|
39 |
+
|
40 |
+
def __init__(self, in_dim):
|
41 |
+
super(Self_Attention, self).__init__()
|
42 |
+
|
43 |
+
self.qConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
44 |
+
self.kConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
|
45 |
+
self.vConv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
|
46 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
47 |
+
|
48 |
+
self.softmax = nn.Softmax(dim=-1)
|
49 |
+
|
50 |
+
def forward(self, inFeature):
|
51 |
+
bs, C, w, h = inFeature.size()
|
52 |
+
|
53 |
+
proj_query = self.qConv(inFeature).view(bs, -1, w * h).permute(0, 2, 1)
|
54 |
+
proj_key = self.kConv(inFeature).view(bs, -1, w * h)
|
55 |
+
energy = torch.bmm(proj_query, proj_key)
|
56 |
+
attention = self.softmax(energy)
|
57 |
+
proj_value = self.vConv(inFeature).view(bs, -1, w * h)
|
58 |
+
|
59 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
60 |
+
out = out.view(bs, C, w, h)
|
61 |
+
|
62 |
+
out = self.gamma * out + inFeature
|
63 |
+
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
class MAL(nn.Module):
|
68 |
+
"""
|
69 |
+
Multi-view Attention Learning (MAL) module
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, in_dim=768, feature_num=4, feature_size=28):
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
self.channel_attention = Attention_Block(in_dim * feature_num) # Channel-wise self attention
|
76 |
+
self.feature_attention = Attention_Block(feature_size ** 2 * feature_num) # Pixel-wise self attention
|
77 |
+
|
78 |
+
# Self attention module for each input feature
|
79 |
+
self.attention_module = nn.ModuleList()
|
80 |
+
for _ in range(feature_num):
|
81 |
+
self.attention_module.append(Self_Attention(in_dim))
|
82 |
+
|
83 |
+
self.feature_num = feature_num
|
84 |
+
self.in_dim = in_dim
|
85 |
+
|
86 |
+
def forward(self, features):
|
87 |
+
feature = torch.tensor([]).cuda()
|
88 |
+
for index, _ in enumerate(features):
|
89 |
+
feature = torch.cat((feature, self.attention_module[index](features[index]).unsqueeze(0)), dim=0)
|
90 |
+
features = feature
|
91 |
+
|
92 |
+
input_tensor = rearrange(features, 'n b c w h -> b (n c) (w h)') # bs, 768 * feature_num, 28 * 28
|
93 |
+
bs, _, _ = input_tensor.shape # [2, 3072, 784]
|
94 |
+
|
95 |
+
in_feature = rearrange(input_tensor, 'b (w c) h -> b w (c h)', w=self.in_dim, c=self.feature_num) # bs, 768, 28 * 28 * feature_num
|
96 |
+
feature_weight_sum = self.feature_attention(in_feature) # bs, 768, 768
|
97 |
+
|
98 |
+
in_channel = input_tensor.permute(0, 2, 1) # bs, 28 * 28, 768 * feature_num
|
99 |
+
channel_weight_sum = self.channel_attention(in_channel) # bs, 28 * 28, 28 * 28
|
100 |
+
|
101 |
+
weight_sum_res = (rearrange(feature_weight_sum, 'b w (c h) -> b (w c) h', w=self.in_dim,
|
102 |
+
c=self.feature_num) + channel_weight_sum.permute(0, 2, 1)) / 2 # [2, 3072, 784]
|
103 |
+
|
104 |
+
weight_sum_res = torch.mean(weight_sum_res.view(bs, self.feature_num, self.in_dim, -1), dim=1)
|
105 |
+
|
106 |
+
return weight_sum_res # bs, 768, 28 * 28
|
107 |
+
|
108 |
+
|
109 |
+
class SaveOutput:
|
110 |
+
def __init__(self):
|
111 |
+
self.outputs = []
|
112 |
+
|
113 |
+
def __call__(self, module, module_in, module_out):
|
114 |
+
self.outputs.append(module_out)
|
115 |
+
|
116 |
+
def clear(self):
|
117 |
+
self.outputs = []
|
118 |
+
|
119 |
+
|
120 |
+
class MoNet(nn.Module):
|
121 |
+
def __init__(self, config, patch_size=8, drop=0.1, dim_mlp=768, img_size=224):
|
122 |
+
super().__init__()
|
123 |
+
self.img_size = img_size
|
124 |
+
self.input_size = img_size // patch_size
|
125 |
+
self.dim_mlp = dim_mlp
|
126 |
+
|
127 |
+
self.vit = timm.create_model(config.backbone, pretrained=False)
|
128 |
+
self.save_output = SaveOutput()
|
129 |
+
|
130 |
+
# Register Hooks
|
131 |
+
hook_handles = []
|
132 |
+
for layer in self.vit.modules():
|
133 |
+
if isinstance(layer, Block):
|
134 |
+
handle = layer.register_forward_hook(self.save_output)
|
135 |
+
hook_handles.append(handle)
|
136 |
+
|
137 |
+
self.MALs = nn.ModuleList()
|
138 |
+
for _ in range(config.mal_num):
|
139 |
+
self.MALs.append(MAL())
|
140 |
+
|
141 |
+
# Image Quality Score Regression
|
142 |
+
self.fusion_wam = MAL(feature_num=config.mal_num)
|
143 |
+
self.block = Block(dim_mlp, 12)
|
144 |
+
self.cnn = nn.Sequential(
|
145 |
+
nn.Conv2d(dim_mlp, 256, 5),
|
146 |
+
nn.BatchNorm2d(256),
|
147 |
+
nn.ReLU(inplace=True),
|
148 |
+
nn.AvgPool2d((2, 2)),
|
149 |
+
nn.Conv2d(256, 128, 3),
|
150 |
+
nn.BatchNorm2d(128),
|
151 |
+
nn.ReLU(inplace=True),
|
152 |
+
nn.AvgPool2d((2, 2)),
|
153 |
+
nn.Conv2d(128, 128, 3),
|
154 |
+
nn.BatchNorm2d(128),
|
155 |
+
nn.ReLU(inplace=True),
|
156 |
+
nn.AvgPool2d((3, 3)),
|
157 |
+
)
|
158 |
+
self.fc_score = nn.Sequential(
|
159 |
+
nn.Linear(128, 128 // 2),
|
160 |
+
nn.ReLU(),
|
161 |
+
nn.Dropout(drop),
|
162 |
+
nn.Linear(128 // 2, 1),
|
163 |
+
nn.Sigmoid()
|
164 |
+
)
|
165 |
+
|
166 |
+
def extract_feature(self, save_output, block_index=[2, 5, 8, 11]):
|
167 |
+
x1 = save_output.outputs[block_index[0]][:, 1:]
|
168 |
+
x2 = save_output.outputs[block_index[1]][:, 1:]
|
169 |
+
x3 = save_output.outputs[block_index[2]][:, 1:]
|
170 |
+
x4 = save_output.outputs[block_index[3]][:, 1:]
|
171 |
+
x = torch.cat((x1, x2, x3, x4), dim=2)
|
172 |
+
return x
|
173 |
+
|
174 |
+
def forward(self, x):
|
175 |
+
# Multi-level Feature From Different Transformer Blocks
|
176 |
+
_x = self.vit(x)
|
177 |
+
x = self.extract_feature(self.save_output) # bs, 28 * 28, 768 * 4
|
178 |
+
self.save_output.outputs.clear()
|
179 |
+
|
180 |
+
x = x.permute(0, 2, 1) # bs, 768 * 4, 28 * 28
|
181 |
+
x = rearrange(x, 'b (d n) (w h) -> b d n w h', d=4, n=self.dim_mlp, w=self.input_size, h=self.input_size) # bs, 4, 768, 28, 28
|
182 |
+
x = x.permute(1, 0, 2, 3, 4) # bs, 4, 768, 28 * 28
|
183 |
+
|
184 |
+
# Different Opinion Features (DOF)
|
185 |
+
DOF = torch.tensor([]).cuda()
|
186 |
+
for index, _ in enumerate(self.MALs):
|
187 |
+
DOF = torch.cat((DOF, self.MALs[index](x).unsqueeze(0)), dim=0)
|
188 |
+
DOF = rearrange(DOF, 'n c d (w h) -> n c d w h', w=self.input_size, h=self.input_size) # 3, bs, 768, 28, 28
|
189 |
+
|
190 |
+
# Image Quality Score Regression
|
191 |
+
wam = self.fusion_wam(DOF).permute(0, 2, 1) # bs, 28 * 28 768
|
192 |
+
wam = self.block(wam).permute(0, 2, 1)
|
193 |
+
wam = rearrange(wam, 'c d (w h) -> c d w h', w=self.input_size, h=self.input_size)
|
194 |
+
score = self.cnn(wam).squeeze(-1).squeeze(-1)
|
195 |
+
score = self.fc_score(score).view(-1)
|
196 |
+
|
197 |
+
return score
|
198 |
+
|
199 |
+
|
200 |
+
if __name__ == '__main__':
|
201 |
+
import argparse
|
202 |
+
|
203 |
+
parser = argparse.ArgumentParser()
|
204 |
+
parser.add_argument('--seed', dest='seed', type=int, default=3407)
|
205 |
+
parser.add_argument('--gpu_id', dest='gpu_id', type=str, default='0')
|
206 |
+
|
207 |
+
# model related
|
208 |
+
parser.add_argument('--backbone', dest='backbone', type=str, default='vit_base_patch8_224',
|
209 |
+
help='The backbone for MoNet.')
|
210 |
+
parser.add_argument('--mal_num', dest='mal_num', type=int, default=3, help='The number of the MAL modules.')
|
211 |
+
|
212 |
+
# data related
|
213 |
+
parser.add_argument('--dataset', dest='dataset', type=str, default='livec',
|
214 |
+
help='Support datasets: livec|koniq10k|bid|spaq')
|
215 |
+
parser.add_argument('--train_patch_num', dest='train_patch_num', type=int, default=5,
|
216 |
+
help='Number of sample patches from training image')
|
217 |
+
parser.add_argument('--test_patch_num', dest='test_patch_num', type=int, default=25,
|
218 |
+
help='Number of sample patches from testing image')
|
219 |
+
parser.add_argument('--patch_size', dest='patch_size', type=int, default=224,
|
220 |
+
help='Crop size for training & testing image patches')
|
221 |
+
|
222 |
+
# training related
|
223 |
+
parser.add_argument('--lr', dest='lr', type=float, default=1e-5, help='Learning rate')
|
224 |
+
parser.add_argument('--weight_decay', dest='weight_decay', type=float, default=1e-5, help='Weight decay')
|
225 |
+
parser.add_argument('--batch_size', dest='batch_size', type=int, default=11, help='Batch size')
|
226 |
+
parser.add_argument('--epochs', dest='epochs', type=int, default=50, help='Epochs for training')
|
227 |
+
parser.add_argument('--T_max', dest='T_max', type=int, default=50, help='Hyper-parameter for CosineAnnealingLR')
|
228 |
+
parser.add_argument('--eta_min', dest='eta_min', type=int, default=0, help='Hyper-parameter for CosineAnnealingLR')
|
229 |
+
|
230 |
+
parser.add_argument('--save_path', dest='save_path', type=str, default='./training_for_IQA',
|
231 |
+
help='The path where the model and logs will be saved.')
|
232 |
+
|
233 |
+
config = parser.parse_args()
|
234 |
+
|
235 |
+
# torch.autograd.set_detect_anomaly(True)
|
236 |
+
# with torch.autograd.detect_anomaly():
|
237 |
+
in_tensor = torch.zeros((2, 3, 224, 224), dtype=torch.float).cuda()
|
238 |
+
model = MoNet(config).cuda()
|
239 |
+
res = model(in_tensor)
|
240 |
+
|
241 |
+
print('{} : {} [M]'.format('#Params', sum(map(lambda x: x.numel(), model.parameters())) / 10 ** 6))
|
242 |
+
|
243 |
+
# label = torch.tensor([1, 2], dtype=torch.float).cuda()
|
244 |
+
# loss = torch.nn.L1Loss().cuda()
|
245 |
+
#
|
246 |
+
# res = model(in_tensor)
|
247 |
+
# # loss = loss_func()
|
248 |
+
# l = loss(label, res)
|
249 |
+
# print(l)
|
250 |
+
# l.backward()
|
utils/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (258 Bytes). View file
|
|
utils/__pycache__/iqa_solver.cpython-36.pyc
ADDED
Binary file (3.73 kB). View file
|
|
utils/__pycache__/iqa_solver.cpython-37.pyc
ADDED
Binary file (3.81 kB). View file
|
|
utils/__pycache__/iqa_solver.cpython-38.pyc
ADDED
Binary file (3.64 kB). View file
|
|
utils/__pycache__/log_writer.cpython-36.pyc
ADDED
Binary file (799 Bytes). View file
|
|
utils/__pycache__/log_writer.cpython-37.pyc
ADDED
Binary file (809 Bytes). View file
|
|
utils/__pycache__/log_writer.cpython-38.pyc
ADDED
Binary file (825 Bytes). View file
|
|
utils/__pycache__/process.cpython-38.pyc
ADDED
Binary file (1.78 kB). View file
|
|
utils/dataset/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (266 Bytes). View file
|
|
utils/dataset/__pycache__/data_loader.cpython-37.pyc
ADDED
Binary file (1.8 kB). View file
|
|
utils/dataset/__pycache__/data_loader.cpython-38.pyc
ADDED
Binary file (1.44 kB). View file
|
|
utils/dataset/__pycache__/folders.cpython-37.pyc
ADDED
Binary file (6.02 kB). View file
|
|
utils/dataset/__pycache__/folders.cpython-38.pyc
ADDED
Binary file (5.75 kB). View file
|
|
utils/dataset/__pycache__/process.cpython-38.pyc
ADDED
Binary file (1.93 kB). View file
|
|
utils/dataset/data_loader.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
|
4 |
+
from utils.dataset import folders
|
5 |
+
from utils.dataset.process import ToTensor, Normalize, RandHorizontalFlip
|
6 |
+
|
7 |
+
class Data_Loader(object):
|
8 |
+
"""Dataset class for IQA databases"""
|
9 |
+
|
10 |
+
def __init__(self, config, path, img_indx, istrain=True):
|
11 |
+
|
12 |
+
self.batch_size = config.batch_size
|
13 |
+
self.istrain = istrain
|
14 |
+
dataset = config.dataset
|
15 |
+
patch_size = config.patch_size
|
16 |
+
|
17 |
+
# Train transforms
|
18 |
+
if istrain:
|
19 |
+
transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), RandHorizontalFlip(prob_aug=0.5), ToTensor()])
|
20 |
+
else:
|
21 |
+
transforms=torchvision.transforms.Compose([Normalize(0.5, 0.5), ToTensor()])
|
22 |
+
|
23 |
+
if dataset == 'livec':
|
24 |
+
self.data = folders.LIVEC(root=path, index=img_indx, transform=transforms)
|
25 |
+
elif dataset == 'koniq10k':
|
26 |
+
self.data = folders.Koniq10k(root=path, index=img_indx, transform=transforms)
|
27 |
+
elif dataset == 'bid':
|
28 |
+
self.data = folders.BID(root=path, index=img_indx, transform=transforms)
|
29 |
+
elif dataset == 'spaq':
|
30 |
+
self.data = folders.SPAQ(root=path, index=img_indx, transform=transforms)
|
31 |
+
else:
|
32 |
+
raise Exception("Only support livec, koniq10k, bid, spaq.")
|
33 |
+
|
34 |
+
def get_data(self):
|
35 |
+
dataloader = torch.utils.data.DataLoader(self.data, batch_size=self.batch_size, shuffle=self.istrain, num_workers=8)
|
36 |
+
return dataloader
|
utils/dataset/dataset_info.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"livec": ["/disk1/chenzewen/sciResLife/ALotOfDataset/IQA/ChallengeDB_release", 1162],
|
3 |
+
"koniq10k": ["/disk1/chenzewen/sciResLife/ALotOfDataset/IQA/koniq-10k", 10073],
|
4 |
+
"bid": ["/disk1/chenzewen/sciResLife/ALotOfDataset/IQA/BID/ImageDatabase", 586],
|
5 |
+
"spaq": ["/home/ssl/Database/ChallengeDB_release/ChallengeDB_release/", 11125]
|
6 |
+
}
|
utils/dataset/folders.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.utils.data as data
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
import scipy.io
|
7 |
+
import numpy as np
|
8 |
+
import csv
|
9 |
+
from openpyxl import load_workbook
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
class LIVEC(data.Dataset):
|
13 |
+
def __init__(self, root, index, transform):
|
14 |
+
imgpath = scipy.io.loadmat(os.path.join(root, 'Data', 'AllImages_release.mat'))
|
15 |
+
imgpath = imgpath['AllImages_release']
|
16 |
+
imgpath = imgpath[7:1169]
|
17 |
+
mos = scipy.io.loadmat(os.path.join(root, 'Data', 'AllMOS_release.mat'))
|
18 |
+
labels = mos['AllMOS_release'].astype(np.float32)
|
19 |
+
labels = labels[0][7:1169]
|
20 |
+
|
21 |
+
sample, gt = [], []
|
22 |
+
for i, item in enumerate(index):
|
23 |
+
sample.append(os.path.join(root, 'Images', imgpath[item][0][0]))
|
24 |
+
gt.append(labels[item])
|
25 |
+
gt = normalization(gt)
|
26 |
+
|
27 |
+
self.samples, self.gt = sample, gt
|
28 |
+
self.transform = transform
|
29 |
+
|
30 |
+
def __getitem__(self, index):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
index (int): Index
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
tuple: (sample, target) where target is class_index of the target class.
|
37 |
+
"""
|
38 |
+
img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform)
|
39 |
+
|
40 |
+
return img_tensor, gt_tensor
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
length = len(self.samples)
|
44 |
+
return length
|
45 |
+
|
46 |
+
|
47 |
+
class Koniq10k(data.Dataset):
|
48 |
+
def __init__(self, root, index, transform):
|
49 |
+
imgname = []
|
50 |
+
mos_all = []
|
51 |
+
csv_file = os.path.join(root, 'koniq10k_distributions_sets.csv')
|
52 |
+
with open(csv_file) as f:
|
53 |
+
reader = csv.DictReader(f)
|
54 |
+
for row in reader:
|
55 |
+
imgname.append(row['image_name'])
|
56 |
+
mos = np.array(float(row['MOS'])).astype(np.float32)
|
57 |
+
mos_all.append(mos)
|
58 |
+
|
59 |
+
sample, gt = [], []
|
60 |
+
for i, item in enumerate(index):
|
61 |
+
sample.append(os.path.join(root, '1024x768', imgname[item]))
|
62 |
+
gt.append(mos_all[item])
|
63 |
+
gt = normalization(gt)
|
64 |
+
|
65 |
+
self.samples, self.gt = sample, gt
|
66 |
+
self.transform = transform
|
67 |
+
|
68 |
+
def __getitem__(self, index):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
index (int): Index
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
tuple: (sample, target) where target is class_index of the target class.
|
75 |
+
"""
|
76 |
+
img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform)
|
77 |
+
|
78 |
+
return img_tensor, gt_tensor
|
79 |
+
|
80 |
+
def __len__(self):
|
81 |
+
length = len(self.samples)
|
82 |
+
return length
|
83 |
+
|
84 |
+
|
85 |
+
class SPAQ(data.Dataset):
|
86 |
+
def __init__(self, root, index, transform):
|
87 |
+
imgname = []
|
88 |
+
mos_all = []
|
89 |
+
csv_file = os.path.join(root, 'koniq10k_scores_and_distributions.csv')
|
90 |
+
with open(csv_file) as f:
|
91 |
+
reader = csv.DictReader(f)
|
92 |
+
for row in reader:
|
93 |
+
imgname.append(row['image_name'])
|
94 |
+
mos = np.array(float(row['MOS_zscore'])).astype(np.float32)
|
95 |
+
mos_all.append(mos)
|
96 |
+
|
97 |
+
sample, gt = [], []
|
98 |
+
for i, item in enumerate(index):
|
99 |
+
sample.append(os.path.join(root, '1024x768', imgname[item]))
|
100 |
+
gt.append(labels[item])
|
101 |
+
gt = norm_target(gt)
|
102 |
+
|
103 |
+
self.samples, self.gt = sample, gt
|
104 |
+
|
105 |
+
self.samples, self.gt = sample, gt
|
106 |
+
self.transform = transform
|
107 |
+
|
108 |
+
def __getitem__(self, index):
|
109 |
+
"""
|
110 |
+
Args:
|
111 |
+
index (int): Index
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
tuple: (sample, target) where target is class_index of the target class.
|
115 |
+
"""
|
116 |
+
path, target = self.samples[index], self.gt[index]
|
117 |
+
sample = pil_loader(path)
|
118 |
+
sample = self.transform(sample)
|
119 |
+
return sample, target
|
120 |
+
|
121 |
+
def __len__(self):
|
122 |
+
length = len(self.samples)
|
123 |
+
return length
|
124 |
+
|
125 |
+
|
126 |
+
class BID(data.Dataset):
|
127 |
+
def __init__(self, root, index, transform):
|
128 |
+
|
129 |
+
imgname = []
|
130 |
+
mos_all = []
|
131 |
+
|
132 |
+
xls_file = os.path.join(root, 'DatabaseGrades.xlsx')
|
133 |
+
workbook = load_workbook(xls_file)
|
134 |
+
booksheet = workbook.active
|
135 |
+
rows = booksheet.rows
|
136 |
+
count = 1
|
137 |
+
for row in rows:
|
138 |
+
count += 1
|
139 |
+
img_num = booksheet.cell(row=count, column=1).value
|
140 |
+
img_name = "DatabaseImage%04d.JPG" % (img_num)
|
141 |
+
imgname.append(img_name)
|
142 |
+
mos = booksheet.cell(row=count, column=2).value
|
143 |
+
mos = np.array(mos)
|
144 |
+
mos = mos.astype(np.float32)
|
145 |
+
mos_all.append(mos)
|
146 |
+
if count == 587:
|
147 |
+
break
|
148 |
+
|
149 |
+
sample, gt = [], []
|
150 |
+
for i, item in enumerate(index):
|
151 |
+
sample.append(os.path.join(root, imgname[item]))
|
152 |
+
gt.append(mos_all[item])
|
153 |
+
gt = normalization(gt)
|
154 |
+
|
155 |
+
self.samples, self.gt = sample, gt
|
156 |
+
self.transform = transform
|
157 |
+
|
158 |
+
def __getitem__(self, index):
|
159 |
+
"""
|
160 |
+
Args:
|
161 |
+
index (int): Index
|
162 |
+
|
163 |
+
Returns:
|
164 |
+
tuple: (sample, target) where target is class_index of the target class.
|
165 |
+
"""
|
166 |
+
img_tensor, gt_tensor = get_item(self.samples, self.gt, index, self.transform)
|
167 |
+
|
168 |
+
return img_tensor, gt_tensor
|
169 |
+
|
170 |
+
def __len__(self):
|
171 |
+
length = len(self.samples)
|
172 |
+
return length
|
173 |
+
|
174 |
+
def get_item(samples, gt, index, transform):
|
175 |
+
path, target = samples[index], gt[index]
|
176 |
+
sample = load_image(path)
|
177 |
+
samples = {'img': sample, 'gt': target }
|
178 |
+
samples = transform(samples)
|
179 |
+
|
180 |
+
return samples['img'], samples['gt'].type(torch.FloatTensor)
|
181 |
+
|
182 |
+
|
183 |
+
def getFileName(path, suffix):
|
184 |
+
filename = []
|
185 |
+
f_list = os.listdir(path)
|
186 |
+
for i in f_list:
|
187 |
+
if os.path.splitext(i)[1] == suffix:
|
188 |
+
filename.append(i)
|
189 |
+
return filename
|
190 |
+
|
191 |
+
|
192 |
+
def load_image(img_path):
|
193 |
+
d_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
194 |
+
d_img = cv2.resize(d_img, (224, 224), interpolation=cv2.INTER_CUBIC)
|
195 |
+
d_img = cv2.cvtColor(d_img, cv2.COLOR_BGR2RGB)
|
196 |
+
d_img = np.array(d_img).astype('float32') / 255
|
197 |
+
d_img = np.transpose(d_img, (2, 0, 1))
|
198 |
+
|
199 |
+
return d_img
|
200 |
+
|
201 |
+
def normalization(data):
|
202 |
+
data = np.array(data)
|
203 |
+
range = np.max(data) - np.min(data)
|
204 |
+
data = (data - np.min(data)) / range
|
205 |
+
data = list(data.astype('float').reshape(-1, 1))
|
206 |
+
|
207 |
+
return data
|
utils/dataset/process.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
class Normalize(object):
|
6 |
+
def __init__(self, mean, var):
|
7 |
+
self.mean = mean
|
8 |
+
self.var = var
|
9 |
+
|
10 |
+
def __call__(self, sample):
|
11 |
+
if isinstance(sample, dict):
|
12 |
+
img = sample['img']
|
13 |
+
gt = sample['gt']
|
14 |
+
img = (img - self.mean) / self.var
|
15 |
+
sample = {'img': img, 'gt': gt}
|
16 |
+
else:
|
17 |
+
sample = (sample - self.mean) / self.var
|
18 |
+
|
19 |
+
return sample
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
class RandHorizontalFlip(object):
|
24 |
+
def __init__(self, prob_aug):
|
25 |
+
self.prob_aug = prob_aug
|
26 |
+
|
27 |
+
def __call__(self, sample):
|
28 |
+
p_aug = np.array([self.prob_aug, 1 - self.prob_aug])
|
29 |
+
prob_lr = np.random.choice([1, 0], p=p_aug.ravel())
|
30 |
+
|
31 |
+
if isinstance(sample, dict):
|
32 |
+
img = sample['img']
|
33 |
+
gt = sample['gt']
|
34 |
+
|
35 |
+
if prob_lr > 0.5:
|
36 |
+
img = np.fliplr(img).copy()
|
37 |
+
sample = {'img': img, 'gt': gt}
|
38 |
+
else:
|
39 |
+
if prob_lr > 0.5:
|
40 |
+
sample = np.fliplr(sample).copy()
|
41 |
+
return sample
|
42 |
+
|
43 |
+
|
44 |
+
class ToTensor(object):
|
45 |
+
def __init__(self):
|
46 |
+
pass
|
47 |
+
|
48 |
+
def __call__(self, sample):
|
49 |
+
if isinstance(sample, dict):
|
50 |
+
img = sample['img']
|
51 |
+
gt = sample['gt']
|
52 |
+
img = torch.from_numpy(img).type(torch.FloatTensor)
|
53 |
+
gt = torch.from_numpy(gt).type(torch.FloatTensor)
|
54 |
+
sample = {'img': img, 'gt': gt}
|
55 |
+
else:
|
56 |
+
sample = torch.from_numpy(sample).type(torch.FloatTensor)
|
57 |
+
return sample
|
utils/iqa_solver.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from scipy import stats
|
3 |
+
import numpy as np
|
4 |
+
from models import monet as MoNet
|
5 |
+
from models import gc_loss as GC_Loss
|
6 |
+
from utils.dataset import data_loader
|
7 |
+
import json
|
8 |
+
import random
|
9 |
+
import os
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
def get_data(dataset, data_path='./utils/dataset/dataset_info.json'):
|
14 |
+
with open(data_path, 'r') as data_info:
|
15 |
+
data_info = json.load(data_info)
|
16 |
+
path, img_num = data_info[dataset]
|
17 |
+
img_num = list(range(img_num))
|
18 |
+
|
19 |
+
random.shuffle(img_num)
|
20 |
+
train_index = img_num[0:int(round(0.8 * len(img_num)))]
|
21 |
+
test_index = img_num[int(round(0.8 * len(img_num))):len(img_num)]
|
22 |
+
|
23 |
+
return path, train_index, test_index
|
24 |
+
|
25 |
+
|
26 |
+
def cal_srocc_plcc(pred_score, gt_score):
|
27 |
+
srocc, _ = stats.spearmanr(pred_score, gt_score)
|
28 |
+
plcc, _ = stats.pearsonr(pred_score, gt_score)
|
29 |
+
|
30 |
+
return srocc, plcc
|
31 |
+
|
32 |
+
|
33 |
+
class Solver:
|
34 |
+
def __init__(self, config):
|
35 |
+
|
36 |
+
path, train_index, test_index = get_data(dataset=config.dataset)
|
37 |
+
|
38 |
+
train_loader = data_loader.Data_Loader(config, path, train_index, istrain=True)
|
39 |
+
test_loader = data_loader.Data_Loader(config, path, test_index, istrain=False)
|
40 |
+
self.train_data = train_loader.get_data()
|
41 |
+
self.test_data = test_loader.get_data()
|
42 |
+
|
43 |
+
print('Traning data number: ', len(train_index))
|
44 |
+
print('Testing data number: ', len(test_index))
|
45 |
+
|
46 |
+
if config.loss == 'MAE':
|
47 |
+
self.loss = torch.nn.L1Loss().cuda()
|
48 |
+
elif config.loss == 'MSE':
|
49 |
+
self.loss = torch.nn.MSELoss().cuda()
|
50 |
+
elif config.loss == 'GC':
|
51 |
+
self.loss = GC_Loss.GC_Loss(queue_len=int(len(train_index) * config.queue_ratio))
|
52 |
+
else:
|
53 |
+
raise 'Only Support MAE, MSE and GC loss.'
|
54 |
+
|
55 |
+
print('Loading MoNet...')
|
56 |
+
self.MoNet = MoNet.MoNet(config).cuda()
|
57 |
+
self.MoNet.train(True)
|
58 |
+
|
59 |
+
self.epochs = config.epochs
|
60 |
+
self.optimizer = torch.optim.Adam(self.MoNet.parameters(), lr=config.lr, weight_decay=config.weight_decay)
|
61 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=config.T_max, eta_min=config.eta_min)
|
62 |
+
|
63 |
+
self.model_save_path = os.path.join(config.save_path, 'best_model.pkl')
|
64 |
+
|
65 |
+
def train(self):
|
66 |
+
"""Training"""
|
67 |
+
best_srocc = 0.0
|
68 |
+
best_plcc = 0.0
|
69 |
+
print('----------------------------------')
|
70 |
+
print('Epoch\tTrain_Loss\tTrain_SROCC\tTrain_PLCC\tTest_SROCC\tTest_PLCC')
|
71 |
+
for t in range(self.epochs):
|
72 |
+
epoch_loss = []
|
73 |
+
pred_scores = []
|
74 |
+
gt_scores = []
|
75 |
+
|
76 |
+
for img, label in tqdm(self.train_data):
|
77 |
+
img = img.cuda()
|
78 |
+
label = label.view(-1).cuda()
|
79 |
+
|
80 |
+
self.optimizer.zero_grad()
|
81 |
+
|
82 |
+
pred = self.MoNet(img) # 'paras' contains the network weights conveyed to target network
|
83 |
+
|
84 |
+
pred_scores = pred_scores + pred.cpu().tolist()
|
85 |
+
gt_scores = gt_scores + label.cpu().tolist()
|
86 |
+
|
87 |
+
loss = self.loss(pred.squeeze(), label.float().detach())
|
88 |
+
epoch_loss.append(loss.item())
|
89 |
+
|
90 |
+
loss.backward()
|
91 |
+
self.optimizer.step()
|
92 |
+
self.scheduler.step()
|
93 |
+
|
94 |
+
train_srocc, train_plcc = cal_srocc_plcc(pred_scores, gt_scores)
|
95 |
+
|
96 |
+
test_srocc, test_plcc = self.test()
|
97 |
+
if test_srocc + test_plcc > best_srocc + best_plcc:
|
98 |
+
best_srocc = test_srocc
|
99 |
+
best_plcc = test_plcc
|
100 |
+
torch.save(self.MoNet.state_dict(), self.model_save_path)
|
101 |
+
print('Model saved in: ', self.model_save_path)
|
102 |
+
|
103 |
+
print('{}\t{}\t{}\t{}\t{}\t{}'.format(t + 1, round(np.mean(epoch_loss), 4), round(train_srocc, 4),
|
104 |
+
round(train_plcc, 4), round(test_srocc, 4), round(test_plcc, 4)))
|
105 |
+
|
106 |
+
print('Best test SROCC {}, PLCC {}'.format(round(best_srocc, 4), round(best_plcc, 4)))
|
107 |
+
|
108 |
+
return best_srocc, best_plcc
|
109 |
+
|
110 |
+
def test(self):
|
111 |
+
"""Testing"""
|
112 |
+
self.MoNet.train(False)
|
113 |
+
pred_scores = []
|
114 |
+
gt_scores = []
|
115 |
+
|
116 |
+
with torch.no_grad():
|
117 |
+
for img, label in tqdm(self.test_data):
|
118 |
+
# Data.
|
119 |
+
img = img.cuda()
|
120 |
+
label = label.view(-1).cuda()
|
121 |
+
|
122 |
+
pred = self.MoNet(img)
|
123 |
+
|
124 |
+
pred_scores = pred_scores + pred.cpu().tolist()
|
125 |
+
gt_scores = gt_scores + label.cpu().tolist()
|
126 |
+
|
127 |
+
test_srocc, test_plcc = cal_srocc_plcc(pred_scores, gt_scores)
|
128 |
+
|
129 |
+
self.MoNet.train(True)
|
130 |
+
return test_srocc, test_plcc
|
utils/log_writer.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
class Logger(object):
|
4 |
+
def __init__(self, filename="Default.log"):
|
5 |
+
self.terminal = sys.stdout
|
6 |
+
self.log = open(filename, "w")
|
7 |
+
|
8 |
+
def write(self, message):
|
9 |
+
self.terminal.write(message)
|
10 |
+
self.log.write(message)
|
11 |
+
self.flush()
|
12 |
+
|
13 |
+
def flush(self):
|
14 |
+
self.log.flush()
|