File size: 3,605 Bytes
2512c83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import random

import numpy as np
import torch
import gradio as gr
import matplotlib as mpl
import matplotlib.cm as cm

from vidar.core.wrapper import Wrapper
from vidar.utils.config import read_config


def colormap_depth(depth_map):
    # Input: depth_map -> HxW numpy array with depth values 
    # Output: colormapped_im -> HxW numpy array with colorcoded depth values
    mask = depth_map!=0
    disp_map = 1/depth_map
    vmax = np.percentile(disp_map[mask], 95)
    vmin = np.percentile(disp_map[mask], 5)
    normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
    mask = np.repeat(np.expand_dims(mask,-1), 3, -1)
    colormapped_im = (mapper.to_rgba(disp_map)[:, :, :3] * 255).astype(np.uint8)
    colormapped_im[~mask] = 255
    return colormapped_im

def data_to_batch(data):
    batch = data.copy()
    batch['rgb'][0] = batch['rgb'][0].unsqueeze(0).unsqueeze(0)
    batch['rgb'][1] = batch['rgb'][1].unsqueeze(0).unsqueeze(0)
    batch['intrinsics'][0] = batch['intrinsics'][0].unsqueeze(0).unsqueeze(0)
    batch['pose'][0] = batch['pose'][0].unsqueeze(0).unsqueeze(0)
    batch['pose'][1] = batch['pose'][1].unsqueeze(0).unsqueeze(0)
    batch['depth'][0] = batch['depth'][0].unsqueeze(0).unsqueeze(0)
    batch['depth'][1] = batch['depth'][1].unsqueeze(0).unsqueeze(0)

    return batch


os.environ['DIST_MODE'] = 'gpu' if torch.cuda.is_available() else 'cpu'
cfg_file_path = 'configs/papers/define/scannet_temporal_test_context_1.yaml'
cfg = read_config(cfg_file_path)

wrapper = Wrapper(cfg, verbose=True)

# print('arch: ', wrapper.arch)
# print('datasets: ', wrapper.datasets)

arch = wrapper.arch
arch.eval()
val_dataset = wrapper.datasets['validation'][0]
len_val_dataset = len(val_dataset)
# print('val datasets length: ', len_val_dataset)

# data_sample = val_dataset[0]
# batch = data_to_batch(data_sample)
# output = arch(batch, epoch=0)
# print('output: ', output)

# output_depth = output['predictions']['depth'][0][0]
# print('output_depth: ', output_depth)
# output_depth = output_depth.squeeze(0).squeeze(0).permute(1,2,0)
# print('output_depth shape: ', output_depth.shape)

def sample_data_idx():
    return random.randint(0, len_val_dataset-1)

def display_images_from_idx(idx):
    rgbs = val_dataset[int(idx)]['rgb']
    return [np.array(rgb.permute(1,2,0)) for rgb in rgbs.values()]

def infer_depth_from_idx(idx):
    data_sample = val_dataset[int(idx)]
    batch = data_to_batch(data_sample)
    output = arch(batch, epoch=0)
    output_depths = output['predictions']['depth']
    return [colormap_depth(output_depth[0].squeeze(0).squeeze(0).squeeze(0).detach().numpy()) for output_depth in output_depths.values()]
    
with gr.Blocks() as demo:

    # layout
    img_box = gr.Gallery(label="Sampled Images").style(grid=[2], height="auto")
    data_idx_box = gr.Textbox(
        label="Sampled Data Index", 
        placeholder="Number between {} and {}".format(0, len_val_dataset-1), 
        interactive=True
    )
    sample_btn = gr.Button('Sample Dataset')

    depth_box = gr.Gallery(label="Infered Depth").style(grid=[2], height="auto")
    infer_btn = gr.Button('Depth Infer')
    
    # actions
    sample_btn.click(
        fn=sample_data_idx, 
        inputs=None, 
        outputs=data_idx_box
    ).success(
        fn=display_images_from_idx,
        inputs=data_idx_box,
        outputs=img_box,
    )

    infer_btn.click(
        fn=infer_depth_from_idx, 
        inputs=data_idx_box, 
        outputs=depth_box
    )

demo.launch(server_name="0.0.0.0", server_port=7860)