Wuvin commited on
Commit
69ac8ac
1 Parent(s): 9d4fa56

Better ZeroGPU utilization

Browse files
app.py CHANGED
@@ -41,9 +41,9 @@ _DESCRIPTION = '''
41
 
42
  * The demo is still under construction, and more features are expected to be implemented soon.
43
 
44
- * The demo takes around 50 seconds on L4.
45
 
46
- * If the Gradio Demo unfortunately hangs or is very crowded, you can use the Gradio Demo or Online Demo. The Online Demo is free to try, and the registration invitation code is `aiuni24`. However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, and the generation results is less stable, but the quality of the texture is better.
47
 
48
 
49
  '''
@@ -53,7 +53,7 @@ def launch():
53
 
54
  with gr.Blocks(
55
  title=_TITLE,
56
- theme=gr.themes.Monochrome(),
57
  ) as demo:
58
  with gr.Row():
59
  with gr.Column(scale=1):
 
41
 
42
  * The demo is still under construction, and more features are expected to be implemented soon.
43
 
44
+ * The demo takes around 50 seconds on L4, and about 60 seconds on Huggingface ZeroGPU.
45
 
46
+ * If the Huggingface Demo unfortunately hangs or is very crowded, you can use the Gradio Demo or Online Demo. The Online Demo is free to try, and the registration invitation code is `aiuni24`. However, the Online Demo is slightly different from the Gradio Demo, in that the inference speed is slower, and the generation results is less stable, but the quality of the texture is better.
47
 
48
 
