Hunyuan3D-1 / third_party /mesh_baker.py
Huiwenshi's picture
Upload folder using huggingface_hub
68cd723 verified
raw
history blame
6.36 kB
import os, sys, time, traceback
print("sys path insert", os.path.join(os.path.dirname(__file__), "dust3r"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "dust3r"))
import cv2
import numpy as np
from PIL import Image, ImageSequence
from einops import rearrange
import torch
from infer.utils import seed_everything, timing_decorator
from infer.utils import get_parameter_number, set_parameter_grad_false
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from third_party.gen_baking import back_projection
from third_party.dust3r_utils import infer_warp_mesh_img
from svrm.ldm.vis_util import render_func
class MeshBaker:
def __init__(
self,
align_model = "third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt",
device = "cuda:0",
align_times = 1,
iou_thresh = 0.8,
force_baking_ele_list = None,
save_memory = False
):
self.device = device
self.save_memory = save_memory
self.align_model = AsymmetricCroCo3DStereo.from_pretrained(align_model)
self.align_model = self.align_model if save_memory else self.align_model.to(device)
self.align_times = align_times
self.align_model.eval()
self.iou_thresh = iou_thresh
self.force_baking_ele_list = [] if force_baking_ele_list is None else force_baking_ele_list
self.force_baking_ele_list = [int(_) for _ in self.force_baking_ele_list]
set_parameter_grad_false(self.align_model)
print('baking align model', get_parameter_number(self.align_model))
def align_and_check(self, src, dst, align_times=3):
try:
st = time.time()
best_baking_flag = False
best_aligned_image = aligned_image = src
best_info = {'match_num': 1000, "mask_iou": self.iou_thresh-0.1}
for i in range(align_times):
aligned_image, info = infer_warp_mesh_img(aligned_image, dst, self.align_model, vis=False)
aligned_image = Image.fromarray(aligned_image)
print(f"{i}-th time align process, mask-iou is {info['mask_iou']}")
if info['mask_iou'] > best_info['mask_iou']:
best_aligned_image, best_info = aligned_image, info
if info['mask_iou'] < self.iou_thresh:
break
print(f"Best Baking Info:{best_info['mask_iou']}")
best_baking_flag = best_info['mask_iou'] > self.iou_thresh
return best_aligned_image, best_info, best_baking_flag
except Exception as e:
print(f"Error processing image: {e}")
traceback.print_exc()
return None, None, None
@timing_decorator("baking mesh")
def __call__(self, *args, **kwargs):
if self.save_memory:
self.align_model = self.align_model.to(self.device)
torch.cuda.empty_cache()
res = self.call(*args, **kwargs)
self.align_model = self.align_model.to("cpu")
else:
res = self.call(*args, **kwargs)
torch.cuda.empty_cache()
return res
def call(self, save_folder):
obj_path = os.path.join(save_folder, "mesh.obj")
raw_texture_path = os.path.join(save_folder, "texture.png")
views_pil = os.path.join(save_folder, "views.jpg")
views_gif = os.path.join(save_folder, "views.gif")
cond_pil = os.path.join(save_folder, "img_nobg.png")
if os.path.exists(views_pil):
views_pil = Image.open(views_pil)
views = rearrange(np.asarray(views_pil, dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
views = [Image.fromarray(views[idx]).convert('RGB') for idx in [0,2,4,5,3,1]]
cond_pil = Image.open(cond_pil).resize((512,512))
elif os.path.exists(views_gif):
views_gif_pil = Image.open(views_gif)
views = [img.convert('RGB') for img in ImageSequence.Iterator(views_gif_pil)]
cond_pil, views = views[0], views[1:]
else:
raise FileNotFoundError("views file not found")
rendered_views = render_func(obj_path, elev=0, n_views=2)
for ele_idx, ele in enumerate([0, 180]):
if ele == 0:
aligned_cond, cond_info, _ = self.align_and_check(cond_pil, rendered_views[0], align_times=self.align_times)
aligned_cond.save(save_folder + f'/aligned_cond.jpg')
aligned_img, info, _ = self.align_and_check(views[0], rendered_views[0], align_times=self.align_times)
aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
if info['mask_iou'] < cond_info['mask_iou']:
print("Using Cond Image to bake front view")
aligned_img = aligned_cond
info = cond_info
need_baking = info['mask_iou'] > self.iou_thresh
else:
aligned_img, info, need_baking = self.align_and_check(views[ele//60], rendered_views[ele_idx])
aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
if need_baking or (ele in self.force_baking_ele_list):
st = time.time()
view1_res = back_projection(
obj_file = obj_path,
init_texture_file = raw_texture_path,
front_view_file = aligned_img,
dst_dir = os.path.join(save_folder, f"view_{ele_idx}"),
render_resolution = aligned_img.size[0],
uv_resolution = 1024,
views = [[0, ele]],
device = self.device
)
print(f"view_{ele_idx} elevation_{ele} baking finished at {time.time() - st}")
obj_path = os.path.join(save_folder, f"view_{ele_idx}/bake/mesh.obj")
raw_texture_path = os.path.join(save_folder, f"view_{ele_idx}/bake/texture.png")
else:
print(f"Skip view_{ele_idx} elevation_{ele} baking")
print("Baking Finished")
return obj_path
if __name__ == "__main__":
baker = MeshBaker()
obj_path = baker("./outputs/test")
print(obj_path)