YulianSa commited on
Commit
914c133
·
1 Parent(s): e6f7ceb
Files changed (3) hide show
  1. app.py +67 -20
  2. infer_refine.py +67 -66
  3. pre-requirements.txt +1 -0
app.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
  import glob
5
  import torch
6
  import random
 
7
  from tempfile import NamedTemporaryFile
8
  from PIL import Image
9
  import os
@@ -74,29 +75,73 @@ If you find our work useful for your research or applications, please cite using
74
  If you have any questions, feel free to open a discussion or contact us at <b>hyz22@mails.tsinghua.edu.cn</b>.
75
  """
76
 
 
 
 
 
 
 
 
77
  # 示例占位函数 - 需替换实际模型
78
  def arbitrary_to_apose(image, seed):
79
  # convert image to PIL.Image
80
  image = Image.fromarray(image)
81
- return infer_api.genStage1(image, seed)
 
 
 
 
 
 
 
 
 
 
82
 
83
  def apose_to_multiview(apose_img, seed):
84
  # convert image to PIL.Image
85
  apose_img = Image.fromarray(apose_img)
86
- results = infer_api.genStage2(apose_img, seed, num_levels=1)
87
- infer_api.add_results(results)
88
- return results[0]["images"]
89
-
90
- def multiview_to_mesh(images):
91
- mesh_files = infer_api.genStage3(images)
92
- return mesh_files
93
-
94
- def refine_mesh(apose_img, mesh1, mesh2, mesh3, seed):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  apose_img = Image.fromarray(apose_img)
96
- results = infer_api.genStage2(apose_img, seed, num_levels=2)
97
- infer_api.add_results(results)
98
- print(infer_api.results.keys())
99
- refined = infer_api.genStage4([mesh1, mesh2, mesh3], infer_api.results)
 
 
 
 
 
 
 
 
100
  return refined
101
 
102
  with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo:
@@ -112,7 +157,7 @@ with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation fr
112
  )
113
  seed_input = gr.Number(
114
  label="Seed",
115
- value=50,
116
  precision=0,
117
  interactive=True
118
  )
@@ -131,6 +176,7 @@ with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation fr
131
  precision=0,
132
  interactive=True
133
  )
 
134
  view_btn = gr.Button("Generate Multi-view Images")
135
 
136
  with gr.Column():
@@ -141,6 +187,7 @@ with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation fr
141
  interactive=False,
142
  height="None"
143
  )
 
144
  mesh_btn = gr.Button("Reconstruct")
145
 
146
  with gr.Row():
@@ -165,20 +212,20 @@ with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation fr
165
  view_btn.click(
166
  apose_to_multiview,
167
  inputs=[a_pose_image, seed_input2],
168
- outputs=multiview_gallery
169
  )
170
 
171
  mesh_btn.click(
172
  multiview_to_mesh,
173
- inputs=multiview_gallery,
174
- outputs=[*mesh_cols, full_mesh]
175
  )
176
 
177
  refine_btn.click(
178
  refine_mesh,
179
- inputs=[a_pose_image, *mesh_cols, seed_input2],
180
  outputs=[refined_meshes[2], refined_meshes[0], refined_meshes[1], refined_full_mesh]
181
  )
182
 
183
  if __name__ == "__main__":
184
- demo.launch()
 
4
  import glob
5
  import torch
6
  import random
7
+ import imagehash
8
  from tempfile import NamedTemporaryFile
9
  from PIL import Image
10
  import os
 
75
  If you have any questions, feel free to open a discussion or contact us at <b>hyz22@mails.tsinghua.edu.cn</b>.
76
  """
77
 
78
+ cache_arbitrary = {}
79
+ cache_multiview = [ {}, {}, {} ]
80
+ cache_slrm = {}
81
+ cache_refine = {}
82
+
83
+ tmp_path = '/tmp'
84
+
85
  # 示例占位函数 - 需替换实际模型
86
  def arbitrary_to_apose(image, seed):
87
  # convert image to PIL.Image
88
  image = Image.fromarray(image)
