Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import time | |
import json | |
import torch | |
import mcubes | |
import trimesh | |
import datetime | |
import argparse | |
import subprocess | |
import numpy as np | |
import gradio as gr | |
from tqdm import tqdm | |
import imageio.v2 as imageio | |
import pytorch_lightning as pl | |
from omegaconf import OmegaConf | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
from ldm.models.diffusion.dpm_solver import DPMSolverSampler | |
from utility.initialize import instantiate_from_config, get_obj_from_str | |
from utility.triplane_renderer.eg3d_renderer import sample_from_planes, generate_planes | |
from utility.triplane_renderer.renderer import get_rays, to8b | |
from safetensors.torch import load_file | |
from huggingface_hub import hf_hub_download | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning) | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
def add_text(rgb, caption): | |
font = cv2.FONT_HERSHEY_SIMPLEX | |
# org | |
gap = 10 | |
org = (gap, gap) | |
# fontScale | |
fontScale = 0.3 | |
# Blue color in BGR | |
color = (255, 0, 0) | |
# Line thickness of 2 px | |
thickness = 1 | |
break_caption = [] | |
for i in range(len(caption) // 30 + 1): | |
break_caption_i = caption[i*30:(i+1)*30] | |
break_caption.append(break_caption_i) | |
for i, bci in enumerate(break_caption): | |
cv2.putText(rgb, bci, (gap, gap*(i+1)), font, fontScale, color, thickness, cv2.LINE_AA) | |
return rgb | |
config = "configs/default.yaml" | |
# ckpt = "checkpoints/3dtopia_diffusion_state_dict.ckpt" | |
ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors") | |
configs = OmegaConf.load(config) | |
os.makedirs("tmp", exist_ok=True) | |
if ckpt.endswith(".ckpt"): | |
model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params) | |
elif ckpt.endswith(".safetensors"): | |
model = get_obj_from_str(configs.model["target"])(**configs.model.params) | |
model_ckpt = load_file(ckpt) | |
model.load_state_dict(model_ckpt) | |
else: | |
raise NotImplementedError | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = model.to(device) | |
sampler = DDIMSampler(model) | |
img_size = configs.model.params.unet_config.params.image_size | |
channels = configs.model.params.unet_config.params.in_channels | |
shape = [channels, img_size, img_size * 3] | |
pose_folder = 'assets/sample_data/pose' | |
poses_fname = sorted([os.path.join(pose_folder, f) for f in os.listdir(pose_folder)]) | |
batch_rays_list = [] | |
H = 128 | |
ratio = 512 // H | |
for p in poses_fname: | |
c2w = np.loadtxt(p).reshape(4, 4) | |
c2w[:3, 3] *= 2.2 | |
c2w = np.array([ | |
[1, 0, 0, 0], | |
[0, 0, -1, 0], | |
[0, 1, 0, 0], | |
[0, 0, 0, 1] | |
]) @ c2w | |
k = np.array([ | |
[560 / ratio, 0, H * 0.5], | |
[0, 560 / ratio, H * 0.5], | |
[0, 0, 1] | |
]) | |
rays_o, rays_d = get_rays(H, H, torch.Tensor(k), torch.Tensor(c2w[:3, :4])) | |
coords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, H-1, H), indexing='ij'), -1) | |
coords = torch.reshape(coords, [-1,2]).long() | |
rays_o = rays_o[coords[:, 0], coords[:, 1]] | |
rays_d = rays_d[coords[:, 0], coords[:, 1]] | |
batch_rays = torch.stack([rays_o, rays_d], 0) | |
batch_rays_list.append(batch_rays) | |
batch_rays_list = torch.stack(batch_rays_list, 0) | |
def marching_cube(b, text, global_info): | |
# prepare volumn for marching cube | |
res = 128 | |
assert 'decode_res' in global_info | |
decode_res = global_info['decode_res'] | |
c_list = torch.linspace(-1.2, 1.2, steps=res) | |
grid_x, grid_y, grid_z = torch.meshgrid( | |
c_list, c_list, c_list, indexing='ij' | |
) | |
coords = torch.stack([grid_x, grid_y, grid_z], -1).to(device) | |
plane_axes = generate_planes() | |
feats = sample_from_planes( | |
plane_axes, decode_res[b:b+1].reshape(1, 3, -1, 256, 256), coords.reshape(1, -1, 3), padding_mode='zeros', box_warp=2.4 | |
) | |
fake_dirs = torch.zeros_like(coords) | |
fake_dirs[..., 0] = 1 | |
out = model.first_stage_model.triplane_decoder.decoder(feats, fake_dirs) | |
u = out['sigma'].reshape(res, res, res).detach().cpu().numpy() | |
del out | |
# marching cube | |
vertices, triangles = mcubes.marching_cubes(u, 10) | |
min_bound = np.array([-1.2, -1.2, -1.2]) | |
max_bound = np.array([1.2, 1.2, 1.2]) | |
vertices = vertices / (res - 1) * (max_bound - min_bound)[None, :] + min_bound[None, :] | |
pt_vertices = torch.from_numpy(vertices).to(device) | |
# extract vertices color | |
res_triplane = 256 | |
render_kwargs = { | |
'depth_resolution': 128, | |
'disparity_space_sampling': False, | |
'box_warp': 2.4, | |
'depth_resolution_importance': 128, | |
'clamp_mode': 'softplus', | |
'white_back': True, | |
'det': True | |
} | |
rays_o_list = [ | |
np.array([0, 0, 2]), | |
np.array([0, 0, -2]), | |
np.array([0, 2, 0]), | |
np.array([0, -2, 0]), | |
np.array([2, 0, 0]), | |
np.array([-2, 0, 0]), | |
] | |
rgb_final = None | |
diff_final = None | |
for rays_o in tqdm(rays_o_list): | |
rays_o = torch.from_numpy(rays_o.reshape(1, 3)).repeat(vertices.shape[0], 1).float().to(device) | |
rays_d = pt_vertices.reshape(-1, 3) - rays_o | |
rays_d = rays_d / torch.norm(rays_d, dim=-1).reshape(-1, 1) | |
dist = torch.norm(pt_vertices.reshape(-1, 3) - rays_o, dim=-1).cpu().numpy().reshape(-1) | |
render_out = model.first_stage_model.triplane_decoder( | |
decode_res[b:b+1].reshape(1, 3, -1, res_triplane, res_triplane), | |
rays_o.unsqueeze(0), rays_d.unsqueeze(0), render_kwargs, | |
whole_img=False, tvloss=False | |
) | |
rgb = render_out['rgb_marched'].reshape(-1, 3).detach().cpu().numpy() | |
depth = render_out['depth_final'].reshape(-1).detach().cpu().numpy() | |
depth_diff = np.abs(dist - depth) | |
if rgb_final is None: | |
rgb_final = rgb.copy() | |
diff_final = depth_diff.copy() | |
else: | |
ind = diff_final > depth_diff | |
rgb_final[ind] = rgb[ind] | |
diff_final[ind] = depth_diff[ind] | |
# bgr to rgb | |
rgb_final = np.stack([ | |
rgb_final[:, 2], rgb_final[:, 1], rgb_final[:, 0] | |
], -1) | |
# export to ply | |
mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=(rgb_final * 255).astype(np.uint8)) | |
path = os.path.join('tmp', f"{text.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.ply") | |
trimesh.exchange.export.export_mesh(mesh, path, file_type='ply') | |
del vertices, triangles, rgb_final | |
torch.cuda.empty_cache() | |
return path | |
def infer(prompt, samples, steps, scale, seed, global_info): | |
prompt = prompt.replace('/', '') | |
pl.seed_everything(seed) | |
batch_size = samples | |
with torch.no_grad(): | |
noise = None | |
c = model.get_learned_conditioning([prompt]) | |
unconditional_c = torch.zeros_like(c) | |
sample, _ = sampler.sample( | |
S=steps, | |
batch_size=batch_size, | |
shape=shape, | |
verbose=False, | |
x_T = noise, | |
conditioning = c.repeat(batch_size, 1, 1), | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=unconditional_c.repeat(batch_size, 1, 1) | |
) | |
decode_res = model.decode_first_stage(sample) | |
big_video_list = [] | |
global_info['decode_res'] = decode_res | |
for b in range(batch_size): | |
def render_img(v): | |
rgb_sample, _ = model.first_stage_model.render_triplane_eg3d_decoder( | |
decode_res[b:b+1], batch_rays_list[v:v+1].to(device), torch.zeros(1, H, H, 3).to(device), | |
) | |
rgb_sample = to8b(rgb_sample.detach().cpu().numpy())[0] | |
rgb_sample = np.stack( | |
[rgb_sample[..., 2], rgb_sample[..., 1], rgb_sample[..., 0]], -1 | |
) | |
rgb_sample = add_text(rgb_sample, str(b)) | |
return rgb_sample | |
view_num = len(batch_rays_list) | |
video_list = [] | |
for v in tqdm(range(view_num//8*3, view_num//8*5, 2)): | |
rgb_sample = render_img(v) | |
video_list.append(rgb_sample) | |
big_video_list.append(video_list) | |
# if batch_size == 2: | |
# cat_video_list = [ | |
# np.concatenate([big_video_list[j][i] for j in range(len(big_video_list))], 1) \ | |
# for i in range(len(big_video_list[0])) | |
# ] | |
# elif batch_size > 2: | |
# if batch_size == 3: | |
# big_video_list.append( | |
# [np.zeros_like(f) for f in big_video_list[0]] | |
# ) | |
# cat_video_list = [ | |
# np.concatenate([ | |
# np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1), | |
# np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1), | |
# ], 0) \ | |
# for i in range(len(big_video_list[0])) | |
# ] | |
# else: | |
# cat_video_list = big_video_list[0] | |
for _ in range(4 - batch_size): | |
big_video_list.append( | |
[np.zeros_like(f) + 255 for f in big_video_list[0]] | |
) | |
cat_video_list = [ | |
np.concatenate([ | |
np.concatenate([big_video_list[0][i], big_video_list[1][i]], 1), | |
np.concatenate([big_video_list[2][i], big_video_list[3][i]], 1), | |
], 0) \ | |
for i in range(len(big_video_list[0])) | |
] | |
path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4" | |
imageio.mimwrite(path, np.stack(cat_video_list, 0)) | |
return global_info, path | |
def infer_stage2(prompt, selection, seed, global_info): | |
prompt = prompt.replace('/', '') | |
mesh_path = marching_cube(int(selection), prompt, global_info) | |
mesh_name = mesh_path.split('/')[-1][:-4] | |
if2_cmd = f"threefiner if2 --mesh {mesh_path} --prompt \"{prompt}\" --outdir tmp --save {mesh_name}_if2.glb --text_dir --front_dir=-y" | |
print(if2_cmd) | |
# os.system(if2_cmd) | |
subprocess.Popen(if2_cmd, shell=True).wait() | |
torch.cuda.empty_cache() | |
video_path = f"tmp/{prompt.replace(' ', '_')}_{str(datetime.datetime.now()).replace(' ', '_')}.mp4" | |
render_cmd = f"kire {os.path.join('tmp', mesh_name + '_if2.glb')} --save_video {video_path} --wogui --force_cuda_rast --H 256 --W 256" | |
print(render_cmd) | |
# os.system(render_cmd) | |
subprocess.Popen(render_cmd, shell=True).wait() | |
torch.cuda.empty_cache() | |
return video_path, os.path.join('tmp', mesh_name + '_if2.glb') | |
block = gr.Blocks() | |
with block: | |
global_info = gr.State(dict()) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
text = gr.Textbox( | |
label = "Enter your prompt", | |
max_lines = 1, | |
placeholder = "Enter your prompt", | |
container = False, | |
) | |
btn = gr.Button("Generate 3D") | |
gallery = gr.Video(height=512) | |
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn") | |
with gr.Row(elem_id="advanced-options"): | |
samples = gr.Slider(label="Number of Samples", minimum=1, maximum=4, value=4, step=1) | |
steps = gr.Slider(label="Steps", minimum=1, maximum=500, value=50, step=1) | |
scale = gr.Slider( | |
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1 | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=2147483647, | |
step=1, | |
randomize=True, | |
) | |
gr.on([text.submit, btn.click], infer, inputs=[text, samples, steps, scale, seed, global_info], outputs=[global_info, gallery]) | |
advanced_button.click( | |
None, | |
[], | |
text, | |
) | |
with gr.Column(): | |
with gr.Row(): | |
dropdown = gr.Dropdown( | |
['0', '1', '2', '3'], label="Choose a Candidate For Stage2", value='0' | |
) | |
btn_stage2 = gr.Button("Start Refinement") | |
gallery = gr.Video(height=512) | |
download = gr.File(label="Download Mesh", file_count="single", height=100) | |
gr.on([btn_stage2.click], infer_stage2, inputs=[text, dropdown, seed, global_info], outputs=[gallery, download]) | |
block.launch(share=True) | |