""" Video Face Manipulation Detection Through Ensemble of CNNs Image and Sound Processing Lab - Politecnico di Milano Nicolò Bonettini Edoardo Daniele Cannas Sara Mandelli Luca Bondi Paolo Bestagini """ import argparse import os import shutil import warnings import numpy as np import torch import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') import torch.nn as nn import torch.optim as optim from tensorboardX import SummaryWriter from torch.utils.data import DataLoader from tqdm import tqdm from architectures import tripletnet from train_binclass import save_model, tb_attention from isplutils.data import FrameFaceIterableDataset from isplutils.data_siamese import FrameFaceTripletIterableDataset from isplutils import split, utils def main(): # Args parser = argparse.ArgumentParser() parser.add_argument('--net', type=str, help='Net model class', required=True) parser.add_argument('--traindb', type=str, help='Training datasets', nargs='+', choices=split.available_datasets, required=True) parser.add_argument('--valdb', type=str, help='Validation datasets', nargs='+', choices=split.available_datasets, required=True) parser.add_argument('--dfdc_faces_df_path', type=str, action='store', help='Path to the Pandas Dataframe obtained from extract_faces.py on the DFDC dataset. ' 'Required for training/validating on the DFDC dataset.') parser.add_argument('--dfdc_faces_dir', type=str, action='store', help='Path to the directory containing the faces extracted from the DFDC dataset. ' 'Required for training/validating on the DFDC dataset.') parser.add_argument('--ffpp_faces_df_path', type=str, action='store', help='Path to the Pandas Dataframe obtained from extract_faces.py on the FF++ dataset. ' 'Required for training/validating on the FF++ dataset.') parser.add_argument('--ffpp_faces_dir', type=str, action='store', help='Path to the directory containing the faces extracted from the FF++ dataset. ' 'Required for training/validating on the FF++ dataset.') parser.add_argument('--face', type=str, help='Face crop or scale', required=True, choices=['scale', 'tight']) parser.add_argument('--size', type=int, help='Train patch size', required=True) parser.add_argument('--batch', type=int, help='Batch size to fit in GPU memory', default=12) parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate') parser.add_argument('--valint', type=int, help='Validation interval (iterations)', default=500) parser.add_argument('--patience', type=int, help='Patience before dropping the LR [validation intervals]', default=10) parser.add_argument('--maxiter', type=int, help='Maximum number of iterations', default=20000) parser.add_argument('--init', type=str, help='Weight initialization file') parser.add_argument('--scratch', action='store_true', help='Train from scratch') parser.add_argument('--traintriplets', type=int, help='Limit the number of train triplets per epoch', default=-1) parser.add_argument('--valtriplets', type=int, help='Limit the number of validation triplets per epoch', default=2000) parser.add_argument('--logint', type=int, help='Training log interval (iterations)', default=100) parser.add_argument('--workers', type=int, help='Num workers for data loaders', default=6) parser.add_argument('--device', type=int, help='GPU device id', default=0) parser.add_argument('--seed', type=int, help='Random seed', default=0) parser.add_argument('--debug', action='store_true', help='Activate debug') parser.add_argument('--suffix', type=str, help='Suffix to default tag') parser.add_argument('--attention', action='store_true', help='Enable Tensorboard log of attention masks') parser.add_argument('--embedding', action='store_true', help='Activate embedding visualization in TensorBoard') parser.add_argument('--embeddingint', type=int, help='Embedding visualization interval in TensorBoard', default=5000) parser.add_argument('--log_dir', type=str, help='Directory for saving the training logs', default='runs/triplet/') parser.add_argument('--models_dir', type=str, help='Directory for saving the models weights', default='weights/triplet/') args = parser.parse_args() # Parse arguments net_class = getattr(tripletnet, args.net) train_datasets = args.traindb val_datasets = args.valdb dfdc_df_path = args.dfdc_faces_df_path ffpp_df_path = args.ffpp_faces_df_path dfdc_faces_dir = args.dfdc_faces_dir ffpp_faces_dir = args.ffpp_faces_dir face_policy = args.face face_size = args.size batch_size = args.batch initial_lr = args.lr validation_interval = args.valint patience = args.patience max_num_iterations = args.maxiter initial_model = args.init train_from_scratch = args.scratch max_train_triplets = args.traintriplets max_val_triplets = args.valtriplets log_interval = args.logint num_workers = args.workers device = torch.device('cuda:{:d}'.format(args.device)) if torch.cuda.is_available() else torch.device('cpu') seed = args.seed debug = args.debug suffix = args.suffix enable_attention = args.attention enable_embedding = args.embedding embedding_interval = args.embeddingint weights_folder = args.models_dir logs_folder = args.log_dir # Random initialization np.random.seed(seed) torch.random.manual_seed(seed) # Load net net: nn.Module = net_class().to(device) # Loss and optimizers criterion = nn.TripletMarginLoss() min_lr = initial_lr * 1e-5 optimizer = optim.Adam(net.get_trainable_parameters(), lr=initial_lr) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer=optimizer, mode='min', factor=0.1, patience=patience, cooldown=2 * patience, min_lr=min_lr, ) tag = utils.make_train_tag(net_class=net_class, traindb=train_datasets, face_policy=face_policy, patch_size=face_size, seed=seed, suffix=suffix, debug=debug, ) # Model checkpoint paths bestval_path = os.path.join(weights_folder, tag, 'bestval.pth') last_path = os.path.join(weights_folder, tag, 'last.pth') periodic_path = os.path.join(weights_folder, tag, 'it{:06d}.pth') os.makedirs(os.path.join(weights_folder, tag), exist_ok=True) # Load model val_loss = min_val_loss = 20 epoch = iteration = 0 net_state = None opt_state = None if initial_model is not None: # If given load initial model print('Loading model form: {}'.format(initial_model)) state = torch.load(initial_model, map_location='cpu') net_state = state['net'] elif not train_from_scratch and os.path.exists(last_path): print('Loading model form: {}'.format(last_path)) state = torch.load(last_path, map_location='cpu') net_state = state['net'] opt_state = state['opt'] iteration = state['iteration'] + 1 epoch = state['epoch'] if not train_from_scratch and os.path.exists(bestval_path): state = torch.load(bestval_path, map_location='cpu') min_val_loss = state['val_loss'] if net_state is not None: adapt_binclass_model(net_state) incomp_keys = net.load_state_dict(net_state, strict=False) print(incomp_keys) if opt_state is not None: for param_group in opt_state['param_groups']: param_group['lr'] = initial_lr optimizer.load_state_dict(opt_state) # Initialize Tensorboard logdir = os.path.join(logs_folder, tag) if iteration == 0: # If training from scratch or initialization remove history if exists shutil.rmtree(logdir, ignore_errors=True) # TensorboardX instance tb = SummaryWriter(logdir=logdir) if iteration == 0: dummy = torch.randn((1, 3, face_size, face_size), device=device) with warnings.catch_warnings(): warnings.simplefilter("ignore") tb.add_graph(net, [dummy, dummy, dummy], verbose=False) transformer = utils.get_transformer(face_policy=face_policy, patch_size=face_size, net_normalizer=net.get_normalizer(), train=True) # Datasets and data loaders print('Loading data') # Check if paths for DFDC and FF++ extracted faces and DataFrames are provided for dataset in train_datasets: if dataset.split('-')[0] == 'dfdc' and (dfdc_df_path is None or dfdc_faces_dir is None): raise RuntimeError('Specify DataFrame and directory for DFDC faces for training!') elif dataset.split('-')[0] == 'ff' and (ffpp_df_path is None or ffpp_faces_dir is None): raise RuntimeError('Specify DataFrame and directory for FF++ faces for training!') for dataset in val_datasets: if dataset.split('-')[0] == 'dfdc' and (dfdc_df_path is None or dfdc_faces_dir is None): raise RuntimeError('Specify DataFrame and directory for DFDC faces for validation!') elif dataset.split('-')[0] == 'ff' and (ffpp_df_path is None or ffpp_faces_dir is None): raise RuntimeError('Specify DataFrame and directory for FF++ faces for validation!') splits = split.make_splits(dfdc_df=dfdc_df_path, ffpp_df=ffpp_df_path, dfdc_dir=dfdc_faces_dir, ffpp_dir=ffpp_faces_dir, dbs={'train': train_datasets, 'val': val_datasets}) train_dfs = [splits['train'][db][0] for db in splits['train']] train_roots = [splits['train'][db][1] for db in splits['train']] val_roots = [splits['val'][db][1] for db in splits['val']] val_dfs = [splits['val'][db][0] for db in splits['val']] train_dataset = FrameFaceTripletIterableDataset(roots=train_roots, dfs=train_dfs, scale=face_policy, num_triplets=max_train_triplets, transformer=transformer, size=face_size, ) val_dataset = FrameFaceTripletIterableDataset(roots=val_roots, dfs=val_dfs, scale=face_policy, num_triplets=max_val_triplets, transformer=transformer, size=face_size, ) train_loader = DataLoader(train_dataset, num_workers=num_workers, batch_size=batch_size, ) val_loader = DataLoader(val_dataset, num_workers=num_workers, batch_size=batch_size, ) print('Training triplets: {}'.format(len(train_dataset))) print('Validation triplets: {}'.format(len(val_dataset))) if len(train_dataset) == 0: print('No training triplets. Halt.') return if len(val_dataset) == 0: print('No validation triplets. Halt.') return # Embedding visualization if enable_embedding: train_dataset_embedding = FrameFaceIterableDataset(roots=train_roots, dfs=train_dfs, scale=face_policy, num_samples=64, transformer=transformer, size=face_size, ) train_loader_embedding = DataLoader(train_dataset_embedding, num_workers=num_workers, batch_size=batch_size, ) val_dataset_embedding = FrameFaceIterableDataset(roots=val_roots, dfs=val_dfs, scale=face_policy, num_samples=64, transformer=transformer, size=face_size, ) val_loader_embedding = DataLoader(val_dataset_embedding, num_workers=num_workers, batch_size=batch_size, ) else: train_loader_embedding = None val_loader_embedding = None stop = False while not stop: # Training optimizer.zero_grad() train_loss = train_num = 0 for train_batch in tqdm(train_loader, desc='Epoch {:03d}'.format(epoch), leave=False, total=len(train_loader) // train_loader.batch_size): net.train() train_batch_num = len(train_batch[0]) train_num += train_batch_num train_batch_loss = batch_forward(net, device, criterion, train_batch) if torch.isnan(train_batch_loss): raise ValueError('NaN loss') train_loss += train_batch_loss.item() * train_batch_num # Optimization train_batch_loss.backward() optimizer.step() optimizer.zero_grad() # Logging if iteration > 0 and (iteration % log_interval == 0): train_loss /= train_num tb.add_scalar('train/loss', train_loss, iteration) tb.add_scalar('lr', optimizer.param_groups[0]['lr'], iteration) tb.add_scalar('epoch', epoch, iteration) # Checkpoint save_model(net, optimizer, train_loss, val_loss, iteration, batch_size, epoch, last_path) train_loss = train_num = 0 # Validation if iteration > 0 and (iteration % validation_interval == 0): # Validation val_loss = validation_routine(net, device, val_loader, criterion, tb, iteration, tag='val') tb.flush() # LR Scheduler lr_scheduler.step(val_loss) # Model checkpoint save_model(net, optimizer, train_loss, val_loss, iteration, batch_size, epoch, periodic_path.format(iteration)) if val_loss < min_val_loss: min_val_loss = val_loss shutil.copy(periodic_path.format(iteration), bestval_path) # Attention if enable_attention and hasattr(net, 'feat_ext') and hasattr(net.feat_ext, 'get_attention'): net.eval() # For each dataframe show the attention for a real,fake couple of frames for df, root, sample_idx, tag in [ (train_dfs[0], train_roots[0], train_dfs[0][train_dfs[0]['label'] == False].index[0], 'train/att/real'), (train_dfs[0], train_roots[0], train_dfs[0][train_dfs[0]['label'] == True].index[0], 'train/att/fake'), ]: record = df.loc[sample_idx] tb_attention(tb, tag, iteration, net.feat_ext, device, face_size, face_policy, transformer, root, record) if optimizer.param_groups[0]['lr'] <= min_lr: print('Reached minimum learning rate. Stopping.') stop = True break # Embedding visualization if enable_embedding: if iteration > 0 and (iteration % embedding_interval == 0): embedding_routine(net=net, device=device, loader=train_loader_embedding, iteration=iteration, tb=tb, tag=tag + '/train') embedding_routine(net=net, device=device, loader=val_loader_embedding, iteration=iteration, tb=tb, tag=tag + '/val') iteration += 1 if iteration > max_num_iterations: print('Maximum number of iterations reached') stop = True break # End of iteration epoch += 1 # Needed to flush out last events tb.close() print('Completed') def adapt_binclass_model(net_state): # Check that the model contains at least one key starting with feat_ext, otherwise adapt found = False for key in net_state: if key.startswith('feat_ext.'): found = True break if not found: # Adapt all keys print('Adapting keys') keys = [k for k in net_state] for key in keys: net_state['feat_ext.{}'.format(key)] = net_state[key] del net_state[key] def batch_forward(net: nn.Module, device, criterion, data: tuple) -> torch.Tensor: if torch.cuda.is_available(): data = [i.cuda(device) for i in data] out = net(*data) loss = criterion(*out) return loss def validation_routine(net, device, val_loader, criterion, tb, iteration, tag): net.eval() val_num = 0 val_loss = 0. for val_data in tqdm(val_loader, desc='Validation', leave=False, total=len(val_loader) // val_loader.batch_size): val_batch_num = len(val_data[0]) with torch.no_grad(): val_batch_loss = batch_forward(net, device, criterion, val_data, ) val_num += val_batch_num val_loss += val_batch_loss.item() * val_batch_num # Logging val_loss /= val_num tb.add_scalar('{}/loss'.format(tag), val_loss, iteration) return val_loss def embedding_routine(net: nn.Module, device: torch.device, loader: DataLoader, tb: SummaryWriter, iteration: int, tag: str): net.eval() labels = [] embeddings = [] for batch_data in loader: batch_faces, batch_labels = batch_data if torch.cuda.is_available(): batch_faces = batch_faces.to(device) with torch.no_grad(): batch_emb = net.features(batch_faces) labels.append(batch_labels.numpy().flatten()) embeddings.append(torch.flatten(batch_emb.cpu(), start_dim=1).numpy()) labels = list(np.concatenate(labels)) embeddings = np.concatenate(embeddings) # Logging tb.add_embedding(mat=embeddings, metadata=labels, tag=tag, global_step=iteration) if __name__ == '__main__': main()