File size: 8,665 Bytes
cc9780d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import torch
import torch.nn as nn
from models.modules.resunet import ResUnet_DirectAttenMultiImg_Cond
from models.modules.parpoints_encoder import ParPoint_Encoder
from models.modules.PointEMB import PointEmbed
from models.modules.utils import StackedRandomGenerator
from models.modules.diffusion_sampler import edm_sampler
from models.modules.encoder import DiagonalGaussianDistribution
import numpy as np
class EDMLoss_MultiImgCond:
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5,use_par=False):
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.use_par=use_par

    def __call__(self, net, data_batch, classifier_free=False):
        inputs = data_batch['input']
        image=data_batch['image']
        proj_mat=data_batch['proj_mat']
        valid_frames=data_batch['valid_frames']
        par_points=data_batch["par_points"]
        category_code=data_batch["category_code"]
        rnd_normal = torch.randn([inputs.shape[0], 1, 1, 1], device=inputs.device)

        sigma = (rnd_normal * self.P_std + self.P_mean).exp() #B,1,1,1
        weight = (sigma ** 2 + self.sigma_data ** 2) / (self.sigma_data * sigma) ** 2
        y=inputs

        n = torch.randn_like(y) * sigma

        # if classifier_free and np.random.random()<0.5:
        #     net.par_feat=torch.zeros((inputs.shape[0],32,inputs.shape[2],inputs.shape[3])).float().to(inputs.device)
        if classifier_free and np.random.random()<0.5:
            image=torch.zeros_like(image).float().cuda()
        net.module.extract_img_feat(image)
        net.module.set_proj_matrix(proj_mat)
        net.module.set_valid_frames(valid_frames)
        net.module.set_category_code(category_code)
        if self.use_par:
            net.module.extract_point_feat(par_points)

        D_yn = net(y + n,sigma)
        loss = weight * ((D_yn - y) ** 2)
        return loss

