File size: 1,927 Bytes
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3582c8a
 
eaf2e33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import argparse
from src.drl.train_async import *
from src.gan.adversarial_train import *
from src.drl.train_sinproc import set_SAC_parser, train_SAC
from src.drl.egsac.train_egsac import set_EGSAC_parser, train_EGSAC
from src.drl.sunrise.train_sunrise import train_SUNRISE, set_SUNRISE_args
from src.drl.dvd import set_DvDSAC_parser, train_DvDSAC


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()

    parser_gan = subparsers.add_parser('gan', help='Train GAN')
    set_GAN_parser(parser_gan)
    parser_gan.set_defaults(entry=train_GAN)

    parser_sac = subparsers.add_parser('sac', help='Train SAC')
    set_SAC_parser(parser_sac)
    parser_sac.set_defaults(entry=train_SAC)

    parser_asyncsac = subparsers.add_parser('asyncsac', help='Train Asynchronous SAC')
    set_AsyncSAC_parser(parser_asyncsac)
    parser_asyncsac.set_defaults(entry=train_AsyncSAC)

    parser_egsac = subparsers.add_parser('egsac', help='Train Episodic Generative SAC')
    set_EGSAC_parser(parser_egsac)
    parser_egsac.set_defaults(entry=train_EGSAC)

    parser_ncesac = subparsers.add_parser('ncesac', help='Train Negatively Correlated Ensemble SAC')
    set_NCESAC_parser(parser_ncesac)
    parser_ncesac.set_defaults(entry=train_NCESAC)

    parser_pmoesac = subparsers.add_parser('pmoe', help='Train PMOE')
    set_PMOESAC_parser(parser_pmoesac)
    parser_pmoesac.set_defaults(entry=train_PMOESAC)

    parser_sunrise = subparsers.add_parser('sunrise', help='Train SUNRISE')
    set_SUNRISE_args(parser_sunrise)
    parser_sunrise.set_defaults(entry=train_SUNRISE)

    parser_dvd = subparsers.add_parser('dvd', help='Train DvD')
    set_DvDSAC_parser(parser_dvd)
    parser_dvd.set_defaults(entry=train_DvDSAC)

    args = parser.parse_args()

    entry = args.entry
    
    # entry是每一个模型的训练入口,具体函数在各个subparser内定义
    entry(args)