import numpy as np import rlkit.torch.pytorch_util as ptu from multiworld.core.image_env import normalize_image from rlkit.core.eval_util import create_stats_ordered_dict from rlkit.data_management.obs_dict_replay_buffer import flatten_dict from rlkit.data_management.shared_obs_dict_replay_buffer import \ SharedObsDictRelabelingBuffer from rlkit.envs.vae_wrapper import VAEWrappedEnv from rlkit.torch.vae.vae_trainer import ( compute_p_x_np_to_np, relative_probs_from_log_probs, ) class OnlineVaeRelabelingBuffer(SharedObsDictRelabelingBuffer): def __init__( self, vae, *args, decoded_obs_key='image_observation', decoded_achieved_goal_key='image_achieved_goal', decoded_desired_goal_key='image_desired_goal', exploration_rewards_type='None', exploration_rewards_scale=1.0, vae_priority_type='None', start_skew_epoch=0, power=1.0, internal_keys=None, priority_function_kwargs=None, relabeling_goal_sampling_mode='vae_prior', **kwargs ): if internal_keys is None: internal_keys = [] for key in [ decoded_obs_key, decoded_achieved_goal_key, decoded_desired_goal_key ]: if key not in internal_keys: internal_keys.append(key) super().__init__(internal_keys=internal_keys, *args, **kwargs) assert isinstance(self.env, VAEWrappedEnv) self.vae = vae self.decoded_obs_key = decoded_obs_key self.decoded_desired_goal_key = decoded_desired_goal_key self.decoded_achieved_goal_key = decoded_achieved_goal_key self.exploration_rewards_type = exploration_rewards_type self.exploration_rewards_scale = exploration_rewards_scale self.start_skew_epoch = start_skew_epoch self.vae_priority_type = vae_priority_type self.power = power self._relabeling_goal_sampling_mode = relabeling_goal_sampling_mode self._give_explr_reward_bonus = ( exploration_rewards_type != 'None' and exploration_rewards_scale != 0. ) self._exploration_rewards = np.zeros((self.max_size, 1)) self._prioritize_vae_samples = ( vae_priority_type != 'None' and power != 0. ) self._vae_sample_priorities = np.zeros((self.max_size, 1)) self._vae_sample_probs = None type_to_function = { 'vae_prob': self.vae_prob, 'None': self.no_reward, } self.exploration_reward_func = ( type_to_function[self.exploration_rewards_type] ) self.vae_prioritization_func = ( type_to_function[self.vae_priority_type] ) if priority_function_kwargs is None: self.priority_function_kwargs = dict() else: self.priority_function_kwargs = priority_function_kwargs self.epoch = 0 self._register_mp_array("_exploration_rewards") self._register_mp_array("_vae_sample_priorities") def add_path(self, path): self.add_decoded_vae_goals_to_path(path) super().add_path(path) def add_decoded_vae_goals_to_path(self, path): # decoding the self-sampled vae images should be done in batch (here) # rather than in the env for efficiency desired_goals = flatten_dict( path['observations'], [self.desired_goal_key] )[self.desired_goal_key] desired_decoded_goals = self.env._decode(desired_goals) desired_decoded_goals = desired_decoded_goals.reshape( len(desired_decoded_goals), -1 ) for idx, next_obs in enumerate(path['observations']): path['observations'][idx][self.decoded_desired_goal_key] = \ desired_decoded_goals[idx] path['next_observations'][idx][self.decoded_desired_goal_key] = \ desired_decoded_goals[idx] def get_diagnostics(self): if self._vae_sample_probs is None or self._vae_sample_priorities is None: stats = create_stats_ordered_dict( 'VAE Sample Weights', np.zeros(self._size), ) stats.update(create_stats_ordered_dict( 'VAE Sample Probs', np.zeros(self._size), )) else: vae_sample_priorities = self._vae_sample_priorities[:self._size] vae_sample_probs = self._vae_sample_probs[:self._size] stats = create_stats_ordered_dict( 'VAE Sample Weights', vae_sample_priorities, ) stats.update(create_stats_ordered_dict( 'VAE Sample Probs', vae_sample_probs, )) return stats def refresh_latents(self, epoch): self.epoch = epoch self.skew = (self.epoch > self.start_skew_epoch) batch_size = 512 next_idx = min(batch_size, self._size) if self.exploration_rewards_type == 'hash_count': # you have to count everything then compute exploration rewards cur_idx = 0 next_idx = min(batch_size, self._size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) normalized_imgs = ( normalize_image(self._next_obs[self.decoded_obs_key][idxs]) ) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) cur_idx = 0 obs_sum = np.zeros(self.vae.representation_size) obs_square_sum = np.zeros(self.vae.representation_size) while cur_idx < self._size: idxs = np.arange(cur_idx, next_idx) self._obs[self.observation_key][idxs] = \ self.env._encode( normalize_image(self._obs[self.decoded_obs_key][idxs]) ) self._next_obs[self.observation_key][idxs] = \ self.env._encode( normalize_image(self._next_obs[self.decoded_obs_key][idxs]) ) # WARNING: we only refresh the desired/achieved latents for # "next_obs". This means that obs[desired/achieve] will be invalid, # so make sure there's no code that references this. self._next_obs[self.desired_goal_key][idxs] = \ self.env._encode( normalize_image(self._next_obs[self.decoded_desired_goal_key][idxs]) ) self._next_obs[self.achieved_goal_key][idxs] = \ self.env._encode( normalize_image(self._next_obs[self.decoded_achieved_goal_key][idxs]) ) normalized_imgs = ( normalize_image(self._next_obs[self.decoded_obs_key][idxs]) ) if self._give_explr_reward_bonus: rewards = self.exploration_reward_func( normalized_imgs, idxs, **self.priority_function_kwargs ) self._exploration_rewards[idxs] = rewards.reshape(-1, 1) if self._prioritize_vae_samples: if ( self.exploration_rewards_type == self.vae_priority_type and self._give_explr_reward_bonus ): self._vae_sample_priorities[idxs] = ( self._exploration_rewards[idxs] ) else: self._vae_sample_priorities[idxs] = ( self.vae_prioritization_func( normalized_imgs, idxs, **self.priority_function_kwargs ).reshape(-1, 1) ) obs_sum+= self._obs[self.observation_key][idxs].sum(axis=0) obs_square_sum+= np.power(self._obs[self.observation_key][idxs], 2).sum(axis=0) cur_idx = next_idx next_idx += batch_size next_idx = min(next_idx, self._size) self.vae.dist_mu = obs_sum/self._size self.vae.dist_std = np.sqrt(obs_square_sum/self._size - np.power(self.vae.dist_mu, 2)) if self._prioritize_vae_samples: """ priority^power is calculated in the priority function for image_bernoulli_prob or image_gaussian_inv_prob and directly here if not. """ if self.vae_priority_type == 'vae_prob': self._vae_sample_priorities[:self._size] = relative_probs_from_log_probs( self._vae_sample_priorities[:self._size] ) self._vae_sample_probs = self._vae_sample_priorities[:self._size] else: self._vae_sample_probs = self._vae_sample_priorities[:self._size] ** self.power p_sum = np.sum(self._vae_sample_probs) assert p_sum > 0, "Unnormalized p sum is {}".format(p_sum) self._vae_sample_probs /= np.sum(self._vae_sample_probs) self._vae_sample_probs = self._vae_sample_probs.flatten() def sample_weighted_indices(self, batch_size): if ( self._prioritize_vae_samples and self._vae_sample_probs is not None and self.skew ): indices = np.random.choice( len(self._vae_sample_probs), batch_size, p=self._vae_sample_probs, ) assert ( np.max(self._vae_sample_probs) <= 1 and np.min(self._vae_sample_probs) >= 0 ) else: indices = self._sample_indices(batch_size) return indices def _sample_goals_from_env(self, batch_size): self.env.goal_sampling_mode = self._relabeling_goal_sampling_mode return self.env.sample_goals(batch_size) def sample_buffer_goals(self, batch_size): """ Samples goals from weighted replay buffer for relabeling or exploration. Returns None if replay buffer is empty. Example of what might be returned: dict( image_desired_goals: image_achieved_goals[weighted_indices], latent_desired_goals: latent_desired_goals[weighted_indices], ) """ if self._size == 0: return None weighted_idxs = self.sample_weighted_indices( batch_size, ) next_image_obs = normalize_image( self._next_obs[self.decoded_obs_key][weighted_idxs] ) next_latent_obs = self._next_obs[self.achieved_goal_key][weighted_idxs] return { self.decoded_desired_goal_key: next_image_obs, self.desired_goal_key: next_latent_obs } def random_vae_training_data(self, batch_size, epoch): # epoch no longer needed. Using self.skew in sample_weighted_indices # instead. weighted_idxs = self.sample_weighted_indices( batch_size, ) next_image_obs = normalize_image( self._next_obs[self.decoded_obs_key][weighted_idxs] ) return dict( next_obs=ptu.from_numpy(next_image_obs) ) def vae_prob(self, next_vae_obs, indices, **kwargs): return compute_p_x_np_to_np( self.vae, next_vae_obs, power=self.power, **kwargs ) def no_reward(self, next_vae_obs, indices): return np.zeros((len(next_vae_obs), 1)) def _get_sorted_idx_and_train_weights(self): idx_and_weights = zip(range(len(self._vae_sample_probs)), self._vae_sample_probs) return sorted(idx_and_weights, key=lambda x: x[1])