Spaces:
Running
on
L40S
Running
on
L40S
import spaces | |
import gradio as gr | |
import glob | |
import hashlib | |
from PIL import Image | |
import os | |
import shlex | |
import subprocess | |
os.makedirs("./ckpt", exist_ok=True) | |
# download ViT-H SAM model into ./ckpt | |
subprocess.call(["wget", "-q", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "-O", "./ckpt/sam_vit_h_4b8939.pth"]) | |
subprocess.run( | |
shlex.split( | |
"pip install pip==24.0" | |
) | |
) | |
subprocess.run( | |
shlex.split( | |
"pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps" | |
) | |
) | |
from infer_api import InferAPI | |
config_canocalize = { | |
'config_path': './configs/canonicalization-infer.yaml', | |
} | |
config_multiview = {} | |
config_slrm = { | |
'config_path': './configs/mesh-slrm-infer.yaml' | |
} | |
config_refine = {} | |
EXAMPLE_IMAGES = glob.glob("./input_cases/*") | |
EXAMPLE_APOSE_IMAGES = glob.glob("./input_cases_apose/*") | |
infer_api = InferAPI(config_canocalize, config_multiview, config_slrm, config_refine) | |
_HEADER_ = ''' | |
<h2><b>[CVPR 2025] StdGEN 🤗 Gradio Demo</b></h2> | |
This is official demo for our CVPR 2025 paper <a href="">StdGEN: Semantic-Decomposed 3D Character Generation from Single Images</a>. | |
Code: <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2411.05738' target='_blank'>ArXiv</a>. | |
❗️❗️❗️**Important Notes:** This is only a **PREVIEW** version with **coarse precision geometry and texture** due to limited online resource. We skip some refinement process and perform only color back-projection to clothes and hair. Please refer to GitHub repo for complete version. | |
1. Refinement stage takes about ~2.5min, and the mesh result may possibly delayed due to the server load, please wait patiently. | |
2. You can upload any reference image (with or without background), A-pose images are also supported (white bkg required). If the image has an alpha channel (transparency), background segmentation will be automatically performed. Alternatively, you can pre-segment the background using other tools and upload the result directly. | |
3. Real person images generally work well, but note that normals may appear smoother than expected. You can try to use other monocular normal estimation models. | |
4. The base human model in the output is uncolored due to potential NSFW concerns. If you need colored results, please refer to the official GitHub repository for instructions. | |
''' | |
_CITE_ = r""" | |
If StdGEN is helpful, please help to ⭐ the <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub Repo</a>. Thanks! [](https://github.com/hyz317/StdGEN) | |
--- | |
📝 **Citation** | |
If you find our work useful for your research or applications, please cite using this bibtex: | |
```bibtex | |
@article{he2024stdgen, | |
title={StdGEN: Semantic-Decomposed 3D Character Generation from Single Images}, | |
author={He, Yuze and Zhou, Yanning and Zhao, Wang and Wu, Zhongkai and Xiao, Kaiwen and Yang, Wei and Liu, Yong-Jin and Han, Xiao}, | |
journal={arXiv preprint arXiv:2411.05738}, | |
year={2024} | |
} | |
``` | |
📧 **Contact** | |
If you have any questions, feel free to open a discussion or contact us at <b>hyz22@mails.tsinghua.edu.cn</b>. | |
""" | |
cache_arbitrary = {} | |
cache_multiview = [ {}, {}, {} ] | |
cache_slrm = {} | |
cache_refine = {} | |
tmp_path = '/tmp' | |
# 示例占位函数 - 需替换实际模型 | |
def arbitrary_to_apose(image, seed): | |
# convert image to PIL.Image | |
image = Image.fromarray(image) | |
image_hash = str(hashlib.md5(image.tobytes()).hexdigest()) + '_' + str(seed) | |
if image_hash not in cache_arbitrary: | |
apose_img = infer_api.genStage1(image, seed) | |
apose_img.save(f'{tmp_path}/{image_hash}.png') | |
cache_arbitrary[image_hash] = f'{tmp_path}/{image_hash}.png' | |
print(f'cached apose image: {image_hash}') | |
return apose_img | |
else: | |
apose_img = Image.open(cache_arbitrary[image_hash]) | |
print(f'loaded cached apose image: {image_hash}') | |
return apose_img | |
def apose_to_multiview(apose_img, seed): | |
# convert image to PIL.Image | |
apose_img = Image.fromarray(apose_img) | |
image_hash = str(hashlib.md5(apose_img.tobytes()).hexdigest()) + '_' + str(seed) | |
if image_hash not in cache_multiview[0]: | |
results = infer_api.genStage2(apose_img, seed, num_levels=1) | |
for idx, img in enumerate(results[0]["images"]): | |
img.save(f'{tmp_path}/{image_hash}_images_{idx}.png') | |
for idx, img in enumerate(results[0]["normals"]): | |
img.save(f'{tmp_path}/{image_hash}_normals_{idx}.png') | |
cache_multiview[0][image_hash] = { | |
"images": [f'{tmp_path}/{image_hash}_images_{idx}.png' for idx in range(len(results[0]["images"]))], | |
"normals": [f'{tmp_path}/{image_hash}_normals_{idx}.png' for idx in range(len(results[0]["normals"]))] | |
} | |
print(f'cached multiview images: {image_hash}') | |
return results[0]["images"], image_hash | |
else: | |
print(f'loaded cached multiview images: {image_hash}') | |
return [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]], image_hash | |
def multiview_to_mesh(images, image_hash): | |
if image_hash not in cache_slrm: | |
mesh_files = infer_api.genStage3(images) | |
cache_slrm[image_hash] = mesh_files | |
print(f'cached slrm files: {image_hash}') | |
else: | |
mesh_files = cache_slrm[image_hash] | |
print(f'loaded cached slrm files: {image_hash}') | |
return *mesh_files, image_hash | |
def refine_mesh(mesh1, mesh2, mesh3, seed, image_hash): | |
apose_img = Image.open(cache_multiview[0][image_hash]["images"][0]) | |
if image_hash not in cache_refine: | |
results = infer_api.genStage2(apose_img, seed, num_levels=2) | |
results[0] = {} | |
results[0]["images"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]] | |
results[0]["normals"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["normals"]] | |
refined = infer_api.genStage4([mesh1, mesh2, mesh3], results) | |
cache_refine[image_hash] = refined | |
print(f'cached refined mesh: {image_hash}') | |
else: | |
refined = cache_refine[image_hash] | |
print(f'loaded cached refined mesh: {image_hash}') | |
return refined | |
with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo: | |
gr.Markdown(_HEADER_) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("## 1. Reference Image to A-pose Image") | |
input_image = gr.Image(label="Input Reference Image", type="numpy", width=384, height=384) | |
gr.Examples( | |
examples=EXAMPLE_IMAGES, | |
inputs=input_image, | |
label="Click to use sample images", | |
) | |
seed_input = gr.Number( | |
label="Seed", | |
value=52, | |
precision=0, | |
interactive=True | |
) | |
pose_btn = gr.Button("Convert") | |
with gr.Column(): | |
gr.Markdown("## 2. Multi-view Generation") | |
a_pose_image = gr.Image(label="A-pose Result", type="numpy", width=384, height=384) | |
gr.Examples( | |
examples=EXAMPLE_APOSE_IMAGES, | |
inputs=a_pose_image, | |
label="Click to use sample A-pose images", | |
) | |
seed_input2 = gr.Number( | |
label="Seed", | |
value=50, | |
precision=0, | |
interactive=True | |
) | |
state2 = gr.State(value="") | |
view_btn = gr.Button("Generate Multi-view Images") | |
with gr.Column(): | |
gr.Markdown("## 3. Semantic-aware Reconstruction") | |
multiview_gallery = gr.Gallery( | |
label="Multi-view results", | |
columns=2, | |
interactive=False, | |
height="None" | |
) | |
state3 = gr.State(value="") | |
mesh_btn = gr.Button("Reconstruct") | |
with gr.Row(): | |
mesh_cols = [gr.Model3D(label=f"Mesh {i+1}", interactive=False, height=384) for i in range(3)] | |
full_mesh = gr.Model3D(label="Whole Mesh", height=384) | |
refine_btn = gr.Button("Refine") | |
gr.Markdown("## 4. Mesh refinement") | |
with gr.Row(): | |
refined_meshes = [gr.Model3D(label=f"refined mesh {i+1}", height=384) for i in range(3)] | |
refined_full_mesh = gr.Model3D(label="refined whole mesh", height=384) | |
gr.Markdown(_CITE_) | |
# 交互逻辑 | |
pose_btn.click( | |
arbitrary_to_apose, | |
inputs=[input_image, seed_input], | |
outputs=a_pose_image | |
) | |
view_btn.click( | |
apose_to_multiview, | |
inputs=[a_pose_image, seed_input2], | |
outputs=[multiview_gallery, state2] | |
) | |
mesh_btn.click( | |
multiview_to_mesh, | |
inputs=[multiview_gallery, state2], | |
outputs=[*mesh_cols, full_mesh, state3] | |
) | |
refine_btn.click( | |
refine_mesh, | |
inputs=[*mesh_cols, seed_input2, state3], | |
outputs=[refined_meshes[2], refined_meshes[0], refined_meshes[1], refined_full_mesh] | |
) | |
if __name__ == "__main__": | |
demo.launch(ssr_mode=False) |