File size: 2,841 Bytes
8406293
851751e
 
 
 
 
8406293
1615664
 
8406293
1615664
 
 
 
 
 
 
8406293
 
851751e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e67befd
062d491
851751e
 
1615664
851751e
 
 
 
 
 
 
 
 
 
 
 
1615664
 
851751e
706f88d
851751e
 
 
d10ef69
 
 
 
2014fad
d10ef69
851751e
 
 
 
 
2014fad
 
851751e
 
 
2014fad
851751e
2014fad
851751e
 
 
2014fad
d10ef69
851751e
706f88d
2aa6e78
851751e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import spaces
import os
import torch
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

from app_config import CSS, HEADER, FOOTER
from sample_cond import CKPT_PATH, MODEL_CFG, load_model_from_config, sample

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


def load_model():
    pl_sd = torch.load(CKPT_PATH, map_location="cpu")
    model = load_model_from_config(MODEL_CFG.model, pl_sd["state_dict"])
    return model


def create_custom_colormap():
    colors = [(0, 1, 0), (0, 1, 1), (0, 0, 1), (1, 0, 1), (1, 1, 0)]
    positions = [0, 0.38, 0.6, 0.7, 1]

    custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', list(zip(positions, colors)), N=256)
    return custom_cmap


def colorize_depth(depth, log_scale):
    if log_scale:
        depth = ((np.log2((depth / 255.) * 56. + 1) / 5.84) * 255.).astype(np.uint8)
    mask = depth == 0
    colormap = create_custom_colormap()
    rgb = colormap(depth)[:, :, :3]
    rgb[mask] = 0.
    return rgb


@spaces.GPU
@torch.no_grad()
def generate_lidar(model, cond):
    img, pcd = sample(model, cond)
    return img, pcd


def load_camera(image):
    split_per_view = 4
    camera = np.array(image).astype(np.float32) / 255.
    camera = camera.transpose(2, 0, 1)
    camera_list = np.split(camera, split_per_view, axis=2)  # split into n chunks as different views
    camera_cond = torch.from_numpy(np.stack(camera_list, axis=0)).unsqueeze(0).to(DEVICE)
    return camera_cond


model = load_model().to(DEVICE)

with gr.Blocks(css=CSS) as demo:
    gr.Markdown(HEADER)

    with gr.Row():
        input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
        with gr.Column():
            output_image = gr.Image(label="Output Range Map", elem_id='img-display-output')
            output_pcd = gr.Model3D(label="Output Point Cloud", elem_id='pcd-display-output', interactive=False)

    # raw_file = gr.File(label="Point Cloud (.txt file). Can be viewed through Meshlab")
    submit = gr.Button("Generate")

    def on_submit(image):
        cond = load_camera(image)
        img, pcd = generate_lidar(model, cond)

        # tmp = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
        # pcd.save(tmp.name)

        rgb_img = colorize_depth(img, log_scale=True)

        return [rgb_img, pcd]

    submit.click(on_submit, inputs=[input_image], outputs=[output_image, output_pcd])

    example_files = sorted(os.listdir('cam_examples'))
    example_files = [os.path.join('cam_examples', filename) for filename in example_files]
    examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[output_image, output_pcd],
                           fn=on_submit, cache_examples=False)

    gr.Markdown(FOOTER)


if __name__ == '__main__':
    demo.queue().launch()