Linoy Tsaban commited on
Commit
ba508b5
1 Parent(s): 9aade41

Update preprocess_utils.py

Browse files
Files changed (1) hide show
  1. preprocess_utils.py +290 -145
preprocess_utils.py CHANGED
@@ -22,158 +22,303 @@ def get_timesteps(scheduler, num_inference_steps, strength, device):
22
  timesteps = scheduler.timesteps[t_start:]
23
 
24
  return timesteps, num_inference_steps - t_start
25
-
26
- @torch.no_grad()
27
- def decode_latents(pipe, latents):
28
- decoded = []
29
- batch_size = 8
30
- for b in range(0, latents.shape[0], batch_size):
31
- latents_batch = 1 / 0.18215 * latents[b:b + batch_size]
32
- imgs = pipe.vae.decode(latents_batch).sample
33
- imgs = (imgs / 2 + 0.5).clamp(0, 1)
34
- decoded.append(imgs)
35
- return torch.cat(decoded)
36
-
37
- @torch.no_grad()
38
- def ddim_inversion(pipe, cond, latent_frames, batch_size, save_latents=True, timesteps_to_save=None):
39
-
40
- timesteps = reversed(pipe.scheduler.timesteps)
41
- timesteps_to_save = timesteps_to_save if timesteps_to_save is not None else timesteps
42
- for i, t in enumerate(tqdm(timesteps)):
43
- for b in range(0, latent_frames.shape[0], batch_size):
44
- x_batch = latent_frames[b:b + batch_size]
45
- model_input = x_batch
46
- cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
47
- #remove comment from commented block to support controlnet
48
- # if self.sd_version == 'depth':
49
- # depth_maps = torch.cat([self.depth_maps[b: b + batch_size]])
50
- # model_input = torch.cat([x_batch, depth_maps],dim=1)
51
-
52
- alpha_prod_t = pipe.scheduler.alphas_cumprod[t]
53
- alpha_prod_t_prev = (
54
- pipe.scheduler.alphas_cumprod[timesteps[i - 1]]
55
- if i > 0 else pipe.scheduler.final_alpha_cumprod
56
- )
57
 
58
- mu = alpha_prod_t ** 0.5
59
- mu_prev = alpha_prod_t_prev ** 0.5
60
- sigma = (1 - alpha_prod_t) ** 0.5
61
- sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
62
 
63
-
64
- #remove line below and replace with commented block to support controlnet
65
- eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
66
- # if self.sd_version != 'ControlNet':
67
- # eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
68
- # else:
69
- # eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))
70
-
71
- pred_x0 = (x_batch - sigma_prev * eps) / mu_prev
72
- latent_frames[b:b + batch_size] = mu * pred_x0 + sigma * eps
73
 
74
- # if save_latents and t in timesteps_to_save:
75
- # torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
76
- # torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
77
- return latent_frames
78
-
79
- @torch.no_grad()
80
- def ddim_sample(pipe, x, cond, batch_size):
81
- timesteps = pipe.scheduler.timesteps
82
- for i, t in enumerate(tqdm(timesteps)):
83
- for b in range(0, x.shape[0], batch_size):
84
- x_batch = x[b:b + batch_size]
85
- model_input = x_batch
86
- cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- #remove comment from commented block to support controlnet
89
- # if self.sd_version == 'depth':
90
- # depth_maps = torch.cat([self.depth_maps[b: b + batch_size]])
91
- # model_input = torch.cat([x_batch, depth_maps],dim=1)
92
-
93
- alpha_prod_t = pipe.scheduler.alphas_cumprod[t]
94
- alpha_prod_t_prev = (
95
- pipe.scheduler.alphas_cumprod[timesteps[i + 1]]
96
- if i < len(timesteps) - 1
97
- else pipe.scheduler.final_alpha_cumprod
98
  )
