sgoodfriend's picture
PPO playing HalfCheetahBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
f050c92
raw
history blame
2.95 kB
from copy import deepcopy
import optuna
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams
from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
from rl_algo_impls.shared.vec_env import make_eval_env
from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
def sample_params(
trial: optuna.Trial,
base_hyperparams: Hyperparams,
base_config: Config,
) -> Hyperparams:
hyperparams = deepcopy(base_hyperparams)
base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
env = make_eval_env(
base_config,
base_env_hyperparams,
override_hparams={"n_envs": 1},
)
# env_hyperparams
env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
# policy_hyperparams
policy_hyperparams = sample_on_policy_hyperparams(
trial, hyperparams.policy_hyperparams, env
)
# algo_hyperparams
algo_hyperparams = hyperparams.algo_hyperparams
learning_rate = trial.suggest_float("learning_rate", 1e-5, 2e-3, log=True)
learning_rate_decay = trial.suggest_categorical(
"learning_rate_decay", ["none", "linear"]
)
n_steps_exp = trial.suggest_int("n_steps_exp", 1, 10)
n_steps = 2**n_steps_exp
trial.set_user_attr("n_steps", n_steps)
gamma = 1.0 - trial.suggest_float("gamma_om", 1e-4, 1e-1, log=True)
trial.set_user_attr("gamma", gamma)
gae_lambda = 1 - trial.suggest_float("gae_lambda_om", 1e-4, 1e-1)
trial.set_user_attr("gae_lambda", gae_lambda)
ent_coef = trial.suggest_float("ent_coef", 1e-8, 2.5e-2, log=True)
ent_coef_decay = trial.suggest_categorical("ent_coef_decay", ["none", "linear"])
vf_coef = trial.suggest_float("vf_coef", 0.1, 0.7)
max_grad_norm = trial.suggest_float("max_grad_norm", 1e-1, 1e1, log=True)
use_rms_prop = trial.suggest_categorical("use_rms_prop", [True, False])
normalize_advantage = trial.suggest_categorical(
"normalize_advantage", [True, False]
)
algo_hyperparams.update(
{
"learning_rate": learning_rate,
"learning_rate_decay": learning_rate_decay,
"n_steps": n_steps,
"gamma": gamma,
"gae_lambda": gae_lambda,
"ent_coef": ent_coef,
"ent_coef_decay": ent_coef_decay,
"vf_coef": vf_coef,
"max_grad_norm": max_grad_norm,
"use_rms_prop": use_rms_prop,
"normalize_advantage": normalize_advantage,
}
)
if policy_hyperparams.get("use_sde", False):
sde_sample_freq = 2 ** trial.suggest_int("sde_sample_freq_exp", 0, n_steps_exp)
trial.set_user_attr("sde_sample_freq", sde_sample_freq)
algo_hyperparams["sde_sample_freq"] = sde_sample_freq
elif "sde_sample_freq" in algo_hyperparams:
del algo_hyperparams["sde_sample_freq"]
env.close()
return hyperparams