89
+ image_hash = str(imagehash.average_hash(image)) + '_' + str(seed)
90
+ if image_hash not in cache_arbitrary:
91
+ apose_img = infer_api.genStage1(image, seed)
92
+ apose_img.save(f'{tmp_path}/{image_hash}.png')
93
+ cache_arbitrary[image_hash] = f'{tmp_path}/{image_hash}.png'
94
+ print(f'cached apose image: {image_hash}')
95
+ return apose_img
96
+ else:
97
+ apose_img = Image.open(cache_arbitrary[image_hash])
98
+ print(f'loaded cached apose image: {image_hash}')
99
+ return apose_img
100
 
101
  def apose_to_multiview(apose_img, seed):
102
  # convert image to PIL.Image
103
  apose_img = Image.fromarray(apose_img)
104
+ image_hash = str(imagehash.average_hash(apose_img)) + '_' + str(seed)
105
+ if image_hash not in cache_multiview[0]:
106
+ results = infer_api.genStage2(apose_img, seed, num_levels=1)
107
+ for idx, img in enumerate(results[0]["images"]):
108
+ img.save(f'{tmp_path}/{image_hash}_images_{idx}.png')
109
+ for idx, img in enumerate(results[0]["normals"]):
110
+ img.save(f'{tmp_path}/{image_hash}_normals_{idx}.png')
111
+ cache_multiview[0][image_hash] = {
112
+ "images": [f'{tmp_path}/{image_hash}_images_{idx}.png' for idx in range(len(results[0]["images"]))],
113
+ "normals": [f'{tmp_path}/{image_hash}_normals_{idx}.png' for idx in range(len(results[0]["normals"]))]
114
+ }
115
+ print(f'cached multiview images: {image_hash}')
116
+ return results[0]["images"], image_hash
117
+ else:
118
+ print(f'loaded cached multiview images: {image_hash}')
119
+ return [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]], image_hash
120
+
121
+ def multiview_to_mesh(images, image_hash):
122
+ if image_hash not in cache_slrm:
123
+ mesh_files = infer_api.genStage3(images)
124
+ cache_slrm[image_hash] = mesh_files
125
+ print(f'cached slrm files: {image_hash}')
126
+ else:
127
+ mesh_files = cache_slrm[image_hash]
128
+ print(f'loaded cached slrm files: {image_hash}')
129
+ return *mesh_files, image_hash
130
+
131
+ def refine_mesh(apose_img, mesh1, mesh2, mesh3, seed, image_hash):
132
  apose_img = Image.fromarray(apose_img)
133
+ if image_hash not in cache_refine:
134
+ results = infer_api.genStage2(apose_img, seed, num_levels=2)
135
+ results[0] = {}
136
+ results[0]["images"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]]
137
+ results[0]["normals"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["normals"]]
138
+ refined = infer_api.genStage4([mesh1, mesh2, mesh3], results)
139
+ cache_refine[image_hash] = refined
140
+ print(f'cached refined mesh: {image_hash}')
141
+ else:
142
+ refined = cache_refine[image_hash]
143
+ print(f'loaded cached refined mesh: {image_hash}')
144
+
145
  return refined
146
 
147
  with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo:
 
157
  )
158
  seed_input = gr.Number(
159
  label="Seed",
160
+ value=52,
161
  precision=0,
162
  interactive=True
163
  )
 
176
  precision=0,
177
  interactive=True
178
  )
179
+ state2 = gr.State(value="")
180
  view_btn = gr.Button("Generate Multi-view Images")
181
 
182
  with gr.Column():
 
187
  interactive=False,
188
  height="None"
189
  )
190
+ state3 = gr.State(value="")
191
  mesh_btn = gr.Button("Reconstruct")
192
 
193
  with gr.Row():
 
212
  view_btn.click(
213
  apose_to_multiview,
214
  inputs=[a_pose_image, seed_input2],
215
+ outputs=[multiview_gallery, state2]
216
  )
217
 
218
  mesh_btn.click(
219
  multiview_to_mesh,
220
+ inputs=[multiview_gallery, state2],
221
+ outputs=[*mesh_cols, full_mesh, state3]
222
  )
223
 
224
  refine_btn.click(
225
  refine_mesh,
226
+ inputs=[a_pose_image, *mesh_cols, seed_input2, state3],
227
  outputs=[refined_meshes[2], refined_meshes[0], refined_meshes[1], refined_full_mesh]
228
  )
229
 
230
  if __name__ == "__main__":
