StdGEN / app.py
YulianSa's picture
update
e08235d
raw
history blame
5.93 kB
import spaces
import gradio as gr
import numpy as np
import glob
import torch
import random
from tempfile import NamedTemporaryFile
from PIL import Image
import os
import subprocess
def install_cuda_toolkit():
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
os.environ["CUDA_HOME"] = "/usr/local/cuda"
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
os.environ["CUDA_HOME"],
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
)
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
install_cuda_toolkit()
from infer_api import InferAPI
config_canocalize = {
'config_path': './configs/canonicalization-infer.yaml',
}
config_multiview = {}
config_slrm = {
'config_path': './configs/mesh-slrm-infer.yaml'
}
config_refine = {}
EXAMPLE_IMAGES = glob.glob("./input_cases/*")
EXAMPLE_APOSE_IMAGES = glob.glob("./input_cases_apose/*")
infer_api = InferAPI(config_canocalize, config_multiview, config_slrm, config_refine)
REMINDER = """
### Reminder:
1. **Reference Image**:
- You can upload any reference image (with or without background).
- If the image has an alpha channel (transparency), background segmentation will be automatically performed.
- Alternatively, you can pre-segment the background using other tools and upload the result directly.
- A-pose images are also supported.
2. Real person images generally work well, but note that normals may appear smoother than expected. You can try to use other monocular normal estimation models.
3. The base human model in the output is uncolored due to potential NSFW concerns. If you need colored results, please refer to the official GitHub repository for instructions.
"""
# 示例占位函数 - 需替换实际模型
def arbitrary_to_apose(image, seed):
# convert image to PIL.Image
image = Image.fromarray(image)
return infer_api.genStage1(image, seed)
def apose_to_multiview(apose_img, seed):
# convert image to PIL.Image
apose_img = Image.fromarray(apose_img)
return infer_api.genStage2(apose_img, seed, num_levels=1)[0]["images"]
def multiview_to_mesh(images):
mesh_files = infer_api.genStage3(images)
return mesh_files
def refine_mesh(apose_img, mesh1, mesh2, mesh3, seed):
apose_img = Image.fromarray(apose_img)
infer_api.genStage2(apose_img, seed, num_levels=2)
print(infer_api.multiview_infer.results.keys())
refined = infer_api.genStage4([mesh1, mesh2, mesh3], infer_api.multiview_infer.results)
return refined
with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo:
gr.Markdown(REMINDER)
with gr.Row():
with gr.Column():
gr.Markdown("## 1. Reference Image to A-pose Image")
input_image = gr.Image(label="Input Reference Image", type="numpy", width=384, height=384)
gr.Examples(
examples=EXAMPLE_IMAGES,
inputs=input_image,
label="Click to use sample images",
)
seed_input = gr.Number(
label="Seed",
value=42,
precision=0,
interactive=True
)
pose_btn = gr.Button("Convert")
with gr.Column():
gr.Markdown("## 2. Multi-view Generation")
a_pose_image = gr.Image(label="A-pose Result", type="numpy", width=384, height=384)
gr.Examples(
examples=EXAMPLE_APOSE_IMAGES,
inputs=a_pose_image,
label="Click to use sample A-pose images",
)
seed_input2 = gr.Number(
label="Seed",
value=42,
precision=0,
interactive=True
)
view_btn = gr.Button("Generate Multi-view Images")
with gr.Column():
gr.Markdown("## 3. Semantic-aware Reconstruction")
multiview_gallery = gr.Gallery(
label="Multi-view results",
columns=2,
interactive=False,
height="None"
)
mesh_btn = gr.Button("Reconstruct")
with gr.Row():
mesh_cols = [gr.Model3D(label=f"Mesh {i+1}", interactive=False, height=384) for i in range(3)]
full_mesh = gr.Model3D(label="Whole Mesh", height=384)
refine_btn = gr.Button("Refine")
gr.Markdown("## 4. Mesh refinement")
with gr.Row():
refined_meshes = [gr.Model3D(label=f"refined mesh {i+1}", height=384) for i in range(3)]
refined_full_mesh = gr.Model3D(label="refined whole mesh", height=384)
# 交互逻辑
pose_btn.click(
arbitrary_to_apose,
inputs=[input_image, seed_input],
outputs=a_pose_image
)
view_btn.click(
apose_to_multiview,
inputs=[a_pose_image, seed_input2],
outputs=multiview_gallery
)
mesh_btn.click(
multiview_to_mesh,
inputs=multiview_gallery,
outputs=[*mesh_cols, full_mesh]
)
refine_btn.click(
refine_mesh,
inputs=[a_pose_image, *mesh_cols, seed_input2],
outputs=[refined_meshes[2], refined_meshes[0], refined_meshes[1], refined_full_mesh]
)
if __name__ == "__main__":
demo.launch()