genrl / agent /video_utils.py
mazpie's picture
Initial commit
2d9a728
raw
history blame
10.8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import agent.dreamer_utils as common
from collections import defaultdict
import numpy as np
class ResidualLinear(nn.Module):
def __init__(self, in_channels, out_channels, norm='layer', act='SiLU', prenorm=False):
super().__init__()
self.norm_layer = common.NormLayer(norm, in_channels if prenorm else out_channels)
self.act = common.get_act(act)
self.layer = nn.Linear(in_channels, out_channels)
self.prenorm = prenorm
self.res_proj = nn.Identity() if in_channels == out_channels else nn.Linear(in_channels, out_channels)
def forward(self, x):
if self.prenorm:
h = self.norm_layer(x)
h = self.layer(h)
else:
h = self.layer(x)
h = self.norm_layer(h)
h = self.act(h)
return h + self.res_proj(x)
class UNetDenoiser(nn.Module):
def __init__(self, in_channels : int, mid_channels : int, n_layers : int, norm='layer', act= 'SiLU', ):
super().__init__()
out_channels = in_channels
self.down = nn.ModuleList()
for i in range(n_layers):
if i == (n_layers - 1):
self.down.append(ResidualLinear(in_channels, mid_channels, norm=norm, act=act))
else:
self.down.append(ResidualLinear(in_channels, in_channels, norm=norm, act=act))
self.mid = nn.ModuleList()
for i in range(n_layers):
self.mid.append(ResidualLinear(mid_channels, mid_channels, norm=norm, act=act))
self.up = nn.ModuleList()
for i in range(n_layers):
if i == 0:
self.up.append(ResidualLinear(mid_channels * 2, out_channels, norm='none', act='Identity'))
else:
self.up.append(ResidualLinear(out_channels * 2, out_channels, norm=norm, act=act))
def forward(self, x):
down_res = []
for down_layer in self.down:
x = down_layer(x)
down_res.append(x)
for mid_layer in self.mid:
x = mid_layer(x)
down_res.reverse()
for up_layer, res in zip(self.up, down_res):
x = up_layer(torch.cat([x, res], dim=-1))
return x
class VideoSSM(common.EnsembleRSSM):
def __init__(self, *args,
connector_kl={}, temporal_embeds=False, detached_post=True, n_frames=8,
token_dropout=0., loss_scale=1, clip_add_noise=0, clip_lafite_noise=0,
rescale_embeds=False, denoising_ae=False, learn_initial=True, **kwargs,):
super().__init__(*args, **kwargs)
#
self.n_frames = n_frames
# by default, adding the n_frames in actions (doesn't hurt and easier to test whether it's useful or not)
self.viclip_emb_dim = kwargs['action_dim'] - self.n_frames
#
self.temporal_embeds = temporal_embeds
self.detached_post = detached_post
self.connector_kl = connector_kl
self.token_dropout = token_dropout
self.loss_scale = loss_scale
self.rescale_embeds = rescale_embeds
self.clip_add_noise = clip_add_noise
self.clip_lafite_noise = clip_lafite_noise
self.clip_const = np.sqrt(self.viclip_emb_dim).item()
self.denoising_ae = denoising_ae
if self.denoising_ae:
self.aligner = UNetDenoiser(self.viclip_emb_dim, self.viclip_emb_dim // 2, n_layers=2, norm='layer', act='SiLU')
self.learn_initial = learn_initial
if self.learn_initial:
self.initial_state_pred = nn.Sequential(
nn.Linear(kwargs['action_dim'], kwargs['hidden']),
common.NormLayer(kwargs['norm'],kwargs['hidden']), common.get_act('SiLU'),
nn.Linear(kwargs['hidden'], kwargs['hidden']),
common.NormLayer(kwargs['norm'],kwargs['hidden']), common.get_act('SiLU'),
nn.Linear(kwargs['hidden'], kwargs['deter'])
)
# Deleting non-useful models
del self._obs_out
del self._obs_dist
def initial(self, batch_size, init_embed=None, ignore_learned=False):
init = super().initial(batch_size)
if self.learn_initial and not ignore_learned and hasattr(self, 'initial_state_pred'):
assert init_embed is not None
# patcher to avoid edge cases
if init_embed.shape[-1] == self.viclip_emb_dim:
patcher = torch.zeros((*init_embed.shape[:-1], 8), device=self.device)
init_embed = torch.cat([init_embed, patcher], dim=-1)
init['deter'] = self.initial_state_pred(init_embed)
stoch, stats = self.get_stoch_stats_from_deter_state(init)
init['stoch'] = stoch
init.update(stats)
return init
def get_action(self, video_embed):
n_frames = self.n_frames
B, T = video_embed.shape[:2]
if self.rescale_embeds:
video_embed = video_embed * self.clip_const
temporal_embeds = F.one_hot(torch.arange(T).to(video_embed.device) % n_frames, n_frames).reshape(1, T, n_frames,).repeat(B, 1, 1,)
if not self.temporal_embeds:
temporal_embeds *= 0
return torch.cat([video_embed, temporal_embeds],dim=-1)
def update(self, video_embed, wm_post):
n_frames = self.n_frames
B, T = video_embed.shape[:2]
loss = 0
metrics = {}
# NOVEL
video_embed = video_embed[:,n_frames-1::n_frames] # tested
video_embed = video_embed.to(self.device)
video_embed = video_embed.reshape(B, T // n_frames, 1, -1).repeat(1,1, n_frames, 1).reshape(B, T, -1)
orig_video_embed = video_embed
if self.clip_add_noise > 0:
video_embed = video_embed + torch.randn_like(video_embed, device=video_embed.device) * self.clip_add_noise
video_embed = nn.functional.normalize(video_embed, dim=-1)
if self.clip_lafite_noise > 0:
normed_noise = F.normalize(torch.randn_like(video_embed, device=video_embed.device), dim=-1)
video_embed = (1 - self.clip_lafite_noise) * video_embed + self.clip_lafite_noise * normed_noise
video_embed = nn.functional.normalize(video_embed, dim=-1)
if self.denoising_ae:
assert (self.clip_lafite_noise + self.clip_add_noise) > 0, "Nothing to denoise"
denoised_embed = self.aligner(video_embed)
denoised_embed = F.normalize(denoised_embed, dim=-1)
denoising_loss = 1 - F.cosine_similarity(denoised_embed, orig_video_embed, dim=-1).mean() # works same as F.mse_loss(denoised_embed, orig_video_embed).mean()
loss += denoising_loss
metrics['aligner_cosine_distance'] = denoising_loss
# if using a denoiser, it's the denoiser's duty to denoise the video embed
video_embed = orig_video_embed # could also be denoised_embed for e2e training
embed_actions = self.get_action(video_embed)
if self.detached_post:
wm_post = { k : v.reshape(B, T, *v.shape[2:]).detach() for k,v in wm_post.items() }
else:
wm_post = { k : v.reshape(B, T, *v.shape[2:]) for k,v in wm_post.items() }
# Get prior states
prior_states = defaultdict(list)
for t in range(T):
# Get video action
action = embed_actions[:, t]
if t == 0:
prev_state = self.initial(batch_size=wm_post['stoch'].shape[0], init_embed=action)
else:
# Get deter from prior, get stoch from wm_post
prev_state = prior
prev_state[self.cell_input] = wm_post[self.cell_input][:, t-1]
if self.token_dropout > 0:
prev_state['stoch'] = torch.einsum('b...,b->b...', prev_state['stoch'], (torch.rand(B, device=action.device) > self.token_dropout).float() )
prior = self.img_step(prev_state, action)
for k in prior:
prior_states[k].append(prior[k])
# Aggregate
for k in prior_states:
prior_states[k] = torch.stack(prior_states[k], dim=1)
# Compute loss
prior = prior_states
kl_loss, kl_value = self.kl_loss(wm_post, prior, **self.connector_kl)
video_loss = self.loss_scale * kl_loss
metrics['connector_kl'] = kl_value.mean()
loss += video_loss
# Compute initial KL
video_embed = video_embed.reshape(B, T // n_frames, n_frames, -1)[:,1:,0].reshape(B * (T//n_frames-1), 1, -1) # taking only one (0) and skipping first temporal step
embed_actions = self.get_action(video_embed)
wm_post = { k : v.reshape(B, T // n_frames, n_frames, *v.shape[2:])[:,1:,0].reshape(B * (T//n_frames-1), *v.shape[2:]) for k,v in wm_post.items() }
action = embed_actions[:, 0]
prev_state = self.initial(batch_size=wm_post['stoch'].shape[0], init_embed=action)
prior = self.img_step(prev_state, action)
kl_loss, kl_value = self.kl_loss(wm_post, prior, **self.connector_kl)
metrics['connector_initial_kl'] = kl_value.mean()
return loss, metrics
def video_imagine(self, video_embed, dreamer_init=None, sample=True, reset_every_n_frames=True, denoise=False):
n_frames = self.n_frames
B, T = video_embed.shape[:2]
if self.denoising_ae and denoise:
denoised_embed = self.aligner(video_embed)
video_embed = F.normalize(denoised_embed, dim=-1)
action = self.get_action(video_embed)
# Imagine
init = self.initial(batch_size=B, init_embed=action[:, 0]) # -> this ensures only stoch is used from the current frame
if dreamer_init is not None:
init[self.cell_input] = dreamer_init[self.cell_input]
if reset_every_n_frames:
prior_states = defaultdict(list)
for action_chunk in torch.chunk(action, T // n_frames, dim=1):
prior = self.imagine(action_chunk, init, sample=sample)
for k in prior:
prior_states[k].append(prior[k])
# -> this ensures only stoch is used from the current frame
init = self.initial(batch_size=B, ignore_learned=True)
init[self.cell_input] = prior[self.cell_input][:, -1]
# Agg
for k in prior_states:
prior_states[k] = torch.cat(prior_states[k], dim=1)
prior = prior_states
else:
prior = self.imagine(action, init, sample=sample)
return prior