# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. import math import torch import torch.nn as nn from vidar.utils.distributed import print0, rank, dist_mode from vidar.utils.logging import pcolor from vidar.utils.tensor import same_shape from vidar.utils.types import is_list def freeze_layers(network, layers=('ALL',), flag_freeze=True): """ Freeze layers of a network (weights and biases) Parameters ---------- network : nn.Module Network to be modified layers : List or Tuple List of layers to freeze/unfreeze ('ALL' for everything) flag_freeze : Bool Whether the layers will be frozen (True) or not (False) """ if len(layers) > 0: for name, parameters in network.named_parameters(): for layer in layers: if layer in name or layer == 'ALL': parameters.requires_grad_(not flag_freeze) def freeze_norms(network, layers=('ALL',), flag_freeze=True): """ Freeze layers of a network (normalization) Parameters ---------- network : nn.Module Network to be modified layers : List or Tuple List of layers to freeze/unfreeze ('ALL' for everything) flag_freeze : Bool Whether the layers will be frozen (True) or not (False) """ if len(layers) > 0: for name, module in network.named_modules(): for layer in layers: if layer in name or layer == 'ALL': if isinstance(module, nn.BatchNorm2d): if hasattr(module, 'weight'): module.weight.requires_grad_(not flag_freeze) if hasattr(module, 'bias'): module.bias.requires_grad_(not flag_freeze) if flag_freeze: module.eval() else: module.train() def freeze_layers_and_norms(network, layers=('ALL',), flag_freeze=True): """Freeze layers and normalizations of a network""" freeze_layers(network, layers, flag_freeze) freeze_norms(network, layers, flag_freeze) def make_val_fit(model, key, val, updated_state_dict, strict=False): """ Parse state dictionary to fit a model, and make tensors fit if requested Parameters ---------- model : nn.Module Network to be used key : String Which key will be used val : torch.Tensor Key value updated_state_dict : Dict Updated dictionary strict : Bool True if no changes are allowed, False if tensors can be changed to fit Returns ------- fit : Int Number of tensors that fit the model """ fit = 0 val_new = model.state_dict()[key] if same_shape(val.shape, val_new.shape): updated_state_dict[key] = val fit += 1 elif not strict: for i in range(val.dim()): if val.shape[i] != val_new.shape[i]: if val_new.shape[i] > val.shape[i]: ratio = math.ceil(val_new.shape[i] / val.shape[i]) val = torch.cat([val] * ratio, i) if val.shape[i] != val_new.shape[i]: val = val[:val_new.shape[i]] if same_shape(val.shape, val_new.shape): updated_state_dict[key] = val fit += 1 elif val_new.shape[0] < val.shape[i]: val = val[:val_new.shape[i]] if same_shape(val.shape, val_new.shape): updated_state_dict[key] = val fit += 1 assert fit <= 1 # Each tensor cannot fit 2 or more times return fit def load_checkpoint(model, checkpoint, strict=False, verbose=False, prefix=None): """ Load checkpoint into a model Parameters ---------- model : nn.Module Input network checkpoint : String or list[String] Checkpoint path (if it's a list, load them in order) strict : Bool True if all tensors are required, False if can be partially loaded verbose : Bool Print information on screen prefix : String Prefix used to change keys Returns ------- model: nn.Module Loaded network """ if is_list(checkpoint): for ckpt in checkpoint: load_checkpoint(model, ckpt, strict, verbose) return model font1 = {'color': 'magenta', 'attrs': ('bold', 'dark')} font2 = {'color': 'magenta', 'attrs': ('bold',)} if verbose: print0(pcolor('#' * 60, **font1)) print0(pcolor('###### Loading from checkpoint: ', **font1) + pcolor('{}'.format(checkpoint), **font2)) state_dict = torch.load( checkpoint, map_location='cpu' if dist_mode() == 'cpu' else 'cuda:{}'.format(rank()) )['state_dict'] updated_state_dict = {} total, fit = len(model.state_dict()), 0 for key, val in state_dict.items(): for start in ['model.', 'module.']: if key.startswith(start): key = key[len(start):] if prefix is not None: idx = key.find(prefix) if idx > -1: key = key[(idx + len(prefix) + 1):] if key in model.state_dict().keys(): fit += make_val_fit(model, key, val, updated_state_dict, strict=strict) model.load_state_dict(updated_state_dict, strict=strict) if verbose: color = 'red' if fit == 0 else 'yellow' if fit < total else 'green' print0(pcolor('###### Loaded ', **font1) + \ pcolor('{}/{}'.format(fit,total), color=color, attrs=('bold',)) + \ pcolor(' tensors', **font1)) print0(pcolor('#' * 60, **font1)) return model def save_checkpoint(filename, wrapper, epoch=None): """ Save checkpoint to disk Parameters ---------- filename : String Name of the file wrapper : nn.Module Model wrapper to save epoch : Int Training epoch """ if epoch is None: torch.save({ 'state_dict': wrapper.state_dict(), }, filename) else: torch.save({ 'epoch': epoch, 'config': wrapper.cfg, 'state_dict': wrapper.arch.state_dict(), }, filename)