import gradio as gr import os from PIL import Image import subprocess from gradio_model4dgs import Model4DGS import numpy import hashlib import shlex subprocess.run(shlex.split("pip install wheels/diff_gaussian_rasterization-0.0.0-cp310-cp310-linux_x86_64.whl")) # subprocess.run(shlex.split("pip install xformers==0.0.23 --no-deps --index-url https://download.pytorch.org/whl/cu118")) import rembg import glob import cv2 import numpy as np from diffusers import StableVideoDiffusionPipeline from scripts.gen_vid import * import sys sys.path.append('lgm') from safetensors.torch import load_file from kiui.cam import orbit_camera from core.options import config_defaults, Options from core.models import LGM from mvdream.pipeline_mvdream import MVDreamPipeline from infer_demo import process as process_lgm from main_4d_demo import process as process_dg4d import spaces from huggingface_hub import hf_hub_download ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors") js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'light') { url.searchParams.set('__theme', 'light'); window.location.href = url.href; } } """ device = torch.device('cuda') # device = torch.device('cpu') session = rembg.new_session(model_name='u2net') pipe = StableVideoDiffusionPipeline.from_pretrained( "stabilityai/stable-video-diffusion-img2vid", torch_dtype=torch.float16, variant="fp16" ) pipe.to(device) opt = config_defaults['big'] opt.resume = ckpt_path # model model = LGM(opt) # resume pretrained checkpoint if opt.resume is not None: if opt.resume.endswith('safetensors'): ckpt = load_file(opt.resume, device='cpu') else: ckpt = torch.load(opt.resume, map_location='cpu') model.load_state_dict(ckpt, strict=False) print(f'[INFO] Loaded checkpoint from {opt.resume}') else: print(f'[WARN] model randomly initialized, are you sure?') # device model = model.half().to(device) model.eval() rays_embeddings = model.prepare_default_rays(device) # load image dream pipe_mvdream = MVDreamPipeline.from_pretrained( "ashawkey/imagedream-ipmv-diffusers", # remote weights torch_dtype=torch.float16, trust_remote_code=True, # local_files_only=True, ) pipe_mvdream = pipe_mvdream.to(device) from guidance.zero123_utils import Zero123 guidance_zero123 = Zero123(device, model_key='ashawkey/stable-zero123-diffusers') def preprocess(path, recenter=True, size=256, border_ratio=0.2): files = [path] out_dir = os.path.dirname(path) for file in files: out_base = os.path.basename(file).split('.')[0] out_rgba = os.path.join(out_dir, out_base + '_rgba.png') # load image print(f'[INFO] loading image {file}...') image = cv2.imread(file, cv2.IMREAD_UNCHANGED) # carve background print(f'[INFO] background removal...') carved_image = rembg.remove(image, session=session) # [H, W, 4] mask = carved_image[..., -1] > 0 # recenter if recenter: print(f'[INFO] recenter...') final_rgba = np.zeros((size, size, 4), dtype=np.uint8) coords = np.nonzero(mask) x_min, x_max = coords[0].min(), coords[0].max() y_min, y_max = coords[1].min(), coords[1].max() h = x_max - x_min w = y_max - y_min desired_size = int(size * (1 - border_ratio)) scale = desired_size / max(h, w) h2 = int(h * scale) w2 = int(w * scale) x2_min = (size - h2) // 2 x2_max = x2_min + h2 y2_min = (size - w2) // 2 y2_max = y2_min + w2 final_rgba[x2_min:x2_max, y2_min:y2_max] = cv2.resize(carved_image[x_min:x_max, y_min:y_max], (w2, h2), interpolation=cv2.INTER_AREA) else: final_rgba = carved_image # write image cv2.imwrite(out_rgba, final_rgba) def gen_vid(input_path, seed, bg='white'): name = input_path.split('/')[-1].split('.')[0] input_dir = os.path.dirname(input_path) height, width = 512, 512 image = load_image(input_path, width, height, bg) generator = torch.manual_seed(seed) # frames = pipe(image, height, width, decode_chunk_size=2, generator=generator).frames[0] frames = pipe(image, height, width, generator=generator).frames[0] imageio.mimwrite(f"{input_dir}/{name}_generated.mp4", frames, fps=7) os.makedirs(f"{input_dir}/{name}_frames", exist_ok=True) for idx, img in enumerate(frames): img.save(f"{input_dir}/{name}_frames/{idx:03}.png") # check if there is a picture uploaded or selected def check_img_input(control_image): if control_image is None: raise gr.Error("Please select or upload an input image") # check if there is a picture uploaded or selected def check_video_3d_input(image_block: Image.Image): img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4')): raise gr.Error("Please generate a video first") if not os.path.exists(os.path.join('vis_data', f'{img_hash}_rgba_static.mp4')): raise gr.Error("Please generate a 3D first") @spaces.GPU() def optimize_stage_0(image_block: Image.Image, preprocess_chk: bool, seed_slider: int): if not os.path.exists('tmp_data'): os.makedirs('tmp_data') img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')): if preprocess_chk: # save image to a designated path image_block.save(os.path.join('tmp_data', f'{img_hash}.png')) # preprocess image # print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}') # subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True) preprocess(os.path.join("tmp_data", f"{img_hash}.png")) else: image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png')) # stage 1 # subprocess.run(f'export MKL_THREADING_LAYER=GNU;export MKL_SERVICE_FORCE_INTEL=1;python scripts/gen_vid.py --path tmp_data/{img_hash}_rgba.png --seed {seed_slider} --bg white', shell=True) gen_vid(f'tmp_data/{img_hash}_rgba.png', seed_slider) # return [os.path.join('logs', 'tmp_rgba_model.ply')] return os.path.join('tmp_data', f'{img_hash}_rgba_generated.mp4') @spaces.GPU() def optimize_stage_1(image_block: Image.Image, preprocess_chk: bool, seed_slider: int): if not os.path.exists('tmp_data'): os.makedirs('tmp_data') img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() if not os.path.exists(os.path.join('tmp_data', f'{img_hash}_rgba.png')): if preprocess_chk: # save image to a designated path image_block.save(os.path.join('tmp_data', f'{img_hash}.png')) # preprocess image # print(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}') # subprocess.run(f'python scripts/process.py {os.path.join("tmp_data", f"{img_hash}.png")}', shell=True) preprocess(os.path.join("tmp_data", f"{img_hash}.png")) else: image_block.save(os.path.join('tmp_data', f'{img_hash}_rgba.png')) # stage 1 # subprocess.run(f'python lgm/infer.py big --resume {ckpt_path} --test_path tmp_data/{img_hash}_rgba.png', shell=True) process_lgm(opt, f'tmp_data/{img_hash}_rgba.png', pipe_mvdream, model, rays_embeddings, seed_slider) # return [os.path.join('logs', 'tmp_rgba_model.ply')] return os.path.join('vis_data', f'{img_hash}_rgba_static.mp4') @spaces.GPU(duration=120) def optimize_stage_2(image_block: Image.Image, seed_slider: int): img_hash = hashlib.sha256(image_block.tobytes()).hexdigest() # stage 2 # subprocess.run(f'python main_4d.py --config {os.path.join("configs", "4d_demo.yaml")} input={os.path.join("tmp_data", f"{img_hash}_rgba.png")}', shell=True) process_dg4d(os.path.join("configs", "4d_demo.yaml"), os.path.join("tmp_data", f"{img_hash}_rgba.png"), guidance_zero123) # os.rename(os.path.join('logs', f'{img_hash}_rgba_frames'), os.path.join('logs', f'{img_hash}_{seed_slider:03d}_rgba_frames')) image_dir = os.path.join('logs', f'{img_hash}_rgba_frames') return os.path.join('vis_data', f'{img_hash}_rgba.mp4'), [image_dir+f'/{t:03d}.ply' for t in range(28)] # return [image_dir+f'/{t:03d}.ply' for t in range(28)] if __name__ == "__main__": _TITLE = '''DreamGaussian4D: Generative 4D Gaussian Splatting''' _DESCRIPTION = '''
We present DreamGausssion4D, an efficient 4D generation framework that builds on Gaussian Splatting. ''' _IMG_USER_GUIDE = "Please upload an image in the block above (or choose an example above), click **Generate Video** and **Generate 3D**. Finally, click **Generate 4D**." # load images in 'data' folder as examples example_folder = os.path.join(os.path.dirname(__file__), 'data') example_fns = os.listdir(example_folder) example_fns.sort() examples_full = [os.path.join(example_folder, x) for x in example_fns if x.endswith('.png')] # Compose demo layout & data flow with gr.Blocks(title=_TITLE, theme=gr.themes.Soft(), js=js_func) as demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown('# ' + _TITLE) gr.Markdown(_DESCRIPTION) # Image-to-3D with gr.Row(variant='panel'): with gr.Column(scale=5): image_block = gr.Image(type='pil', image_mode='RGBA', height=290, label='Input image') # elevation_slider = gr.Slider(-90, 90, value=0, step=1, label='Estimated elevation angle') seed_slider = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (Video)') seed_slider2 = gr.Slider(0, 100000, value=0, step=1, label='Random Seed (3D)') gr.Markdown( "random seed for video generation.") preprocess_chk = gr.Checkbox(True, label='Preprocess image automatically (remove background and recenter object)') gr.Examples( examples=examples_full, # NOTE: elements must match inputs list! inputs=[image_block], outputs=[image_block], cache_examples=False, label='Examples (click one of the images below to start)', examples_per_page=40 ) img_run_btn = gr.Button("Generate Video") threed_run_btn = gr.Button("Generate 3D") fourd_run_btn = gr.Button("Generate 4D") img_guide_text = gr.Markdown(_IMG_USER_GUIDE, visible=True) with gr.Column(scale=5): with gr.Row(): with gr.Column(scale=5): dirving_video = gr.Video(label="video",height=290) with gr.Column(scale=5): obj3d = gr.Video(label="3D Model",height=290) video4d = gr.Video(label="4D video",height=290) obj4d = Model4DGS(label="4D Model", height=500, fps=28) img_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_0, inputs=[image_block, preprocess_chk, seed_slider], outputs=[ dirving_video]) threed_run_btn.click(check_img_input, inputs=[image_block], queue=False).success(optimize_stage_1, inputs=[image_block, preprocess_chk, seed_slider2], outputs=[ obj3d]) fourd_run_btn.click(check_video_3d_input, inputs=[image_block], queue=False).success(optimize_stage_2, inputs=[image_block, seed_slider], outputs=[video4d, obj4d]) # demo.queue().launch(share=True) demo.queue(max_size=10) # <-- Sets up a queue with default parameters demo.launch()