File size: 4,865 Bytes
cc6c676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import argparse
import os
import torch
from WiggleGAN import WiggleGAN
#from MyACGAN import MyACGAN
#from MyGAN import MyGAN

"""parsing and configuration"""


def parse_args():
    desc = "Pytorch implementation of GAN collections"
    parser = argparse.ArgumentParser(description=desc)

    parser.add_argument('--gan_type', type=str, default='WiggleGAN',
                        choices=['MyACGAN', 'MyGAN', 'WiggleGAN'],
                        help='The type of GAN')
    parser.add_argument('--dataset', type=str, default='4cam',
                        choices=['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'svhn', 'stl10', 'lsun-bed', '4cam'],
                        help='The name of dataset')
    parser.add_argument('--split', type=str, default='', help='The split flag for svhn and stl10')
    parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run')
    parser.add_argument('--batch_size', type=int, default=16, help='The size of batch')
    parser.add_argument('--input_size', type=int, default=10, help='The size of input image')
    parser.add_argument('--save_dir', type=str, default='models',
                        help='Directory name to save the model')
    parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the generated images')
    parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs')
    parser.add_argument('--lrG', type=float, default=0.0002)
    parser.add_argument('--lrD', type=float, default=0.001)
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--beta2', type=float, default=0.999)
    parser.add_argument('--gpu_mode', type=str2bool, default=True)
    parser.add_argument('--benchmark_mode', type=str2bool, default=True)
    parser.add_argument('--cameras', type=int, default=2)
    parser.add_argument('--imageDim', type=int, default=128)
    parser.add_argument('--epochV', type=int, default=0)
    parser.add_argument('--cIm', type=int, default=4)
    parser.add_argument('--seedLoad', type=str, default="-0000")
    parser.add_argument('--zGF', type=float, default=0.2)
    parser.add_argument('--zDF', type=float, default=0.2)
    parser.add_argument('--bF', type=float, default=0.2)
    parser.add_argument('--expandGen', type=int, default=3)
    parser.add_argument('--expandDis', type=int, default=3)
    parser.add_argument('--wiggleDepth', type=int, default=-1)
    parser.add_argument('--visdom', type=str2bool, default=True)
    parser.add_argument('--lambdaL1', type=int, default=100)
    parser.add_argument('--clipping', type=float, default=-1)
    parser.add_argument('--depth', type=str2bool, default=True)
    parser.add_argument('--recreate', type=str2bool, default=False)
    parser.add_argument('--name_wiggle', type=str, default='wiggle-result')

    return check_args(parser.parse_args())


"""checking arguments"""

def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def check_args(args):
    # --save_dir
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # --result_dir
    if not os.path.exists(args.result_dir):
        os.makedirs(args.result_dir)

    # --result_dir
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    # --epoch
    try:
        assert args.epoch >= 1
    except:
        print('number of epochs must be larger than or equal to one')

    # --batch_size
    try:
        assert args.batch_size >= 1
    except:
        print('batch size must be larger than or equal to one')

    return args


"""main"""


def main():
    # parse arguments
    args = parse_args()
    if args is None:
        exit()

    if args.benchmark_mode:
        torch.backends.cudnn.benchmark = True

        # declare instance for GAN
    if args.gan_type == 'WiggleGAN':
        gan = WiggleGAN(args)
    #elif args.gan_type == 'MyACGAN':
    #    gan = MyACGAN(args)
    #elif args.gan_type == 'MyGAN':
    #    gan = MyGAN(args)
    else:
        raise Exception("[!] There is no option for " + args.gan_type)

    # launch the graph in a session
    if (args.wiggleDepth < 0 and not args.recreate):
        print(" [*] Training Starting!")
        gan.train()
        print(" [*] Training finished!")
    else:
      if not args.recreate:
            print(" [*] Wiggle Started!")
            gan.wiggleEf()
            print(" [*] Wiggle finished!")
      else:
            print(" [*] Dataset recreation Started")
            gan.recreate()
            print(" [*] Dataset recreation finished")


if __name__ == '__main__':
    main()