File size: 4,865 Bytes
cc6c676 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import argparse
import os
import torch
from WiggleGAN import WiggleGAN
#from MyACGAN import MyACGAN
#from MyGAN import MyGAN
"""parsing and configuration"""
def parse_args():
desc = "Pytorch implementation of GAN collections"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--gan_type', type=str, default='WiggleGAN',
choices=['MyACGAN', 'MyGAN', 'WiggleGAN'],
help='The type of GAN')
parser.add_argument('--dataset', type=str, default='4cam',
choices=['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'svhn', 'stl10', 'lsun-bed', '4cam'],
help='The name of dataset')
parser.add_argument('--split', type=str, default='', help='The split flag for svhn and stl10')
parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run')
parser.add_argument('--batch_size', type=int, default=16, help='The size of batch')
parser.add_argument('--input_size', type=int, default=10, help='The size of input image')
parser.add_argument('--save_dir', type=str, default='models',
help='Directory name to save the model')
parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the generated images')
parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs')
parser.add_argument('--lrG', type=float, default=0.0002)
parser.add_argument('--lrD', type=float, default=0.001)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--gpu_mode', type=str2bool, default=True)
parser.add_argument('--benchmark_mode', type=str2bool, default=True)
parser.add_argument('--cameras', type=int, default=2)
parser.add_argument('--imageDim', type=int, default=128)
parser.add_argument('--epochV', type=int, default=0)
parser.add_argument('--cIm', type=int, default=4)
parser.add_argument('--seedLoad', type=str, default="-0000")
parser.add_argument('--zGF', type=float, default=0.2)
parser.add_argument('--zDF', type=float, default=0.2)
parser.add_argument('--bF', type=float, default=0.2)
parser.add_argument('--expandGen', type=int, default=3)
parser.add_argument('--expandDis', type=int, default=3)
parser.add_argument('--wiggleDepth', type=int, default=-1)
parser.add_argument('--visdom', type=str2bool, default=True)
parser.add_argument('--lambdaL1', type=int, default=100)
parser.add_argument('--clipping', type=float, default=-1)
parser.add_argument('--depth', type=str2bool, default=True)
parser.add_argument('--recreate', type=str2bool, default=False)
parser.add_argument('--name_wiggle', type=str, default='wiggle-result')
return check_args(parser.parse_args())
"""checking arguments"""
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def check_args(args):
# --save_dir
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# --result_dir
if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
# --result_dir
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
# --epoch
try:
assert args.epoch >= 1
except:
print('number of epochs must be larger than or equal to one')
# --batch_size
try:
assert args.batch_size >= 1
except:
print('batch size must be larger than or equal to one')
return args
"""main"""
def main():
# parse arguments
args = parse_args()
if args is None:
exit()
if args.benchmark_mode:
torch.backends.cudnn.benchmark = True
# declare instance for GAN
if args.gan_type == 'WiggleGAN':
gan = WiggleGAN(args)
#elif args.gan_type == 'MyACGAN':
# gan = MyACGAN(args)
#elif args.gan_type == 'MyGAN':
# gan = MyGAN(args)
else:
raise Exception("[!] There is no option for " + args.gan_type)
# launch the graph in a session
if (args.wiggleDepth < 0 and not args.recreate):
print(" [*] Training Starting!")
gan.train()
print(" [*] Training finished!")
else:
if not args.recreate:
print(" [*] Wiggle Started!")
gan.wiggleEf()
print(" [*] Wiggle finished!")
else:
print(" [*] Dataset recreation Started")
gan.recreate()
print(" [*] Dataset recreation finished")
if __name__ == '__main__':
main()
|