99
- mu = alpha_prod_t ** 0.5
100
- sigma = (1 - alpha_prod_t) ** 0.5
101
- mu_prev = alpha_prod_t_prev ** 0.5
102
- sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
103
-
104
- #remove line below and replace with commented block to support controlnet
105
- eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
106
- # if self.sd_version != 'ControlNet':
107
- # eps = pipe.unet(model_input, t, encoder_hidden_states=cond_batch).sample
108
- # else:
109
- # eps = self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))
110
-
111
- pred_x0 = (x_batch - sigma * eps) / mu
112
- x[b:b + batch_size] = mu_prev * pred_x0 + sigma_prev * eps
113
- return x
114
-
115
-
116
- @torch.no_grad()
117
- def get_text_embeds(pipe, prompt, negative_prompt, batch_size=1, device="cuda"):
118
- # Tokenize text and get embeddings
119
- text_input = pipe.tokenizer(prompt, padding='max_length', max_length=pipe.tokenizer.model_max_length,
120
- truncation=True, return_tensors='pt')
121
- text_embeddings = pipe.text_encoder(text_input.input_ids.to(pipe.device))[0]
122
-
123
- # Do the same for unconditional embeddings
124
- uncond_input = pipe.tokenizer(negative_prompt, padding='max_length', max_length=pipe.tokenizer.model_max_length,
125
- return_tensors='pt')
126
-
127
- uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(pipe.device))[0]
128
-
129
- # Cat for final embeddings
130
- text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size)
131
- return text_embeddings
132
-
133
- @torch.no_grad()
134
- def extract_latents(pipe,
135
- num_steps,
136
- latent_frames,
137
- batch_size,
138
- timesteps_to_save,
139
- inversion_prompt=''):
140
- pipe.scheduler.set_timesteps(num_steps)
141
- cond = get_text_embeds(pipe, inversion_prompt, "", device=pipe.device)[1].unsqueeze(0)
142
- # latent_frames = self.latents
143
-
144
- inverted_latents = ddim_inversion(pipe, cond,
145
- latent_frames,
146
- batch_size=batch_size,
147
- save_latents=False,
148
- timesteps_to_save=timesteps_to_save)
149
-
150
- # latent_reconstruction = ddim_sample(pipe, inverted_latents, cond, batch_size=batch_size)
151
 
152
- # rgb_reconstruction = decode_latents(pipe, latent_reconstruction)
 
 
 
 
 
 
 
 
153
 
