import torch
from agent.dreamer import DreamerAgent, ActorCritic, stop_gradient, env_reward
import agent.dreamer_utils as common
import agent.video_utils as video_utils
from tools.genrl_utils import *
def connector_update_fn(self, module_name, data, outputs, metrics):
connector = getattr(self, module_name)
n_frames = connector.n_frames
B, T = data['observation'].shape[:2]
# video embed are actions
if getattr(self.cfg, "viclip_encode", False):
video_embed = data['clip_video']
# Obtaining video embed
with torch.no_grad():
viclip_model = getattr(self, 'viclip_model')
processed_obs = viclip_model.preprocess_transf(data['observation'].reshape(B*T, *data['observation'].shape[2:]) / 255)
reshaped_obs = processed_obs.reshape(B * (T // n_frames), n_frames, 3,224,224)
video_embed = viclip_model.get_vid_features(
# Get posterior states from original model
wm_post = outputs['post']
return connector.update(video_embed, wm_post)
class GenRLAgent(DreamerAgent):
def __init__(self, **kwargs):
self.n_frames = 8 # NOTE: this should become an hyperparam if changing the model
self.viclip_emb_dim = 512 # NOTE: this should become an hyperparam if changing the model
assert self.cfg.batch_length % self.n_frames == 0, "Fix batch length param"
if 'clip_video' in self.obs_space:
self.viclip_emb_dim = self.obs_space['clip_video'].shape[0]
connector = video_utils.VideoSSM(**self.cfg.connector, **self.cfg.connector_rssm, connector_kl=self.cfg.connector_kl,
n_frames=self.n_frames, action_dim=self.viclip_emb_dim + self.n_frames,
clip_add_noise=self.cfg.clip_add_noise, clip_lafite_noise=self.cfg.clip_lafite_noise,
device=self.device, cell_input='stoch')
self.wm.add_module_to_update('connector', connector, connector_update_fn, detached=self.cfg.connector.detached_post)
if getattr(self.cfg, 'imag_reward_fn', None) is not None:
def instantiate_imag_behavior(self):
self._imag_behavior = ActorCritic(self.cfg, self.act_spec, self.wm.inp_size, name='imag').to(self.device)
self._imag_behavior.rewnorm = common.StreamNorm(**self.cfg.imag_reward_norm, device=self.device)
def finetune_mode(self,):
self._acting_behavior = self._imag_behavior
self.wm.detached_update_fns = {}
self.wm.e2e_update_fns = {}
def update_wm(self, data, step):
return super().update_wm(data, step)
def report(self, data, key='observation', nvid=8):
# Redefine data with trim
n_frames = self.wm.connector.n_frames
obs = data['observation'][:nvid, n_frames:]
B, T = obs.shape[:2]
report_data = super().report(data)
wm = self.wm
n_frames = wm.connector.n_frames
# Init is same as Dreamer for reporting
truth = data[key][:nvid] / 255
decoder = wm.heads['decoder'] # B, T, C, H, W
preprocessed_data = self.wm.preprocess(data)
embed = wm.encoder(preprocessed_data)
states, _ = wm.rssm.observe(embed[:nvid, :n_frames], data['action'][:nvid, :n_frames], data['is_first'][:nvid, :n_frames])
recon = decoder(wm.decoder_input_fn(states))[key].mean[:nvid] # mode
dreamer_init = {k: v[:, -1] for k, v in states.items()}
# video embed are actions
if getattr(self.cfg, "viclip_encode", False):
video_embed = data['clip_video'][:nvid,n_frames*2-1::n_frames]
# Obtain embed
processed_obs = wm.viclip_model.preprocess_transf(obs.reshape(B*T, *obs.shape[2:]) / 255)
reshaped_obs = processed_obs.reshape(B * (T // n_frames), n_frames, 3,224,224)
video_embed = wm.viclip_model.get_vid_features(
video_embed =
# Get actions
video_embed = video_embed.reshape(B, T // n_frames, -1).unsqueeze(2).repeat(1,1,n_frames, 1).reshape(B, T, -1)
prior = wm.connector.video_imagine(video_embed, dreamer_init, reset_every_n_frames=False)
prior_recon = decoder(wm.decoder_input_fn(prior))[key].mean # mode
model = torch.clip([recon[:, :n_frames] + 0.5, prior_recon + 0.5], 1), 0, 1)
error = (model - truth + 1) / 2
# Add video to logs
video =[truth, model, error], 3)
report_data['video_clip_pred'] = video
return report_data
def update_imag_behavior(self, state=None, outputs=None, metrics={}, seq_data=None,):
if getattr(self.cfg, 'imag_reward_fn', None) is None:
return outputs['post'], metrics
if outputs is not None:
post = outputs['post']
is_terminal = outputs['is_terminal']
seq_data = self.wm.preprocess(seq_data)
embed = self.wm.encoder(seq_data)
post, _ = self.wm.rssm.observe(
embed, seq_data['action'], seq_data['is_first'])
is_terminal = seq_data['is_terminal']
start = {k: stop_gradient(v) for k,v in post.items()}
imag_reward_fn = lambda seq: globals()[self.cfg.imag_reward_fn](self, seq, **self.cfg.imag_reward_args)
metrics.update(self._imag_behavior.update(self.wm, start, is_terminal, imag_reward_fn,))
return start, metrics