Spaces:
Runtime error
Runtime error
import os | |
import wandb | |
import numpy as np | |
import torch | |
import mmengine | |
from mmengine.optim import build_optim_wrapper | |
import torch.optim as optim | |
import matplotlib.pyplot as plt | |
import torch.distributed as dist | |
from mmengine.dist import get_dist_info, collect_results_cpu, collect_results_gpu | |
from mmengine import print_log | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from estimator.utils import colorize | |
class Trainer: | |
""" | |
Trainer class | |
""" | |
def __init__( | |
self, | |
config, | |
runner_info, | |
train_sampler, | |
train_dataloader, | |
val_dataloader, | |
model): | |
self.config = config | |
self.runner_info = runner_info | |
self.train_sampler = train_sampler | |
self.train_dataloader = train_dataloader | |
self.val_dataloader = val_dataloader | |
self.model = model | |
# build opt and schedule | |
self.optimizer_wrapper = build_optim_wrapper(self.model, config.optim_wrapper) | |
self.scheduler = optim.lr_scheduler.OneCycleLR( | |
self.optimizer_wrapper.optimizer, [l['lr'] for l in self.optimizer_wrapper.optimizer.param_groups], epochs=self.config.train_cfg.max_epochs, steps_per_epoch=len(self.train_dataloader), | |
cycle_momentum=config.param_scheduler.cycle_momentum, base_momentum=config.param_scheduler.get('base_momentum', 0.85), max_momentum=config.param_scheduler.get('max_momentum', 0.95), | |
div_factor=config.param_scheduler.div_factor, final_div_factor=config.param_scheduler.final_div_factor, pct_start=config.param_scheduler.pct_start, three_phase=config.param_scheduler.three_phase) | |
# I'd like use wandb log_name | |
self.train_step = 0 # for training | |
self.val_step = 0 # for validation | |
self.iters_per_train_epoch = len(self.train_dataloader) | |
self.iters_per_val_epoch = len(self.val_dataloader) | |
self.grad_scaler = torch.cuda.amp.GradScaler() | |
self.collect_input_args = config.collect_input_args | |
print_log('successfully init trainer', logger='current') | |
def log_images(self, log_dict, prefix="", scalar_cmap="turbo_r", min_depth=1e-3, max_depth=80, step=0): | |
# Custom log images. Please add more items to the log dict returned from the model | |
wimages = dict() | |
wimages['{}/step'.format(prefix)] = step | |
rgb = log_dict.get('rgb')[0] | |
_, h_rgb, w_rgb = rgb.shape | |
if 'depth_pred' in log_dict.keys(): | |
depth_pred = log_dict.get('depth_pred')[0] | |
depth_pred = depth_pred.squeeze() | |
depth_gt = log_dict.get('depth_gt')[0] | |
depth_gt = depth_gt.squeeze() | |
invalid_mask = torch.logical_or(depth_gt<=min_depth, depth_gt>=max_depth).detach().cpu().squeeze().numpy() # (h, w) | |
if np.sum(np.logical_not(invalid_mask)) == 0: # all pixels in gt are invalid | |
return | |
depth_gt_color = colorize(depth_gt, vmin=None, vmax=None, invalid_mask=invalid_mask, cmap=scalar_cmap) | |
depth_pred_color = colorize(depth_pred, vmin=None, vmax=None) | |
depth_gt_img = wandb.Image(depth_gt_color, caption='depth_gt') | |
depth_pred_img = wandb.Image(depth_pred_color, caption='depth_pred') | |
rgb = wandb.Image(rgb, caption='rgb') | |
wimages['{}/LogImageDepth'.format(prefix)] = [rgb, depth_gt_img, depth_pred_img] | |
if 'seg_pred' in log_dict.keys(): | |
seg_pred = log_dict.get('seg_pred')[0] | |
seg_pred = seg_pred.squeeze() | |
seg_gt = log_dict.get('seg_gt')[0] | |
seg_gt = seg_gt.squeeze() | |
# class_labels = {0: "good", 1: "refine", 2: "oor", 3: "sky"} | |
class_labels = {0: "bg", 1: "edge"} | |
mask_img = wandb.Image( | |
rgb, | |
masks={ | |
"predictions": {"mask_data": seg_pred.detach().cpu().numpy(), "class_labels": class_labels}, | |
"ground_truth": {"mask_data": seg_gt.detach().cpu().numpy(), "class_labels": class_labels}, | |
}, | |
caption='segmentation') | |
wimages['{}/LogImageSeg'.format(prefix)] = [mask_img] | |
if 'mask' in log_dict.keys(): | |
mask = log_dict.get('mask')[0] | |
mask = mask.squeeze().float()*255 | |
mask_img = wandb.Image( | |
mask.detach().cpu().numpy(), | |
caption='segmentation') | |
cur_log = wimages['{}/LogImageDepth'.format(prefix)] | |
cur_log.append(mask_img) | |
wimages['{}/LogImageDepth'.format(prefix)] = cur_log | |
# some other things | |
if 'pseudo_gt' in log_dict.keys(): | |
pseudo_gt = log_dict.get('pseudo_gt')[0] | |
pseudo_gt = pseudo_gt.squeeze() | |
pseudo_gt_color = colorize(pseudo_gt, vmin=None, vmax=None, cmap=scalar_cmap) | |
pseudo_gt_img = wandb.Image(pseudo_gt_color, caption='pseudo_gt') | |
cur_log = wimages['{}/LogImageDepth'.format(prefix)] | |
cur_log.append(pseudo_gt_img) | |
# pseudo_gt = log_dict.get('pseudo_gt')[0][0] | |
# pseudo_gt = pseudo_gt * 255 | |
# pseudo_gt = pseudo_gt.astype(np.uint8) | |
# pseudo_gt_img = wandb.Image(pseudo_gt, caption='pseudo_gt') | |
# cur_log = wimages['{}/LogImageDepth'.format(prefix)] | |
# cur_log.append(pseudo_gt_img) | |
wandb.log(wimages) | |
def collect_input(self, batch_data): | |
collect_batch_data = dict() | |
for k, v in batch_data.items(): | |
if isinstance(v, torch.Tensor): | |
if k in self.collect_input_args: | |
collect_batch_data[k] = v.cuda() | |
return collect_batch_data | |
def val_epoch(self): | |
results = [] | |
results_list = [[] for _ in range(8)] | |
self.model.eval() | |
dataset = self.val_dataloader.dataset | |
loader_indices = self.val_dataloader.batch_sampler | |
rank, world_size = get_dist_info() | |
if self.runner_info.rank == 0: | |
prog_bar = mmengine.utils.ProgressBar(len(dataset)) | |
for idx, (batch_indices, batch_data) in enumerate(zip(loader_indices, self.val_dataloader)): | |
self.val_step += 1 | |
batch_data_collect = self.collect_input(batch_data) | |
# result, log_dict = self.model(mode='infer', **batch_data_collect) | |
result, log_dict = self.model(mode='infer', cai_mode='m1', process_num=4, **batch_data_collect) # might use test/val to split cases | |
if isinstance(result, list): | |
# in case you have multiple results | |
for num_res in range(len(result)): | |
metrics = dataset.get_metrics( | |
batch_data_collect['depth_gt'], | |
result[num_res], | |
disp_gt_edges=batch_data.get('boundary', None), | |
additional_mask=log_dict.get('mask', None), | |
image_hr=batch_data.get('image_hr', None)) | |
results_list[num_res].extend([metrics]) | |
else: | |
metrics = dataset.get_metrics( | |
batch_data_collect['depth_gt'], | |
result, | |
seg_image=batch_data_collect.get('seg_image', None), | |
disp_gt_edges=batch_data.get('boundary', None), | |
additional_mask=log_dict.get('mask', None), | |
image_hr=batch_data.get('image_hr', None)) | |
results.extend([metrics]) | |
if self.runner_info.rank == 0: | |
if isinstance(result, list): | |
batch_size = len(result[0]) * world_size | |
else: | |
batch_size = len(result) * world_size | |
for _ in range(batch_size): | |
prog_bar.update() | |
if self.runner_info.rank == 0 and self.config.debug == False and (idx + 1) % self.config.train_cfg.val_log_img_interval == False: | |
self.log_images(log_dict=log_dict, prefix="Val", min_depth=self.config.model.min_depth, max_depth=self.config.model.max_depth, step=self.val_step) | |
# collect results from all ranks | |
if isinstance(result, list): | |
results_collect = [] | |
for results in results_list: | |
results = collect_results_gpu(results, len(dataset)) | |
results_collect.append(results) | |
else: | |
results = collect_results_gpu(results, len(dataset)) | |
if self.runner_info.rank == 0: | |
if isinstance(result, list): | |
for num_refine in range(len(result)): | |
ret_dict = dataset.evaluate(results_collect[num_refine]) | |
else: | |
ret_dict = dataset.evaluate(results) | |
if self.runner_info.rank == 0 and self.config.debug == False: | |
wdict = dict() | |
for k, v in ret_dict.items(): | |
wdict["Val/{}".format(k)] = v.item() | |
wdict['Val/step'] = self.val_step | |
wandb.log(wdict) | |
torch.cuda.empty_cache() | |
if self.runner_info.distributed is True: | |
torch.distributed.barrier() | |
self.model.train() # avoid changing model state | |
def train_epoch(self, epoch_idx): | |
self.model.train() | |
if self.runner_info.distributed: | |
dist.barrier() | |
pbar = tqdm(enumerate(self.train_dataloader), desc=f"Epoch: {epoch_idx + 1}/{self.config.train_cfg.max_epochs}. Loop: Train", | |
total=self.iters_per_train_epoch) if self.runner_info.rank == 0 else enumerate(self.train_dataloader) | |
for idx, batch_data in pbar: | |
self.train_step += 1 | |
batch_data_collect = self.collect_input(batch_data) | |
loss_dict, log_dict = self.model(mode='train', **batch_data_collect) | |
total_loss = loss_dict['total_loss'] | |
# total_loss = self.grad_scaler.scale(loss_dict['total_loss']) | |
self.optimizer_wrapper.update_params(total_loss) | |
self.scheduler.step() | |
# log something here | |
if self.runner_info.rank == 0: | |
log_info = 'Epoch: [{:02d}/{:02d}]'.format(epoch_idx + 1, self.config.train_cfg.max_epochs, idx + 1, len(self.train_dataloader)) | |
for k, v in loss_dict.items(): | |
log_info += ' - {}: {:.2f}'.format(k, v.item()) | |
pbar.set_description(log_info) | |
if (idx + 1) % self.config.train_cfg.log_interval == 0: | |
log_info = 'Epoch: [{:02d}/{:02d}] - Step: [{:05d}/{:05d}] - Time: [{}/{}] - Total Loss: {}'.format(epoch_idx + 1, self.config.train_cfg.max_epochs, idx + 1, len(self.train_dataloader), 1, 1, total_loss) | |
for k, v in loss_dict.items(): | |
if k != 'total_loss': | |
log_info += ' - {}: {}'.format(k, v) | |
print_log(log_info, logger='current') | |
if self.runner_info.rank == 0 and self.config.debug == False: | |
wdict = dict() | |
wdict['Train/total_loss'] = total_loss.item() | |
wdict['Train/LR'] = self.optimizer_wrapper.get_lr()['lr'][0] | |
wdict['Train/momentum'] = self.optimizer_wrapper.get_momentum()['momentum'][0] | |
wdict['Train/step'] = self.train_step | |
for k, v in loss_dict.items(): | |
if k != 'total_loss': | |
if isinstance(v, torch.Tensor): | |
wdict['Train/{}'.format(k)] = v.item() | |
else: | |
wdict['Train/{}'.format(k)] = v | |
wandb.log(wdict) | |
if self.runner_info.rank == 0 and self.config.debug == False and (idx + 1) % self.config.train_cfg.train_log_img_interval == False: | |
self.log_images(log_dict=log_dict, prefix="Train", min_depth=self.config.model.min_depth, max_depth=self.config.model.max_depth, step=self.train_step) | |
if self.config.train_cfg.val_type == 'iter_base': | |
if (self.train_step + 1) % self.config.train_cfg.val_interval == 0 and (self.train_step + 1) >= self.config.train_cfg.get('eval_start', 0): | |
self.val_epoch() | |
def save_checkpoint(self, epoch_idx): | |
# As default, the model is wrappered by DDP!!! Hence, even if you're using one gpu, please use dist_train.sh | |
if hasattr(self.model.module, 'get_save_dict'): | |
print_log('Saving ckp, but use the inner get_save_dict fuction to get model_dict', logger='current') | |
# print_log('For saving space. Would you like to save base model several times? :>', logger='current') | |
model_dict = self.model.module.get_save_dict() | |
else: | |
model_dict = self.model.module.state_dict() | |
checkpoint_dict = { | |
'epoch': epoch_idx, | |
'model_state_dict': model_dict, | |
'optim_state_dict': self.optimizer_wrapper.state_dict(), | |
'schedule_state_dict': self.scheduler.state_dict()} | |
if self.runner_info.rank == 0: | |
torch.save(checkpoint_dict, os.path.join(self.runner_info.work_dir, 'checkpoint_{:02d}.pth'.format(epoch_idx + 1))) | |
log_info = 'save checkpoint_{:02d}.pth at {}'.format(epoch_idx + 1, self.runner_info.work_dir) | |
print_log(log_info, logger='current') | |
def run(self): | |
for name, param in self.model.named_parameters(): | |
if param.requires_grad: | |
print_log('training param: {}'.format(name), logger='current') | |
# self.val_epoch() # do you want to debug val step? | |
for epoch_idx in range(self.config.train_cfg.max_epochs): | |
if self.runner_info.distributed: | |
self.train_sampler.set_epoch(epoch_idx) | |
self.train_epoch(epoch_idx) | |
if (epoch_idx + 1) % self.config.train_cfg.val_interval == 0 and (epoch_idx + 1) >= self.config.train_cfg.get('eval_start', 0) and self.config.train_cfg.val_type == 'epoch_base': | |
self.val_epoch() | |
if (epoch_idx + 1) % self.config.train_cfg.save_checkpoint_interval == 0: | |
self.save_checkpoint(epoch_idx) | |
if (epoch_idx + 1) % self.config.train_cfg.get('early_stop_epoch', 9999999) == 0: # Are you using 99999999+ epochs? | |
print_log('early stop at epoch: {}'.format(epoch_idx), logger='current') | |
break | |
if self.config.train_cfg.val_type == 'iter_base': | |
self.val_epoch() | |