sgoodfriend's picture
DQN playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
923ccaf
import os
from dataclasses import astuple
from typing import Callable, Optional
import gym
from gym.vector.async_vector_env import AsyncVectorEnv
from gym.vector.sync_vector_env import SyncVectorEnv
from gym.wrappers.frame_stack import FrameStack
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.resize_observation import ResizeObservation
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv, NoopResetEnv
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.shared.policy.policy import VEC_NORMALIZE_FILENAME
from rl_algo_impls.shared.vec_env.utils import (
import_for_env_id,
is_atari,
is_bullet_env,
is_car_racing,
is_gym_procgen,
is_microrts,
)
from rl_algo_impls.wrappers.action_mask_wrapper import SingleActionMaskWrapper
from rl_algo_impls.wrappers.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireOnLifeStarttEnv,
)
from rl_algo_impls.wrappers.episode_record_video import EpisodeRecordVideo
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
from rl_algo_impls.wrappers.initial_step_truncate_wrapper import (
InitialStepTruncateWrapper,
)
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
from rl_algo_impls.wrappers.no_reward_timeout import NoRewardTimeout
from rl_algo_impls.wrappers.noop_env_seed import NoopEnvSeed
from rl_algo_impls.wrappers.normalize import NormalizeObservation, NormalizeReward
from rl_algo_impls.wrappers.sync_vector_env_render_compat import (
SyncVectorEnvRenderCompat,
)
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
from rl_algo_impls.wrappers.video_compat_wrapper import VideoCompatWrapper
def make_vec_env(
config: Config,
hparams: EnvHyperparams,
training: bool = True,
render: bool = False,
normalize_load_path: Optional[str] = None,
tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
(
env_type,
n_envs,
frame_stack,
make_kwargs,
no_reward_timeout_steps,
no_reward_fire_steps,
vec_env_class,
normalize,
normalize_kwargs,
rolling_length,
train_record_video,
video_step_interval,
initial_steps_to_truncate,
clip_atari_rewards,
normalize_type,
mask_actions,
_, # bots
_, # self_play_kwargs
_, # selfplay_bots
) = astuple(hparams)
import_for_env_id(config.env_id)
seed = config.seed(training=training)
make_kwargs = make_kwargs.copy() if make_kwargs is not None else {}
if is_bullet_env(config) and render:
make_kwargs["render"] = True
if is_car_racing(config):
make_kwargs["verbose"] = 0
if is_gym_procgen(config) and not render:
make_kwargs["render_mode"] = "rgb_array"
def make(idx: int) -> Callable[[], gym.Env]:
def _make() -> gym.Env:
env = gym.make(config.env_id, **make_kwargs)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = VideoCompatWrapper(env)
if training and train_record_video and idx == 0:
env = EpisodeRecordVideo(
env,
config.video_prefix,
step_increment=n_envs,
video_step_interval=int(video_step_interval),
)
if training and initial_steps_to_truncate:
env = InitialStepTruncateWrapper(
env, idx * initial_steps_to_truncate // n_envs
)
if is_atari(config): # type: ignore
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env, training=training)
action_meanings = env.unwrapped.get_action_meanings()
if "FIRE" in action_meanings: # type: ignore
env = FireOnLifeStarttEnv(env, action_meanings.index("FIRE"))
if clip_atari_rewards:
env = ClipRewardEnv(env, training=training)
env = ResizeObservation(env, (84, 84))
env = GrayScaleObservation(env, keep_dim=False)
env = FrameStack(env, frame_stack)
elif is_car_racing(config):
env = ResizeObservation(env, (64, 64))
env = GrayScaleObservation(env, keep_dim=False)
env = FrameStack(env, frame_stack)
elif is_gym_procgen(config):
# env = GrayScaleObservation(env, keep_dim=False)
env = NoopEnvSeed(env)
env = HwcToChwObservation(env)
if frame_stack > 1:
env = FrameStack(env, frame_stack)
elif is_microrts(config):
env = HwcToChwObservation(env)
if no_reward_timeout_steps:
env = NoRewardTimeout(
env, no_reward_timeout_steps, n_fire_steps=no_reward_fire_steps
)
if seed is not None:
env.seed(seed + idx)
env.action_space.seed(seed + idx)
env.observation_space.seed(seed + idx)
return env
return _make
if env_type == "sb3vec":
VecEnvClass = {"sync": DummyVecEnv, "async": SubprocVecEnv}[vec_env_class]
elif env_type == "gymvec":
VecEnvClass = {"sync": SyncVectorEnv, "async": AsyncVectorEnv}[vec_env_class]
else:
raise ValueError(f"env_type {env_type} unsupported")
envs = VecEnvClass([make(i) for i in range(n_envs)])
if env_type == "gymvec" and vec_env_class == "sync":
envs = SyncVectorEnvRenderCompat(envs)
if env_type == "sb3vec":
envs = IsVectorEnv(envs)
if mask_actions:
envs = SingleActionMaskWrapper(envs)
if training:
assert tb_writer
envs = EpisodeStatsWriter(
envs, tb_writer, training=training, rolling_length=rolling_length
)
if normalize:
if normalize_type is None:
normalize_type = "sb3" if env_type == "sb3vec" else "gymlike"
normalize_kwargs = normalize_kwargs or {}
if normalize_type == "sb3":
if normalize_load_path:
envs = VecNormalize.load(
os.path.join(normalize_load_path, VEC_NORMALIZE_FILENAME),
envs, # type: ignore
)
else:
envs = VecNormalize(
envs, # type: ignore
training=training,
**normalize_kwargs,
)
if not training:
envs.norm_reward = False
elif normalize_type == "gymlike":
if normalize_kwargs.get("norm_obs", True):
envs = NormalizeObservation(
envs, training=training, clip=normalize_kwargs.get("clip_obs", 10.0)
)
if training and normalize_kwargs.get("norm_reward", True):
envs = NormalizeReward(
envs,
training=training,
clip=normalize_kwargs.get("clip_reward", 10.0),
)
else:
raise ValueError(
f"normalize_type {normalize_type} not supported (sb3 or gymlike)"
)
return envs