import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.parallel import DistributedDataParallel as DDP
from ptflops import get_model_complexity_info
from .DarkIR import DarkIR
def create_model(opt, rank, adapter = False):
Creates the model.
opt: a dictionary from the yaml config key network
name = opt['name']
model = DarkIR(img_channel=opt['img_channels'],
if rank ==0:
print(f'Using {name} network')
input_size = (3, 256, 256)
macs, params = get_model_complexity_info(model, input_size, print_per_layer_stat = False)
print(f'Computational complexity at {input_size}: {macs}')
print('Number of parameters: ', params)
macs, params = None, None
model = DDP(model, device_ids=[rank], find_unused_parameters=adapter)
return model, macs, params
def create_optim_scheduler(opt, model):
Returns the optim and its scheduler.
opt: a dictionary of the yaml config file with the train key
optim = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()) ,
lr = opt['lr_initial'],
weight_decay = opt['weight_decay'],
betas = opt['betas'])
if opt['lr_scheme'] == 'CosineAnnealing':
scheduler = CosineAnnealingLR(optim, T_max=opt['epochs'], eta_min=opt['eta_min'])
raise NotImplementedError('scheduler not implemented')
return optim, scheduler
def load_weights(model, old_weights):
Loads the weights of a pretrained model, picking only the weights that are
in the new model.
new_weights = model.state_dict()
new_weights.update({k: v for k, v in old_weights.items() if k in new_weights})
return model
def load_optim(optim, optim_weights):
Loads the values of the optimizer picking only the weights that are in the new model.
optim_new_weights = optim.state_dict()
# optim_new_weights.load_state_dict(optim_weights)
optim_new_weights.update({k:v for k, v in optim_weights.items() if k in optim_new_weights})
return optim
def resume_model(model,
Returns the loaded weights of model and optimizer if resume flag is True
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
if resume:
checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
weights = checkpoints['model_state_dict']
model = load_weights(model, old_weights=weights)
optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
start_epochs = checkpoints['epoch']
if rank == 0: print('Loaded weights')
start_epochs = 0
if rank==0: print('Starting from zero the training')
return model, optim, scheduler, start_epochs
def find_different_keys(dict1, dict2):
# Finding different keys
different_keys = set(dict1.keys()) ^ set(dict2.keys())
return different_keys
def number_common_keys(dict1, dict2):
# Finding common keys
common_keys = set(dict1.keys()) & set(dict2.keys())
# Counting the number of common keys
common_keys_count = len(common_keys)
return common_keys_count
def save_checkpoint(model, optim, scheduler, metrics_eval, metrics_train, paths, adapter = False, rank = None):
Save the .pt of the model after each epoch.
best_psnr = metrics_train['best_psnr']
if rank!=0:
return best_psnr
if type(next(iter(metrics_eval.values()))) != dict:
metrics_eval = {'metrics': metrics_eval}
weights = model.state_dict()
# Save the model after every epoch
model_to_save = {
'epoch': metrics_train['epoch'],
'model_state_dict': weights,
'optimizer_state_dict': optim.state_dict(),
'loss': metrics_train['train_loss'],
'scheduler_state_dict': scheduler.state_dict()
torch.save(model_to_save, paths['new'])
# Save best model if new valid_psnr is higher than the best one
if next(iter(metrics_eval.values()))['valid_psnr'] >= metrics_train['best_psnr']:
torch.save(model_to_save, paths['best'])
metrics_train['best_psnr'] = next(iter(metrics_eval.values()))['valid_psnr'] # update best psnr
except Exception as e:
print(f"Error saving model: {e}")
return metrics_train['best_psnr']
__all__ = ['create_model', 'resume_model', 'create_optim_scheduler', 'save_checkpoint',
'load_optim', 'load_weights']