File size: 1,351 Bytes
6f3bdf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import optuna
from typing import Any, Dict
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, single_observation_space
def sample_env_hyperparams(
trial: optuna.Trial, env_hparams: Dict[str, Any], env: VecEnv
) -> Dict[str, Any]:
obs_space = single_observation_space(env)
n_envs = 2 ** trial.suggest_int("n_envs_exp", 1, 5)
trial.set_user_attr("n_envs", n_envs)
env_hparams["n_envs"] = n_envs
normalize = trial.suggest_categorical("normalize", [False, True])
env_hparams["normalize"] = normalize
if normalize:
normalize_kwargs = env_hparams.get("normalize_kwargs", {})
if len(obs_space.shape) == 3:
normalize_kwargs.update(
{
"norm_obs": False,
"norm_reward": True,
}
)
else:
norm_obs = trial.suggest_categorical("norm_obs", [True, False])
norm_reward = trial.suggest_categorical("norm_reward", [True, False])
normalize_kwargs.update(
{
"norm_obs": norm_obs,
"norm_reward": norm_reward,
}
)
env_hparams["normalize_kwargs"] = normalize_kwargs
elif "normalize_kwargs" in env_hparams:
del env_hparams["normalize_kwargs"]
return env_hparams
|