DQN playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
923ccaf
from dataclasses import astuple | |
from typing import Optional | |
import gym | |
import numpy as np | |
from torch.utils.tensorboard.writer import SummaryWriter | |
from rl_algo_impls.runner.config import Config, EnvHyperparams | |
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter | |
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation | |
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv | |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv | |
def make_procgen_env( | |
config: Config, | |
hparams: EnvHyperparams, | |
training: bool = True, | |
render: bool = False, | |
normalize_load_path: Optional[str] = None, | |
tb_writer: Optional[SummaryWriter] = None, | |
) -> VecEnv: | |
from gym3 import ExtractDictObWrapper, ViewerWrapper | |
from procgen.env import ProcgenGym3Env, ToBaselinesVecEnv | |
( | |
_, # env_type | |
n_envs, | |
_, # frame_stack | |
make_kwargs, | |
_, # no_reward_timeout_steps | |
_, # no_reward_fire_steps | |
_, # vec_env_class | |
normalize, | |
normalize_kwargs, | |
rolling_length, | |
_, # train_record_video | |
_, # video_step_interval | |
_, # initial_steps_to_truncate | |
_, # clip_atari_rewards | |
_, # normalize_type | |
_, # mask_actions | |
_, # bots | |
_, # self_play_kwargs | |
_, # selfplay_bots | |
) = astuple(hparams) | |
seed = config.seed(training=training) | |
make_kwargs = make_kwargs or {} | |
make_kwargs["render_mode"] = "rgb_array" | |
if seed is not None: | |
make_kwargs["rand_seed"] = seed | |
envs = ProcgenGym3Env(n_envs, config.env_id, **make_kwargs) | |
envs = ExtractDictObWrapper(envs, key="rgb") | |
if render: | |
envs = ViewerWrapper(envs, info_key="rgb") | |
envs = ToBaselinesVecEnv(envs) | |
envs = IsVectorEnv(envs) | |
# TODO: Handle Grayscale and/or FrameStack | |
envs = HwcToChwObservation(envs) | |
envs = gym.wrappers.RecordEpisodeStatistics(envs) | |
if seed is not None: | |
envs.action_space.seed(seed) | |
envs.observation_space.seed(seed) | |
if training: | |
assert tb_writer | |
envs = EpisodeStatsWriter( | |
envs, tb_writer, training=training, rolling_length=rolling_length | |
) | |
if normalize and training: | |
normalize_kwargs = normalize_kwargs or {} | |
envs = gym.wrappers.NormalizeReward(envs) | |
clip_obs = normalize_kwargs.get("clip_reward", 10.0) | |
envs = gym.wrappers.TransformReward( | |
envs, lambda r: np.clip(r, -clip_obs, clip_obs) | |
) | |
return envs # type: ignore | |