GMC-IQA / utils /iqa_solver.py
Zevin2023's picture
MoC-IQA
07e1105
raw
history blame contribute delete
No virus
4.58 kB
import torch
from scipy import stats
import numpy as np
from models import monet as MoNet
from models import gc_loss as GC_Loss
from utils.dataset import data_loader
import json
import random
import os
from tqdm import tqdm
def get_data(dataset, data_path='./utils/dataset/dataset_info.json'):
with open(data_path, 'r') as data_info:
data_info = json.load(data_info)
path, img_num = data_info[dataset]
img_num = list(range(img_num))
random.shuffle(img_num)
train_index = img_num[0:int(round(0.8 * len(img_num)))]
test_index = img_num[int(round(0.8 * len(img_num))):len(img_num)]
return path, train_index, test_index
def cal_srocc_plcc(pred_score, gt_score):
srocc, _ = stats.spearmanr(pred_score, gt_score)
plcc, _ = stats.pearsonr(pred_score, gt_score)
return srocc, plcc
class Solver:
def __init__(self, config):
path, train_index, test_index = get_data(dataset=config.dataset)
train_loader = data_loader.Data_Loader(config, path, train_index, istrain=True)
test_loader = data_loader.Data_Loader(config, path, test_index, istrain=False)
self.train_data = train_loader.get_data()
self.test_data = test_loader.get_data()
print('Traning data number: ', len(train_index))
print('Testing data number: ', len(test_index))
if config.loss == 'MAE':
self.loss = torch.nn.L1Loss().cuda()
elif config.loss == 'MSE':
self.loss = torch.nn.MSELoss().cuda()
elif config.loss == 'GC':
self.loss = GC_Loss.GC_Loss(queue_len=int(len(train_index) * config.queue_ratio))
else:
raise 'Only Support MAE, MSE and GC loss.'
print('Loading MoNet...')
self.MoNet = MoNet.MoNet(config).cuda()
self.MoNet.train(True)
self.epochs = config.epochs
self.optimizer = torch.optim.Adam(self.MoNet.parameters(), lr=config.lr, weight_decay=config.weight_decay)
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=config.T_max, eta_min=config.eta_min)
self.model_save_path = os.path.join(config.save_path, 'best_model.pkl')
def train(self):
"""Training"""
best_srocc = 0.0
best_plcc = 0.0
print('----------------------------------')
print('Epoch\tTrain_Loss\tTrain_SROCC\tTrain_PLCC\tTest_SROCC\tTest_PLCC')
for t in range(self.epochs):
epoch_loss = []
pred_scores = []
gt_scores = []
for img, label in tqdm(self.train_data):
img = img.cuda()
label = label.view(-1).cuda()
self.optimizer.zero_grad()
pred = self.MoNet(img) # 'paras' contains the network weights conveyed to target network
pred_scores = pred_scores + pred.cpu().tolist()
gt_scores = gt_scores + label.cpu().tolist()
loss = self.loss(pred.squeeze(), label.float().detach())
epoch_loss.append(loss.item())
loss.backward()
self.optimizer.step()
self.scheduler.step()
train_srocc, train_plcc = cal_srocc_plcc(pred_scores, gt_scores)
test_srocc, test_plcc = self.test()
if test_srocc + test_plcc > best_srocc + best_plcc:
best_srocc = test_srocc
best_plcc = test_plcc
torch.save(self.MoNet.state_dict(), self.model_save_path)
print('Model saved in: ', self.model_save_path)
print('{}\t{}\t{}\t{}\t{}\t{}'.format(t + 1, round(np.mean(epoch_loss), 4), round(train_srocc, 4),
round(train_plcc, 4), round(test_srocc, 4), round(test_plcc, 4)))
print('Best test SROCC {}, PLCC {}'.format(round(best_srocc, 4), round(best_plcc, 4)))
return best_srocc, best_plcc
def test(self):
"""Testing"""
self.MoNet.train(False)
pred_scores = []
gt_scores = []
with torch.no_grad():
for img, label in tqdm(self.test_data):
# Data.
img = img.cuda()
label = label.view(-1).cuda()
pred = self.MoNet(img)
pred_scores = pred_scores + pred.cpu().tolist()
gt_scores = gt_scores + label.cpu().tolist()
test_srocc, test_plcc = cal_srocc_plcc(pred_scores, gt_scores)
self.MoNet.train(True)
return test_srocc, test_plcc