Spaces:
Running
on
L40S
Running
on
L40S
File size: 6,364 Bytes
68cd723 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os, sys, time, traceback
print("sys path insert", os.path.join(os.path.dirname(__file__), "dust3r"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "dust3r"))
import cv2
import numpy as np
from PIL import Image, ImageSequence
from einops import rearrange
import torch
from infer.utils import seed_everything, timing_decorator
from infer.utils import get_parameter_number, set_parameter_grad_false
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from third_party.gen_baking import back_projection
from third_party.dust3r_utils import infer_warp_mesh_img
from svrm.ldm.vis_util import render_func
class MeshBaker:
def __init__(
self,
align_model = "third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt",
device = "cuda:0",
align_times = 1,
iou_thresh = 0.8,
force_baking_ele_list = None,
save_memory = False
):
self.device = device
self.save_memory = save_memory
self.align_model = AsymmetricCroCo3DStereo.from_pretrained(align_model)
self.align_model = self.align_model if save_memory else self.align_model.to(device)
self.align_times = align_times
self.align_model.eval()
self.iou_thresh = iou_thresh
self.force_baking_ele_list = [] if force_baking_ele_list is None else force_baking_ele_list
self.force_baking_ele_list = [int(_) for _ in self.force_baking_ele_list]
set_parameter_grad_false(self.align_model)
print('baking align model', get_parameter_number(self.align_model))
def align_and_check(self, src, dst, align_times=3):
try:
st = time.time()
best_baking_flag = False
best_aligned_image = aligned_image = src
best_info = {'match_num': 1000, "mask_iou": self.iou_thresh-0.1}
for i in range(align_times):
aligned_image, info = infer_warp_mesh_img(aligned_image, dst, self.align_model, vis=False)
aligned_image = Image.fromarray(aligned_image)
print(f"{i}-th time align process, mask-iou is {info['mask_iou']}")
if info['mask_iou'] > best_info['mask_iou']:
best_aligned_image, best_info = aligned_image, info
if info['mask_iou'] < self.iou_thresh:
break
print(f"Best Baking Info:{best_info['mask_iou']}")
best_baking_flag = best_info['mask_iou'] > self.iou_thresh
return best_aligned_image, best_info, best_baking_flag
except Exception as e:
print(f"Error processing image: {e}")
traceback.print_exc()
return None, None, None
@timing_decorator("baking mesh")
def __call__(self, *args, **kwargs):
if self.save_memory:
self.align_model = self.align_model.to(self.device)
torch.cuda.empty_cache()
res = self.call(*args, **kwargs)
self.align_model = self.align_model.to("cpu")
else:
res = self.call(*args, **kwargs)
torch.cuda.empty_cache()
return res
def call(self, save_folder):
obj_path = os.path.join(save_folder, "mesh.obj")
raw_texture_path = os.path.join(save_folder, "texture.png")
views_pil = os.path.join(save_folder, "views.jpg")
views_gif = os.path.join(save_folder, "views.gif")
cond_pil = os.path.join(save_folder, "img_nobg.png")
if os.path.exists(views_pil):
views_pil = Image.open(views_pil)
views = rearrange(np.asarray(views_pil, dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
views = [Image.fromarray(views[idx]).convert('RGB') for idx in [0,2,4,5,3,1]]
cond_pil = Image.open(cond_pil).resize((512,512))
elif os.path.exists(views_gif):
views_gif_pil = Image.open(views_gif)
views = [img.convert('RGB') for img in ImageSequence.Iterator(views_gif_pil)]
cond_pil, views = views[0], views[1:]
else:
raise FileNotFoundError("views file not found")
rendered_views = render_func(obj_path, elev=0, n_views=2)
for ele_idx, ele in enumerate([0, 180]):
if ele == 0:
aligned_cond, cond_info, _ = self.align_and_check(cond_pil, rendered_views[0], align_times=self.align_times)
aligned_cond.save(save_folder + f'/aligned_cond.jpg')
aligned_img, info, _ = self.align_and_check(views[0], rendered_views[0], align_times=self.align_times)
aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
if info['mask_iou'] < cond_info['mask_iou']:
print("Using Cond Image to bake front view")
aligned_img = aligned_cond
info = cond_info
need_baking = info['mask_iou'] > self.iou_thresh
else:
aligned_img, info, need_baking = self.align_and_check(views[ele//60], rendered_views[ele_idx])
aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
if need_baking or (ele in self.force_baking_ele_list):
st = time.time()
view1_res = back_projection(
obj_file = obj_path,
init_texture_file = raw_texture_path,
front_view_file = aligned_img,
dst_dir = os.path.join(save_folder, f"view_{ele_idx}"),
render_resolution = aligned_img.size[0],
uv_resolution = 1024,
views = [[0, ele]],
device = self.device
)
print(f"view_{ele_idx} elevation_{ele} baking finished at {time.time() - st}")
obj_path = os.path.join(save_folder, f"view_{ele_idx}/bake/mesh.obj")
raw_texture_path = os.path.join(save_folder, f"view_{ele_idx}/bake/texture.png")
else:
print(f"Skip view_{ele_idx} elevation_{ele} baking")
print("Baking Finished")
return obj_path
if __name__ == "__main__":
baker = MeshBaker()
obj_path = baker("./outputs/test")
print(obj_path) |