import argparse import tensorflow as tf import os from utils import * from AUGAN import AUGAN from ops import * import time parser = argparse.ArgumentParser(description="") parser.add_argument( "--dataset_dir", dest="dataset_dir", default="bdd100k", help="path of the dataset" ) parser.add_argument( "--experiment_name", dest="experiment_name", type=str, default="bdd_exp", help="name of experiment", ) parser.add_argument("--epoch", dest="epoch", type=int, default=20, help="# of epoch") parser.add_argument( "--epoch_step", dest="epoch_step", type=int, default=10, help="# of epoch to decay lr", ) parser.add_argument( "--batch_size", dest="batch_size", type=int, default=1, help="# images in batch" ) parser.add_argument( "--train_size", dest="train_size", type=int, default=1e8, help="# images used to train", ) parser.add_argument( "--load_size", dest="load_size", type=int, default=286, help="scale images to this size", ) parser.add_argument( "--fine_size", dest="fine_size", type=int, default=256, help="then crop to this size", ) parser.add_argument( "--ngf", dest="ngf", type=int, default=64, help="# of gen filters in first conv layer", ) parser.add_argument( "--ndf", dest="ndf", type=int, default=64, help="# of discri filters in first conv layer", ) parser.add_argument( "--n_d", dest="n_d", type=int, default=2, help="# of discriminators" ) parser.add_argument( "--n_scale", dest="n_scale", type=int, default=2, help="# of scales" ) parser.add_argument( "--gpu", dest="gpu", type=int, default=0, help="# index of gpu device" ) parser.add_argument( "--input_nc", dest="input_nc", type=int, default=3, help="# of input image channels" ) parser.add_argument( "--output_nc", dest="output_nc", type=int, default=3, help="# of output image channels", ) parser.add_argument( "--lr", dest="lr", type=float, default=0.0002, help="initial learning rate for adam" ) parser.add_argument( "--beta1", dest="beta1", type=float, default=0.5, help="momentum term of adam" ) parser.add_argument( "--which_direction", dest="which_direction", default="AtoB", help="AtoB or BtoA " ) parser.add_argument("--phase", dest="phase", default="test", help="train, test") parser.add_argument( "--save_freq", dest="save_freq", type=int, default=1000, help="save a model every save_freq iterations", ) parser.add_argument( "--print_freq", dest="print_freq", type=int, default=100, help="print the debug information every print_freq iterations", ) parser.add_argument( "--L1_lambda", dest="L1_lambda", type=float, default=10.0, help="weight on L1 term in objective", ) parser.add_argument( "--conf_lambda", dest="conf_lambda", type=float, default=1.0, help="weight on L1 term in objective", ) parser.add_argument( "--use_resnet", dest="use_resnet", type=bool, default=True, help="generation network using reidule block", ) parser.add_argument( "--use_lsgan", dest="use_lsgan", type=bool, default=True, help="gan loss defined in lsgan", ) parser.add_argument( "--use_uncertainty", dest="use_uncertainty", type=bool, default=True, help="max size of image pool, 0 means do not use image pool", ) parser.add_argument( "--max_size", dest="max_size", type=int, default=50, help="max size of image pool, 0 means do not use image pool", ) parser.add_argument( "--continue_train", dest="continue_train", type=bool, default=False, help="if continue training, load the latest model: 1: true, 0: false", ) parser.add_argument( "--save_conf", dest="save_conf", type=bool, default=False, help="save conf map in test phase", ) args = parser.parse_args() os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) def main(_): set_path(args, args.experiment_name) tfconfig = tf.compat.v1.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.compat.v1.Session(config=tfconfig) as sess: model = AUGAN(sess, args) # show_all_variables() # model.train(args) if args.phase == 'train' \ # else model.test(args) if args.phase == "train": model.train(args) elif args.phase == "test": model.test(args) elif args.phase == "convert": model.convert_image(args, "inf_data/b1ca2e5d-84cf9134.jpg", "out") else: raise Exception("Give a phase") if __name__ == "__main__": tf.compat.v1.app.run()