|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
import torchvision.transforms.functional as TF |
|
from safetensors.torch import load_file |
|
import rembg |
|
import gradio as gr |
|
|
|
|
|
from huggingface_hub import hf_hub_download |
|
ckpt_path = hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16.safetensors") |
|
|
|
try: |
|
import diff_gaussian_rasterization |
|
except ImportError: |
|
os.system("pip install ./diff-gaussian-rasterization") |
|
|
|
import kiui |
|
from kiui.op import recenter |
|
|
|
from core.options import Options |
|
from core.models import LGM |
|
from mvdream.pipeline_mvdream import MVDreamPipeline |
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
TMP_DIR = '/tmp' |
|
os.makedirs(TMP_DIR, exist_ok=True) |
|
|
|
|
|
opt = Options( |
|
input_size=256, |
|
up_channels=(1024, 1024, 512, 256, 128), |
|
up_attention=(True, True, True, False, False), |
|
splat_size=128, |
|
output_size=512, |
|
batch_size=8, |
|
num_views=8, |
|
gradient_accumulation_steps=1, |
|
mixed_precision='bf16', |
|
resume=ckpt_path, |
|
) |
|
|
|
|
|
model = LGM(opt) |
|
|
|
|
|
if opt.resume is not None: |
|
if opt.resume.endswith('safetensors'): |
|
ckpt = load_file(opt.resume, device='cpu') |
|
else: |
|
ckpt = torch.load(opt.resume, map_location='cpu') |
|
model.load_state_dict(ckpt, strict=False) |
|
print(f'[INFO] Loaded checkpoint from {opt.resume}') |
|
else: |
|
print(f'[WARN] model randomly initialized, are you sure?') |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = model.half().to(device) |
|
model.eval() |
|
|
|
tan_half_fov = np.tan(0.5 * np.deg2rad(opt.fovy)) |
|
proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device) |
|
proj_matrix[0, 0] = -1 / tan_half_fov |
|
proj_matrix[1, 1] = -1 / tan_half_fov |
|
proj_matrix[2, 2] = - (opt.zfar + opt.znear) / (opt.zfar - opt.znear) |
|
proj_matrix[3, 2] = - (opt.zfar * opt.znear) / (opt.zfar - opt.znear) |
|
proj_matrix[2, 3] = 1 |
|
|
|
|
|
pipe_text = MVDreamPipeline.from_pretrained( |
|
'ashawkey/mvdream-sd2.1-diffusers', |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True, |
|
|
|
) |
|
pipe_text = pipe_text.to(device) |
|
|
|
pipe_image = MVDreamPipeline.from_pretrained( |
|
"ashawkey/imagedream-ipmv-diffusers", |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True, |
|
|
|
) |
|
pipe_image = pipe_image.to(device) |
|
|
|
|
|
bg_remover = rembg.new_session() |
|
|
|
|
|
def run(input_image): |
|
prompt_neg = "ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate" |
|
|
|
|
|
kiui.seed_everything(42) |
|
|
|
output_ply_path = os.path.join(TMP_DIR, 'output.ply') |
|
|
|
input_image = np.array(input_image) |
|
|
|
carved_image = rembg.remove(input_image, session=bg_remover) |
|
mask = carved_image[..., -1] > 0 |
|
image = recenter(carved_image, mask, border_ratio=0.2) |
|
image = image.astype(np.float32) / 255.0 |
|
image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4]) |
|
mv_image = pipe_image("", image, negative_prompt=prompt_neg, num_inference_steps=30, guidance_scale=5.0, elevation=0) |
|
|
|
|
|
input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) |
|
input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) |
|
input_image = F.interpolate(input_image, size=(opt.input_size, opt.input_size), mode='bilinear', align_corners=False) |
|
input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) |
|
|
|
rays_embeddings = model.prepare_default_rays(device, elevation=0) |
|
input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
with torch.autocast(device_type='cuda', dtype=torch.float16): |
|
|
|
gaussians = model.forward_gaussians(input_image) |
|
|
|
|
|
model.gs.save_ply(gaussians, output_ply_path) |
|
|
|
return output_ply_path |
|
|
|
|
|
|
|
_TITLE = '''LGM Mini''' |
|
|
|
_DESCRIPTION = ''' |
|
<div> |
|
A lightweight version of <a href="https://huggingface.co/spaces/ashawkey/LGM">LGM: Large Multi-View Gaussian Model for High-Resolution 3D Content Creation</a>. |
|
</div> |
|
''' |
|
|
|
css = ''' |
|
#duplicate-button { |
|
margin: auto; |
|
color: white; |
|
background: #1565c0; |
|
border-radius: 100vh; |
|
} |
|
''' |
|
|
|
block = gr.Blocks(title=_TITLE, css=css) |
|
with block: |
|
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.Markdown('# ' + _TITLE) |
|
gr.Markdown(_DESCRIPTION) |
|
|
|
with gr.Row(variant='panel'): |
|
with gr.Column(scale=1): |
|
|
|
input_image = gr.Image(label="image", type='pil', height=300) |
|
|
|
button_gen = gr.Button("Generate") |
|
|
|
|
|
with gr.Column(scale=1): |
|
output_splat = gr.Model3D(label="3D Gaussians") |
|
|
|
button_gen.click(fn=run, inputs=[input_image], outputs=[output_splat]) |
|
|
|
gr.Examples( |
|
examples=[ |
|
"data_test/frog_sweater.jpg", |
|
"data_test/bird.jpg", |
|
"data_test/boy.jpg", |
|
"data_test/cat_statue.jpg", |
|
"data_test/dragontoy.jpg", |
|
"data_test/gso_rabbit.jpg", |
|
], |
|
inputs=[input_image], |
|
outputs=[output_splat], |
|
fn=lambda x: run(input_image=x), |
|
cache_examples=True, |
|
label='Image-to-3D Examples' |
|
) |
|
|
|
block.queue().launch(debug=True, share=True) |
|
|