Zhyever
refactor
1f418ff
raw
history blame
3.35 kB
import os
import cv2
import wandb
import numpy as np
import torch
import mmengine
from mmengine.optim import build_optim_wrapper
import torch.optim as optim
import matplotlib.pyplot as plt
from mmengine.dist import get_dist_info, collect_results_cpu, collect_results_gpu
from mmengine import print_log
from estimator.utils import colorize, colorize_infer_pfv1, colorize_rescale
import torch.nn.functional as F
from tqdm import tqdm
from mmengine.utils import mkdir_or_exist
import copy
from skimage import io
import kornia
from PIL import Image
class Tester:
"""
Tester class
"""
def __init__(
self,
config,
runner_info,
dataloader,
model):
self.config = config
self.runner_info = runner_info
self.dataloader = dataloader
self.model = model
self.collect_input_args = config.collect_input_args
def collect_input(self, batch_data):
collect_batch_data = dict()
for k, v in batch_data.items():
if isinstance(v, torch.Tensor):
if k in self.collect_input_args:
collect_batch_data[k] = v.cuda()
return collect_batch_data
@torch.no_grad()
def run(self, cai_mode='p16', process_num=4):
results = []
dataset = self.dataloader.dataset
loader_indices = self.dataloader.batch_sampler
rank, world_size = get_dist_info()
if self.runner_info.rank == 0:
prog_bar = mmengine.utils.ProgressBar(len(dataset))
for idx, (batch_indices, batch_data) in enumerate(zip(loader_indices, self.dataloader)):
batch_data_collect = self.collect_input(batch_data)
result, log_dict = self.model(mode='infer', cai_mode=cai_mode, process_num=process_num, **batch_data_collect) # might use test/val to split cases
if self.runner_info.save:
color_pred = colorize(result, cmap='magma_r')[:, :, [2, 1, 0]]
cv2.imwrite(os.path.join(self.runner_info.work_dir, '{}.png'.format(batch_data['img_file_basename'][0])), color_pred)
# Save as PNG
raw_depth = Image.fromarray((result.clone().squeeze().detach().cpu().numpy()*256).astype('uint16'))
raw_depth.save(os.path.join(self.runner_info.work_dir, '{}_uint16.png'.format(batch_data['img_file_basename'][0])))
if batch_data_collect.get('depth_gt', None) is not None:
metrics = dataset.get_metrics(
batch_data_collect['depth_gt'],
result,
seg_image=batch_data_collect.get('seg_image', None),
disp_gt_edges=batch_data.get('boundary', None),
image_hr=batch_data.get('image_hr', None))
results.extend([metrics])
if self.runner_info.rank == 0:
batch_size = len(result) * world_size
for _ in range(batch_size):
prog_bar.update()
if batch_data_collect.get('depth_gt', None) is not None:
results = collect_results_gpu(results, len(dataset))
if self.runner_info.rank == 0:
ret_dict = dataset.evaluate(results)