Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,346 Bytes
062b5a5 37aeb5b 5a3e910 37aeb5b 69ac8ac 531ccc1 69ac8ac 37aeb5b 531ccc1 37aeb5b 69ac8ac 37aeb5b 69ac8ac 37aeb5b 5807069 37aeb5b 69ac8ac 37aeb5b 5807069 37aeb5b c07e086 37aeb5b 69ac8ac 37aeb5b e4f6021 5807069 37aeb5b 5807069 37aeb5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import spaces
import os
import gradio as gr
from PIL import Image
from pytorch3d.structures import Meshes
from gradio_app.utils import clean_up
from gradio_app.custom_models.mvimg_prediction import run_mvprediction
from gradio_app.custom_models.normal_prediction import predict_normals
from scripts.refine_lr_to_sr import run_sr_fast
from scripts.utils import save_glb_and_video
# from scripts.multiview_inference import geo_reconstruct
from scripts.multiview_inference import geo_reconstruct_part1, geo_reconstruct_part2, geo_reconstruct_part3
@spaces.GPU(duration=100)
def run_mv(preview_img, input_processing, seed):
if preview_img.size[0] <= 512:
preview_img = run_sr_fast([preview_img])[0]
rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
return rgb_pils, front_pil
@spaces.GPU(duration=100) # seems split into multiple part will leads to `RuntimeError`, before fix it, still initialize here
def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
if preview_img is None:
raise gr.Error("The input image is none!")
if isinstance(preview_img, str):
preview_img = Image.open(preview_img)
rgb_pils, front_pil = run_mv(preview_img, input_processing, seed)
vertices, faces, img_list = geo_reconstruct_part1(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
meshes = geo_reconstruct_part2(vertices, faces)
new_meshes = geo_reconstruct_part3(meshes, img_list)
vertices = new_meshes.verts_packed()
vertices = vertices / 2 * 1.35
vertices[..., [0, 2]] = - vertices[..., [0, 2]]
new_meshes = Meshes(verts=[vertices], faces=new_meshes.faces_list(), textures=new_meshes.textures)
ret_mesh, video = save_glb_and_video("/tmp/gradio/generated", new_meshes, with_timestamp=True, dist=3.5, fov_in_degrees=2 / 1.35, cam_type="ortho", export_video=render_video)
return ret_mesh, video, gr.update(value=ret_mesh, visible=True)
#######################################
def create_ui(concurrency_id="wkl"):
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
input_processing = gr.Checkbox(
value=True,
label='Remove Background',
visible=True,
)
do_refine = gr.Checkbox(value=True, label="Refine Multiview Details", visible=False)
expansion_weight = gr.Slider(minimum=-1., maximum=1.0, value=0.1, step=0.1, label="Expansion Weight", visible=False)
init_type = gr.Dropdown(choices=["std", "thin"], label="Mesh Initialization", value="std", visible=False)
setable_seed = gr.Slider(-1, 1000000000, -1, step=1, visible=True, label="Seed")
render_video = gr.Checkbox(value=False, visible=False, label="generate video")
fullrunv2_btn = gr.Button('Generate 3D', variant = "primary", interactive=True)
example_folder = os.path.join(os.path.dirname(__file__), "./examples")
example_fns = sorted([os.path.join(example_folder, example) for example in os.listdir(example_folder)])
gr.Examples(
examples=example_fns,
inputs=[input_image],
cache_examples=False,
label='Examples',
examples_per_page=12
)
with gr.Column(scale=1):
# export mesh display
output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320, camera_position=(90, 90, 2))
download_button = gr.DownloadButton(label="Download 3D", visible=False)
output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
fullrunv2_btn.click(
fn = generate3dv2,
inputs=[input_image, input_processing, setable_seed, render_video, do_refine, expansion_weight, init_type],
outputs=[output_mesh, output_video, download_button],
concurrency_id=concurrency_id,
api_name="generate3dv2",
).success(clean_up, api_name=False)
return input_image
|