import torch import hydra import numpy as np from einops import rearrange import random import os def seed_everything(seed: int): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True def transform_points(x, mat): shape = x.shape x = rearrange(x, 'b t (j c) -> b (t j) c', c=3) # B x N x 3 x = torch.einsum('bpc,bck->bpk', mat[:, :3, :3], x.permute(0, 2, 1)) # B x 3 x N N x B x 3 x = x.permute(2, 0, 1) + mat[:, :3, 3] x = x.permute(1, 0, 2) x = x.reshape(shape) return x def create_meshgrid(bbox, size, batch_size=1): x = torch.linspace(bbox[0], bbox[1], size[0]) y = torch.linspace(bbox[2], bbox[3], size[1]) z = torch.linspace(bbox[4], bbox[5], size[2]) xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij') grid = torch.stack([xx, yy, zz], dim=-1).reshape(-1, 3) grid = grid.repeat(batch_size, 1, 1) # aug_z = 0.75 + torch.rand(batch_size, 1) * 0.35 # grid[:, :, 2] = grid[:, :, 2] * aug_z return grid def zup_to_yup(coord): # change the coordinate from yup to zup if len(coord.shape) > 1: coord = coord[..., [0, 2, 1]] coord[..., 2] *= -1 else: coord = coord[[0, 2, 1]] coord[2] *= -1 return coord def rigid_transform_3D(A, B, scale=False): assert len(A) == len(B) N = A.shape[0] # total points centroid_A = np.mean(A, axis=0) centroid_B = np.mean(B, axis=0) # center the points AA = A - np.tile(centroid_A, (N, 1)) BB = B - np.tile(centroid_B, (N, 1)) # dot is matrix multiplication for array if scale: H = np.transpose(BB) * AA / N else: H = np.transpose(BB) * AA U, S, Vt = np.linalg.svd(H) R = Vt.T * U.T # special reflection case if np.linalg.det(R) < 0: print("Reflection detected") # return None, None, None Vt[2, :] *= -1 R = Vt.T * U.T if scale: varA = np.var(A, axis=0).sum() c = 1 / (1 / varA * np.sum(S)) # scale factor t = -R * (centroid_B.T * c) + centroid_A.T else: c = 1 t = -R * centroid_B.T + centroid_A.T return c, R, t def find_free_port(): from contextlib import closing import socket with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(('', 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return str(s.getsockname()[1]) def extract(a, t, x_shape): batch_size = t.shape[0] out = a.gather(-1, t.cpu()) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) def linear_beta_schedule(timesteps): beta_start = 0.0001 beta_end = 0.02 return torch.linspace(beta_start, beta_end, timesteps) def init_model(model_cfg, device, eval, load_state_dict=False): model = hydra.utils.instantiate(model_cfg) if eval: load_state_dict_eval(model, model_cfg.ckpt, device=device) else: model = model.to(device) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], broadcast_buffers=False, find_unused_parameters=True) if load_state_dict: model.module.load_state_dict(torch.load(model_cfg.ckpt)) model.train() return model def load_state_dict_eval(model, state_dict_path, map_location='cuda:0', device='cuda'): state_dict = torch.load(state_dict_path, map_location=map_location) key_list = [key for key in state_dict.keys()] for old_key in key_list: new_key = old_key.replace('module.', '') state_dict[new_key] = state_dict.pop(old_key) model.load_state_dict(state_dict) model.to(device) model.eval() class dotDict(dict): """dot.notation access to dictionary attributes""" def __getattr__(*args): val = dict.get(*args) return dotDict(val) if type(val) is dict else val __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__