231
+ demo.launch(server_name="0.0.0.0", share=True, server_port=24527)
infer_refine.py CHANGED
@@ -16,16 +16,16 @@ from sklearn.neighbors import KDTree
16
 
17
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
18
 
19
- sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
20
- generator = SamAutomaticMaskGenerator(
21
- model=sam,
22
- points_per_side=64,
23
- pred_iou_thresh=0.80,
24
- stability_score_thresh=0.92,
25
- crop_n_layers=1,
26
- crop_n_points_downscale_factor=2,
27
- min_mask_region_area=100,
28
- )
29
 
30
 
31
  def fix_vert_color_glb(mesh_path):
@@ -49,28 +49,59 @@ def srgb_to_linear(c_srgb):
49
  return c_linear.clip(0, 1.)
50
 
51
 
 
 
 
 
 
52
  def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
53
- # convert from pytorch3d meshes to trimesh mesh
54
  vertices = meshes.verts_packed().cpu().float().numpy()
55
  triangles = meshes.faces_packed().cpu().long().numpy()
56
- np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  if save_glb_path.endswith(".glb"):
58
- # rotate 180 along +Y
59
  vertices[:, [0, 2]] = -vertices[:, [0, 2]]
60
 
61
- if apply_sRGB_to_LinearRGB:
62
- np_color = srgb_to_linear(np_color)
63
- assert vertices.shape[0] == np_color.shape[0]
64
- assert np_color.shape[1] == 3
65
- assert 0 <= np_color.min() and np_color.max() <= 1.001, f"min={np_color.min()}, max={np_color.max()}"
66
- np_color = np.clip(np_color, 0, 1)
67
- mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
68
  mesh.remove_unreferenced_vertices()
69
- # save mesh
 
70
  mesh.export(save_glb_path)
71
- if save_glb_path.endswith(".glb"):
72
- fix_vert_color_glb(save_glb_path)
73
- print(f"saving to {save_glb_path}")
74
 
75
 
76
  def calc_horizontal_offset(target_img, source_img):
@@ -124,43 +155,7 @@ def get_distract_mask(color_0, color_1, normal_0=None, normal_1=None, thres=0.25
124
  max_x, max_y = bbox.max(axis=0)
125
  distract_bbox[min_x:max_x, min_y:max_y] = 1
126
 
127
- points = np.array(random_sampled_points)[:, ::-1]
128
- labels = np.ones(len(points), dtype=np.int32)
129
-
130
- masks = generator.generate((color_1 * 255).astype(np.uint8))
131
-
132
- outside_area = np.abs(color_0 - color_1).sum(axis=-1) < outside_thres
133
-
134
- final_mask = np.zeros_like(distract_mask)
135
- for iii, mask in enumerate(masks):
136
- mask['segmentation'] = cv2.resize(mask['segmentation'].astype(np.float32), (1024, 1024)) > 0.5
137
- intersection = np.logical_and(mask['segmentation'], distract_mask).sum()
138
- total = mask['segmentation'].sum()
139
- iou = intersection / total
140
- outside_intersection = np.logical_and(mask['segmentation'], outside_area).sum()
141
- outside_total = mask['segmentation'].sum()
142
- outside_iou = outside_intersection / outside_total
143
- if iou > ratio and outside_iou < outside_ratio:
144
- final_mask |= mask['segmentation']
145
-
146
- # calculate coverage
147
- intersection = np.logical_and(final_mask, distract_mask).sum()
148
- total = distract_mask.sum()
149
- coverage = intersection / total
150
-
151
- if coverage < 0.8:
152
- # use original distract mask
153
- final_mask = (distract_mask.copy() * 255).astype(np.uint8)
154
- final_mask = cv2.dilate(final_mask, np.ones((3, 3), np.uint8), iterations=3)
155
- labeled_array_dilate, num_features_dilate = scipy.ndimage.label(final_mask)
156
- for i in range(num_features_dilate + 1):
157
- if np.sum(labeled_array_dilate == i) < 200:
158
- final_mask[labeled_array_dilate == i] = 255
159
-
160
- final_mask = cv2.erode(final_mask, np.ones((3, 3), np.uint8), iterations=3)
161
- final_mask = final_mask > 127
162
-
163
- return distract_mask, distract_bbox, random_sampled_points, final_mask
164
 
165
 
166
  if __name__ == '__main__':
@@ -172,6 +167,9 @@ if __name__ == '__main__':
172
  parser.add_argument('--no_decompose', action='store_true')
173
  args = parser.parse_args()
174
 
 
 
 
175
  for test_idx in os.listdir(args.input_mv_dir):
176
  mv_root_dir = os.path.join(args.input_mv_dir, test_idx)
177
  obj_dir = os.path.join(args.input_obj_dir, test_idx)
@@ -228,7 +226,7 @@ if __name__ == '__main__':
228
  normals.append(normal)
229
 
230
  if last_front_color is not None and level == 0:
231
- original_mask, distract_bbox, _, distract_mask = get_distract_mask(last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=args.outside_ratio)
232
  cv2.imwrite(f'{args.output_dir}/{test_idx}/distract_mask.png', distract_mask.astype(np.uint8) * 255)
233
  else:
234
  distract_mask = None
@@ -275,7 +273,7 @@ if __name__ == '__main__':
275
  # my mesh flow weight by nearest vertexs
276
  try:
277
  if fixed_v is not None and fixed_f is not None and level != 0:
278
- new_mesh_v = new_mesh.verts_packed().cpu().numpy()
279
 
280
  fixed_v_cpu = fixed_v.cpu().numpy()
281
  kdtree_anchor = KDTree(fixed_v_cpu)
@@ -297,14 +295,13 @@ if __name__ == '__main__':
297
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
298
  new_mesh_v += weighted_vec_anchor.cpu().numpy()
299
 
300
- # replace new_mesh verts with new_mesh_v
301
- new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures)
302
 
