Zhyever
refactor
1f418ff
import os
import wandb
import numpy as np
import torch
import mmengine
from mmengine.optim import build_optim_wrapper
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.distributed as dist
from mmengine.dist import get_dist_info, collect_results_cpu, collect_results_gpu
from mmengine import print_log
import torch.nn.functional as F
from tqdm import tqdm
from estimator.utils import colorize
class Trainer:
"""
Trainer class
"""
def __init__(
self,
config,
runner_info,
train_sampler,
train_dataloader,
val_dataloader,
model):
self.config = config
self.runner_info = runner_info
self.train_sampler = train_sampler
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.model = model
# build opt and schedule
self.optimizer_wrapper = build_optim_wrapper(self.model, config.optim_wrapper)
self.scheduler = optim.lr_scheduler.OneCycleLR(
self.optimizer_wrapper.optimizer, [l['lr'] for l in self.optimizer_wrapper.optimizer.param_groups], epochs=self.config.train_cfg.max_epochs, steps_per_epoch=len(self.train_dataloader),
cycle_momentum=config.param_scheduler.cycle_momentum, base_momentum=config.param_scheduler.get('base_momentum', 0.85), max_momentum=config.param_scheduler.get('max_momentum', 0.95),
div_factor=config.param_scheduler.div_factor, final_div_factor=config.param_scheduler.final_div_factor, pct_start=config.param_scheduler.pct_start, three_phase=config.param_scheduler.three_phase)
# I'd like use wandb log_name
self.train_step = 0 # for training
self.val_step = 0 # for validation
self.iters_per_train_epoch = len(self.train_dataloader)
self.iters_per_val_epoch = len(self.val_dataloader)
self.grad_scaler = torch.cuda.amp.GradScaler()
self.collect_input_args = config.collect_input_args
print_log('successfully init trainer', logger='current')
def log_images(self, log_dict, prefix="", scalar_cmap="turbo_r", min_depth=1e-3, max_depth=80, step=0):
# Custom log images. Please add more items to the log dict returned from the model
wimages = dict()
wimages['{}/step'.format(prefix)] = step
rgb = log_dict.get('rgb')[0]
_, h_rgb, w_rgb = rgb.shape
if 'depth_pred' in log_dict.keys():
depth_pred = log_dict.get('depth_pred')[0]
depth_pred = depth_pred.squeeze()
depth_gt = log_dict.get('depth_gt')[0]
depth_gt = depth_gt.squeeze()
invalid_mask = torch.logical_or(depth_gt<=min_depth, depth_gt>=max_depth).detach().cpu().squeeze().numpy() # (h, w)
if np.sum(np.logical_not(invalid_mask)) == 0: # all pixels in gt are invalid
return
depth_gt_color = colorize(depth_gt, vmin=None, vmax=None, invalid_mask=invalid_mask, cmap=scalar_cmap)
depth_pred_color = colorize(depth_pred, vmin=None, vmax=None)
depth_gt_img = wandb.Image(depth_gt_color, caption='depth_gt')
depth_pred_img = wandb.Image(depth_pred_color, caption='depth_pred')
rgb = wandb.Image(rgb, caption='rgb')
wimages['{}/LogImageDepth'.format(prefix)] = [rgb, depth_gt_img, depth_pred_img]
if 'seg_pred' in log_dict.keys():
seg_pred = log_dict.get('seg_pred')[0]
seg_pred = seg_pred.squeeze()
seg_gt = log_dict.get('seg_gt')[0]
seg_gt = seg_gt.squeeze()
# class_labels = {0: "good", 1: "refine", 2: "oor", 3: "sky"}
class_labels = {0: "bg", 1: "edge"}
mask_img = wandb.Image(
rgb,
masks={
"predictions": {"mask_data": seg_pred.detach().cpu().numpy(), "class_labels": class_labels},
"ground_truth": {"mask_data": seg_gt.detach().cpu().numpy(), "class_labels": class_labels},
},
caption='segmentation')
wimages['{}/LogImageSeg'.format(prefix)] = [mask_img]
if 'mask' in log_dict.keys():
mask = log_dict.get('mask')[0]
mask = mask.squeeze().float()*255
mask_img = wandb.Image(
mask.detach().cpu().numpy(),
caption='segmentation')
cur_log = wimages['{}/LogImageDepth'.format(prefix)]
cur_log.append(mask_img)
wimages['{}/LogImageDepth'.format(prefix)] = cur_log
# some other things
if 'pseudo_gt' in log_dict.keys():
pseudo_gt = log_dict.get('pseudo_gt')[0]
pseudo_gt = pseudo_gt.squeeze()
pseudo_gt_color = colorize(pseudo_gt, vmin=None, vmax=None, cmap=scalar_cmap)
pseudo_gt_img = wandb.Image(pseudo_gt_color, caption='pseudo_gt')
cur_log = wimages['{}/LogImageDepth'.format(prefix)]
cur_log.append(pseudo_gt_img)
# pseudo_gt = log_dict.get('pseudo_gt')[0][0]
# pseudo_gt = pseudo_gt * 255
# pseudo_gt = pseudo_gt.astype(np.uint8)
# pseudo_gt_img = wandb.Image(pseudo_gt, caption='pseudo_gt')
# cur_log = wimages['{}/LogImageDepth'.format(prefix)]
# cur_log.append(pseudo_gt_img)
wandb.log(wimages)
def collect_input(self, batch_data):
collect_batch_data = dict()
for k, v in batch_data.items():
if isinstance(v, torch.Tensor):
if k in self.collect_input_args:
collect_batch_data[k] = v.cuda()
return collect_batch_data
@torch.no_grad()
def val_epoch(self):
results = []
results_list = [[] for _ in range(8)]
self.model.eval()
dataset = self.val_dataloader.dataset
loader_indices = self.val_dataloader.batch_sampler
rank, world_size = get_dist_info()
if self.runner_info.rank == 0:
prog_bar = mmengine.utils.ProgressBar(len(dataset))
for idx, (batch_indices, batch_data) in enumerate(zip(loader_indices, self.val_dataloader)):
self.val_step += 1
batch_data_collect = self.collect_input(batch_data)
# result, log_dict = self.model(mode='infer', **batch_data_collect)
result, log_dict = self.model(mode='infer', cai_mode='m1', process_num=4, **batch_data_collect) # might use test/val to split cases
if isinstance(result, list):
# in case you have multiple results
for num_res in range(len(result)):
metrics = dataset.get_metrics(
batch_data_collect['depth_gt'],
result[num_res],
disp_gt_edges=batch_data.get('boundary', None),
additional_mask=log_dict.get('mask', None),
image_hr=batch_data.get('image_hr', None))
results_list[num_res].extend([metrics])
else:
metrics = dataset.get_metrics(
batch_data_collect['depth_gt'],
result,
seg_image=batch_data_collect.get('seg_image', None),
disp_gt_edges=batch_data.get('boundary', None),
additional_mask=log_dict.get('mask', None),
image_hr=batch_data.get('image_hr', None))
results.extend([metrics])
if self.runner_info.rank == 0:
if isinstance(result, list):
batch_size = len(result[0]) * world_size
else:
batch_size = len(result) * world_size
for _ in range(batch_size):
prog_bar.update()
if self.runner_info.rank == 0 and self.config.debug == False and (idx + 1) % self.config.train_cfg.val_log_img_interval == False:
self.log_images(log_dict=log_dict, prefix="Val", min_depth=self.config.model.min_depth, max_depth=self.config.model.max_depth, step=self.val_step)
# collect results from all ranks
if isinstance(result, list):
results_collect = []
for results in results_list:
results = collect_results_gpu(results, len(dataset))
results_collect.append(results)
else:
results = collect_results_gpu(results, len(dataset))
if self.runner_info.rank == 0:
if isinstance(result, list):
for num_refine in range(len(result)):
ret_dict = dataset.evaluate(results_collect[num_refine])
else:
ret_dict = dataset.evaluate(results)
if self.runner_info.rank == 0 and self.config.debug == False:
wdict = dict()
for k, v in ret_dict.items():
wdict["Val/{}".format(k)] = v.item()
wdict['Val/step'] = self.val_step
wandb.log(wdict)
torch.cuda.empty_cache()
if self.runner_info.distributed is True:
torch.distributed.barrier()
self.model.train() # avoid changing model state
def train_epoch(self, epoch_idx):
self.model.train()
if self.runner_info.distributed:
dist.barrier()
pbar = tqdm(enumerate(self.train_dataloader), desc=f"Epoch: {epoch_idx + 1}/{self.config.train_cfg.max_epochs}. Loop: Train",
total=self.iters_per_train_epoch) if self.runner_info.rank == 0 else enumerate(self.train_dataloader)
for idx, batch_data in pbar:
self.train_step += 1
batch_data_collect = self.collect_input(batch_data)
loss_dict, log_dict = self.model(mode='train', **batch_data_collect)
total_loss = loss_dict['total_loss']
# total_loss = self.grad_scaler.scale(loss_dict['total_loss'])
self.optimizer_wrapper.update_params(total_loss)
self.scheduler.step()
# log something here
if self.runner_info.rank == 0:
log_info = 'Epoch: [{:02d}/{:02d}]'.format(epoch_idx + 1, self.config.train_cfg.max_epochs, idx + 1, len(self.train_dataloader))
for k, v in loss_dict.items():
log_info += ' - {}: {:.2f}'.format(k, v.item())
pbar.set_description(log_info)
if (idx + 1) % self.config.train_cfg.log_interval == 0:
log_info = 'Epoch: [{:02d}/{:02d}] - Step: [{:05d}/{:05d}] - Time: [{}/{}] - Total Loss: {}'.format(epoch_idx + 1, self.config.train_cfg.max_epochs, idx + 1, len(self.train_dataloader), 1, 1, total_loss)
for k, v in loss_dict.items():
if k != 'total_loss':
log_info += ' - {}: {}'.format(k, v)
print_log(log_info, logger='current')
if self.runner_info.rank == 0 and self.config.debug == False:
wdict = dict()
wdict['Train/total_loss'] = total_loss.item()
wdict['Train/LR'] = self.optimizer_wrapper.get_lr()['lr'][0]
wdict['Train/momentum'] = self.optimizer_wrapper.get_momentum()['momentum'][0]
wdict['Train/step'] = self.train_step
for k, v in loss_dict.items():
if k != 'total_loss':
if isinstance(v, torch.Tensor):
wdict['Train/{}'.format(k)] = v.item()
else:
wdict['Train/{}'.format(k)] = v
wandb.log(wdict)
if self.runner_info.rank == 0 and self.config.debug == False and (idx + 1) % self.config.train_cfg.train_log_img_interval == False:
self.log_images(log_dict=log_dict, prefix="Train", min_depth=self.config.model.min_depth, max_depth=self.config.model.max_depth, step=self.train_step)
if self.config.train_cfg.val_type == 'iter_base':
if (self.train_step + 1) % self.config.train_cfg.val_interval == 0 and (self.train_step + 1) >= self.config.train_cfg.get('eval_start', 0):
self.val_epoch()
def save_checkpoint(self, epoch_idx):
# As default, the model is wrappered by DDP!!! Hence, even if you're using one gpu, please use dist_train.sh
if hasattr(self.model.module, 'get_save_dict'):
print_log('Saving ckp, but use the inner get_save_dict fuction to get model_dict', logger='current')
# print_log('For saving space. Would you like to save base model several times? :>', logger='current')
model_dict = self.model.module.get_save_dict()
else:
model_dict = self.model.module.state_dict()
checkpoint_dict = {
'epoch': epoch_idx,
'model_state_dict': model_dict,
'optim_state_dict': self.optimizer_wrapper.state_dict(),
'schedule_state_dict': self.scheduler.state_dict()}
if self.runner_info.rank == 0:
torch.save(checkpoint_dict, os.path.join(self.runner_info.work_dir, 'checkpoint_{:02d}.pth'.format(epoch_idx + 1)))
log_info = 'save checkpoint_{:02d}.pth at {}'.format(epoch_idx + 1, self.runner_info.work_dir)
print_log(log_info, logger='current')
def run(self):
for name, param in self.model.named_parameters():
if param.requires_grad:
print_log('training param: {}'.format(name), logger='current')
# self.val_epoch() # do you want to debug val step?
for epoch_idx in range(self.config.train_cfg.max_epochs):
if self.runner_info.distributed:
self.train_sampler.set_epoch(epoch_idx)
self.train_epoch(epoch_idx)
if (epoch_idx + 1) % self.config.train_cfg.val_interval == 0 and (epoch_idx + 1) >= self.config.train_cfg.get('eval_start', 0) and self.config.train_cfg.val_type == 'epoch_base':
self.val_epoch()
if (epoch_idx + 1) % self.config.train_cfg.save_checkpoint_interval == 0:
self.save_checkpoint(epoch_idx)
if (epoch_idx + 1) % self.config.train_cfg.get('early_stop_epoch', 9999999) == 0: # Are you using 99999999+ epochs?
print_log('early stop at epoch: {}'.format(epoch_idx), logger='current')
break
if self.config.train_cfg.val_type == 'iter_base':
self.val_epoch()