Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import wandb | |
import numpy as np | |
import torch | |
import mmengine | |
from mmengine.optim import build_optim_wrapper | |
import torch.optim as optim | |
import matplotlib.pyplot as plt | |
from mmengine.dist import get_dist_info, collect_results_cpu, collect_results_gpu | |
from mmengine import print_log | |
from estimator.utils import colorize, colorize_infer_pfv1, colorize_rescale | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
from mmengine.utils import mkdir_or_exist | |
import copy | |
from skimage import io | |
import kornia | |
from PIL import Image | |
class Tester: | |
""" | |
Tester class | |
""" | |
def __init__( | |
self, | |
config, | |
runner_info, | |
dataloader, | |
model): | |
self.config = config | |
self.runner_info = runner_info | |
self.dataloader = dataloader | |
self.model = model | |
self.collect_input_args = config.collect_input_args | |
def collect_input(self, batch_data): | |
collect_batch_data = dict() | |
for k, v in batch_data.items(): | |
if isinstance(v, torch.Tensor): | |
if k in self.collect_input_args: | |
collect_batch_data[k] = v.cuda() | |
return collect_batch_data | |
def run(self, cai_mode='p16', process_num=4): | |
results = [] | |
dataset = self.dataloader.dataset | |
loader_indices = self.dataloader.batch_sampler | |
rank, world_size = get_dist_info() | |
if self.runner_info.rank == 0: | |
prog_bar = mmengine.utils.ProgressBar(len(dataset)) | |
for idx, (batch_indices, batch_data) in enumerate(zip(loader_indices, self.dataloader)): | |
batch_data_collect = self.collect_input(batch_data) | |
result, log_dict = self.model(mode='infer', cai_mode=cai_mode, process_num=process_num, **batch_data_collect) # might use test/val to split cases | |
if self.runner_info.save: | |
color_pred = colorize(result, cmap='magma_r')[:, :, [2, 1, 0]] | |
cv2.imwrite(os.path.join(self.runner_info.work_dir, '{}.png'.format(batch_data['img_file_basename'][0])), color_pred) | |
# Save as PNG | |
raw_depth = Image.fromarray((result.clone().squeeze().detach().cpu().numpy()*256).astype('uint16')) | |
raw_depth.save(os.path.join(self.runner_info.work_dir, '{}_uint16.png'.format(batch_data['img_file_basename'][0]))) | |
if batch_data_collect.get('depth_gt', None) is not None: | |
metrics = dataset.get_metrics( | |
batch_data_collect['depth_gt'], | |
result, | |
seg_image=batch_data_collect.get('seg_image', None), | |
disp_gt_edges=batch_data.get('boundary', None), | |
image_hr=batch_data.get('image_hr', None)) | |
results.extend([metrics]) | |
if self.runner_info.rank == 0: | |
batch_size = len(result) * world_size | |
for _ in range(batch_size): | |
prog_bar.update() | |
if batch_data_collect.get('depth_gt', None) is not None: | |
results = collect_results_gpu(results, len(dataset)) | |
if self.runner_info.rank == 0: | |
ret_dict = dataset.evaluate(results) | |