NCERL-Diverse-PCG / src /drl /sunrise /train_sunrise.py
baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
8.97 kB
import json
import importlib
import src.rlkit.torch.pytorch_util as ptu
from analysis.tests import *
from src.env.environments import AsyncOlGenEnv
from src.env.logger import AsyncCsvLogger, AsyncStdLogger, GenResLogger
from src.gan.gankits import *
from src.drl.sunrise.sunrise_adaption import AsyncOffPolicyALgo, MaskedRepMem
from src.olgen.ol_generator import OnlineGenerator
from src.olgen.olg_policy import EnsembleGenPolicy
from src.rlkit.torch.sac.policies import TanhGaussianPolicy, MakeDeterministic
from src.rlkit.torch.sac.neurips20_sac_ensemble import NeurIPS20SACEnsembleTrainer
from src.rlkit.torch.networks import FlattenMlp
from src.smb.asyncsimlt import AsycSimltPool
from src.utils.filesys import auto_dire, getpath
from src.utils.misc import record_time
variant = dict(
algorithm="SAC",
version="normal",
layer_size=256,
replay_buffer_size=int(1e6),
# algorithm_kwargs=dict(
# num_epochs=210,
# num_eval_steps_per_epoch=1000,
# num_trains_per_train_loop=1000,
# num_expl_steps_per_train_loop=1000,
# min_num_steps_before_training=1000,
# max_path_length=1000,
# batch_size=args.batch_size,
# save_frequency=args.save_freq,
# ),
trainer_kwargs=dict(
# discount=0.99,
# soft_target_tau=5e-3,
target_update_period=1,
policy_lr=3E-4,
qf_lr=3E-4,
reward_scale=1,
use_automatic_entropy_tuning=True,
),
# num_ensemble=args.num_ensemble,
# num_layer=args.num_layer,
# seed=args.seed,
# ber_mean=args.ber_mean,
# env=args.env,
# inference_type=args.inference_type,
# temperature=args.temperature,
# log_dir="",
)
def set_SUNRISE_args(parser):
parser.add_argument('--name', type=str, default='SUNRISE', help='Name of this algorithm')
parser.add_argument('--n_workers', type=int, default=20, help='Number of max_parallel processes in the environment.')
parser.add_argument('--queuesize', type=int, default=25, help='Size of waiting queue of the environment.')
parser.add_argument('--eplen', type=int, default=50, help='Episode length of the environment.')
parser.add_argument('--budget', type=int, default=int(1e6), help='Total time steps of training.')
parser.add_argument('--gamma', type=float, default=0.9, help='RL parameter, discount factor')
# parser.add_argument('--tar_entropy', type=float, default=-nz, help='SAC parameter, taget entropy')
parser.add_argument('--tau', type=float, default=0.02, help='SAC parameter, taget net smooth coefficient')
parser.add_argument('--update_per', type=int, default=2,
help='Do one update (with one batch) per how many collected transitions')
# parser.add_argument('--batch', type=int, default=256, help='Batch size for one update')
parser.add_argument('--mem_size', type=int, default=int(1e6), help='Size of replay memory')
parser.add_argument('--gpuid', type=int, default=0,
help='ID of GPU to train the policy. CPU will be used if gpuid < 0')
parser.add_argument('--rfunc', type=str, default='default', help='Name of the reward function in src/env/rfuncs.py')
parser.add_argument('--path', type=str, default='',
help='Path related to \'/training_data\'to save the training logs. If not specified, a new folder named SAC{id} will be created.')
# parser.add_argument('--actor_hiddens', type=int, nargs='+', default=[256, 256], help='List of number of units in each hideen layer of actors net')
# parser.add_argument('--critic_hiddens', type=int, nargs='+', default=[256, 256], help='List of number of units in each hideen layer of critic net')
parser.add_argument('--gen_period', type=int, default=20000, help='Period of saving level generation results')
parser.add_argument('--periodic_gen_num', type=int, default=200,
help='Number of levels to be generated for each evaluation')
parser.add_argument('--redirect', action='store_true', help='If add this, redirect STD log to log.txt')
# architecture
parser.add_argument('--num_layer', default=2, type=int)
# train
parser.add_argument('--batch_size', default=256, type=int)
# ensemble
parser.add_argument('--num_ensemble', default=5, type=int)
parser.add_argument('--ber_mean', default=0.5, type=float)
# inference
parser.add_argument('--inference_type', default=1.0, type=float)
# corrective feedback
parser.add_argument('--temperature', default=10.0, type=float)
def get_env(args, device):
evalpool = AsycSimltPool(args.n_workers, args.queuesize, args.rfunc, verbose=False)
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
env = AsyncOlGenEnv(rfunc.get_n(), get_decoder('models/decoder.pth'), evalpool, args.eplen, device=device)
return env
def get_trainer(args, obs_dim, action_dim, path, device):
M = variant['layer_size']
num_layer = args.num_layer
network_structure = [M] * num_layer
L_qf1, L_qf2, L_target_qf1, L_target_qf2, L_policy = [], [], [], [], []
for _ in range(args.num_ensemble):
qf1 = FlattenMlp(
input_size=obs_dim + action_dim,
output_size=1,
hidden_sizes=network_structure,
)
qf2 = FlattenMlp(
input_size=obs_dim + action_dim,
output_size=1,
hidden_sizes=network_structure,
)
target_qf1 = FlattenMlp(
input_size=obs_dim + action_dim,
output_size=1,
hidden_sizes=network_structure,
)
target_qf2 = FlattenMlp(
input_size=obs_dim + action_dim,
output_size=1,
hidden_sizes=network_structure,
)
policy = TanhGaussianPolicy(
obs_dim=obs_dim,
action_dim=action_dim,
hidden_sizes=network_structure,
)
L_qf1.append(qf1)
L_qf2.append(qf2)
L_target_qf1.append(target_qf1)
L_target_qf2.append(target_qf2)
L_policy.append(policy)
trainer = NeurIPS20SACEnsembleTrainer(
# env=eval_env,
policy=L_policy,
qf1=L_qf1,
qf2=L_qf2,
target_qf1=L_target_qf1,
target_qf2=L_target_qf2,
num_ensemble=args.num_ensemble,
discount=args.gamma,
soft_target_tau=args.tau,
feedback_type=1,
temperature=args.temperature,
temperature_act=0,
expl_gamma=0,
log_dir=path,
device=device,
**variant['trainer_kwargs']
)
return trainer
pass
def get_algo(args, rfunc, device, path):
algorithm = AsyncOffPolicyALgo(
MaskedRepMem(args.num_ensemble, args.mem_size, args.ber_mean, device),
update_per=args.update_per,
batch=args.batch_size,
device=device
)
loggers = [
AsyncCsvLogger(f'{path}/log.csv', rfunc),
AsyncStdLogger(rfunc, 2000, f'{path}/log.txt' if args.redirect else '')
]
if args.periodic_gen_num > 0:
loggers.append(GenResLogger(path, args.periodic_gen_num, args.gen_period))
algorithm.set_loggers(*loggers)
return algorithm
def train_SUNRISE(args):
if not args.path:
path = auto_dire('training_data', args.name)
else:
path = getpath('training_data/' + args.path)
try:
os.makedirs(path)
except FileExistsError:
if os.path.exists(f'{path}/policy.pth'):
print(f'Training is cancelled due to \`{path}\` has been occupied')
return
if os.path.exists(f'{path}/policy.pth'):
print(f'Trainning at <{path}> is skipped as there has a finished trial already.')
return
device = 'cpu' if args.gpuid < 0 or not torch.cuda.is_available() else f'cuda:{args.gpuid}'
rfunc = importlib.import_module('src.env.rfuncs').__getattribute__(f'{args.rfunc}')()
with open(path + '/run_configuration.txt', 'w') as f:
f.write(time.strftime('%Y-%m-%d %H:%M') + '\n')
f.write(f'---------{args.name}---------\n')
args_strlines = [
f'{key}={val}\n' for key, val in vars(args).items()
if key not in {'name', 'rfunc', 'path', 'entry'}
]
f.writelines(args_strlines)
f.write('-' * 50 + '\n')
f.write(str(rfunc))
N = rfunc.get_n()
with open(f'{path}/cfgs.json', 'w') as f:
data = {'N': N, 'gamma': args.gamma, 'h': args.eplen, 'rfunc': args.rfunc, 'm': args.num_ensemble}
json.dump(data, f)
ptu.set_gpu_mode(True)
env = get_env(args, device)
obs_dim, action_dim = env.histlen * nz, nz
trainer = get_trainer(args, obs_dim, action_dim, path, device)
algorithm = get_algo(args, rfunc, device, path)
_, timecost = record_time(algorithm.train)(env, trainer, args.budget, args.inference_type, path)
pass
if __name__ == '__main__':
pass