Spaces:
Running
Running
File size: 8,653 Bytes
159cb3e d960e2d 159cb3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.parallel import DistributedDataParallel as DDP
from ptflops import get_model_complexity_info
from .DarkIR import DarkIR
def create_model(opt, rank, adapter = False):
'''
Creates the model.
opt: a dictionary from the yaml config key network
'''
name = opt['name']
model = DarkIR(img_channel=opt['img_channels'],
width=opt['width'],
middle_blk_num_enc=opt['middle_blk_num_enc'],
middle_blk_num_dec=opt['middle_blk_num_dec'],
enc_blk_nums=opt['enc_blk_nums'],
dec_blk_nums=opt['dec_blk_nums'],
dilations=opt['dilations'],
extra_depth_wise=opt['extra_depth_wise'])
if rank ==0:
print(f'Using {name} network')
input_size = (3, 256, 256)
macs, params = get_model_complexity_info(model, input_size, print_per_layer_stat = False)
print(f'Computational complexity at {input_size}: {macs}')
print('Number of parameters: ', params)
else:
macs, params = None, None
model.to(rank)
model = DDP(model, device_ids=[rank], find_unused_parameters=adapter)
return model, macs, params
def create_optim_scheduler(opt, model):
'''
Returns the optim and its scheduler.
opt: a dictionary of the yaml config file with the train key
'''
optim = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()) ,
lr = opt['lr_initial'],
weight_decay = opt['weight_decay'],
betas = opt['betas'])
if opt['lr_scheme'] == 'CosineAnnealing':
scheduler = CosineAnnealingLR(optim, T_max=opt['epochs'], eta_min=opt['eta_min'])
else:
raise NotImplementedError('scheduler not implemented')
return optim, scheduler
def load_weights(model, old_weights):
'''
Loads the weights of a pretrained model, picking only the weights that are
in the new model.
'''
new_weights = model.state_dict()
new_weights.update({k: v for k, v in old_weights.items() if k in new_weights})
model.load_state_dict(new_weights)
return model
def load_optim(optim, optim_weights):
'''
Loads the values of the optimizer picking only the weights that are in the new model.
'''
optim_new_weights = optim.state_dict()
# optim_new_weights.load_state_dict(optim_weights)
optim_new_weights.update({k:v for k, v in optim_weights.items() if k in optim_new_weights})
return optim
def resume_model(model,
optim,
scheduler,
path_model,
rank,resume:str=None):
'''
Returns the loaded weights of model and optimizer if resume flag is True
'''
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
if resume:
checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
weights = checkpoints['model_state_dict']
model = load_weights(model, old_weights=weights)
optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
scheduler.load_state_dict(checkpoints['scheduler_state_dict'])
start_epochs = checkpoints['epoch']
if rank == 0: print('Loaded weights')
else:
start_epochs = 0
if rank==0: print('Starting from zero the training')
return model, optim, scheduler, start_epochs
def find_different_keys(dict1, dict2):
# Finding different keys
different_keys = set(dict1.keys()) ^ set(dict2.keys())
return different_keys
def number_common_keys(dict1, dict2):
# Finding common keys
common_keys = set(dict1.keys()) & set(dict2.keys())
# Counting the number of common keys
common_keys_count = len(common_keys)
return common_keys_count
# # Function to add 'modules_list' prefix after the first numeric index
# def add_middle_prefix(state_dict, middle_prefix, target_strings):
# new_state_dict = {}
# for key, value in state_dict.items():
# for target in target_strings:
# if target in key:
# parts = key.split('.')
# # Find the first numeric index after the target string
# for i, part in enumerate(parts):
# if part == target:
# # Insert the middle prefix after the first numeric index
# if i + 1 < len(parts) and parts[i + 1].isdigit():
# parts.insert(i + 2, middle_prefix)
# break
# new_key = '.'.join(parts)
# new_state_dict[new_key] = value
# break
# else:
# new_state_dict[key] = value
# return new_state_dict
# # Function to adjust keys for 'middle_blks.' prefix
# def adjust_middle_blks_keys(state_dict, target_prefix, middle_prefix):
# new_state_dict = {}
# for key, value in state_dict.items():
# if target_prefix in key:
# parts = key.split('.')
# # Find the target prefix and adjust the key
# for i, part in enumerate(parts):
# if part == target_prefix.rstrip('.'):
# if i + 1 < len(parts) and parts[i + 1].isdigit():
# # Swap the numerical part and the middle prefix
# new_key = '.'.join(parts[:i + 1] + [middle_prefix] + parts[i + 1:i + 2] + parts[i + 2:])
# new_state_dict[new_key] = value
# break
# else:
# new_state_dict[key] = value
# return new_state_dict
# def resume_nafnet(model,
# optim,
# scheduler,
# path_adapter,
# path_model,
# rank, resume:str=None):
# '''
# Returns the loaded weights of model and optimizer if resume flag is True
# '''
# map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
# #first load the model weights
# checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
# weights = checkpoints
# if rank==0:
# print(len(weights), len(model.state_dict().keys()))
# different_keys = find_different_keys(weights, model.state_dict())
# filtered_keys = {item for item in different_keys if 'adapter' not in item}
# print(filtered_keys)
# print(len(filtered_keys))
# model = load_weights(model, old_weights=weights)
# #now if needed load the adapter weights
# if resume:
# checkpoints = torch.load(path_adapter, map_location=map_location, weights_only=False)
# weights = checkpoints
# model = load_weights(model, old_weights=weights)
# # optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoints['scheduler_state_dict'])
# start_epochs = checkpoints['epoch']
# if rank == 0: print('Loaded weights')
# else:
# start_epochs = 0
# if rank == 0: print('Starting from zero the training')
# return model, optim, scheduler, start_epochs
def save_checkpoint(model, optim, scheduler, metrics_eval, metrics_train, paths, adapter = False, rank = None):
'''
Save the .pt of the model after each epoch.
'''
best_psnr = metrics_train['best_psnr']
if rank!=0:
return best_psnr
if type(next(iter(metrics_eval.values()))) != dict:
metrics_eval = {'metrics': metrics_eval}
weights = model.state_dict()
# Save the model after every epoch
model_to_save = {
'epoch': metrics_train['epoch'],
'model_state_dict': weights,
'optimizer_state_dict': optim.state_dict(),
'loss': metrics_train['train_loss'],
'scheduler_state_dict': scheduler.state_dict()
}
try:
torch.save(model_to_save, paths['new'])
# Save best model if new valid_psnr is higher than the best one
if next(iter(metrics_eval.values()))['valid_psnr'] >= metrics_train['best_psnr']:
torch.save(model_to_save, paths['best'])
metrics_train['best_psnr'] = next(iter(metrics_eval.values()))['valid_psnr'] # update best psnr
except Exception as e:
print(f"Error saving model: {e}")
return metrics_train['best_psnr']
__all__ = ['create_model', 'resume_model', 'create_optim_scheduler', 'save_checkpoint',
'load_optim', 'load_weights']
|