""" trainer.py - warpper and utility functions for network training Compute loss, back-prop, update parameters, logging, etc. """ import datetime import os import time import numpy as np import torch import torch.nn as nn import torch.optim as optim from model.network import XMem from model.losses import LossComputer from util.log_integrator import Integrator from util.image_saver import pool_pairs class XMemTrainer: def __init__(self, config, logger=None, save_path=None, local_rank=0, world_size=1): self.config = config self.num_frames = config['num_frames'] self.num_ref_frames = config['num_ref_frames'] self.deep_update_prob = config['deep_update_prob'] self.local_rank = local_rank self.XMem = nn.parallel.DistributedDataParallel( XMem(config).cuda(), device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) # Set up logger when local_rank=0 self.logger = logger self.save_path = save_path if logger is not None: self.last_time = time.time() self.logger.log_string('model_size', str(sum([param.nelement() for param in self.XMem.parameters()]))) self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) self.loss_computer = LossComputer(config) self.train() self.optimizer = optim.AdamW(filter( lambda p: p.requires_grad, self.XMem.parameters()), lr=config['lr'], weight_decay=config['weight_decay']) self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, config['steps'], config['gamma']) if config['amp']: self.scaler = torch.cuda.amp.GradScaler() # Logging info self.log_text_interval = config['log_text_interval'] self.log_image_interval = config['log_image_interval'] self.save_network_interval = config['save_network_interval'] self.save_checkpoint_interval = config['save_checkpoint_interval'] if config['debug']: self.log_text_interval = self.log_image_interval = 1 def do_pass(self, data, max_it, it=0): # No need to store the gradient outside training torch.set_grad_enabled(self._is_train) for k, v in data.items(): if type(v) != list and type(v) != dict and type(v) != int: data[k] = v.cuda(non_blocking=True) out = {} frames = data['rgb'] first_frame_gt = data['first_frame_gt'].float() b = frames.shape[0] num_filled_objects = [o.item() for o in data['info']['num_objects']] num_objects = first_frame_gt.shape[2] selector = data['selector'].unsqueeze(2).unsqueeze(2) global_avg = 0 with torch.cuda.amp.autocast(enabled=self.config['amp']): # image features never change, compute once key, shrinkage, selection, f16, f8, f4 = self.XMem('encode_key', frames) filler_one = torch.zeros(1, dtype=torch.int64) hidden = torch.zeros((b, num_objects, self.config['hidden_dim'], *key.shape[-2:])) v16, hidden = self.XMem('encode_value', frames[:,0], f16[:,0], hidden, first_frame_gt[:,0]) values = v16.unsqueeze(3) # add the time dimension for ti in range(1, self.num_frames): if ti <= self.num_ref_frames: ref_values = values ref_keys = key[:,:,:ti] ref_shrinkage = shrinkage[:,:,:ti] if shrinkage is not None else None else: # pick num_ref_frames random frames # this is not very efficient but I think we would # need broadcasting in gather which we don't have indices = [ torch.cat([filler_one, torch.randperm(ti-1)[:self.num_ref_frames-1]+1]) for _ in range(b)] ref_values = torch.stack([ values[bi, :, :, indices[bi]] for bi in range(b) ], 0) ref_keys = torch.stack([ key[bi, :, indices[bi]] for bi in range(b) ], 0) ref_shrinkage = torch.stack([ shrinkage[bi, :, indices[bi]] for bi in range(b) ], 0) if shrinkage is not None else None # Segment frame ti memory_readout = self.XMem('read_memory', key[:,:,ti], selection[:,:,ti] if selection is not None else None, ref_keys, ref_shrinkage, ref_values) hidden, logits, masks = self.XMem('segment', (f16[:,ti], f8[:,ti], f4[:,ti]), memory_readout, hidden, selector, h_out=(ti < (self.num_frames-1))) # No need to encode the last frame if ti < (self.num_frames-1): is_deep_update = np.random.rand() < self.deep_update_prob v16, hidden = self.XMem('encode_value', frames[:,ti], f16[:,ti], hidden, masks, is_deep_update=is_deep_update) values = torch.cat([values, v16.unsqueeze(3)], 3) out[f'masks_{ti}'] = masks out[f'logits_{ti}'] = logits if self._do_log or self._is_train: losses = self.loss_computer.compute({**data, **out}, num_filled_objects, it) # Logging if self._do_log: self.integrator.add_dict(losses) if self._is_train: if it % self.log_image_interval == 0 and it != 0: if self.logger is not None: images = {**data, **out} size = (384, 384) self.logger.log_cv2('train/pairs', pool_pairs(images, size, num_filled_objects), it) if self._is_train: if (it) % self.log_text_interval == 0 and it != 0: time_spent = time.time()-self.last_time if self.logger is not None: self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) self.logger.log_metrics('train', 'time', (time_spent)/self.log_text_interval, it) global_avg = 0.5*(global_avg) + 0.5*(time_spent) eta_seconds = global_avg * (max_it - it) / 100 eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) print(f'ETA: {eta_string}') self.last_time = time.time() self.train_integrator.finalize('train', it) self.train_integrator.reset_except_hooks() if it % self.save_network_interval == 0 and it != 0: if self.logger is not None: self.save_network(it) if it % self.save_checkpoint_interval == 0 and it != 0: if self.logger is not None: self.save_checkpoint(it) # Backward pass self.optimizer.zero_grad(set_to_none=True) if self.config['amp']: self.scaler.scale(losses['total_loss']).backward() self.scaler.step(self.optimizer) self.scaler.update() else: losses['total_loss'].backward() self.optimizer.step() self.scheduler.step() def save_network(self, it): if self.save_path is None: print('Saving has been disabled.') return os.makedirs(os.path.dirname(self.save_path), exist_ok=True) model_path = f'{self.save_path}_{it}.pth' torch.save(self.XMem.module.state_dict(), model_path) print(f'Network saved to {model_path}.') def save_checkpoint(self, it): if self.save_path is None: print('Saving has been disabled.') return os.makedirs(os.path.dirname(self.save_path), exist_ok=True) checkpoint_path = f'{self.save_path}_checkpoint_{it}.pth' checkpoint = { 'it': it, 'network': self.XMem.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict()} torch.save(checkpoint, checkpoint_path) print(f'Checkpoint saved to {checkpoint_path}.') def load_checkpoint(self, path): # This method loads everything and should be used to resume training map_location = 'cuda:%d' % self.local_rank checkpoint = torch.load(path, map_location={'cuda:0': map_location}) it = checkpoint['it'] network = checkpoint['network'] optimizer = checkpoint['optimizer'] scheduler = checkpoint['scheduler'] map_location = 'cuda:%d' % self.local_rank self.XMem.module.load_state_dict(network) self.optimizer.load_state_dict(optimizer) self.scheduler.load_state_dict(scheduler) print('Network weights, optimizer states, and scheduler states loaded.') return it def load_network_in_memory(self, src_dict): self.XMem.module.load_weights(src_dict) print('Network weight loaded from memory.') def load_network(self, path): # This method loads only the network weight and should be used to load a pretrained model map_location = 'cuda:%d' % self.local_rank src_dict = torch.load(path, map_location={'cuda:0': map_location}) self.load_network_in_memory(src_dict) print(f'Network weight loaded from {path}') def train(self): self._is_train = True self._do_log = True self.integrator = self.train_integrator self.XMem.eval() return self def val(self): self._is_train = False self._do_log = True self.XMem.eval() return self def test(self): self._is_train = False self._do_log = False self.XMem.eval() return self