import os, sys, time, traceback print("sys path insert", os.path.join(os.path.dirname(__file__), "dust3r")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "dust3r")) import cv2 import numpy as np from PIL import Image, ImageSequence from einops import rearrange import torch from infer.utils import seed_everything, timing_decorator from infer.utils import get_parameter_number, set_parameter_grad_false from dust3r.inference import inference from dust3r.model import AsymmetricCroCo3DStereo from third_party.gen_baking import back_projection from third_party.dust3r_utils import infer_warp_mesh_img from svrm.ldm.vis_util import render_func class MeshBaker: def __init__( self, align_model = "third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt", device = "cuda:0", align_times = 1, iou_thresh = 0.8, force_baking_ele_list = None, save_memory = False ): self.device = device self.save_memory = save_memory self.align_model = AsymmetricCroCo3DStereo.from_pretrained(align_model) self.align_model = self.align_model if save_memory else self.align_model.to(device) self.align_times = align_times self.align_model.eval() self.iou_thresh = iou_thresh self.force_baking_ele_list = [] if force_baking_ele_list is None else force_baking_ele_list self.force_baking_ele_list = [int(_) for _ in self.force_baking_ele_list] set_parameter_grad_false(self.align_model) print('baking align model', get_parameter_number(self.align_model)) def align_and_check(self, src, dst, align_times=3): try: st = time.time() best_baking_flag = False best_aligned_image = aligned_image = src best_info = {'match_num': 1000, "mask_iou": self.iou_thresh-0.1} for i in range(align_times): aligned_image, info = infer_warp_mesh_img(aligned_image, dst, self.align_model, vis=False) aligned_image = Image.fromarray(aligned_image) print(f"{i}-th time align process, mask-iou is {info['mask_iou']}") if info['mask_iou'] > best_info['mask_iou']: best_aligned_image, best_info = aligned_image, info if info['mask_iou'] < self.iou_thresh: break print(f"Best Baking Info:{best_info['mask_iou']}") best_baking_flag = best_info['mask_iou'] > self.iou_thresh return best_aligned_image, best_info, best_baking_flag except Exception as e: print(f"Error processing image: {e}") traceback.print_exc() return None, None, None @timing_decorator("baking mesh") def __call__(self, *args, **kwargs): if self.save_memory: self.align_model = self.align_model.to(self.device) torch.cuda.empty_cache() res = self.call(*args, **kwargs) self.align_model = self.align_model.to("cpu") else: res = self.call(*args, **kwargs) torch.cuda.empty_cache() return res def call(self, save_folder): obj_path = os.path.join(save_folder, "mesh.obj") raw_texture_path = os.path.join(save_folder, "texture.png") views_pil = os.path.join(save_folder, "views.jpg") views_gif = os.path.join(save_folder, "views.gif") cond_pil = os.path.join(save_folder, "img_nobg.png") if os.path.exists(views_pil): views_pil = Image.open(views_pil) views = rearrange(np.asarray(views_pil, dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2) views = [Image.fromarray(views[idx]).convert('RGB') for idx in [0,2,4,5,3,1]] cond_pil = Image.open(cond_pil).resize((512,512)) elif os.path.exists(views_gif): views_gif_pil = Image.open(views_gif) views = [img.convert('RGB') for img in ImageSequence.Iterator(views_gif_pil)] cond_pil, views = views[0], views[1:] else: raise FileNotFoundError("views file not found") rendered_views = render_func(obj_path, elev=0, n_views=2) for ele_idx, ele in enumerate([0, 180]): if ele == 0: aligned_cond, cond_info, _ = self.align_and_check(cond_pil, rendered_views[0], align_times=self.align_times) aligned_cond.save(save_folder + f'/aligned_cond.jpg') aligned_img, info, _ = self.align_and_check(views[0], rendered_views[0], align_times=self.align_times) aligned_img.save(save_folder + f'/aligned_{ele}.jpg') if info['mask_iou'] < cond_info['mask_iou']: print("Using Cond Image to bake front view") aligned_img = aligned_cond info = cond_info need_baking = info['mask_iou'] > self.iou_thresh else: aligned_img, info, need_baking = self.align_and_check(views[ele//60], rendered_views[ele_idx]) aligned_img.save(save_folder + f'/aligned_{ele}.jpg') if need_baking or (ele in self.force_baking_ele_list): st = time.time() view1_res = back_projection( obj_file = obj_path, init_texture_file = raw_texture_path, front_view_file = aligned_img, dst_dir = os.path.join(save_folder, f"view_{ele_idx}"), render_resolution = aligned_img.size[0], uv_resolution = 1024, views = [[0, ele]], device = self.device ) print(f"view_{ele_idx} elevation_{ele} baking finished at {time.time() - st}") obj_path = os.path.join(save_folder, f"view_{ele_idx}/bake/mesh.obj") raw_texture_path = os.path.join(save_folder, f"view_{ele_idx}/bake/texture.png") else: print(f"Skip view_{ele_idx} elevation_{ele} baking") print("Baking Finished") return obj_path if __name__ == "__main__": baker = MeshBaker() obj_path = baker("./outputs/test") print(obj_path)