Spaces:
Runtime error
Runtime error
File size: 6,393 Bytes
fc16538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import math
import torch
import torch.nn as nn
from vidar.utils.distributed import print0, rank, dist_mode
from vidar.utils.logging import pcolor
from vidar.utils.tensor import same_shape
from vidar.utils.types import is_list
def freeze_layers(network, layers=('ALL',), flag_freeze=True):
"""
Freeze layers of a network (weights and biases)
Parameters
----------
network : nn.Module
Network to be modified
layers : List or Tuple
List of layers to freeze/unfreeze ('ALL' for everything)
flag_freeze : Bool
Whether the layers will be frozen (True) or not (False)
"""
if len(layers) > 0:
for name, parameters in network.named_parameters():
for layer in layers:
if layer in name or layer == 'ALL':
parameters.requires_grad_(not flag_freeze)
def freeze_norms(network, layers=('ALL',), flag_freeze=True):
"""
Freeze layers of a network (normalization)
Parameters
----------
network : nn.Module
Network to be modified
layers : List or Tuple
List of layers to freeze/unfreeze ('ALL' for everything)
flag_freeze : Bool
Whether the layers will be frozen (True) or not (False)
"""
if len(layers) > 0:
for name, module in network.named_modules():
for layer in layers:
if layer in name or layer == 'ALL':
if isinstance(module, nn.BatchNorm2d):
if hasattr(module, 'weight'):
module.weight.requires_grad_(not flag_freeze)
if hasattr(module, 'bias'):
module.bias.requires_grad_(not flag_freeze)
if flag_freeze:
module.eval()
else:
module.train()
def freeze_layers_and_norms(network, layers=('ALL',), flag_freeze=True):
"""Freeze layers and normalizations of a network"""
freeze_layers(network, layers, flag_freeze)
freeze_norms(network, layers, flag_freeze)
def make_val_fit(model, key, val, updated_state_dict, strict=False):
"""
Parse state dictionary to fit a model, and make tensors fit if requested
Parameters
----------
model : nn.Module
Network to be used
key : String
Which key will be used
val : torch.Tensor
Key value
updated_state_dict : Dict
Updated dictionary
strict : Bool
True if no changes are allowed, False if tensors can be changed to fit
Returns
-------
fit : Int
Number of tensors that fit the model
"""
fit = 0
val_new = model.state_dict()[key]
if same_shape(val.shape, val_new.shape):
updated_state_dict[key] = val
fit += 1
elif not strict:
for i in range(val.dim()):
if val.shape[i] != val_new.shape[i]:
if val_new.shape[i] > val.shape[i]:
ratio = math.ceil(val_new.shape[i] / val.shape[i])
val = torch.cat([val] * ratio, i)
if val.shape[i] != val_new.shape[i]:
val = val[:val_new.shape[i]]
if same_shape(val.shape, val_new.shape):
updated_state_dict[key] = val
fit += 1
elif val_new.shape[0] < val.shape[i]:
val = val[:val_new.shape[i]]
if same_shape(val.shape, val_new.shape):
updated_state_dict[key] = val
fit += 1
assert fit <= 1 # Each tensor cannot fit 2 or more times
return fit
def load_checkpoint(model, checkpoint, strict=False, verbose=False, prefix=None):
"""
Load checkpoint into a model
Parameters
----------
model : nn.Module
Input network
checkpoint : String or list[String]
Checkpoint path (if it's a list, load them in order)
strict : Bool
True if all tensors are required, False if can be partially loaded
verbose : Bool
Print information on screen
prefix : String
Prefix used to change keys
Returns
-------
model: nn.Module
Loaded network
"""
if is_list(checkpoint):
for ckpt in checkpoint:
load_checkpoint(model, ckpt, strict, verbose)
return model
font1 = {'color': 'magenta', 'attrs': ('bold', 'dark')}
font2 = {'color': 'magenta', 'attrs': ('bold',)}
if verbose:
print0(pcolor('#' * 60, **font1))
print0(pcolor('###### Loading from checkpoint: ', **font1) +
pcolor('{}'.format(checkpoint), **font2))
state_dict = torch.load(
checkpoint,
map_location='cpu' if dist_mode() == 'cpu' else 'cuda:{}'.format(rank())
)['state_dict']
updated_state_dict = {}
total, fit = len(model.state_dict()), 0
for key, val in state_dict.items():
for start in ['model.', 'module.']:
if key.startswith(start):
key = key[len(start):]
if prefix is not None:
idx = key.find(prefix)
if idx > -1:
key = key[(idx + len(prefix) + 1):]
if key in model.state_dict().keys():
fit += make_val_fit(model, key, val, updated_state_dict, strict=strict)
model.load_state_dict(updated_state_dict, strict=strict)
if verbose:
color = 'red' if fit == 0 else 'yellow' if fit < total else 'green'
print0(pcolor('###### Loaded ', **font1) + \
pcolor('{}/{}'.format(fit,total), color=color, attrs=('bold',)) + \
pcolor(' tensors', **font1))
print0(pcolor('#' * 60, **font1))
return model
def save_checkpoint(filename, wrapper, epoch=None):
"""
Save checkpoint to disk
Parameters
----------
filename : String
Name of the file
wrapper : nn.Module
Model wrapper to save
epoch : Int
Training epoch
"""
if epoch is None:
torch.save({
'state_dict': wrapper.state_dict(),
}, filename)
else:
torch.save({
'epoch': epoch,
'config': wrapper.cfg,
'state_dict': wrapper.arch.state_dict(),
}, filename)
|