File size: 2,611 Bytes
8406293
851751e
 
 
 
 
 
8406293
706f88d
851751e
8406293
851751e
8406293
 
851751e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706f88d
851751e
 
 
d10ef69
 
 
 
851751e
d10ef69
851751e
 
 
 
 
 
 
 
 
 
d10ef69
851751e
d10ef69
851751e
 
 
d10ef69
 
851751e
706f88d
2aa6e78
851751e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import spaces
import tempfile
import os
import torch
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

from app_config import CSS, HEADER, FOOTER, DEVICE
import sample_cond

model = sample_cond.load_model()


def create_custom_colormap():
    colors = [(0, 1, 0), (0, 1, 1), (0, 0, 1), (1, 0, 1), (1, 1, 0)]
    positions = [0, 0.38, 0.6, 0.7, 1]

    custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', list(zip(positions, colors)), N=256)
    return custom_cmap


def colorize_depth(depth, log_scale):
    if log_scale:
        depth = ((np.log2((depth / 255.) * 56. + 1) / 5.84) * 255.).astype(np.uint8)
    mask = depth == 0
    colormap = create_custom_colormap()
    rgb = colormap(depth)[:, :, :3]
    rgb[mask] = 0.
    return rgb


@torch.no_grad()
def generate_lidar(model, cond):
    img, pcd = sample_cond.sample(model, cond)
    return img, pcd


def load_camera(image):
    split_per_view = 4
    camera = np.array(image).astype(np.float32) / 255.
    camera = camera.transpose(2, 0, 1)
    camera_list = np.split(camera, split_per_view, axis=2)  # split into n chunks as different views
    camera_cond = torch.from_numpy(np.stack(camera_list, axis=0)).unsqueeze(0).to(DEVICE)
    return camera_cond


with gr.Blocks(css=CSS) as demo:
    gr.Markdown(HEADER)

    with gr.Row():
        input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
        with gr.Column():
            output_image = gr.Image(label="Output Range Map", elem_id='img-display-output')
            output_pcd = gr.Model3D(label="Output Point Cloud", elem_id='pcd-display-output', interactive=False)

    raw_file = gr.File(label="Point Cloud (.txt file). Can be viewed through Meshlab")
    submit = gr.Button("Generate")

    def on_submit(image):
        cond = load_camera(image)
        img, pcd = generate_lidar(model, cond)

        tmp = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
        pcd.save(tmp.name)

        rgb_img = colorize_depth(img, log_scale=True)

        return [rgb_img, pcd, tmp.name]

    submit.click(on_submit, inputs=[input_image], outputs=[output_image, output_pcd, raw_file])

    example_files = sorted(os.listdir('cam_examples'))
    example_files = [os.path.join('cam_examples', filename) for filename in example_files]
    examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[output_image, output_pcd, raw_file],
                           fn=on_submit, cache_examples=False)

    gr.Markdown(FOOTER)


if __name__ == '__main__':
    demo.queue().launch()