import gradio as gr import spaces import os import torch import numpy as np from matplotlib.colors import LinearSegmentedColormap from app_config import CSS, HEADER, FOOTER from sample_cond import CKPT_PATH, MODEL_CFG, load_model_from_config, sample DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' def load_model(): pl_sd = torch.load(CKPT_PATH, map_location="cpu") model = load_model_from_config(MODEL_CFG.model, pl_sd["state_dict"]) return model def create_custom_colormap(): colors = [(0, 1, 0), (0, 1, 1), (0, 0, 1), (1, 0, 1), (1, 1, 0)] positions = [0, 0.38, 0.6, 0.7, 1] custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', list(zip(positions, colors)), N=256) return custom_cmap def colorize_depth(depth, log_scale): if log_scale: depth = ((np.log2((depth / 255.) * 56. + 1) / 5.84) * 255.).astype(np.uint8) mask = depth == 0 colormap = create_custom_colormap() rgb = colormap(depth)[:, :, :3] rgb[mask] = 0. return rgb @spaces.GPU @torch.no_grad() def generate_lidar(model, cond): img, pcd = sample(model, cond) return img, pcd def load_camera(image): split_per_view = 4 camera = np.array(image).astype(np.float32) / 255. camera = camera.transpose(2, 0, 1) camera_list = np.split(camera, split_per_view, axis=2) # split into n chunks as different views camera_cond = torch.from_numpy(np.stack(camera_list, axis=0)).unsqueeze(0).to(DEVICE) return camera_cond model = load_model().to(DEVICE) with gr.Blocks(css=CSS) as demo: gr.Markdown(HEADER) with gr.Row(): input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') with gr.Column(): output_image = gr.Image(label="Output Range Map", elem_id='img-display-output') output_pcd = gr.Model3D(label="Output Point Cloud", elem_id='pcd-display-output', interactive=False) # raw_file = gr.File(label="Point Cloud (.txt file). Can be viewed through Meshlab") submit = gr.Button("Generate") def on_submit(image): cond = load_camera(image) img, pcd = generate_lidar(model, cond) # tmp = tempfile.NamedTemporaryFile(suffix='.txt', delete=False) # pcd.save(tmp.name) rgb_img = colorize_depth(img, log_scale=True) return [rgb_img, pcd] submit.click(on_submit, inputs=[input_image], outputs=[output_image, output_pcd]) example_files = sorted(os.listdir('cam_examples')) example_files = [os.path.join('cam_examples', filename) for filename in example_files] examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[output_image, output_pcd], fn=on_submit, cache_examples=False) gr.Markdown(FOOTER) if __name__ == '__main__': demo.queue().launch()