trumans / models /synhsi.py
jnnan's picture
Upload 68 files
aeba71c verified
raw
history blame
18.3 kB
import math
import pdb
import torch
from torch import nn
import torch.nn.functional as F
from vit_pytorch import ViT
from tqdm import tqdm
from utils import *
class Sampler:
def __init__(self, device, mask_ind, emb_f, batch_size, seq_len, channel, fix_mode, timesteps, fixed_frame, **kwargs):
self.device = device
self.mask_ind = mask_ind
self.emb_f = emb_f
self.batch_size = batch_size
self.seq_len = seq_len
self.channel = channel
self.fix_mode = fix_mode
self.timesteps = timesteps
self.fixed_frame = fixed_frame
self.get_scheduler()
def set_dataset_and_model(self, dataset, model):
self.dataset = dataset
if dataset.load_scene:
self.grid = dataset.create_meshgrid(batch_size=self.batch_size).to(self.device)
self.model = model
def get_scheduler(self):
betas = linear_beta_schedule(timesteps=self.timesteps)
# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
self.betas = betas
def q_sample(self, x_start, t, noise):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def p_losses(self, x_start, obj_points, mat, scene_flag, mask, t, action_label, noise=None, loss_type='huber'):
if noise is None:
noise = torch.randn_like(x_start)
noise[mask] = 0.
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
if self.dataset.load_scene:
with torch.no_grad():
x_orig = transform_points(self.dataset.denormalize_torch(x_noisy), mat)
mat_for_query = mat.clone()
target_ind = self.mask_ind if self.mask_ind != -1 else 0
mat_for_query[:, :3, 3] = x_orig[:, self.emb_f, target_ind * 3: target_ind * 3 + 3]
mat_for_query[:, 1, 3] = 0
query_points = transform_points(self.grid, mat_for_query)
occ = self.dataset.get_occ_for_points(query_points, obj_points, scene_flag)
nb_voxels = self.dataset.nb_voxels
occ = occ.reshape(-1, nb_voxels, nb_voxels, nb_voxels).float()
# import trimesh
# print(mat[0])
# grid_np = self.grid[0].detach().cpu().numpy().reshape((-1, 3))
# occ_np = occ[0].detach().cpu().numpy().reshape((-1))
# points = grid_np[occ_np > 0.5]
# pcd_trimesh = trimesh.PointCloud(vertices=points)
# scene = trimesh.Scene([pcd_trimesh, trimesh.creation.axis(origin_color=[0, 0, 0])])
# scene.show()
occ = occ.permute(0, 2, 1, 3)
else:
occ = None
# x_noisy = torch.cat([x_noisy, occ], dim=-1).detach()
predicted_noise = self.model(x_noisy, occ, t, action_label, mask)
mask_inv = torch.logical_not(mask)
if loss_type == 'l1':
loss = F.l1_loss(noise[mask_inv], predicted_noise[mask_inv])
elif loss_type == 'l2':
loss = F.mse_loss(noise[mask_inv], predicted_noise[mask_inv])
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise[mask_inv], predicted_noise[mask_inv])
else:
raise NotImplementedError()
return loss
@torch.no_grad()
def p_sample_loop(self, fixed_points, obj_locs, mat, scene, goal, action_label):
device = next(self.model.parameters()).device
shape = (self.batch_size, self.seq_len, self.channel)
points = torch.randn(shape, device=device) # + torch.tensor([0., 0.3, 0.] * 22, device=device)
if self.fix_mode:
self.set_fixed_points(points, goal, fixed_points, mat, joint_id=self.mask_ind, fix_mode=True, fix_goal=True)
imgs = []
occs = []
if self.dataset.load_scene:
x_orig = transform_points(self.dataset.denormalize_torch(points), mat)
mat_for_query = mat.clone()
target_ind = self.mask_ind if self.mask_ind != -1 else 0
mat_for_query[:, :3, 3] = x_orig[:, self.emb_f, target_ind * 3: target_ind * 3 + 3]
mat_for_query[:, 1, 3] = 0
query_points = transform_points(self.grid, mat_for_query)
occ = self.dataset.get_occ_for_points(query_points, obj_locs, scene)
nb_voxels = self.dataset.nb_voxels
occ = occ.reshape(-1, nb_voxels, nb_voxels, nb_voxels).float()
# import trimesh
# print('\n', mat[0])
# grid_np = self.grid[0].detach().cpu().numpy().reshape((-1, 3))
# occ_np = occ[0].detach().cpu().numpy().reshape((-1))
# pointcloud = grid_np[occ_np > 0.5]
# pcd_trimesh = trimesh.PointCloud(vertices=pointcloud)
# np.save('/home/jiangnan/SyntheticHSI/Paper/Teaser/occ.npy', pointcloud)
# scene = trimesh.Scene([pcd_trimesh, trimesh.creation.axis(origin_color=[0, 0, 0])])
# scene.show()
occ = occ.permute(0, 2, 1, 3)
else:
occ = None
for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps):
model_used = self.model
# if s < 3 or (s == 3 and i < 5) or i < 3:
# model_used = model_fix
# else:
# model_used = model
points, occ = self.p_sample(model_used, points, fixed_points, goal, None, mat, occ,
torch.full((self.batch_size,), i, device=device, dtype=torch.long), i, action_label, self.mask_ind,
self.emb_f, self.fix_mode)
if self.fix_mode:
self.set_fixed_points(points, goal, fixed_points, mat, joint_id=self.mask_ind, fix_mode=True, fix_goal=True)
# set_fixed_points(points, goal, mat, joint_id=mask_ind)
# # t = torch.ones(b, device=device, dtype=torch.int64) * i
# if fixed_points is not None:
# points[:, :fixed_points.shape[1], :] = fixed_points # q_sample(fixed_points, t, None, sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod)
points_orig = transform_points(self.dataset.denormalize_torch(points), mat)
imgs.append(points_orig)
if occ is not None:
occs.append(occ.cpu().numpy())
return imgs, occs
@torch.no_grad()
def p_sample(self, model, x, fixed_points, goal, obj_points, mat, occ, t, t_index, action_label, mask_ind, emb_f,
fix_mode, no_scene=False):
betas_t = extract(self.betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(
self.sqrt_one_minus_alphas_cumprod, t, x.shape
)
sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
# Equation 11 in the paper
# Use our model (noise predictor) to predict the mean
# joints_orig = transform_points(synhsi_dataset.denormalize_torch(x), mat)
# occ = synhsi_dataset.get_occ_for_points(joints_orig, obj_points, scene)
# x_occ = torch.cat([x, occ], dim=-1).detach()
model_mean = sqrt_recip_alphas_t * (
x - betas_t * model(x, occ, t, action_label, mask=None) / sqrt_one_minus_alphas_cumprod_t
)
# model_mean_noact = sqrt_recip_alphas_t * (
# x - betas_t * model(x, occ, t, action_label, mask=None, no_action=True) / sqrt_one_minus_alphas_cumprod_t
# )
# model_mean = model_mean_noact + (model_mean - model_mean_noact) * 10
if not fix_mode:
self.set_fixed_points(model_mean, goal, fixed_points, mat, joint_id=mask_ind, fix_mode=True, fix_goal=False)
if t_index == 0:
return model_mean, occ
else:
posterior_variance_t = extract(self.posterior_variance, t, x.shape)
noise = torch.randn_like(x)
# Algorithm 2 line 4:
return model_mean + torch.sqrt(posterior_variance_t) * noise, occ
# Algorithm 2 (including returning all images)
def set_fixed_points(self, img, goal, fixed_points, mat, joint_id=0, fix_mode=False, fix_goal=True):
# if joint_id != 0:
# goal_len = 2
goal_len = goal.shape[1]
# goal_batch = goal.reshape(1, 1, 3).repeat(img.shape[0], 1, 1)
goal = self.dataset.normalize_torch(transform_points(goal, torch.inverse(mat)))
# img[:, -1, joint_id * 3: joint_id * 3 + 3] = goal_batch[:, 0]
if fix_goal:
img[:, -goal_len:, joint_id * 3] = goal[:, :, 0]
if joint_id != 0:
img[:, -goal_len:, joint_id * 3 + 1] = goal[:, :, 1]
img[:, -goal_len:, joint_id * 3 + 2] = goal[:, :, 2]
if fixed_points is not None and fix_mode:
img[:, :fixed_points.shape[1], :] = fixed_points
class Unet(nn.Module):
def __init__(
self,
dim_model,
num_heads,
num_layers,
dropout_p,
dim_input,
dim_output,
nb_voxels=None,
free_p=0.1,
nb_actions=0,
ac_type='',
no_scene=False,
no_action=False,
**kwargs
):
super().__init__()
# INFO
self.model_type = "Transformer"
self.dim_model = dim_model
self.nb_actions = nb_actions
self.ac_type = ac_type
self.no_scene = no_scene
self.no_action = no_action
# LAYERS
if not no_scene:
self.scene_embedding = ViT(
image_size=nb_voxels,
patch_size=nb_voxels // 4,
channels=nb_voxels,
num_classes=dim_model,
dim=1024,
depth=6,
heads=16,
mlp_dim=2048,
dropout=0.1,
emb_dropout=0.1
)
self.free_p = free_p
self.positional_encoder = PositionalEncoding(
dim_model=dim_model, dropout_p=dropout_p, max_len=5000
)
self.embedding_input = nn.Linear(dim_input, dim_model)
self.embedding_output = nn.Linear(dim_output, dim_model)
# self.embedding_action = nn.Parameter(torch.randn(16, dim_model))
if not no_action and nb_actions > 0:
if self.ac_type in ['last_add_first_token', 'last_new_token']:
self.embedding_action = ActionTransformerEncoder(action_number=nb_actions,
dim_model=dim_model,
nhead=num_heads // 2,
num_layers=num_layers,
dim_feedforward=dim_model,
dropout_p=dropout_p,
activation="gelu")
elif self.ac_type in ['all_add_token']:
self.embedding_action = nn.Sequential(
nn.Linear(nb_actions, dim_model),
nn.SiLU(inplace=False),
nn.Linear(dim_model, dim_model),
)
encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model,
nhead=num_heads,
dim_feedforward=dim_model,
dropout=dropout_p,
activation="gelu")
self.transformer = nn.TransformerEncoder(encoder_layer,
num_layers=num_layers
)
# self.out = nn.Linear(dim_model, dim_output)
self.out = nn.Linear(dim_model, dim_output)
self.embed_timestep = TimestepEmbedder(self.dim_model, self.positional_encoder)
def forward(self, x, cond, timesteps, action, mask, no_action=None):
#TODO ActionFlag
# action[action[:, 0] != 0., 0] = 1.
t_emb = self.embed_timestep(timesteps) # [1, b, d]
if self.no_scene:
scene_emb = torch.zeros_like(t_emb)
else:
scene_emb = self.scene_embedding(cond).reshape(-1, 1, self.dim_model)
if self.no_action or self.nb_actions == 0:
action_emb = torch.zeros_like(t_emb)
else:
if self.ac_type in ['all_add_token']:
action_emb = self.embedding_action(action)
elif self.ac_type in ['last_add_first_token', 'last_new_token']:
action_emb = self.embedding_action(action)
else:
raise NotImplementedError
t_emb = t_emb.permute(1, 0, 2)
free_ind = torch.rand(scene_emb.shape[0]).to(scene_emb.device) < self.free_p
scene_emb[free_ind] = 0.
# if mask is not None:
# x[free_ind][:, mask[0]] = 0.
if self.ac_type in ['last_add_first_token', 'last_new_token']:
action_emb[free_ind] = 0.
scene_emb = scene_emb.permute(1, 0, 2)
action_emb = action_emb.permute(1, 0, 2)
if self.ac_type in ['all_add_token', 'last_new_token']:
emb = t_emb + scene_emb
elif self.ac_type in ['last_add_first_token']:
emb = t_emb + scene_emb + action_emb
x = x.permute(1, 0, 2)
x = self.embedding_input(x) * math.sqrt(self.dim_model)
if self.ac_type in ['all_add_token', 'last_add_first_token']:
x = torch.cat((emb, x), dim=0)
elif self.ac_type in ['last_new_token']:
x = torch.cat((emb, action_emb, x), dim=0)
if self.ac_type in ['all_add_token']:
x[1:] = x[1:] + action_emb
x = self.positional_encoder(x)
x = self.transformer(x)
if self.ac_type in ['all_add_token', 'last_add_first_token']:
output = self.out(x)[1:]
elif self.ac_type in ['last_new_token']:
output = self.out(x)[2:]
output = output.permute(1, 0, 2)
return output
class PositionalEncoding(nn.Module):
def __init__(self, dim_model, dropout_p, max_len):
super().__init__()
# Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
# max_len determines how far the position can have an effect on a token (window)
# Info
self.dropout = nn.Dropout(dropout_p)
# Encoding - From formula
pos_encoding = torch.zeros(max_len, dim_model)
positions_list = torch.arange(0, max_len, dtype=torch.float).reshape(-1, 1) # 0, 1, 2, 3, 4, 5
division_term = torch.exp(
torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
# PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
# PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
# Saving buffer (same as parameter without gradients needed)
pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
self.register_buffer("pos_encoding", pos_encoding)
def forward(self, token_embedding: torch.tensor) -> torch.tensor:
# Residual connection + pos encoding
return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])
class TimestepEmbedder(nn.Module):
def __init__(self, latent_dim, sequence_pos_encoder):
super().__init__()
self.latent_dim = latent_dim
self.sequence_pos_encoder = sequence_pos_encoder
time_embed_dim = self.latent_dim
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(inplace=False),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timesteps):
return self.time_embed(self.sequence_pos_encoder.pos_encoding[timesteps])#.permute(1, 0, 2)
class ActionTransformerEncoder(nn.Module):
def __init__(self,
action_number,
dim_model,
nhead,
num_layers,
dim_feedforward,
dropout_p,
activation="gelu") -> None:
super().__init__()
self.positional_encoder = PositionalEncoding(
dim_model=dim_model, dropout_p=dropout_p, max_len=5000
)
self.input_embedder = nn.Linear(action_number, dim_model)
encoder_layer = nn.TransformerEncoderLayer(d_model=dim_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout_p,
activation=activation)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer,
num_layers=num_layers
)
def forward(self, x):
x = x.permute(1, 0, 2)
x = self.input_embedder(x)
x = self.positional_encoder(x)
x = self.transformer_encoder(x)
x = x.permute(1, 0, 2)
x = torch.mean(x, dim=1, keepdim=True)
return x