154
- # return rgb_reconstruction
155
- return inverted_latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- @torch.no_grad()
158
- def encode_imgs(pipe, imgs, batch_size=10, deterministic=True):
159
- imgs = 2 * imgs - 1
160
- latents = []
161
- for i in range(0, len(imgs), batch_size):
162
- posterior = pipe.vae.encode(imgs[i:i + batch_size]).latent_dist
163
- latent = posterior.mean if deterministic else posterior.sample()
164
- latents.append(latent * 0.18215)
165
- latents = torch.cat(latents)
166
- return latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
- def get_data(pipe, frames, n_frames):
169
- """
170
- converts frames to tensors, saves to device and encodes to obtain latents
171
- """
172
- frames = frames[:n_frames]
173
- if frames[0].size[0] == frames[0].size[1]:
174
- frames = [frame.convert("RGB").resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
175
- stacked_tensor_frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(pipe.device)
176
- # encode to latents
177
- latents = encode_imgs(pipe, stacked_tensor_frames, deterministic=True).to(torch.float16).to(pipe.device)
178
- return stacked_tensor_frames, latents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
 
22
  timesteps = scheduler.timesteps[t_start:]
23
 
24
  return timesteps, num_inference_steps - t_start
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
26
 
27
+ class Preprocess(nn.Module):
28
+ def __init__(self, device, opt, hf_key=None):
29
+ super().__init__()
 
 
 
 
 
 
 
30
 
31
+ self.device = device
32
+ self.sd_version = opt["sd_version"]
33
+ self.use_depth = False
34
+ self.config = opt
35
+
36
+ print(f'[INFO] loading stable diffusion...')
37
+ if hf_key is not None:
38
+ print(f'[INFO] using hugging face custom model key: {hf_key}')
39
+ model_key = hf_key
40
+ elif self.sd_version == '2.1':
41
+ model_key = "stabilityai/stable-diffusion-2-1-base"
42
+ elif self.sd_version == '2.0':
43
+ model_key = "stabilityai/stable-diffusion-2-base"
44
+ elif self.sd_version == '1.5' or self.sd_version == 'ControlNet':
45
+ model_key = "runwayml/stable-diffusion-v1-5"
46
+ elif self.sd_version == 'depth':
47
+ model_key = "stabilityai/stable-diffusion-2-depth"
48
+ else:
49
+ raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')
50
+ self.model_key = model_key
51
+ # Create model
52
+ self.vae = AutoencoderKL.from_pretrained(model_key, subfolder="vae", revision="fp16",
53
+ torch_dtype=torch.float16).to(self.device)
54
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_key, subfolder="tokenizer")
55
+ self.text_encoder = CLIPTextModel.from_pretrained(model_key, subfolder="text_encoder", revision="fp16",
56
+ torch_dtype=torch.float16).to(self.device)
57
+ self.unet = UNet2DConditionModel.from_pretrained(model_key, subfolder="unet", revision="fp16",
58
+ torch_dtype=torch.float16).to(self.device)
59
+ self.total_inverted_latents = {}
60
+
61
+ self.paths, self.frames, self.latents = self.get_data(self.config["data_path"], self.config["n_frames"])
62
+ print("self.frames", self.frames.shape)
63
+ print("self.latents", self.latents.shape)
64
+
65
+
66
+ if self.sd_version == 'ControlNet':
67
+ from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
68
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16).to(self.device)
69
+ control_pipe = StableDiffusionControlNetPipeline.from_pretrained(
70
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
71
+ ).to(self.device)
72
+ self.unet = control_pipe.unet
73
+ self.controlnet = control_pipe.controlnet
74
+ self.canny_cond = self.get_canny_cond()
75
+ elif self.sd_version == 'depth':
76
+ self.depth_maps = self.prepare_depth_maps()
77
+ self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
78
+
79
+ # self.unet.enable_xformers_memory_efficient_attention()
80
+ print(f'[INFO] loaded stable diffusion!')
81
+
82
+ @torch.no_grad()
83
+ def prepare_depth_maps(self, model_type='DPT_Large', device='cuda'):
84
+ depth_maps = []
85
+ midas = torch.hub.load("intel-isl/MiDaS", model_type)
86
+ midas.to(device)
87
+ midas.eval()
88
+
89
+ midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
90
+
91
+ if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
92
+ transform = midas_transforms.dpt_transform
93
+ else:
94
+ transform = midas_transforms.small_transform
95
+
96
+ for i in range(len(self.paths)):
97
+ img = cv2.imread(self.paths[i])
98
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
99
+
100
+ latent_h = img.shape[0] // 8
101
+ latent_w = img.shape[1] // 8
102
 
103
+ input_batch = transform(img).to(device)
104
+ prediction = midas(input_batch)
105
+
106
+ depth_map = torch.nn.functional.interpolate(
107
+ prediction.unsqueeze(1),
108
+ size=(latent_h, latent_w),
109
+ mode="bicubic",
110
+ align_corners=False,
 
 
111
  )
112
+ depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
113
+ depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
114
+ depth_map = 2.0 * (depth_map - depth_min) / (depth_max - depth_min) - 1.0
115
+ depth_maps.append(depth_map)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
+ return torch.cat(depth_maps).to(self.device).to(torch.float16)
118
+
119
+ @torch.no_grad()
120
+ def get_canny_cond(self):
121
+ canny_cond = []
122
+ for image in self.frames.cpu().permute(0, 2, 3, 1):
123
+ image = np.uint8(np.array(255 * image))
124
+ low_threshold = 100
125
+ high_threshold = 200
126
 
127
+ image = cv2.Canny(image, low_threshold, high_threshold)
128
+ image = image[:, :, None]
129
+ image = np.concatenate([image, image, image], axis=2)
130
+ image = torch.from_numpy((image.astype(np.float32) / 255.0))
131
+ canny_cond.append(image)
132
+ canny_cond = torch.stack(canny_cond).permute(0, 3, 1, 2).to(self.device).to(torch.float16)
133
+ return canny_cond
134
+
135
+ def controlnet_pred(self, latent_model_input, t, text_embed_input, controlnet_cond):
136
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
137
+ latent_model_input,
138
+ t,
139
+ encoder_hidden_states=text_embed_input,
140
+ controlnet_cond=controlnet_cond,
141
+ conditioning_scale=1,
142
+ return_dict=False,
143
+ )
144
+
145
+ # apply the denoising network
146
+ noise_pred = self.unet(
147
+ latent_model_input,
148
+ t,
149
+ encoder_hidden_states=text_embed_input,
150
+ cross_attention_kwargs={},
151
+ down_block_additional_residuals=down_block_res_samples,
152
+ mid_block_additional_residual=mid_block_res_sample,
153
+ return_dict=False,
154
+ )[0]
155
+ return noise_pred
156
 
