augan / main.py
qninhdt's picture
Upload 28 files
89ce6b3 verified
raw
history blame
4.91 kB
import argparse
import tensorflow as tf
import os
from utils import *
from AUGAN import AUGAN
from ops import *
import time
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"--dataset_dir", dest="dataset_dir", default="bdd100k", help="path of the dataset"
)
parser.add_argument(
"--experiment_name",
dest="experiment_name",
type=str,
default="bdd_exp",
help="name of experiment",
)
parser.add_argument("--epoch", dest="epoch", type=int, default=20, help="# of epoch")
parser.add_argument(
"--epoch_step",
dest="epoch_step",
type=int,
default=10,
help="# of epoch to decay lr",
)
parser.add_argument(
"--batch_size", dest="batch_size", type=int, default=1, help="# images in batch"
)
parser.add_argument(
"--train_size",
dest="train_size",
type=int,
default=1e8,
help="# images used to train",
)
parser.add_argument(
"--load_size",
dest="load_size",
type=int,
default=286,
help="scale images to this size",
)
parser.add_argument(
"--fine_size",
dest="fine_size",
type=int,
default=256,
help="then crop to this size",
)
parser.add_argument(
"--ngf",
dest="ngf",
type=int,
default=64,
help="# of gen filters in first conv layer",
)
parser.add_argument(
"--ndf",
dest="ndf",
type=int,
default=64,
help="# of discri filters in first conv layer",
)
parser.add_argument(
"--n_d", dest="n_d", type=int, default=2, help="# of discriminators"
)
parser.add_argument(
"--n_scale", dest="n_scale", type=int, default=2, help="# of scales"
)
parser.add_argument(
"--gpu", dest="gpu", type=int, default=0, help="# index of gpu device"
)
parser.add_argument(
"--input_nc", dest="input_nc", type=int, default=3, help="# of input image channels"
)
parser.add_argument(
"--output_nc",
dest="output_nc",
type=int,
default=3,
help="# of output image channels",
)
parser.add_argument(
"--lr", dest="lr", type=float, default=0.0002, help="initial learning rate for adam"
)
parser.add_argument(
"--beta1", dest="beta1", type=float, default=0.5, help="momentum term of adam"
)
parser.add_argument(
"--which_direction", dest="which_direction", default="AtoB", help="AtoB or BtoA "
)
parser.add_argument("--phase", dest="phase", default="test", help="train, test")
parser.add_argument(
"--save_freq",
dest="save_freq",
type=int,
default=1000,
help="save a model every save_freq iterations",
)
parser.add_argument(
"--print_freq",
dest="print_freq",
type=int,
default=100,
help="print the debug information every print_freq iterations",
)
parser.add_argument(
"--L1_lambda",
dest="L1_lambda",
type=float,
default=10.0,
help="weight on L1 term in objective",
)
parser.add_argument(
"--conf_lambda",
dest="conf_lambda",
type=float,
default=1.0,
help="weight on L1 term in objective",
)
parser.add_argument(
"--use_resnet",
dest="use_resnet",
type=bool,
default=True,
help="generation network using reidule block",
)
parser.add_argument(
"--use_lsgan",
dest="use_lsgan",
type=bool,
default=True,
help="gan loss defined in lsgan",
)
parser.add_argument(
"--use_uncertainty",
dest="use_uncertainty",
type=bool,
default=True,
help="max size of image pool, 0 means do not use image pool",
)
parser.add_argument(
"--max_size",
dest="max_size",
type=int,
default=50,
help="max size of image pool, 0 means do not use image pool",
)
parser.add_argument(
"--continue_train",
dest="continue_train",
type=bool,
default=False,
help="if continue training, load the latest model: 1: true, 0: false",
)
parser.add_argument(
"--save_conf",
dest="save_conf",
type=bool,
default=False,
help="save conf map in test phase",
)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
def main(_):
set_path(args, args.experiment_name)
tfconfig = tf.compat.v1.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True
with tf.compat.v1.Session(config=tfconfig) as sess:
model = AUGAN(sess, args)
# show_all_variables()
# model.train(args) if args.phase == 'train' \
# else model.test(args)
if args.phase == "train":
model.train(args)
elif args.phase == "test":
model.test(args)
elif args.phase == "convert":
model.convert_image(args, "inf_data/b1ca2e5d-84cf9134.jpg", "out")
else:
raise Exception("Give a phase")
if __name__ == "__main__":
tf.compat.v1.app.run()