baiyanlali-zhao's picture
init
eaf2e33
raw
history blame
16.2 kB
import copy
import random
import warnings
import torch
import cv2
import numpy as np
from gym.spaces import Box, Dict
import rlkit.torch.pytorch_util as ptu
from multiworld.core.multitask_env import MultitaskEnv
from multiworld.envs.env_util import get_stat_in_paths, create_stats_ordered_dict
from rlkit.envs.wrappers import ProxyEnv
class VAEWrappedEnv(ProxyEnv, MultitaskEnv):
"""This class wraps an image-based environment with a VAE.
Assumes you get flattened (channels,84,84) observations from wrapped_env.
This class adheres to the "Silent Multitask Env" semantics: on reset,
it resamples a goal.
"""
def __init__(
self,
wrapped_env,
vae,
vae_input_key_prefix='image',
sample_from_true_prior=False,
decode_goals=False,
render_goals=False,
render_rollouts=False,
reward_params=None,
goal_sampling_mode="vae_prior",
imsize=84,
obs_size=None,
norm_order=2,
epsilon=20,
presampled_goals=None,
):
if reward_params is None:
reward_params = dict()
super().__init__(wrapped_env)
self.vae = vae
self.representation_size = self.vae.representation_size
self.input_channels = self.vae.input_channels
self.sample_from_true_prior = sample_from_true_prior
self._decode_goals = decode_goals
self.render_goals = render_goals
self.render_rollouts = render_rollouts
self.default_kwargs=dict(
decode_goals=decode_goals,
render_goals=render_goals,
render_rollouts=render_rollouts,
)
self.imsize = imsize
self.reward_params = reward_params
self.reward_type = self.reward_params.get("type", 'latent_distance')
self.norm_order = self.reward_params.get("norm_order", norm_order)
self.epsilon = self.reward_params.get("epsilon", epsilon)
self.reward_min_variance = self.reward_params.get("min_variance", 0)
latent_space = Box(
-10 * np.ones(obs_size or self.representation_size),
10 * np.ones(obs_size or self.representation_size),
dtype=np.float32,
)
spaces = self.wrapped_env.observation_space.spaces
spaces['observation'] = latent_space
spaces['desired_goal'] = latent_space
spaces['achieved_goal'] = latent_space
spaces['latent_observation'] = latent_space
spaces['latent_desired_goal'] = latent_space
spaces['latent_achieved_goal'] = latent_space
self.observation_space = Dict(spaces)
self._presampled_goals = presampled_goals
if self._presampled_goals is None:
self.num_goals_presampled = 0
else:
self.num_goals_presampled = presampled_goals[random.choice(list(presampled_goals))].shape[0]
self.vae_input_key_prefix = vae_input_key_prefix
assert vae_input_key_prefix in {'image', 'image_proprio'}
self.vae_input_observation_key = vae_input_key_prefix + '_observation'
self.vae_input_achieved_goal_key = vae_input_key_prefix + '_achieved_goal'
self.vae_input_desired_goal_key = vae_input_key_prefix + '_desired_goal'
self._mode_map = {}
self.desired_goal = {'latent_desired_goal': latent_space.sample()}
self._initial_obs = None
self._custom_goal_sampler = None
self._goal_sampling_mode = goal_sampling_mode
def reset(self):
obs = self.wrapped_env.reset()
goal = self.sample_goal()
self.set_goal(goal)
self._initial_obs = obs
return self._update_obs(obs)
def step(self, action):
obs, reward, done, info = self.wrapped_env.step(action)
new_obs = self._update_obs(obs)
self._update_info(info, new_obs)
reward = self.compute_reward(
action,
{'latent_achieved_goal': new_obs['latent_achieved_goal'],
'latent_desired_goal': new_obs['latent_desired_goal']}
)
self.try_render(new_obs)
return new_obs, reward, done, info
def _update_obs(self, obs):
latent_obs = self._encode_one(obs[self.vae_input_observation_key])
obs['latent_observation'] = latent_obs
obs['latent_achieved_goal'] = latent_obs
obs['observation'] = latent_obs
obs['achieved_goal'] = latent_obs
obs = {**obs, **self.desired_goal}
return obs
def _update_info(self, info, obs):
latent_distribution_params = self.vae.encode(
ptu.from_numpy(obs[self.vae_input_observation_key].reshape(1,-1))
)
latent_obs, logvar = ptu.get_numpy(latent_distribution_params[0])[0], ptu.get_numpy(latent_distribution_params[1])[0]
# assert (latent_obs == obs['latent_observation']).all()
latent_goal = self.desired_goal['latent_desired_goal']
dist = latent_goal - latent_obs
var = np.exp(logvar.flatten())
var = np.maximum(var, self.reward_min_variance)
err = dist * dist / 2 / var
mdist = np.sum(err) # mahalanobis distance
info["vae_mdist"] = mdist
info["vae_success"] = 1 if mdist < self.epsilon else 0
info["vae_dist"] = np.linalg.norm(dist, ord=self.norm_order)
info["vae_dist_l1"] = np.linalg.norm(dist, ord=1)
info["vae_dist_l2"] = np.linalg.norm(dist, ord=2)
"""
Multitask functions
"""
def sample_goals(self, batch_size):
# TODO: make mode a parameter you pass in
if self._goal_sampling_mode == 'custom_goal_sampler':
return self.custom_goal_sampler(batch_size)
elif self._goal_sampling_mode == 'presampled':
idx = np.random.randint(0, self.num_goals_presampled, batch_size)
sampled_goals = {
k: v[idx] for k, v in self._presampled_goals.items()
}
# ensures goals are encoded using latest vae
if 'image_desired_goal' in sampled_goals:
sampled_goals['latent_desired_goal'] = self._encode(sampled_goals['image_desired_goal'])
return sampled_goals
elif self._goal_sampling_mode == 'env':
goals = self.wrapped_env.sample_goals(batch_size)
latent_goals = self._encode(goals[self.vae_input_desired_goal_key])
elif self._goal_sampling_mode == 'reset_of_env':
assert batch_size == 1
goal = self.wrapped_env.get_goal()
goals = {k: v[None] for k, v in goal.items()}
latent_goals = self._encode(
goals[self.vae_input_desired_goal_key]
)
elif self._goal_sampling_mode == 'vae_prior':
goals = {}
latent_goals = self._sample_vae_prior(batch_size)
else:
raise RuntimeError("Invalid: {}".format(self._goal_sampling_mode))
if self._decode_goals:
decoded_goals = self._decode(latent_goals)
else:
decoded_goals = None
image_goals, proprio_goals = self._image_and_proprio_from_decoded(
decoded_goals
)
goals['desired_goal'] = latent_goals
goals['latent_desired_goal'] = latent_goals
if proprio_goals is not None:
goals['proprio_desired_goal'] = proprio_goals
if image_goals is not None:
goals['image_desired_goal'] = image_goals
if decoded_goals is not None:
goals[self.vae_input_desired_goal_key] = decoded_goals
return goals
def get_goal(self):
return self.desired_goal
def compute_reward(self, action, obs):
actions = action[None]
next_obs = {
k: v[None] for k, v in obs.items()
}
return self.compute_rewards(actions, next_obs)[0]
def compute_rewards(self, actions, obs):
# TODO: implement log_prob/mdist
if self.reward_type == 'latent_distance':
achieved_goals = obs['latent_achieved_goal']
desired_goals = obs['latent_desired_goal']
dist = np.linalg.norm(desired_goals - achieved_goals, ord=self.norm_order, axis=1)
return -dist
elif self.reward_type == 'vectorized_latent_distance':
achieved_goals = obs['latent_achieved_goal']
desired_goals = obs['latent_desired_goal']
return -np.abs(desired_goals - achieved_goals)
elif self.reward_type == 'latent_sparse':
achieved_goals = obs['latent_achieved_goal']
desired_goals = obs['latent_desired_goal']
dist = np.linalg.norm(desired_goals - achieved_goals, ord=self.norm_order, axis=1)
reward = 0 if dist < self.epsilon else -1
return reward
elif self.reward_type == 'state_distance':
achieved_goals = obs['state_achieved_goal']
desired_goals = obs['state_desired_goal']
return - np.linalg.norm(desired_goals - achieved_goals, ord=self.norm_order, axis=1)
elif self.reward_type == 'wrapped_env':
return self.wrapped_env.compute_rewards(actions, obs)
else:
raise NotImplementedError
@property
def goal_dim(self):
return self.representation_size
def set_goal(self, goal):
"""
Assume goal contains both image_desired_goal and any goals required for wrapped envs
:param goal:
:return:
"""
self.desired_goal = goal
# TODO: fix this hack / document this
if self._goal_sampling_mode in {'presampled', 'env'}:
self.wrapped_env.set_goal(goal)
def get_diagnostics(self, paths, **kwargs):
statistics = self.wrapped_env.get_diagnostics(paths, **kwargs)
for stat_name_in_paths in ["vae_mdist", "vae_success", "vae_dist"]:
stats = get_stat_in_paths(paths, 'env_infos', stat_name_in_paths)
statistics.update(create_stats_ordered_dict(
stat_name_in_paths,
stats,
always_show_all_stats=True,
))
final_stats = [s[-1] for s in stats]
statistics.update(create_stats_ordered_dict(
"Final " + stat_name_in_paths,
final_stats,
always_show_all_stats=True,
))
return statistics
"""
Other functions
"""
@property
def goal_sampling_mode(self):
return self._goal_sampling_mode
@goal_sampling_mode.setter
def goal_sampling_mode(self, mode):
assert mode in [
'custom_goal_sampler',
'presampled',
'vae_prior',
'env',
'reset_of_env'
], "Invalid env mode"
self._goal_sampling_mode = mode
if mode == 'custom_goal_sampler':
test_goals = self.custom_goal_sampler(1)
if test_goals is None:
self._goal_sampling_mode = 'vae_prior'
warnings.warn(
"self.goal_sampler returned None. " + \
"Defaulting to vae_prior goal sampling mode"
)
@property
def custom_goal_sampler(self):
return self._custom_goal_sampler
@custom_goal_sampler.setter
def custom_goal_sampler(self, new_custom_goal_sampler):
assert self.custom_goal_sampler is None, (
"Cannot override custom goal setter"
)
self._custom_goal_sampler = new_custom_goal_sampler
@property
def decode_goals(self):
return self._decode_goals
@decode_goals.setter
def decode_goals(self, _decode_goals):
self._decode_goals = _decode_goals
def get_env_update(self):
"""
For online-parallel. Gets updates to the environment since the last time
the env was serialized.
subprocess_env.update_env(**env.get_env_update())
"""
return dict(
mode_map=self._mode_map,
gpu_info=dict(
use_gpu=ptu._use_gpu,
gpu_id=ptu._gpu_id,
),
vae_state=self.vae.__getstate__(),
)
def update_env(self, mode_map, vae_state, gpu_info):
self._mode_map = mode_map
self.vae.__setstate__(vae_state)
gpu_id = gpu_info['gpu_id']
use_gpu = gpu_info['use_gpu']
ptu.device = torch.device("cuda:" + str(gpu_id) if use_gpu else "cpu")
self.vae.to(ptu.device)
def enable_render(self):
self._decode_goals = True
self.render_goals = True
self.render_rollouts = True
def disable_render(self):
self._decode_goals = False
self.render_goals = False
self.render_rollouts = False
def try_render(self, obs):
if self.render_rollouts:
img = obs['image_observation'].reshape(
self.input_channels,
self.imsize,
self.imsize,
).transpose()
cv2.imshow('env', img)
cv2.waitKey(1)
reconstruction = self._reconstruct_img(obs['image_observation']).transpose()
cv2.imshow('env_reconstruction', reconstruction)
cv2.waitKey(1)
init_img = self._initial_obs['image_observation'].reshape(
self.input_channels,
self.imsize,
self.imsize,
).transpose()
cv2.imshow('initial_state', init_img)
cv2.waitKey(1)
init_reconstruction = self._reconstruct_img(
self._initial_obs['image_observation']
).transpose()
cv2.imshow('init_reconstruction', init_reconstruction)
cv2.waitKey(1)
if self.render_goals:
goal = obs['image_desired_goal'].reshape(
self.input_channels,
self.imsize,
self.imsize,
).transpose()
cv2.imshow('goal', goal)
cv2.waitKey(1)
def _sample_vae_prior(self, batch_size):
if self.sample_from_true_prior:
mu, sigma = 0, 1 # sample from prior
else:
mu, sigma = self.vae.dist_mu, self.vae.dist_std
n = np.random.randn(batch_size, self.representation_size)
return sigma * n + mu
def _decode(self, latents):
reconstructions, _ = self.vae.decode(ptu.from_numpy(latents))
decoded = ptu.get_numpy(reconstructions)
return decoded
def _encode_one(self, img):
return self._encode(img[None])[0]
def _encode(self, imgs):
latent_distribution_params = self.vae.encode(ptu.from_numpy(imgs))
return ptu.get_numpy(latent_distribution_params[0])
def _reconstruct_img(self, flat_img):
latent_distribution_params = self.vae.encode(ptu.from_numpy(flat_img.reshape(1,-1)))
reconstructions, _ = self.vae.decode(latent_distribution_params[0])
imgs = ptu.get_numpy(reconstructions)
imgs = imgs.reshape(
1, self.input_channels, self.imsize, self.imsize
)
return imgs[0]
def _image_and_proprio_from_decoded(self, decoded):
if decoded is None:
return None, None
if self.vae_input_key_prefix == 'image_proprio':
images = decoded[:, :self.image_length]
proprio = decoded[:, self.image_length:]
return images, proprio
elif self.vae_input_key_prefix == 'image':
return decoded, None
else:
raise AssertionError("Bad prefix for the vae input key.")
def __getstate__(self):
state = super().__getstate__()
state = copy.copy(state)
state['_custom_goal_sampler'] = None
warnings.warn('VAEWrapperEnv.custom_goal_sampler is not saved.')
return state
def __setstate__(self, state):
warnings.warn('VAEWrapperEnv.custom_goal_sampler was not loaded.')
super().__setstate__(state)
def temporary_mode(env, mode, func, args=None, kwargs=None):
if args is None:
args = []
if kwargs is None:
kwargs = {}
cur_mode = env.cur_mode
env.mode(env._mode_map[mode])
return_val = func(*args, **kwargs)
env.mode(cur_mode)
return return_val