|
import numpy as np |
|
import os |
|
from collections import deque |
|
import gym |
|
from gym import spaces |
|
import cv2 |
|
|
|
|
|
''' |
|
Atari Wrapper copied from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py |
|
''' |
|
|
|
|
|
class LazyFrames(object): |
|
def __init__(self, frames): |
|
"""This object ensures that common frames between the observations are only stored once. |
|
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay |
|
buffers. |
|
This object should only be converted to numpy array before being passed to the model. |
|
You'd not believe how complex the previous solution was.""" |
|
self._frames = frames |
|
self._out = None |
|
|
|
def _force(self): |
|
if self._out is None: |
|
self._out = np.concatenate(self._frames, axis=2) |
|
self._frames = None |
|
return self._out |
|
|
|
def __array__(self, dtype=None): |
|
out = self._force() |
|
if dtype is not None: |
|
out = out.astype(dtype) |
|
return out |
|
|
|
def __len__(self): |
|
return len(self._force()) |
|
|
|
def __getitem__(self, i): |
|
return self._force()[i] |
|
|
|
class FireResetEnv(gym.Wrapper): |
|
def __init__(self, env): |
|
"""Take action on reset for environments that are fixed until firing.""" |
|
gym.Wrapper.__init__(self, env) |
|
assert env.unwrapped.get_action_meanings()[1] == 'FIRE' |
|
assert len(env.unwrapped.get_action_meanings()) >= 3 |
|
|
|
def reset(self, **kwargs): |
|
self.env.reset(**kwargs) |
|
obs, _, done, _ = self.env.step(1) |
|
if done: |
|
self.env.reset(**kwargs) |
|
obs, _, done, _ = self.env.step(2) |
|
if done: |
|
self.env.reset(**kwargs) |
|
return obs |
|
|
|
def step(self, ac): |
|
return self.env.step(ac) |
|
|
|
|
|
class MaxAndSkipEnv(gym.Wrapper): |
|
def __init__(self, env, skip=4): |
|
"""Return only every `skip`-th frame""" |
|
gym.Wrapper.__init__(self, env) |
|
|
|
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) |
|
self._skip = skip |
|
|
|
def step(self, action): |
|
"""Repeat action, sum reward, and max over last observations.""" |
|
total_reward = 0.0 |
|
done = None |
|
for i in range(self._skip): |
|
obs, reward, done, info = self.env.step(action) |
|
if i == self._skip - 2: self._obs_buffer[0] = obs |
|
if i == self._skip - 1: self._obs_buffer[1] = obs |
|
total_reward += reward |
|
if done: |
|
break |
|
|
|
|
|
max_frame = self._obs_buffer.max(axis=0) |
|
|
|
return max_frame, total_reward, done, info |
|
|
|
def reset(self, **kwargs): |
|
return self.env.reset(**kwargs) |
|
|
|
|
|
|
|
class WarpFrame(gym.ObservationWrapper): |
|
def __init__(self, env): |
|
"""Warp frames to 84x84 as done in the Nature paper and later work.""" |
|
gym.ObservationWrapper.__init__(self, env) |
|
self.width = 84 |
|
self.height = 84 |
|
self.observation_space = spaces.Box(low=0, high=255, |
|
shape=(self.height, self.width, 1), dtype=np.uint8) |
|
|
|
def observation(self, frame): |
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) |
|
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) |
|
return frame[:, :, None] |
|
|
|
class WarpFrameNoResize(gym.ObservationWrapper): |
|
def __init__(self, env): |
|
"""Warp frames to 84x84 as done in the Nature paper and later work.""" |
|
gym.ObservationWrapper.__init__(self, env) |
|
|
|
def observation(self, frame): |
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) |
|
|
|
return frame[:, :, None] |
|
|
|
|
|
|
|
class FrameStack(gym.Wrapper): |
|
def __init__(self, env, k): |
|
"""Stack k last frames. |
|
Returns lazy array, which is much more memory efficient. |
|
See Also |
|
-------- |
|
baselines.common.atari_wrappers.LazyFrames |
|
""" |
|
gym.Wrapper.__init__(self, env) |
|
self.k = k |
|
self.frames = deque([], maxlen=k) |
|
shp = env.observation_space.shape |
|
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype) |
|
|
|
def reset(self): |
|
ob = self.env.reset() |
|
for _ in range(self.k): |
|
self.frames.append(ob) |
|
return self._get_ob() |
|
|
|
def step(self, action): |
|
ob, reward, done, info = self.env.step(action) |
|
self.frames.append(ob) |
|
return self._get_ob(), reward, done, info |
|
|
|
def _get_ob(self): |
|
assert len(self.frames) == self.k |
|
return LazyFrames(list(self.frames)) |
|
|
|
|
|
class ImageToPyTorch(gym.ObservationWrapper): |
|
def __init__(self, env): |
|
super(ImageToPyTorch, self).__init__(env) |
|
old_shape = self.observation_space.shape |
|
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32) |
|
|
|
def observation(self, observation): |
|
return np.moveaxis(observation, 2, 0) |
|
|
|
|
|
class ScaledFloatFrame(gym.ObservationWrapper): |
|
def __init__(self, env): |
|
gym.ObservationWrapper.__init__(self, env) |
|
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32) |
|
|
|
def observation(self, observation): |
|
|
|
|
|
return np.array(observation).astype(np.float32) / 255.0 |
|
|
|
class ClipRewardEnv(gym.RewardWrapper): |
|
def __init__(self, env): |
|
gym.RewardWrapper.__init__(self, env) |
|
|
|
def reward(self, reward): |
|
"""Bin reward to {+1, 0, -1} by its sign.""" |
|
return np.sign(reward) |
|
|
|
|
|
def make_starpilot(render=False): |
|
print("Environment: Starpilot") |
|
if render: |
|
env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy", render_mode="human") |
|
else: |
|
env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy") |
|
env = WarpFrameNoResize(env) |
|
env = ImageToPyTorch(env) |
|
env = FrameStack(env, 4) |
|
return env |
|
|