157
+ @torch.no_grad()
158
+ def get_text_embeds(self, prompt, negative_prompt, device="cuda"):
159
+ text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
160
+ truncation=True, return_tensors='pt')
161
+ text_embeddings = self.text_encoder(text_input.input_ids.to(device))[0]
162
+ uncond_input = self.tokenizer(negative_prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
163
+ return_tensors='pt')
164
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]
165
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
166
+ return text_embeddings
167
+
168
+ @torch.no_grad()
169
+ def decode_latents(self, latents):
170
+ decoded = []
171
+ batch_size = 8
172
+ for b in range(0, latents.shape[0], batch_size):
173
+ latents_batch = 1 / 0.18215 * latents[b:b + batch_size]
174
+ imgs = self.vae.decode(latents_batch).sample
175
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
176
+ decoded.append(imgs)
177
+ return torch.cat(decoded)
178
+
179
+ @torch.no_grad()
180
+ def encode_imgs(self, imgs, batch_size=10, deterministic=True):
181
+ imgs = 2 * imgs - 1
182
+ latents = []
183
+ for i in range(0, len(imgs), batch_size):
184
+ posterior = self.vae.encode(imgs[i:i + batch_size]).latent_dist
185
+ latent = posterior.mean if deterministic else posterior.sample()
186
+ latents.append(latent * 0.18215)
187
+ latents = torch.cat(latents)
188
+ return latents
189
+
190
+ def get_data(self, frames_path, n_frames):
191
+
192
+ # load frames
193
+ if not self.config["frames"]:
194
+ paths = [f"{frames_path}/%05d.png" % i for i in range(n_frames)]
195
+ print(paths)
196
+ if not os.path.exists(paths[0]):
197
+ paths = [f"{frames_path}/%05d.jpg" % i for i in range(n_frames)]
198
+ self.paths = paths
199
+ frames = [Image.open(path).convert('RGB') for path in paths]
200
+ if frames[0].size[0] == frames[0].size[1]:
201
+ frames = [frame.resize((512, 512), resample=Image.Resampling.LANCZOS) for frame in frames]
202
+ else:
203
+ frames = self.config["frames"][:n_frames]
204
+ frames = torch.stack([T.ToTensor()(frame) for frame in frames]).to(torch.float16).to(self.device)
205
+ # encode to latents
206
+ latents = self.encode_imgs(frames, deterministic=True).to(torch.float16).to(self.device)
207
+ print("frames", frames.shape)
208
+ print("latents", latents.shape)
209
+
210
+ if not self.config["frames"]:
211
+ return paths, frames, latents
212
+ else:
213
+ return None, frames, latents
214
+
215
+ @torch.no_grad()
216
+ def ddim_inversion(self, cond, latent_frames, save_path, batch_size, save_latents=True, timesteps_to_save=None):
217
+ timesteps = reversed(self.scheduler.timesteps)
218
+ timesteps_to_save = timesteps_to_save if timesteps_to_save is not None else timesteps
219
+
220
+ return_inverted_latents = self.config["frames"] is not None
221
+ for i, t in enumerate(tqdm(timesteps)):
222
+ for b in range(0, latent_frames.shape[0], batch_size):
223
+ x_batch = latent_frames[b:b + batch_size]
224
+ model_input = x_batch
225
+ cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
226
+ if self.sd_version == 'depth':
227
+ depth_maps = torch.cat([self.depth_maps[b: b + batch_size]])
228
+ model_input = torch.cat([x_batch, depth_maps],dim=1)
229
+
230
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
231
+ alpha_prod_t_prev = (
232
+ self.scheduler.alphas_cumprod[timesteps[i - 1]]
233
+ if i > 0 else self.scheduler.final_alpha_cumprod
234
+ )
235
+
236
+ mu = alpha_prod_t ** 0.5
237
+ mu_prev = alpha_prod_t_prev ** 0.5
238
+ sigma = (1 - alpha_prod_t) ** 0.5
239
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
240
+
241
+ eps = self.unet(model_input, t, encoder_hidden_states=cond_batch).sample if self.sd_version != 'ControlNet' \
242
+ else self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))
243
+ pred_x0 = (x_batch - sigma_prev * eps) / mu_prev
244
+ latent_frames[b:b + batch_size] = mu * pred_x0 + sigma * eps
245
+
246
+ if return_inverted_latents and t in timesteps_to_save:
247
+ self.total_inverted_latents[f'noisy_latents_{t}'] = latent_frames.clone()
248
 
