import os import random import numpy as np import torch import gradio as gr import matplotlib as mpl import matplotlib.cm as cm from vidar.core.wrapper import Wrapper from vidar.utils.config import read_config def colormap_depth(depth_map): # Input: depth_map -> HxW numpy array with depth values # Output: colormapped_im -> HxW numpy array with colorcoded depth values mask = depth_map!=0 disp_map = 1/depth_map vmax = np.percentile(disp_map[mask], 95) vmin = np.percentile(disp_map[mask], 5) normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax) mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') mask = np.repeat(np.expand_dims(mask,-1), 3, -1) colormapped_im = (mapper.to_rgba(disp_map)[:, :, :3] * 255).astype(np.uint8) colormapped_im[~mask] = 255 return colormapped_im def data_to_batch(data): batch = data.copy() batch['rgb'][0] = batch['rgb'][0].unsqueeze(0).unsqueeze(0) batch['rgb'][1] = batch['rgb'][1].unsqueeze(0).unsqueeze(0) batch['intrinsics'][0] = batch['intrinsics'][0].unsqueeze(0).unsqueeze(0) batch['pose'][0] = batch['pose'][0].unsqueeze(0).unsqueeze(0) batch['pose'][1] = batch['pose'][1].unsqueeze(0).unsqueeze(0) batch['depth'][0] = batch['depth'][0].unsqueeze(0).unsqueeze(0) batch['depth'][1] = batch['depth'][1].unsqueeze(0).unsqueeze(0) return batch os.environ['DIST_MODE'] = 'gpu' if torch.cuda.is_available() else 'cpu' cfg_file_path = 'configs/papers/define/scannet_temporal_test_context_1.yaml' cfg = read_config(cfg_file_path) wrapper = Wrapper(cfg, verbose=True) # print('arch: ', wrapper.arch) # print('datasets: ', wrapper.datasets) arch = wrapper.arch arch.eval() val_dataset = wrapper.datasets['validation'][0] len_val_dataset = len(val_dataset) # print('val datasets length: ', len_val_dataset) # data_sample = val_dataset[0] # batch = data_to_batch(data_sample) # output = arch(batch, epoch=0) # print('output: ', output) # output_depth = output['predictions']['depth'][0][0] # print('output_depth: ', output_depth) # output_depth = output_depth.squeeze(0).squeeze(0).permute(1,2,0) # print('output_depth shape: ', output_depth.shape) def sample_data_idx(): return random.randint(0, len_val_dataset-1) def display_images_from_idx(idx): rgbs = val_dataset[int(idx)]['rgb'] return [np.array(rgb.permute(1,2,0)) for rgb in rgbs.values()] def infer_depth_from_idx(idx): data_sample = val_dataset[int(idx)] batch = data_to_batch(data_sample) output = arch(batch, epoch=0) output_depths = output['predictions']['depth'] return [colormap_depth(output_depth[0].squeeze(0).squeeze(0).squeeze(0).detach().numpy()) for output_depth in output_depths.values()] with gr.Blocks() as demo: # layout img_box = gr.Gallery(label="Sampled Images").style(grid=[2], height="auto") data_idx_box = gr.Textbox( label="Sampled Data Index", placeholder="Number between {} and {}".format(0, len_val_dataset-1), interactive=True ) sample_btn = gr.Button('Sample Dataset') depth_box = gr.Gallery(label="Infered Depth").style(grid=[2], height="auto") infer_btn = gr.Button('Depth Infer') # actions sample_btn.click( fn=sample_data_idx, inputs=None, outputs=data_idx_box ).success( fn=display_images_from_idx, inputs=data_idx_box, outputs=img_box, ) infer_btn.click( fn=infer_depth_from_idx, inputs=data_idx_box, outputs=depth_box ) demo.launch(server_name="0.0.0.0", server_port=7860)