class Triplane_Diff_MultiImgCond_EDM(nn.Module):
    def __init__(self,opt):
        super().__init__()
        self.diff_reso=opt['diff_reso']
        self.diff_dim=opt['output_channel']
        self.use_cat_embedding=opt['use_cat_embedding']
        self.use_fp16=False
        self.sigma_data=0.5
        self.sigma_max=float("inf")
        self.sigma_min=0
        self.use_par=opt['use_par']
        self.triplane_padding=opt['triplane_padding']
        self.block_type=opt['block_type']
        #self.use_bn=opt['use_bn']
        if opt['backbone']=="resunet_multiimg_direct_atten":
            self.denoise_model=ResUnet_DirectAttenMultiImg_Cond(channel=opt['input_channel'],
                                       output_channel=opt['output_channel'],use_par=opt['use_par'],par_channel=opt['par_channel'],
                                       img_in_channels=opt['img_in_channels'],vit_reso=opt['vit_reso'],triplane_padding=self.triplane_padding,
                                       norm=opt['norm'],use_cat_embedding=self.use_cat_embedding,block_type=self.block_type)
        else:
            raise NotImplementedError
        if opt['use_par']: #use partial point cloud as inputs
            par_emb_dim = opt['par_emb_dim']
            par_args = opt['par_point_encoder']
            self.point_embedder = PointEmbed(hidden_dim=par_emb_dim)
            self.par_points_encoder = ParPoint_Encoder(c_dim=par_args['plane_latent_dim'], dim=par_emb_dim,
                                                       plane_resolution=par_args['plane_reso'],
                                                       unet_kwargs=par_args['unet'])
        self.unflatten = torch.nn.Unflatten(1, (16, 16))
    def prepare_data(self,data_batch):
        #par_points = data_batch['par_points'].to(device, non_blocking=True)
        device=torch.device("cuda")
        means, logvars = data_batch['triplane_mean'].to(device, non_blocking=True), data_batch['triplane_logvar'].to(
            device, non_blocking=True)
        distribution = DiagonalGaussianDistribution(means, logvars)
        plane_feat = distribution.sample()

        image=data_batch["image"].to(device)
        proj_mat = data_batch['proj_mat'].to(device, non_blocking=True)
        valid_frames=data_batch["valid_frames"].to(device,non_blocking=True)
        par_points=data_batch["par_points"].to(device,non_blocking=True)
        category_code=data_batch["category_code"].to(device,non_blocking=True)
        input_dict = {"input": plane_feat.float(),
                      "image": image.float(),
                      "par_points":par_points.float(),
                      "proj_mat":proj_mat.float(),
                      "category_code":category_code.float(),
                      "valid_frames":valid_frames.float()}  # TODO: add image and proj matrix

        return input_dict

    def prepare_sample_data(self,data_batch):
        device=torch.device("cuda")
        image=data_batch['image'].to(device, non_blocking=True)
        proj_mat = data_batch['proj_mat'].to(device, non_blocking=True)
        valid_frames = data_batch["valid_frames"].to(device, non_blocking=True)
        par_points = data_batch["par_points"].to(device, non_blocking=True)
        category_code=data_batch["category_code"].to(device,non_blocking=True)
        sample_dict={
            "image":image.float(),
            "proj_mat":proj_mat.float(),
            "valid_frames":valid_frames.float(),
            "category_code":category_code.float(),
            "par_points":par_points.float(),
        }
        return sample_dict

    def prepare_eval_data(self,data_batch):
        device=torch.device("cuda")
        samples=data_batch["points"].to(device, non_blocking=True)
        labels=data_batch['labels'].to(device,non_blocking=True)

        eval_dict={
            "samples":samples,
            "labels":labels,
        }
        return eval_dict

    def extract_point_feat(self,par_points):
        par_emb=self.point_embedder(par_points)
        self.par_feat=self.par_points_encoder(par_points,par_emb)

    def extract_img_feat(self,image):
        self.image_emb=image

    def set_proj_matrix(self,proj_matrix):
        self.proj_matrix=proj_matrix

    def set_valid_frames(self,valid_frames):
        self.valid_frames=valid_frames

    def set_category_code(self,category_code):
        self.category_code=category_code

    def forward(self, x, sigma,force_fp32=False):
        x = x.to(torch.float32)
        sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) #B,1,1,1
        dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32

        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
        c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
        c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
        c_noise = sigma.log() / 4 #B,1,1,1, need to check how to add embedding into unet

        if self.use_par:
            F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(), self.image_emb, self.proj_matrix,
                                     self.valid_frames,self.category_code,self.par_feat)
        else:
            F_x = self.denoise_model((c_in * x).to(dtype), c_noise.flatten(),self.image_emb,self.proj_matrix,
                                     self.valid_frames,self.category_code)
        assert F_x.dtype == dtype
        D_x = c_skip * x + c_out * F_x.to(torch.float32)
        return D_x

    def round_sigma(self, sigma):
        return torch.as_tensor(sigma)

    @torch.no_grad()
    def sample(self, input_batch, batch_seeds=None,ret_all=False,num_steps=18):
        img_cond=input_batch['image']
        proj_mat=input_batch['proj_mat']
        valid_frames=input_batch["valid_frames"]
        category_code=input_batch["category_code"]
        if img_cond is not None:
            batch_size, device = img_cond.shape[0], img_cond.device
            if batch_seeds is None:
                batch_seeds = torch.arange(batch_size)
        else:
            device = batch_seeds.device
            batch_size = batch_seeds.shape[0]

        self.extract_img_feat(img_cond)
        self.set_proj_matrix(proj_mat)
        self.set_valid_frames(valid_frames)
        self.set_category_code(category_code)
        if self.use_par:
            par_points=input_batch["par_points"]
            self.extract_point_feat(par_points)
        rnd = StackedRandomGenerator(device, batch_seeds)
        latents = rnd.randn([batch_size, self.diff_dim, self.diff_reso*3,self.diff_reso], device=device)

        return edm_sampler(self, latents, randn_like=rnd.randn_like,ret_all=ret_all,sigma_min=0.002, sigma_max=80,num_steps=num_steps)