DQN playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
8068833
import optuna | |
from gym.spaces import Box | |
from typing import Any, Dict | |
from rl_algo_impls.wrappers.vectorable_wrapper import ( | |
VecEnv, | |
single_action_space, | |
) | |
def sample_on_policy_hyperparams( | |
trial: optuna.Trial, policy_hparams: Dict[str, Any], env: VecEnv | |
) -> Dict[str, Any]: | |
act_space = single_action_space(env) | |
policy_hparams["init_layers_orthogonal"] = trial.suggest_categorical( | |
"init_layers_orthogonal", [True, False] | |
) | |
policy_hparams["activation_fn"] = trial.suggest_categorical( | |
"activation_fn", ["tanh", "relu"] | |
) | |
if isinstance(act_space, Box): | |
policy_hparams["log_std_init"] = trial.suggest_float("log_std_init", -5, 0.5) | |
policy_hparams["use_sde"] = trial.suggest_categorical("use_sde", [False, True]) | |
if policy_hparams.get("use_sde", False): | |
policy_hparams["squash_output"] = trial.suggest_categorical( | |
"squash_output", [False, True] | |
) | |
elif "squash_output" in policy_hparams: | |
del policy_hparams["squash_output"] | |
return policy_hparams | |