NCERL-Diverse-PCG / src /rlkit /envs /mujoco_image_env.py
baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
5.22 kB
import cv2
import numpy as np
import torch
from PIL import Image
from collections.__init__ import deque
from gym import Env
from gym.spaces import Box
from rlkit.envs.wrappers import ProxyEnv
class ImageMujocoEnv(ProxyEnv, Env):
def __init__(self,
wrapped_env,
imsize=32,
keep_prev=0,
init_camera=None,
camera_name=None,
transpose=False,
grayscale=False,
normalize=False,
):
import mujoco_py
super().__init__(wrapped_env)
self.imsize = imsize
if grayscale:
self.image_length = self.imsize * self.imsize
else:
self.image_length = 3 * self.imsize * self.imsize
# This is torch format rather than PIL image
self.image_shape = (self.imsize, self.imsize)
# Flattened past image queue
self.history_length = keep_prev + 1
self.history = deque(maxlen=self.history_length)
# init camera
if init_camera is not None:
sim = self._wrapped_env.sim
viewer = mujoco_py.MjRenderContextOffscreen(sim, device_id=-1)
init_camera(viewer.cam)
sim.add_render_context(viewer)
self.camera_name = camera_name # None means default camera
self.transpose = transpose
self.grayscale = grayscale
self.normalize = normalize
self._render_local = False
self.observation_space = Box(low=0.0,
high=1.0,
shape=(
self.image_length * self.history_length,))
def step(self, action):
# image observation get returned as a flattened 1D array
true_state, reward, done, info = super().step(action)
observation = self._image_observation()
self.history.append(observation)
history = self._get_history().flatten()
full_obs = self._get_obs(history, true_state)
return full_obs, reward, done, info
def reset(self, **kwargs):
true_state = super().reset(**kwargs)
self.history = deque(maxlen=self.history_length)
observation = self._image_observation()
self.history.append(observation)
history = self._get_history().flatten()
full_obs = self._get_obs(history, true_state)
return full_obs
def get_image(self):
return self._image_observation()
def _get_obs(self, history_flat, true_state):
# adds extra information from true_state into to the image observation.
# Used in ImageWithObsEnv.
return history_flat
def _image_observation(self):
# returns the image as a torch format np array
image_obs = self._wrapped_env.sim.render(width=self.imsize,
height=self.imsize,
camera_name=self.camera_name)
if self._render_local:
cv2.imshow('env', image_obs)
cv2.waitKey(1)
if self.grayscale:
image_obs = Image.fromarray(image_obs).convert('L')
image_obs = np.array(image_obs)
if self.normalize:
image_obs = image_obs / 255.0
if self.transpose:
image_obs = image_obs.transpose()
return image_obs
def _get_history(self):
observations = list(self.history)
obs_count = len(observations)
for _ in range(self.history_length - obs_count):
dummy = np.zeros(self.image_shape)
observations.append(dummy)
return np.c_[observations]
def retrieve_images(self):
# returns images in unflattened PIL format
images = []
for image_obs in self.history:
pil_image = self.torch_to_pil(torch.Tensor(image_obs))
images.append(pil_image)
return images
def split_obs(self, obs):
# splits observation into image input and true observation input
imlength = self.image_length * self.history_length
obs_length = self.observation_space.low.size
obs = obs.view(-1, obs_length)
image_obs = obs.narrow(start=0,
length=imlength,
dimension=1)
if obs_length == imlength:
return image_obs, None
fc_obs = obs.narrow(start=imlength,
length=obs.shape[1] - imlength,
dimension=1)
return image_obs, fc_obs
def enable_render(self):
self._render_local = True
class ImageMujocoWithObsEnv(ImageMujocoEnv):
def __init__(self, env, **kwargs):
super().__init__(env, **kwargs)
self.observation_space = Box(low=0.0,
high=1.0,
shape=(
self.image_length * self.history_length +
self.wrapped_env.obs_dim,))
def _get_obs(self, history_flat, true_state):
return np.concatenate([history_flat,
true_state])