capstonedubtrack commited on
Commit
8011cec
1 Parent(s): dbb8251

Upload wav2lip_train.py

Browse files
Files changed (1) hide show
  1. Wav2Lip/Wav2Lip/wav2lip_train.py +374 -0
Wav2Lip/Wav2Lip/wav2lip_train.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import dirname, join, basename, isfile
2
+ from tqdm import tqdm
3
+
4
+ from models import SyncNet_color as SyncNet
5
+ from models import Wav2Lip as Wav2Lip
6
+ import audio
7
+
8
+ import torch
9
+ from torch import nn
10
+ from torch import optim
11
+ import torch.backends.cudnn as cudnn
12
+ from torch.utils import data as data_utils
13
+ import numpy as np
14
+
15
+ from glob import glob
16
+
17
+ import os, random, cv2, argparse
18
+ from hparams import hparams, get_image_list
19
+
20
+ parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model without the visual quality discriminator')
21
+
22
+ parser.add_argument("--data_root", help="Root folder of the preprocessed LRS2 dataset", required=True, type=str)
23
+
24
+ parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True, type=str)
25
+ parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True, type=str)
26
+
27
+ parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None, type=str)
28
+
29
+ args = parser.parse_args()
30
+
31
+
32
+ global_step = 0
33
+ global_epoch = 0
34
+ use_cuda = torch.cuda.is_available()
35
+ print('use_cuda: {}'.format(use_cuda))
36
+
37
+ syncnet_T = 5
38
+ syncnet_mel_step_size = 16
39
+
40
+ class Dataset(object):
41
+ def __init__(self, split):
42
+ self.all_videos = get_image_list(args.data_root, split)
43
+
44
+ def get_frame_id(self, frame):
45
+ return int(basename(frame).split('.')[0])
46
+
47
+ def get_window(self, start_frame):
48
+ start_id = self.get_frame_id(start_frame)
49
+ vidname = dirname(start_frame)
50
+
51
+ window_fnames = []
52
+ for frame_id in range(start_id, start_id + syncnet_T):
53
+ frame = join(vidname, '{}.jpg'.format(frame_id))
54
+ if not isfile(frame):
55
+ return None
56
+ window_fnames.append(frame)
57
+ return window_fnames
58
+
59
+ def read_window(self, window_fnames):
60
+ if window_fnames is None: return None
61
+ window = []
62
+ for fname in window_fnames:
63
+ img = cv2.imread(fname)
64
+ if img is None:
65
+ return None
66
+ try:
67
+ img = cv2.resize(img, (hparams.img_size, hparams.img_size))
68
+ except Exception as e:
69
+ return None
70
+
71
+ window.append(img)
72
+
73
+ return window
74
+
75
+ def crop_audio_window(self, spec, start_frame):
76
+ if type(start_frame) == int:
77
+ start_frame_num = start_frame
78
+ else:
79
+ start_frame_num = self.get_frame_id(start_frame) # 0-indexing ---> 1-indexing
80
+ start_idx = int(80. * (start_frame_num / float(hparams.fps)))
81
+
82
+ end_idx = start_idx + syncnet_mel_step_size
83
+
84
+ return spec[start_idx : end_idx, :]
85
+
86
+ def get_segmented_mels(self, spec, start_frame):
87
+ mels = []
88
+ assert syncnet_T == 5
89
+ start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
90
+ if start_frame_num - 2 < 0: return None
91
+ for i in range(start_frame_num, start_frame_num + syncnet_T):
92
+ m = self.crop_audio_window(spec, i - 2)
93
+ if m.shape[0] != syncnet_mel_step_size:
94
+ return None
95
+ mels.append(m.T)
96
+
97
+ mels = np.asarray(mels)
98
+
99
+ return mels
100
+
101
+ def prepare_window(self, window):
102
+ # 3 x T x H x W
103
+ x = np.asarray(window) / 255.
104
+ x = np.transpose(x, (3, 0, 1, 2))
105
+
106
+ return x
107
+
108
+ def __len__(self):
109
+ return len(self.all_videos)
110
+
111
+ def __getitem__(self, idx):
112
+ while 1:
113
+ idx = random.randint(0, len(self.all_videos) - 1)
114
+ vidname = self.all_videos[idx]
115
+ img_names = list(glob(join(vidname, '*.jpg')))
116
+ if len(img_names) <= 3 * syncnet_T:
117
+ continue
118
+
119
+ img_name = random.choice(img_names)
120
+ wrong_img_name = random.choice(img_names)
121
+ while wrong_img_name == img_name:
122
+ wrong_img_name = random.choice(img_names)
123
+
124
+ window_fnames = self.get_window(img_name)
125
+ wrong_window_fnames = self.get_window(wrong_img_name)
126
+ if window_fnames is None or wrong_window_fnames is None:
127
+ continue
128
+
129
+ window = self.read_window(window_fnames)
130
+ if window is None:
131
+ continue
132
+
133
+ wrong_window = self.read_window(wrong_window_fnames)
134
+ if wrong_window is None:
135
+ continue
136
+
137
+ try:
138
+ wavpath = join(vidname, "audio.wav")
139
+ wav = audio.load_wav(wavpath, hparams.sample_rate)
140
+
141
+ orig_mel = audio.melspectrogram(wav).T
142
+ except Exception as e:
143
+ continue
144
+
145
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
146
+
147
+ if (mel.shape[0] != syncnet_mel_step_size):
148
+ continue
149
+
150
+ indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
151
+ if indiv_mels is None: continue
152
+
153
+ window = self.prepare_window(window)
154
+ y = window.copy()
155
+ window[:, :, window.shape[2]//2:] = 0.
156
+
157
+ wrong_window = self.prepare_window(wrong_window)
158
+ x = np.concatenate([window, wrong_window], axis=0)
159
+
160
+ x = torch.FloatTensor(x)
161
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
162
+ indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
163
+ y = torch.FloatTensor(y)
164
+ return x, indiv_mels, mel, y
165
+
166
+ def save_sample_images(x, g, gt, global_step, checkpoint_dir):
167
+ x = (x.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
168
+ g = (g.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
169
+ gt = (gt.detach().cpu().numpy().transpose(0, 2, 3, 4, 1) * 255.).astype(np.uint8)
170
+
171
+ refs, inps = x[..., 3:], x[..., :3]
172
+ folder = join(checkpoint_dir, "samples_step{:09d}".format(global_step))
173
+ if not os.path.exists(folder): os.mkdir(folder)
174
+ collage = np.concatenate((refs, inps, g, gt), axis=-2)
175
+ for batch_idx, c in enumerate(collage):
176
+ for t in range(len(c)):
177
+ cv2.imwrite('{}/{}_{}.jpg'.format(folder, batch_idx, t), c[t])
178
+
179
+ logloss = nn.BCELoss()
180
+ def cosine_loss(a, v, y):
181
+ d = nn.functional.cosine_similarity(a, v)
182
+ loss = logloss(d.unsqueeze(1), y)
183
+
184
+ return loss
185
+
186
+ device = torch.device("cuda" if use_cuda else "cpu")
187
+ syncnet = SyncNet().to(device)
188
+ for p in syncnet.parameters():
189
+ p.requires_grad = False
190
+
191
+ recon_loss = nn.L1Loss()
192
+ def get_sync_loss(mel, g):
193
+ g = g[:, :, :, g.size(3)//2:]
194
+ g = torch.cat([g[:, :, i] for i in range(syncnet_T)], dim=1)
195
+ # B, 3 * T, H//2, W
196
+ a, v = syncnet(mel, g)
197
+ y = torch.ones(g.size(0), 1).float().to(device)
198
+ return cosine_loss(a, v, y)
199
+
200
+ def train(device, model, train_data_loader, test_data_loader, optimizer,
201
+ checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
202
+
203
+ global global_step, global_epoch
204
+ resumed_step = global_step
205
+
206
+ while global_epoch < nepochs:
207
+ print('Starting Epoch: {}'.format(global_epoch))
208
+ running_sync_loss, running_l1_loss = 0., 0.
209
+ prog_bar = tqdm(enumerate(train_data_loader))
210
+ for step, (x, indiv_mels, mel, gt) in prog_bar:
211
+ model.train()
212
+ optimizer.zero_grad()
213
+
214
+ # Move data to CUDA device
215
+ x = x.to(device)
216
+ mel = mel.to(device)
217
+ indiv_mels = indiv_mels.to(device)
218
+ gt = gt.to(device)
219
+
220
+ g = model(indiv_mels, x)
221
+
222
+ if hparams.syncnet_wt > 0.:
223
+ sync_loss = get_sync_loss(mel, g)
224
+ else:
225
+ sync_loss = 0.
226
+
227
+ l1loss = recon_loss(g, gt)
228
+
229
+ loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt) * l1loss
230
+ loss.backward()
231
+ optimizer.step()
232
+
233
+ if global_step % checkpoint_interval == 0:
234
+ save_sample_images(x, g, gt, global_step, checkpoint_dir)
235
+
236
+ global_step += 1
237
+ cur_session_steps = global_step - resumed_step
238
+
239
+ running_l1_loss += l1loss.item()
240
+ if hparams.syncnet_wt > 0.:
241
+ running_sync_loss += sync_loss.item()
242
+ else:
243
+ running_sync_loss += 0.
244
+
245
+ if global_step == 1 or global_step % checkpoint_interval == 0:
246
+ save_checkpoint(
247
+ model, optimizer, global_step, checkpoint_dir, global_epoch)
248
+
249
+ if global_step == 1 or global_step % hparams.eval_interval == 0:
250
+ with torch.no_grad():
251
+ average_sync_loss = eval_model(test_data_loader, global_step, device, model, checkpoint_dir)
252
+
253
+ if average_sync_loss < .75:
254
+ hparams.set_hparam('syncnet_wt', 0.01) # without image GAN a lesser weight is sufficient
255
+
256
+ prog_bar.set_description('L1: {}, Sync Loss: {}'.format(running_l1_loss / (step + 1),
257
+ running_sync_loss / (step + 1)))
258
+
259
+ global_epoch += 1
260
+
261
+
262
+ def eval_model(test_data_loader, global_step, device, model, checkpoint_dir):
263
+ eval_steps = 700
264
+ print('Evaluating for {} steps'.format(eval_steps))
265
+ sync_losses, recon_losses = [], []
266
+ step = 0
267
+ while 1:
268
+ for x, indiv_mels, mel, gt in test_data_loader:
269
+ step += 1
270
+ model.eval()
271
+
272
+ # Move data to CUDA device
273
+ x = x.to(device)
274
+ gt = gt.to(device)
275
+ indiv_mels = indiv_mels.to(device)
276
+ mel = mel.to(device)
277
+
278
+ g = model(indiv_mels, x)
279
+
280
+ sync_loss = get_sync_loss(mel, g)
281
+ l1loss = recon_loss(g, gt)
282
+
283
+ sync_losses.append(sync_loss.item())
284
+ recon_losses.append(l1loss.item())
285
+
286
+ if step > eval_steps:
287
+ averaged_sync_loss = sum(sync_losses) / len(sync_losses)
288
+ averaged_recon_loss = sum(recon_losses) / len(recon_losses)
289
+
290
+ print('L1: {}, Sync loss: {}'.format(averaged_recon_loss, averaged_sync_loss))
291
+
292
+ return averaged_sync_loss
293
+
294
+ def save_checkpoint(model, optimizer, step, checkpoint_dir, epoch):
295
+
296
+ checkpoint_path = join(
297
+ checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step))
298
+ optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None
299
+ torch.save({
300
+ "state_dict": model.state_dict(),
301
+ "optimizer": optimizer_state,
302
+ "global_step": step,
303
+ "global_epoch": epoch,
304
+ }, checkpoint_path)
305
+ print("Saved checkpoint:", checkpoint_path)
306
+
307
+
308
+ def _load(checkpoint_path):
309
+ if use_cuda:
310
+ checkpoint = torch.load(checkpoint_path)
311
+ else:
312
+ checkpoint = torch.load(checkpoint_path,
313
+ map_location=lambda storage, loc: storage)
314
+ return checkpoint
315
+
316
+ def load_checkpoint(path, model, optimizer, reset_optimizer=False, overwrite_global_states=True):
317
+ global global_step
318
+ global global_epoch
319
+
320
+ print("Load checkpoint from: {}".format(path))
321
+ checkpoint = _load(path)
322
+ s = checkpoint["state_dict"]
323
+ new_s = {}
324
+ for k, v in s.items():
325
+ new_s[k.replace('module.', '')] = v
326
+ model.load_state_dict(new_s)
327
+ if not reset_optimizer:
328
+ optimizer_state = checkpoint["optimizer"]
329
+ if optimizer_state is not None:
330
+ print("Load optimizer state from {}".format(path))
331
+ optimizer.load_state_dict(checkpoint["optimizer"])
332
+ if overwrite_global_states:
333
+ global_step = checkpoint["global_step"]
334
+ global_epoch = checkpoint["global_epoch"]
335
+
336
+ return model
337
+
338
+ if __name__ == "__main__":
339
+ checkpoint_dir = args.checkpoint_dir
340
+
341
+ # Dataset and Dataloader setup
342
+ train_dataset = Dataset('train')
343
+ test_dataset = Dataset('val')
344
+
345
+ train_data_loader = data_utils.DataLoader(
346
+ train_dataset, batch_size=hparams.batch_size, shuffle=True,
347
+ num_workers=hparams.num_workers)
348
+
349
+ test_data_loader = data_utils.DataLoader(
350
+ test_dataset, batch_size=hparams.batch_size,
351
+ num_workers=4)
352
+
353
+ device = torch.device("cuda" if use_cuda else "cpu")
354
+
355
+ # Model
356
+ model = Wav2Lip().to(device)
357
+ print('total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
358
+
359
+ optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
360
+ lr=hparams.initial_learning_rate)
361
+
362
+ if args.checkpoint_path is not None:
363
+ load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
364
+
365
+ load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)
366
+
367
+ if not os.path.exists(checkpoint_dir):
368
+ os.mkdir(checkpoint_dir)
369
+
370
+ # Train!
371
+ train(device, model, train_data_loader, test_data_loader, optimizer,
372
+ checkpoint_dir=checkpoint_dir,
373
+ checkpoint_interval=hparams.checkpoint_interval,
374
+ nepochs=hparams.nepochs)