Spaces:
Paused
Paused
import math | |
import pdb | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from vit_pytorch import ViT | |
from tqdm import tqdm | |
from utils import * | |
class Sampler: | |
def __init__(self, device, mask_ind, emb_f, batch_size, seq_len, channel, fix_mode, timesteps, fixed_frame, **kwargs): | |
self.device = device | |
self.mask_ind = mask_ind | |
self.emb_f = emb_f | |
self.batch_size = batch_size | |
self.seq_len = seq_len | |
self.channel = channel | |
self.fix_mode = fix_mode | |
self.timesteps = timesteps | |
self.fixed_frame = fixed_frame | |
self.get_scheduler() | |
def set_dataset_and_model(self, dataset, model): | |
self.dataset = dataset | |
if dataset.load_scene: | |
self.grid = dataset.create_meshgrid(batch_size=self.batch_size).to(self.device) | |
self.model = model | |
def get_scheduler(self): | |
betas = linear_beta_schedule(timesteps=self.timesteps) | |
# define alphas | |
alphas = 1. - betas | |
alphas_cumprod = torch.cumprod(alphas, axis=0) | |
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) | |
self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas) | |
# calculations for diffusion q(x_t | x_{t-1}) and others | |
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) | |
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod) | |
# calculations for posterior q(x_{t-1} | x_t, x_0) | |
self.posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) | |
self.betas = betas | |
def q_sample(self, x_start, t, noise): | |
if noise is None: | |
noise = torch.randn_like(x_start) | |
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) | |
sqrt_one_minus_alphas_cumprod_t = extract( | |
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape | |
) | |
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise | |
def p_losses(self, x_start, obj_points, mat, scene_flag, mask, t, action_label, noise=None, loss_type='huber'): | |
if noise is None: | |
noise = torch.randn_like(x_start) | |
noise[mask] = 0. | |
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) | |
if self.dataset.load_scene: | |
with torch.no_grad(): | |
x_orig = transform_points(self.dataset.denormalize_torch(x_noisy), mat) | |
mat_for_query = mat.clone() | |
target_ind = self.mask_ind if self.mask_ind != -1 else 0 | |
mat_for_query[:, :3, 3] = x_orig[:, self.emb_f, target_ind * 3: target_ind * 3 + 3] | |
mat_for_query[:, 1, 3] = 0 | |
query_points = transform_points(self.grid, mat_for_query) | |
occ = self.dataset.get_occ_for_points(query_points, obj_points, scene_flag) | |
nb_voxels = self.dataset.nb_voxels | |
occ = occ.reshape(-1, nb_voxels, nb_voxels, nb_voxels).float() | |
# import trimesh | |
# print(mat[0]) | |
# grid_np = self.grid[0].detach().cpu().numpy().reshape((-1, 3)) | |
# occ_np = occ[0].detach().cpu().numpy().reshape((-1)) | |
# points = grid_np[occ_np > 0.5] | |
# pcd_trimesh = trimesh.PointCloud(vertices=points) | |
# scene = trimesh.Scene([pcd_trimesh, trimesh.creation.axis(origin_color=[0, 0, 0])]) | |
# scene.show() | |
occ = occ.permute(0, 2, 1, 3) | |
else: | |
occ = None | |
# x_noisy = torch.cat([x_noisy, occ], dim=-1).detach() | |
predicted_noise = self.model(x_noisy, occ, t, action_label, mask) | |
mask_inv = torch.logical_not(mask) | |
if loss_type == 'l1': | |
loss = F.l1_loss(noise[mask_inv], predicted_noise[mask_inv]) | |
elif loss_type == 'l2': | |
loss = F.mse_loss(noise[mask_inv], predicted_noise[mask_inv]) | |
elif loss_type == "huber": | |
loss = F.smooth_l1_loss(noise[mask_inv], predicted_noise[mask_inv]) | |
else: | |
raise NotImplementedError() | |
return loss | |
def p_sample_loop(self, fixed_points, obj_locs, mat, scene, goal, action_label): | |
device = next(self.model.parameters()).device | |
shape = (self.batch_size, self.seq_len, self.channel) | |
points = torch.randn(shape, device=device) # + torch.tensor([0., 0.3, 0.] * 22, device=device) | |
if self.fix_mode: | |
self.set_fixed_points(points, goal, fixed_points, mat, joint_id=self.mask_ind, fix_mode=True, fix_goal=True) | |
imgs = [] | |
occs = [] | |
if self.dataset.load_scene: | |
x_orig = transform_points(self.dataset.denormalize_torch(points), mat) | |
mat_for_query = mat.clone() | |
target_ind = self.mask_ind if self.mask_ind != -1 else 0 | |
mat_for_query[:, :3, 3] = x_orig[:, self.emb_f, target_ind * 3: target_ind * 3 + 3] | |
mat_for_query[:, 1, 3] = 0 | |
query_points = transform_points(self.grid, mat_for_query) | |
occ = self.dataset.get_occ_for_points(query_points, obj_locs, scene) | |
nb_voxels = self.dataset.nb_voxels | |
occ = occ.reshape(-1, nb_voxels, nb_voxels, nb_voxels).float() | |
# import trimesh | |
# print('\n', mat[0]) | |
# grid_np = self.grid[0].detach().cpu().numpy().reshape((-1, 3)) | |
# occ_np = occ[0].detach().cpu().numpy().reshape((-1)) | |
# pointcloud = grid_np[occ_np > 0.5] | |
# pcd_trimesh = trimesh.PointCloud(vertices=pointcloud) | |
# np.save('/home/jiangnan/SyntheticHSI/Paper/Teaser/occ.npy', pointcloud) | |
# scene = trimesh.Scene([pcd_trimesh, trimesh.creation.axis(origin_color=[0, 0, 0])]) | |
# scene.show() | |
occ = occ.permute(0, 2, 1, 3) | |
else: | |
occ = None | |
for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps): | |
model_used = self.model | |
# if s < 3 or (s == 3 and i < 5) or i < 3: | |
# model_used = model_fix | |
# else: | |
# model_used = model | |
points, occ = self.p_sample(model_used, points, fixed_points, goal, None, mat, occ, | |
torch.full((self.batch_size,), i, device=device, dtype=torch.long), i, action_label, self.mask_ind, | |
self.emb_f, self.fix_mode) | |
if self.fix_mode: | |
self.set_fixed_points(points, goal, fixed_points, mat, joint_id=self.mask_ind, fix_mode=True, fix_goal=True) | |
# set_fixed_points(points, goal, mat, joint_id=mask_ind) | |
# # t = torch.ones(b, device=device, dtype=torch.int64) * i | |
# if fixed_points is not None: | |
# points[:, :fixed_points.shape[1], :] = fixed_points # q_sample(fixed_points, t, None, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod) | |
points_orig = transform_points(self.dataset.denormalize_torch(points), mat) | |
imgs.append(points_orig) | |
if occ is not None: | |
occs.append(occ.cpu().numpy()) | |
return imgs, occs | |
def p_sample(self, model, x, fixed_points, goal, obj_points, mat, occ, t, t_index, action_label, mask_ind, emb_f, | |
fix_mode, no_scene=False): | |
betas_t = extract(self.betas, t, x.shape) | |
sqrt_one_minus_alphas_cumprod_t = extract( | |
self.sqrt_one_minus_alphas_cumprod, t, x.shape | |
) | |
sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape) | |
# Equation 11 in the paper | |
# Use our model (noise predictor) to predict the mean | |
# joints_orig = transform_points(synhsi_dataset.denormalize_torch(x), mat) | |
# occ = synhsi_dataset.get_occ_for_points(joints_orig, obj_points, scene) | |
# x_occ = torch.cat([x, occ], dim=-1).detach() | |
model_mean = sqrt_recip_alphas_t * ( | |
x - betas_t * model(x, occ, t, action_label, mask=None) / sqrt_one_minus_alphas_cumprod_t | |
) | |
# model_mean_noact = sqrt_recip_alphas_t * ( | |
# x - betas_t * model(x, occ, t, action_label, mask=None, no_action=True) / sqrt_one_minus_alphas_cumprod_t | |
# ) | |
# model_mean = model_mean_noact + (model_mean - model_mean_noact) * 10 | |
if not fix_mode: | |
self.set_fixed_points(model_mean, goal, fixed_points, mat, joint_id=mask_ind, fix_mode=True, fix_goal=False) | |
if t_index == 0: | |
return model_mean, occ | |
else: | |
posterior_variance_t = extract(self.posterior_variance, t, x.shape) | |
noise = torch.randn_like(x) | |
# Algorithm 2 line 4: | |
return model_mean + torch.sqrt(posterior_variance_t) * noise, occ | |
# Algorithm 2 (including returning all images) | |
def set_fixed_points(self, img, goal, fixed_points, mat, joint_id=0, fix_mode=False, fix_goal=True): | |
# if joint_id != 0: | |
# goal_len = 2 | |
goal_len = goal.shape[1] | |
# goal_batch = goal.reshape(1, 1, 3).repeat(img.shape[0], 1, 1) | |
goal = self.dataset.normalize_torch(transform_points(goal, torch.inverse(mat))) | |
# img[:, -1, joint_id * 3: joint_id * 3 + 3] = goal_batch[:, 0] | |
if fix_goal: | |
img[:, -goal_len:, joint_id * 3] = goal[:, :, 0] | |
if joint_id != 0: | |
img[:, -goal_len:, joint_id * 3 + 1] = goal[:, :, 1] | |
img[:, -goal_len:, joint_id * 3 + 2] = goal[:, :, 2] | |
if fixed_points is not None and fix_mode: | |
img[:, :fixed_points.shape[1], :] = fixed_points | |
class Unet(nn.Module): | |
def __init__( | |
self, | |
dim_model, | |
num_heads, | |
num_layers, | |
dropout_p, | |
dim_input, | |
dim_output, | |
nb_voxels=None, | |
free_p=0.1, | |
nb_actions=0, | |
ac_type='', | |
no_scene=False, | |
no_action=False, | |
**kwargs | |
): | |
super().__init__() | |
# INFO | |
self.model_type = "Transformer" | |
self.dim_model = dim_model | |
self.nb_actions = nb_actions | |
self.ac_type = ac_type | |
self.no_scene = no_scene | |
self.no_action = no_action | |
# LAYERS | |
if not no_scene: | |
self.scene_embedding = ViT( | |
image_size=nb_voxels, | |
patch_size=nb_voxels // 4, | |
channels=nb_voxels, | |
num_classes=dim_model, | |
dim=1024, | |
depth=6, | |
heads=16, | |
mlp_dim=2048, | |
dropout=0.1, | |
emb_dropout=0.1 | |
) | |
self.free_p = free_p | |
self.positional_encoder = PositionalEncoding( | |
dim_model=dim_model, dropout_p=dropout_p, max_len=5000 | |
) | |
self.embedding_input = nn.Linear(dim_input, dim_model) | |
self.embedding_output = nn.Linear(dim_output, dim_model) | |
# self.embedding_action = nn.Parameter(torch.randn(16, dim_model)) | |
if not no_action and nb_actions > 0: | |
if self.ac_type in ['last_add_first_token', 'last_new_token']: | |
self.embedding_action = ActionTransformerEncoder(action_number=nb_actions, | |
dim_model=dim_model, | |
nhead=num_heads // 2, | |
num_layers=num_layers, | |
dim_feedforward=dim_model, | |
dropout_p=dropout_p, | |
activation="gelu") | |
elif self.ac_type in ['all_add_token']: | |
self.embedding_action = nn.Sequential( | |
nn.Linear(nb_actions, dim_model), | |
nn.SiLU(inplace=False), | |
nn.Linear(dim_model, dim_model), | |
) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model, | |
nhead=num_heads, | |
dim_feedforward=dim_model, | |
dropout=dropout_p, | |
activation="gelu") | |
self.transformer = nn.TransformerEncoder(encoder_layer, | |
num_layers=num_layers | |
) | |
# self.out = nn.Linear(dim_model, dim_output) | |
self.out = nn.Linear(dim_model, dim_output) | |
self.embed_timestep = TimestepEmbedder(self.dim_model, self.positional_encoder) | |
def forward(self, x, cond, timesteps, action, mask, no_action=None): | |
#TODO ActionFlag | |
# action[action[:, 0] != 0., 0] = 1. | |
t_emb = self.embed_timestep(timesteps) # [1, b, d] | |
if self.no_scene: | |
scene_emb = torch.zeros_like(t_emb) | |
else: | |
scene_emb = self.scene_embedding(cond).reshape(-1, 1, self.dim_model) | |
if self.no_action or self.nb_actions == 0: | |
action_emb = torch.zeros_like(t_emb) | |
else: | |
if self.ac_type in ['all_add_token']: | |
action_emb = self.embedding_action(action) | |
elif self.ac_type in ['last_add_first_token', 'last_new_token']: | |
action_emb = self.embedding_action(action) | |
else: | |
raise NotImplementedError | |
t_emb = t_emb.permute(1, 0, 2) | |
free_ind = torch.rand(scene_emb.shape[0]).to(scene_emb.device) < self.free_p | |
scene_emb[free_ind] = 0. | |
# if mask is not None: | |
# x[free_ind][:, mask[0]] = 0. | |
if self.ac_type in ['last_add_first_token', 'last_new_token']: | |
action_emb[free_ind] = 0. | |
scene_emb = scene_emb.permute(1, 0, 2) | |
action_emb = action_emb.permute(1, 0, 2) | |
if self.ac_type in ['all_add_token', 'last_new_token']: | |
emb = t_emb + scene_emb | |
elif self.ac_type in ['last_add_first_token']: | |
emb = t_emb + scene_emb + action_emb | |
x = x.permute(1, 0, 2) | |
x = self.embedding_input(x) * math.sqrt(self.dim_model) | |
if self.ac_type in ['all_add_token', 'last_add_first_token']: | |
x = torch.cat((emb, x), dim=0) | |
elif self.ac_type in ['last_new_token']: | |
x = torch.cat((emb, action_emb, x), dim=0) | |
if self.ac_type in ['all_add_token']: | |
x[1:] = x[1:] + action_emb | |
x = self.positional_encoder(x) | |
x = self.transformer(x) | |
if self.ac_type in ['all_add_token', 'last_add_first_token']: | |
output = self.out(x)[1:] | |
elif self.ac_type in ['last_new_token']: | |
output = self.out(x)[2:] | |
output = output.permute(1, 0, 2) | |
return output | |
class PositionalEncoding(nn.Module): | |
def __init__(self, dim_model, dropout_p, max_len): | |
super().__init__() | |
# Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html | |
# max_len determines how far the position can have an effect on a token (window) | |
# Info | |
self.dropout = nn.Dropout(dropout_p) | |
# Encoding - From formula | |
pos_encoding = torch.zeros(max_len, dim_model) | |
positions_list = torch.arange(0, max_len, dtype=torch.float).reshape(-1, 1) # 0, 1, 2, 3, 4, 5 | |
division_term = torch.exp( | |
torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model) | |
# PE(pos, 2i) = sin(pos/1000^(2i/dim_model)) | |
pos_encoding[:, 0::2] = torch.sin(positions_list * division_term) | |
# PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model)) | |
pos_encoding[:, 1::2] = torch.cos(positions_list * division_term) | |
# Saving buffer (same as parameter without gradients needed) | |
pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1) | |
self.register_buffer("pos_encoding", pos_encoding) | |
def forward(self, token_embedding: torch.tensor) -> torch.tensor: | |
# Residual connection + pos encoding | |
return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :]) | |
class TimestepEmbedder(nn.Module): | |
def __init__(self, latent_dim, sequence_pos_encoder): | |
super().__init__() | |
self.latent_dim = latent_dim | |
self.sequence_pos_encoder = sequence_pos_encoder | |
time_embed_dim = self.latent_dim | |
self.time_embed = nn.Sequential( | |
nn.Linear(self.latent_dim, time_embed_dim), | |
nn.SiLU(inplace=False), | |
nn.Linear(time_embed_dim, time_embed_dim), | |
) | |
def forward(self, timesteps): | |
return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2) | |
class ActionTransformerEncoder(nn.Module): | |
def __init__(self, | |
action_number, | |
dim_model, | |
nhead, | |
num_layers, | |
dim_feedforward, | |
dropout_p, | |
activation="gelu") -> None: | |
super().__init__() | |
self.positional_encoder = PositionalEncoding( | |
dim_model=dim_model, dropout_p=dropout_p, max_len=5000 | |
) | |
self.input_embedder = nn.Linear(action_number, dim_model) | |
encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model, | |
nhead=nhead, | |
dim_feedforward=dim_feedforward, | |
dropout=dropout_p, | |
activation=activation) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, | |
num_layers=num_layers | |
) | |
def forward(self, x): | |
x = x.permute(1, 0, 2) | |
x = self.input_embedder(x) | |
x = self.positional_encoder(x) | |
x = self.transformer_encoder(x) | |
x = x.permute(1, 0, 2) | |
x = torch.mean(x, dim=1, keepdim=True) | |
return x | |