File size: 8,653 Bytes
159cb3e
 
 
 
d960e2d
159cb3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn.parallel import DistributedDataParallel as DDP
from ptflops import get_model_complexity_info

from .DarkIR import DarkIR   

def create_model(opt, rank, adapter = False):
    '''
    Creates the model.
    opt: a dictionary from the yaml config key network
    '''
    name = opt['name']


    model = DarkIR(img_channel=opt['img_channels'], 
                    width=opt['width'], 
                    middle_blk_num_enc=opt['middle_blk_num_enc'],
                    middle_blk_num_dec=opt['middle_blk_num_dec'], 
                    enc_blk_nums=opt['enc_blk_nums'],
                    dec_blk_nums=opt['dec_blk_nums'], 
                    dilations=opt['dilations'],
                    extra_depth_wise=opt['extra_depth_wise'])

    if rank ==0:
        print(f'Using {name} network')

        input_size = (3, 256, 256)
        macs, params = get_model_complexity_info(model, input_size, print_per_layer_stat = False)
        print(f'Computational complexity at {input_size}: {macs}')
        print('Number of parameters: ', params)    
    else:
        macs, params = None, None

    model.to(rank)
    
    model = DDP(model, device_ids=[rank], find_unused_parameters=adapter)
    
    return model, macs, params

def create_optim_scheduler(opt, model):
    '''
    Returns the optim and its scheduler.
    opt: a dictionary of the yaml config file with the train key
    '''
    optim = torch.optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()) , 
                            lr = opt['lr_initial'],
                            weight_decay = opt['weight_decay'],
                            betas = opt['betas'])
    
    if opt['lr_scheme'] == 'CosineAnnealing':
        scheduler = CosineAnnealingLR(optim, T_max=opt['epochs'], eta_min=opt['eta_min'])
    else: 
        raise NotImplementedError('scheduler not implemented')    
        
    return optim, scheduler

def load_weights(model, old_weights):
    '''
    Loads the weights of a pretrained model, picking only the weights that are
    in the new model.
    '''
    new_weights = model.state_dict()
    new_weights.update({k: v for k, v in old_weights.items() if k in new_weights})
    
    model.load_state_dict(new_weights)
    return model

def load_optim(optim, optim_weights):
    '''
    Loads the values of the optimizer picking only the weights that are in the new model.
    '''
    optim_new_weights = optim.state_dict()
    # optim_new_weights.load_state_dict(optim_weights)
    optim_new_weights.update({k:v for k, v in optim_weights.items() if k in optim_new_weights})
    return optim

def resume_model(model,
                 optim,
                 scheduler, 
                 path_model, 
                 rank,resume:str=None):
    '''
    Returns the loaded weights of model and optimizer if resume flag is True
    '''
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    if resume:
        checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
        weights = checkpoints['model_state_dict']
        model = load_weights(model, old_weights=weights)
        optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoints['scheduler_state_dict'])
        start_epochs = checkpoints['epoch']

        if rank == 0: print('Loaded weights')
    else:
        start_epochs = 0
        if rank==0: print('Starting from zero the training')
    
    return model, optim, scheduler, start_epochs

def find_different_keys(dict1, dict2):

# Finding different keys
    different_keys = set(dict1.keys()) ^ set(dict2.keys())

    return different_keys

def number_common_keys(dict1, dict2):
    # Finding common keys
    common_keys = set(dict1.keys()) & set(dict2.keys())

    # Counting the number of common keys
    common_keys_count = len(common_keys)
    return common_keys_count

# # Function to add 'modules_list' prefix after the first numeric index
# def add_middle_prefix(state_dict, middle_prefix, target_strings):
#     new_state_dict = {}
#     for key, value in state_dict.items():
#         for target in target_strings:
#             if target in key:
#                 parts = key.split('.')
#                 # Find the first numeric index after the target string
#                 for i, part in enumerate(parts):
#                     if part == target:
#                         # Insert the middle prefix after the first numeric index
#                         if i + 1 < len(parts) and parts[i + 1].isdigit():
#                             parts.insert(i + 2, middle_prefix)
#                             break
#                 new_key = '.'.join(parts)
#                 new_state_dict[new_key] = value
#                 break
#         else:
#             new_state_dict[key] = value
#     return new_state_dict

