import torch from agent.dreamer import DreamerAgent, ActorCritic, stop_gradient, env_reward import agent.dreamer_utils as common import agent.video_utils as video_utils from tools.genrl_utils import * def connector_update_fn(self, module_name, data, outputs, metrics): connector = getattr(self, module_name) n_frames = connector.n_frames B, T = data['observation'].shape[:2] # video embed are actions if getattr(self.cfg, "viclip_encode", False): video_embed = data['clip_video'] else: # Obtaining video embed with torch.no_grad(): viclip_model = getattr(self, 'viclip_model') processed_obs = viclip_model.preprocess_transf(data['observation'].reshape(B*T, *data['observation'].shape[2:]) / 255) reshaped_obs = processed_obs.reshape(B * (T // n_frames), n_frames, 3,224,224) video_embed = viclip_model.get_vid_features(reshaped_obs.to(viclip_model.device)) # Get posterior states from original model wm_post = outputs['post'] return connector.update(video_embed, wm_post) class GenRLAgent(DreamerAgent): def __init__(self, **kwargs): super().__init__(**kwargs) self.n_frames = 8 # NOTE: this should become an hyperparam if changing the model self.viclip_emb_dim = 512 # NOTE: this should become an hyperparam if changing the model assert self.cfg.batch_length % self.n_frames == 0, "Fix batch length param" if 'clip_video' in self.obs_space: self.viclip_emb_dim = self.obs_space['clip_video'].shape[0] connector = video_utils.VideoSSM(**self.cfg.connector, **self.cfg.connector_rssm, connector_kl=self.cfg.connector_kl, n_frames=self.n_frames, action_dim=self.viclip_emb_dim + self.n_frames, clip_add_noise=self.cfg.clip_add_noise, clip_lafite_noise=self.cfg.clip_lafite_noise, device=self.device, cell_input='stoch') connector.to(self.device) self.wm.add_module_to_update('connector', connector, connector_update_fn, detached=self.cfg.connector.detached_post) if getattr(self.cfg, 'imag_reward_fn', None) is not None: self.instantiate_imag_behavior() def instantiate_imag_behavior(self): self._imag_behavior = ActorCritic(self.cfg, self.act_spec, self.wm.inp_size, name='imag').to(self.device) self._imag_behavior.rewnorm = common.StreamNorm(**self.cfg.imag_reward_norm, device=self.device) def finetune_mode(self,): self._acting_behavior = self._imag_behavior self.wm.detached_update_fns = {} self.wm.e2e_update_fns = {} self.wm.grad_heads.append('reward') def update_wm(self, data, step): return super().update_wm(data, step) def report(self, data, key='observation', nvid=8): # Redefine data with trim n_frames = self.wm.connector.n_frames obs = data['observation'][:nvid, n_frames:] B, T = obs.shape[:2] report_data = super().report(data) wm = self.wm n_frames = wm.connector.n_frames # Init is same as Dreamer for reporting truth = data[key][:nvid] / 255 decoder = wm.heads['decoder'] # B, T, C, H, W preprocessed_data = self.wm.preprocess(data) embed = wm.encoder(preprocessed_data) states, _ = wm.rssm.observe(embed[:nvid, :n_frames], data['action'][:nvid, :n_frames], data['is_first'][:nvid, :n_frames]) recon = decoder(wm.decoder_input_fn(states))[key].mean[:nvid] # mode dreamer_init = {k: v[:, -1] for k, v in states.items()} # video embed are actions if getattr(self.cfg, "viclip_encode", False): video_embed = data['clip_video'][:nvid,n_frames*2-1::n_frames] else: # Obtain embed processed_obs = wm.viclip_model.preprocess_transf(obs.reshape(B*T, *obs.shape[2:]) / 255) reshaped_obs = processed_obs.reshape(B * (T // n_frames), n_frames, 3,224,224) video_embed = wm.viclip_model.get_vid_features(reshaped_obs.to(wm.viclip_model.device)) video_embed = video_embed.to(self.device) # Get actions video_embed = video_embed.reshape(B, T // n_frames, -1).unsqueeze(2).repeat(1,1,n_frames, 1).reshape(B, T, -1) prior = wm.connector.video_imagine(video_embed, dreamer_init, reset_every_n_frames=False) prior_recon = decoder(wm.decoder_input_fn(prior))[key].mean # mode model = torch.clip(torch.cat([recon[:, :n_frames] + 0.5, prior_recon + 0.5], 1), 0, 1) error = (model - truth + 1) / 2 # Add video to logs video = torch.cat([truth, model, error], 3) report_data['video_clip_pred'] = video return report_data def update_imag_behavior(self, state=None, outputs=None, metrics={}, seq_data=None,): if getattr(self.cfg, 'imag_reward_fn', None) is None: return outputs['post'], metrics if outputs is not None: post = outputs['post'] is_terminal = outputs['is_terminal'] else: seq_data = self.wm.preprocess(seq_data) embed = self.wm.encoder(seq_data) post, _ = self.wm.rssm.observe( embed, seq_data['action'], seq_data['is_first']) is_terminal = seq_data['is_terminal'] # start = {k: stop_gradient(v) for k,v in post.items()} imag_reward_fn = lambda seq: globals()[self.cfg.imag_reward_fn](self, seq, **self.cfg.imag_reward_args) metrics.update(self._imag_behavior.update(self.wm, start, is_terminal, imag_reward_fn,)) return start, metrics