303
  except Exception as e:
304
  pass
305
 
306
  os.makedirs(f'{args.output_dir}/{test_idx}', exist_ok=True)
307
- save_py3dmesh_with_trimesh_fast(new_mesh, f'{args.output_dir}/{test_idx}/out_{level}.glb', apply_sRGB_to_LinearRGB=False)
308
 
309
  if fixed_v is None:
310
  fixed_v, fixed_f = simp_v, simp_f
@@ -312,6 +309,10 @@ if __name__ == '__main__':
312
  fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
313
  fixed_v = torch.cat([fixed_v, simp_v], dim=0)
314
 
 
 
 
 
315
 
316
  else:
317
  mesh = trimesh.load(obj_dir + f'_0.obj')
 
16
 
17
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
18
 
19
+ # sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
20
+ # generator = SamAutomaticMaskGenerator(
21
+ # model=sam,
22
+ # points_per_side=64,
23
+ # pred_iou_thresh=0.80,
24
+ # stability_score_thresh=0.92,
25
+ # crop_n_layers=1,
26
+ # crop_n_points_downscale_factor=2,
27
+ # min_mask_region_area=100,
28
+ # )
29
 
30
 
31
  def fix_vert_color_glb(mesh_path):
 
49
  return c_linear.clip(0, 1.)
50
 
51
 
52
+ import trimesh
53
+ import numpy as np
54
+ from pytorch3d.structures import Meshes
55
+ from pytorch3d.renderer import TexturesUV
56
+
57
  def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
58
+ # Convert from pytorch3d meshes to trimesh mesh
59
  vertices = meshes.verts_packed().cpu().float().numpy()
60
  triangles = meshes.faces_packed().cpu().long().numpy()
