import cv2 import numpy as np import torch from PIL import Image from collections.__init__ import deque from gym import Env from gym.spaces import Box from rlkit.envs.wrappers import ProxyEnv class ImageMujocoEnv(ProxyEnv, Env): def __init__(self, wrapped_env, imsize=32, keep_prev=0, init_camera=None, camera_name=None, transpose=False, grayscale=False, normalize=False, ): import mujoco_py super().__init__(wrapped_env) self.imsize = imsize if grayscale: self.image_length = self.imsize * self.imsize else: self.image_length = 3 * self.imsize * self.imsize # This is torch format rather than PIL image self.image_shape = (self.imsize, self.imsize) # Flattened past image queue self.history_length = keep_prev + 1 self.history = deque(maxlen=self.history_length) # init camera if init_camera is not None: sim = self._wrapped_env.sim viewer = mujoco_py.MjRenderContextOffscreen(sim, device_id=-1) init_camera(viewer.cam) sim.add_render_context(viewer) self.camera_name = camera_name # None means default camera self.transpose = transpose self.grayscale = grayscale self.normalize = normalize self._render_local = False self.observation_space = Box(low=0.0, high=1.0, shape=( self.image_length * self.history_length,)) def step(self, action): # image observation get returned as a flattened 1D array true_state, reward, done, info = super().step(action) observation = self._image_observation() self.history.append(observation) history = self._get_history().flatten() full_obs = self._get_obs(history, true_state) return full_obs, reward, done, info def reset(self, **kwargs): true_state = super().reset(**kwargs) self.history = deque(maxlen=self.history_length) observation = self._image_observation() self.history.append(observation) history = self._get_history().flatten() full_obs = self._get_obs(history, true_state) return full_obs def get_image(self): return self._image_observation() def _get_obs(self, history_flat, true_state): # adds extra information from true_state into to the image observation. # Used in ImageWithObsEnv. return history_flat def _image_observation(self): # returns the image as a torch format np array image_obs = self._wrapped_env.sim.render(width=self.imsize, height=self.imsize, camera_name=self.camera_name) if self._render_local: cv2.imshow('env', image_obs) cv2.waitKey(1) if self.grayscale: image_obs = Image.fromarray(image_obs).convert('L') image_obs = np.array(image_obs) if self.normalize: image_obs = image_obs / 255.0 if self.transpose: image_obs = image_obs.transpose() return image_obs def _get_history(self): observations = list(self.history) obs_count = len(observations) for _ in range(self.history_length - obs_count): dummy = np.zeros(self.image_shape) observations.append(dummy) return np.c_[observations] def retrieve_images(self): # returns images in unflattened PIL format images = [] for image_obs in self.history: pil_image = self.torch_to_pil(torch.Tensor(image_obs)) images.append(pil_image) return images def split_obs(self, obs): # splits observation into image input and true observation input imlength = self.image_length * self.history_length obs_length = self.observation_space.low.size obs = obs.view(-1, obs_length) image_obs = obs.narrow(start=0, length=imlength, dimension=1) if obs_length == imlength: return image_obs, None fc_obs = obs.narrow(start=imlength, length=obs.shape[1] - imlength, dimension=1) return image_obs, fc_obs def enable_render(self): self._render_local = True class ImageMujocoWithObsEnv(ImageMujocoEnv): def __init__(self, env, **kwargs): super().__init__(env, **kwargs) self.observation_space = Box(low=0.0, high=1.0, shape=( self.image_length * self.history_length + self.wrapped_env.obs_dim,)) def _get_obs(self, history_flat, true_state): return np.concatenate([history_flat, true_state])