49
  '''
 
53
 
54
  with gr.Blocks(
55
  title=_TITLE,
56
+ # theme=gr.themes.Monochrome(),
57
  ) as demo:
58
  with gr.Row():
59
  with gr.Column(scale=1):
gradio_app/gradio_3dgen.py CHANGED
@@ -8,19 +8,30 @@ from gradio_app.custom_models.mvimg_prediction import run_mvprediction
8
  from gradio_app.custom_models.normal_prediction import predict_normals
9
  from scripts.refine_lr_to_sr import run_sr_fast
10
  from scripts.utils import save_glb_and_video
11
- from scripts.multiview_inference import geo_reconstruct
 
 
 
 
 
 
 
 
12
 
13
- @spaces.GPU(duration=100)
14
  def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
15
  if preview_img is None:
16
- raise gr.Error("preview_img is none")
17
  if isinstance(preview_img, str):
18
  preview_img = Image.open(preview_img)
19
 
20
- if preview_img.size[0] <= 512:
21
- preview_img = run_sr_fast([preview_img])[0]
22
- rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
23
- new_meshes = geo_reconstruct(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
 
 
 
 
24
  vertices = new_meshes.verts_packed()
25
  vertices = vertices / 2 * 1.35
26
  vertices[..., [0, 2]] = - vertices[..., [0, 2]]
@@ -32,7 +43,7 @@ def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refi
32
  #######################################
33
  def create_ui(concurrency_id="wkl"):
34
  with gr.Row():
35
- with gr.Column(scale=2):
36
  input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
37
 
38
  example_folder = os.path.join(os.path.dirname(__file__), "./examples")
@@ -46,7 +57,7 @@ def create_ui(concurrency_id="wkl"):
46
  )
47
 
48
 
49
- with gr.Column(scale=3):
50
  # export mesh display
51
  output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320)
52
  output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
 
8
  from gradio_app.custom_models.normal_prediction import predict_normals
9
  from scripts.refine_lr_to_sr import run_sr_fast
10
  from scripts.utils import save_glb_and_video
11
+ # from scripts.multiview_inference import geo_reconstruct
12
+ from scripts.multiview_inference import geo_reconstruct_part1, geo_reconstruct_part2, geo_reconstruct_part3
13
+
14
+ @spaces.GPU
15
+ def run_mv(preview_img, input_processing, seed):
16
+ if preview_img.size[0] <= 512:
17
+ preview_img = run_sr_fast([preview_img])[0]
18
+ rgb_pils, front_pil = run_mvprediction(preview_img, remove_bg=input_processing, seed=int(seed)) # 6s
19
+ return rgb_pils, front_pil
20
 
 
21
  def generate3dv2(preview_img, input_processing, seed, render_video=True, do_refine=True, expansion_weight=0.1, init_type="std"):
22
  if preview_img is None:
23
+ raise gr.Error("The input image is none!")
24
  if isinstance(preview_img, str):
25
  preview_img = Image.open(preview_img)
26
 
27
+ rgb_pils, front_pil = run_mv(preview_img, input_processing, seed)
28
+
29
+ vertices, faces, img_list = geo_reconstruct_part1(rgb_pils, None, front_pil, do_refine=do_refine, predict_normal=True, expansion_weight=expansion_weight, init_type=init_type)
30
+
31
+ meshes = geo_reconstruct_part2(vertices, faces)
32
+
33
+ new_meshes = geo_reconstruct_part3(meshes, img_list)
34
+
35
  vertices = new_meshes.verts_packed()
36
  vertices = vertices / 2 * 1.35
37
  vertices[..., [0, 2]] = - vertices[..., [0, 2]]
 
43
  #######################################
44
  def create_ui(concurrency_id="wkl"):
45
  with gr.Row():
46
+ with gr.Column(scale=1):
47
  input_image = gr.Image(type='pil', image_mode='RGBA', label='Frontview')
48
 
49
  example_folder = os.path.join(os.path.dirname(__file__), "./examples")
 
57
  )
58
 
59
 
60
+ with gr.Column(scale=1):
61
  # export mesh display
62
  output_mesh = gr.Model3D(value=None, label="Mesh Model", show_label=True, height=320)
63
  output_video = gr.Video(label="Preview", show_label=True, show_share_button=True, height=320, visible=False)
scripts/multiview_inference.py CHANGED
@@ -95,6 +95,62 @@ def geo_reconstruct(rgb_pils, normal_pils, front_pil, do_refine=False, predict_n
95
  normal_stg2 = [img.resize((1024, 1024)) for img in rm_normals] # reduce computation on huggingface demo, use 1024 instead of 2048
96
 
97
  vertices, faces = run_mesh_refine(vertices, faces, normal_stg2, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=20, update_warmup=5, return_mesh=False, process_inputs=False, process_outputs=False)
 
98
  meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=1, apply_sub_divide=True, sub_divide_threshold=0.25).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  new_meshes = multiview_color_projection(meshes, img_list, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([0, 90, 180, 270], "cuda", focal=1))
100
  return new_meshes
 
95
  normal_stg2 = [img.resize((1024, 1024)) for img in rm_normals] # reduce computation on huggingface demo, use 1024 instead of 2048
96
 
97
  vertices, faces = run_mesh_refine(vertices, faces, normal_stg2, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=20, update_warmup=5, return_mesh=False, process_inputs=False, process_outputs=False)
98
+
99
  meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=1, apply_sub_divide=True, sub_divide_threshold=0.25).to("cuda")
100
+
101
+ new_meshes = multiview_color_projection(meshes, img_list, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([0, 90, 180, 270], "cuda", focal=1))
102
+ return new_meshes
103
+
104
+ ########################
105
+ import spaces
106
+
107
+ @spaces.GPU
108
+ def geo_reconstruct_part1(rgb_pils, normal_pils, front_pil, do_refine=False, predict_normal=True, expansion_weight=0.1, init_type="std"):
109
+ if front_pil.size[0] <= 512:
110
+ front_pil = run_sr_fast([front_pil])[0]
111
+ if do_refine:
112
+ refined_rgbs = refine_rgb(rgb_pils, front_pil) # 6s
113
+ else:
114
+ refined_rgbs = [rgb.resize((512, 512), resample=Image.LANCZOS) for rgb in rgb_pils]
115
+ img_list = [front_pil] + run_sr_fast(refined_rgbs[1:])
116
+
117
+ if predict_normal:
118
+ rm_normals = predict_normals([img.resize((512, 512), resample=Image.LANCZOS) for img in img_list], guidance_scale=1.5)
119
+ else:
120
+ rm_normals = simple_remove([img.resize((512, 512), resample=Image.LANCZOS) for img in normal_pils])
121
+ # transfer the alpha channel of rm_normals to img_list
122
+ for idx, img in enumerate(rm_normals):
123
+ if idx == 0 and img_list[0].mode == "RGBA":
124
+ temp = img_list[0].resize((2048, 2048))
125
+ rm_normals[0] = Image.fromarray(np.concatenate([np.array(rm_normals[0])[:, :, :3], np.array(temp)[:, :, 3:4]], axis=-1))
126
+ continue
127
+ img_list[idx] = Image.fromarray(np.concatenate([np.array(img_list[idx]), np.array(img)[:, :, 3:4]], axis=-1))
128
+ assert img_list[0].mode == "RGBA"
129
+ assert np.mean(np.array(img_list[0])[..., 3]) < 250
130
+
131
+ img_list = [img_list[0]] + erode_alpha(img_list[1:])
132
+ normal_stg1 = [img.resize((512, 512)) for img in rm_normals]
133
+ if init_type in ["std", "thin"]:
134
+ meshes = fast_geo(normal_stg1[0], normal_stg1[2], normal_stg1[1], init_type=init_type)
135
+ _ = multiview_color_projection(meshes, rgb_pils, resolution=512, device="cuda", complete_unseen=False, confidence_threshold=0.1) # just check for validation, may throw error
136
+ vertices, faces, _ = from_py3d_mesh(meshes)
137
+ vertices, faces = reconstruct_stage1(normal_stg1, steps=200, vertices=vertices, faces=faces, start_edge_len=0.1, end_edge_len=0.02, gain=0.05, return_mesh=False, loss_expansion_weight=expansion_weight)
138
+ elif init_type in ["ball"]:
139
+ vertices, faces = reconstruct_stage1(normal_stg1, steps=200, end_edge_len=0.01, return_mesh=False, loss_expansion_weight=expansion_weight)
140
+
141
+ normal_stg2 = [img.resize((1024, 1024)) for img in rm_normals] # reduce computation on huggingface demo, use 1024 instead of 2048
142
+
143
+ vertices, faces = run_mesh_refine(vertices, faces, normal_stg2, steps=100, start_edge_len=0.02, end_edge_len=0.005, decay=0.99, update_normal_interval=20, update_warmup=5, return_mesh=False, process_inputs=False, process_outputs=False)
144
+
145
+ return vertices, faces, img_list
146
+
147
+ # no GPU
148
+ def geo_reconstruct_part2(vertices, faces):
149
+ meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=1, apply_sub_divide=True, sub_divide_threshold=0.25)
150
+ return meshes
151
+
152
+ @spaces.GPU
153
+ def geo_reconstruct_part3(meshes, img_list):
154
+ meshes = meshes.to("cuda")
155
  new_meshes = multiview_color_projection(meshes, img_list, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([0, 90, 180, 270], "cuda", focal=1))
156
  return new_meshes