Xin Liu
test
2492d81 unverified
import matplotlib
matplotlib.use('Agg')
import os, sys
import yaml
from argparse import ArgumentParser
from time import gmtime, strftime
from shutil import copy
from frames_dataset import FramesDataset
from modules.inpainting_network import InpaintingNetwork
from modules.keypoint_detector import KPDetector
from modules.bg_motion_predictor import BGMotionPredictor
from modules.dense_motion import DenseMotionNetwork
from modules.avd_network import AVDNetwork
import torch
from train import train
from train_avd import train_avd
from reconstruction import reconstruction
import os
if __name__ == "__main__":
if sys.version_info[0] < 3:
raise Exception("You must use Python 3 or higher. Recommended version is Python 3.9")
parser = ArgumentParser()
parser.add_argument("--config", default="config/vox-256.yaml", help="path to config")
parser.add_argument("--mode", default="train", choices=["train", "reconstruction", "train_avd"])
parser.add_argument("--log_dir", default='log', help="path to log into")
parser.add_argument("--checkpoint", default=None, help="path to checkpoint to restore")
parser.add_argument("--device_ids", default="0,1", type=lambda x: list(map(int, x.split(','))),
help="Names of the devices comma separated.")
opt = parser.parse_args()
with open(opt.config) as f:
config = yaml.load(f)
if opt.checkpoint is not None:
log_dir = os.path.join(*os.path.split(opt.checkpoint)[:-1])
else:
log_dir = os.path.join(opt.log_dir, os.path.basename(opt.config).split('.')[0])
log_dir += ' ' + strftime("%d_%m_%y_%H.%M.%S", gmtime())
inpainting = InpaintingNetwork(**config['model_params']['generator_params'],
**config['model_params']['common_params'])
if torch.cuda.is_available():
cuda_device = torch.device('cuda:'+str(opt.device_ids[0]))
inpainting.to(cuda_device)
kp_detector = KPDetector(**config['model_params']['common_params'])
dense_motion_network = DenseMotionNetwork(**config['model_params']['common_params'],
**config['model_params']['dense_motion_params'])
if torch.cuda.is_available():
kp_detector.to(opt.device_ids[0])
dense_motion_network.to(opt.device_ids[0])
bg_predictor = None
if (config['model_params']['common_params']['bg']):
bg_predictor = BGMotionPredictor()
if torch.cuda.is_available():
bg_predictor.to(opt.device_ids[0])
avd_network = None
if opt.mode == "train_avd":
avd_network = AVDNetwork(num_tps=config['model_params']['common_params']['num_tps'],
**config['model_params']['avd_network_params'])
if torch.cuda.is_available():
avd_network.to(opt.device_ids[0])
dataset = FramesDataset(is_train=(opt.mode.startswith('train')), **config['dataset_params'])
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if not os.path.exists(os.path.join(log_dir, os.path.basename(opt.config))):
copy(opt.config, log_dir)
if opt.mode == 'train':
print("Training...")
train(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)
elif opt.mode == 'train_avd':
print("Training Animation via Disentaglement...")
train_avd(config, inpainting, kp_detector, bg_predictor, dense_motion_network, avd_network, opt.checkpoint, log_dir, dataset)
elif opt.mode == 'reconstruction':
print("Reconstruction...")
reconstruction(config, inpainting, kp_detector, bg_predictor, dense_motion_network, opt.checkpoint, log_dir, dataset)