Spaces:
Running
Running
File size: 2,841 Bytes
8406293 851751e 8406293 1615664 8406293 1615664 8406293 851751e e67befd 062d491 851751e 1615664 851751e 1615664 851751e 706f88d 851751e d10ef69 2014fad d10ef69 851751e 2014fad 851751e 2014fad 851751e 2014fad 851751e 2014fad d10ef69 851751e 706f88d 2aa6e78 851751e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
import spaces
import os
import torch
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from app_config import CSS, HEADER, FOOTER
from sample_cond import CKPT_PATH, MODEL_CFG, load_model_from_config, sample
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_model():
pl_sd = torch.load(CKPT_PATH, map_location="cpu")
model = load_model_from_config(MODEL_CFG.model, pl_sd["state_dict"])
return model
def create_custom_colormap():
colors = [(0, 1, 0), (0, 1, 1), (0, 0, 1), (1, 0, 1), (1, 1, 0)]
positions = [0, 0.38, 0.6, 0.7, 1]
custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', list(zip(positions, colors)), N=256)
return custom_cmap
def colorize_depth(depth, log_scale):
if log_scale:
depth = ((np.log2((depth / 255.) * 56. + 1) / 5.84) * 255.).astype(np.uint8)
mask = depth == 0
colormap = create_custom_colormap()
rgb = colormap(depth)[:, :, :3]
rgb[mask] = 0.
return rgb
@spaces.GPU
@torch.no_grad()
def generate_lidar(model, cond):
img, pcd = sample(model, cond)
return img, pcd
def load_camera(image):
split_per_view = 4
camera = np.array(image).astype(np.float32) / 255.
camera = camera.transpose(2, 0, 1)
camera_list = np.split(camera, split_per_view, axis=2) # split into n chunks as different views
camera_cond = torch.from_numpy(np.stack(camera_list, axis=0)).unsqueeze(0).to(DEVICE)
return camera_cond
model = load_model().to(DEVICE)
with gr.Blocks(css=CSS) as demo:
gr.Markdown(HEADER)
with gr.Row():
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
with gr.Column():
output_image = gr.Image(label="Output Range Map", elem_id='img-display-output')
output_pcd = gr.Model3D(label="Output Point Cloud", elem_id='pcd-display-output', interactive=False)
# raw_file = gr.File(label="Point Cloud (.txt file). Can be viewed through Meshlab")
submit = gr.Button("Generate")
def on_submit(image):
cond = load_camera(image)
img, pcd = generate_lidar(model, cond)
# tmp = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
# pcd.save(tmp.name)
rgb_img = colorize_depth(img, log_scale=True)
return [rgb_img, pcd]
submit.click(on_submit, inputs=[input_image], outputs=[output_image, output_pcd])
example_files = sorted(os.listdir('cam_examples'))
example_files = [os.path.join('cam_examples', filename) for filename in example_files]
examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[output_image, output_pcd],
fn=on_submit, cache_examples=False)
gr.Markdown(FOOTER)
if __name__ == '__main__':
demo.queue().launch()
|