|
import gradio as gr |
|
from models import build_model |
|
from PIL import Image |
|
import numpy as np |
|
import torchvision |
|
import math |
|
import ninja |
|
import torch |
|
from tqdm import trange |
|
import imageio |
|
import requests |
|
import argparse |
|
import imageio |
|
from scipy.spatial.transform import Rotation |
|
|
|
from gradio_draggable import Draggable |
|
|
|
checkpoint = 'clevr.pth' |
|
state = torch.load(checkpoint, map_location='cpu') |
|
G = build_model(**state['model_kwargs_init']['generator_smooth']) |
|
o0, o1 = G.load_state_dict(state['models']['generator_smooth'], strict=False) |
|
G.eval().cuda() |
|
G.backbone.synthesis.input.x_offset =0 |
|
G.backbone.synthesis.input.y_offset =0 |
|
G_kwargs= dict(noise_mode='const', |
|
fused_modulate=False, |
|
impl='cuda', |
|
fp16_res=None) |
|
print('prepare finish', flush=True) |
|
|
|
|
|
COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'purple', 'blue'] |
|
SHAPE_NAME_LIST = ['cube', 'sphere', 'cylinder'] |
|
MATERIAL_NAME_LIST = ['rubber', 'metal'] |
|
|
|
canvas_x = 800 |
|
canvas_y = 200 |
|
batch_size = 1 |
|
code = torch.randn(1, G.z_dim).cuda() |
|
to_pil = torchvision.transforms.ToPILImage() |
|
|
|
RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, -0.8660, |
|
10.3923, 0.0000, -0.8660, -0.5000, 6.0000, 0.0000, 0.0000, |
|
0.0000, 1.0000, 262.5000, 0.0000, 32.0000, 0.0000, 262.5000, |
|
32.0000, 0.0000, 0.0000, 1.0000]], device='cuda') |
|
|
|
obj_dict = {} |
|
|
|
def trans(x, y, z, length): |
|
w = h = length |
|
x = 0.5 * w - 128 + 256 - (x/9 + .5) * 256 |
|
y = 0.5 * h - 128 + (y/9 + .5) * 256 |
|
z = z / 9 * 256 |
|
return x, y, z |
|
|
|
def objs_to_canvas(lst, length=256, scale = 2.6): |
|
objs = [] |
|
for each in lst: |
|
x, y, obj_id = each['x'], each['y'], each['id'] |
|
|
|
if obj_id not in obj_dict: |
|
color = np.random.choice(COLOR_NAME_LIST) |
|
shape = 'cube' |
|
material = 'rubber' |
|
rot = 0 |
|
obj_dict[obj_id] = [color, shape, material, rot] |
|
|
|
color, shape, material, rot = obj_dict[obj_id] |
|
x = -x / canvas_x * 16 |
|
y = y / canvas_y * 2 |
|
y *= 2 |
|
x += 1.0 |
|
y -= 1.5 |
|
z = 0.35 |
|
objs.append([x, y, z, shape, color, material, rot]) |
|
|
|
h, w = length, int(length *scale) |
|
nc = 14 |
|
canvas = np.zeros([h, w, nc]) |
|
xx = np.ones([h,w]).cumsum(0) |
|
yy = np.ones([h,w]).cumsum(1) |
|
|
|
for x, y, z, shape, color, material, rot in objs: |
|
y, x, z = trans(x, y, z, length) |
|
|
|
feat = [0] * nc |
|
feat[0] = 1 |
|
feat[COLOR_NAME_LIST.index(color) + 1] = 1 |
|
feat[SHAPE_NAME_LIST.index(shape) + 1 + len(COLOR_NAME_LIST)] = 1 |
|
feat[MATERIAL_NAME_LIST.index(material) + 1 + len(COLOR_NAME_LIST) + len(SHAPE_NAME_LIST)] = 1 |
|
feat = np.array(feat) |
|
rot_sin = np.sin(rot / 180 * np.pi) |
|
rot_cos = np.cos(rot / 180 * np.pi) |
|
|
|
if shape == 'cube': |
|
mask = (np.abs(+rot_cos * (xx-x) + rot_sin * (yy-y)) <= z) * \ |
|
(np.abs(-rot_sin * (xx-x) + rot_cos * (yy-y)) <= z) |
|
else: |
|
mask = ((xx-x)**2 + (y-yy)**2) ** 0.5 <= z |
|
canvas[mask] = feat |
|
canvas = np.transpose(canvas, [2, 0, 1]).astype(np.float32) |
|
return canvas |
|
|
|
@torch.no_grad() |
|
def predict_local_view(lst): |
|
canvas = torch.tensor(objs_to_canvas(lst)).cuda()[None] |
|
bevs = canvas[..., 0: 0+256] |
|
gen = G(code, RT, bevs) |
|
rgb = gen['gen_output']['image'][0] * .5 + .5 |
|
return to_pil(rgb) |
|
|
|
@torch.no_grad() |
|
def predict_local_view_video(lst): |
|
canvas = torch.tensor(objs_to_canvas(lst)).cuda()[None] |
|
bevs = canvas[..., 0: 0+256] |
|
RT_array = np.array(RT[0].cpu()) |
|
rot = RT_array[:16].reshape(4,4) |
|
trans = RT_array[16:] |
|
rot_new = rot.copy() |
|
r = Rotation.from_matrix(rot[:3, :3]) |
|
angles = r.as_euler("zyx",degrees=True) |
|
v_mean, h_mean = angles[1], angles[2] |
|
|
|
writer = imageio.get_writer('tmp.mp4', fps=25) |
|
for t in np.linspace(0, 1, 50): |
|
angles[1] = 0.5 * np.cos(t * 2 * math.pi) + v_mean |
|
angles[2] = 1 * np.sin(t * 2 * math.pi) + h_mean |
|
r = Rotation.from_euler("zyx",angles,degrees=True) |
|
rot_new[:3,:3] = r.as_matrix() |
|
new_RT = torch.tensor(np.concatenate([rot_new.flatten(), trans])[None]).cuda().float() |
|
gen = G(code, new_RT, bevs) |
|
rgb = gen['gen_output']['image'][0] * .5 + .5 |
|
writer.append_data(np.array(to_pil(rgb))) |
|
writer.close() |
|
return 'tmp.mp4' |
|
|
|
@torch.no_grad() |
|
def predict_global_view(lst): |
|
canvas = torch.tensor(objs_to_canvas(lst)).cuda()[None] |
|
length = canvas.shape[-1] |
|
lines = [] |
|
for i in trange(0, length - 256, 10): |
|
bevs = canvas[..., i: i+256] |
|
gen = G(code, RT, bevs) |
|
start = 128 if i > 0 else 0 |
|
lines.append(gen['gen_output']['image'][0, ..., start:128+32]) |
|
rgb = torch.cat(lines, 2)*.5+.5 |
|
return to_pil(rgb) |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# BerfScene: Bev-conditioned Equivariant Radiance Fields for Infinite 3D Scene Generation |
|
Qihang Zhang, Yinghao Xu, Yujun Shen, Bo Dai, Bolei Zhou*, Ceyuan Yang* (*Corresponding Author)<br> |
|
[Arxiv Report](https://arxiv.org/abs/2312.02136) | [Project Page](https://zqh0253.github.io/BerfScene/) | [Github](https://github.com/zqh0253/BerfScene) |
|
""" |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
### Quick Start |
|
1. Drag and place objects in the canvas. |
|
2. Click `Add object` to insert object into the canvas. |
|
3. Click `Reset` to clean the canvas. |
|
4. Click `Get local view` to synthesize local 3D scenes. |
|
5. Click `Get global view` to synthesize global 3D scenes. |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
drag = Draggable() |
|
with gr.Row(): |
|
submit_btn_local = gr.Button("Get local view", variant='primary') |
|
submit_btn_global = gr.Button("Get global view", variant='primary') |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
single_view_image = gr.Image(label='single view', interactive=False) |
|
single_view_video = gr.Video(label='mutli-view', interactive=False, autoplay=True) |
|
|
|
global_view_image = gr.Image(label='global view', interactive=False) |
|
|
|
|
|
submit_btn_local.click(fn=predict_local_view, inputs=drag, outputs=single_view_image) |
|
submit_btn_local.click(fn=predict_local_view_video, inputs=drag, outputs=single_view_video) |
|
submit_btn_global.click(fn=predict_global_view, inputs=drag, outputs=global_view_image) |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--port', type=int, help='The port number', default=7860) |
|
args = parser.parse_args() |
|
|
|
demo.queue() |
|
demo.launch(server_name='0.0.0.0', server_port=args.port, debug=True, show_error=True) |
|
|
|
|