Spaces:
Sleeping
Sleeping
import copy | |
import random | |
import warnings | |
import torch | |
import cv2 | |
import numpy as np | |
from gym.spaces import Box, Dict | |
import rlkit.torch.pytorch_util as ptu | |
from multiworld.core.multitask_env import MultitaskEnv | |
from multiworld.envs.env_util import get_stat_in_paths, create_stats_ordered_dict | |
from rlkit.envs.wrappers import ProxyEnv | |
class VAEWrappedEnv(ProxyEnv, MultitaskEnv): | |
"""This class wraps an image-based environment with a VAE. | |
Assumes you get flattened (channels,84,84) observations from wrapped_env. | |
This class adheres to the "Silent Multitask Env" semantics: on reset, | |
it resamples a goal. | |
""" | |
def __init__( | |
self, | |
wrapped_env, | |
vae, | |
vae_input_key_prefix='image', | |
sample_from_true_prior=False, | |
decode_goals=False, | |
render_goals=False, | |
render_rollouts=False, | |
reward_params=None, | |
goal_sampling_mode="vae_prior", | |
imsize=84, | |
obs_size=None, | |
norm_order=2, | |
epsilon=20, | |
presampled_goals=None, | |
): | |
if reward_params is None: | |
reward_params = dict() | |
super().__init__(wrapped_env) | |
self.vae = vae | |
self.representation_size = self.vae.representation_size | |
self.input_channels = self.vae.input_channels | |
self.sample_from_true_prior = sample_from_true_prior | |
self._decode_goals = decode_goals | |
self.render_goals = render_goals | |
self.render_rollouts = render_rollouts | |
self.default_kwargs=dict( | |
decode_goals=decode_goals, | |
render_goals=render_goals, | |
render_rollouts=render_rollouts, | |
) | |
self.imsize = imsize | |
self.reward_params = reward_params | |
self.reward_type = self.reward_params.get("type", 'latent_distance') | |
self.norm_order = self.reward_params.get("norm_order", norm_order) | |
self.epsilon = self.reward_params.get("epsilon", epsilon) | |
self.reward_min_variance = self.reward_params.get("min_variance", 0) | |
latent_space = Box( | |
-10 * np.ones(obs_size or self.representation_size), | |
10 * np.ones(obs_size or self.representation_size), | |
dtype=np.float32, | |
) | |
spaces = self.wrapped_env.observation_space.spaces | |
spaces['observation'] = latent_space | |
spaces['desired_goal'] = latent_space | |
spaces['achieved_goal'] = latent_space | |
spaces['latent_observation'] = latent_space | |
spaces['latent_desired_goal'] = latent_space | |
spaces['latent_achieved_goal'] = latent_space | |
self.observation_space = Dict(spaces) | |
self._presampled_goals = presampled_goals | |
if self._presampled_goals is None: | |
self.num_goals_presampled = 0 | |
else: | |
self.num_goals_presampled = presampled_goals[random.choice(list(presampled_goals))].shape[0] | |
self.vae_input_key_prefix = vae_input_key_prefix | |
assert vae_input_key_prefix in {'image', 'image_proprio'} | |
self.vae_input_observation_key = vae_input_key_prefix + '_observation' | |
self.vae_input_achieved_goal_key = vae_input_key_prefix + '_achieved_goal' | |
self.vae_input_desired_goal_key = vae_input_key_prefix + '_desired_goal' | |
self._mode_map = {} | |
self.desired_goal = {'latent_desired_goal': latent_space.sample()} | |
self._initial_obs = None | |
self._custom_goal_sampler = None | |
self._goal_sampling_mode = goal_sampling_mode | |
def reset(self): | |
obs = self.wrapped_env.reset() | |
goal = self.sample_goal() | |
self.set_goal(goal) | |
self._initial_obs = obs | |
return self._update_obs(obs) | |
def step(self, action): | |
obs, reward, done, info = self.wrapped_env.step(action) | |
new_obs = self._update_obs(obs) | |
self._update_info(info, new_obs) | |
reward = self.compute_reward( | |
action, | |
{'latent_achieved_goal': new_obs['latent_achieved_goal'], | |
'latent_desired_goal': new_obs['latent_desired_goal']} | |
) | |
self.try_render(new_obs) | |
return new_obs, reward, done, info | |
def _update_obs(self, obs): | |
latent_obs = self._encode_one(obs[self.vae_input_observation_key]) | |
obs['latent_observation'] = latent_obs | |
obs['latent_achieved_goal'] = latent_obs | |
obs['observation'] = latent_obs | |
obs['achieved_goal'] = latent_obs | |
obs = {**obs, **self.desired_goal} | |
return obs | |
def _update_info(self, info, obs): | |
latent_distribution_params = self.vae.encode( | |
ptu.from_numpy(obs[self.vae_input_observation_key].reshape(1,-1)) | |
) | |
latent_obs, logvar = ptu.get_numpy(latent_distribution_params[0])[0], ptu.get_numpy(latent_distribution_params[1])[0] | |
# assert (latent_obs == obs['latent_observation']).all() | |
latent_goal = self.desired_goal['latent_desired_goal'] | |
dist = latent_goal - latent_obs | |
var = np.exp(logvar.flatten()) | |
var = np.maximum(var, self.reward_min_variance) | |
err = dist * dist / 2 / var | |
mdist = np.sum(err) # mahalanobis distance | |
info["vae_mdist"] = mdist | |
info["vae_success"] = 1 if mdist < self.epsilon else 0 | |
info["vae_dist"] = np.linalg.norm(dist, ord=self.norm_order) | |
info["vae_dist_l1"] = np.linalg.norm(dist, ord=1) | |
info["vae_dist_l2"] = np.linalg.norm(dist, ord=2) | |
""" | |
Multitask functions | |
""" | |
def sample_goals(self, batch_size): | |
# TODO: make mode a parameter you pass in | |
if self._goal_sampling_mode == 'custom_goal_sampler': | |
return self.custom_goal_sampler(batch_size) | |
elif self._goal_sampling_mode == 'presampled': | |
idx = np.random.randint(0, self.num_goals_presampled, batch_size) | |
sampled_goals = { | |
k: v[idx] for k, v in self._presampled_goals.items() | |
} | |
# ensures goals are encoded using latest vae | |
if 'image_desired_goal' in sampled_goals: | |
sampled_goals['latent_desired_goal'] = self._encode(sampled_goals['image_desired_goal']) | |
return sampled_goals | |
elif self._goal_sampling_mode == 'env': | |
goals = self.wrapped_env.sample_goals(batch_size) | |
latent_goals = self._encode(goals[self.vae_input_desired_goal_key]) | |
elif self._goal_sampling_mode == 'reset_of_env': | |
assert batch_size == 1 | |
goal = self.wrapped_env.get_goal() | |
goals = {k: v[None] for k, v in goal.items()} | |
latent_goals = self._encode( | |
goals[self.vae_input_desired_goal_key] | |
) | |
elif self._goal_sampling_mode == 'vae_prior': | |
goals = {} | |
latent_goals = self._sample_vae_prior(batch_size) | |
else: | |
raise RuntimeError("Invalid: {}".format(self._goal_sampling_mode)) | |
if self._decode_goals: | |
decoded_goals = self._decode(latent_goals) | |
else: | |
decoded_goals = None | |
image_goals, proprio_goals = self._image_and_proprio_from_decoded( | |
decoded_goals | |
) | |
goals['desired_goal'] = latent_goals | |
goals['latent_desired_goal'] = latent_goals | |
if proprio_goals is not None: | |
goals['proprio_desired_goal'] = proprio_goals | |
if image_goals is not None: | |
goals['image_desired_goal'] = image_goals | |
if decoded_goals is not None: | |
goals[self.vae_input_desired_goal_key] = decoded_goals | |
return goals | |
def get_goal(self): | |
return self.desired_goal | |
def compute_reward(self, action, obs): | |
actions = action[None] | |
next_obs = { | |
k: v[None] for k, v in obs.items() | |
} | |
return self.compute_rewards(actions, next_obs)[0] | |
def compute_rewards(self, actions, obs): | |
# TODO: implement log_prob/mdist | |
if self.reward_type == 'latent_distance': | |
achieved_goals = obs['latent_achieved_goal'] | |
desired_goals = obs['latent_desired_goal'] | |
dist = np.linalg.norm(desired_goals - achieved_goals, ord=self.norm_order, axis=1) | |
return -dist | |
elif self.reward_type == 'vectorized_latent_distance': | |
achieved_goals = obs['latent_achieved_goal'] | |
desired_goals = obs['latent_desired_goal'] | |
return -np.abs(desired_goals - achieved_goals) | |
elif self.reward_type == 'latent_sparse': | |
achieved_goals = obs['latent_achieved_goal'] | |
desired_goals = obs['latent_desired_goal'] | |
dist = np.linalg.norm(desired_goals - achieved_goals, ord=self.norm_order, axis=1) | |
reward = 0 if dist < self.epsilon else -1 | |
return reward | |
elif self.reward_type == 'state_distance': | |
achieved_goals = obs['state_achieved_goal'] | |
desired_goals = obs['state_desired_goal'] | |
return - np.linalg.norm(desired_goals - achieved_goals, ord=self.norm_order, axis=1) | |
elif self.reward_type == 'wrapped_env': | |
return self.wrapped_env.compute_rewards(actions, obs) | |
else: | |
raise NotImplementedError | |
def goal_dim(self): | |
return self.representation_size | |
def set_goal(self, goal): | |
""" | |
Assume goal contains both image_desired_goal and any goals required for wrapped envs | |
:param goal: | |
:return: | |
""" | |
self.desired_goal = goal | |
# TODO: fix this hack / document this | |
if self._goal_sampling_mode in {'presampled', 'env'}: | |
self.wrapped_env.set_goal(goal) | |
def get_diagnostics(self, paths, **kwargs): | |
statistics = self.wrapped_env.get_diagnostics(paths, **kwargs) | |
for stat_name_in_paths in ["vae_mdist", "vae_success", "vae_dist"]: | |
stats = get_stat_in_paths(paths, 'env_infos', stat_name_in_paths) | |
statistics.update(create_stats_ordered_dict( | |
stat_name_in_paths, | |
stats, | |
always_show_all_stats=True, | |
)) | |
final_stats = [s[-1] for s in stats] | |
statistics.update(create_stats_ordered_dict( | |
"Final " + stat_name_in_paths, | |
final_stats, | |
always_show_all_stats=True, | |
)) | |
return statistics | |
""" | |
Other functions | |
""" | |
def goal_sampling_mode(self): | |
return self._goal_sampling_mode | |
def goal_sampling_mode(self, mode): | |
assert mode in [ | |
'custom_goal_sampler', | |
'presampled', | |
'vae_prior', | |
'env', | |
'reset_of_env' | |
], "Invalid env mode" | |
self._goal_sampling_mode = mode | |
if mode == 'custom_goal_sampler': | |
test_goals = self.custom_goal_sampler(1) | |
if test_goals is None: | |
self._goal_sampling_mode = 'vae_prior' | |
warnings.warn( | |
"self.goal_sampler returned None. " + \ | |
"Defaulting to vae_prior goal sampling mode" | |
) | |
def custom_goal_sampler(self): | |
return self._custom_goal_sampler | |
def custom_goal_sampler(self, new_custom_goal_sampler): | |
assert self.custom_goal_sampler is None, ( | |
"Cannot override custom goal setter" | |
) | |
self._custom_goal_sampler = new_custom_goal_sampler | |
def decode_goals(self): | |
return self._decode_goals | |
def decode_goals(self, _decode_goals): | |
self._decode_goals = _decode_goals | |
def get_env_update(self): | |
""" | |
For online-parallel. Gets updates to the environment since the last time | |
the env was serialized. | |
subprocess_env.update_env(**env.get_env_update()) | |
""" | |
return dict( | |
mode_map=self._mode_map, | |
gpu_info=dict( | |
use_gpu=ptu._use_gpu, | |
gpu_id=ptu._gpu_id, | |
), | |
vae_state=self.vae.__getstate__(), | |
) | |
def update_env(self, mode_map, vae_state, gpu_info): | |
self._mode_map = mode_map | |
self.vae.__setstate__(vae_state) | |
gpu_id = gpu_info['gpu_id'] | |
use_gpu = gpu_info['use_gpu'] | |
ptu.device = torch.device("cuda:" + str(gpu_id) if use_gpu else "cpu") | |
self.vae.to(ptu.device) | |
def enable_render(self): | |
self._decode_goals = True | |
self.render_goals = True | |
self.render_rollouts = True | |
def disable_render(self): | |
self._decode_goals = False | |
self.render_goals = False | |
self.render_rollouts = False | |
def try_render(self, obs): | |
if self.render_rollouts: | |
img = obs['image_observation'].reshape( | |
self.input_channels, | |
self.imsize, | |
self.imsize, | |
).transpose() | |
cv2.imshow('env', img) | |
cv2.waitKey(1) | |
reconstruction = self._reconstruct_img(obs['image_observation']).transpose() | |
cv2.imshow('env_reconstruction', reconstruction) | |
cv2.waitKey(1) | |
init_img = self._initial_obs['image_observation'].reshape( | |
self.input_channels, | |
self.imsize, | |
self.imsize, | |
).transpose() | |
cv2.imshow('initial_state', init_img) | |
cv2.waitKey(1) | |
init_reconstruction = self._reconstruct_img( | |
self._initial_obs['image_observation'] | |
).transpose() | |
cv2.imshow('init_reconstruction', init_reconstruction) | |
cv2.waitKey(1) | |
if self.render_goals: | |
goal = obs['image_desired_goal'].reshape( | |
self.input_channels, | |
self.imsize, | |
self.imsize, | |
).transpose() | |
cv2.imshow('goal', goal) | |
cv2.waitKey(1) | |
def _sample_vae_prior(self, batch_size): | |
if self.sample_from_true_prior: | |
mu, sigma = 0, 1 # sample from prior | |
else: | |
mu, sigma = self.vae.dist_mu, self.vae.dist_std | |
n = np.random.randn(batch_size, self.representation_size) | |
return sigma * n + mu | |
def _decode(self, latents): | |
reconstructions, _ = self.vae.decode(ptu.from_numpy(latents)) | |
decoded = ptu.get_numpy(reconstructions) | |
return decoded | |
def _encode_one(self, img): | |
return self._encode(img[None])[0] | |
def _encode(self, imgs): | |
latent_distribution_params = self.vae.encode(ptu.from_numpy(imgs)) | |
return ptu.get_numpy(latent_distribution_params[0]) | |
def _reconstruct_img(self, flat_img): | |
latent_distribution_params = self.vae.encode(ptu.from_numpy(flat_img.reshape(1,-1))) | |
reconstructions, _ = self.vae.decode(latent_distribution_params[0]) | |
imgs = ptu.get_numpy(reconstructions) | |
imgs = imgs.reshape( | |
1, self.input_channels, self.imsize, self.imsize | |
) | |
return imgs[0] | |
def _image_and_proprio_from_decoded(self, decoded): | |
if decoded is None: | |
return None, None | |
if self.vae_input_key_prefix == 'image_proprio': | |
images = decoded[:, :self.image_length] | |
proprio = decoded[:, self.image_length:] | |
return images, proprio | |
elif self.vae_input_key_prefix == 'image': | |
return decoded, None | |
else: | |
raise AssertionError("Bad prefix for the vae input key.") | |
def __getstate__(self): | |
state = super().__getstate__() | |
state = copy.copy(state) | |
state['_custom_goal_sampler'] = None | |
warnings.warn('VAEWrapperEnv.custom_goal_sampler is not saved.') | |
return state | |
def __setstate__(self, state): | |
warnings.warn('VAEWrapperEnv.custom_goal_sampler was not loaded.') | |
super().__setstate__(state) | |
def temporary_mode(env, mode, func, args=None, kwargs=None): | |
if args is None: | |
args = [] | |
if kwargs is None: | |
kwargs = {} | |
cur_mode = env.cur_mode | |
env.mode(env._mode_map[mode]) | |
return_val = func(*args, **kwargs) | |
env.mode(cur_mode) | |
return return_val | |