File size: 4,483 Bytes
06bba7e
 
 
 
 
53625b9
 
 
 
 
 
 
 
 
 
 
 
 
06bba7e
53625b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06bba7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
from huggingface_hub import HfFolder
HfFolder().save_token(st.secrets['etoken'])


import numpy
import trimesh
import objaverse
import openshape
import misc_utils
import plotly.graph_objects as go


@st.cache_resource
def load_openshape(name):
    return openshape.load_pc_encoder(name)


f32 = numpy.float32
model_b32 = openshape.load_pc_encoder('openshape-pointbert-vitb32-rgb')
model_l14 = openshape.load_pc_encoder('openshape-pointbert-vitl14-rgb')
model_g14 = openshape.load_pc_encoder('openshape-pointbert-vitg14-rgb')


st.title("OpenShape Demo")
objaid = st.text_input("Enter an Objaverse ID")
model = st.file_uploader("Or upload a model (.glb/.obj/.ply)")
npy = st.file_uploader("Or upload a point cloud numpy array (.npy of Nx3 XYZ or Nx6 XYZRGB)")
swap_yz_axes = st.checkbox("Swap Y/Z axes of input (Y is up for OpenShape)")
prog = st.progress(0.0, "Idle")


def load_data():
    # load the model
    prog.progress(0.05, "Preparing Point Cloud")
    if npy is not None:
        pc: numpy.ndarray = numpy.load(npy)
    elif model is not None:
        pc = misc_utils.model_to_pc(misc_utils.as_mesh(trimesh.load(model)))
    elif objaid:
        prog.progress(0.1, "Downloading Objaverse Object")
        model = objaverse.load_objects([objaid])[objaid]
        prog.progress(0.2, "Preparing Point Cloud")
        pc = misc_utils.model_to_pc(misc_utils.as_mesh(trimesh.load(model)))
    else:
        raise ValueError("You have to supply 3D input!")
    prog.progress(0.25, "Preprocessing Point Cloud")
    assert pc.ndim == 2, "invalid pc shape: ndim = %d != 2" % pc.ndim
    assert pc.shape[1] in [3, 6], "invalid pc shape: should have 3/6 channels, got %d" % pc.shape[1]
    if swap_yz_axes:
        pc[:, [1, 2]] = pc[:, [2, 1]]
    pc[:, :3] = pc[:, :3] - numpy.mean(pc[:, :3], axis=0)
    pc[:, :3] = pc[:, :3] / numpy.linalg.norm(pc[:, :3], axis=-1).max()
    if pc.shape[1] == 3:
        pc = numpy.concatenate([pc, numpy.ones_like(pc)], axis=-1)
    prog.progress(0.3, "Preprocessed Point Cloud")
    return pc.astype(f32)


def render_pc(ncols, col, pc):
    cols = st.columns(ncols)
    with cols[col]:
        rgb = (pc[:, 3:] * 255).astype(numpy.uint8)
        g = go.Scatter3d(
            x=pc[:, 0], y=pc[:, 1], z=pc[:, 2],
            mode='markers',
            marker=dict(size=2, color=[f'rgb({rgb[i, 0]}, {rgb[i, 1]}, {rgb[i, 2]})' for i in range(len(pc))]),
        )
        fig = go.Figure(data=[g])
        st.plotly_chart(fig)
        st.caption("Point Cloud Preview")
    return cols


try:
    tab_cls, tab_pc2img, tab_cap = st.tabs(["Classification", "Point Cloud to Image Generation", "Point Cloud Captioning"])

    with tab_cls:
        if st.button("Run Classification on LVIS Categories"):
            pc = load_data()
            col1, col2 = render_pc(2, 0, pc)
            prog.progress(0.5, "Running Classification")
            with col2:
                pred = openshape.pred_lvis_sims(model_g14, pc)
                for i, (cat, sim) in zip(range(5), pred.items()):
                    st.text(cat)
                    st.caption("Similarity %.4f" % sim)
            prog.progress(1.0, "Idle")

    with tab_pc2img:
        prompt = st.text_input("Prompt")
        noise_scale = st.slider('Variation Level', 0, 5)
        cfg_scale = st.slider('Guidance Scale', 0.0, 30.0, 10.0)
        steps = st.slider('Diffusion Steps', 2, 80, 10)
        width = st.slider('Width', 128, 512, step=32)
        height = st.slider('Height', 128, 512, step=32)
        if st.button("Generate"):
            pc = load_data()
            col1, col2 = render_pc(2, 0, pc)
            prog.progress(0.49, "Running Generation")
            img = openshape.pc_to_image(
                model_l14, pc, prompt, noise_scale, width, height, cfg_scale, steps,
                lambda i, t, _: prog.progress(0.49 + i / (steps + 1) / 2, "Running Diffusion Step %d" % i)
            )
            with col2:
                st.image(img)
            prog.progress(1.0, "Idle")

    with tab_cap:
        cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 1.0)
        if st.button("Generate a Caption"):
            pc = load_data()
            col1, col2 = render_pc(2, 0, pc)
            prog.progress(0.5, "Running Generation")
            cap = openshape.pc_caption(model_b32, pc, cond_scale)
            with col2:
                st.text(cap)
            prog.progress(1.0, "Idle")
except Exception as exc:
    st.error(repr(exc))