# # Function to adjust keys for 'middle_blks.' prefix
# def adjust_middle_blks_keys(state_dict, target_prefix, middle_prefix):
#     new_state_dict = {}
#     for key, value in state_dict.items():
#         if target_prefix in key:
#             parts = key.split('.')
#             # Find the target prefix and adjust the key
#             for i, part in enumerate(parts):
#                 if part == target_prefix.rstrip('.'):
#                     if i + 1 < len(parts) and parts[i + 1].isdigit():
#                         # Swap the numerical part and the middle prefix
#                         new_key = '.'.join(parts[:i + 1] + [middle_prefix] + parts[i + 1:i + 2] + parts[i + 2:])
#                         new_state_dict[new_key] = value
#                         break
#         else:
#             new_state_dict[key] = value
#     return new_state_dict

# def resume_nafnet(model,
#                  optim,
#                  scheduler, 
#                  path_adapter,
#                  path_model,
#                  rank, resume:str=None):
#     '''
#     Returns the loaded weights of model and optimizer if resume flag is True
#     '''
#     map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
#     #first load the model weights
#     checkpoints = torch.load(path_model, map_location=map_location, weights_only=False)
#     weights = checkpoints
#     if rank==0:
#         print(len(weights), len(model.state_dict().keys()))

#         different_keys = find_different_keys(weights, model.state_dict())
#         filtered_keys = {item for item in different_keys if 'adapter' not in item}
#         print(filtered_keys)
#         print(len(filtered_keys))
#     model = load_weights(model, old_weights=weights) 
#     #now if needed load the adapter weights
#     if resume:
#         checkpoints = torch.load(path_adapter, map_location=map_location, weights_only=False)
#         weights = checkpoints
#         model = load_weights(model, old_weights=weights)
#         # optim = load_optim(optim, optim_weights = checkpoints['optimizer_state_dict'])
#         scheduler.load_state_dict(checkpoints['scheduler_state_dict'])
#         start_epochs = checkpoints['epoch']

#         if rank == 0: print('Loaded weights')
#     else:
#         start_epochs = 0
#         if rank == 0: print('Starting from zero the training')
    
#     return model, optim, scheduler, start_epochs

def save_checkpoint(model, optim, scheduler, metrics_eval, metrics_train, paths, adapter = False, rank = None):

    '''
    Save the .pt of the model after each epoch.
    '''
    best_psnr = metrics_train['best_psnr']
    if rank!=0: 
        return best_psnr
    
    if type(next(iter(metrics_eval.values()))) != dict:
        metrics_eval = {'metrics': metrics_eval}

    weights = model.state_dict()

    # Save the model after every epoch
    model_to_save = {
        'epoch': metrics_train['epoch'],
        'model_state_dict': weights,
        'optimizer_state_dict': optim.state_dict(),
        'loss': metrics_train['train_loss'],
        'scheduler_state_dict': scheduler.state_dict()
    }

    try:
        torch.save(model_to_save, paths['new'])

        # Save best model if new valid_psnr is higher than the best one
        if next(iter(metrics_eval.values()))['valid_psnr'] >= metrics_train['best_psnr']:
            torch.save(model_to_save, paths['best'])
            metrics_train['best_psnr'] = next(iter(metrics_eval.values()))['valid_psnr']  # update best psnr
    except Exception as e:
        print(f"Error saving model: {e}")
    return metrics_train['best_psnr']

__all__ = ['create_model', 'resume_model', 'create_optim_scheduler', 'save_checkpoint',
           'load_optim', 'load_weights']