Delik commited on
Commit
668dbc7
·
verified ·
1 Parent(s): 276eac5

Delete code

Browse files
code/LIA_Model.py DELETED
@@ -1,39 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from networks.encoder import Encoder
4
- from networks.styledecoder import Synthesis
5
-
6
- # This part is modified from: https://github.com/wyhsirius/LIA
7
- class LIA_Model(torch.nn.Module):
8
- def __init__(self, size = 256, style_dim = 512, motion_dim = 20, channel_multiplier=1, blur_kernel=[1, 3, 3, 1], fusion_type=''):
9
- super().__init__()
10
- self.enc = Encoder(size, style_dim, motion_dim, fusion_type)
11
- self.dec = Synthesis(size, style_dim, motion_dim, blur_kernel, channel_multiplier)
12
-
13
- def get_start_direction_code(self, x_start, x_target, x_face, x_aug):
14
- enc_dic = self.enc(x_start, x_target, x_face, x_aug)
15
-
16
- wa, alpha, feats = enc_dic['h_source'], enc_dic['h_motion'], enc_dic['feats']
17
-
18
- return wa, alpha, feats
19
-
20
- def render(self, start, direction, feats):
21
- return self.dec(start, direction, feats)
22
-
23
- def load_lightning_model(self, lia_pretrained_model_path):
24
- selfState = self.state_dict()
25
-
26
- state = torch.load(lia_pretrained_model_path, map_location='cpu')
27
- for name, param in state.items():
28
- origName = name;
29
-
30
- if name not in selfState:
31
- name = name.replace("lia.", "")
32
- if name not in selfState:
33
- print("%s is not in the model."%origName)
34
- # You can ignore those errors as some parameters are only used for training
35
- continue
36
- if selfState[name].size() != state[origName].size():
37
- print("Wrong parameter length: %s, model: %s, loaded: %s"%(origName, selfState[name].size(), state[origName].size()))
38
- continue
39
- selfState[name].copy_(param)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/__init__.py DELETED
@@ -1 +0,0 @@
1
- from.LIA_Model import *
 
 
code/app.py DELETED
@@ -1,403 +0,0 @@
1
-
2
- import argparse
3
- from datetime import datetime
4
- from pathlib import Path
5
- import numpy as np
6
- import torch
7
- from PIL import Image
8
- import gradio as gr
9
- import shutil
10
- import librosa
11
- import python_speech_features
12
- import time
13
- from LIA_Model import LIA_Model
14
- import os
15
- from tqdm import tqdm
16
- import argparse
17
- import numpy as np
18
- from torchvision import transforms
19
- from templates import *
20
- import argparse
21
- import shutil
22
- from moviepy.editor import *
23
- import librosa
24
- import python_speech_features
25
- import importlib.util
26
- import time
27
- import os
28
- import time
29
- import numpy as np
30
-
31
-
32
-
33
- # Disable Gradio analytics to avoid network-related issues
34
- gr.analytics_enabled = False
35
-
36
-
37
- def check_package_installed(package_name):
38
- package_spec = importlib.util.find_spec(package_name)
39
- if package_spec is None:
40
- print(f"{package_name} is not installed.")
41
- return False
42
- else:
43
- print(f"{package_name} is installed.")
44
- return True
45
-
46
- def frames_to_video(input_path, audio_path, output_path, fps=25):
47
- image_files = [os.path.join(input_path, img) for img in sorted(os.listdir(input_path))]
48
- clips = [ImageClip(m).set_duration(1/fps) for m in image_files]
49
- video = concatenate_videoclips(clips, method="compose")
50
-
51
- audio = AudioFileClip(audio_path)
52
- final_video = video.set_audio(audio)
53
- final_video.write_videofile(output_path, fps=fps, codec='libx264', audio_codec='aac')
54
-
55
- def load_image(filename, size):
56
- img = Image.open(filename).convert('RGB')
57
- img = img.resize((size, size))
58
- img = np.asarray(img)
59
- img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
60
- return img / 255.0
61
-
62
- def img_preprocessing(img_path, size):
63
- img = load_image(img_path, size) # [0, 1]
64
- img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
65
- imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
66
- return imgs_norm
67
-
68
- def saved_image(img_tensor, img_path):
69
- toPIL = transforms.ToPILImage()
70
- img = toPIL(img_tensor.detach().cpu().squeeze(0)) # 使用squeeze(0)来移除批次维度
71
- img.save(img_path)
72
-
73
- def main(args):
74
- frames_result_saved_path = os.path.join(args.result_path, 'frames')
75
- os.makedirs(frames_result_saved_path, exist_ok=True)
76
- test_image_name = os.path.splitext(os.path.basename(args.test_image_path))[0]
77
- audio_name = os.path.splitext(os.path.basename(args.test_audio_path))[0]
78
- predicted_video_256_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}.mp4')
79
- predicted_video_512_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}_SR.mp4')
80
-
81
- #======Loading Stage 1 model=========
82
- lia = LIA_Model(motion_dim=args.motion_dim, fusion_type='weighted_sum')
83
- lia.load_lightning_model(args.stage1_checkpoint_path)
84
- lia.to(args.device)
85
- #============================
86
-
87
- conf = ffhq256_autoenc()
88
- conf.seed = args.seed
89
- conf.decoder_layers = args.decoder_layers
90
- conf.infer_type = args.infer_type
91
- conf.motion_dim = args.motion_dim
92
-
93
- if args.infer_type == 'mfcc_full_control':
94
- conf.face_location=True
95
- conf.face_scale=True
96
- conf.mfcc = True
97
- elif args.infer_type == 'mfcc_pose_only':
98
- conf.face_location=False
99
- conf.face_scale=False
100
- conf.mfcc = True
101
- elif args.infer_type == 'hubert_pose_only':
102
- conf.face_location=False
103
- conf.face_scale=False
104
- conf.mfcc = False
105
- elif args.infer_type == 'hubert_audio_only':
106
- conf.face_location=False
107
- conf.face_scale=False
108
- conf.mfcc = False
109
- elif args.infer_type == 'hubert_full_control':
110
- conf.face_location=True
111
- conf.face_scale=True
112
- conf.mfcc = False
113
- else:
114
- print('Type NOT Found!')
115
- exit(0)
116
-
117
- if not os.path.exists(args.test_image_path):
118
- print(f'{args.test_image_path} does not exist!')
119
- exit(0)
120
-
121
- if not os.path.exists(args.test_audio_path):
122
- print(f'{args.test_audio_path} does not exist!')
123
- exit(0)
124
-
125
- img_source = img_preprocessing(args.test_image_path, args.image_size).to(args.device)
126
- one_shot_lia_start, one_shot_lia_direction, feats = lia.get_start_direction_code(img_source, img_source, img_source, img_source)
127
-
128
- #======Loading Stage 2 model=========
129
- model = LitModel(conf)
130
- state = torch.load(args.stage2_checkpoint_path, map_location='cpu')
131
- model.load_state_dict(state, strict=True)
132
- model.ema_model.eval()
133
- model.ema_model.to(args.device)
134
- #=================================
135
-
136
- #======Audio Input=========
137
- if conf.infer_type.startswith('mfcc'):
138
- # MFCC features
139
- wav, sr = librosa.load(args.test_audio_path, sr=16000)
140
- input_values = python_speech_features.mfcc(signal=wav, samplerate=sr, numcep=13, winlen=0.025, winstep=0.01)
141
- d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
142
- d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
143
- audio_driven_obj = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
144
- frame_start, frame_end = 0, int(audio_driven_obj.shape[0]/4)
145
- audio_start, audio_end = int(frame_start * 4), int(frame_end * 4) # The video frame is fixed to 25 hz and the audio is fixed to 100 hz
146
-
147
- audio_driven = torch.Tensor(audio_driven_obj[audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
148
-
149
- elif conf.infer_type.startswith('hubert'):
150
- # Hubert features
151
- if not os.path.exists(args.test_hubert_path):
152
-
153
- if not check_package_installed('transformers'):
154
- print('Please install transformers module first.')
155
- exit(0)
156
- hubert_model_path = './ckpts/chinese-hubert-large'
157
- if not os.path.exists(hubert_model_path):
158
- print('Please download the hubert weight into the ckpts path first.')
159
- exit(0)
160
- print('You did not extract the audio features in advance, extracting online now, which will increase processing delay')
161
-
162
- start_time = time.time()
163
-
164
- # load hubert model
165
- from transformers import Wav2Vec2FeatureExtractor, HubertModel
166
- audio_model = HubertModel.from_pretrained(hubert_model_path).to(args.device)
167
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_path)
168
- audio_model.feature_extractor._freeze_parameters()
169
- audio_model.eval()
170
-
171
- # hubert model forward pass
172
- audio, sr = librosa.load(args.test_audio_path, sr=16000)
173
- input_values = feature_extractor(audio, sampling_rate=16000, padding=True, do_normalize=True, return_tensors="pt").input_values
174
- input_values = input_values.to(args.device)
175
- ws_feats = []
176
- with torch.no_grad():
177
- outputs = audio_model(input_values, output_hidden_states=True)
178
- for i in range(len(outputs.hidden_states)):
179
- ws_feats.append(outputs.hidden_states[i].detach().cpu().numpy())
180
- ws_feat_obj = np.array(ws_feats)
181
- ws_feat_obj = np.squeeze(ws_feat_obj, 1)
182
- ws_feat_obj = np.pad(ws_feat_obj, ((0, 0), (0, 1), (0, 0)), 'edge') # align the audio length with video frame
183
-
184
- execution_time = time.time() - start_time
185
- print(f"Extraction Audio Feature: {execution_time:.2f} Seconds")
186
-
187
- audio_driven_obj = ws_feat_obj
188
- else:
189
- print(f'Using audio feature from path: {args.test_hubert_path}')
190
- audio_driven_obj = np.load(args.test_hubert_path)
191
-
192
- frame_start, frame_end = 0, int(audio_driven_obj.shape[1]/2)
193
- audio_start, audio_end = int(frame_start * 2), int(frame_end * 2) # The video frame is fixed to 25 hz and the audio is fixed to 50 hz
194
-
195
- audio_driven = torch.Tensor(audio_driven_obj[:,audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
196
- #============================
197
-
198
- # Diffusion Noise
199
- noisyT = torch.randn((1,frame_end, args.motion_dim)).to(args.device)
200
-
201
- #======Inputs for Attribute Control=========
202
- if os.path.exists(args.pose_driven_path):
203
- pose_obj = np.load(args.pose_driven_path)
204
-
205
- if len(pose_obj.shape) != 2:
206
- print('please check your pose information. The shape must be like (T, 3).')
207
- exit(0)
208
- if pose_obj.shape[1] != 3:
209
- print('please check your pose information. The shape must be like (T, 3).')
210
- exit(0)
211
-
212
- if pose_obj.shape[0] >= frame_end:
213
- pose_obj = pose_obj[:frame_end,:]
214
- else:
215
- padding = np.tile(pose_obj[-1, :], (frame_end - pose_obj.shape[0], 1))
216
- pose_obj = np.vstack((pose_obj, padding))
217
-
218
- pose_signal = torch.Tensor(pose_obj).unsqueeze(0).to(args.device) / 90 # 90 is for normalization here
219
- else:
220
- yaw_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_yaw
221
- pitch_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_pitch
222
- roll_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_roll
223
- pose_signal = torch.cat((yaw_signal, pitch_signal, roll_signal), dim=-1)
224
-
225
- pose_signal = torch.clamp(pose_signal, -1, 1)
226
-
227
- face_location_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_location
228
- face_scae_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_scale
229
- #===========================================
230
-
231
- start_time = time.time()
232
-
233
- #======Diffusion Denosing Process=========
234
- generated_directions = model.render(one_shot_lia_start, one_shot_lia_direction, audio_driven, face_location_signal, face_scae_signal, pose_signal, noisyT, args.step_T, control_flag=args.control_flag)
235
- #=========================================
236
-
237
- execution_time = time.time() - start_time
238
- print(f"Motion Diffusion Model: {execution_time:.2f} Seconds")
239
-
240
- generated_directions = generated_directions.detach().cpu().numpy()
241
-
242
- start_time = time.time()
243
- #======Rendering images frame-by-frame=========
244
- for pred_index in tqdm(range(generated_directions.shape[1])):
245
- ori_img_recon = lia.render(one_shot_lia_start, torch.Tensor(generated_directions[:,pred_index,:]).to(args.device), feats)
246
- ori_img_recon = ori_img_recon.clamp(-1, 1)
247
- wav_pred = (ori_img_recon.detach() + 1) / 2
248
- saved_image(wav_pred, os.path.join(frames_result_saved_path, "%06d.png"%(pred_index)))
249
- #==============================================
250
-
251
- execution_time = time.time() - start_time
252
- print(f"Renderer Model: {execution_time:.2f} Seconds")
253
-
254
- frames_to_video(frames_result_saved_path, args.test_audio_path, predicted_video_256_path)
255
-
256
- shutil.rmtree(frames_result_saved_path)
257
-
258
- # Enhancer
259
- if args.face_sr and check_package_installed('gfpgan'):
260
- from face_sr.face_enhancer import enhancer_list
261
- import imageio
262
-
263
- # Super-resolution
264
- imageio.mimsave(predicted_video_512_path+'.tmp.mp4', enhancer_list(predicted_video_256_path, method='gfpgan', bg_upsampler=None), fps=float(25))
265
-
266
- # Merge audio and video
267
- video_clip = VideoFileClip(predicted_video_512_path+'.tmp.mp4')
268
- audio_clip = AudioFileClip(predicted_video_256_path)
269
- final_clip = video_clip.set_audio(audio_clip)
270
- final_clip.write_videofile(predicted_video_512_path, codec='libx264', audio_codec='aac')
271
-
272
- os.remove(predicted_video_512_path+'.tmp.mp4')
273
-
274
- if args.face_sr:
275
- return predicted_video_256_path, predicted_video_512_path
276
- else:
277
- return predicted_video_256_path, predicted_video_256_path
278
-
279
- def generate_video(uploaded_img, uploaded_audio, infer_type,
280
- pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, device, face_sr, seed):
281
- if uploaded_img is None or uploaded_audio is None:
282
- return None, gr.Markdown("Error: Input image or audio file is empty. Please check and upload both files.")
283
-
284
- model_mapping = {
285
- "mfcc_pose_only": "./ckpts/stage2_pose_only_mfcc.ckpt",
286
- "mfcc_full_control": "./ckpts/stage2_more_controllable_mfcc.ckpt",
287
- "hubert_audio_only": "./ckpts/stage2_audio_only_hubert.ckpt",
288
- "hubert_pose_only": "./ckpts/stage2_pose_only_hubert.ckpt",
289
- "hubert_full_control": "./ckpts/stage2_full_control_hubert.ckpt",
290
- }
291
-
292
- # if face_crop:
293
- # uploaded_img_path = Path(uploaded_img)
294
- # cropped_img_path = uploaded_img_path.with_name(uploaded_img_path.stem + "_crop" + uploaded_img_path.suffix)
295
- # crop_image(uploaded_img, cropped_img_path)
296
- # uploaded_img = str(cropped_img_path)
297
-
298
- # import pdb;pdb.set_trace()
299
-
300
- stage2_checkpoint_path = model_mapping.get(infer_type, "default_checkpoint.ckpt")
301
- try:
302
- args = argparse.Namespace(
303
- infer_type=infer_type,
304
- test_image_path=uploaded_img,
305
- test_audio_path=uploaded_audio,
306
- test_hubert_path='',
307
- result_path='./outputs/',
308
- stage1_checkpoint_path='./ckpts/stage1.ckpt',
309
- stage2_checkpoint_path=stage2_checkpoint_path,
310
- seed=seed,
311
- control_flag=True,
312
- pose_yaw=pose_yaw,
313
- pose_pitch=pose_pitch,
314
- pose_roll=pose_roll,
315
- face_location=face_location,
316
- pose_driven_path='not_supported_in_this_mode',
317
- face_scale=face_scale,
318
- step_T=step_T,
319
- image_size=256,
320
- device=device,
321
- motion_dim=20,
322
- decoder_layers=2,
323
- face_sr=face_sr
324
- )
325
-
326
- # Save the uploaded audio to the expected path
327
- # shutil.copy(uploaded_audio, args.test_audio_path)
328
-
329
- # Run the main function
330
- output_256_video_path, output_512_video_path = main(args)
331
-
332
- # Check if the output video file exists
333
- if not os.path.exists(output_256_video_path):
334
- return None, gr.Markdown("Error: Video generation failed. Please check your inputs and try again.")
335
- if output_256_video_path == output_512_video_path:
336
- return gr.Video(value=output_256_video_path), None, gr.Markdown("Video (256*256 only) generated successfully!")
337
- return gr.Video(value=output_256_video_path), gr.Video(value=output_512_video_path), gr.Markdown("Video generated successfully!")
338
-
339
- except Exception as e:
340
- return None, None, gr.Markdown(f"Error: An unexpected error occurred - {str(e)}")
341
-
342
- default_values = {
343
- "pose_yaw": 0,
344
- "pose_pitch": 0,
345
- "pose_roll": 0,
346
- "face_location": 0.5,
347
- "face_scale": 0.5,
348
- "step_T": 50,
349
- "seed": 0,
350
- "device": "cuda"
351
- }
352
-
353
- with gr.Blocks() as demo:
354
- gr.Markdown('# AniTalker')
355
- gr.Markdown('![]()')
356
- with gr.Row():
357
- with gr.Column():
358
- uploaded_img = gr.Image(type="filepath", label="Reference Image")
359
- uploaded_audio = gr.Audio(type="filepath", label="Input Audio")
360
- with gr.Column():
361
- output_video_256 = gr.Video(label="Generated Video (256)")
362
- output_video_512 = gr.Video(label="Generated Video (512)")
363
- output_message = gr.Markdown()
364
-
365
-
366
-
367
- generate_button = gr.Button("Generate Video")
368
-
369
- with gr.Accordion("Configuration", open=True):
370
- infer_type = gr.Dropdown(
371
- label="Inference Type",
372
- choices=['mfcc_pose_only', 'mfcc_full_control', 'hubert_audio_only', 'hubert_pose_only'],
373
- value='hubert_audio_only'
374
- )
375
- face_sr = gr.Checkbox(label="Enable Face Super-Resolution (512*512)", value=False)
376
- # face_crop = gr.Checkbox(label="Face Crop (Dlib)", value=False)
377
- # face_crop = False # TODO
378
- seed = gr.Number(label="Seed", value=default_values["seed"])
379
- pose_yaw = gr.Slider(label="pose_yaw", minimum=-1, maximum=1, value=default_values["pose_yaw"])
380
- pose_pitch = gr.Slider(label="pose_pitch", minimum=-1, maximum=1, value=default_values["pose_pitch"])
381
- pose_roll = gr.Slider(label="pose_roll", minimum=-1, maximum=1, value=default_values["pose_roll"])
382
- face_location = gr.Slider(label="face_location", minimum=0, maximum=1, value=default_values["face_location"])
383
- face_scale = gr.Slider(label="face_scale", minimum=0, maximum=1, value=default_values["face_scale"])
384
- step_T = gr.Slider(label="step_T", minimum=1, maximum=100, step=1, value=default_values["step_T"])
385
- device = gr.Radio(label="Device", choices=["cuda", "cpu"], value=default_values["device"])
386
-
387
-
388
- generate_button.click(
389
- generate_video,
390
- inputs=[
391
- uploaded_img, uploaded_audio, infer_type,
392
- pose_yaw, pose_pitch, pose_roll, face_location, face_scale, step_T, device, face_sr, seed
393
- ],
394
- outputs=[output_video_256, output_video_512, output_message]
395
- )
396
-
397
- if __name__ == '__main__':
398
- parser = argparse.ArgumentParser(description='EchoMimic')
399
- parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
400
- parser.add_argument('--server_port', type=int, default=3001, help='Server port')
401
- args = parser.parse_args()
402
-
403
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/choices.py DELETED
@@ -1,179 +0,0 @@
1
- from enum import Enum
2
- from torch import nn
3
-
4
-
5
- class TrainMode(Enum):
6
- # manipulate mode = training the classifier
7
- manipulate = 'manipulate'
8
- # default trainin mode!
9
- diffusion = 'diffusion'
10
- # default latent training mode!
11
- # fitting the a DDPM to a given latent
12
- latent_diffusion = 'latentdiffusion'
13
-
14
- def is_manipulate(self):
15
- return self in [
16
- TrainMode.manipulate,
17
- ]
18
-
19
- def is_diffusion(self):
20
- return self in [
21
- TrainMode.diffusion,
22
- TrainMode.latent_diffusion,
23
- ]
24
-
25
- def is_autoenc(self):
26
- # the network possibly does autoencoding
27
- return self in [
28
- TrainMode.diffusion,
29
- ]
30
-
31
- def is_latent_diffusion(self):
32
- return self in [
33
- TrainMode.latent_diffusion,
34
- ]
35
-
36
- def use_latent_net(self):
37
- return self.is_latent_diffusion()
38
-
39
- def require_dataset_infer(self):
40
- """
41
- whether training in this mode requires the latent variables to be available?
42
- """
43
- # this will precalculate all the latents before hand
44
- # and the dataset will be all the predicted latents
45
- return self in [
46
- TrainMode.latent_diffusion,
47
- TrainMode.manipulate,
48
- ]
49
-
50
-
51
- class ManipulateMode(Enum):
52
- """
53
- how to train the classifier to manipulate
54
- """
55
- # train on whole celeba attr dataset
56
- celebahq_all = 'celebahq_all'
57
- # celeba with D2C's crop
58
- d2c_fewshot = 'd2cfewshot'
59
- d2c_fewshot_allneg = 'd2cfewshotallneg'
60
-
61
- def is_celeba_attr(self):
62
- return self in [
63
- ManipulateMode.d2c_fewshot,
64
- ManipulateMode.d2c_fewshot_allneg,
65
- ManipulateMode.celebahq_all,
66
- ]
67
-
68
- def is_single_class(self):
69
- return self in [
70
- ManipulateMode.d2c_fewshot,
71
- ManipulateMode.d2c_fewshot_allneg,
72
- ]
73
-
74
- def is_fewshot(self):
75
- return self in [
76
- ManipulateMode.d2c_fewshot,
77
- ManipulateMode.d2c_fewshot_allneg,
78
- ]
79
-
80
- def is_fewshot_allneg(self):
81
- return self in [
82
- ManipulateMode.d2c_fewshot_allneg,
83
- ]
84
-
85
-
86
- class ModelType(Enum):
87
- """
88
- Kinds of the backbone models
89
- """
90
-
91
- # unconditional ddpm
92
- ddpm = 'ddpm'
93
- # autoencoding ddpm cannot do unconditional generation
94
- autoencoder = 'autoencoder'
95
-
96
- def has_autoenc(self):
97
- return self in [
98
- ModelType.autoencoder,
99
- ]
100
-
101
- def can_sample(self):
102
- return self in [ModelType.ddpm]
103
-
104
-
105
- class ModelName(Enum):
106
- """
107
- List of all supported model classes
108
- """
109
-
110
- beatgans_ddpm = 'beatgans_ddpm'
111
- beatgans_autoenc = 'beatgans_autoenc'
112
-
113
-
114
- class ModelMeanType(Enum):
115
- """
116
- Which type of output the model predicts.
117
- """
118
-
119
- eps = 'eps' # the model predicts epsilon
120
-
121
-
122
- class ModelVarType(Enum):
123
- """
124
- What is used as the model's output variance.
125
-
126
- The LEARNED_RANGE option has been added to allow the model to predict
127
- values between FIXED_SMALL and FIXED_LARGE, making its job easier.
128
- """
129
-
130
- # posterior beta_t
131
- fixed_small = 'fixed_small'
132
- # beta_t
133
- fixed_large = 'fixed_large'
134
-
135
-
136
- class LossType(Enum):
137
- mse = 'mse' # use raw MSE loss (and KL when learning variances)
138
- l1 = 'l1'
139
-
140
-
141
- class GenerativeType(Enum):
142
- """
143
- How's a sample generated
144
- """
145
-
146
- ddpm = 'ddpm'
147
- ddim = 'ddim'
148
-
149
-
150
- class OptimizerType(Enum):
151
- adam = 'adam'
152
- adamw = 'adamw'
153
-
154
-
155
- class Activation(Enum):
156
- none = 'none'
157
- relu = 'relu'
158
- lrelu = 'lrelu'
159
- silu = 'silu'
160
- tanh = 'tanh'
161
-
162
- def get_act(self):
163
- if self == Activation.none:
164
- return nn.Identity()
165
- elif self == Activation.relu:
166
- return nn.ReLU()
167
- elif self == Activation.lrelu:
168
- return nn.LeakyReLU(negative_slope=0.2)
169
- elif self == Activation.silu:
170
- return nn.SiLU()
171
- elif self == Activation.tanh:
172
- return nn.Tanh()
173
- else:
174
- raise NotImplementedError()
175
-
176
-
177
- class ManipulateLossType(Enum):
178
- bce = 'bce'
179
- mse = 'mse'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/config.py DELETED
@@ -1,388 +0,0 @@
1
- from model.unet import ScaleAt
2
- from model.latentnet import *
3
- from diffusion.resample import UniformSampler
4
- from diffusion.diffusion import space_timesteps
5
- from typing import Tuple
6
-
7
- from torch.utils.data import DataLoader
8
-
9
- from config_base import BaseConfig
10
- from diffusion import *
11
- from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
12
- from model import *
13
- from choices import *
14
- from multiprocessing import get_context
15
- import os
16
- from dataset_util import *
17
- from torch.utils.data.distributed import DistributedSampler
18
- from dataset import LatentDataLoader
19
-
20
- @dataclass
21
- class PretrainConfig(BaseConfig):
22
- name: str
23
- path: str
24
-
25
-
26
- @dataclass
27
- class TrainConfig(BaseConfig):
28
- # random seed
29
- seed: int = 0
30
- train_mode: TrainMode = TrainMode.diffusion
31
- train_cond0_prob: float = 0
32
- train_pred_xstart_detach: bool = True
33
- train_interpolate_prob: float = 0
34
- train_interpolate_img: bool = False
35
- manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all
36
- manipulate_cls: str = None
37
- manipulate_shots: int = None
38
- manipulate_loss: ManipulateLossType = ManipulateLossType.bce
39
- manipulate_znormalize: bool = False
40
- manipulate_seed: int = 0
41
- accum_batches: int = 1
42
- autoenc_mid_attn: bool = True
43
- batch_size: int = 16
44
- batch_size_eval: int = None
45
- beatgans_gen_type: GenerativeType = GenerativeType.ddim
46
- beatgans_loss_type: LossType = LossType.mse
47
- beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
48
- beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
49
- beatgans_rescale_timesteps: bool = False
50
- latent_infer_path: str = None
51
- latent_znormalize: bool = False
52
- latent_gen_type: GenerativeType = GenerativeType.ddim
53
- latent_loss_type: LossType = LossType.mse
54
- latent_model_mean_type: ModelMeanType = ModelMeanType.eps
55
- latent_model_var_type: ModelVarType = ModelVarType.fixed_large
56
- latent_rescale_timesteps: bool = False
57
- latent_T_eval: int = 1_000
58
- latent_clip_sample: bool = False
59
- latent_beta_scheduler: str = 'linear'
60
- beta_scheduler: str = 'linear'
61
- data_name: str = ''
62
- data_val_name: str = None
63
- diffusion_type: str = None
64
- dropout: float = 0.1
65
- ema_decay: float = 0.9999
66
- eval_num_images: int = 5_000
67
- eval_every_samples: int = 200_000
68
- eval_ema_every_samples: int = 200_000
69
- fid_use_torch: bool = True
70
- fp16: bool = False
71
- grad_clip: float = 1
72
- img_size: int = 64
73
- lr: float = 0.0001
74
- optimizer: OptimizerType = OptimizerType.adam
75
- weight_decay: float = 0
76
- model_conf: ModelConfig = None
77
- model_name: ModelName = None
78
- model_type: ModelType = None
79
- net_attn: Tuple[int] = None
80
- net_beatgans_attn_head: int = 1
81
- # not necessarily the same as the the number of style channels
82
- net_beatgans_embed_channels: int = 512
83
- net_resblock_updown: bool = True
84
- net_enc_use_time: bool = False
85
- net_enc_pool: str = 'adaptivenonzero'
86
- net_beatgans_gradient_checkpoint: bool = False
87
- net_beatgans_resnet_two_cond: bool = False
88
- net_beatgans_resnet_use_zero_module: bool = True
89
- net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
90
- net_beatgans_resnet_cond_channels: int = None
91
- net_ch_mult: Tuple[int] = None
92
- net_ch: int = 64
93
- net_enc_attn: Tuple[int] = None
94
- net_enc_k: int = None
95
- # number of resblocks for the encoder (half-unet)
96
- net_enc_num_res_blocks: int = 2
97
- net_enc_channel_mult: Tuple[int] = None
98
- net_enc_grad_checkpoint: bool = False
99
- net_autoenc_stochastic: bool = False
100
- net_latent_activation: Activation = Activation.silu
101
- net_latent_channel_mult: Tuple[int] = (1, 2, 4)
102
- net_latent_condition_bias: float = 0
103
- net_latent_dropout: float = 0
104
- net_latent_layers: int = None
105
- net_latent_net_last_act: Activation = Activation.none
106
- net_latent_net_type: LatentNetType = LatentNetType.none
107
- net_latent_num_hid_channels: int = 1024
108
- net_latent_num_time_layers: int = 2
109
- net_latent_skip_layers: Tuple[int] = None
110
- net_latent_time_emb_channels: int = 64
111
- net_latent_use_norm: bool = False
112
- net_latent_time_last_act: bool = False
113
- net_num_res_blocks: int = 2
114
- # number of resblocks for the UNET
115
- net_num_input_res_blocks: int = None
116
- net_enc_num_cls: int = None
117
- num_workers: int = 4
118
- parallel: bool = False
119
- postfix: str = ''
120
- sample_size: int = 64
121
- sample_every_samples: int = 20_000
122
- save_every_samples: int = 100_000
123
- style_ch: int = 512
124
- T_eval: int = 1_000
125
- T_sampler: str = 'uniform'
126
- T: int = 1_000
127
- total_samples: int = 10_000_000
128
- warmup: int = 0
129
- pretrain: PretrainConfig = None
130
- continue_from: PretrainConfig = None
131
- eval_programs: Tuple[str] = None
132
- # if present load the checkpoint from this path instead
133
- eval_path: str = None
134
- base_dir: str = 'checkpoints'
135
- use_cache_dataset: bool = False
136
- data_cache_dir: str = os.path.expanduser('~/cache')
137
- work_cache_dir: str = os.path.expanduser('~/mycache')
138
- # to be overridden
139
- name: str = ''
140
-
141
- def __post_init__(self):
142
- self.batch_size_eval = self.batch_size_eval or self.batch_size
143
- self.data_val_name = self.data_val_name or self.data_name
144
-
145
- def scale_up_gpus(self, num_gpus, num_nodes=1):
146
- self.eval_ema_every_samples *= num_gpus * num_nodes
147
- self.eval_every_samples *= num_gpus * num_nodes
148
- self.sample_every_samples *= num_gpus * num_nodes
149
- self.batch_size *= num_gpus * num_nodes
150
- self.batch_size_eval *= num_gpus * num_nodes
151
- return self
152
-
153
- @property
154
- def batch_size_effective(self):
155
- return self.batch_size * self.accum_batches
156
-
157
- @property
158
- def fid_cache(self):
159
- # we try to use the local dirs to reduce the load over network drives
160
- # hopefully, this would reduce the disconnection problems with sshfs
161
- return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}'
162
-
163
- @property
164
- def data_path(self):
165
- # may use the cache dir
166
- path = data_paths[self.data_name]
167
- if self.use_cache_dataset and path is not None:
168
- path = use_cached_dataset_path(
169
- path, f'{self.data_cache_dir}/{self.data_name}')
170
- return path
171
-
172
- @property
173
- def logdir(self):
174
- return f'{self.base_dir}/{self.name}'
175
-
176
- @property
177
- def generate_dir(self):
178
- # we try to use the local dirs to reduce the load over network drives
179
- # hopefully, this would reduce the disconnection problems with sshfs
180
- return f'{self.work_cache_dir}/gen_images/{self.name}'
181
-
182
- def _make_diffusion_conf(self, T=None):
183
- if self.diffusion_type == 'beatgans':
184
- # can use T < self.T for evaluation
185
- # follows the guided-diffusion repo conventions
186
- # t's are evenly spaced
187
- if self.beatgans_gen_type == GenerativeType.ddpm:
188
- section_counts = [T]
189
- elif self.beatgans_gen_type == GenerativeType.ddim:
190
- section_counts = f'ddim{T}'
191
- else:
192
- raise NotImplementedError()
193
-
194
- return SpacedDiffusionBeatGansConfig(
195
- gen_type=self.beatgans_gen_type,
196
- model_type=self.model_type,
197
- betas=get_named_beta_schedule(self.beta_scheduler, self.T),
198
- model_mean_type=self.beatgans_model_mean_type,
199
- model_var_type=self.beatgans_model_var_type,
200
- loss_type=self.beatgans_loss_type,
201
- rescale_timesteps=self.beatgans_rescale_timesteps,
202
- use_timesteps=space_timesteps(num_timesteps=self.T,
203
- section_counts=section_counts),
204
- fp16=self.fp16,
205
- )
206
- else:
207
- raise NotImplementedError()
208
-
209
- def _make_latent_diffusion_conf(self, T=None):
210
- # can use T < self.T for evaluation
211
- # follows the guided-diffusion repo conventions
212
- # t's are evenly spaced
213
- if self.latent_gen_type == GenerativeType.ddpm:
214
- section_counts = [T]
215
- elif self.latent_gen_type == GenerativeType.ddim:
216
- section_counts = f'ddim{T}'
217
- else:
218
- raise NotImplementedError()
219
-
220
- return SpacedDiffusionBeatGansConfig(
221
- train_pred_xstart_detach=self.train_pred_xstart_detach,
222
- gen_type=self.latent_gen_type,
223
- # latent's model is always ddpm
224
- model_type=ModelType.ddpm,
225
- # latent shares the beta scheduler and full T
226
- betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
227
- model_mean_type=self.latent_model_mean_type,
228
- model_var_type=self.latent_model_var_type,
229
- loss_type=self.latent_loss_type,
230
- rescale_timesteps=self.latent_rescale_timesteps,
231
- use_timesteps=space_timesteps(num_timesteps=self.T,
232
- section_counts=section_counts),
233
- fp16=self.fp16,
234
- )
235
-
236
- @property
237
- def model_out_channels(self):
238
- return 3
239
-
240
- def make_T_sampler(self):
241
- if self.T_sampler == 'uniform':
242
- return UniformSampler(self.T)
243
- else:
244
- raise NotImplementedError()
245
-
246
- def make_diffusion_conf(self):
247
- return self._make_diffusion_conf(self.T)
248
-
249
- def make_eval_diffusion_conf(self):
250
- return self._make_diffusion_conf(T=self.T_eval)
251
-
252
- def make_latent_diffusion_conf(self):
253
- return self._make_latent_diffusion_conf(T=self.T)
254
-
255
- def make_latent_eval_diffusion_conf(self):
256
- # latent can have different eval T
257
- return self._make_latent_diffusion_conf(T=self.latent_T_eval)
258
-
259
- def make_dataset(self, path=None, **kwargs):
260
- return LatentDataLoader(self.window_size,
261
- self.frame_jpgs,
262
- self.lmd_feats_prefix,
263
- self.audio_prefix,
264
- self.raw_audio_prefix,
265
- self.motion_latents_prefix,
266
- self.pose_prefix,
267
- self.db_name,
268
- audio_hz=self.audio_hz)
269
-
270
- def make_loader(self,
271
- dataset,
272
- shuffle: bool,
273
- num_worker: bool = None,
274
- drop_last: bool = True,
275
- batch_size: int = None,
276
- parallel: bool = False):
277
- if parallel and distributed.is_initialized():
278
- # drop last to make sure that there is no added special indexes
279
- sampler = DistributedSampler(dataset,
280
- shuffle=shuffle,
281
- drop_last=True)
282
- else:
283
- sampler = None
284
- return DataLoader(
285
- dataset,
286
- batch_size=batch_size or self.batch_size,
287
- sampler=sampler,
288
- # with sampler, use the sample instead of this option
289
- shuffle=False if sampler else shuffle,
290
- num_workers=num_worker or self.num_workers,
291
- pin_memory=True,
292
- drop_last=drop_last,
293
- multiprocessing_context=get_context('fork'),
294
- )
295
-
296
- def make_model_conf(self):
297
- if self.model_name == ModelName.beatgans_ddpm:
298
- self.model_type = ModelType.ddpm
299
- self.model_conf = BeatGANsUNetConfig(
300
- attention_resolutions=self.net_attn,
301
- channel_mult=self.net_ch_mult,
302
- conv_resample=True,
303
- dims=2,
304
- dropout=self.dropout,
305
- embed_channels=self.net_beatgans_embed_channels,
306
- image_size=self.img_size,
307
- in_channels=3,
308
- model_channels=self.net_ch,
309
- num_classes=None,
310
- num_head_channels=-1,
311
- num_heads_upsample=-1,
312
- num_heads=self.net_beatgans_attn_head,
313
- num_res_blocks=self.net_num_res_blocks,
314
- num_input_res_blocks=self.net_num_input_res_blocks,
315
- out_channels=self.model_out_channels,
316
- resblock_updown=self.net_resblock_updown,
317
- use_checkpoint=self.net_beatgans_gradient_checkpoint,
318
- use_new_attention_order=False,
319
- resnet_two_cond=self.net_beatgans_resnet_two_cond,
320
- resnet_use_zero_module=self.
321
- net_beatgans_resnet_use_zero_module,
322
- )
323
- elif self.model_name in [
324
- ModelName.beatgans_autoenc,
325
- ]:
326
- cls = BeatGANsAutoencConfig
327
- # supports both autoenc and vaeddpm
328
- if self.model_name == ModelName.beatgans_autoenc:
329
- self.model_type = ModelType.autoencoder
330
- else:
331
- raise NotImplementedError()
332
-
333
- if self.net_latent_net_type == LatentNetType.none:
334
- latent_net_conf = None
335
- elif self.net_latent_net_type == LatentNetType.skip:
336
- latent_net_conf = MLPSkipNetConfig(
337
- num_channels=self.style_ch,
338
- skip_layers=self.net_latent_skip_layers,
339
- num_hid_channels=self.net_latent_num_hid_channels,
340
- num_layers=self.net_latent_layers,
341
- num_time_emb_channels=self.net_latent_time_emb_channels,
342
- activation=self.net_latent_activation,
343
- use_norm=self.net_latent_use_norm,
344
- condition_bias=self.net_latent_condition_bias,
345
- dropout=self.net_latent_dropout,
346
- last_act=self.net_latent_net_last_act,
347
- num_time_layers=self.net_latent_num_time_layers,
348
- time_last_act=self.net_latent_time_last_act,
349
- )
350
- else:
351
- raise NotImplementedError()
352
-
353
- self.model_conf = cls(
354
- attention_resolutions=self.net_attn,
355
- channel_mult=self.net_ch_mult,
356
- conv_resample=True,
357
- dims=2,
358
- dropout=self.dropout,
359
- embed_channels=self.net_beatgans_embed_channels,
360
- enc_out_channels=self.style_ch,
361
- enc_pool=self.net_enc_pool,
362
- enc_num_res_block=self.net_enc_num_res_blocks,
363
- enc_channel_mult=self.net_enc_channel_mult,
364
- enc_grad_checkpoint=self.net_enc_grad_checkpoint,
365
- enc_attn_resolutions=self.net_enc_attn,
366
- image_size=self.img_size,
367
- in_channels=3,
368
- model_channels=self.net_ch,
369
- num_classes=None,
370
- num_head_channels=-1,
371
- num_heads_upsample=-1,
372
- num_heads=self.net_beatgans_attn_head,
373
- num_res_blocks=self.net_num_res_blocks,
374
- num_input_res_blocks=self.net_num_input_res_blocks,
375
- out_channels=self.model_out_channels,
376
- resblock_updown=self.net_resblock_updown,
377
- use_checkpoint=self.net_beatgans_gradient_checkpoint,
378
- use_new_attention_order=False,
379
- resnet_two_cond=self.net_beatgans_resnet_two_cond,
380
- resnet_use_zero_module=self.
381
- net_beatgans_resnet_use_zero_module,
382
- latent_net_conf=latent_net_conf,
383
- resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
384
- )
385
- else:
386
- raise NotImplementedError(self.model_name)
387
-
388
- return self.model_conf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/config_base.py DELETED
@@ -1,72 +0,0 @@
1
- import json
2
- import os
3
- from copy import deepcopy
4
- from dataclasses import dataclass
5
-
6
-
7
- @dataclass
8
- class BaseConfig:
9
- def clone(self):
10
- return deepcopy(self)
11
-
12
- def inherit(self, another):
13
- """inherit common keys from a given config"""
14
- common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
15
- for k in common_keys:
16
- setattr(self, k, getattr(another, k))
17
-
18
- def propagate(self):
19
- """push down the configuration to all members"""
20
- for k, v in self.__dict__.items():
21
- if isinstance(v, BaseConfig):
22
- v.inherit(self)
23
- v.propagate()
24
-
25
- def save(self, save_path):
26
- """save config to json file"""
27
- dirname = os.path.dirname(save_path)
28
- if not os.path.exists(dirname):
29
- os.makedirs(dirname)
30
- conf = self.as_dict_jsonable()
31
- with open(save_path, 'w') as f:
32
- json.dump(conf, f)
33
-
34
- def load(self, load_path):
35
- """load json config"""
36
- with open(load_path) as f:
37
- conf = json.load(f)
38
- self.from_dict(conf)
39
-
40
- def from_dict(self, dict, strict=False):
41
- for k, v in dict.items():
42
- if not hasattr(self, k):
43
- if strict:
44
- raise ValueError(f"loading extra '{k}'")
45
- else:
46
- print(f"loading extra '{k}'")
47
- continue
48
- if isinstance(self.__dict__[k], BaseConfig):
49
- self.__dict__[k].from_dict(v)
50
- else:
51
- self.__dict__[k] = v
52
-
53
- def as_dict_jsonable(self):
54
- conf = {}
55
- for k, v in self.__dict__.items():
56
- if isinstance(v, BaseConfig):
57
- conf[k] = v.as_dict_jsonable()
58
- else:
59
- if jsonable(v):
60
- conf[k] = v
61
- else:
62
- # ignore not jsonable
63
- pass
64
- return conf
65
-
66
-
67
- def jsonable(x):
68
- try:
69
- json.dumps(x)
70
- return True
71
- except TypeError:
72
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/dataset.py DELETED
@@ -1,218 +0,0 @@
1
- import os
2
- import librosa
3
- from PIL import Image
4
- from torchvision import transforms
5
- import python_speech_features
6
- import random
7
- import os
8
- import numpy as np
9
- from tqdm import tqdm
10
- import torchvision
11
- import torchvision.transforms as transforms
12
- from PIL import Image
13
-
14
- class LatentDataLoader(object):
15
-
16
- def __init__(
17
- self,
18
- window_size,
19
- frame_jpgs,
20
- lmd_feats_prefix,
21
- audio_prefix,
22
- raw_audio_prefix,
23
- motion_latents_prefix,
24
- pose_prefix,
25
- db_name,
26
- video_fps=25,
27
- audio_hz=50,
28
- size=256,
29
- mfcc_mode=False,
30
- ):
31
- self.window_size = window_size
32
- self.lmd_feats_prefix = lmd_feats_prefix
33
- self.audio_prefix = audio_prefix
34
- self.pose_prefix = pose_prefix
35
- self.video_fps = video_fps
36
- self.audio_hz = audio_hz
37
- self.db_name = db_name
38
- self.raw_audio_prefix = raw_audio_prefix
39
- self.mfcc_mode = mfcc_mode
40
-
41
-
42
- self.transform = torchvision.transforms.Compose([
43
- transforms.Resize((size, size)),
44
- transforms.ToTensor(),
45
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
46
- )
47
-
48
- self.data = []
49
- for db_name in [ 'VoxCeleb2', 'HDTF' ]:
50
- db_png_path = os.path.join(frame_jpgs, db_name)
51
- for clip_name in tqdm(os.listdir(db_png_path)):
52
-
53
- item_dict = dict()
54
- item_dict['clip_name'] = clip_name
55
- item_dict['frame_count'] = len(list(os.listdir(os.path.join(frame_jpgs, db_name, clip_name))))
56
- item_dict['hubert_path'] = os.path.join(audio_prefix, db_name, clip_name +".npy")
57
- item_dict['wav_path'] = os.path.join(raw_audio_prefix, db_name, clip_name +".wav")
58
-
59
- item_dict['yaw_pitch_roll_path'] = os.path.join(pose_prefix, db_name, 'raw_videos_pose_yaw_pitch_roll', clip_name +".npy")
60
- if not os.path.exists(item_dict['yaw_pitch_roll_path']):
61
- print(f"{db_name}'s {clip_name} miss yaw_pitch_roll_path")
62
- continue
63
-
64
- item_dict['yaw_pitch_roll'] = np.load(item_dict['yaw_pitch_roll_path'])
65
- item_dict['yaw_pitch_roll'] = np.clip(item_dict['yaw_pitch_roll'], -90, 90) / 90.0
66
-
67
- if not os.path.exists(item_dict['wav_path']):
68
- print(f"{db_name}'s {clip_name} miss wav_path")
69
- continue
70
-
71
- if not os.path.exists(item_dict['hubert_path']):
72
- print(f"{db_name}'s {clip_name} miss hubert_path")
73
- continue
74
-
75
-
76
- if self.mfcc_mode:
77
- wav, sr = librosa.load(item_dict['wav_path'], sr=16000)
78
- input_values = python_speech_features.mfcc(signal=wav,samplerate=sr,numcep=13,winlen=0.025,winstep=0.01)
79
- d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
80
- d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
81
- input_values = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
82
- item_dict['hubert_obj'] = input_values
83
- else:
84
- item_dict['hubert_obj'] = np.load(item_dict['hubert_path'], mmap_mode='r')
85
- item_dict['lmd_path'] = os.path.join(lmd_feats_prefix, db_name, clip_name +".txt")
86
- item_dict['lmd_obj_full'] = self.read_landmark_info(item_dict['lmd_path'], upper_face=False)
87
-
88
- motion_start_path = os.path.join(motion_latents_prefix, db_name, 'motions', clip_name +".npy")
89
- motion_direction_path = os.path.join(motion_latents_prefix, db_name, 'directions', clip_name +".npy")
90
-
91
- if not os.path.exists(motion_start_path):
92
- print(f"{db_name}'s {clip_name} miss motion_start_path")
93
- continue
94
- if not os.path.exists(motion_direction_path):
95
- print(f"{db_name}'s {clip_name} miss motion_direction_path")
96
- continue
97
-
98
- item_dict['motion_start_obj'] = np.load(motion_start_path)
99
- item_dict['motion_direction_obj'] = np.load(motion_direction_path)
100
-
101
- if self.mfcc_mode:
102
- min_len = min(
103
- item_dict['lmd_obj_full'].shape[0],
104
- item_dict['yaw_pitch_roll'].shape[0],
105
- item_dict['motion_start_obj'].shape[0],
106
- item_dict['motion_direction_obj'].shape[0],
107
- int(item_dict['hubert_obj'].shape[0]/4),
108
- item_dict['frame_count']
109
- )
110
- item_dict['frame_count'] = min_len
111
- item_dict['hubert_obj'] = item_dict['hubert_obj'][:min_len*4,:]
112
- else:
113
- min_len = min(
114
- item_dict['lmd_obj_full'].shape[0],
115
- item_dict['yaw_pitch_roll'].shape[0],
116
- item_dict['motion_start_obj'].shape[0],
117
- item_dict['motion_direction_obj'].shape[0],
118
- int(item_dict['hubert_obj'].shape[1]/2),
119
- item_dict['frame_count']
120
- )
121
-
122
- item_dict['frame_count'] = min_len
123
- item_dict['hubert_obj'] = item_dict['hubert_obj'][:, :min_len*2, :]
124
-
125
- if min_len < self.window_size * self.video_fps + 5:
126
- continue
127
-
128
- print('Db count:', len(self.data))
129
-
130
- def get_single_image(self, image_path):
131
- img_source = Image.open(image_path).convert('RGB')
132
- img_source = self.transform(img_source)
133
- return img_source
134
-
135
- def get_multiple_ranges(self, lists, multi_ranges):
136
- # Ensure that multi_ranges is a list of tuples
137
- if not all(isinstance(item, tuple) and len(item) == 2 for item in multi_ranges):
138
- raise ValueError("multi_ranges must be a list of (start, end) tuples with exactly two elements each")
139
- extracted_elements = [lists[start:end] for start, end in multi_ranges]
140
- flat_list = [item for sublist in extracted_elements for item in sublist]
141
- return flat_list
142
-
143
-
144
- def read_landmark_info(self, lmd_path, upper_face=True):
145
- with open(lmd_path, 'r') as file:
146
- lmd_lines = file.readlines()
147
- lmd_lines.sort()
148
-
149
- total_lmd_obj = []
150
- for i, line in enumerate(lmd_lines):
151
- # Split the coordinates and filter out any empty strings
152
- coords = [c for c in line.strip().split(' ') if c]
153
- coords = coords[1:] # do not include the file name in the first row
154
- lmd_obj = []
155
- if upper_face:
156
- # Ensure that the coordinates are parsed as integers
157
- for coord_pair in self.get_multiple_ranges(coords, [(0, 3), (14, 27), (36, 48)]): # 28个
158
- x, y = coord_pair.split('_')
159
- lmd_obj.append((int(x)/512, int(y)/512))
160
- else:
161
- for coord_pair in coords:
162
- x, y = coord_pair.split('_')
163
- lmd_obj.append((int(x)/512, int(y)/512))
164
- total_lmd_obj.append(lmd_obj)
165
-
166
- return np.array(total_lmd_obj, dtype=np.float32)
167
-
168
- def calculate_face_height(self, landmarks):
169
- forehead_center = (landmarks[ :, 21, :] + landmarks[:, 22, :]) / 2
170
- chin_bottom = landmarks[:, 8, :]
171
- distances = np.linalg.norm(forehead_center - chin_bottom, axis=1, keepdims=True)
172
- return distances
173
-
174
- def __getitem__(self, index):
175
-
176
- data_item = self.data[index]
177
- hubert_obj = data_item['hubert_obj']
178
- frame_count = data_item['frame_count']
179
- lmd_obj_full = data_item['lmd_obj_full']
180
- yaw_pitch_roll = data_item['yaw_pitch_roll']
181
- motion_start_obj = data_item['motion_start_obj']
182
- motion_direction_obj = data_item['motion_direction_obj']
183
-
184
- frame_end_index = random.randint(self.window_size * self.video_fps + 1, frame_count - 1)
185
- frame_start_index = frame_end_index - self.window_size * self.video_fps
186
- frame_hint_index = frame_start_index - 1
187
-
188
- audio_start_index = int(frame_start_index * (self.audio_hz / self.video_fps))
189
- audio_end_index = int(frame_end_index * (self.audio_hz / self.video_fps))
190
-
191
- if self.mfcc_mode:
192
- audio_feats = hubert_obj[audio_start_index:audio_end_index, :]
193
- else:
194
- audio_feats = hubert_obj[:, audio_start_index:audio_end_index, :]
195
-
196
- lmd_obj_full = lmd_obj_full[frame_hint_index:frame_end_index, :]
197
-
198
- yaw_pitch_roll = yaw_pitch_roll[frame_start_index:frame_end_index, :]
199
-
200
- motion_start = motion_start_obj[frame_hint_index]
201
- motion_direction_start = motion_direction_obj[frame_hint_index]
202
- motion_direction = motion_direction_obj[frame_start_index:frame_end_index, :]
203
-
204
-
205
-
206
- return {
207
- 'motion_start': motion_start,
208
- 'motion_direction': motion_direction,
209
- 'audio_feats': audio_feats,
210
- 'face_location': lmd_obj_full[1:, 30, 0], # '1:' means taking the first frame as the driven frame. '30' is the noise location, '0' means x coordinate
211
- 'face_scale': self.calculate_face_height(lmd_obj_full[1:,:,:]),
212
- 'yaw_pitch_roll': yaw_pitch_roll,
213
- 'motion_direction_start': motion_direction_start,
214
- }
215
-
216
- def __len__(self):
217
- return len(self.data)
218
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/dataset_util.py DELETED
@@ -1,13 +0,0 @@
1
- import shutil
2
- import os
3
- from dist_utils import *
4
-
5
-
6
- def use_cached_dataset_path(source_path, cache_path):
7
- if get_rank() == 0:
8
- if not os.path.exists(cache_path):
9
- # shutil.rmtree(cache_path)
10
- print(f'copying the data: {source_path} to {cache_path}')
11
- shutil.copytree(source_path, cache_path)
12
- barrier()
13
- return cache_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/demo.py DELETED
@@ -1,295 +0,0 @@
1
- from LIA_Model import LIA_Model
2
- import torch
3
- import numpy as np
4
- import os
5
- from PIL import Image
6
- from tqdm import tqdm
7
- import argparse
8
- import numpy as np
9
- from torchvision import transforms
10
- from templates import *
11
- import argparse
12
- import shutil
13
- from moviepy.editor import *
14
- import librosa
15
- import python_speech_features
16
- import importlib.util
17
- import time
18
-
19
- def check_package_installed(package_name):
20
- package_spec = importlib.util.find_spec(package_name)
21
- if package_spec is None:
22
- print(f"{package_name} is not installed.")
23
- return False
24
- else:
25
- print(f"{package_name} is installed.")
26
- return True
27
-
28
- def frames_to_video(input_path, audio_path, output_path, fps=25):
29
- image_files = [os.path.join(input_path, img) for img in sorted(os.listdir(input_path))]
30
- clips = [ImageClip(m).set_duration(1/fps) for m in image_files]
31
- video = concatenate_videoclips(clips, method="compose")
32
-
33
- audio = AudioFileClip(audio_path)
34
- final_video = video.set_audio(audio)
35
- final_video.write_videofile(output_path, fps=fps, codec='libx264', audio_codec='aac')
36
-
37
- def load_image(filename, size):
38
- img = Image.open(filename).convert('RGB')
39
- img = img.resize((size, size))
40
- img = np.asarray(img)
41
- img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
42
- return img / 255.0
43
-
44
- def img_preprocessing(img_path, size):
45
- img = load_image(img_path, size) # [0, 1]
46
- img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
47
- imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
48
- return imgs_norm
49
-
50
- def saved_image(img_tensor, img_path):
51
- toPIL = transforms.ToPILImage()
52
- img = toPIL(img_tensor.detach().cpu().squeeze(0)) # 使用squeeze(0)来移除批次维度
53
- img.save(img_path)
54
-
55
- def main(args):
56
- frames_result_saved_path = os.path.join(args.result_path, 'frames')
57
- os.makedirs(frames_result_saved_path, exist_ok=True)
58
- test_image_name = os.path.splitext(os.path.basename(args.test_image_path))[0]
59
- audio_name = os.path.splitext(os.path.basename(args.test_audio_path))[0]
60
- predicted_video_256_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}.mp4')
61
- predicted_video_512_path = os.path.join(args.result_path, f'{test_image_name}-{audio_name}_SR.mp4')
62
-
63
- #======Loading Stage 1 model=========
64
- lia = LIA_Model(motion_dim=args.motion_dim, fusion_type='weighted_sum')
65
- lia.load_lightning_model(args.stage1_checkpoint_path)
66
- lia.to(args.device)
67
- #============================
68
-
69
- conf = ffhq256_autoenc()
70
- conf.seed = args.seed
71
- conf.decoder_layers = args.decoder_layers
72
- conf.infer_type = args.infer_type
73
- conf.motion_dim = args.motion_dim
74
-
75
- if args.infer_type == 'mfcc_full_control':
76
- conf.face_location=True
77
- conf.face_scale=True
78
- conf.mfcc = True
79
-
80
- elif args.infer_type == 'mfcc_pose_only':
81
- conf.face_location=False
82
- conf.face_scale=False
83
- conf.mfcc = True
84
-
85
- elif args.infer_type == 'hubert_pose_only':
86
- conf.face_location=False
87
- conf.face_scale=False
88
- conf.mfcc = False
89
-
90
- elif args.infer_type == 'hubert_audio_only':
91
- conf.face_location=False
92
- conf.face_scale=False
93
- conf.mfcc = False
94
-
95
- elif args.infer_type == 'hubert_full_control':
96
- conf.face_location=True
97
- conf.face_scale=True
98
- conf.mfcc = False
99
-
100
- else:
101
- print('Type NOT Found!')
102
- exit(0)
103
-
104
- if not os.path.exists(args.test_image_path):
105
- print(f'{args.test_image_path} does not exist!')
106
- exit(0)
107
-
108
- if not os.path.exists(args.test_audio_path):
109
- print(f'{args.test_audio_path} does not exist!')
110
- exit(0)
111
-
112
- img_source = img_preprocessing(args.test_image_path, args.image_size).to(args.device)
113
- one_shot_lia_start, one_shot_lia_direction, feats = lia.get_start_direction_code(img_source, img_source, img_source, img_source)
114
-
115
-
116
- #======Loading Stage 2 model=========
117
- model = LitModel(conf)
118
- state = torch.load(args.stage2_checkpoint_path, map_location='cpu')
119
- model.load_state_dict(state, strict=True)
120
- model.ema_model.eval()
121
- model.ema_model.to(args.device);
122
- #=================================
123
-
124
-
125
- #======Audio Input=========
126
- if conf.infer_type.startswith('mfcc'):
127
- # MFCC features
128
- wav, sr = librosa.load(args.test_audio_path, sr=16000)
129
- input_values = python_speech_features.mfcc(signal=wav, samplerate=sr, numcep=13, winlen=0.025, winstep=0.01)
130
- d_mfcc_feat = python_speech_features.base.delta(input_values, 1)
131
- d_mfcc_feat2 = python_speech_features.base.delta(input_values, 2)
132
- audio_driven_obj = np.hstack((input_values, d_mfcc_feat, d_mfcc_feat2))
133
- frame_start, frame_end = 0, int(audio_driven_obj.shape[0]/4)
134
- audio_start, audio_end = int(frame_start * 4), int(frame_end * 4) # The video frame is fixed to 25 hz and the audio is fixed to 100 hz
135
-
136
- audio_driven = torch.Tensor(audio_driven_obj[audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
137
-
138
- elif conf.infer_type.startswith('hubert'):
139
- # Hubert features
140
- if not os.path.exists(args.test_hubert_path):
141
-
142
- if not check_package_installed('transformers'):
143
- print('Please install transformers module first.')
144
- exit(0)
145
- hubert_model_path = 'ckpts/chinese-hubert-large'
146
- if not os.path.exists(hubert_model_path):
147
- print('Please download the hubert weight into the ckpts path first.')
148
- exit(0)
149
- print('You did not extract the audio features in advance, extracting online now, which will increase processing delay')
150
-
151
- start_time = time.time()
152
-
153
- # load hubert model
154
- from transformers import Wav2Vec2FeatureExtractor, HubertModel
155
- audio_model = HubertModel.from_pretrained(hubert_model_path).to(args.device)
156
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(hubert_model_path)
157
- audio_model.feature_extractor._freeze_parameters()
158
- audio_model.eval()
159
-
160
- # hubert model forward pass
161
- audio, sr = librosa.load(args.test_audio_path, sr=16000)
162
- input_values = feature_extractor(audio, sampling_rate=16000, padding=True, do_normalize=True, return_tensors="pt").input_values
163
- input_values = input_values.to(args.device)
164
- ws_feats = []
165
- with torch.no_grad():
166
- outputs = audio_model(input_values, output_hidden_states=True)
167
- for i in range(len(outputs.hidden_states)):
168
- ws_feats.append(outputs.hidden_states[i].detach().cpu().numpy())
169
- ws_feat_obj = np.array(ws_feats)
170
- ws_feat_obj = np.squeeze(ws_feat_obj, 1)
171
- ws_feat_obj = np.pad(ws_feat_obj, ((0, 0), (0, 1), (0, 0)), 'edge') # align the audio length with video frame
172
-
173
- execution_time = time.time() - start_time
174
- print(f"Extraction Audio Feature: {execution_time:.2f} Seconds")
175
-
176
- audio_driven_obj = ws_feat_obj
177
- else:
178
- print(f'Using audio feature from path: {args.test_hubert_path}')
179
- audio_driven_obj = np.load(args.test_hubert_path)
180
-
181
- frame_start, frame_end = 0, int(audio_driven_obj.shape[1]/2)
182
- audio_start, audio_end = int(frame_start * 2), int(frame_end * 2) # The video frame is fixed to 25 hz and the audio is fixed to 50 hz
183
-
184
- audio_driven = torch.Tensor(audio_driven_obj[:,audio_start:audio_end,:]).unsqueeze(0).float().to(args.device)
185
- #============================
186
-
187
- # Diffusion Noise
188
- noisyT = th.randn((1,frame_end, args.motion_dim)).to(args.device)
189
-
190
- #======Inputs for Attribute Control=========
191
- if os.path.exists(args.pose_driven_path):
192
- pose_obj = np.load(args.pose_driven_path)
193
-
194
-
195
- if len(pose_obj.shape) != 2:
196
- print('please check your pose information. The shape must be like (T, 3).')
197
- exit(0)
198
- if pose_obj.shape[1] != 3:
199
- print('please check your pose information. The shape must be like (T, 3).')
200
- exit(0)
201
-
202
- if pose_obj.shape[0] >= frame_end:
203
- pose_obj = pose_obj[:frame_end,:]
204
- else:
205
- padding = np.tile(pose_obj[-1, :], (frame_end - pose_obj.shape[0], 1))
206
- pose_obj = np.vstack((pose_obj, padding))
207
-
208
- pose_signal = torch.Tensor(pose_obj).unsqueeze(0).to(args.device) / 90 # 90 is for normalization here
209
- else:
210
- yaw_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_yaw
211
- pitch_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_pitch
212
- roll_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.pose_roll
213
- pose_signal = torch.cat((yaw_signal, pitch_signal, roll_signal), dim=-1)
214
-
215
- pose_signal = torch.clamp(pose_signal, -1, 1)
216
-
217
- face_location_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_location
218
- face_scae_signal = torch.zeros(1, frame_end, 1).to(args.device) + args.face_scale
219
- #===========================================
220
-
221
- start_time = time.time()
222
-
223
- #======Diffusion Denosing Process=========
224
- generated_directions = model.render(one_shot_lia_start, one_shot_lia_direction, audio_driven, face_location_signal, face_scae_signal, pose_signal, noisyT, args.step_T, control_flag=args.control_flag)
225
- #=========================================
226
-
227
- execution_time = time.time() - start_time
228
- print(f"Motion Diffusion Model: {execution_time:.2f} Seconds")
229
-
230
- generated_directions = generated_directions.detach().cpu().numpy()
231
-
232
- start_time = time.time()
233
- #======Rendering images frame-by-frame=========
234
- for pred_index in tqdm(range(generated_directions.shape[1])):
235
- ori_img_recon = lia.render(one_shot_lia_start, torch.Tensor(generated_directions[:,pred_index,:]).to(args.device), feats)
236
- ori_img_recon = ori_img_recon.clamp(-1, 1)
237
- wav_pred = (ori_img_recon.detach() + 1) / 2
238
- saved_image(wav_pred, os.path.join(frames_result_saved_path, "%06d.png"%(pred_index)))
239
- #==============================================
240
-
241
- execution_time = time.time() - start_time
242
- print(f"Renderer Model: {execution_time:.2f} Seconds")
243
-
244
- frames_to_video(frames_result_saved_path, args.test_audio_path, predicted_video_256_path)
245
-
246
- shutil.rmtree(frames_result_saved_path)
247
-
248
-
249
- # Enhancer
250
- # Code is modified from https://github.com/OpenTalker/SadTalker/blob/cd4c0465ae0b54a6f85af57f5c65fec9fe23e7f8/src/utils/face_enhancer.py#L26
251
-
252
- if args.face_sr and check_package_installed('gfpgan'):
253
- from face_sr.face_enhancer import enhancer_list
254
- import imageio
255
-
256
- # Super-resolution
257
- imageio.mimsave(predicted_video_512_path+'.tmp.mp4', enhancer_list(predicted_video_256_path, method='gfpgan', bg_upsampler=None), fps=float(25))
258
-
259
- # Merge audio and video
260
- video_clip = VideoFileClip(predicted_video_512_path+'.tmp.mp4')
261
- audio_clip = AudioFileClip(predicted_video_256_path)
262
- final_clip = video_clip.set_audio(audio_clip)
263
- final_clip.write_videofile(predicted_video_512_path, codec='libx264', audio_codec='aac')
264
-
265
- os.remove(predicted_video_512_path+'.tmp.mp4')
266
-
267
- if __name__ == '__main__':
268
- parser = argparse.ArgumentParser()
269
- parser.add_argument('--infer_type', type=str, default='mfcc_pose_only', help='mfcc_pose_only or mfcc_full_control')
270
- parser.add_argument('--test_image_path', type=str, default='./test_demos/portraits/monalisa.jpg', help='Path to the portrait')
271
- parser.add_argument('--test_audio_path', type=str, default='./test_demos/audios/english_female.wav', help='Path to the driven audio')
272
- parser.add_argument('--test_hubert_path', type=str, default='./test_demos/audios_hubert/english_female.npy', help='Path to the driven audio(hubert type). Not needed for MFCC')
273
- parser.add_argument('--result_path', type=str, default='./results/', help='Type of inference')
274
- parser.add_argument('--stage1_checkpoint_path', type=str, default='./ckpts/stage1.ckpt', help='Path to the checkpoint of Stage1')
275
- parser.add_argument('--stage2_checkpoint_path', type=str, default='./ckpts/pose_only.ckpt', help='Path to the checkpoint of Stage2')
276
- parser.add_argument('--seed', type=int, default=0, help='seed for generations')
277
- parser.add_argument('--control_flag', action='store_true', help='Whether to use control signal or not')
278
- parser.add_argument('--pose_yaw', type=float, default=0.25, help='range from -1 to 1 (-90 ~ 90 angles)')
279
- parser.add_argument('--pose_pitch', type=float, default=0, help='range from -1 to 1 (-90 ~ 90 angles)')
280
- parser.add_argument('--pose_roll', type=float, default=0, help='range from -1 to 1 (-90 ~ 90 angles)')
281
- parser.add_argument('--face_location', type=float, default=0.5, help='range from 0 to 1 (from left to right)')
282
- parser.add_argument('--pose_driven_path', type=str, default='xxx', help='path to pose numpy, shape is (T, 3). You can check the following code https://github.com/liutaocode/talking_face_preprocessing to extract the yaw, pitch and roll.')
283
- parser.add_argument('--face_scale', type=float, default=0.5, help='range from 0 to 1 (from small to large)')
284
- parser.add_argument('--step_T', type=int, default=50, help='Step T for diffusion denoising process')
285
- parser.add_argument('--image_size', type=int, default=256, help='Size of the image. Do not change.')
286
- parser.add_argument('--device', type=str, default='cuda:0', help='Device for computation')
287
- parser.add_argument('--motion_dim', type=int, default=20, help='Dimension of motion. Do not change.')
288
- parser.add_argument('--decoder_layers', type=int, default=2, help='Layer number for the conformer.')
289
- parser.add_argument('--face_sr', action='store_true', help='Face super-resolution (Optional). Please install GFPGAN first')
290
-
291
-
292
-
293
- args = parser.parse_args()
294
-
295
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/diffusion/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from typing import Union
2
-
3
- from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
4
-
5
- Sampler = Union[SpacedDiffusionBeatGans]
6
- SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
 
 
 
 
 
 
 
code/diffusion/base.py DELETED
@@ -1,1128 +0,0 @@
1
- """
2
- This code started out as a PyTorch port of Ho et al's diffusion models:
3
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
4
-
5
- Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
6
- """
7
-
8
- from model.unet_autoenc import AutoencReturn
9
- from config_base import BaseConfig
10
- import enum
11
- import math
12
-
13
- import numpy as np
14
- import torch as th
15
- from model import *
16
- from model.nn import mean_flat
17
- from typing import NamedTuple, Tuple
18
- from choices import *
19
- from torch.cuda.amp import autocast
20
- import torch.nn.functional as F
21
-
22
- from dataclasses import dataclass
23
-
24
-
25
- @dataclass
26
- class GaussianDiffusionBeatGansConfig(BaseConfig):
27
- gen_type: GenerativeType
28
- betas: Tuple[float]
29
- model_type: ModelType
30
- model_mean_type: ModelMeanType
31
- model_var_type: ModelVarType
32
- loss_type: LossType
33
- rescale_timesteps: bool
34
- fp16: bool
35
- train_pred_xstart_detach: bool = True
36
-
37
- def make_sampler(self):
38
- return GaussianDiffusionBeatGans(self)
39
-
40
-
41
- class GaussianDiffusionBeatGans:
42
- """
43
- Utilities for training and sampling diffusion models.
44
-
45
- Ported directly from here, and then adapted over time to further experimentation.
46
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
47
-
48
- :param betas: a 1-D numpy array of betas for each diffusion timestep,
49
- starting at T and going to 1.
50
- :param model_mean_type: a ModelMeanType determining what the model outputs.
51
- :param model_var_type: a ModelVarType determining how variance is output.
52
- :param loss_type: a LossType determining the loss function to use.
53
- :param rescale_timesteps: if True, pass floating point timesteps into the
54
- model so that they are always scaled like in the
55
- original paper (0 to 1000).
56
- """
57
- def __init__(self, conf: GaussianDiffusionBeatGansConfig):
58
- self.conf = conf
59
- self.model_mean_type = conf.model_mean_type
60
- self.model_var_type = conf.model_var_type
61
- self.loss_type = conf.loss_type
62
- self.rescale_timesteps = conf.rescale_timesteps
63
-
64
- # Use float64 for accuracy.
65
- betas = np.array(conf.betas, dtype=np.float64)
66
- self.betas = betas
67
- assert len(betas.shape) == 1, "betas must be 1-D"
68
- assert (betas > 0).all() and (betas <= 1).all()
69
-
70
- self.num_timesteps = int(betas.shape[0])
71
-
72
- alphas = 1.0 - betas
73
- self.alphas_cumprod = np.cumprod(alphas, axis=0)
74
- self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
75
- self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
76
- assert self.alphas_cumprod_prev.shape == (self.num_timesteps, )
77
-
78
- # calculations for diffusion q(x_t | x_{t-1}) and others
79
- self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
80
- self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
81
- self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
82
- self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
83
- self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod -
84
- 1)
85
-
86
- # calculations for posterior q(x_{t-1} | x_t, x_0)
87
- self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) /
88
- (1.0 - self.alphas_cumprod))
89
- # log calculation clipped because the posterior variance is 0 at the
90
- # beginning of the diffusion chain.
91
- self.posterior_log_variance_clipped = np.log(
92
- np.append(self.posterior_variance[1], self.posterior_variance[1:]))
93
- self.posterior_mean_coef1 = (betas *
94
- np.sqrt(self.alphas_cumprod_prev) /
95
- (1.0 - self.alphas_cumprod))
96
- self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) *
97
- np.sqrt(alphas) /
98
- (1.0 - self.alphas_cumprod))
99
-
100
- def training_losses(self,
101
- model,
102
- motion_direction_start: th.Tensor,
103
- motion_target: th.Tensor,
104
- motion_start: th.Tensor,
105
- audio_feats: th.Tensor,
106
- face_location: th.Tensor,
107
- face_scale: th.Tensor,
108
- yaw_pitch_roll: th.Tensor,
109
- t: th.Tensor,
110
- model_kwargs=None,
111
- noise: th.Tensor = None):
112
- """
113
- Compute training losses for a single timestep.
114
-
115
- :param model: the model to evaluate loss on.
116
- :param x_start: the [N x C x ...] tensor of inputs.
117
- :param t: a batch of timestep indices.
118
- :param model_kwargs: if not None, a dict of extra keyword arguments to
119
- pass to the model. This can be used for conditioning.
120
- :param noise: if specified, the specific Gaussian noise to try to remove.
121
- :return: a dict with the key "loss" containing a tensor of shape [N].
122
- Some mean or variance settings may also have other keys.
123
- """
124
- if model_kwargs is None:
125
- model_kwargs = {}
126
- if noise is None:
127
- noise = th.randn_like(motion_target)
128
-
129
- x_t = self.q_sample(motion_target, t, noise=noise)
130
-
131
- terms = {'x_t': x_t}
132
-
133
- if self.loss_type in [
134
- LossType.mse,
135
- LossType.l1,
136
- ]:
137
- with autocast(self.conf.fp16):
138
- # x_t is static wrt. to the diffusion process
139
- predicted_direction, predicted_location, predicted_scale, predicted_pose = model.forward(motion_start,
140
- motion_direction_start,
141
- audio_feats,
142
- face_location,
143
- face_scale,
144
- yaw_pitch_roll,
145
- x_t.detach(),
146
- self._scale_timesteps(t),
147
- control_flag=False)
148
-
149
-
150
- target_types = {
151
- ModelMeanType.eps: noise,
152
- }
153
- target = target_types[self.model_mean_type]
154
- assert predicted_direction.shape == target.shape == motion_target.shape
155
-
156
- if self.loss_type == LossType.mse:
157
- if self.model_mean_type == ModelMeanType.eps:
158
-
159
- direction_loss = mean_flat((target - predicted_direction)**2)
160
- # import pdb;pdb.set_trace()
161
- location_loss = mean_flat((face_location.unsqueeze(-1) - predicted_location)**2)
162
- scale_loss = mean_flat((face_scale - predicted_scale)**2)
163
- pose_loss = mean_flat((yaw_pitch_roll - predicted_pose)**2)
164
-
165
- terms["mse"] = direction_loss + location_loss + scale_loss + pose_loss
166
-
167
- else:
168
- raise NotImplementedError()
169
- elif self.loss_type == LossType.l1:
170
- # (n, c, h, w) => (n, )
171
- terms["mse"] = mean_flat((target - predicted_direction).abs())
172
- else:
173
- raise NotImplementedError()
174
-
175
- if "vb" in terms:
176
- # if learning the variance also use the vlb loss
177
- terms["loss"] = terms["mse"] + terms["vb"]
178
- else:
179
- terms["loss"] = terms["mse"]
180
- else:
181
- raise NotImplementedError(self.loss_type)
182
-
183
-
184
- return terms
185
-
186
- def sample(self,
187
- model: Model,
188
- shape=None,
189
- noise=None,
190
- cond=None,
191
- x_start=None,
192
- clip_denoised=True,
193
- model_kwargs=None,
194
- progress=False):
195
- """
196
- Args:
197
- x_start: given for the autoencoder
198
- """
199
- if model_kwargs is None:
200
- model_kwargs = {}
201
- if self.conf.model_type.has_autoenc():
202
- model_kwargs['x_start'] = x_start
203
- model_kwargs['cond'] = cond
204
-
205
- if self.conf.gen_type == GenerativeType.ddpm:
206
- return self.p_sample_loop(model,
207
- shape=shape,
208
- noise=noise,
209
- clip_denoised=clip_denoised,
210
- model_kwargs=model_kwargs,
211
- progress=progress)
212
- elif self.conf.gen_type == GenerativeType.ddim:
213
- return self.ddim_sample_loop(model,
214
- shape=shape,
215
- noise=noise,
216
- clip_denoised=clip_denoised,
217
- model_kwargs=model_kwargs,
218
- progress=progress)
219
- else:
220
- raise NotImplementedError()
221
-
222
- def q_mean_variance(self, x_start, t):
223
- """
224
- Get the distribution q(x_t | x_0).
225
-
226
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
227
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
228
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
229
- """
230
- mean = (
231
- _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
232
- x_start)
233
- variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t,
234
- x_start.shape)
235
- log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod,
236
- t, x_start.shape)
237
- return mean, variance, log_variance
238
-
239
- def q_sample(self, x_start, t, noise=None):
240
- """
241
- Diffuse the data for a given number of diffusion steps.
242
-
243
- In other words, sample from q(x_t | x_0).
244
-
245
- :param x_start: the initial data batch.
246
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
247
- :param noise: if specified, the split-out normal noise.
248
- :return: A noisy version of x_start.
249
- """
250
- if noise is None:
251
- noise = th.randn_like(x_start)
252
- assert noise.shape == x_start.shape
253
- return (
254
- _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
255
- x_start + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
256
- t, x_start.shape) * noise)
257
-
258
- def q_posterior_mean_variance(self, x_start, x_t, t):
259
- """
260
- Compute the mean and variance of the diffusion posterior:
261
-
262
- q(x_{t-1} | x_t, x_0)
263
-
264
- """
265
- assert x_start.shape == x_t.shape
266
- posterior_mean = (
267
- _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) *
268
- x_start +
269
- _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) *
270
- x_t)
271
- posterior_variance = _extract_into_tensor(self.posterior_variance, t,
272
- x_t.shape)
273
- posterior_log_variance_clipped = _extract_into_tensor(
274
- self.posterior_log_variance_clipped, t, x_t.shape)
275
- assert (posterior_mean.shape[0] == posterior_variance.shape[0] ==
276
- posterior_log_variance_clipped.shape[0] == x_start.shape[0])
277
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
278
-
279
- def p_mean_variance(self,
280
- model,
281
- x,
282
- t,
283
- clip_denoised=True,
284
- denoised_fn=None,
285
- model_kwargs=None):
286
- """
287
- Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
288
- the initial x, x_0.
289
-
290
- :param model: the model, which takes a signal and a batch of timesteps
291
- as input.
292
- :param x: the [N x C x ...] tensor at time t.
293
- :param t: a 1-D Tensor of timesteps.
294
- :param clip_denoised: if True, clip the denoised signal into [-1, 1].
295
- :param denoised_fn: if not None, a function which applies to the
296
- x_start prediction before it is used to sample. Applies before
297
- clip_denoised.
298
- :param model_kwargs: if not None, a dict of extra keyword arguments to
299
- pass to the model. This can be used for conditioning.
300
- :return: a dict with the following keys:
301
- - 'mean': the model mean output.
302
- - 'variance': the model variance output.
303
- - 'log_variance': the log of 'variance'.
304
- - 'pred_xstart': the prediction for x_0.
305
- """
306
- if model_kwargs is None:
307
- model_kwargs = {}
308
-
309
- motion_start = model_kwargs['start']
310
- audio_feats = model_kwargs['audio_driven']
311
- face_location = model_kwargs['face_location']
312
- face_scale = model_kwargs['face_scale']
313
- yaw_pitch_roll = model_kwargs['yaw_pitch_roll']
314
- motion_direction_start = model_kwargs['motion_direction_start']
315
- control_flag = model_kwargs['control_flag']
316
-
317
- B, C = x.shape[:2]
318
- assert t.shape == (B, )
319
- with autocast(self.conf.fp16):
320
- model_forward, _, _, _ = model.forward(motion_start,
321
- motion_direction_start,
322
- audio_feats,
323
- face_location,
324
- face_scale,
325
- yaw_pitch_roll,
326
- x,
327
- self._scale_timesteps(t),
328
- control_flag)
329
- model_output = model_forward
330
-
331
- if self.model_var_type in [
332
- ModelVarType.fixed_large, ModelVarType.fixed_small
333
- ]:
334
- model_variance, model_log_variance = {
335
- # for fixedlarge, we set the initial (log-)variance like so
336
- # to get a better decoder log likelihood.
337
- ModelVarType.fixed_large: (
338
- np.append(self.posterior_variance[1], self.betas[1:]),
339
- np.log(
340
- np.append(self.posterior_variance[1], self.betas[1:])),
341
- ),
342
- ModelVarType.fixed_small: (
343
- self.posterior_variance,
344
- self.posterior_log_variance_clipped,
345
- ),
346
- }[self.model_var_type]
347
- model_variance = _extract_into_tensor(model_variance, t, x.shape)
348
- model_log_variance = _extract_into_tensor(model_log_variance, t,
349
- x.shape)
350
-
351
- def process_xstart(x):
352
- if denoised_fn is not None:
353
- x = denoised_fn(x)
354
- if clip_denoised:
355
- return x.clamp(-1, 1)
356
- return x
357
-
358
- if self.model_mean_type in [
359
- ModelMeanType.eps,
360
- ]:
361
- if self.model_mean_type == ModelMeanType.eps:
362
- pred_xstart = process_xstart(
363
- self._predict_xstart_from_eps(x_t=x, t=t,
364
- eps=model_output))
365
- else:
366
- raise NotImplementedError()
367
- model_mean, _, _ = self.q_posterior_mean_variance(
368
- x_start=pred_xstart, x_t=x, t=t)
369
- else:
370
- raise NotImplementedError(self.model_mean_type)
371
-
372
- assert (model_mean.shape == model_log_variance.shape ==
373
- pred_xstart.shape == x.shape)
374
- return {
375
- "mean": model_mean,
376
- "variance": model_variance,
377
- "log_variance": model_log_variance,
378
- "pred_xstart": pred_xstart,
379
- 'model_forward': model_forward,
380
- }
381
-
382
- def _predict_xstart_from_eps(self, x_t, t, eps):
383
- assert x_t.shape == eps.shape
384
- return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
385
- x_t.shape) * x_t -
386
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
387
- x_t.shape) * eps)
388
-
389
- def _predict_xstart_from_xprev(self, x_t, t, xprev):
390
- assert x_t.shape == xprev.shape
391
- return ( # (xprev - coef2*x_t) / coef1
392
- _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape)
393
- * xprev - _extract_into_tensor(
394
- self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
395
- x_t.shape) * x_t)
396
-
397
- def _predict_xstart_from_scaled_xstart(self, t, scaled_xstart):
398
- return scaled_xstart * _extract_into_tensor(
399
- self.sqrt_recip_alphas_cumprod, t, scaled_xstart.shape)
400
-
401
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
402
- return (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
403
- x_t.shape) * x_t -
404
- pred_xstart) / _extract_into_tensor(
405
- self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
406
-
407
- def _predict_eps_from_scaled_xstart(self, x_t, t, scaled_xstart):
408
- """
409
- Args:
410
- scaled_xstart: is supposed to be sqrt(alphacum) * x_0
411
- """
412
- # 1 / sqrt(1-alphabar) * (x_t - scaled xstart)
413
- return (x_t - scaled_xstart) / _extract_into_tensor(
414
- self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
415
-
416
- def _scale_timesteps(self, t):
417
- if self.rescale_timesteps:
418
- # scale t to be maxed out at 1000 steps
419
- return t.float() * (1000.0 / self.num_timesteps)
420
- return t
421
-
422
- def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
423
- """
424
- Compute the mean for the previous step, given a function cond_fn that
425
- computes the gradient of a conditional log probability with respect to
426
- x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
427
- condition on y.
428
-
429
- This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
430
- """
431
- gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
432
- new_mean = (p_mean_var["mean"].float() +
433
- p_mean_var["variance"] * gradient.float())
434
- return new_mean
435
-
436
- def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
437
- """
438
- Compute what the p_mean_variance output would have been, should the
439
- model's score function be conditioned by cond_fn.
440
-
441
- See condition_mean() for details on cond_fn.
442
-
443
- Unlike condition_mean(), this instead uses the conditioning strategy
444
- from Song et al (2020).
445
- """
446
- alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
447
-
448
- eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
449
- eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
450
- x, self._scale_timesteps(t), **model_kwargs)
451
-
452
- out = p_mean_var.copy()
453
- out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
454
- out["mean"], _, _ = self.q_posterior_mean_variance(
455
- x_start=out["pred_xstart"], x_t=x, t=t)
456
- return out
457
-
458
- def p_sample(
459
- self,
460
- model: Model,
461
- x,
462
- t,
463
- clip_denoised=True,
464
- denoised_fn=None,
465
- cond_fn=None,
466
- model_kwargs=None,
467
- ):
468
- """
469
- Sample x_{t-1} from the model at the given timestep.
470
-
471
- :param model: the model to sample from.
472
- :param x: the current tensor at x_{t-1}.
473
- :param t: the value of t, starting at 0 for the first diffusion step.
474
- :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
475
- :param denoised_fn: if not None, a function which applies to the
476
- x_start prediction before it is used to sample.
477
- :param cond_fn: if not None, this is a gradient function that acts
478
- similarly to the model.
479
- :param model_kwargs: if not None, a dict of extra keyword arguments to
480
- pass to the model. This can be used for conditioning.
481
- :return: a dict containing the following keys:
482
- - 'sample': a random sample from the model.
483
- - 'pred_xstart': a prediction of x_0.
484
- """
485
- out = self.p_mean_variance(
486
- model,
487
- x,
488
- t,
489
- clip_denoised=clip_denoised,
490
- denoised_fn=denoised_fn,
491
- model_kwargs=model_kwargs,
492
- )
493
- noise = th.randn_like(x)
494
- nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
495
- ) # no noise when t == 0
496
- if cond_fn is not None:
497
- out["mean"] = self.condition_mean(cond_fn,
498
- out,
499
- x,
500
- t,
501
- model_kwargs=model_kwargs)
502
- sample = out["mean"] + nonzero_mask * th.exp(
503
- 0.5 * out["log_variance"]) * noise
504
- return {"sample": sample, "pred_xstart": out["pred_xstart"]}
505
-
506
- def p_sample_loop(
507
- self,
508
- model: Model,
509
- shape=None,
510
- noise=None,
511
- clip_denoised=True,
512
- denoised_fn=None,
513
- cond_fn=None,
514
- model_kwargs=None,
515
- device=None,
516
- progress=False,
517
- ):
518
- """
519
- Generate samples from the model.
520
-
521
- :param model: the model module.
522
- :param shape: the shape of the samples, (N, C, H, W).
523
- :param noise: if specified, the noise from the encoder to sample.
524
- Should be of the same shape as `shape`.
525
- :param clip_denoised: if True, clip x_start predictions to [-1, 1].
526
- :param denoised_fn: if not None, a function which applies to the
527
- x_start prediction before it is used to sample.
528
- :param cond_fn: if not None, this is a gradient function that acts
529
- similarly to the model.
530
- :param model_kwargs: if not None, a dict of extra keyword arguments to
531
- pass to the model. This can be used for conditioning.
532
- :param device: if specified, the device to create the samples on.
533
- If not specified, use a model parameter's device.
534
- :param progress: if True, show a tqdm progress bar.
535
- :return: a non-differentiable batch of samples.
536
- """
537
- final = None
538
- for sample in self.p_sample_loop_progressive(
539
- model,
540
- shape,
541
- noise=noise,
542
- clip_denoised=clip_denoised,
543
- denoised_fn=denoised_fn,
544
- cond_fn=cond_fn,
545
- model_kwargs=model_kwargs,
546
- device=device,
547
- progress=progress,
548
- ):
549
- final = sample
550
- return final["sample"]
551
-
552
- def p_sample_loop_progressive(
553
- self,
554
- model: Model,
555
- shape=None,
556
- noise=None,
557
- clip_denoised=True,
558
- denoised_fn=None,
559
- cond_fn=None,
560
- model_kwargs=None,
561
- device=None,
562
- progress=False,
563
- ):
564
- """
565
- Generate samples from the model and yield intermediate samples from
566
- each timestep of diffusion.
567
-
568
- Arguments are the same as p_sample_loop().
569
- Returns a generator over dicts, where each dict is the return value of
570
- p_sample().
571
- """
572
- if device is None:
573
- device = next(model.parameters()).device
574
- if noise is not None:
575
- img = noise
576
- else:
577
- assert isinstance(shape, (tuple, list))
578
- img = th.randn(*shape, device=device)
579
- indices = list(range(self.num_timesteps))[::-1]
580
-
581
- if progress:
582
- # Lazy import so that we don't depend on tqdm.
583
- from tqdm.auto import tqdm
584
-
585
- indices = tqdm(indices)
586
-
587
- for i in indices:
588
- # t = th.tensor([i] * shape[0], device=device)
589
- t = th.tensor([i] * len(img), device=device)
590
- with th.no_grad():
591
- out = self.p_sample(
592
- model,
593
- img,
594
- t,
595
- clip_denoised=clip_denoised,
596
- denoised_fn=denoised_fn,
597
- cond_fn=cond_fn,
598
- model_kwargs=model_kwargs,
599
- )
600
- yield out
601
- img = out["sample"]
602
-
603
- def ddim_sample(
604
- self,
605
- model: Model,
606
- x,
607
- t,
608
- clip_denoised=True,
609
- denoised_fn=None,
610
- cond_fn=None,
611
- model_kwargs=None,
612
- eta=0.0,
613
- ):
614
- """
615
- Sample x_{t-1} from the model using DDIM.
616
-
617
- Same usage as p_sample().
618
- """
619
- out = self.p_mean_variance(
620
- model,
621
- x,
622
- t,
623
- clip_denoised=clip_denoised,
624
- denoised_fn=denoised_fn,
625
- model_kwargs=model_kwargs,
626
- )
627
- if cond_fn is not None:
628
- out = self.condition_score(cond_fn,
629
- out,
630
- x,
631
- t,
632
- model_kwargs=model_kwargs)
633
-
634
- # Usually our model outputs epsilon, but we re-derive it
635
- # in case we used x_start or x_prev prediction.
636
- eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
637
-
638
- alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
639
- alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t,
640
- x.shape)
641
- sigma = (eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) *
642
- th.sqrt(1 - alpha_bar / alpha_bar_prev))
643
- # Equation 12.
644
- noise = th.randn_like(x)
645
- mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_prev) +
646
- th.sqrt(1 - alpha_bar_prev - sigma**2) * eps)
647
- nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
648
- ) # no noise when t == 0
649
- sample = mean_pred + nonzero_mask * sigma * noise
650
- return {"sample": sample, "pred_xstart": out["pred_xstart"]}
651
-
652
- def ddim_reverse_sample(
653
- self,
654
- model: Model,
655
- x,
656
- t,
657
- clip_denoised=True,
658
- denoised_fn=None,
659
- model_kwargs=None,
660
- eta=0.0,
661
- ):
662
- """
663
- Sample x_{t+1} from the model using DDIM reverse ODE.
664
- NOTE: never used ?
665
- """
666
- assert eta == 0.0, "Reverse ODE only for deterministic path"
667
- out = self.p_mean_variance(
668
- model,
669
- x,
670
- t,
671
- clip_denoised=clip_denoised,
672
- denoised_fn=denoised_fn,
673
- model_kwargs=model_kwargs,
674
- )
675
- # Usually our model outputs epsilon, but we re-derive it
676
- # in case we used x_start or x_prev prediction.
677
- eps = (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape)
678
- * x - out["pred_xstart"]) / _extract_into_tensor(
679
- self.sqrt_recipm1_alphas_cumprod, t, x.shape)
680
- alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t,
681
- x.shape)
682
-
683
- # Equation 12. reversed (DDIM paper) (th.sqrt == torch.sqrt)
684
- mean_pred = (out["pred_xstart"] * th.sqrt(alpha_bar_next) +
685
- th.sqrt(1 - alpha_bar_next) * eps)
686
-
687
- return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
688
-
689
- def ddim_reverse_sample_loop(
690
- self,
691
- model: Model,
692
- x,
693
- clip_denoised=True,
694
- denoised_fn=None,
695
- model_kwargs=None,
696
- eta=0.0,
697
- device=None,
698
- ):
699
- if device is None:
700
- device = next(model.parameters()).device
701
- sample_t = []
702
- xstart_t = []
703
- T = []
704
- indices = list(range(self.num_timesteps))
705
- sample = x
706
- for i in indices:
707
- t = th.tensor([i] * len(sample), device=device)
708
- with th.no_grad():
709
- out = self.ddim_reverse_sample(model,
710
- sample,
711
- t=t,
712
- clip_denoised=clip_denoised,
713
- denoised_fn=denoised_fn,
714
- model_kwargs=model_kwargs,
715
- eta=eta)
716
- sample = out['sample']
717
- # [1, ..., T]
718
- sample_t.append(sample)
719
- # [0, ...., T-1]
720
- xstart_t.append(out['pred_xstart'])
721
- # [0, ..., T-1] ready to use
722
- T.append(t)
723
-
724
- return {
725
- # xT "
726
- 'sample': sample,
727
- # (1, ..., T)
728
- 'sample_t': sample_t,
729
- # xstart here is a bit different from sampling from T = T-1 to T = 0
730
- # may not be exact
731
- 'xstart_t': xstart_t,
732
- 'T': T,
733
- }
734
-
735
- def ddim_sample_loop(
736
- self,
737
- model: Model,
738
- shape=None,
739
- noise=None,
740
- clip_denoised=True,
741
- denoised_fn=None,
742
- cond_fn=None,
743
- model_kwargs=None,
744
- device=None,
745
- progress=False,
746
- eta=0.0,
747
- ):
748
- """
749
- Generate samples from the model using DDIM.
750
-
751
- Same usage as p_sample_loop().
752
- """
753
- final = None
754
- for sample in self.ddim_sample_loop_progressive(
755
- model,
756
- shape,
757
- noise=noise,
758
- clip_denoised=clip_denoised,
759
- denoised_fn=denoised_fn,
760
- cond_fn=cond_fn,
761
- model_kwargs=model_kwargs,
762
- device=device,
763
- progress=progress,
764
- eta=eta,
765
- ):
766
- final = sample
767
- return final["sample"]
768
-
769
- def ddim_sample_loop_progressive(
770
- self,
771
- model: Model,
772
- shape=None,
773
- noise=None,
774
- clip_denoised=True,
775
- denoised_fn=None,
776
- cond_fn=None,
777
- model_kwargs=None,
778
- device=None,
779
- progress=False,
780
- eta=0.0,
781
- ):
782
- """
783
- Use DDIM to sample from the model and yield intermediate samples from
784
- each timestep of DDIM.
785
-
786
- Same usage as p_sample_loop_progressive().
787
- """
788
- if device is None:
789
- device = next(model.parameters()).device
790
- if noise is not None:
791
- img = noise
792
- else:
793
- assert isinstance(shape, (tuple, list))
794
- img = th.randn(*shape, device=device)
795
- indices = list(range(self.num_timesteps))[::-1]
796
-
797
- if progress:
798
- # Lazy import so that we don't depend on tqdm.
799
- from tqdm.auto import tqdm
800
-
801
- indices = tqdm(indices)
802
-
803
- for i in indices:
804
-
805
- if isinstance(model_kwargs, list):
806
- # index dependent model kwargs
807
- # (T-1, ..., 0)
808
- _kwargs = model_kwargs[i]
809
- else:
810
- _kwargs = model_kwargs
811
-
812
- t = th.tensor([i] * len(img), device=device)
813
- with th.no_grad():
814
- out = self.ddim_sample(
815
- model,
816
- img,
817
- t,
818
- clip_denoised=clip_denoised,
819
- denoised_fn=denoised_fn,
820
- cond_fn=cond_fn,
821
- model_kwargs=_kwargs,
822
- eta=eta,
823
- )
824
- out['t'] = t
825
- yield out
826
- img = out["sample"]
827
-
828
- def _vb_terms_bpd(self,
829
- model: Model,
830
- x_start,
831
- x_t,
832
- t,
833
- clip_denoised=True,
834
- model_kwargs=None):
835
- """
836
- Get a term for the variational lower-bound.
837
-
838
- The resulting units are bits (rather than nats, as one might expect).
839
- This allows for comparison to other papers.
840
-
841
- :return: a dict with the following keys:
842
- - 'output': a shape [N] tensor of NLLs or KLs.
843
- - 'pred_xstart': the x_0 predictions.
844
- """
845
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
846
- x_start=x_start, x_t=x_t, t=t)
847
- out = self.p_mean_variance(model,
848
- x_t,
849
- t,
850
- clip_denoised=clip_denoised,
851
- model_kwargs=model_kwargs)
852
- kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"],
853
- out["log_variance"])
854
- kl = mean_flat(kl) / np.log(2.0)
855
-
856
- decoder_nll = -discretized_gaussian_log_likelihood(
857
- x_start, means=out["mean"], log_scales=0.5 * out["log_variance"])
858
- assert decoder_nll.shape == x_start.shape
859
- decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
860
-
861
- # At the first timestep return the decoder NLL,
862
- # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
863
- output = th.where((t == 0), decoder_nll, kl)
864
- return {
865
- "output": output,
866
- "pred_xstart": out["pred_xstart"],
867
- 'model_forward': out['model_forward'],
868
- }
869
-
870
- def _prior_bpd(self, x_start):
871
- """
872
- Get the prior KL term for the variational lower-bound, measured in
873
- bits-per-dim.
874
-
875
- This term can't be optimized, as it only depends on the encoder.
876
-
877
- :param x_start: the [N x C x ...] tensor of inputs.
878
- :return: a batch of [N] KL values (in bits), one per batch element.
879
- """
880
- batch_size = x_start.shape[0]
881
- t = th.tensor([self.num_timesteps - 1] * batch_size,
882
- device=x_start.device)
883
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
884
- kl_prior = normal_kl(mean1=qt_mean,
885
- logvar1=qt_log_variance,
886
- mean2=0.0,
887
- logvar2=0.0)
888
- return mean_flat(kl_prior) / np.log(2.0)
889
-
890
- def calc_bpd_loop(self,
891
- model: Model,
892
- x_start,
893
- clip_denoised=True,
894
- model_kwargs=None):
895
- """
896
- Compute the entire variational lower-bound, measured in bits-per-dim,
897
- as well as other related quantities.
898
-
899
- :param model: the model to evaluate loss on.
900
- :param x_start: the [N x C x ...] tensor of inputs.
901
- :param clip_denoised: if True, clip denoised samples.
902
- :param model_kwargs: if not None, a dict of extra keyword arguments to
903
- pass to the model. This can be used for conditioning.
904
-
905
- :return: a dict containing the following keys:
906
- - total_bpd: the total variational lower-bound, per batch element.
907
- - prior_bpd: the prior term in the lower-bound.
908
- - vb: an [N x T] tensor of terms in the lower-bound.
909
- - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
910
- - mse: an [N x T] tensor of epsilon MSEs for each timestep.
911
- """
912
- device = x_start.device
913
- batch_size = x_start.shape[0]
914
-
915
- vb = []
916
- xstart_mse = []
917
- mse = []
918
- for t in list(range(self.num_timesteps))[::-1]:
919
- t_batch = th.tensor([t] * batch_size, device=device)
920
- noise = th.randn_like(x_start)
921
- x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
922
- # Calculate VLB term at the current timestep
923
- with th.no_grad():
924
- out = self._vb_terms_bpd(
925
- model,
926
- x_start=x_start,
927
- x_t=x_t,
928
- t=t_batch,
929
- clip_denoised=clip_denoised,
930
- model_kwargs=model_kwargs,
931
- )
932
- vb.append(out["output"])
933
- xstart_mse.append(mean_flat((out["pred_xstart"] - x_start)**2))
934
- eps = self._predict_eps_from_xstart(x_t, t_batch,
935
- out["pred_xstart"])
936
- mse.append(mean_flat((eps - noise)**2))
937
-
938
- vb = th.stack(vb, dim=1)
939
- xstart_mse = th.stack(xstart_mse, dim=1)
940
- mse = th.stack(mse, dim=1)
941
-
942
- prior_bpd = self._prior_bpd(x_start)
943
- total_bpd = vb.sum(dim=1) + prior_bpd
944
- return {
945
- "total_bpd": total_bpd,
946
- "prior_bpd": prior_bpd,
947
- "vb": vb,
948
- "xstart_mse": xstart_mse,
949
- "mse": mse,
950
- }
951
-
952
-
953
- def _extract_into_tensor(arr, timesteps, broadcast_shape):
954
- """
955
- Extract values from a 1-D numpy array for a batch of indices.
956
-
957
- :param arr: the 1-D numpy array.
958
- :param timesteps: a tensor of indices into the array to extract.
959
- :param broadcast_shape: a larger shape of K dimensions with the batch
960
- dimension equal to the length of timesteps.
961
- :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
962
- """
963
- res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
964
- while len(res.shape) < len(broadcast_shape):
965
- res = res[..., None]
966
- return res.expand(broadcast_shape)
967
-
968
-
969
- def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
970
- """
971
- Get a pre-defined beta schedule for the given name.
972
-
973
- The beta schedule library consists of beta schedules which remain similar
974
- in the limit of num_diffusion_timesteps.
975
- Beta schedules may be added, but should not be removed or changed once
976
- they are committed to maintain backwards compatibility.
977
- """
978
- if schedule_name == "linear":
979
- # Linear schedule from Ho et al, extended to work for any number of
980
- # diffusion steps.
981
- scale = 1000 / num_diffusion_timesteps
982
- beta_start = scale * 0.0001
983
- beta_end = scale * 0.02
984
- return np.linspace(beta_start,
985
- beta_end,
986
- num_diffusion_timesteps,
987
- dtype=np.float64)
988
- elif schedule_name == "cosine":
989
- return betas_for_alpha_bar(
990
- num_diffusion_timesteps,
991
- lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2)**2,
992
- )
993
- elif schedule_name == "const0.01":
994
- scale = 1000 / num_diffusion_timesteps
995
- return np.array([scale * 0.01] * num_diffusion_timesteps,
996
- dtype=np.float64)
997
- elif schedule_name == "const0.015":
998
- scale = 1000 / num_diffusion_timesteps
999
- return np.array([scale * 0.015] * num_diffusion_timesteps,
1000
- dtype=np.float64)
1001
- elif schedule_name == "const0.008":
1002
- scale = 1000 / num_diffusion_timesteps
1003
- return np.array([scale * 0.008] * num_diffusion_timesteps,
1004
- dtype=np.float64)
1005
- elif schedule_name == "const0.0065":
1006
- scale = 1000 / num_diffusion_timesteps
1007
- return np.array([scale * 0.0065] * num_diffusion_timesteps,
1008
- dtype=np.float64)
1009
- elif schedule_name == "const0.0055":
1010
- scale = 1000 / num_diffusion_timesteps
1011
- return np.array([scale * 0.0055] * num_diffusion_timesteps,
1012
- dtype=np.float64)
1013
- elif schedule_name == "const0.0045":
1014
- scale = 1000 / num_diffusion_timesteps
1015
- return np.array([scale * 0.0045] * num_diffusion_timesteps,
1016
- dtype=np.float64)
1017
- elif schedule_name == "const0.0035":
1018
- scale = 1000 / num_diffusion_timesteps
1019
- return np.array([scale * 0.0035] * num_diffusion_timesteps,
1020
- dtype=np.float64)
1021
- elif schedule_name == "const0.0025":
1022
- scale = 1000 / num_diffusion_timesteps
1023
- return np.array([scale * 0.0025] * num_diffusion_timesteps,
1024
- dtype=np.float64)
1025
- elif schedule_name == "const0.0015":
1026
- scale = 1000 / num_diffusion_timesteps
1027
- return np.array([scale * 0.0015] * num_diffusion_timesteps,
1028
- dtype=np.float64)
1029
- else:
1030
- raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
1031
-
1032
-
1033
- def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
1034
- """
1035
- Create a beta schedule that discretizes the given alpha_t_bar function,
1036
- which defines the cumulative product of (1-beta) over time from t = [0,1].
1037
-
1038
- :param num_diffusion_timesteps: the number of betas to produce.
1039
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
1040
- produces the cumulative product of (1-beta) up to that
1041
- part of the diffusion process.
1042
- :param max_beta: the maximum beta to use; use values lower than 1 to
1043
- prevent singularities.
1044
- """
1045
- betas = []
1046
- for i in range(num_diffusion_timesteps):
1047
- t1 = i / num_diffusion_timesteps
1048
- t2 = (i + 1) / num_diffusion_timesteps
1049
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
1050
- return np.array(betas)
1051
-
1052
-
1053
- def normal_kl(mean1, logvar1, mean2, logvar2):
1054
- """
1055
- Compute the KL divergence between two gaussians.
1056
-
1057
- Shapes are automatically broadcasted, so batches can be compared to
1058
- scalars, among other use cases.
1059
- """
1060
- tensor = None
1061
- for obj in (mean1, logvar1, mean2, logvar2):
1062
- if isinstance(obj, th.Tensor):
1063
- tensor = obj
1064
- break
1065
- assert tensor is not None, "at least one argument must be a Tensor"
1066
-
1067
- # Force variances to be Tensors. Broadcasting helps convert scalars to
1068
- # Tensors, but it does not work for th.exp().
1069
- logvar1, logvar2 = [
1070
- x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
1071
- for x in (logvar1, logvar2)
1072
- ]
1073
-
1074
- return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) +
1075
- ((mean1 - mean2)**2) * th.exp(-logvar2))
1076
-
1077
-
1078
- def approx_standard_normal_cdf(x):
1079
- """
1080
- A fast approximation of the cumulative distribution function of the
1081
- standard normal.
1082
- """
1083
- return 0.5 * (
1084
- 1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
1085
-
1086
-
1087
- def discretized_gaussian_log_likelihood(x, *, means, log_scales):
1088
- """
1089
- Compute the log-likelihood of a Gaussian distribution discretizing to a
1090
- given image.
1091
-
1092
- :param x: the target images. It is assumed that this was uint8 values,
1093
- rescaled to the range [-1, 1].
1094
- :param means: the Gaussian mean Tensor.
1095
- :param log_scales: the Gaussian log stddev Tensor.
1096
- :return: a tensor like x of log probabilities (in nats).
1097
- """
1098
- assert x.shape == means.shape == log_scales.shape
1099
- centered_x = x - means
1100
- inv_stdv = th.exp(-log_scales)
1101
- plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
1102
- cdf_plus = approx_standard_normal_cdf(plus_in)
1103
- min_in = inv_stdv * (centered_x - 1.0 / 255.0)
1104
- cdf_min = approx_standard_normal_cdf(min_in)
1105
- log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
1106
- log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
1107
- cdf_delta = cdf_plus - cdf_min
1108
- log_probs = th.where(
1109
- x < -0.999,
1110
- log_cdf_plus,
1111
- th.where(x > 0.999, log_one_minus_cdf_min,
1112
- th.log(cdf_delta.clamp(min=1e-12))),
1113
- )
1114
- assert log_probs.shape == x.shape
1115
- return log_probs
1116
-
1117
-
1118
- class DummyModel(th.nn.Module):
1119
- def __init__(self, pred):
1120
- super().__init__()
1121
- self.pred = pred
1122
-
1123
- def forward(self, *args, **kwargs):
1124
- return DummyReturn(pred=self.pred)
1125
-
1126
-
1127
- class DummyReturn(NamedTuple):
1128
- pred: th.Tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/diffusion/diffusion.py DELETED
@@ -1,156 +0,0 @@
1
- from .base import *
2
- from dataclasses import dataclass
3
-
4
-
5
- def space_timesteps(num_timesteps, section_counts):
6
- """
7
- Create a list of timesteps to use from an original diffusion process,
8
- given the number of timesteps we want to take from equally-sized portions
9
- of the original process.
10
-
11
- For example, if there's 300 timesteps and the section counts are [10,15,20]
12
- then the first 100 timesteps are strided to be 10 timesteps, the second 100
13
- are strided to be 15 timesteps, and the final 100 are strided to be 20.
14
-
15
- If the stride is a string starting with "ddim", then the fixed striding
16
- from the DDIM paper is used, and only one section is allowed.
17
-
18
- :param num_timesteps: the number of diffusion steps in the original
19
- process to divide up.
20
- :param section_counts: either a list of numbers, or a string containing
21
- comma-separated numbers, indicating the step count
22
- per section. As a special case, use "ddimN" where N
23
- is a number of steps to use the striding from the
24
- DDIM paper.
25
- :return: a set of diffusion steps from the original process to use.
26
- """
27
- if isinstance(section_counts, str):
28
- if section_counts.startswith("ddim"):
29
- desired_count = int(section_counts[len("ddim"):])
30
- for i in range(1, num_timesteps):
31
- if len(range(0, num_timesteps, i)) == desired_count:
32
- return set(range(0, num_timesteps, i))
33
- raise ValueError(
34
- f"cannot create exactly {num_timesteps} steps with an integer stride"
35
- )
36
- section_counts = [int(x) for x in section_counts.split(",")]
37
- size_per = num_timesteps // len(section_counts)
38
- extra = num_timesteps % len(section_counts)
39
- start_idx = 0
40
- all_steps = []
41
- for i, section_count in enumerate(section_counts):
42
- size = size_per + (1 if i < extra else 0)
43
- if size < section_count:
44
- raise ValueError(
45
- f"cannot divide section of {size} steps into {section_count}")
46
- if section_count <= 1:
47
- frac_stride = 1
48
- else:
49
- frac_stride = (size - 1) / (section_count - 1)
50
- cur_idx = 0.0
51
- taken_steps = []
52
- for _ in range(section_count):
53
- taken_steps.append(start_idx + round(cur_idx))
54
- cur_idx += frac_stride
55
- all_steps += taken_steps
56
- start_idx += size
57
- return set(all_steps)
58
-
59
-
60
- @dataclass
61
- class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
62
- use_timesteps: Tuple[int] = None
63
-
64
- def make_sampler(self):
65
- return SpacedDiffusionBeatGans(self)
66
-
67
-
68
- class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
69
- """
70
- A diffusion process which can skip steps in a base diffusion process.
71
-
72
- :param use_timesteps: a collection (sequence or set) of timesteps from the
73
- original diffusion process to retain.
74
- :param kwargs: the kwargs to create the base diffusion process.
75
- """
76
- def __init__(self, conf: SpacedDiffusionBeatGansConfig):
77
- self.conf = conf
78
- self.use_timesteps = set(conf.use_timesteps)
79
- # how the new t's mapped to the old t's
80
- self.timestep_map = []
81
- self.original_num_steps = len(conf.betas)
82
-
83
- base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
84
- last_alpha_cumprod = 1.0
85
- new_betas = []
86
- for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
87
- if i in self.use_timesteps:
88
- # getting the new betas of the new timesteps
89
- new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
90
- last_alpha_cumprod = alpha_cumprod
91
- self.timestep_map.append(i)
92
- conf.betas = np.array(new_betas)
93
- super().__init__(conf)
94
-
95
- def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
96
- return super().p_mean_variance(self._wrap_model(model), *args,
97
- **kwargs)
98
-
99
- def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
100
- return super().training_losses(self._wrap_model(model), *args,
101
- **kwargs)
102
-
103
- def condition_mean(self, cond_fn, *args, **kwargs):
104
- return super().condition_mean(self._wrap_model(cond_fn), *args,
105
- **kwargs)
106
-
107
- def condition_score(self, cond_fn, *args, **kwargs):
108
- return super().condition_score(self._wrap_model(cond_fn), *args,
109
- **kwargs)
110
-
111
- def _wrap_model(self, model: Model):
112
- if isinstance(model, _WrappedModel):
113
- return model
114
- return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
115
- self.original_num_steps)
116
-
117
- def _scale_timesteps(self, t):
118
- # Scaling is done by the wrapped model.
119
- return t
120
-
121
-
122
- class _WrappedModel:
123
- """
124
- converting the supplied t's to the old t's scales.
125
- """
126
- def __init__(self, model, timestep_map, rescale_timesteps,
127
- original_num_steps):
128
- self.model = model
129
- self.timestep_map = timestep_map
130
- self.rescale_timesteps = rescale_timesteps
131
- self.original_num_steps = original_num_steps
132
-
133
- def forward(self,motion_start, motion_direction_start, audio_feats,face_location, face_scale,yaw_pitch_roll, x_t, t, control_flag=False):
134
- """
135
- Args:
136
- t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
137
- t_cond: the same as t but can be of different values
138
- """
139
- map_tensor = th.tensor(self.timestep_map,
140
- device=t.device,
141
- dtype=t.dtype)
142
-
143
- def do(t):
144
- new_ts = map_tensor[t]
145
- if self.rescale_timesteps:
146
- new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
147
- return new_ts
148
-
149
- return self.model(motion_start, motion_direction_start, audio_feats,face_location, face_scale,yaw_pitch_roll, x_t,do(t), control_flag=control_flag)
150
-
151
- def __getattr__(self, name):
152
- # allow for calling the model's methods
153
- if hasattr(self.model, name):
154
- func = getattr(self.model, name)
155
- return func
156
- raise AttributeError(name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/diffusion/resample.py DELETED
@@ -1,63 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- import numpy as np
4
- import torch as th
5
- import torch.distributed as dist
6
-
7
-
8
- def create_named_schedule_sampler(name, diffusion):
9
- """
10
- Create a ScheduleSampler from a library of pre-defined samplers.
11
-
12
- :param name: the name of the sampler.
13
- :param diffusion: the diffusion object to sample for.
14
- """
15
- if name == "uniform":
16
- return UniformSampler(diffusion)
17
- else:
18
- raise NotImplementedError(f"unknown schedule sampler: {name}")
19
-
20
-
21
- class ScheduleSampler(ABC):
22
- """
23
- A distribution over timesteps in the diffusion process, intended to reduce
24
- variance of the objective.
25
-
26
- By default, samplers perform unbiased importance sampling, in which the
27
- objective's mean is unchanged.
28
- However, subclasses may override sample() to change how the resampled
29
- terms are reweighted, allowing for actual changes in the objective.
30
- """
31
- @abstractmethod
32
- def weights(self):
33
- """
34
- Get a numpy array of weights, one per diffusion step.
35
-
36
- The weights needn't be normalized, but must be positive.
37
- """
38
-
39
- def sample(self, batch_size, device):
40
- """
41
- Importance-sample timesteps for a batch.
42
-
43
- :param batch_size: the number of timesteps.
44
- :param device: the torch device to save to.
45
- :return: a tuple (timesteps, weights):
46
- - timesteps: a tensor of timestep indices.
47
- - weights: a tensor of weights to scale the resulting losses.
48
- """
49
- w = self.weights()
50
- p = w / np.sum(w)
51
- indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
52
- indices = th.from_numpy(indices_np).long().to(device)
53
- weights_np = 1 / (len(p) * p[indices_np])
54
- weights = th.from_numpy(weights_np).float().to(device)
55
- return indices, weights
56
-
57
-
58
- class UniformSampler(ScheduleSampler):
59
- def __init__(self, num_timesteps):
60
- self._weights = np.ones([num_timesteps])
61
-
62
- def weights(self):
63
- return self._weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/dist_utils.py DELETED
@@ -1,42 +0,0 @@
1
- from typing import List
2
- from torch import distributed
3
-
4
-
5
- def barrier():
6
- if distributed.is_initialized():
7
- distributed.barrier()
8
- else:
9
- pass
10
-
11
-
12
- def broadcast(data, src):
13
- if distributed.is_initialized():
14
- distributed.broadcast(data, src)
15
- else:
16
- pass
17
-
18
-
19
- def all_gather(data: List, src):
20
- if distributed.is_initialized():
21
- distributed.all_gather(data, src)
22
- else:
23
- data[0] = src
24
-
25
-
26
- def get_rank():
27
- if distributed.is_initialized():
28
- return distributed.get_rank()
29
- else:
30
- return 0
31
-
32
-
33
- def get_world_size():
34
- if distributed.is_initialized():
35
- return distributed.get_world_size()
36
- else:
37
- return 1
38
-
39
-
40
- def chunk_size(size, rank, world_size):
41
- extra = rank < size % world_size
42
- return size // world_size + extra
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/experiment.py DELETED
@@ -1,356 +0,0 @@
1
- import copy
2
- import os
3
-
4
- import numpy as np
5
- import pytorch_lightning as pl
6
- import torch
7
- from pytorch_lightning import loggers as pl_loggers
8
- from pytorch_lightning.callbacks import *
9
- from torch.cuda import amp
10
- from torch.optim.optimizer import Optimizer
11
- from torch.utils.data.dataset import TensorDataset
12
- from model.seq2seq import DiffusionPredictor
13
-
14
- from config import *
15
- from dist_utils import *
16
- from renderer import *
17
-
18
- # This part is modified from: https://github.com/phizaz/diffae/blob/master/experiment.py
19
- class LitModel(pl.LightningModule):
20
- def __init__(self, conf: TrainConfig):
21
- super().__init__()
22
- assert conf.train_mode != TrainMode.manipulate
23
- if conf.seed is not None:
24
- pl.seed_everything(conf.seed)
25
-
26
- self.save_hyperparameters(conf.as_dict_jsonable())
27
-
28
- self.conf = conf
29
-
30
- self.model = DiffusionPredictor(conf)
31
-
32
- self.ema_model = copy.deepcopy(self.model)
33
- self.ema_model.requires_grad_(False)
34
- self.ema_model.eval()
35
-
36
- self.sampler = conf.make_diffusion_conf().make_sampler()
37
- self.eval_sampler = conf.make_eval_diffusion_conf().make_sampler()
38
-
39
- # this is shared for both model and latent
40
- self.T_sampler = conf.make_T_sampler()
41
-
42
- if conf.train_mode.use_latent_net():
43
- self.latent_sampler = conf.make_latent_diffusion_conf(
44
- ).make_sampler()
45
- self.eval_latent_sampler = conf.make_latent_eval_diffusion_conf(
46
- ).make_sampler()
47
- else:
48
- self.latent_sampler = None
49
- self.eval_latent_sampler = None
50
-
51
- # initial variables for consistent sampling
52
- self.register_buffer(
53
- 'x_T',
54
- torch.randn(conf.sample_size, 3, conf.img_size, conf.img_size))
55
-
56
-
57
- def render(self, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, step_T, control_flag):
58
- if step_T is None:
59
- sampler = self.eval_sampler
60
- else:
61
- sampler = self.conf._make_diffusion_conf(step_T).make_sampler()
62
-
63
- pred_img = render_condition(self.conf,
64
- self.ema_model,
65
- sampler, start, motion_direction_start, audio_driven, face_location, face_scale, ypr_info, noisyT, control_flag)
66
- return pred_img
67
-
68
- def forward(self, noise=None, x_start=None, ema_model: bool = False):
69
- with amp.autocast(False):
70
- if not self.disable_ema:
71
- model = self.ema_model
72
- else:
73
- model = self.model
74
- gen = self.eval_sampler.sample(model=model,
75
- noise=noise,
76
- x_start=x_start)
77
- return gen
78
-
79
- def setup(self, stage=None) -> None:
80
- """
81
- make datasets & seeding each worker separately
82
- """
83
- ##############################################
84
- # NEED TO SET THE SEED SEPARATELY HERE
85
- if self.conf.seed is not None:
86
- seed = self.conf.seed * get_world_size() + self.global_rank
87
- np.random.seed(seed)
88
- torch.manual_seed(seed)
89
- torch.cuda.manual_seed(seed)
90
- print('local seed:', seed)
91
- ##############################################
92
-
93
- self.train_data = self.conf.make_dataset()
94
- print('train data:', len(self.train_data))
95
- self.val_data = self.train_data
96
- print('val data:', len(self.val_data))
97
-
98
- def _train_dataloader(self, drop_last=True):
99
- """
100
- really make the dataloader
101
- """
102
- # make sure to use the fraction of batch size
103
- # the batch size is global!
104
- conf = self.conf.clone()
105
- conf.batch_size = self.batch_size
106
-
107
- dataloader = conf.make_loader(self.train_data,
108
- shuffle=True,
109
- drop_last=drop_last)
110
- return dataloader
111
-
112
- def train_dataloader(self):
113
- """
114
- return the dataloader, if diffusion mode => return image dataset
115
- if latent mode => return the inferred latent dataset
116
- """
117
- print('on train dataloader start ...')
118
- if self.conf.train_mode.require_dataset_infer():
119
- if self.conds is None:
120
- # usually we load self.conds from a file
121
- # so we do not need to do this again!
122
- self.conds = self.infer_whole_dataset()
123
- # need to use float32! unless the mean & std will be off!
124
- # (1, c)
125
- self.conds_mean.data = self.conds.float().mean(dim=0,
126
- keepdim=True)
127
- self.conds_std.data = self.conds.float().std(dim=0,
128
- keepdim=True)
129
- print('mean:', self.conds_mean.mean(), 'std:',
130
- self.conds_std.mean())
131
-
132
- # return the dataset with pre-calculated conds
133
- conf = self.conf.clone()
134
- conf.batch_size = self.batch_size
135
- data = TensorDataset(self.conds)
136
- return conf.make_loader(data, shuffle=True)
137
- else:
138
- return self._train_dataloader()
139
-
140
- @property
141
- def batch_size(self):
142
- """
143
- local batch size for each worker
144
- """
145
- ws = get_world_size()
146
- assert self.conf.batch_size % ws == 0
147
- return self.conf.batch_size // ws
148
-
149
- @property
150
- def num_samples(self):
151
- """
152
- (global) batch size * iterations
153
- """
154
- # batch size here is global!
155
- # global_step already takes into account the accum batches
156
- return self.global_step * self.conf.batch_size_effective
157
-
158
- def is_last_accum(self, batch_idx):
159
- """
160
- is it the last gradient accumulation loop?
161
- used with gradient_accum > 1 and to see if the optimizer will perform "step" in this iteration or not
162
- """
163
- return (batch_idx + 1) % self.conf.accum_batches == 0
164
-
165
- def training_step(self, batch, batch_idx):
166
- """
167
- given an input, calculate the loss function
168
- no optimization at this stage.
169
- """
170
- with amp.autocast(False):
171
- motion_start = batch['motion_start'] # torch.Size([B, 512])
172
- motion_direction = batch['motion_direction'] # torch.Size([B, 125, 20])
173
- audio_feats = batch['audio_feats'].float() # torch.Size([B, 25, 250, 1024])
174
- face_location = batch['face_location'].float() # torch.Size([B, 125])
175
- face_scale = batch['face_scale'].float() # torch.Size([B, 125, 1])
176
- yaw_pitch_roll = batch['yaw_pitch_roll'].float() # torch.Size([B, 125, 3])
177
- motion_direction_start = batch['motion_direction_start'].float() # torch.Size([B, 20])
178
-
179
- # import pdb; pdb.set_trace()
180
- if self.conf.train_mode == TrainMode.diffusion:
181
- """
182
- main training mode!!!
183
- """
184
- # with numpy seed we have the problem that the sample t's are related!
185
- t, weight = self.T_sampler.sample(len(motion_start), motion_start.device)
186
- losses = self.sampler.training_losses(model=self.model,
187
- motion_direction_start=motion_direction_start,
188
- motion_target=motion_direction,
189
- motion_start=motion_start,
190
- audio_feats=audio_feats,
191
- face_location=face_location,
192
- face_scale=face_scale,
193
- yaw_pitch_roll=yaw_pitch_roll,
194
- t=t)
195
- else:
196
- raise NotImplementedError()
197
-
198
- loss = losses['loss'].mean()
199
- # divide by accum batches to make the accumulated gradient exact!
200
- for key in losses.keys():
201
- losses[key] = self.all_gather(losses[key]).mean()
202
-
203
- if self.global_rank == 0:
204
- self.logger.experiment.add_scalar('loss', losses['loss'],
205
- self.num_samples)
206
- for key in losses:
207
- self.logger.experiment.add_scalar(
208
- f'loss/{key}', losses[key], self.num_samples)
209
-
210
- return {'loss': loss}
211
-
212
- def on_train_batch_end(self, outputs, batch, batch_idx: int,
213
- dataloader_idx: int) -> None:
214
- """
215
- after each training step ...
216
- """
217
- if self.is_last_accum(batch_idx):
218
-
219
- if self.conf.train_mode == TrainMode.latent_diffusion:
220
- # it trains only the latent hence change only the latent
221
- ema(self.model.latent_net, self.ema_model.latent_net,
222
- self.conf.ema_decay)
223
- else:
224
- ema(self.model, self.ema_model, self.conf.ema_decay)
225
-
226
- def on_before_optimizer_step(self, optimizer: Optimizer,
227
- optimizer_idx: int) -> None:
228
- # fix the fp16 + clip grad norm problem with pytorch lightinng
229
- # this is the currently correct way to do it
230
- if self.conf.grad_clip > 0:
231
- # from trainer.params_grads import grads_norm, iter_opt_params
232
- params = [
233
- p for group in optimizer.param_groups for p in group['params']
234
- ]
235
- torch.nn.utils.clip_grad_norm_(params,
236
- max_norm=self.conf.grad_clip)
237
- def configure_optimizers(self):
238
- out = {}
239
- if self.conf.optimizer == OptimizerType.adam:
240
- optim = torch.optim.Adam(self.model.parameters(),
241
- lr=self.conf.lr,
242
- weight_decay=self.conf.weight_decay)
243
- elif self.conf.optimizer == OptimizerType.adamw:
244
- optim = torch.optim.AdamW(self.model.parameters(),
245
- lr=self.conf.lr,
246
- weight_decay=self.conf.weight_decay)
247
- else:
248
- raise NotImplementedError()
249
- out['optimizer'] = optim
250
- if self.conf.warmup > 0:
251
- sched = torch.optim.lr_scheduler.LambdaLR(optim,
252
- lr_lambda=WarmupLR(
253
- self.conf.warmup))
254
- out['lr_scheduler'] = {
255
- 'scheduler': sched,
256
- 'interval': 'step',
257
- }
258
- return out
259
-
260
- def split_tensor(self, x):
261
- """
262
- extract the tensor for a corresponding "worker" in the batch dimension
263
-
264
- Args:
265
- x: (n, c)
266
-
267
- Returns: x: (n_local, c)
268
- """
269
- n = len(x)
270
- rank = self.global_rank
271
- world_size = get_world_size()
272
- # print(f'rank: {rank}/{world_size}')
273
- per_rank = n // world_size
274
- return x[rank * per_rank:(rank + 1) * per_rank]
275
-
276
- def ema(source, target, decay):
277
- source_dict = source.state_dict()
278
- target_dict = target.state_dict()
279
- for key in source_dict.keys():
280
- target_dict[key].data.copy_(target_dict[key].data * decay +
281
- source_dict[key].data * (1 - decay))
282
-
283
-
284
- class WarmupLR:
285
- def __init__(self, warmup) -> None:
286
- self.warmup = warmup
287
-
288
- def __call__(self, step):
289
- return min(step, self.warmup) / self.warmup
290
-
291
-
292
- def is_time(num_samples, every, step_size):
293
- closest = (num_samples // every) * every
294
- return num_samples - closest < step_size
295
-
296
-
297
- def train(conf: TrainConfig, gpus, nodes=1, mode: str = 'train'):
298
- print('conf:', conf.name)
299
- # assert not (conf.fp16 and conf.grad_clip > 0
300
- # ), 'pytorch lightning has bug with amp + gradient clipping'
301
- model = LitModel(conf)
302
-
303
- if not os.path.exists(conf.logdir):
304
- os.makedirs(conf.logdir)
305
- checkpoint = ModelCheckpoint(dirpath=f'{conf.logdir}',
306
- save_last=True,
307
- save_top_k=-1,
308
- every_n_epochs=10)
309
- checkpoint_path = f'{conf.logdir}/last.ckpt'
310
- print('ckpt path:', checkpoint_path)
311
- if os.path.exists(checkpoint_path):
312
- resume = checkpoint_path
313
- print('resume!')
314
- else:
315
- if conf.continue_from is not None:
316
- # continue from a checkpoint
317
- resume = conf.continue_from.pathcd
318
- else:
319
- resume = None
320
-
321
- tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir,
322
- name=None,
323
- version='')
324
-
325
- # from pytorch_lightning.
326
-
327
- plugins = []
328
- if len(gpus) == 1 and nodes == 1:
329
- accelerator = None
330
- else:
331
- accelerator = 'ddp'
332
- from pytorch_lightning.plugins import DDPPlugin
333
-
334
- # important for working with gradient checkpoint
335
- plugins.append(DDPPlugin(find_unused_parameters=True))
336
-
337
- trainer = pl.Trainer(
338
- max_steps=conf.total_samples // conf.batch_size_effective,
339
- resume_from_checkpoint=resume,
340
- gpus=gpus,
341
- num_nodes=nodes,
342
- accelerator=accelerator,
343
- precision=16 if conf.fp16 else 32,
344
- callbacks=[
345
- checkpoint,
346
- LearningRateMonitor(),
347
- ],
348
- # clip in the model instead
349
- # gradient_clip_val=conf.grad_clip,
350
- replace_sampler_ddp=True,
351
- logger=tb_logger,
352
- accumulate_grad_batches=conf.accum_batches,
353
- plugins=plugins,
354
- )
355
-
356
- trainer.fit(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/face_sr/face_enhancer.py DELETED
@@ -1,123 +0,0 @@
1
- import os
2
- import torch
3
-
4
- from gfpgan import GFPGANer
5
-
6
- from tqdm import tqdm
7
-
8
- from .videoio import load_video_to_cv2
9
-
10
- import cv2
11
-
12
-
13
- class GeneratorWithLen(object):
14
- """ From https://stackoverflow.com/a/7460929 """
15
-
16
- def __init__(self, gen, length):
17
- self.gen = gen
18
- self.length = length
19
-
20
- def __len__(self):
21
- return self.length
22
-
23
- def __iter__(self):
24
- return self.gen
25
-
26
- def enhancer_list(images, method='gfpgan', bg_upsampler='realesrgan'):
27
- gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
28
- return list(gen)
29
-
30
- def enhancer_generator_with_len(images, method='gfpgan', bg_upsampler='realesrgan'):
31
- """ Provide a generator with a __len__ method so that it can passed to functions that
32
- call len()"""
33
-
34
- if os.path.isfile(images): # handle video to images
35
- # TODO: Create a generator version of load_video_to_cv2
36
- images = load_video_to_cv2(images)
37
-
38
- gen = enhancer_generator_no_len(images, method=method, bg_upsampler=bg_upsampler)
39
- gen_with_len = GeneratorWithLen(gen, len(images))
40
- return gen_with_len
41
-
42
- def enhancer_generator_no_len(images, method='gfpgan', bg_upsampler='realesrgan'):
43
- """ Provide a generator function so that all of the enhanced images don't need
44
- to be stored in memory at the same time. This can save tons of RAM compared to
45
- the enhancer function. """
46
-
47
- print('face enhancer....')
48
- if not isinstance(images, list) and os.path.isfile(images): # handle video to images
49
- images = load_video_to_cv2(images)
50
-
51
- # ------------------------ set up GFPGAN restorer ------------------------
52
- if method == 'gfpgan':
53
- arch = 'clean'
54
- channel_multiplier = 2
55
- model_name = 'GFPGANv1.4'
56
- url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
57
- elif method == 'RestoreFormer':
58
- arch = 'RestoreFormer'
59
- channel_multiplier = 2
60
- model_name = 'RestoreFormer'
61
- url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
62
- elif method == 'codeformer': # TODO:
63
- arch = 'CodeFormer'
64
- channel_multiplier = 2
65
- model_name = 'CodeFormer'
66
- url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
67
- else:
68
- raise ValueError(f'Wrong model version {method}.')
69
-
70
-
71
- # ------------------------ set up background upsampler ------------------------
72
- if bg_upsampler == 'realesrgan':
73
- if not torch.cuda.is_available(): # CPU
74
- import warnings
75
- warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
76
- 'If you really want to use it, please modify the corresponding codes.')
77
- bg_upsampler = None
78
- else:
79
- from basicsr.archs.rrdbnet_arch import RRDBNet
80
- from realesrgan import RealESRGANer
81
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
82
- bg_upsampler = RealESRGANer(
83
- scale=2,
84
- model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
85
- model=model,
86
- tile=400,
87
- tile_pad=10,
88
- pre_pad=0,
89
- half=True) # need to set False in CPU mode
90
- else:
91
- bg_upsampler = None
92
-
93
- # determine model paths
94
- model_path = os.path.join('gfpgan/weights', model_name + '.pth')
95
-
96
- if not os.path.isfile(model_path):
97
- model_path = os.path.join('checkpoints', model_name + '.pth')
98
-
99
- if not os.path.isfile(model_path):
100
- # download pre-trained models from url
101
- model_path = url
102
-
103
- restorer = GFPGANer(
104
- model_path=model_path,
105
- upscale=2,
106
- arch=arch,
107
- channel_multiplier=channel_multiplier,
108
- bg_upsampler=bg_upsampler)
109
-
110
- # ------------------------ restore ------------------------
111
- for idx in tqdm(range(len(images)), 'Face Enhancer:'):
112
-
113
- img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
114
-
115
- # restore faces and background if necessary
116
- cropped_faces, restored_faces, r_img = restorer.enhance(
117
- img,
118
- has_aligned=False,
119
- only_center_face=False,
120
- paste_back=True)
121
-
122
- r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
123
- yield r_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/face_sr/videoio.py DELETED
@@ -1,41 +0,0 @@
1
- import shutil
2
- import uuid
3
-
4
- import os
5
-
6
- import cv2
7
-
8
- def load_video_to_cv2(input_path):
9
- video_stream = cv2.VideoCapture(input_path)
10
- fps = video_stream.get(cv2.CAP_PROP_FPS)
11
- full_frames = []
12
- while 1:
13
- still_reading, frame = video_stream.read()
14
- if not still_reading:
15
- video_stream.release()
16
- break
17
- full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
18
- return full_frames
19
-
20
- def save_video_with_watermark(video, audio, save_path, watermark=False):
21
- temp_file = str(uuid.uuid4())+'.mp4'
22
- cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file)
23
- os.system(cmd)
24
-
25
- if watermark is False:
26
- shutil.move(temp_file, save_path)
27
- else:
28
- # watermark
29
- try:
30
- ##### check if stable-diffusion-webui
31
- import webui
32
- from modules import paths
33
- watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png"
34
- except:
35
- # get the root path of sadtalker.
36
- dir_path = os.path.dirname(os.path.realpath(__file__))
37
- watarmark_path = dir_path+"/../../docs/sadtalker_logo.png"
38
-
39
- cmd = r'ffmpeg -y -hide_banner -loglevel error -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path)
40
- os.system(cmd)
41
- os.remove(temp_file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/model/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from typing import Union
2
- from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
3
- from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
4
-
5
- Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
6
- ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
 
 
 
 
 
 
 
code/model/base.py DELETED
@@ -1,37 +0,0 @@
1
- # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2
- # This program is free software; you can redistribute it and/or modify
3
- # it under the terms of the MIT License.
4
- # This program is distributed in the hope that it will be useful,
5
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
6
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7
- # MIT License for more details.
8
-
9
- import numpy as np
10
- import torch
11
-
12
-
13
- class BaseModule(torch.nn.Module):
14
- def __init__(self):
15
- super(BaseModule, self).__init__()
16
-
17
- @property
18
- def nparams(self):
19
- """
20
- Returns number of trainable parameters of the module.
21
- """
22
- num_params = 0
23
- for name, param in self.named_parameters():
24
- if param.requires_grad:
25
- num_params += np.prod(param.detach().cpu().numpy().shape)
26
- return num_params
27
-
28
-
29
- def relocate_input(self, x: list):
30
- """
31
- Relocates provided tensors to the same device set for the module.
32
- """
33
- device = next(self.parameters()).device
34
- for i in range(len(x)):
35
- if isinstance(x[i], torch.Tensor) and x[i].device != device:
36
- x[i] = x[i].to(device)
37
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/model/blocks.py DELETED
@@ -1,567 +0,0 @@
1
- import math
2
- from abc import abstractmethod
3
- from dataclasses import dataclass
4
- from numbers import Number
5
-
6
- import torch as th
7
- import torch.nn.functional as F
8
- from choices import *
9
- from config_base import BaseConfig
10
- from torch import nn
11
-
12
- from .nn import (avg_pool_nd, conv_nd, linear, normalization,
13
- timestep_embedding, torch_checkpoint, zero_module)
14
-
15
-
16
- class ScaleAt(Enum):
17
- after_norm = 'afternorm'
18
-
19
-
20
- class TimestepBlock(nn.Module):
21
- """
22
- Any module where forward() takes timestep embeddings as a second argument.
23
- """
24
- @abstractmethod
25
- def forward(self, x, emb=None, cond=None, lateral=None):
26
- """
27
- Apply the module to `x` given `emb` timestep embeddings.
28
- """
29
-
30
-
31
- class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
32
- """
33
- A sequential module that passes timestep embeddings to the children that
34
- support it as an extra input.
35
- """
36
- def forward(self, x, emb=None, cond=None, lateral=None):
37
- for layer in self:
38
- if isinstance(layer, TimestepBlock):
39
- x = layer(x, emb=emb, cond=cond, lateral=lateral)
40
- else:
41
- x = layer(x)
42
- return x
43
-
44
-
45
- @dataclass
46
- class ResBlockConfig(BaseConfig):
47
- channels: int
48
- emb_channels: int
49
- dropout: float
50
- out_channels: int = None
51
- # condition the resblock with time (and encoder's output)
52
- use_condition: bool = True
53
- # whether to use 3x3 conv for skip path when the channels aren't matched
54
- use_conv: bool = False
55
- # dimension of conv (always 2 = 2d)
56
- dims: int = 2
57
- # gradient checkpoint
58
- use_checkpoint: bool = False
59
- up: bool = False
60
- down: bool = False
61
- # whether to condition with both time & encoder's output
62
- two_cond: bool = False
63
- # number of encoders' output channels
64
- cond_emb_channels: int = None
65
- # suggest: False
66
- has_lateral: bool = False
67
- lateral_channels: int = None
68
- # whether to init the convolution with zero weights
69
- # this is default from BeatGANs and seems to help learning
70
- use_zero_module: bool = True
71
-
72
- def __post_init__(self):
73
- self.out_channels = self.out_channels or self.channels
74
- self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
75
-
76
- def make_model(self):
77
- return ResBlock(self)
78
-
79
-
80
- class ResBlock(TimestepBlock):
81
- """
82
- A residual block that can optionally change the number of channels.
83
-
84
- total layers:
85
- in_layers
86
- - norm
87
- - act
88
- - conv
89
- out_layers
90
- - norm
91
- - (modulation)
92
- - act
93
- - conv
94
- """
95
- def __init__(self, conf: ResBlockConfig):
96
- super().__init__()
97
- self.conf = conf
98
-
99
- #############################
100
- # IN LAYERS
101
- #############################
102
- assert conf.lateral_channels is None
103
- layers = [
104
- normalization(conf.channels),
105
- nn.SiLU(),
106
- conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
107
- ]
108
- self.in_layers = nn.Sequential(*layers)
109
-
110
- self.updown = conf.up or conf.down
111
-
112
- if conf.up:
113
- self.h_upd = Upsample(conf.channels, False, conf.dims)
114
- self.x_upd = Upsample(conf.channels, False, conf.dims)
115
- elif conf.down:
116
- self.h_upd = Downsample(conf.channels, False, conf.dims)
117
- self.x_upd = Downsample(conf.channels, False, conf.dims)
118
- else:
119
- self.h_upd = self.x_upd = nn.Identity()
120
-
121
- #############################
122
- # OUT LAYERS CONDITIONS
123
- #############################
124
- if conf.use_condition:
125
- # condition layers for the out_layers
126
- self.emb_layers = nn.Sequential(
127
- nn.SiLU(),
128
- linear(conf.emb_channels, 2 * conf.out_channels),
129
- )
130
-
131
- if conf.two_cond:
132
- self.cond_emb_layers = nn.Sequential(
133
- nn.SiLU(),
134
- linear(conf.cond_emb_channels, conf.out_channels),
135
- )
136
- #############################
137
- # OUT LAYERS (ignored when there is no condition)
138
- #############################
139
- # original version
140
- conv = conv_nd(conf.dims,
141
- conf.out_channels,
142
- conf.out_channels,
143
- 3,
144
- padding=1)
145
- if conf.use_zero_module:
146
- # zere out the weights
147
- # it seems to help training
148
- conv = zero_module(conv)
149
-
150
- # construct the layers
151
- # - norm
152
- # - (modulation)
153
- # - act
154
- # - dropout
155
- # - conv
156
- layers = []
157
- layers += [
158
- normalization(conf.out_channels),
159
- nn.SiLU(),
160
- nn.Dropout(p=conf.dropout),
161
- conv,
162
- ]
163
- self.out_layers = nn.Sequential(*layers)
164
-
165
- #############################
166
- # SKIP LAYERS
167
- #############################
168
- if conf.out_channels == conf.channels:
169
- # cannot be used with gatedconv, also gatedconv is alsways used as the first block
170
- self.skip_connection = nn.Identity()
171
- else:
172
- if conf.use_conv:
173
- kernel_size = 3
174
- padding = 1
175
- else:
176
- kernel_size = 1
177
- padding = 0
178
-
179
- self.skip_connection = conv_nd(conf.dims,
180
- conf.channels,
181
- conf.out_channels,
182
- kernel_size,
183
- padding=padding)
184
-
185
- def forward(self, x, emb=None, cond=None, lateral=None):
186
- """
187
- Apply the block to a Tensor, conditioned on a timestep embedding.
188
-
189
- Args:
190
- x: input
191
- lateral: lateral connection from the encoder
192
- """
193
- return torch_checkpoint(self._forward, (x, emb, cond, lateral),
194
- self.conf.use_checkpoint)
195
-
196
- def _forward(
197
- self,
198
- x,
199
- emb=None,
200
- cond=None,
201
- lateral=None,
202
- ):
203
- """
204
- Args:
205
- lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
206
- """
207
- if self.conf.has_lateral:
208
- # lateral may be supplied even if it doesn't require
209
- # the model will take the lateral only if "has_lateral"
210
- assert lateral is not None
211
- x = th.cat([x, lateral], dim=1)
212
-
213
- if self.updown:
214
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
215
- h = in_rest(x)
216
- h = self.h_upd(h)
217
- x = self.x_upd(x)
218
- h = in_conv(h)
219
- else:
220
- h = self.in_layers(x)
221
-
222
- if self.conf.use_condition:
223
- # it's possible that the network may not receieve the time emb
224
- # this happens with autoenc and setting the time_at
225
- if emb is not None:
226
- emb_out = self.emb_layers(emb).type(h.dtype)
227
- else:
228
- emb_out = None
229
-
230
- if self.conf.two_cond:
231
- # it's possible that the network is two_cond
232
- # but it doesn't get the second condition
233
- # in which case, we ignore the second condition
234
- # and treat as if the network has one condition
235
- if cond is None:
236
- cond_out = None
237
- else:
238
- cond_out = self.cond_emb_layers(cond).type(h.dtype)
239
-
240
- if cond_out is not None:
241
- while len(cond_out.shape) < len(h.shape):
242
- cond_out = cond_out[..., None]
243
- else:
244
- cond_out = None
245
-
246
- # this is the new refactored code
247
- h = apply_conditions(
248
- h=h,
249
- emb=emb_out,
250
- cond=cond_out,
251
- layers=self.out_layers,
252
- scale_bias=1,
253
- in_channels=self.conf.out_channels,
254
- up_down_layer=None,
255
- )
256
-
257
- return self.skip_connection(x) + h
258
-
259
-
260
- def apply_conditions(
261
- h,
262
- emb=None,
263
- cond=None,
264
- layers: nn.Sequential = None,
265
- scale_bias: float = 1,
266
- in_channels: int = 512,
267
- up_down_layer: nn.Module = None,
268
- ):
269
- """
270
- apply conditions on the feature maps
271
-
272
- Args:
273
- emb: time conditional (ready to scale + shift)
274
- cond: encoder's conditional (read to scale + shift)
275
- """
276
- two_cond = emb is not None and cond is not None
277
-
278
- if emb is not None:
279
- # adjusting shapes
280
- while len(emb.shape) < len(h.shape):
281
- emb = emb[..., None]
282
-
283
- if two_cond:
284
- # adjusting shapes
285
- while len(cond.shape) < len(h.shape):
286
- cond = cond[..., None]
287
- # time first
288
- scale_shifts = [emb, cond]
289
- else:
290
- # "cond" is not used with single cond mode
291
- scale_shifts = [emb]
292
-
293
- # support scale, shift or shift only
294
- for i, each in enumerate(scale_shifts):
295
- if each is None:
296
- # special case: the condition is not provided
297
- a = None
298
- b = None
299
- else:
300
- if each.shape[1] == in_channels * 2:
301
- a, b = th.chunk(each, 2, dim=1)
302
- else:
303
- a = each
304
- b = None
305
- scale_shifts[i] = (a, b)
306
-
307
- # condition scale bias could be a list
308
- if isinstance(scale_bias, Number):
309
- biases = [scale_bias] * len(scale_shifts)
310
- else:
311
- # a list
312
- biases = scale_bias
313
-
314
- # default, the scale & shift are applied after the group norm but BEFORE SiLU
315
- pre_layers, post_layers = layers[0], layers[1:]
316
-
317
- # spilt the post layer to be able to scale up or down before conv
318
- # post layers will contain only the conv
319
- mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
320
-
321
- h = pre_layers(h)
322
- # scale and shift for each condition
323
- for i, (scale, shift) in enumerate(scale_shifts):
324
- # if scale is None, it indicates that the condition is not provided
325
- if scale is not None:
326
- h = h * (biases[i] + scale)
327
- if shift is not None:
328
- h = h + shift
329
- h = mid_layers(h)
330
-
331
- # upscale or downscale if any just before the last conv
332
- if up_down_layer is not None:
333
- h = up_down_layer(h)
334
- h = post_layers(h)
335
- return h
336
-
337
-
338
- class Upsample(nn.Module):
339
- """
340
- An upsampling layer with an optional convolution.
341
-
342
- :param channels: channels in the inputs and outputs.
343
- :param use_conv: a bool determining if a convolution is applied.
344
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
345
- upsampling occurs in the inner-two dimensions.
346
- """
347
- def __init__(self, channels, use_conv, dims=2, out_channels=None):
348
- super().__init__()
349
- self.channels = channels
350
- self.out_channels = out_channels or channels
351
- self.use_conv = use_conv
352
- self.dims = dims
353
- if use_conv:
354
- self.conv = conv_nd(dims,
355
- self.channels,
356
- self.out_channels,
357
- 3,
358
- padding=1)
359
-
360
- def forward(self, x):
361
- assert x.shape[1] == self.channels
362
- if self.dims == 3:
363
- x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
364
- mode="nearest")
365
- else:
366
- x = F.interpolate(x, scale_factor=2, mode="nearest")
367
- if self.use_conv:
368
- x = self.conv(x)
369
- return x
370
-
371
-
372
- class Downsample(nn.Module):
373
- """
374
- A downsampling layer with an optional convolution.
375
-
376
- :param channels: channels in the inputs and outputs.
377
- :param use_conv: a bool determining if a convolution is applied.
378
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
379
- downsampling occurs in the inner-two dimensions.
380
- """
381
- def __init__(self, channels, use_conv, dims=2, out_channels=None):
382
- super().__init__()
383
- self.channels = channels
384
- self.out_channels = out_channels or channels
385
- self.use_conv = use_conv
386
- self.dims = dims
387
- stride = 2 if dims != 3 else (1, 2, 2)
388
- if use_conv:
389
- self.op = conv_nd(dims,
390
- self.channels,
391
- self.out_channels,
392
- 3,
393
- stride=stride,
394
- padding=1)
395
- else:
396
- assert self.channels == self.out_channels
397
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
398
-
399
- def forward(self, x):
400
- assert x.shape[1] == self.channels
401
- return self.op(x)
402
-
403
-
404
- class AttentionBlock(nn.Module):
405
- """
406
- An attention block that allows spatial positions to attend to each other.
407
-
408
- Originally ported from here, but adapted to the N-d case.
409
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
410
- """
411
- def __init__(
412
- self,
413
- channels,
414
- num_heads=1,
415
- num_head_channels=-1,
416
- use_checkpoint=False,
417
- use_new_attention_order=False,
418
- ):
419
- super().__init__()
420
- self.channels = channels
421
- if num_head_channels == -1:
422
- self.num_heads = num_heads
423
- else:
424
- assert (
425
- channels % num_head_channels == 0
426
- ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
427
- self.num_heads = channels // num_head_channels
428
- self.use_checkpoint = use_checkpoint
429
- self.norm = normalization(channels)
430
- self.qkv = conv_nd(1, channels, channels * 3, 1)
431
- if use_new_attention_order:
432
- # split qkv before split heads
433
- self.attention = QKVAttention(self.num_heads)
434
- else:
435
- # split heads before split qkv
436
- self.attention = QKVAttentionLegacy(self.num_heads)
437
-
438
- self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
439
-
440
- def forward(self, x):
441
- return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
442
-
443
- def _forward(self, x):
444
- b, c, *spatial = x.shape
445
- x = x.reshape(b, c, -1)
446
- qkv = self.qkv(self.norm(x))
447
- h = self.attention(qkv)
448
- h = self.proj_out(h)
449
- return (x + h).reshape(b, c, *spatial)
450
-
451
-
452
- def count_flops_attn(model, _x, y):
453
- """
454
- A counter for the `thop` package to count the operations in an
455
- attention operation.
456
- Meant to be used like:
457
- macs, params = thop.profile(
458
- model,
459
- inputs=(inputs, timestamps),
460
- custom_ops={QKVAttention: QKVAttention.count_flops},
461
- )
462
- """
463
- b, c, *spatial = y[0].shape
464
- num_spatial = int(np.prod(spatial))
465
- # We perform two matmuls with the same number of ops.
466
- # The first computes the weight matrix, the second computes
467
- # the combination of the value vectors.
468
- matmul_ops = 2 * b * (num_spatial**2) * c
469
- model.total_ops += th.DoubleTensor([matmul_ops])
470
-
471
-
472
- class QKVAttentionLegacy(nn.Module):
473
- """
474
- A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
475
- """
476
- def __init__(self, n_heads):
477
- super().__init__()
478
- self.n_heads = n_heads
479
-
480
- def forward(self, qkv):
481
- """
482
- Apply QKV attention.
483
-
484
- :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
485
- :return: an [N x (H * C) x T] tensor after attention.
486
- """
487
- bs, width, length = qkv.shape
488
- assert width % (3 * self.n_heads) == 0
489
- ch = width // (3 * self.n_heads)
490
- q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
491
- dim=1)
492
- scale = 1 / math.sqrt(math.sqrt(ch))
493
- weight = th.einsum(
494
- "bct,bcs->bts", q * scale,
495
- k * scale) # More stable with f16 than dividing afterwards
496
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
497
- a = th.einsum("bts,bcs->bct", weight, v)
498
- return a.reshape(bs, -1, length)
499
-
500
- @staticmethod
501
- def count_flops(model, _x, y):
502
- return count_flops_attn(model, _x, y)
503
-
504
-
505
- class QKVAttention(nn.Module):
506
- """
507
- A module which performs QKV attention and splits in a different order.
508
- """
509
- def __init__(self, n_heads):
510
- super().__init__()
511
- self.n_heads = n_heads
512
-
513
- def forward(self, qkv):
514
- """
515
- Apply QKV attention.
516
-
517
- :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
518
- :return: an [N x (H * C) x T] tensor after attention.
519
- """
520
- bs, width, length = qkv.shape
521
- assert width % (3 * self.n_heads) == 0
522
- ch = width // (3 * self.n_heads)
523
- q, k, v = qkv.chunk(3, dim=1)
524
- scale = 1 / math.sqrt(math.sqrt(ch))
525
- weight = th.einsum(
526
- "bct,bcs->bts",
527
- (q * scale).view(bs * self.n_heads, ch, length),
528
- (k * scale).view(bs * self.n_heads, ch, length),
529
- ) # More stable with f16 than dividing afterwards
530
- weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
531
- a = th.einsum("bts,bcs->bct", weight,
532
- v.reshape(bs * self.n_heads, ch, length))
533
- return a.reshape(bs, -1, length)
534
-
535
- @staticmethod
536
- def count_flops(model, _x, y):
537
- return count_flops_attn(model, _x, y)
538
-
539
-
540
- class AttentionPool2d(nn.Module):
541
- """
542
- Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
543
- """
544
- def __init__(
545
- self,
546
- spacial_dim: int,
547
- embed_dim: int,
548
- num_heads_channels: int,
549
- output_dim: int = None,
550
- ):
551
- super().__init__()
552
- self.positional_embedding = nn.Parameter(
553
- th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
554
- self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
555
- self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
556
- self.num_heads = embed_dim // num_heads_channels
557
- self.attention = QKVAttention(self.num_heads)
558
-
559
- def forward(self, x):
560
- b, c, *_spatial = x.shape
561
- x = x.reshape(b, c, -1) # NC(HW)
562
- x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
563
- x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
564
- x = self.qkv_proj(x)
565
- x = self.attention(x)
566
- x = self.c_proj(x)
567
- return x[:, :, 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/model/diffusion.py DELETED
@@ -1,294 +0,0 @@
1
- # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2
- # This program is free software; you can redistribute it and/or modify
3
- # it under the terms of the MIT License.
4
- # This program is distributed in the hope that it will be useful,
5
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
6
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7
- # MIT License for more details.
8
-
9
- import math
10
- import torch
11
- from einops import rearrange
12
-
13
- from model.base import BaseModule
14
-
15
-
16
- class Mish(BaseModule):
17
- def forward(self, x):
18
- return x * torch.tanh(torch.nn.functional.softplus(x))
19
-
20
-
21
- class Upsample(BaseModule):
22
- def __init__(self, dim):
23
- super(Upsample, self).__init__()
24
- self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
25
-
26
- def forward(self, x):
27
- return self.conv(x)
28
-
29
-
30
- class Downsample(BaseModule):
31
- def __init__(self, dim):
32
- super(Downsample, self).__init__()
33
- self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
34
-
35
- def forward(self, x):
36
- return self.conv(x)
37
-
38
-
39
- class Rezero(BaseModule):
40
- def __init__(self, fn):
41
- super(Rezero, self).__init__()
42
- self.fn = fn
43
- self.g = torch.nn.Parameter(torch.zeros(1))
44
-
45
- def forward(self, x):
46
- return self.fn(x) * self.g
47
-
48
-
49
- class Block(BaseModule):
50
- def __init__(self, dim, dim_out, groups=8):
51
- super(Block, self).__init__()
52
- self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
53
- padding=1), torch.nn.GroupNorm(
54
- groups, dim_out), Mish())
55
-
56
- def forward(self, x, mask):
57
- output = self.block(x * mask)
58
- return output * mask
59
-
60
-
61
- class ResnetBlock(BaseModule):
62
- def __init__(self, dim, dim_out, time_emb_dim, groups=8):
63
- super(ResnetBlock, self).__init__()
64
- self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
65
- dim_out))
66
-
67
- self.block1 = Block(dim, dim_out, groups=groups)
68
- self.block2 = Block(dim_out, dim_out, groups=groups)
69
- if dim != dim_out:
70
- self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
71
- else:
72
- self.res_conv = torch.nn.Identity()
73
-
74
- def forward(self, x, mask, time_emb):
75
- h = self.block1(x, mask)
76
- h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
77
- h = self.block2(h, mask)
78
- output = h + self.res_conv(x * mask)
79
- return output
80
-
81
-
82
- class LinearAttention(BaseModule):
83
- def __init__(self, dim, heads=4, dim_head=32):
84
- super(LinearAttention, self).__init__()
85
- self.heads = heads
86
- hidden_dim = dim_head * heads
87
- self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
88
- self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
89
-
90
- def forward(self, x):
91
- b, c, h, w = x.shape
92
- qkv = self.to_qkv(x)
93
- q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
94
- heads = self.heads, qkv=3)
95
- k = k.softmax(dim=-1)
96
- context = torch.einsum('bhdn,bhen->bhde', k, v)
97
- out = torch.einsum('bhde,bhdn->bhen', context, q)
98
- out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
99
- heads=self.heads, h=h, w=w)
100
- return self.to_out(out)
101
-
102
-
103
- class Residual(BaseModule):
104
- def __init__(self, fn):
105
- super(Residual, self).__init__()
106
- self.fn = fn
107
-
108
- def forward(self, x, *args, **kwargs):
109
- output = self.fn(x, *args, **kwargs) + x
110
- return output
111
-
112
-
113
- class SinusoidalPosEmb(BaseModule):
114
- def __init__(self, dim):
115
- super(SinusoidalPosEmb, self).__init__()
116
- self.dim = dim
117
-
118
- def forward(self, x, scale=1000):
119
- device = x.device
120
- half_dim = self.dim // 2
121
- emb = math.log(10000) / (half_dim - 1)
122
- emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
123
- emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
124
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
125
- return emb
126
-
127
-
128
- class GradLogPEstimator2d(BaseModule):
129
- def __init__(self, dim, dim_mults=(1, 2, 4), groups=8,
130
- n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
131
- super(GradLogPEstimator2d, self).__init__()
132
- self.dim = dim
133
- self.dim_mults = dim_mults
134
- self.groups = groups
135
- self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
136
- self.spk_emb_dim = spk_emb_dim
137
- self.pe_scale = pe_scale
138
-
139
- if n_spks > 1:
140
- self.spk_mlp = torch.nn.Sequential(torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(),
141
- torch.nn.Linear(spk_emb_dim * 4, n_feats))
142
- self.time_pos_emb = SinusoidalPosEmb(dim)
143
- self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
144
- torch.nn.Linear(dim * 4, dim))
145
-
146
- dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
147
- in_out = list(zip(dims[:-1], dims[1:]))
148
- self.downs = torch.nn.ModuleList([])
149
- self.ups = torch.nn.ModuleList([])
150
- num_resolutions = len(in_out)
151
-
152
- for ind, (dim_in, dim_out) in enumerate(in_out):
153
- is_last = ind >= (num_resolutions - 1)
154
- self.downs.append(torch.nn.ModuleList([
155
- ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
156
- ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
157
- Residual(Rezero(LinearAttention(dim_out))),
158
- Downsample(dim_out) if not is_last else torch.nn.Identity()]))
159
-
160
- mid_dim = dims[-1]
161
- self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
162
- self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
163
- self.mid_block2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
164
-
165
- for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
166
- self.ups.append(torch.nn.ModuleList([
167
- ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
168
- ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
169
- Residual(Rezero(LinearAttention(dim_in))),
170
- Upsample(dim_in)]))
171
- self.final_block = Block(dim, dim)
172
- self.final_conv = torch.nn.Conv2d(dim, 1, 1)
173
-
174
- def forward(self, x, mask, mu, t, spk=None):
175
- if not isinstance(spk, type(None)):
176
- s = self.spk_mlp(spk)
177
-
178
- t = self.time_pos_emb(t, scale=self.pe_scale)
179
- t = self.mlp(t)
180
-
181
- if self.n_spks < 2:
182
- x = torch.stack([mu, x], 1)
183
- else:
184
- s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
185
- x = torch.stack([mu, x, s], 1)
186
- mask = mask.unsqueeze(1)
187
-
188
- hiddens = []
189
- masks = [mask]
190
- for resnet1, resnet2, attn, downsample in self.downs:
191
- mask_down = masks[-1]
192
- x = resnet1(x, mask_down, t)
193
- x = resnet2(x, mask_down, t)
194
- x = attn(x)
195
- hiddens.append(x)
196
- x = downsample(x * mask_down)
197
- masks.append(mask_down[:, :, :, ::2])
198
-
199
- masks = masks[:-1]
200
- mask_mid = masks[-1]
201
- x = self.mid_block1(x, mask_mid, t)
202
- x = self.mid_attn(x)
203
- x = self.mid_block2(x, mask_mid, t)
204
-
205
- for resnet1, resnet2, attn, upsample in self.ups:
206
- mask_up = masks.pop()
207
- x = torch.cat((x, hiddens.pop()), dim=1)
208
- x = resnet1(x, mask_up, t)
209
- x = resnet2(x, mask_up, t)
210
- x = attn(x)
211
- x = upsample(x * mask_up)
212
-
213
- x = self.final_block(x, mask)
214
- output = self.final_conv(x * mask)
215
-
216
- return (output * mask).squeeze(1)
217
-
218
-
219
- def get_noise(t, beta_init, beta_term, cumulative=False):
220
- if cumulative:
221
- noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
222
- else:
223
- noise = beta_init + (beta_term - beta_init)*t
224
- return noise
225
-
226
-
227
- class Diffusion(BaseModule):
228
- def __init__(self, n_feats, dim,
229
- n_spks=1, spk_emb_dim=64,
230
- beta_min=0.05, beta_max=20, pe_scale=1000):
231
- super(Diffusion, self).__init__()
232
- self.n_feats = n_feats
233
- self.dim = dim
234
- self.n_spks = n_spks
235
- self.spk_emb_dim = spk_emb_dim
236
- self.beta_min = beta_min
237
- self.beta_max = beta_max
238
- self.pe_scale = pe_scale
239
-
240
- self.estimator = GradLogPEstimator2d(dim, n_spks=n_spks,
241
- spk_emb_dim=spk_emb_dim,
242
- pe_scale=pe_scale)
243
-
244
- def forward_diffusion(self, x0, mask, mu, t):
245
- time = t.unsqueeze(-1).unsqueeze(-1)
246
- cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
247
- mean = x0*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
248
- variance = 1.0 - torch.exp(-cum_noise)
249
- z = torch.randn(x0.shape, dtype=x0.dtype, device=x0.device,
250
- requires_grad=False)
251
- xt = mean + z * torch.sqrt(variance)
252
- return xt * mask, z * mask
253
-
254
- @torch.no_grad()
255
- def reverse_diffusion(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
256
- h = 1.0 / n_timesteps
257
- xt = z * mask
258
- for i in range(n_timesteps):
259
- t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype,
260
- device=z.device)
261
- time = t.unsqueeze(-1).unsqueeze(-1)
262
- noise_t = get_noise(time, self.beta_min, self.beta_max,
263
- cumulative=False)
264
- if stoc: # adds stochastic term
265
- dxt_det = 0.5 * (mu - xt) - self.estimator(xt, mask, mu, t, spk)
266
- dxt_det = dxt_det * noise_t * h
267
- dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
268
- requires_grad=False)
269
- dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
270
- dxt = dxt_det + dxt_stoc
271
- else:
272
- dxt = 0.5 * (mu - xt - self.estimator(xt, mask, mu, t, spk))
273
- dxt = dxt * noise_t * h
274
- xt = (xt - dxt) * mask
275
- return xt
276
-
277
- @torch.no_grad()
278
- def forward(self, z, mask, mu, n_timesteps, stoc=False, spk=None):
279
- return self.reverse_diffusion(z, mask, mu, n_timesteps, stoc, spk)
280
-
281
- def loss_t(self, x0, mask, mu, t, spk=None):
282
- xt, z = self.forward_diffusion(x0, mask, mu, t)
283
- time = t.unsqueeze(-1).unsqueeze(-1)
284
- cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
285
- noise_estimation = self.estimator(xt, mask, mu, t, spk)
286
- noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
287
- loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feats)
288
- return loss, xt
289
-
290
- def compute_loss(self, x0, mask, mu, spk=None, offset=1e-5):
291
- t = torch.rand(x0.shape[0], dtype=x0.dtype, device=x0.device,
292
- requires_grad=False)
293
- t = torch.clamp(t, offset, 1.0 - offset)
294
- return self.loss_t(x0, mask, mu, t, spk)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/model/latentnet.py DELETED
@@ -1,193 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
- from enum import Enum
4
- from typing import NamedTuple, Tuple
5
-
6
- import torch
7
- from choices import *
8
- from config_base import BaseConfig
9
- from torch import nn
10
- from torch.nn import init
11
-
12
- from .blocks import *
13
- from .nn import timestep_embedding
14
- from .unet import *
15
-
16
-
17
- class LatentNetType(Enum):
18
- none = 'none'
19
- # injecting inputs into the hidden layers
20
- skip = 'skip'
21
-
22
-
23
- class LatentNetReturn(NamedTuple):
24
- pred: torch.Tensor = None
25
-
26
-
27
- @dataclass
28
- class MLPSkipNetConfig(BaseConfig):
29
- """
30
- default MLP for the latent DPM in the paper!
31
- """
32
- num_channels: int
33
- skip_layers: Tuple[int]
34
- num_hid_channels: int
35
- num_layers: int
36
- num_time_emb_channels: int = 64
37
- activation: Activation = Activation.silu
38
- use_norm: bool = True
39
- condition_bias: float = 1
40
- dropout: float = 0
41
- last_act: Activation = Activation.none
42
- num_time_layers: int = 2
43
- time_last_act: bool = False
44
-
45
- def make_model(self):
46
- return MLPSkipNet(self)
47
-
48
-
49
- class MLPSkipNet(nn.Module):
50
- """
51
- concat x to hidden layers
52
-
53
- default MLP for the latent DPM in the paper!
54
- """
55
- def __init__(self, conf: MLPSkipNetConfig):
56
- super().__init__()
57
- self.conf = conf
58
-
59
- layers = []
60
- for i in range(conf.num_time_layers):
61
- if i == 0:
62
- a = conf.num_time_emb_channels
63
- b = conf.num_channels
64
- else:
65
- a = conf.num_channels
66
- b = conf.num_channels
67
- layers.append(nn.Linear(a, b))
68
- if i < conf.num_time_layers - 1 or conf.time_last_act:
69
- layers.append(conf.activation.get_act())
70
- self.time_embed = nn.Sequential(*layers)
71
-
72
- self.layers = nn.ModuleList([])
73
- for i in range(conf.num_layers):
74
- if i == 0:
75
- act = conf.activation
76
- norm = conf.use_norm
77
- cond = True
78
- a, b = conf.num_channels, conf.num_hid_channels
79
- dropout = conf.dropout
80
- elif i == conf.num_layers - 1:
81
- act = Activation.none
82
- norm = False
83
- cond = False
84
- a, b = conf.num_hid_channels, conf.num_channels
85
- dropout = 0
86
- else:
87
- act = conf.activation
88
- norm = conf.use_norm
89
- cond = True
90
- a, b = conf.num_hid_channels, conf.num_hid_channels
91
- dropout = conf.dropout
92
-
93
- if i in conf.skip_layers:
94
- a += conf.num_channels
95
-
96
- self.layers.append(
97
- MLPLNAct(
98
- a,
99
- b,
100
- norm=norm,
101
- activation=act,
102
- cond_channels=conf.num_channels,
103
- use_cond=cond,
104
- condition_bias=conf.condition_bias,
105
- dropout=dropout,
106
- ))
107
- self.last_act = conf.last_act.get_act()
108
-
109
- def forward(self, x, t, **kwargs):
110
- t = timestep_embedding(t, self.conf.num_time_emb_channels)
111
- cond = self.time_embed(t)
112
- h = x
113
- for i in range(len(self.layers)):
114
- if i in self.conf.skip_layers:
115
- # injecting input into the hidden layers
116
- h = torch.cat([h, x], dim=1)
117
- h = self.layers[i].forward(x=h, cond=cond)
118
- h = self.last_act(h)
119
- return LatentNetReturn(h)
120
-
121
-
122
- class MLPLNAct(nn.Module):
123
- def __init__(
124
- self,
125
- in_channels: int,
126
- out_channels: int,
127
- norm: bool,
128
- use_cond: bool,
129
- activation: Activation,
130
- cond_channels: int,
131
- condition_bias: float = 0,
132
- dropout: float = 0,
133
- ):
134
- super().__init__()
135
- self.activation = activation
136
- self.condition_bias = condition_bias
137
- self.use_cond = use_cond
138
-
139
- self.linear = nn.Linear(in_channels, out_channels)
140
- self.act = activation.get_act()
141
- if self.use_cond:
142
- self.linear_emb = nn.Linear(cond_channels, out_channels)
143
- self.cond_layers = nn.Sequential(self.act, self.linear_emb)
144
- if norm:
145
- self.norm = nn.LayerNorm(out_channels)
146
- else:
147
- self.norm = nn.Identity()
148
-
149
- if dropout > 0:
150
- self.dropout = nn.Dropout(p=dropout)
151
- else:
152
- self.dropout = nn.Identity()
153
-
154
- self.init_weights()
155
-
156
- def init_weights(self):
157
- for module in self.modules():
158
- if isinstance(module, nn.Linear):
159
- if self.activation == Activation.relu:
160
- init.kaiming_normal_(module.weight,
161
- a=0,
162
- nonlinearity='relu')
163
- elif self.activation == Activation.lrelu:
164
- init.kaiming_normal_(module.weight,
165
- a=0.2,
166
- nonlinearity='leaky_relu')
167
- elif self.activation == Activation.silu:
168
- init.kaiming_normal_(module.weight,
169
- a=0,
170
- nonlinearity='relu')
171
- else:
172
- # leave it as default
173
- pass
174
-
175
- def forward(self, x, cond=None):
176
- x = self.linear(x)
177
- if self.use_cond:
178
- # (n, c) or (n, c * 2)
179
- cond = self.cond_layers(cond)
180
- cond = (cond, None)
181
-
182
- # scale shift first
183
- x = x * (self.condition_bias + cond[0])
184
- if cond[1] is not None:
185
- x = x + cond[1]
186
- # then norm
187
- x = self.norm(x)
188
- else:
189
- # no condition
190
- x = self.norm(x)
191
- x = self.act(x)
192
- x = self.dropout(x)
193
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/model/nn.py DELETED
@@ -1,137 +0,0 @@
1
- """
2
- Various utilities for neural networks.
3
- """
4
-
5
- from enum import Enum
6
- import math
7
- from typing import Optional
8
-
9
- import torch as th
10
- import torch.nn as nn
11
- import torch.utils.checkpoint
12
-
13
- import torch.nn.functional as F
14
-
15
-
16
- # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
17
- class SiLU(nn.Module):
18
- # @th.jit.script
19
- def forward(self, x):
20
- return x * th.sigmoid(x)
21
-
22
-
23
- class GroupNorm32(nn.GroupNorm):
24
- def forward(self, x):
25
- return super().forward(x.float()).type(x.dtype)
26
-
27
-
28
- def conv_nd(dims, *args, **kwargs):
29
- """
30
- Create a 1D, 2D, or 3D convolution module.
31
- """
32
- if dims == 1:
33
- return nn.Conv1d(*args, **kwargs)
34
- elif dims == 2:
35
- return nn.Conv2d(*args, **kwargs)
36
- elif dims == 3:
37
- return nn.Conv3d(*args, **kwargs)
38
- raise ValueError(f"unsupported dimensions: {dims}")
39
-
40
-
41
- def linear(*args, **kwargs):
42
- """
43
- Create a linear module.
44
- """
45
- return nn.Linear(*args, **kwargs)
46
-
47
-
48
- def avg_pool_nd(dims, *args, **kwargs):
49
- """
50
- Create a 1D, 2D, or 3D average pooling module.
51
- """
52
- if dims == 1:
53
- return nn.AvgPool1d(*args, **kwargs)
54
- elif dims == 2:
55
- return nn.AvgPool2d(*args, **kwargs)
56
- elif dims == 3:
57
- return nn.AvgPool3d(*args, **kwargs)
58
- raise ValueError(f"unsupported dimensions: {dims}")
59
-
60
-
61
- def update_ema(target_params, source_params, rate=0.99):
62
- """
63
- Update target parameters to be closer to those of source parameters using
64
- an exponential moving average.
65
-
66
- :param target_params: the target parameter sequence.
67
- :param source_params: the source parameter sequence.
68
- :param rate: the EMA rate (closer to 1 means slower).
69
- """
70
- for targ, src in zip(target_params, source_params):
71
- targ.detach().mul_(rate).add_(src, alpha=1 - rate)
72
-
73
-
74
- def zero_module(module):
75
- """
76
- Zero out the parameters of a module and return it.
77
- """
78
- for p in module.parameters():
79
- p.detach().zero_()
80
- return module
81
-
82
-
83
- def scale_module(module, scale):
84
- """
85
- Scale the parameters of a module and return it.
86
- """
87
- for p in module.parameters():
88
- p.detach().mul_(scale)
89
- return module
90
-
91
-
92
- def mean_flat(tensor):
93
- """
94
- Take the mean over all non-batch dimensions.
95
- """
96
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
97
-
98
-
99
- def normalization(channels):
100
- """
101
- Make a standard normalization layer.
102
-
103
- :param channels: number of input channels.
104
- :return: an nn.Module for normalization.
105
- """
106
- return GroupNorm32(min(32, channels), channels)
107
-
108
-
109
- def timestep_embedding(timesteps, dim, max_period=10000):
110
- """
111
- Create sinusoidal timestep embeddings.
112
-
113
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
114
- These may be fractional.
115
- :param dim: the dimension of the output.
116
- :param max_period: controls the minimum frequency of the embeddings.
117
- :return: an [N x dim] Tensor of positional embeddings.
118
- """
119
- half = dim // 2
120
- freqs = th.exp(-math.log(max_period) *
121
- th.arange(start=0, end=half, dtype=th.float32) /
122
- half).to(device=timesteps.device)
123
- args = timesteps[:, None].float() * freqs[None]
124
- embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
125
- if dim % 2:
126
- embedding = th.cat(
127
- [embedding, th.zeros_like(embedding[:, :1])], dim=-1)
128
- return embedding
129
-
130
-
131
- def torch_checkpoint(func, args, flag, preserve_rng_state=False):
132
- # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
133
- if flag:
134
- return torch.utils.checkpoint.checkpoint(
135
- func, *args, preserve_rng_state=preserve_rng_state)
136
- else:
137
- return func(*args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/model/seq2seq.py DELETED
@@ -1,141 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from model.base import BaseModule
4
- from espnet.nets.pytorch_backend.conformer.encoder import Encoder as ConformerEncoder
5
- import torch.nn.functional as F
6
-
7
- class LSTM(nn.Module):
8
- def __init__(self, motion_dim, output_dim, num_layers=2, hidden_dim=128):
9
- super().__init__()
10
- self.lstm = nn.LSTM(input_size=motion_dim, hidden_size=hidden_dim,
11
- num_layers=num_layers, batch_first=True)
12
- self.fc = nn.Linear(hidden_dim, output_dim)
13
-
14
- def forward(self, x):
15
- x, _ = self.lstm(x)
16
- return self.fc(x)
17
-
18
- class DiffusionPredictor(BaseModule):
19
- def __init__(self, conf):
20
- super(DiffusionPredictor, self).__init__()
21
-
22
- self.infer_type = conf.infer_type
23
-
24
- self.initialize_layers(conf)
25
- print(f'infer_type: {self.infer_type}')
26
-
27
- def create_conformer_encoder(self, attention_dim, num_blocks):
28
- return ConformerEncoder(
29
- idim=0, attention_dim=attention_dim, attention_heads=2, linear_units=attention_dim,
30
- num_blocks=num_blocks, input_layer=None, dropout_rate=0.2, positional_dropout_rate=0.2,
31
- attention_dropout_rate=0.2, normalize_before=False, concat_after=False,
32
- positionwise_layer_type="linear", positionwise_conv_kernel_size=3, macaron_style=True,
33
- pos_enc_layer_type="rel_pos", selfattention_layer_type="rel_selfattn", use_cnn_module=True,
34
- cnn_module_kernel=13)
35
-
36
- def initialize_layers(self, conf, mfcc_dim=39, hubert_dim=1024, speech_layers=4, speech_dim=512, decoder_dim=1024, motion_start_dim=512, HAL_layers=25):
37
-
38
- self.conf = conf
39
- # Speech downsampling
40
- if self.infer_type.startswith('mfcc'):
41
- # from 100 hz to 25 hz
42
- self.down_sample1 = nn.Conv1d(mfcc_dim, 256, kernel_size=3, stride=2, padding=1)
43
- self.down_sample2 = nn.Conv1d(256, speech_dim, kernel_size=3, stride=2, padding=1)
44
- elif self.infer_type.startswith('hubert'):
45
- # from 50 hz to 25 hz
46
- self.down_sample1 = nn.Conv1d(hubert_dim, speech_dim, kernel_size=3, stride=2, padding=1)
47
-
48
- self.weights = nn.Parameter(torch.zeros(HAL_layers))
49
- self.speech_encoder = self.create_conformer_encoder(speech_dim, speech_layers)
50
- else:
51
- print('infer_type not supported')
52
-
53
- # Encoders & Deocoders
54
- self.coarse_decoder = self.create_conformer_encoder(decoder_dim, conf.decoder_layers)
55
-
56
- # LSTM predictors for Variance Adapter
57
- if self.infer_type != 'hubert_audio_only':
58
- self.pose_predictor = LSTM(speech_dim, 3)
59
- self.pose_encoder = LSTM(3, speech_dim)
60
-
61
- if 'full_control' in self.infer_type:
62
- self.location_predictor = LSTM(speech_dim, 1)
63
- self.location_encoder = LSTM(1, speech_dim)
64
- self.face_scale_predictor = LSTM(speech_dim, 1)
65
- self.face_scale_encoder = LSTM(1, speech_dim)
66
-
67
- # Linear transformations
68
- self.init_code_proj = nn.Sequential(nn.Linear(motion_start_dim, 128))
69
- self.noisy_encoder = nn.Sequential(nn.Linear(conf.motion_dim, 128))
70
- self.t_encoder = nn.Sequential(nn.Linear(1, 128))
71
- self.encoder_direction_code = nn.Linear(conf.motion_dim, 128)
72
-
73
- self.out_proj = nn.Linear(decoder_dim, conf.motion_dim)
74
-
75
-
76
- def forward(self, initial_code, direction_code, seq_input_vector, face_location, face_scale, yaw_pitch_roll, noisy_x, t_emb, control_flag=False):
77
-
78
- if self.infer_type.startswith('mfcc'):
79
- x = self.mfcc_speech_downsample(seq_input_vector)
80
- elif self.infer_type.startswith('hubert'):
81
- norm_weights = F.softmax(self.weights, dim=-1)
82
- weighted_feature = (norm_weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) * seq_input_vector).sum(dim=1)
83
- x = self.down_sample1(weighted_feature.transpose(1,2)).transpose(1,2)
84
- x, _ = self.speech_encoder(x, masks=None)
85
- predicted_location, predicted_scale, predicted_pose = face_location, face_scale, yaw_pitch_roll
86
- if self.infer_type != 'hubert_audio_only':
87
- print(f'pose controllable. control_flag: {control_flag}')
88
- x, predicted_location, predicted_scale, predicted_pose = self.adjust_features(x, face_location, face_scale, yaw_pitch_roll, control_flag)
89
- concatenated_features = self.combine_features(x, initial_code, direction_code, noisy_x, t_emb) # initial_code and direction_code serve as a motion guide extracted from the reference image. This aims to tell the model what the starting motion should be.
90
- outputs = self.decode_features(concatenated_features)
91
- return outputs, predicted_location, predicted_scale, predicted_pose
92
-
93
- def mfcc_speech_downsample(self, seq_input_vector):
94
- x = self.down_sample1(seq_input_vector.transpose(1,2))
95
- return self.down_sample2(x).transpose(1,2)
96
-
97
- def adjust_features(self, x, face_location, face_scale, yaw_pitch_roll, control_flag):
98
- predicted_location, predicted_scale = 0, 0
99
- if 'full_control' in self.infer_type:
100
- print(f'full controllable. control_flag: {control_flag}')
101
- x_residual, predicted_location = self.adjust_location(x, face_location, control_flag)
102
- x = x + x_residual
103
-
104
- x_residual, predicted_scale = self.adjust_scale(x, face_scale, control_flag)
105
- x = x + x_residual
106
-
107
- x_residual, predicted_pose= self.adjust_pose(x, yaw_pitch_roll, control_flag)
108
- x = x + x_residual
109
- return x, predicted_location, predicted_scale, predicted_pose
110
-
111
- def adjust_location(self, x, face_location, control_flag):
112
- if control_flag:
113
- predicted_location = face_location
114
- else:
115
- predicted_location = self.location_predictor(x)
116
- return self.location_encoder(predicted_location), predicted_location
117
-
118
- def adjust_scale(self, x, face_scale, control_flag):
119
- if control_flag:
120
- predicted_face_scale = face_scale
121
- else:
122
- predicted_face_scale = self.face_scale_predictor(x)
123
- return self.face_scale_encoder(predicted_face_scale), predicted_face_scale
124
-
125
- def adjust_pose(self, x, yaw_pitch_roll, control_flag):
126
- if control_flag:
127
- predicted_pose = yaw_pitch_roll
128
- else:
129
- predicted_pose = self.pose_predictor(x)
130
- return self.pose_encoder(predicted_pose), predicted_pose
131
-
132
- def combine_features(self, x, initial_code, direction_code, noisy_x, t_emb):
133
- init_code_proj = self.init_code_proj(initial_code).unsqueeze(1).repeat(1, x.size(1), 1)
134
- noisy_feature = self.noisy_encoder(noisy_x)
135
- t_emb_feature = self.t_encoder(t_emb.unsqueeze(1).float()).unsqueeze(1).repeat(1, x.size(1), 1)
136
- direction_code_feature = self.encoder_direction_code(direction_code).unsqueeze(1).repeat(1, x.size(1), 1)
137
- return torch.cat((x, direction_code_feature, init_code_proj, noisy_feature, t_emb_feature), dim=-1)
138
-
139
- def decode_features(self, concatenated_features):
140
- outputs, _ = self.coarse_decoder(concatenated_features, masks=None)
141
- return self.out_proj(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/model/unet.py DELETED
@@ -1,552 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
- from numbers import Number
4
- from typing import NamedTuple, Tuple, Union
5
-
6
- import numpy as np
7
- import torch as th
8
- from torch import nn
9
- import torch.nn.functional as F
10
- from choices import *
11
- from config_base import BaseConfig
12
- from .blocks import *
13
-
14
- from .nn import (conv_nd, linear, normalization, timestep_embedding,
15
- torch_checkpoint, zero_module)
16
-
17
-
18
- @dataclass
19
- class BeatGANsUNetConfig(BaseConfig):
20
- image_size: int = 64
21
- in_channels: int = 3
22
- # base channels, will be multiplied
23
- model_channels: int = 64
24
- # output of the unet
25
- # suggest: 3
26
- # you only need 6 if you also model the variance of the noise prediction (usually we use an analytical variance hence 3)
27
- out_channels: int = 3
28
- # how many repeating resblocks per resolution
29
- # the decoding side would have "one more" resblock
30
- # default: 2
31
- num_res_blocks: int = 2
32
- # you can also set the number of resblocks specifically for the input blocks
33
- # default: None = above
34
- num_input_res_blocks: int = None
35
- # number of time embed channels and style channels
36
- embed_channels: int = 512
37
- # at what resolutions you want to do self-attention of the feature maps
38
- # attentions generally improve performance
39
- # default: [16]
40
- # beatgans: [32, 16, 8]
41
- attention_resolutions: Tuple[int] = (16, )
42
- # number of time embed channels
43
- time_embed_channels: int = None
44
- # dropout applies to the resblocks (on feature maps)
45
- dropout: float = 0.1
46
- channel_mult: Tuple[int] = (1, 2, 4, 8)
47
- input_channel_mult: Tuple[int] = None
48
- conv_resample: bool = True
49
- # always 2 = 2d conv
50
- dims: int = 2
51
- # don't use this, legacy from BeatGANs
52
- num_classes: int = None
53
- use_checkpoint: bool = False
54
- # number of attention heads
55
- num_heads: int = 1
56
- # or specify the number of channels per attention head
57
- num_head_channels: int = -1
58
- # what's this?
59
- num_heads_upsample: int = -1
60
- # use resblock for upscale/downscale blocks (expensive)
61
- # default: True (BeatGANs)
62
- resblock_updown: bool = True
63
- # never tried
64
- use_new_attention_order: bool = False
65
- resnet_two_cond: bool = False
66
- resnet_cond_channels: int = None
67
- # init the decoding conv layers with zero weights, this speeds up training
68
- # default: True (BeattGANs)
69
- resnet_use_zero_module: bool = True
70
- # gradient checkpoint the attention operation
71
- attn_checkpoint: bool = False
72
-
73
- def make_model(self):
74
- return BeatGANsUNetModel(self)
75
-
76
-
77
- class BeatGANsUNetModel(nn.Module):
78
- def __init__(self, conf: BeatGANsUNetConfig):
79
- super().__init__()
80
- self.conf = conf
81
-
82
- if conf.num_heads_upsample == -1:
83
- self.num_heads_upsample = conf.num_heads
84
-
85
- self.dtype = th.float32
86
-
87
- self.time_emb_channels = conf.time_embed_channels or conf.model_channels
88
- self.time_embed = nn.Sequential(
89
- linear(self.time_emb_channels, conf.embed_channels),
90
- nn.SiLU(),
91
- linear(conf.embed_channels, conf.embed_channels),
92
- )
93
-
94
- if conf.num_classes is not None:
95
- self.label_emb = nn.Embedding(conf.num_classes,
96
- conf.embed_channels)
97
-
98
- ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
99
- self.input_blocks = nn.ModuleList([
100
- TimestepEmbedSequential(
101
- conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
102
- ])
103
-
104
- kwargs = dict(
105
- use_condition=True,
106
- two_cond=conf.resnet_two_cond,
107
- use_zero_module=conf.resnet_use_zero_module,
108
- # style channels for the resnet block
109
- cond_emb_channels=conf.resnet_cond_channels,
110
- )
111
-
112
- self._feature_size = ch
113
-
114
- # input_block_chans = [ch]
115
- input_block_chans = [[] for _ in range(len(conf.channel_mult))]
116
- input_block_chans[0].append(ch)
117
-
118
- # number of blocks at each resolution
119
- self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
120
- self.input_num_blocks[0] = 1
121
- self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
122
-
123
- ds = 1
124
- resolution = conf.image_size
125
- for level, mult in enumerate(conf.input_channel_mult
126
- or conf.channel_mult):
127
- for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
128
- layers = [
129
- ResBlockConfig(
130
- ch,
131
- conf.embed_channels,
132
- conf.dropout,
133
- out_channels=int(mult * conf.model_channels),
134
- dims=conf.dims,
135
- use_checkpoint=conf.use_checkpoint,
136
- **kwargs,
137
- ).make_model()
138
- ]
139
- ch = int(mult * conf.model_channels)
140
- if resolution in conf.attention_resolutions:
141
- layers.append(
142
- AttentionBlock(
143
- ch,
144
- use_checkpoint=conf.use_checkpoint
145
- or conf.attn_checkpoint,
146
- num_heads=conf.num_heads,
147
- num_head_channels=conf.num_head_channels,
148
- use_new_attention_order=conf.
149
- use_new_attention_order,
150
- ))
151
- self.input_blocks.append(TimestepEmbedSequential(*layers))
152
- self._feature_size += ch
153
- # input_block_chans.append(ch)
154
- input_block_chans[level].append(ch)
155
- self.input_num_blocks[level] += 1
156
- # print(input_block_chans)
157
- if level != len(conf.channel_mult) - 1:
158
- resolution //= 2
159
- out_ch = ch
160
- self.input_blocks.append(
161
- TimestepEmbedSequential(
162
- ResBlockConfig(
163
- ch,
164
- conf.embed_channels,
165
- conf.dropout,
166
- out_channels=out_ch,
167
- dims=conf.dims,
168
- use_checkpoint=conf.use_checkpoint,
169
- down=True,
170
- **kwargs,
171
- ).make_model() if conf.
172
- resblock_updown else Downsample(ch,
173
- conf.conv_resample,
174
- dims=conf.dims,
175
- out_channels=out_ch)))
176
- ch = out_ch
177
- # input_block_chans.append(ch)
178
- input_block_chans[level + 1].append(ch)
179
- self.input_num_blocks[level + 1] += 1
180
- ds *= 2
181
- self._feature_size += ch
182
-
183
- self.middle_block = TimestepEmbedSequential(
184
- ResBlockConfig(
185
- ch,
186
- conf.embed_channels,
187
- conf.dropout,
188
- dims=conf.dims,
189
- use_checkpoint=conf.use_checkpoint,
190
- **kwargs,
191
- ).make_model(),
192
- AttentionBlock(
193
- ch,
194
- use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
195
- num_heads=conf.num_heads,
196
- num_head_channels=conf.num_head_channels,
197
- use_new_attention_order=conf.use_new_attention_order,
198
- ),
199
- ResBlockConfig(
200
- ch,
201
- conf.embed_channels,
202
- conf.dropout,
203
- dims=conf.dims,
204
- use_checkpoint=conf.use_checkpoint,
205
- **kwargs,
206
- ).make_model(),
207
- )
208
- self._feature_size += ch
209
-
210
- self.output_blocks = nn.ModuleList([])
211
- for level, mult in list(enumerate(conf.channel_mult))[::-1]:
212
- for i in range(conf.num_res_blocks + 1):
213
- # print(input_block_chans)
214
- # ich = input_block_chans.pop()
215
- try:
216
- ich = input_block_chans[level].pop()
217
- except IndexError:
218
- # this happens only when num_res_block > num_enc_res_block
219
- # we will not have enough lateral (skip) connecions for all decoder blocks
220
- ich = 0
221
- # print('pop:', ich)
222
- layers = [
223
- ResBlockConfig(
224
- # only direct channels when gated
225
- channels=ch + ich,
226
- emb_channels=conf.embed_channels,
227
- dropout=conf.dropout,
228
- out_channels=int(conf.model_channels * mult),
229
- dims=conf.dims,
230
- use_checkpoint=conf.use_checkpoint,
231
- # lateral channels are described here when gated
232
- has_lateral=True if ich > 0 else False,
233
- lateral_channels=None,
234
- **kwargs,
235
- ).make_model()
236
- ]
237
- ch = int(conf.model_channels * mult)
238
- if resolution in conf.attention_resolutions:
239
- layers.append(
240
- AttentionBlock(
241
- ch,
242
- use_checkpoint=conf.use_checkpoint
243
- or conf.attn_checkpoint,
244
- num_heads=self.num_heads_upsample,
245
- num_head_channels=conf.num_head_channels,
246
- use_new_attention_order=conf.
247
- use_new_attention_order,
248
- ))
249
- if level and i == conf.num_res_blocks:
250
- resolution *= 2
251
- out_ch = ch
252
- layers.append(
253
- ResBlockConfig(
254
- ch,
255
- conf.embed_channels,
256
- conf.dropout,
257
- out_channels=out_ch,
258
- dims=conf.dims,
259
- use_checkpoint=conf.use_checkpoint,
260
- up=True,
261
- **kwargs,
262
- ).make_model() if (
263
- conf.resblock_updown
264
- ) else Upsample(ch,
265
- conf.conv_resample,
266
- dims=conf.dims,
267
- out_channels=out_ch))
268
- ds //= 2
269
- self.output_blocks.append(TimestepEmbedSequential(*layers))
270
- self.output_num_blocks[level] += 1
271
- self._feature_size += ch
272
-
273
- # print(input_block_chans)
274
- # print('inputs:', self.input_num_blocks)
275
- # print('outputs:', self.output_num_blocks)
276
-
277
- if conf.resnet_use_zero_module:
278
- self.out = nn.Sequential(
279
- normalization(ch),
280
- nn.SiLU(),
281
- zero_module(
282
- conv_nd(conf.dims,
283
- input_ch,
284
- conf.out_channels,
285
- 3,
286
- padding=1)),
287
- )
288
- else:
289
- self.out = nn.Sequential(
290
- normalization(ch),
291
- nn.SiLU(),
292
- conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
293
- )
294
-
295
- def forward(self, x, t, y=None, **kwargs):
296
- """
297
- Apply the model to an input batch.
298
-
299
- :param x: an [N x C x ...] Tensor of inputs.
300
- :param timesteps: a 1-D batch of timesteps.
301
- :param y: an [N] Tensor of labels, if class-conditional.
302
- :return: an [N x C x ...] Tensor of outputs.
303
- """
304
- assert (y is not None) == (
305
- self.conf.num_classes is not None
306
- ), "must specify y if and only if the model is class-conditional"
307
-
308
- # hs = []
309
- hs = [[] for _ in range(len(self.conf.channel_mult))]
310
- emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
311
-
312
- if self.conf.num_classes is not None:
313
- raise NotImplementedError()
314
- # assert y.shape == (x.shape[0], )
315
- # emb = emb + self.label_emb(y)
316
-
317
- # new code supports input_num_blocks != output_num_blocks
318
- h = x.type(self.dtype)
319
- k = 0
320
- for i in range(len(self.input_num_blocks)):
321
- for j in range(self.input_num_blocks[i]):
322
- h = self.input_blocks[k](h, emb=emb)
323
- # print(i, j, h.shape)
324
- hs[i].append(h)
325
- k += 1
326
- assert k == len(self.input_blocks)
327
-
328
- h = self.middle_block(h, emb=emb)
329
- k = 0
330
- for i in range(len(self.output_num_blocks)):
331
- for j in range(self.output_num_blocks[i]):
332
- # take the lateral connection from the same layer (in reserve)
333
- # until there is no more, use None
334
- try:
335
- lateral = hs[-i - 1].pop()
336
- # print(i, j, lateral.shape)
337
- except IndexError:
338
- lateral = None
339
- # print(i, j, lateral)
340
- h = self.output_blocks[k](h, emb=emb, lateral=lateral)
341
- k += 1
342
-
343
- h = h.type(x.dtype)
344
- pred = self.out(h)
345
- return Return(pred=pred)
346
-
347
-
348
- class Return(NamedTuple):
349
- pred: th.Tensor
350
-
351
-
352
- @dataclass
353
- class BeatGANsEncoderConfig(BaseConfig):
354
- image_size: int
355
- in_channels: int
356
- model_channels: int
357
- out_hid_channels: int
358
- out_channels: int
359
- num_res_blocks: int
360
- attention_resolutions: Tuple[int]
361
- dropout: float = 0
362
- channel_mult: Tuple[int] = (1, 2, 4, 8)
363
- use_time_condition: bool = True
364
- conv_resample: bool = True
365
- dims: int = 2
366
- use_checkpoint: bool = False
367
- num_heads: int = 1
368
- num_head_channels: int = -1
369
- resblock_updown: bool = False
370
- use_new_attention_order: bool = False
371
- pool: str = 'adaptivenonzero'
372
-
373
- def make_model(self):
374
- return BeatGANsEncoderModel(self)
375
-
376
-
377
- class BeatGANsEncoderModel(nn.Module):
378
- """
379
- The half UNet model with attention and timestep embedding.
380
-
381
- For usage, see UNet.
382
- """
383
- def __init__(self, conf: BeatGANsEncoderConfig):
384
- super().__init__()
385
- self.conf = conf
386
- self.dtype = th.float32
387
-
388
- if conf.use_time_condition:
389
- time_embed_dim = conf.model_channels * 4
390
- self.time_embed = nn.Sequential(
391
- linear(conf.model_channels, time_embed_dim),
392
- nn.SiLU(),
393
- linear(time_embed_dim, time_embed_dim),
394
- )
395
- else:
396
- time_embed_dim = None
397
-
398
- ch = int(conf.channel_mult[0] * conf.model_channels)
399
- self.input_blocks = nn.ModuleList([
400
- TimestepEmbedSequential(
401
- conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
402
- ])
403
- self._feature_size = ch
404
- input_block_chans = [ch]
405
- ds = 1
406
- resolution = conf.image_size
407
- for level, mult in enumerate(conf.channel_mult):
408
- for _ in range(conf.num_res_blocks):
409
- layers = [
410
- ResBlockConfig(
411
- ch,
412
- time_embed_dim,
413
- conf.dropout,
414
- out_channels=int(mult * conf.model_channels),
415
- dims=conf.dims,
416
- use_condition=conf.use_time_condition,
417
- use_checkpoint=conf.use_checkpoint,
418
- ).make_model()
419
- ]
420
- ch = int(mult * conf.model_channels)
421
- if resolution in conf.attention_resolutions:
422
- layers.append(
423
- AttentionBlock(
424
- ch,
425
- use_checkpoint=conf.use_checkpoint,
426
- num_heads=conf.num_heads,
427
- num_head_channels=conf.num_head_channels,
428
- use_new_attention_order=conf.
429
- use_new_attention_order,
430
- ))
431
- self.input_blocks.append(TimestepEmbedSequential(*layers))
432
- self._feature_size += ch
433
- input_block_chans.append(ch)
434
- if level != len(conf.channel_mult) - 1:
435
- resolution //= 2
436
- out_ch = ch
437
- self.input_blocks.append(
438
- TimestepEmbedSequential(
439
- ResBlockConfig(
440
- ch,
441
- time_embed_dim,
442
- conf.dropout,
443
- out_channels=out_ch,
444
- dims=conf.dims,
445
- use_condition=conf.use_time_condition,
446
- use_checkpoint=conf.use_checkpoint,
447
- down=True,
448
- ).make_model() if (
449
- conf.resblock_updown
450
- ) else Downsample(ch,
451
- conf.conv_resample,
452
- dims=conf.dims,
453
- out_channels=out_ch)))
454
- ch = out_ch
455
- input_block_chans.append(ch)
456
- ds *= 2
457
- self._feature_size += ch
458
-
459
- self.middle_block = TimestepEmbedSequential(
460
- ResBlockConfig(
461
- ch,
462
- time_embed_dim,
463
- conf.dropout,
464
- dims=conf.dims,
465
- use_condition=conf.use_time_condition,
466
- use_checkpoint=conf.use_checkpoint,
467
- ).make_model(),
468
- AttentionBlock(
469
- ch,
470
- use_checkpoint=conf.use_checkpoint,
471
- num_heads=conf.num_heads,
472
- num_head_channels=conf.num_head_channels,
473
- use_new_attention_order=conf.use_new_attention_order,
474
- ),
475
- ResBlockConfig(
476
- ch,
477
- time_embed_dim,
478
- conf.dropout,
479
- dims=conf.dims,
480
- use_condition=conf.use_time_condition,
481
- use_checkpoint=conf.use_checkpoint,
482
- ).make_model(),
483
- )
484
- self._feature_size += ch
485
- if conf.pool == "adaptivenonzero":
486
- self.out = nn.Sequential(
487
- normalization(ch),
488
- nn.SiLU(),
489
- nn.AdaptiveAvgPool2d((1, 1)),
490
- conv_nd(conf.dims, ch, conf.out_channels, 1),
491
- nn.Flatten(),
492
- )
493
- else:
494
- raise NotImplementedError(f"Unexpected {conf.pool} pooling")
495
-
496
- def forward(self, x, t=None, return_2d_feature=False):
497
- """
498
- Apply the model to an input batch.
499
-
500
- :param x: an [N x C x ...] Tensor of inputs.
501
- :param timesteps: a 1-D batch of timesteps.
502
- :return: an [N x K] Tensor of outputs.
503
- """
504
- if self.conf.use_time_condition:
505
- emb = self.time_embed(timestep_embedding(t, self.model_channels))
506
- else:
507
- emb = None
508
-
509
- results = []
510
- h = x.type(self.dtype)
511
- for module in self.input_blocks:
512
- h = module(h, emb=emb)
513
- if self.conf.pool.startswith("spatial"):
514
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
515
- h = self.middle_block(h, emb=emb)
516
- if self.conf.pool.startswith("spatial"):
517
- results.append(h.type(x.dtype).mean(dim=(2, 3)))
518
- h = th.cat(results, axis=-1)
519
- else:
520
- h = h.type(x.dtype)
521
-
522
- h_2d = h
523
- h = self.out(h)
524
-
525
- if return_2d_feature:
526
- return h, h_2d
527
- else:
528
- return h
529
-
530
- def forward_flatten(self, x):
531
- """
532
- transform the last 2d feature into a flatten vector
533
- """
534
- h = self.out(x)
535
- return h
536
-
537
-
538
- class SuperResModel(BeatGANsUNetModel):
539
- """
540
- A UNetModel that performs super-resolution.
541
-
542
- Expects an extra kwarg `low_res` to condition on a low-resolution image.
543
- """
544
- def __init__(self, image_size, in_channels, *args, **kwargs):
545
- super().__init__(image_size, in_channels * 2, *args, **kwargs)
546
-
547
- def forward(self, x, timesteps, low_res=None, **kwargs):
548
- _, _, new_height, new_width = x.shape
549
- upsampled = F.interpolate(low_res, (new_height, new_width),
550
- mode="bilinear")
551
- x = th.cat([x, upsampled], dim=1)
552
- return super().forward(x, timesteps, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/model/unet_autoenc.py DELETED
@@ -1,283 +0,0 @@
1
- from enum import Enum
2
-
3
- import torch
4
- from torch import Tensor
5
- from torch.nn.functional import silu
6
-
7
- from .latentnet import *
8
- from .unet import *
9
- from choices import *
10
-
11
-
12
- @dataclass
13
- class BeatGANsAutoencConfig(BeatGANsUNetConfig):
14
- # number of style channels
15
- enc_out_channels: int = 512
16
- enc_attn_resolutions: Tuple[int] = None
17
- enc_pool: str = 'depthconv'
18
- enc_num_res_block: int = 2
19
- enc_channel_mult: Tuple[int] = None
20
- enc_grad_checkpoint: bool = False
21
- latent_net_conf: MLPSkipNetConfig = None
22
-
23
- def make_model(self):
24
- return BeatGANsAutoencModel(self)
25
-
26
-
27
- class BeatGANsAutoencModel(BeatGANsUNetModel):
28
- def __init__(self, conf: BeatGANsAutoencConfig):
29
- super().__init__(conf)
30
- self.conf = conf
31
-
32
- # having only time, cond
33
- self.time_embed = TimeStyleSeperateEmbed(
34
- time_channels=conf.model_channels,
35
- time_out_channels=conf.embed_channels,
36
- )
37
-
38
- self.encoder = BeatGANsEncoderConfig(
39
- image_size=conf.image_size,
40
- in_channels=conf.in_channels,
41
- model_channels=conf.model_channels,
42
- out_hid_channels=conf.enc_out_channels,
43
- out_channels=conf.enc_out_channels,
44
- num_res_blocks=conf.enc_num_res_block,
45
- attention_resolutions=(conf.enc_attn_resolutions
46
- or conf.attention_resolutions),
47
- dropout=conf.dropout,
48
- channel_mult=conf.enc_channel_mult or conf.channel_mult,
49
- use_time_condition=False,
50
- conv_resample=conf.conv_resample,
51
- dims=conf.dims,
52
- use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
53
- num_heads=conf.num_heads,
54
- num_head_channels=conf.num_head_channels,
55
- resblock_updown=conf.resblock_updown,
56
- use_new_attention_order=conf.use_new_attention_order,
57
- pool=conf.enc_pool,
58
- ).make_model()
59
-
60
- if conf.latent_net_conf is not None:
61
- self.latent_net = conf.latent_net_conf.make_model()
62
-
63
- def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
64
- """
65
- Reparameterization trick to sample from N(mu, var) from
66
- N(0,1).
67
- :param mu: (Tensor) Mean of the latent Gaussian [B x D]
68
- :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
69
- :return: (Tensor) [B x D]
70
- """
71
- assert self.conf.is_stochastic
72
- std = torch.exp(0.5 * logvar)
73
- eps = torch.randn_like(std)
74
- return eps * std + mu
75
-
76
- def sample_z(self, n: int, device):
77
- assert self.conf.is_stochastic
78
- return torch.randn(n, self.conf.enc_out_channels, device=device)
79
-
80
- def noise_to_cond(self, noise: Tensor):
81
- raise NotImplementedError()
82
- assert self.conf.noise_net_conf is not None
83
- return self.noise_net.forward(noise)
84
-
85
- def encode(self, x):
86
- cond = self.encoder.forward(x)
87
- return {'cond': cond}
88
-
89
- @property
90
- def stylespace_sizes(self):
91
- modules = list(self.input_blocks.modules()) + list(
92
- self.middle_block.modules()) + list(self.output_blocks.modules())
93
- sizes = []
94
- for module in modules:
95
- if isinstance(module, ResBlock):
96
- linear = module.cond_emb_layers[-1]
97
- sizes.append(linear.weight.shape[0])
98
- return sizes
99
-
100
- def encode_stylespace(self, x, return_vector: bool = True):
101
- """
102
- encode to style space
103
- """
104
- modules = list(self.input_blocks.modules()) + list(
105
- self.middle_block.modules()) + list(self.output_blocks.modules())
106
- # (n, c)
107
- cond = self.encoder.forward(x)
108
- S = []
109
- for module in modules:
110
- if isinstance(module, ResBlock):
111
- # (n, c')
112
- s = module.cond_emb_layers.forward(cond)
113
- S.append(s)
114
-
115
- if return_vector:
116
- # (n, sum_c)
117
- return torch.cat(S, dim=1)
118
- else:
119
- return S
120
-
121
- def forward(self,
122
- x,
123
- t,
124
- y=None,
125
- x_start=None,
126
- cond=None,
127
- style=None,
128
- noise=None,
129
- t_cond=None,
130
- **kwargs):
131
- """
132
- Apply the model to an input batch.
133
-
134
- Args:
135
- x_start: the original image to encode
136
- cond: output of the encoder
137
- noise: random noise (to predict the cond)
138
- """
139
-
140
- if t_cond is None:
141
- t_cond = t
142
-
143
- if noise is not None:
144
- # if the noise is given, we predict the cond from noise
145
- cond = self.noise_to_cond(noise)
146
-
147
- if cond is None:
148
- if x is not None:
149
- assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
150
-
151
- tmp = self.encode(x_start)
152
- cond = tmp['cond']
153
-
154
- if t is not None:
155
- _t_emb = timestep_embedding(t, self.conf.model_channels)
156
- _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
157
- else:
158
- # this happens when training only autoenc
159
- _t_emb = None
160
- _t_cond_emb = None
161
-
162
- if self.conf.resnet_two_cond:
163
- res = self.time_embed.forward(
164
- time_emb=_t_emb,
165
- cond=cond,
166
- time_cond_emb=_t_cond_emb,
167
- )
168
- else:
169
- raise NotImplementedError()
170
-
171
- if self.conf.resnet_two_cond:
172
- # two cond: first = time emb, second = cond_emb
173
- emb = res.time_emb
174
- cond_emb = res.emb
175
- else:
176
- # one cond = combined of both time and cond
177
- emb = res.emb
178
- cond_emb = None
179
-
180
- # override the style if given
181
- style = style or res.style
182
-
183
- assert (y is not None) == (
184
- self.conf.num_classes is not None
185
- ), "must specify y if and only if the model is class-conditional"
186
-
187
- if self.conf.num_classes is not None:
188
- raise NotImplementedError()
189
- # assert y.shape == (x.shape[0], )
190
- # emb = emb + self.label_emb(y)
191
-
192
- # where in the model to supply time conditions
193
- enc_time_emb = emb
194
- mid_time_emb = emb
195
- dec_time_emb = emb
196
- # where in the model to supply style conditions
197
- enc_cond_emb = cond_emb
198
- mid_cond_emb = cond_emb
199
- dec_cond_emb = cond_emb
200
-
201
- # hs = []
202
- hs = [[] for _ in range(len(self.conf.channel_mult))]
203
-
204
- if x is not None:
205
- h = x.type(self.dtype)
206
-
207
- # input blocks
208
- k = 0
209
- for i in range(len(self.input_num_blocks)):
210
- for j in range(self.input_num_blocks[i]):
211
- h = self.input_blocks[k](h,
212
- emb=enc_time_emb,
213
- cond=enc_cond_emb)
214
-
215
- # print(i, j, h.shape)
216
- hs[i].append(h)
217
- k += 1
218
- assert k == len(self.input_blocks)
219
-
220
- # middle blocks
221
- h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
222
- else:
223
- # no lateral connections
224
- # happens when training only the autonecoder
225
- h = None
226
- hs = [[] for _ in range(len(self.conf.channel_mult))]
227
-
228
- # output blocks
229
- k = 0
230
- for i in range(len(self.output_num_blocks)):
231
- for j in range(self.output_num_blocks[i]):
232
- # take the lateral connection from the same layer (in reserve)
233
- # until there is no more, use None
234
- try:
235
- lateral = hs[-i - 1].pop()
236
- # print(i, j, lateral.shape)
237
- except IndexError:
238
- lateral = None
239
- # print(i, j, lateral)
240
-
241
- h = self.output_blocks[k](h,
242
- emb=dec_time_emb,
243
- cond=dec_cond_emb,
244
- lateral=lateral)
245
- k += 1
246
-
247
- pred = self.out(h)
248
- return AutoencReturn(pred=pred, cond=cond)
249
-
250
-
251
- class AutoencReturn(NamedTuple):
252
- pred: Tensor
253
- cond: Tensor = None
254
-
255
-
256
- class EmbedReturn(NamedTuple):
257
- # style and time
258
- emb: Tensor = None
259
- # time only
260
- time_emb: Tensor = None
261
- # style only (but could depend on time)
262
- style: Tensor = None
263
-
264
-
265
- class TimeStyleSeperateEmbed(nn.Module):
266
- # embed only style
267
- def __init__(self, time_channels, time_out_channels):
268
- super().__init__()
269
- self.time_embed = nn.Sequential(
270
- linear(time_channels, time_out_channels),
271
- nn.SiLU(),
272
- linear(time_out_channels, time_out_channels),
273
- )
274
- self.style = nn.Identity()
275
-
276
- def forward(self, time_emb=None, cond=None, **kwargs):
277
- if time_emb is None:
278
- # happens with autoenc training mode
279
- time_emb = None
280
- else:
281
- time_emb = self.time_embed(time_emb)
282
- style = self.style(cond)
283
- return EmbedReturn(emb=style, time_emb=time_emb, style=style)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/networks/__init__.py DELETED
File without changes
code/networks/discriminator.py DELETED
@@ -1,259 +0,0 @@
1
- import math
2
- import torch
3
- from torch.nn import functional as F
4
- from torch import nn
5
-
6
-
7
- def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
8
- return F.leaky_relu(input + bias, negative_slope) * scale
9
-
10
-
11
- class FusedLeakyReLU(nn.Module):
12
- def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
13
- super().__init__()
14
- self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
15
- self.negative_slope = negative_slope
16
- self.scale = scale
17
-
18
- def forward(self, input):
19
- # print("FusedLeakyReLU: ", input.abs().mean())
20
- out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
21
- # print("FusedLeakyReLU: ", out.abs().mean())
22
- return out
23
-
24
-
25
- def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
26
- _, minor, in_h, in_w = input.shape
27
- kernel_h, kernel_w = kernel.shape
28
-
29
- out = input.view(-1, minor, in_h, 1, in_w, 1)
30
- out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
31
- out = out.view(-1, minor, in_h * up_y, in_w * up_x)
32
-
33
- out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
34
- out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
35
-
36
- # out = out.permute(0, 3, 1, 2)
37
- out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
38
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
39
- out = F.conv2d(out, w)
40
- out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
41
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
42
- # out = out.permute(0, 2, 3, 1)
43
-
44
- return out[:, :, ::down_y, ::down_x]
45
-
46
-
47
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
48
- return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
49
-
50
-
51
- def make_kernel(k):
52
- k = torch.tensor(k, dtype=torch.float32)
53
-
54
- if k.ndim == 1:
55
- k = k[None, :] * k[:, None]
56
-
57
- k /= k.sum()
58
-
59
- return k
60
-
61
-
62
- class Blur(nn.Module):
63
- def __init__(self, kernel, pad, upsample_factor=1):
64
- super().__init__()
65
-
66
- kernel = make_kernel(kernel)
67
-
68
- if upsample_factor > 1:
69
- kernel = kernel * (upsample_factor ** 2)
70
-
71
- self.register_buffer('kernel', kernel)
72
-
73
- self.pad = pad
74
-
75
- def forward(self, input):
76
- return upfirdn2d(input, self.kernel, pad=self.pad)
77
-
78
-
79
- class ScaledLeakyReLU(nn.Module):
80
- def __init__(self, negative_slope=0.2):
81
- super().__init__()
82
-
83
- self.negative_slope = negative_slope
84
-
85
- def forward(self, input):
86
- return F.leaky_relu(input, negative_slope=self.negative_slope)
87
-
88
-
89
- class EqualConv2d(nn.Module):
90
- def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
91
- super().__init__()
92
-
93
- self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
94
- self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
95
-
96
- self.stride = stride
97
- self.padding = padding
98
-
99
- if bias:
100
- self.bias = nn.Parameter(torch.zeros(out_channel))
101
- else:
102
- self.bias = None
103
-
104
- def forward(self, input):
105
-
106
- return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride,
107
- padding=self.padding, )
108
-
109
- def __repr__(self):
110
- return (
111
- f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
112
- f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
113
- )
114
-
115
-
116
- class EqualLinear(nn.Module):
117
- def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
118
- super().__init__()
119
-
120
- self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
121
-
122
- if bias:
123
- self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
124
- else:
125
- self.bias = None
126
-
127
- self.activation = activation
128
-
129
- self.scale = (1 / math.sqrt(in_dim)) * lr_mul
130
- self.lr_mul = lr_mul
131
-
132
- def forward(self, input):
133
-
134
- if self.activation:
135
- out = F.linear(input, self.weight * self.scale)
136
- out = fused_leaky_relu(out, self.bias * self.lr_mul)
137
- else:
138
- out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
139
-
140
- return out
141
-
142
- def __repr__(self):
143
- return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
144
-
145
-
146
- class ConvLayer(nn.Sequential):
147
- def __init__(
148
- self,
149
- in_channel,
150
- out_channel,
151
- kernel_size,
152
- downsample=False,
153
- blur_kernel=[1, 3, 3, 1],
154
- bias=True,
155
- activate=True,
156
- ):
157
- layers = []
158
-
159
- if downsample:
160
- factor = 2
161
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
162
- pad0 = (p + 1) // 2
163
- pad1 = p // 2
164
-
165
- layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
166
-
167
- stride = 2
168
- self.padding = 0
169
-
170
- else:
171
- stride = 1
172
- self.padding = kernel_size // 2
173
-
174
- layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
175
- bias=bias and not activate))
176
-
177
- if activate:
178
- if bias:
179
- layers.append(FusedLeakyReLU(out_channel))
180
- else:
181
- layers.append(ScaledLeakyReLU(0.2))
182
-
183
- super().__init__(*layers)
184
-
185
-
186
- class ResBlock(nn.Module):
187
- def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
188
- super().__init__()
189
-
190
- self.conv1 = ConvLayer(in_channel, in_channel, 3)
191
- self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
192
-
193
- self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
194
-
195
- def forward(self, input):
196
- out = self.conv1(input)
197
- out = self.conv2(out)
198
-
199
- skip = self.skip(input)
200
- out = (out + skip) / math.sqrt(2)
201
-
202
- return out
203
-
204
-
205
- class Discriminator(nn.Module):
206
- def __init__(self, size, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]):
207
- super().__init__()
208
-
209
- self.size = size
210
-
211
- channels = {
212
- 4: 512,
213
- 8: 512,
214
- 16: 512,
215
- 32: 512,
216
- 64: 256 * channel_multiplier,
217
- 128: 128 * channel_multiplier,
218
- 256: 64 * channel_multiplier,
219
- 512: 32 * channel_multiplier,
220
- 1024: 16 * channel_multiplier,
221
- }
222
-
223
- convs = [ConvLayer(3, channels[size], 1)]
224
- log_size = int(math.log(size, 2))
225
- in_channel = channels[size]
226
-
227
- for i in range(log_size, 2, -1):
228
- out_channel = channels[2 ** (i - 1)]
229
- convs.append(ResBlock(in_channel, out_channel, blur_kernel))
230
- in_channel = out_channel
231
-
232
- self.convs = nn.Sequential(*convs)
233
-
234
- self.stddev_group = 4
235
- self.stddev_feat = 1
236
-
237
- self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
238
- self.final_linear = nn.Sequential(
239
- EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
240
- EqualLinear(channels[4], 1),
241
- )
242
-
243
- def forward(self, input):
244
- out = self.convs(input)
245
- batch, channel, height, width = out.shape
246
-
247
- group = min(batch, self.stddev_group)
248
- stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width)
249
- stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
250
- stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
251
- stddev = stddev.repeat(group, 1, height, width)
252
- out = torch.cat([out, stddev], 1)
253
-
254
- out = self.final_conv(out)
255
-
256
- out = out.view(batch, -1)
257
- out = self.final_linear(out)
258
-
259
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/networks/encoder.py DELETED
@@ -1,374 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
-
6
- def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
7
- return F.leaky_relu(input + bias, negative_slope) * scale
8
-
9
- class FusedLeakyReLU(nn.Module):
10
- def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
11
- super().__init__()
12
- self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
13
- self.negative_slope = negative_slope
14
- self.scale = scale
15
-
16
- def forward(self, input):
17
- out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
18
- return out
19
-
20
-
21
- def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
22
- _, minor, in_h, in_w = input.shape
23
- kernel_h, kernel_w = kernel.shape
24
-
25
- out = input.view(-1, minor, in_h, 1, in_w, 1)
26
- out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
27
- out = out.view(-1, minor, in_h * up_y, in_w * up_x)
28
-
29
- out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
30
- out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
31
- max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
32
-
33
- out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
34
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
35
- out = F.conv2d(out, w)
36
- out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
37
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
38
-
39
- return out[:, :, ::down_y, ::down_x]
40
-
41
-
42
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
43
- return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
44
-
45
-
46
- def make_kernel(k):
47
- k = torch.tensor(k, dtype=torch.float32)
48
-
49
- if k.ndim == 1:
50
- k = k[None, :] * k[:, None]
51
-
52
- k /= k.sum()
53
-
54
- return k
55
-
56
-
57
- class Blur(nn.Module):
58
- def __init__(self, kernel, pad, upsample_factor=1):
59
- super().__init__()
60
-
61
- kernel = make_kernel(kernel)
62
-
63
- if upsample_factor > 1:
64
- kernel = kernel * (upsample_factor ** 2)
65
-
66
- self.register_buffer('kernel', kernel)
67
-
68
- self.pad = pad
69
-
70
- def forward(self, input):
71
- return upfirdn2d(input, self.kernel, pad=self.pad)
72
-
73
-
74
- class ScaledLeakyReLU(nn.Module):
75
- def __init__(self, negative_slope=0.2):
76
- super().__init__()
77
-
78
- self.negative_slope = negative_slope
79
-
80
- def forward(self, input):
81
- return F.leaky_relu(input, negative_slope=self.negative_slope)
82
-
83
-
84
- class EqualConv2d(nn.Module):
85
- def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
86
- super().__init__()
87
-
88
- self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
89
- self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
90
-
91
- self.stride = stride
92
- self.padding = padding
93
-
94
- if bias:
95
- self.bias = nn.Parameter(torch.zeros(out_channel))
96
- else:
97
- self.bias = None
98
-
99
- def forward(self, input):
100
-
101
- return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
102
-
103
- def __repr__(self):
104
- return (
105
- f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
106
- f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
107
- )
108
-
109
-
110
- class EqualLinear(nn.Module):
111
- def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
112
- super().__init__()
113
-
114
- self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
115
-
116
- if bias:
117
- self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
118
- else:
119
- self.bias = None
120
-
121
- self.activation = activation
122
-
123
- self.scale = (1 / math.sqrt(in_dim)) * lr_mul
124
- self.lr_mul = lr_mul
125
-
126
- def forward(self, input):
127
-
128
- if self.activation:
129
- out = F.linear(input, self.weight * self.scale)
130
- out = fused_leaky_relu(out, self.bias * self.lr_mul)
131
- else:
132
- out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
133
-
134
- return out
135
-
136
- def __repr__(self):
137
- return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
138
-
139
-
140
- class ConvLayer(nn.Sequential):
141
- def __init__(
142
- self,
143
- in_channel,
144
- out_channel,
145
- kernel_size,
146
- downsample=False,
147
- blur_kernel=[1, 3, 3, 1],
148
- bias=True,
149
- activate=True,
150
- ):
151
- layers = []
152
-
153
- if downsample:
154
- factor = 2
155
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
156
- pad0 = (p + 1) // 2
157
- pad1 = p // 2
158
-
159
- layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
160
-
161
- stride = 2
162
- self.padding = 0
163
-
164
- else:
165
- stride = 1
166
- self.padding = kernel_size // 2
167
-
168
- layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
169
- bias=bias and not activate))
170
-
171
- if activate:
172
- if bias:
173
- layers.append(FusedLeakyReLU(out_channel))
174
- else:
175
- layers.append(ScaledLeakyReLU(0.2))
176
-
177
- super().__init__(*layers)
178
-
179
-
180
- class ResBlock(nn.Module):
181
- def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
182
- super().__init__()
183
-
184
- self.conv1 = ConvLayer(in_channel, in_channel, 3)
185
- self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
186
-
187
- self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
188
-
189
- def forward(self, input):
190
- out = self.conv1(input)
191
- out = self.conv2(out)
192
-
193
- skip = self.skip(input)
194
- out = (out + skip) / math.sqrt(2)
195
-
196
- return out
197
-
198
- class WeightedSumLayer(nn.Module):
199
- def __init__(self, num_tensors=8):
200
- super(WeightedSumLayer, self).__init__()
201
-
202
- self.weights = nn.Parameter(torch.randn(num_tensors))
203
-
204
- def forward(self, tensor_list):
205
-
206
- weights = torch.softmax(self.weights, dim=0)
207
- weighted_sum = torch.zeros_like(tensor_list[0])
208
- for tensor, weight in zip(tensor_list, weights):
209
- weighted_sum += tensor * weight
210
-
211
- return weighted_sum
212
-
213
- class EncoderApp(nn.Module):
214
- def __init__(self, size, w_dim=512, fusion_type=''):
215
- super(EncoderApp, self).__init__()
216
-
217
- channels = {
218
- 4: 512,
219
- 8: 512,
220
- 16: 512,
221
- 32: 512,
222
- 64: 256,
223
- 128: 128,
224
- 256: 64,
225
- 512: 32,
226
- 1024: 16
227
- }
228
-
229
- self.w_dim = w_dim
230
- log_size = int(math.log(size, 2))
231
-
232
- self.convs = nn.ModuleList()
233
- self.convs.append(ConvLayer(3, channels[size], 1))
234
-
235
- in_channel = channels[size]
236
- for i in range(log_size, 2, -1):
237
- out_channel = channels[2 ** (i - 1)]
238
- self.convs.append(ResBlock(in_channel, out_channel))
239
- in_channel = out_channel
240
-
241
- self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
242
-
243
- self.fusion_type = fusion_type
244
- assert self.fusion_type == 'weighted_sum'
245
- if self.fusion_type == 'weighted_sum':
246
- print(f'HAL layer is enabled!')
247
- self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
248
- self.fc1 = EqualLinear(64, 512)
249
- self.fc2 = EqualLinear(128, 512)
250
- self.fc3 = EqualLinear(256, 512)
251
- self.ws = WeightedSumLayer()
252
-
253
- def forward(self, x):
254
-
255
- res = []
256
- h = x
257
- pooled_h_lists = []
258
- for i, conv in enumerate(self.convs):
259
- h = conv(h)
260
- if self.fusion_type == 'weighted_sum':
261
- pooled_h = self.adaptive_pool(h).view(x.size(0), -1)
262
- if i == 0:
263
- pooled_h_lists.append(self.fc1(pooled_h))
264
- elif i == 1:
265
- pooled_h_lists.append(self.fc2(pooled_h))
266
- elif i == 2:
267
- pooled_h_lists.append(self.fc3(pooled_h))
268
- else:
269
- pooled_h_lists.append(pooled_h)
270
- res.append(h)
271
-
272
- if self.fusion_type == 'weighted_sum':
273
- last_layer = self.ws(pooled_h_lists)
274
- else:
275
- last_layer = res[-1].squeeze(-1).squeeze(-1)
276
- layer_features = res[::-1][2:]
277
-
278
- return last_layer, layer_features
279
-
280
-
281
- class DecouplingModel(nn.Module):
282
- def __init__(self, input_dim, hidden_dim, output_dim):
283
- super(DecouplingModel, self).__init__()
284
-
285
- # identity_excluded_net is called identity encoder in the paper
286
- self.identity_net = nn.Sequential(
287
- nn.Linear(input_dim, hidden_dim),
288
- nn.ReLU(),
289
- nn.Linear(hidden_dim, output_dim)
290
- )
291
-
292
- self.identity_net_density = nn.Sequential(
293
- nn.Linear(input_dim, hidden_dim),
294
- nn.ReLU(),
295
- nn.Linear(hidden_dim, output_dim)
296
- )
297
-
298
- # identity_excluded_net is called motion encoder in the paper
299
- self.identity_excluded_net = nn.Sequential(
300
- nn.Linear(input_dim, hidden_dim),
301
- nn.ReLU(),
302
- nn.Linear(hidden_dim, output_dim)
303
- )
304
-
305
- def forward(self, x):
306
-
307
- id_, id_rm = self.identity_net(x), self.identity_excluded_net(x)
308
- id_density = self.identity_net_density(id_)
309
- return id_, id_rm, id_density
310
-
311
- class Encoder(nn.Module):
312
- def __init__(self, size, dim=512, dim_motion=20, weighted_sum=False):
313
- super(Encoder, self).__init__()
314
-
315
- # image encoder
316
- self.net_app = EncoderApp(size, dim, weighted_sum)
317
-
318
- # decouping network
319
- self.net_decouping = DecouplingModel(dim, dim, dim)
320
-
321
- # part of the motion encoder
322
- fc = [EqualLinear(dim, dim)]
323
- for i in range(3):
324
- fc.append(EqualLinear(dim, dim))
325
-
326
- fc.append(EqualLinear(dim, dim_motion))
327
- self.fc = nn.Sequential(*fc)
328
-
329
- def enc_app(self, x):
330
-
331
- h_source = self.net_app(x)
332
-
333
- return h_source
334
-
335
- def enc_motion(self, x):
336
-
337
- h, _ = self.net_app(x)
338
- h_motion = self.fc(h)
339
-
340
- return h_motion
341
-
342
- def encode_image_obj(self, image_obj):
343
- feat, _ = self.net_app(image_obj)
344
- id_emb, idrm_emb, id_density_emb = self.net_decouping(feat)
345
- return id_emb, idrm_emb, id_density_emb
346
-
347
- def forward(self, input_source, input_target, input_face, input_aug):
348
-
349
-
350
- if input_target is not None:
351
-
352
- h_source, feats = self.net_app(input_source)
353
- h_target, _ = self.net_app(input_target)
354
- h_face, _ = self.net_app(input_face)
355
- h_aug, _ = self.net_app(input_aug)
356
-
357
- h_source_id_emb, h_source_idrm_emb, h_source_id_density_emb = self.net_decouping(h_source)
358
- h_target_id_emb, h_target_idrm_emb, h_target_id_density_emb = self.net_decouping(h_target)
359
- h_face_id_emb, h_face_idrm_emb, h_face_id_density_emb = self.net_decouping(h_face)
360
- h_aug_id_emb, h_aug_idrm_emb, h_aug_id_density_emb = self.net_decouping(h_aug)
361
-
362
- h_target_motion_target = self.fc(h_target_idrm_emb)
363
- h_another_face_target = self.fc(h_face_idrm_emb)
364
-
365
- else:
366
- h_source, feats = self.net_app(input_source)
367
-
368
-
369
- return {'h_source':h_source, 'h_motion':h_target_motion_target, 'feats':feats, 'h_another_face_target':h_another_face_target, 'h_face':h_face, \
370
- 'h_source_id_emb':h_source_id_emb, 'h_source_idrm_emb':h_source_idrm_emb, 'h_source_id_density_emb':h_source_id_density_emb, \
371
- 'h_target_id_emb':h_target_id_emb, 'h_target_idrm_emb':h_target_idrm_emb, 'h_target_id_density_emb':h_target_id_density_emb, \
372
- 'h_face_id_emb':h_face_id_emb, 'h_face_idrm_emb':h_face_idrm_emb, 'h_face_id_density_emb':h_face_id_density_emb, \
373
- 'h_aug_id_emb':h_aug_id_emb, 'h_aug_idrm_emb':h_aug_idrm_emb ,'h_aug_id_density_emb':h_aug_id_density_emb, \
374
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/networks/generator.py DELETED
@@ -1,27 +0,0 @@
1
- from torch import nn
2
- from .encoder import Encoder
3
- from .styledecoder import Synthesis
4
-
5
-
6
- class Generator(nn.Module):
7
- def __init__(self, size, style_dim=512, motion_dim=20, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]):
8
- super(Generator, self).__init__()
9
-
10
- # encoder
11
- self.enc = Encoder(size, style_dim, motion_dim)
12
- self.dec = Synthesis(size, style_dim, motion_dim, blur_kernel, channel_multiplier)
13
-
14
- def get_direction(self):
15
- return self.dec.direction(None)
16
-
17
- def synthesis(self, wa, alpha, feat):
18
- img = self.dec(wa, alpha, feat)
19
-
20
- return img
21
-
22
- def forward(self, img_source, img_drive, h_start=None):
23
- wa, alpha, feats = self.enc(img_source, img_drive, h_start)
24
- # import pdb;pdb.set_trace()
25
- img_recon = self.dec(wa, alpha, feats)
26
-
27
- return img_recon
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/networks/styledecoder.py DELETED
@@ -1,527 +0,0 @@
1
- import math
2
- import torch
3
- from torch import nn
4
- from torch.nn import functional as F
5
- import numpy as np
6
-
7
-
8
- def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
9
- return F.leaky_relu(input + bias, negative_slope) * scale
10
-
11
-
12
- class FusedLeakyReLU(nn.Module):
13
- def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
14
- super().__init__()
15
- self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
16
- self.negative_slope = negative_slope
17
- self.scale = scale
18
-
19
- def forward(self, input):
20
- out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
21
- return out
22
-
23
-
24
- def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
25
- _, minor, in_h, in_w = input.shape
26
- kernel_h, kernel_w = kernel.shape
27
-
28
- out = input.view(-1, minor, in_h, 1, in_w, 1)
29
- out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
30
- out = out.view(-1, minor, in_h * up_y, in_w * up_x)
31
-
32
- out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
33
- out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
34
- max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
35
-
36
- out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
37
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
38
- out = F.conv2d(out, w)
39
- out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
40
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
41
- return out[:, :, ::down_y, ::down_x]
42
-
43
-
44
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
45
- return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
46
-
47
-
48
- class PixelNorm(nn.Module):
49
- def __init__(self):
50
- super().__init__()
51
-
52
- def forward(self, input):
53
- return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
54
-
55
-
56
- class MotionPixelNorm(nn.Module):
57
- def __init__(self):
58
- super().__init__()
59
-
60
- def forward(self, input):
61
- return input * torch.rsqrt(torch.mean(input ** 2, dim=2, keepdim=True) + 1e-8)
62
-
63
-
64
- def make_kernel(k):
65
- k = torch.tensor(k, dtype=torch.float32)
66
-
67
- if k.ndim == 1:
68
- k = k[None, :] * k[:, None]
69
-
70
- k /= k.sum()
71
-
72
- return k
73
-
74
-
75
- class Upsample(nn.Module):
76
- def __init__(self, kernel, factor=2):
77
- super().__init__()
78
-
79
- self.factor = factor
80
- kernel = make_kernel(kernel) * (factor ** 2)
81
- self.register_buffer('kernel', kernel)
82
-
83
- p = kernel.shape[0] - factor
84
-
85
- pad0 = (p + 1) // 2 + factor - 1
86
- pad1 = p // 2
87
-
88
- self.pad = (pad0, pad1)
89
-
90
- def forward(self, input):
91
- return upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
92
-
93
-
94
- class Downsample(nn.Module):
95
- def __init__(self, kernel, factor=2):
96
- super().__init__()
97
-
98
- self.factor = factor
99
- kernel = make_kernel(kernel)
100
- self.register_buffer('kernel', kernel)
101
-
102
- p = kernel.shape[0] - factor
103
-
104
- pad0 = (p + 1) // 2
105
- pad1 = p // 2
106
-
107
- self.pad = (pad0, pad1)
108
-
109
- def forward(self, input):
110
- return upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
111
-
112
-
113
- class Blur(nn.Module):
114
- def __init__(self, kernel, pad, upsample_factor=1):
115
- super().__init__()
116
-
117
- kernel = make_kernel(kernel)
118
-
119
- if upsample_factor > 1:
120
- kernel = kernel * (upsample_factor ** 2)
121
-
122
- self.register_buffer('kernel', kernel)
123
-
124
- self.pad = pad
125
-
126
- def forward(self, input):
127
- return upfirdn2d(input, self.kernel, pad=self.pad)
128
-
129
-
130
- class EqualConv2d(nn.Module):
131
- def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
132
- super().__init__()
133
-
134
- self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
135
- self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
136
-
137
- self.stride = stride
138
- self.padding = padding
139
-
140
- if bias:
141
- self.bias = nn.Parameter(torch.zeros(out_channel))
142
- else:
143
- self.bias = None
144
-
145
- def forward(self, input):
146
-
147
- return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding, )
148
-
149
- def __repr__(self):
150
- return (
151
- f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
152
- f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
153
- )
154
-
155
-
156
- class EqualLinear(nn.Module):
157
- def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
158
- super().__init__()
159
-
160
- self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
161
-
162
- if bias:
163
- self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
164
- else:
165
- self.bias = None
166
-
167
- self.activation = activation
168
-
169
- self.scale = (1 / math.sqrt(in_dim)) * lr_mul
170
- self.lr_mul = lr_mul
171
-
172
- def forward(self, input):
173
-
174
- if self.activation:
175
- out = F.linear(input, self.weight * self.scale)
176
- out = fused_leaky_relu(out, self.bias * self.lr_mul)
177
- else:
178
- out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
179
-
180
- return out
181
-
182
- def __repr__(self):
183
- return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
184
-
185
-
186
- class ScaledLeakyReLU(nn.Module):
187
- def __init__(self, negative_slope=0.2):
188
- super().__init__()
189
-
190
- self.negative_slope = negative_slope
191
-
192
- def forward(self, input):
193
- return F.leaky_relu(input, negative_slope=self.negative_slope)
194
-
195
-
196
- class ModulatedConv2d(nn.Module):
197
- def __init__(self, in_channel, out_channel, kernel_size, style_dim, demodulate=True, upsample=False,
198
- downsample=False, blur_kernel=[1, 3, 3, 1], ):
199
- super().__init__()
200
-
201
- self.eps = 1e-8
202
- self.kernel_size = kernel_size
203
- self.in_channel = in_channel
204
- self.out_channel = out_channel
205
- self.upsample = upsample
206
- self.downsample = downsample
207
-
208
- if upsample:
209
- factor = 2
210
- p = (len(blur_kernel) - factor) - (kernel_size - 1)
211
- pad0 = (p + 1) // 2 + factor - 1
212
- pad1 = p // 2 + 1
213
-
214
- self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
215
-
216
- if downsample:
217
- factor = 2
218
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
219
- pad0 = (p + 1) // 2
220
- pad1 = p // 2
221
-
222
- self.blur = Blur(blur_kernel, pad=(pad0, pad1))
223
-
224
- fan_in = in_channel * kernel_size ** 2
225
- self.scale = 1 / math.sqrt(fan_in)
226
- self.padding = kernel_size // 2
227
-
228
- self.weight = nn.Parameter(torch.randn(1, out_channel, in_channel, kernel_size, kernel_size))
229
-
230
- self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
231
- self.demodulate = demodulate
232
-
233
- def __repr__(self):
234
- return (
235
- f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
236
- f'upsample={self.upsample}, downsample={self.downsample})'
237
- )
238
-
239
- def forward(self, input, style):
240
- batch, in_channel, height, width = input.shape
241
-
242
- style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
243
- weight = self.scale * self.weight * style
244
-
245
- if self.demodulate:
246
- demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
247
- weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
248
-
249
- weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size)
250
-
251
- if self.upsample:
252
- input = input.view(1, batch * in_channel, height, width)
253
- weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size)
254
- weight = weight.transpose(1, 2).reshape(batch * in_channel, self.out_channel, self.kernel_size,
255
- self.kernel_size)
256
- out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
257
- _, _, height, width = out.shape
258
- out = out.view(batch, self.out_channel, height, width)
259
- out = self.blur(out)
260
- elif self.downsample:
261
- input = self.blur(input)
262
- _, _, height, width = input.shape
263
- input = input.view(1, batch * in_channel, height, width)
264
- out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
265
- _, _, height, width = out.shape
266
- out = out.view(batch, self.out_channel, height, width)
267
- else:
268
- input = input.view(1, batch * in_channel, height, width)
269
- out = F.conv2d(input, weight, padding=self.padding, groups=batch)
270
- _, _, height, width = out.shape
271
- out = out.view(batch, self.out_channel, height, width)
272
-
273
- return out
274
-
275
-
276
- class NoiseInjection(nn.Module):
277
- def __init__(self):
278
- super().__init__()
279
-
280
- self.weight = nn.Parameter(torch.zeros(1))
281
-
282
- def forward(self, image, noise=None):
283
-
284
- if noise is None:
285
- return image
286
- else:
287
- return image + self.weight * noise
288
-
289
-
290
- class ConstantInput(nn.Module):
291
- def __init__(self, channel, size=4):
292
- super().__init__()
293
-
294
- self.input = nn.Parameter(torch.randn(1, channel, size, size))
295
-
296
- def forward(self, input):
297
- batch = input.shape[0]
298
- out = self.input.repeat(batch, 1, 1, 1)
299
-
300
- return out
301
-
302
-
303
- class StyledConv(nn.Module):
304
- def __init__(self, in_channel, out_channel, kernel_size, style_dim, upsample=False, blur_kernel=[1, 3, 3, 1],
305
- demodulate=True):
306
- super().__init__()
307
-
308
- self.conv = ModulatedConv2d(
309
- in_channel,
310
- out_channel,
311
- kernel_size,
312
- style_dim,
313
- upsample=upsample,
314
- blur_kernel=blur_kernel,
315
- demodulate=demodulate,
316
- )
317
-
318
- self.noise = NoiseInjection()
319
- self.activate = FusedLeakyReLU(out_channel)
320
-
321
- def forward(self, input, style, noise=None):
322
- out = self.conv(input, style)
323
- out = self.noise(out, noise=noise)
324
- out = self.activate(out)
325
-
326
- return out
327
-
328
-
329
- class ConvLayer(nn.Sequential):
330
- def __init__(
331
- self,
332
- in_channel,
333
- out_channel,
334
- kernel_size,
335
- downsample=False,
336
- blur_kernel=[1, 3, 3, 1],
337
- bias=True,
338
- activate=True,
339
- ):
340
- layers = []
341
-
342
- if downsample:
343
- factor = 2
344
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
345
- pad0 = (p + 1) // 2
346
- pad1 = p // 2
347
-
348
- layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
349
-
350
- stride = 2
351
- self.padding = 0
352
-
353
- else:
354
- stride = 1
355
- self.padding = kernel_size // 2
356
-
357
- layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
358
- bias=bias and not activate))
359
-
360
- if activate:
361
- if bias:
362
- layers.append(FusedLeakyReLU(out_channel))
363
- else:
364
- layers.append(ScaledLeakyReLU(0.2))
365
-
366
- super().__init__(*layers)
367
-
368
-
369
- class ToRGB(nn.Module):
370
- def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
371
- super().__init__()
372
-
373
- if upsample:
374
- self.upsample = Upsample(blur_kernel)
375
-
376
- self.conv = ConvLayer(in_channel, 3, 1)
377
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
378
-
379
- def forward(self, input, skip=None):
380
- out = self.conv(input)
381
- out = out + self.bias
382
-
383
- if skip is not None:
384
- skip = self.upsample(skip)
385
- out = out + skip
386
-
387
- return out
388
-
389
-
390
- class ToFlow(nn.Module):
391
- def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
392
- super().__init__()
393
-
394
- if upsample:
395
- self.upsample = Upsample(blur_kernel)
396
-
397
- self.style_dim = style_dim
398
- self.in_channel = in_channel
399
- self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
400
- self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
401
-
402
- def forward(self, input, style, feat, skip=None): # input 是来自上一层的 feature, style 是 512 的 condition, feat 是来自于 unet 的跳层
403
- out = self.conv(input, style)
404
- out = out + self.bias
405
-
406
- # warping
407
- xs = np.linspace(-1, 1, input.size(2))
408
-
409
- xs = np.meshgrid(xs, xs)
410
- xs = np.stack(xs, 2)
411
-
412
- xs = torch.tensor(xs, requires_grad=False).float().unsqueeze(0).repeat(input.size(0), 1, 1, 1).to(input.device)
413
- # import pdb;pdb.set_trace()
414
- if skip is not None:
415
- skip = self.upsample(skip)
416
- out = out + skip
417
-
418
- sampler = torch.tanh(out[:, 0:2, :, :])
419
- mask = torch.sigmoid(out[:, 2:3, :, :])
420
- flow = sampler.permute(0, 2, 3, 1) + xs # xs在这里相当于一个 location 的位置
421
-
422
- feat_warp = F.grid_sample(feat, flow) * mask
423
- # import pdb;pdb.set_trace()
424
- return feat_warp, feat_warp + input * (1.0 - mask), out
425
-
426
-
427
- class Direction(nn.Module):
428
- def __init__(self, motion_dim):
429
- super(Direction, self).__init__()
430
-
431
- self.weight = nn.Parameter(torch.randn(512, motion_dim))
432
-
433
- def forward(self, input):
434
- # input: (bs*t) x 512
435
-
436
- weight = self.weight + 1e-8
437
- Q, R = torch.qr(weight) # get eignvector, orthogonal [n1, n2, n3, n4]
438
-
439
- if input is None:
440
- return Q
441
- else:
442
- input_diag = torch.diag_embed(input) # alpha, diagonal matrix
443
- out = torch.matmul(input_diag, Q.T)
444
- out = torch.sum(out, dim=1)
445
-
446
- return out
447
-
448
- class Synthesis(nn.Module):
449
- def __init__(self, size, style_dim, motion_dim, blur_kernel=[1, 3, 3, 1], channel_multiplier=1):
450
- super(Synthesis, self).__init__()
451
-
452
- self.size = size
453
- self.style_dim = style_dim
454
- self.motion_dim = motion_dim
455
-
456
- self.direction = Direction(motion_dim) # Linear Motion Decomposition (LMD) from LIA
457
-
458
- self.channels = {
459
- 4: 512,
460
- 8: 512,
461
- 16: 512,
462
- 32: 512,
463
- 64: 256 * channel_multiplier,
464
- 128: 128 * channel_multiplier,
465
- 256: 64 * channel_multiplier,
466
- 512: 32 * channel_multiplier,
467
- 1024: 16 * channel_multiplier,
468
- }
469
-
470
- self.input = ConstantInput(self.channels[4])
471
- self.conv1 = StyledConv(self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel)
472
- self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
473
-
474
- self.log_size = int(math.log(size, 2))
475
- self.num_layers = (self.log_size - 2) * 2 + 1
476
-
477
- self.convs = nn.ModuleList()
478
- self.upsamples = nn.ModuleList()
479
- self.to_rgbs = nn.ModuleList()
480
- self.to_flows = nn.ModuleList()
481
-
482
- in_channel = self.channels[4]
483
-
484
- for i in range(3, self.log_size + 1):
485
- out_channel = self.channels[2 ** i]
486
-
487
- self.convs.append(StyledConv(in_channel, out_channel, 3, style_dim, upsample=True,
488
- blur_kernel=blur_kernel))
489
- self.convs.append(StyledConv(out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel))
490
- self.to_rgbs.append(ToRGB(out_channel, style_dim))
491
-
492
- self.to_flows.append(ToFlow(out_channel, style_dim))
493
-
494
- in_channel = out_channel
495
-
496
- self.n_latent = self.log_size * 2 - 2
497
-
498
- def forward(self, source_before_decoupling, target_motion, feats):
499
-
500
- directions = self.direction(target_motion)
501
- latent = source_before_decoupling + directions # wa + directions
502
-
503
- inject_index = self.n_latent
504
- latent = latent.unsqueeze(1).repeat(1, inject_index, 1)
505
-
506
- out = self.input(latent)
507
- out = self.conv1(out, latent[:, 0])
508
-
509
- i = 1
510
- for conv1, conv2, to_rgb, to_flow, feat in zip(self.convs[::2], self.convs[1::2], self.to_rgbs,
511
- self.to_flows, feats):
512
- out = conv1(out, latent[:, i])
513
- out = conv2(out, latent[:, i + 1])
514
- if out.size(2) == 8:
515
- out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat)
516
- skip = to_rgb(out_warp)
517
- else:
518
- out_warp, out, skip_flow = to_flow(out, latent[:, i + 2], feat, skip_flow)
519
- skip = to_rgb(out_warp, skip)
520
- i += 2
521
-
522
- img = skip
523
-
524
- return img
525
-
526
-
527
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/networks/utils.py DELETED
@@ -1,53 +0,0 @@
1
- from torch import nn
2
- import torch.nn.functional as F
3
- import torch
4
-
5
-
6
- class AntiAliasInterpolation2d(nn.Module):
7
- """
8
- Band-limited downsampling, for better preservation of the input signal.
9
- """
10
-
11
- def __init__(self, channels, scale):
12
- super(AntiAliasInterpolation2d, self).__init__()
13
- sigma = (1 / scale - 1) / 2
14
- kernel_size = 2 * round(sigma * 4) + 1
15
- self.ka = kernel_size // 2
16
- self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
17
-
18
- kernel_size = [kernel_size, kernel_size]
19
- sigma = [sigma, sigma]
20
- # The gaussian kernel is the product of the
21
- # gaussian function of each dimension.
22
- kernel = 1
23
- meshgrids = torch.meshgrid(
24
- [
25
- torch.arange(size, dtype=torch.float32)
26
- for size in kernel_size
27
- ]
28
- )
29
- for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
30
- mean = (size - 1) / 2
31
- kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
32
-
33
- # Make sure sum of values in gaussian kernel equals 1.
34
- kernel = kernel / torch.sum(kernel)
35
- # Reshape to depthwise convolutional weight
36
- kernel = kernel.view(1, 1, *kernel.size())
37
- kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
38
-
39
- self.register_buffer('weight', kernel)
40
- self.groups = channels
41
- self.scale = scale
42
- inv_scale = 1 / scale
43
- self.int_inv_scale = int(inv_scale)
44
-
45
- def forward(self, input):
46
- if self.scale == 1.0:
47
- return input
48
-
49
- out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
50
- out = F.conv2d(out, weight=self.weight, groups=self.groups)
51
- out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
52
-
53
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/renderer.py DELETED
@@ -1,25 +0,0 @@
1
- from config import *
2
-
3
- def render_condition(
4
- conf: TrainConfig,
5
- model,
6
- sampler, start, motion_direction_start, audio_driven, \
7
- face_location, face_scale, \
8
- yaw_pitch_roll, noisyT, control_flag,
9
- ):
10
- if conf.train_mode == TrainMode.diffusion:
11
- assert conf.model_type.has_autoenc()
12
-
13
- return sampler.sample(model=model,
14
- noise=noisyT,
15
- model_kwargs={
16
- 'motion_direction_start': motion_direction_start,
17
- 'yaw_pitch_roll': yaw_pitch_roll,
18
- 'start': start,
19
- 'audio_driven': audio_driven,
20
- 'face_location': face_location,
21
- 'face_scale': face_scale,
22
- 'control_flag': control_flag
23
- })
24
- else:
25
- raise NotImplementedError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
code/templates.py DELETED
@@ -1,301 +0,0 @@
1
- from experiment import *
2
-
3
-
4
- def ddpm():
5
- """
6
- base configuration for all DDIM-based models.
7
- """
8
- conf = TrainConfig()
9
- conf.batch_size = 32
10
- conf.beatgans_gen_type = GenerativeType.ddim
11
- conf.beta_scheduler = 'linear'
12
- conf.data_name = 'ffhq'
13
- conf.diffusion_type = 'beatgans'
14
- conf.eval_ema_every_samples = 200_000
15
- conf.eval_every_samples = 200_000
16
- conf.fp16 = True
17
- conf.lr = 1e-4
18
- conf.model_name = ModelName.beatgans_ddpm
19
- conf.net_attn = (16, )
20
- conf.net_beatgans_attn_head = 1
21
- conf.net_beatgans_embed_channels = 512
22
- conf.net_ch_mult = (1, 2, 4, 8)
23
- conf.net_ch = 64
24
- conf.sample_size = 32
25
- conf.T_eval = 20
26
- conf.T = 1000
27
- conf.make_model_conf()
28
- return conf
29
-
30
-
31
- def autoenc_base():
32
- """
33
- base configuration for all Diff-AE models.
34
- """
35
- conf = TrainConfig()
36
- conf.batch_size = 32
37
- conf.beatgans_gen_type = GenerativeType.ddim
38
- conf.beta_scheduler = 'linear'
39
- conf.data_name = 'ffhq'
40
- conf.diffusion_type = 'beatgans'
41
- conf.eval_ema_every_samples = 200_000
42
- conf.eval_every_samples = 200_000
43
- conf.fp16 = True
44
- conf.lr = 1e-4
45
- conf.model_name = ModelName.beatgans_autoenc
46
- conf.net_attn = (16, )
47
- conf.net_beatgans_attn_head = 1
48
- conf.net_beatgans_embed_channels = 512
49
- conf.net_beatgans_resnet_two_cond = True
50
- conf.net_ch_mult = (1, 2, 4, 8)
51
- conf.net_ch = 64
52
- conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
53
- conf.net_enc_pool = 'adaptivenonzero'
54
- conf.sample_size = 32
55
- conf.T_eval = 20
56
- conf.T = 1000
57
- conf.make_model_conf()
58
- return conf
59
-
60
-
61
- def ffhq64_ddpm():
62
- conf = ddpm()
63
- conf.data_name = 'ffhqlmdb256'
64
- conf.warmup = 0
65
- conf.total_samples = 72_000_000
66
- conf.scale_up_gpus(4)
67
- return conf
68
-
69
-
70
- def ffhq64_autoenc():
71
- conf = autoenc_base()
72
- conf.data_name = 'ffhqlmdb256'
73
- conf.warmup = 0
74
- conf.total_samples = 72_000_000
75
- conf.net_ch_mult = (1, 2, 4, 8)
76
- conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
77
- conf.eval_every_samples = 1_000_000
78
- conf.eval_ema_every_samples = 1_000_000
79
- conf.scale_up_gpus(4)
80
- conf.make_model_conf()
81
- return conf
82
-
83
-
84
- def celeba64d2c_ddpm():
85
- conf = ffhq128_ddpm()
86
- conf.data_name = 'celebalmdb'
87
- conf.eval_every_samples = 10_000_000
88
- conf.eval_ema_every_samples = 10_000_000
89
- conf.total_samples = 72_000_000
90
- conf.name = 'celeba64d2c_ddpm'
91
- return conf
92
-
93
-
94
- def celeba64d2c_autoenc():
95
- conf = ffhq64_autoenc()
96
- conf.data_name = 'celebalmdb'
97
- conf.eval_every_samples = 10_000_000
98
- conf.eval_ema_every_samples = 10_000_000
99
- conf.total_samples = 72_000_000
100
- conf.name = 'celeba64d2c_autoenc'
101
- return conf
102
-
103
-
104
- def ffhq128_ddpm():
105
- conf = ddpm()
106
- conf.data_name = 'ffhqlmdb256'
107
- conf.warmup = 0
108
- conf.total_samples = 48_000_000
109
- conf.img_size = 128
110
- conf.net_ch = 128
111
- # channels:
112
- # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4
113
- # sizes:
114
- # 128 => 128 => 64 => 32 => 16 => 8
115
- conf.net_ch_mult = (1, 1, 2, 3, 4)
116
- conf.eval_every_samples = 1_000_000
117
- conf.eval_ema_every_samples = 1_000_000
118
- conf.scale_up_gpus(4)
119
- conf.eval_ema_every_samples = 10_000_000
120
- conf.eval_every_samples = 10_000_000
121
- conf.make_model_conf()
122
- return conf
123
-
124
-
125
- def ffhq128_autoenc_base():
126
- conf = autoenc_base()
127
- conf.data_name = 'ffhqlmdb256'
128
- conf.scale_up_gpus(4)
129
- conf.img_size = 128
130
- conf.net_ch = 128
131
- # final resolution = 8x8
132
- conf.net_ch_mult = (1, 1, 2, 3, 4)
133
- # final resolution = 4x4
134
- conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4)
135
- conf.eval_ema_every_samples = 10_000_000
136
- conf.eval_every_samples = 10_000_000
137
- conf.make_model_conf()
138
- return conf
139
-
140
-
141
- def ffhq256_autoenc():
142
- conf = ffhq128_autoenc_base()
143
- conf.img_size = 256
144
- conf.net_ch = 128
145
- conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
146
- conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
147
- conf.eval_every_samples = 10_000_000
148
- conf.eval_ema_every_samples = 10_000_000
149
- conf.total_samples = 200_000_000
150
- conf.batch_size = 64
151
- conf.make_model_conf()
152
- conf.name = 'ffhq256_autoenc'
153
- return conf
154
-
155
-
156
- def ffhq256_autoenc_eco():
157
- conf = ffhq128_autoenc_base()
158
- conf.img_size = 256
159
- conf.net_ch = 128
160
- conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
161
- conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
162
- conf.eval_every_samples = 10_000_000
163
- conf.eval_ema_every_samples = 10_000_000
164
- conf.total_samples = 200_000_000
165
- conf.batch_size = 64
166
- conf.make_model_conf()
167
- conf.name = 'ffhq256_autoenc_eco'
168
- return conf
169
-
170
-
171
- def ffhq128_ddpm_72M():
172
- conf = ffhq128_ddpm()
173
- conf.total_samples = 72_000_000
174
- conf.name = 'ffhq128_ddpm_72M'
175
- return conf
176
-
177
-
178
- def ffhq128_autoenc_72M():
179
- conf = ffhq128_autoenc_base()
180
- conf.total_samples = 72_000_000
181
- conf.name = 'ffhq128_autoenc_72M'
182
- return conf
183
-
184
-
185
- def ffhq128_ddpm_130M():
186
- conf = ffhq128_ddpm()
187
- conf.total_samples = 130_000_000
188
- conf.eval_ema_every_samples = 10_000_000
189
- conf.eval_every_samples = 10_000_000
190
- conf.name = 'ffhq128_ddpm_130M'
191
- return conf
192
-
193
-
194
- def ffhq128_autoenc_130M():
195
- conf = ffhq128_autoenc_base()
196
- conf.total_samples = 130_000_000
197
- conf.eval_ema_every_samples = 10_000_000
198
- conf.eval_every_samples = 10_000_000
199
- conf.name = 'ffhq128_autoenc_130M'
200
- return conf
201
-
202
-
203
- def horse128_ddpm():
204
- conf = ffhq128_ddpm()
205
- conf.data_name = 'horse256'
206
- conf.total_samples = 130_000_000
207
- conf.eval_ema_every_samples = 10_000_000
208
- conf.eval_every_samples = 10_000_000
209
- conf.name = 'horse128_ddpm'
210
- return conf
211
-
212
-
213
- def horse128_autoenc():
214
- conf = ffhq128_autoenc_base()
215
- conf.data_name = 'horse256'
216
- conf.total_samples = 130_000_000
217
- conf.eval_ema_every_samples = 10_000_000
218
- conf.eval_every_samples = 10_000_000
219
- conf.name = 'horse128_autoenc'
220
- return conf
221
-
222
-
223
- def bedroom128_ddpm():
224
- conf = ffhq128_ddpm()
225
- conf.data_name = 'bedroom256'
226
- conf.eval_ema_every_samples = 10_000_000
227
- conf.eval_every_samples = 10_000_000
228
- conf.total_samples = 120_000_000
229
- conf.name = 'bedroom128_ddpm'
230
- return conf
231
-
232
-
233
- def bedroom128_autoenc():
234
- conf = ffhq128_autoenc_base()
235
- conf.data_name = 'bedroom256'
236
- conf.eval_ema_every_samples = 10_000_000
237
- conf.eval_every_samples = 10_000_000
238
- conf.total_samples = 120_000_000
239
- conf.name = 'bedroom128_autoenc'
240
- return conf
241
-
242
-
243
- def pretrain_celeba64d2c_72M():
244
- conf = celeba64d2c_autoenc()
245
- conf.pretrain = PretrainConfig(
246
- name='72M',
247
- path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt',
248
- )
249
- conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl'
250
- return conf
251
-
252
-
253
- def pretrain_ffhq128_autoenc72M():
254
- conf = ffhq128_autoenc_base()
255
- conf.postfix = ''
256
- conf.pretrain = PretrainConfig(
257
- name='72M',
258
- path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt',
259
- )
260
- conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl'
261
- return conf
262
-
263
-
264
- def pretrain_ffhq128_autoenc130M():
265
- conf = ffhq128_autoenc_base()
266
- conf.pretrain = PretrainConfig(
267
- name='130M',
268
- path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt',
269
- )
270
- conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl'
271
- return conf
272
-
273
-
274
- def pretrain_ffhq256_autoenc():
275
- conf = ffhq256_autoenc()
276
- conf.pretrain = PretrainConfig(
277
- name='90M',
278
- path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt',
279
- )
280
- conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl'
281
- return conf
282
-
283
-
284
- def pretrain_horse128():
285
- conf = horse128_autoenc()
286
- conf.pretrain = PretrainConfig(
287
- name='82M',
288
- path=f'checkpoints/{horse128_autoenc().name}/last.ckpt',
289
- )
290
- conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl'
291
- return conf
292
-
293
-
294
- def pretrain_bedroom128():
295
- conf = bedroom128_autoenc()
296
- conf.pretrain = PretrainConfig(
297
- name='120M',
298
- path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt',
299
- )
300
- conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl'
301
- return conf