61
+
62
+ # Check if the mesh uses TexturesUV
63
+ if isinstance(meshes.textures, TexturesUV):
64
+ # Extract UV coordinates and texture map
65
+ verts_uvs = meshes.textures.verts_uvs_padded()[0].cpu().numpy() # UV coordinates (N, 2)
66
+ faces_uvs = meshes.textures.faces_uvs_padded()[0].cpu().numpy() # UV face indices (M, 3)
67
+ texture_map = meshes.textures.maps_padded()[0].cpu().numpy() # Texture map (H, W, 3 or 4)
68
+
69
+ # Convert texture map to trimesh-compatible format
70
+ if apply_sRGB_to_LinearRGB:
71
+ texture_map = srgb_to_linear(texture_map)
72
+ texture_map = np.clip(texture_map, 0, 1) # Ensure values are in [0, 1]
73
+ material = trimesh.visual.texture.SimpleMaterial(image=texture_data, diffuse=(255, 255, 255))
74
+
75
+ # Create a trimesh.Trimesh object with UVs and texture
76
+ mesh = trimesh.Trimesh(
77
+ vertices=vertices,
78
+ faces=triangles,
79
+ visual=trimesh.visual.TextureVisuals(
80
+ uv=verts_uvs, # UV coordinates
81
+ image=texture_map, # Texture map
82
+ material=material # Material with texture
83
+ )
84
+ )
85
+ else:
86
+ # Fallback to vertex colors if TexturesUV is not used
87
+ np_color = meshes.textures.verts_features_packed().cpu().float().numpy()
88
+ if apply_sRGB_to_LinearRGB:
89
+ np_color = srgb_to_linear(np_color)
90
+ np_color = np.clip(np_color, 0, 1)
91
+ mesh = trimesh.Trimesh(vertices=vertices, faces=triangles, vertex_colors=np_color)
92
+
93
+ # Rotate 180 degrees along +Y if saving as GLB
94
  if save_glb_path.endswith(".glb"):
 
95
  vertices[:, [0, 2]] = -vertices[:, [0, 2]]
96
 
97
+ # Remove unreferenced vertices
 
 
 
 
 
 
98
  mesh.remove_unreferenced_vertices()
99
+
100
+ # Save mesh
101
  mesh.export(save_glb_path)
102
+ # if save_glb_path.endswith(".glb"):
103
+ # fix_vert_color_glb(save_glb_path)
104
+ print(f"Saving to {save_glb_path}")
105
 
106
 
107
  def calc_horizontal_offset(target_img, source_img):
 
155
  max_x, max_y = bbox.max(axis=0)
156
  distract_bbox[min_x:max_x, min_y:max_y] = 1
157
 
158
+ return distract_mask, distract_bbox, _, _
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  if __name__ == '__main__':
 
167
  parser.add_argument('--no_decompose', action='store_true')
168
  args = parser.parse_args()
169
 
170
+ import time
171
+ start_time = time.time()
172
+
173
  for test_idx in os.listdir(args.input_mv_dir):
174
  mv_root_dir = os.path.join(args.input_mv_dir, test_idx)
175
  obj_dir = os.path.join(args.input_obj_dir, test_idx)
 
226
  normals.append(normal)
227
 
228
  if last_front_color is not None and level == 0:
229
+ distract_mask, distract_bbox, _, _ = get_distract_mask(last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=args.outside_ratio)
230
  cv2.imwrite(f'{args.output_dir}/{test_idx}/distract_mask.png', distract_mask.astype(np.uint8) * 255)
231
  else:
232
  distract_mask = None
 
273
  # my mesh flow weight by nearest vertexs
274
  try:
275
  if fixed_v is not None and fixed_f is not None and level != 0:
276
+ new_mesh_v = new_mesh.vertices.copy()
277
 
278
  fixed_v_cpu = fixed_v.cpu().numpy()
279
  kdtree_anchor = KDTree(fixed_v_cpu)
 
295
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
296
  new_mesh_v += weighted_vec_anchor.cpu().numpy()
297
 
298
+ new_mesh.vertices = new_mesh_v
 
299
 
300
  except Exception as e:
301
  pass
302
 
303
  os.makedirs(f'{args.output_dir}/{test_idx}', exist_ok=True)
304
+ new_mesh.export(f'{args.output_dir}/{test_idx}/out_{level}.glb')
305
 
306
  if fixed_v is None:
307
  fixed_v, fixed_f = simp_v, simp_f
 
309
  fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
310
  fixed_v = torch.cat([fixed_v, simp_v], dim=0)
311
 
312
+ # input("Press Enter to continue...")
313
+
314
+ print('finish', time.time() - start_time)
315
+
316
 
317
  else:
318
  mesh = trimesh.load(obj_dir + f'_0.obj')
pre-requirements.txt CHANGED
@@ -23,3 +23,4 @@ scikit-learn
23
  pygltflib
24
  pymeshlab==2022.2.post3
25
  pytorch_lightning
 
 
23
  pygltflib
24
  pymeshlab==2022.2.post3
25
  pytorch_lightning
26
+ imagehash