249
+ if save_latents and t in timesteps_to_save:
250
+ torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
251
+
252
+ if save_latents:
253
+ torch.save(latent_frames, os.path.join(save_path, 'latents', f'noisy_latents_{t}.pt'))
254
+ if return_inverted_latents:
255
+ self.total_inverted_latents[f'noisy_latents_{t}'] = latent_frames.clone()
256
+
257
+ return latent_frames
258
+
259
+ @torch.no_grad()
260
+ def ddim_sample(self, x, cond, batch_size):
261
+ timesteps = self.scheduler.timesteps
262
+ for i, t in enumerate(tqdm(timesteps)):
263
+ for b in range(0, x.shape[0], batch_size):
264
+ x_batch = x[b:b + batch_size]
265
+ model_input = x_batch
266
+ cond_batch = cond.repeat(x_batch.shape[0], 1, 1)
267
+
268
+ if self.sd_version == 'depth':
269
+ depth_maps = torch.cat([self.depth_maps[b: b + batch_size]])
270
+ model_input = torch.cat([x_batch, depth_maps],dim=1)
271
+
272
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
273
+ alpha_prod_t_prev = (
274
+ self.scheduler.alphas_cumprod[timesteps[i + 1]]
275
+ if i < len(timesteps) - 1
276
+ else self.scheduler.final_alpha_cumprod
277
+ )
278
+ mu = alpha_prod_t ** 0.5
279
+ sigma = (1 - alpha_prod_t) ** 0.5
280
+ mu_prev = alpha_prod_t_prev ** 0.5
281
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
282
+
283
+ eps = self.unet(model_input, t, encoder_hidden_states=cond_batch).sample if self.sd_version != 'ControlNet' \
284
+ else self.controlnet_pred(x_batch, t, cond_batch, torch.cat([self.canny_cond[b: b + batch_size]]))
285
+
286
+ pred_x0 = (x_batch - sigma * eps) / mu
287
+ x[b:b + batch_size] = mu_prev * pred_x0 + sigma_prev * eps
288
+ return x
289
+
290
+ @torch.no_grad()
291
+ def extract_latents(self,
292
+ num_steps,
293
+ save_path,
294
+ batch_size,
295
+ timesteps_to_save,
296
+ inversion_prompt='',
297
+ reconstruct=False):
298
+ self.scheduler.set_timesteps(num_steps)
299
+ cond = self.get_text_embeds(inversion_prompt, "")[1].unsqueeze(0)
300
+ latent_frames = self.latents
301
+ print("latent_frames", latent_frames.shape)
302
+
303
+ inverted_x= self.ddim_inversion(cond,
304
+ latent_frames,
305
+ save_path,
306
+ batch_size=batch_size,
307
+ save_latents=True if save_path else False,
308
+ timesteps_to_save=timesteps_to_save)
309
+
310
+
311
+
312
+ # print("total_inverted_latents", len(total_inverted_latents.keys()))
313
+
314
+ if reconstruct:
315
+ latent_reconstruction = self.ddim_sample(inverted_x, cond, batch_size=batch_size)
316
+
317
+ rgb_reconstruction = self.decode_latents(latent_reconstruction)
318
+ return self.frames, self.latents, self.total_inverted_latents, rgb_reconstruction
319
+
320
+ return self.frames, self.latents, self.total_inverted_latents, None
321
+
322
+
323
+
324