Spaces:
Sleeping
Sleeping
import streamlit as st | |
from huggingface_hub import HfFolder | |
HfFolder().save_token(st.secrets['etoken']) | |
import numpy | |
import trimesh | |
import objaverse | |
import openshape | |
import misc_utils | |
import plotly.graph_objects as go | |
def load_openshape(name): | |
return openshape.load_pc_encoder(name) | |
f32 = numpy.float32 | |
model_b32 = openshape.load_pc_encoder('openshape-pointbert-vitb32-rgb') | |
model_l14 = openshape.load_pc_encoder('openshape-pointbert-vitl14-rgb') | |
model_g14 = openshape.load_pc_encoder('openshape-pointbert-vitg14-rgb') | |
st.title("OpenShape Demo") | |
objaid = st.text_input("Enter an Objaverse ID") | |
model = st.file_uploader("Or upload a model (.glb/.obj/.ply)") | |
npy = st.file_uploader("Or upload a point cloud numpy array (.npy of Nx3 XYZ or Nx6 XYZRGB)") | |
swap_yz_axes = st.checkbox("Swap Y/Z axes of input (Y is up for OpenShape)") | |
prog = st.progress(0.0, "Idle") | |
def load_data(): | |
# load the model | |
prog.progress(0.05, "Preparing Point Cloud") | |
if npy is not None: | |
pc: numpy.ndarray = numpy.load(npy) | |
elif model is not None: | |
pc = misc_utils.model_to_pc(misc_utils.as_mesh(trimesh.load(model))) | |
elif objaid: | |
prog.progress(0.1, "Downloading Objaverse Object") | |
model = objaverse.load_objects([objaid])[objaid] | |
prog.progress(0.2, "Preparing Point Cloud") | |
pc = misc_utils.model_to_pc(misc_utils.as_mesh(trimesh.load(model))) | |
else: | |
raise ValueError("You have to supply 3D input!") | |
prog.progress(0.25, "Preprocessing Point Cloud") | |
assert pc.ndim == 2, "invalid pc shape: ndim = %d != 2" % pc.ndim | |
assert pc.shape[1] in [3, 6], "invalid pc shape: should have 3/6 channels, got %d" % pc.shape[1] | |
if swap_yz_axes: | |
pc[:, [1, 2]] = pc[:, [2, 1]] | |
pc[:, :3] = pc[:, :3] - numpy.mean(pc[:, :3], axis=0) | |
pc[:, :3] = pc[:, :3] / numpy.linalg.norm(pc[:, :3], axis=-1).max() | |
if pc.shape[1] == 3: | |
pc = numpy.concatenate([pc, numpy.ones_like(pc)], axis=-1) | |
prog.progress(0.3, "Preprocessed Point Cloud") | |
return pc.astype(f32) | |
def render_pc(ncols, col, pc): | |
cols = st.columns(ncols) | |
with cols[col]: | |
rgb = (pc[:, 3:] * 255).astype(numpy.uint8) | |
g = go.Scatter3d( | |
x=pc[:, 0], y=pc[:, 1], z=pc[:, 2], | |
mode='markers', | |
marker=dict(size=2, color=[f'rgb({rgb[i, 0]}, {rgb[i, 1]}, {rgb[i, 2]})' for i in range(len(pc))]), | |
) | |
fig = go.Figure(data=[g]) | |
st.plotly_chart(fig) | |
st.caption("Point Cloud Preview") | |
return cols | |
try: | |
tab_cls, tab_pc2img, tab_cap = st.tabs(["Classification", "Point Cloud to Image Generation", "Point Cloud Captioning"]) | |
with tab_cls: | |
if st.button("Run Classification on LVIS Categories"): | |
pc = load_data() | |
col1, col2 = render_pc(2, 0, pc) | |
prog.progress(0.5, "Running Classification") | |
with col2: | |
pred = openshape.pred_lvis_sims(model_g14, pc) | |
for i, (cat, sim) in zip(range(5), pred.items()): | |
st.text(cat) | |
st.caption("Similarity %.4f" % sim) | |
prog.progress(1.0, "Idle") | |
with tab_pc2img: | |
prompt = st.text_input("Prompt") | |
noise_scale = st.slider('Variation Level', 0, 5) | |
cfg_scale = st.slider('Guidance Scale', 0.0, 30.0, 10.0) | |
steps = st.slider('Diffusion Steps', 2, 80, 10) | |
width = st.slider('Width', 128, 512, step=32) | |
height = st.slider('Height', 128, 512, step=32) | |
if st.button("Generate"): | |
pc = load_data() | |
col1, col2 = render_pc(2, 0, pc) | |
prog.progress(0.49, "Running Generation") | |
img = openshape.pc_to_image( | |
model_l14, pc, prompt, noise_scale, width, height, cfg_scale, steps, | |
lambda i, t, _: prog.progress(0.49 + i / (steps + 1) / 2, "Running Diffusion Step %d" % i) | |
) | |
with col2: | |
st.image(img) | |
prog.progress(1.0, "Idle") | |
with tab_cap: | |
cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 1.0) | |
if st.button("Generate a Caption"): | |
pc = load_data() | |
col1, col2 = render_pc(2, 0, pc) | |
prog.progress(0.5, "Running Generation") | |
cap = openshape.pc_caption(model_b32, pc, cond_scale) | |
with col2: | |
st.text(cap) | |
prog.progress(1.0, "Idle") | |
except Exception as exc: | |
st.error(repr(exc)) | |