File size: 6,364 Bytes
68cd723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os, sys, time, traceback
print("sys path insert", os.path.join(os.path.dirname(__file__), "dust3r"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "dust3r"))

import cv2
import numpy as np
from PIL import Image, ImageSequence
from einops import rearrange
import torch

from infer.utils import seed_everything, timing_decorator
from infer.utils import get_parameter_number, set_parameter_grad_false

from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo

from third_party.gen_baking import back_projection
from third_party.dust3r_utils import infer_warp_mesh_img
from svrm.ldm.vis_util import render_func


class MeshBaker:
    def __init__(
        self, 
        align_model = "third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt",
        device = "cuda:0", 
        align_times = 1,
        iou_thresh = 0.8, 
        force_baking_ele_list = None,
        save_memory = False
    ):
        self.device = device
        self.save_memory = save_memory
        self.align_model = AsymmetricCroCo3DStereo.from_pretrained(align_model)
        self.align_model = self.align_model if save_memory else self.align_model.to(device)
        self.align_times = align_times
        self.align_model.eval()
        self.iou_thresh = iou_thresh
        self.force_baking_ele_list = [] if force_baking_ele_list is None else force_baking_ele_list
        self.force_baking_ele_list = [int(_) for _ in self.force_baking_ele_list]
        set_parameter_grad_false(self.align_model)
        print('baking align model', get_parameter_number(self.align_model))
    
    def align_and_check(self, src, dst, align_times=3):
        try:
            st = time.time()
            best_baking_flag = False
            best_aligned_image = aligned_image = src
            best_info = {'match_num': 1000, "mask_iou": self.iou_thresh-0.1}
            for i in range(align_times):
                aligned_image, info = infer_warp_mesh_img(aligned_image, dst, self.align_model, vis=False)
                aligned_image = Image.fromarray(aligned_image)
                print(f"{i}-th time align process, mask-iou is {info['mask_iou']}")
                if info['mask_iou'] > best_info['mask_iou']:
                    best_aligned_image, best_info = aligned_image, info
                if info['mask_iou'] < self.iou_thresh:
                    break
            print(f"Best Baking Info:{best_info['mask_iou']}")
            best_baking_flag = best_info['mask_iou'] > self.iou_thresh
            return best_aligned_image, best_info, best_baking_flag
        except Exception as e:
            print(f"Error processing image: {e}")
            traceback.print_exc()
            return None, None, None
        
    @timing_decorator("baking mesh")
    def __call__(self, *args, **kwargs):
        if self.save_memory:
            self.align_model = self.align_model.to(self.device)
            torch.cuda.empty_cache()
            res = self.call(*args, **kwargs)
            self.align_model = self.align_model.to("cpu")
        else:
            res = self.call(*args, **kwargs)
        torch.cuda.empty_cache()
        return res
    
    def call(self, save_folder):
        obj_path         = os.path.join(save_folder, "mesh.obj")
        raw_texture_path = os.path.join(save_folder, "texture.png")
        views_pil        = os.path.join(save_folder, "views.jpg")
        views_gif        = os.path.join(save_folder, "views.gif")
        cond_pil         = os.path.join(save_folder, "img_nobg.png")

        if os.path.exists(views_pil):
            views_pil = Image.open(views_pil)
            views = rearrange(np.asarray(views_pil, dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
            views = [Image.fromarray(views[idx]).convert('RGB') for idx in [0,2,4,5,3,1]] 
            cond_pil = Image.open(cond_pil).resize((512,512))
        elif os.path.exists(views_gif):
            views_gif_pil = Image.open(views_gif)
            views = [img.convert('RGB') for img in ImageSequence.Iterator(views_gif_pil)]
            cond_pil, views = views[0], views[1:]
        else:
            raise FileNotFoundError("views file not found")
                
        rendered_views = render_func(obj_path, elev=0, n_views=2)
        
        for ele_idx, ele in enumerate([0, 180]):
            
            if ele == 0:
                aligned_cond, cond_info, _ = self.align_and_check(cond_pil, rendered_views[0], align_times=self.align_times)
                aligned_cond.save(save_folder + f'/aligned_cond.jpg')
        
                aligned_img, info, _ = self.align_and_check(views[0], rendered_views[0], align_times=self.align_times)
                aligned_img.save(save_folder + f'/aligned_{ele}.jpg')
                
                if info['mask_iou'] < cond_info['mask_iou']:
                    print("Using Cond Image to bake front view")
                    aligned_img = aligned_cond
                    info = cond_info
                need_baking = info['mask_iou'] > self.iou_thresh
            else:
                aligned_img, info, need_baking = self.align_and_check(views[ele//60], rendered_views[ele_idx])
                aligned_img.save(save_folder + f'/aligned_{ele}.jpg')

            if need_baking or (ele in self.force_baking_ele_list):
                st = time.time()
                view1_res = back_projection(
                    obj_file = obj_path,
                    init_texture_file = raw_texture_path,
                    front_view_file = aligned_img,
                    dst_dir = os.path.join(save_folder, f"view_{ele_idx}"),
                    render_resolution = aligned_img.size[0], 
                    uv_resolution = 1024,
                    views = [[0, ele]],
                    device = self.device
                )
                print(f"view_{ele_idx} elevation_{ele} baking finished at {time.time() - st}")
                obj_path = os.path.join(save_folder, f"view_{ele_idx}/bake/mesh.obj")
                raw_texture_path = os.path.join(save_folder, f"view_{ele_idx}/bake/texture.png")
            else:
                print(f"Skip view_{ele_idx} elevation_{ele} baking")

        print("Baking Finished")
        return obj_path
    

if __name__ == "__main__":
    baker = MeshBaker()
    obj_path = baker("./outputs/test")
    print(obj_path)