Spaces:
Running
on
L40S
Running
on
L40S
File size: 9,249 Bytes
27c236e ef198e0 01416ea ef198e0 9d94048 c4940f9 9d94048 424c7ef 4b57058 424c7ef e08235d ef198e0 835933c a59b969 835933c a59b969 f1e6905 13973ba a59b969 1570031 a59b969 835933c cfae272 835933c ef198e0 914c133 ef198e0 01416ea 914c133 ef198e0 01416ea 914c133 4b04e0b 914c133 ef198e0 835933c ef198e0 914c133 ef198e0 c794eba ef198e0 914c133 ef198e0 914c133 ef198e0 835933c ef198e0 914c133 ef198e0 914c133 ef198e0 4b04e0b ef198e0 906e4ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import spaces
import gradio as gr
import glob
import hashlib
from PIL import Image
import os
import shlex
import subprocess
os.makedirs("./ckpt", exist_ok=True)
# download ViT-H SAM model into ./ckpt
subprocess.call(["wget", "-q", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "-O", "./ckpt/sam_vit_h_4b8939.pth"])
subprocess.run(
shlex.split(
"pip install pip==24.0"
)
)
subprocess.run(
shlex.split(
"pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
)
)
from infer_api import InferAPI
config_canocalize = {
'config_path': './configs/canonicalization-infer.yaml',
}
config_multiview = {}
config_slrm = {
'config_path': './configs/mesh-slrm-infer.yaml'
}
config_refine = {}
EXAMPLE_IMAGES = glob.glob("./input_cases/*")
EXAMPLE_APOSE_IMAGES = glob.glob("./input_cases_apose/*")
infer_api = InferAPI(config_canocalize, config_multiview, config_slrm, config_refine)
_HEADER_ = '''
<h2><b>[CVPR 2025] StdGEN 🤗 Gradio Demo</b></h2>
This is official demo for our CVPR 2025 paper <a href="">StdGEN: Semantic-Decomposed 3D Character Generation from Single Images</a>.
Code: <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2411.05738' target='_blank'>ArXiv</a>.
❗️❗️❗️**Important Notes:** This is only a **PREVIEW** version with **coarse precision geometry and texture** due to limited online resource. We skip some refinement process and perform only color back-projection to clothes and hair. Please refer to GitHub repo for complete version.
1. Refinement stage takes about ~2.5min, and the mesh result may possibly delayed due to the server load, please wait patiently.
2. You can upload any reference image (with or without background), A-pose images are also supported (white bkg required). If the image has an alpha channel (transparency), background segmentation will be automatically performed. Alternatively, you can pre-segment the background using other tools and upload the result directly.
3. Real person images generally work well, but note that normals may appear smoother than expected. You can try to use other monocular normal estimation models.
4. The base human model in the output is uncolored due to potential NSFW concerns. If you need colored results, please refer to the official GitHub repository for instructions.
'''
_CITE_ = r"""
If StdGEN is helpful, please help to ⭐ the <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub Repo</a>. Thanks! [](https://github.com/hyz317/StdGEN)
---
📝 **Citation**
If you find our work useful for your research or applications, please cite using this bibtex:
```bibtex
@article{he2024stdgen,
title={StdGEN: Semantic-Decomposed 3D Character Generation from Single Images},
author={He, Yuze and Zhou, Yanning and Zhao, Wang and Wu, Zhongkai and Xiao, Kaiwen and Yang, Wei and Liu, Yong-Jin and Han, Xiao},
journal={arXiv preprint arXiv:2411.05738},
year={2024}
}
```
📧 **Contact**
If you have any questions, feel free to open a discussion or contact us at <b>hyz22@mails.tsinghua.edu.cn</b>.
"""
cache_arbitrary = {}
cache_multiview = [ {}, {}, {} ]
cache_slrm = {}
cache_refine = {}
tmp_path = '/tmp'
# 示例占位函数 - 需替换实际模型
def arbitrary_to_apose(image, seed):
# convert image to PIL.Image
image = Image.fromarray(image)
image_hash = str(hashlib.md5(image.tobytes()).hexdigest()) + '_' + str(seed)
if image_hash not in cache_arbitrary:
apose_img = infer_api.genStage1(image, seed)
apose_img.save(f'{tmp_path}/{image_hash}.png')
cache_arbitrary[image_hash] = f'{tmp_path}/{image_hash}.png'
print(f'cached apose image: {image_hash}')
return apose_img
else:
apose_img = Image.open(cache_arbitrary[image_hash])
print(f'loaded cached apose image: {image_hash}')
return apose_img
def apose_to_multiview(apose_img, seed):
# convert image to PIL.Image
apose_img = Image.fromarray(apose_img)
image_hash = str(hashlib.md5(apose_img.tobytes()).hexdigest()) + '_' + str(seed)
if image_hash not in cache_multiview[0]:
results = infer_api.genStage2(apose_img, seed, num_levels=1)
for idx, img in enumerate(results[0]["images"]):
img.save(f'{tmp_path}/{image_hash}_images_{idx}.png')
for idx, img in enumerate(results[0]["normals"]):
img.save(f'{tmp_path}/{image_hash}_normals_{idx}.png')
cache_multiview[0][image_hash] = {
"images": [f'{tmp_path}/{image_hash}_images_{idx}.png' for idx in range(len(results[0]["images"]))],
"normals": [f'{tmp_path}/{image_hash}_normals_{idx}.png' for idx in range(len(results[0]["normals"]))]
}
print(f'cached multiview images: {image_hash}')
return results[0]["images"], image_hash
else:
print(f'loaded cached multiview images: {image_hash}')
return [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]], image_hash
def multiview_to_mesh(images, image_hash):
if image_hash not in cache_slrm:
mesh_files = infer_api.genStage3(images)
cache_slrm[image_hash] = mesh_files
print(f'cached slrm files: {image_hash}')
else:
mesh_files = cache_slrm[image_hash]
print(f'loaded cached slrm files: {image_hash}')
return *mesh_files, image_hash
def refine_mesh(mesh1, mesh2, mesh3, seed, image_hash):
apose_img = Image.open(cache_multiview[0][image_hash]["images"][0])
if image_hash not in cache_refine:
results = infer_api.genStage2(apose_img, seed, num_levels=2)
results[0] = {}
results[0]["images"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]]
results[0]["normals"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["normals"]]
refined = infer_api.genStage4([mesh1, mesh2, mesh3], results)
cache_refine[image_hash] = refined
print(f'cached refined mesh: {image_hash}')
else:
refined = cache_refine[image_hash]
print(f'loaded cached refined mesh: {image_hash}')
return refined
with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo:
gr.Markdown(_HEADER_)
with gr.Row():
with gr.Column():
gr.Markdown("## 1. Reference Image to A-pose Image")
input_image = gr.Image(label="Input Reference Image", type="numpy", width=384, height=384)
gr.Examples(
examples=EXAMPLE_IMAGES,
inputs=input_image,
label="Click to use sample images",
)
seed_input = gr.Number(
label="Seed",
value=52,
precision=0,
interactive=True
)
pose_btn = gr.Button("Convert")
with gr.Column():
gr.Markdown("## 2. Multi-view Generation")
a_pose_image = gr.Image(label="A-pose Result", type="numpy", width=384, height=384)
gr.Examples(
examples=EXAMPLE_APOSE_IMAGES,
inputs=a_pose_image,
label="Click to use sample A-pose images",
)
seed_input2 = gr.Number(
label="Seed",
value=50,
precision=0,
interactive=True
)
state2 = gr.State(value="")
view_btn = gr.Button("Generate Multi-view Images")
with gr.Column():
gr.Markdown("## 3. Semantic-aware Reconstruction")
multiview_gallery = gr.Gallery(
label="Multi-view results",
columns=2,
interactive=False,
height="None"
)
state3 = gr.State(value="")
mesh_btn = gr.Button("Reconstruct")
with gr.Row():
mesh_cols = [gr.Model3D(label=f"Mesh {i+1}", interactive=False, height=384) for i in range(3)]
full_mesh = gr.Model3D(label="Whole Mesh", height=384)
refine_btn = gr.Button("Refine")
gr.Markdown("## 4. Mesh refinement")
with gr.Row():
refined_meshes = [gr.Model3D(label=f"refined mesh {i+1}", height=384) for i in range(3)]
refined_full_mesh = gr.Model3D(label="refined whole mesh", height=384)
gr.Markdown(_CITE_)
# 交互逻辑
pose_btn.click(
arbitrary_to_apose,
inputs=[input_image, seed_input],
outputs=a_pose_image
)
view_btn.click(
apose_to_multiview,
inputs=[a_pose_image, seed_input2],
outputs=[multiview_gallery, state2]
)
mesh_btn.click(
multiview_to_mesh,
inputs=[multiview_gallery, state2],
outputs=[*mesh_cols, full_mesh, state3]
)
refine_btn.click(
refine_mesh,
inputs=[*mesh_cols, seed_input2, state3],
outputs=[refined_meshes[2], refined_meshes[0], refined_meshes[1], refined_full_mesh]
)
if __name__ == "__main__":
demo.launch(ssr_mode=False) |