Spaces:
Runtime error
Runtime error
import os | |
import random | |
import numpy as np | |
import torch | |
import gradio as gr | |
import matplotlib as mpl | |
import matplotlib.cm as cm | |
from vidar.core.wrapper import Wrapper | |
from vidar.utils.config import read_config | |
def colormap_depth(depth_map): | |
# Input: depth_map -> HxW numpy array with depth values | |
# Output: colormapped_im -> HxW numpy array with colorcoded depth values | |
mask = depth_map!=0 | |
disp_map = 1/depth_map | |
vmax = np.percentile(disp_map[mask], 95) | |
vmin = np.percentile(disp_map[mask], 5) | |
normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax) | |
mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') | |
mask = np.repeat(np.expand_dims(mask,-1), 3, -1) | |
colormapped_im = (mapper.to_rgba(disp_map)[:, :, :3] * 255).astype(np.uint8) | |
colormapped_im[~mask] = 255 | |
return colormapped_im | |
def data_to_batch(data): | |
batch = data.copy() | |
batch['rgb'][0] = batch['rgb'][0].unsqueeze(0).unsqueeze(0) | |
batch['rgb'][1] = batch['rgb'][1].unsqueeze(0).unsqueeze(0) | |
batch['intrinsics'][0] = batch['intrinsics'][0].unsqueeze(0).unsqueeze(0) | |
batch['pose'][0] = batch['pose'][0].unsqueeze(0).unsqueeze(0) | |
batch['pose'][1] = batch['pose'][1].unsqueeze(0).unsqueeze(0) | |
batch['depth'][0] = batch['depth'][0].unsqueeze(0).unsqueeze(0) | |
batch['depth'][1] = batch['depth'][1].unsqueeze(0).unsqueeze(0) | |
return batch | |
os.environ['DIST_MODE'] = 'gpu' if torch.cuda.is_available() else 'cpu' | |
cfg_file_path = 'configs/papers/define/scannet_temporal_test_context_1.yaml' | |
cfg = read_config(cfg_file_path) | |
wrapper = Wrapper(cfg, verbose=True) | |
# print('arch: ', wrapper.arch) | |
# print('datasets: ', wrapper.datasets) | |
arch = wrapper.arch | |
arch.eval() | |
val_dataset = wrapper.datasets['validation'][0] | |
len_val_dataset = len(val_dataset) | |
# print('val datasets length: ', len_val_dataset) | |
# data_sample = val_dataset[0] | |
# batch = data_to_batch(data_sample) | |
# output = arch(batch, epoch=0) | |
# print('output: ', output) | |
# output_depth = output['predictions']['depth'][0][0] | |
# print('output_depth: ', output_depth) | |
# output_depth = output_depth.squeeze(0).squeeze(0).permute(1,2,0) | |
# print('output_depth shape: ', output_depth.shape) | |
def sample_data_idx(): | |
return random.randint(0, len_val_dataset-1) | |
def display_images_from_idx(idx): | |
rgbs = val_dataset[int(idx)]['rgb'] | |
return [np.array(rgb.permute(1,2,0)) for rgb in rgbs.values()] | |
def infer_depth_from_idx(idx): | |
data_sample = val_dataset[int(idx)] | |
batch = data_to_batch(data_sample) | |
output = arch(batch, epoch=0) | |
output_depths = output['predictions']['depth'] | |
return [colormap_depth(output_depth[0].squeeze(0).squeeze(0).squeeze(0).detach().numpy()) for output_depth in output_depths.values()] | |
with gr.Blocks() as demo: | |
# layout | |
img_box = gr.Gallery(label="Sampled Images").style(grid=[2], height="auto") | |
data_idx_box = gr.Textbox( | |
label="Sampled Data Index", | |
placeholder="Number between {} and {}".format(0, len_val_dataset-1), | |
interactive=True | |
) | |
sample_btn = gr.Button('Sample Dataset') | |
depth_box = gr.Gallery(label="Infered Depth").style(grid=[2], height="auto") | |
infer_btn = gr.Button('Depth Infer') | |
# actions | |
sample_btn.click( | |
fn=sample_data_idx, | |
inputs=None, | |
outputs=data_idx_box | |
).success( | |
fn=display_images_from_idx, | |
inputs=data_idx_box, | |
outputs=img_box, | |
) | |
infer_btn.click( | |
fn=infer_depth_from_idx, | |
inputs=data_idx_box, | |
outputs=depth_box | |
) | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |