define-hf-demo / app.py
Jiading Fang
add app file for gradio
2512c83
import os
import random
import numpy as np
import torch
import gradio as gr
import matplotlib as mpl
import matplotlib.cm as cm
from vidar.core.wrapper import Wrapper
from vidar.utils.config import read_config
def colormap_depth(depth_map):
# Input: depth_map -> HxW numpy array with depth values
# Output: colormapped_im -> HxW numpy array with colorcoded depth values
mask = depth_map!=0
disp_map = 1/depth_map
vmax = np.percentile(disp_map[mask], 95)
vmin = np.percentile(disp_map[mask], 5)
normalizer = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
mask = np.repeat(np.expand_dims(mask,-1), 3, -1)
colormapped_im = (mapper.to_rgba(disp_map)[:, :, :3] * 255).astype(np.uint8)
colormapped_im[~mask] = 255
return colormapped_im
def data_to_batch(data):
batch = data.copy()
batch['rgb'][0] = batch['rgb'][0].unsqueeze(0).unsqueeze(0)
batch['rgb'][1] = batch['rgb'][1].unsqueeze(0).unsqueeze(0)
batch['intrinsics'][0] = batch['intrinsics'][0].unsqueeze(0).unsqueeze(0)
batch['pose'][0] = batch['pose'][0].unsqueeze(0).unsqueeze(0)
batch['pose'][1] = batch['pose'][1].unsqueeze(0).unsqueeze(0)
batch['depth'][0] = batch['depth'][0].unsqueeze(0).unsqueeze(0)
batch['depth'][1] = batch['depth'][1].unsqueeze(0).unsqueeze(0)
return batch
os.environ['DIST_MODE'] = 'gpu' if torch.cuda.is_available() else 'cpu'
cfg_file_path = 'configs/papers/define/scannet_temporal_test_context_1.yaml'
cfg = read_config(cfg_file_path)
wrapper = Wrapper(cfg, verbose=True)
# print('arch: ', wrapper.arch)
# print('datasets: ', wrapper.datasets)
arch = wrapper.arch
arch.eval()
val_dataset = wrapper.datasets['validation'][0]
len_val_dataset = len(val_dataset)
# print('val datasets length: ', len_val_dataset)
# data_sample = val_dataset[0]
# batch = data_to_batch(data_sample)
# output = arch(batch, epoch=0)
# print('output: ', output)
# output_depth = output['predictions']['depth'][0][0]
# print('output_depth: ', output_depth)
# output_depth = output_depth.squeeze(0).squeeze(0).permute(1,2,0)
# print('output_depth shape: ', output_depth.shape)
def sample_data_idx():
return random.randint(0, len_val_dataset-1)
def display_images_from_idx(idx):
rgbs = val_dataset[int(idx)]['rgb']
return [np.array(rgb.permute(1,2,0)) for rgb in rgbs.values()]
def infer_depth_from_idx(idx):
data_sample = val_dataset[int(idx)]
batch = data_to_batch(data_sample)
output = arch(batch, epoch=0)
output_depths = output['predictions']['depth']
return [colormap_depth(output_depth[0].squeeze(0).squeeze(0).squeeze(0).detach().numpy()) for output_depth in output_depths.values()]
with gr.Blocks() as demo:
# layout
img_box = gr.Gallery(label="Sampled Images").style(grid=[2], height="auto")
data_idx_box = gr.Textbox(
label="Sampled Data Index",
placeholder="Number between {} and {}".format(0, len_val_dataset-1),
interactive=True
)
sample_btn = gr.Button('Sample Dataset')
depth_box = gr.Gallery(label="Infered Depth").style(grid=[2], height="auto")
infer_btn = gr.Button('Depth Infer')
# actions
sample_btn.click(
fn=sample_data_idx,
inputs=None,
outputs=data_idx_box
).success(
fn=display_images_from_idx,
inputs=data_idx_box,
outputs=img_box,
)
infer_btn.click(
fn=infer_depth_from_idx,
inputs=data_idx_box,
outputs=depth_box
)
demo.launch(server_name="0.0.0.0", server_port=7860)