|
import numpy as np
|
|
import torch
|
|
import random
|
|
|
|
|
|
|
|
def perspective(fovx=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
|
|
|
|
x = np.tan(fovx / 2)
|
|
return torch.tensor([[1/x, 0, 0, 0],
|
|
[ 0, -aspect/x, 0, 0],
|
|
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
|
|
[ 0, 0, -1, 0]], dtype=torch.float32, device=device)
|
|
|
|
|
|
def translate(x, y, z, device=None):
|
|
return torch.tensor([[1, 0, 0, x],
|
|
[0, 1, 0, y],
|
|
[0, 0, 1, z],
|
|
[0, 0, 0, 1]], dtype=torch.float32, device=device)
|
|
|
|
|
|
def rotate_x(a, device=None):
|
|
s, c = np.sin(a), np.cos(a)
|
|
return torch.tensor([[1, 0, 0, 0],
|
|
[0, c, -s, 0],
|
|
[0, s, c, 0],
|
|
[0, 0, 0, 1]], dtype=torch.float32, device=device)
|
|
|
|
|
|
def rotate_y(a, device=None):
|
|
s, c = np.sin(a), np.cos(a)
|
|
return torch.tensor([[ c, 0, s, 0],
|
|
[ 0, 1, 0, 0],
|
|
[-s, 0, c, 0],
|
|
[ 0, 0, 0, 1]], dtype=torch.float32, device=device)
|
|
|
|
|
|
def rotate_z(a, device=None):
|
|
s, c = np.sin(a), np.cos(a)
|
|
return torch.tensor([[c, -s, 0, 0],
|
|
[s, c, 0, 0],
|
|
[0, 0, 1, 0],
|
|
[0, 0, 0, 1]], dtype=torch.float32, device=device)
|
|
|
|
@torch.no_grad()
|
|
def batch_random_rotation_translation(b, t, device=None):
|
|
m = np.random.normal(size=[b, 3, 3])
|
|
m[:, 1] = np.cross(m[:, 0], m[:, 2])
|
|
m[:, 2] = np.cross(m[:, 0], m[:, 1])
|
|
m = m / np.linalg.norm(m, axis=2, keepdims=True)
|
|
m = np.pad(m, [[0, 0], [0, 1], [0, 1]], mode='constant')
|
|
m[:, 3, 3] = 1.0
|
|
m[:, :3, 3] = np.random.uniform(-t, t, size=[b, 3])
|
|
return torch.tensor(m, dtype=torch.float32, device=device)
|
|
|
|
@torch.no_grad()
|
|
def random_rotation_translation(t, device=None):
|
|
m = np.random.normal(size=[3, 3])
|
|
m[1] = np.cross(m[0], m[2])
|
|
m[2] = np.cross(m[0], m[1])
|
|
m = m / np.linalg.norm(m, axis=1, keepdims=True)
|
|
m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
|
|
m[3, 3] = 1.0
|
|
m[:3, 3] = np.random.uniform(-t, t, size=[3])
|
|
return torch.tensor(m, dtype=torch.float32, device=device)
|
|
|
|
|
|
@torch.no_grad()
|
|
def random_rotation(device=None):
|
|
m = np.random.normal(size=[3, 3])
|
|
m[1] = np.cross(m[0], m[2])
|
|
m[2] = np.cross(m[0], m[1])
|
|
m = m / np.linalg.norm(m, axis=1, keepdims=True)
|
|
m = np.pad(m, [[0, 1], [0, 1]], mode='constant')
|
|
m[3, 3] = 1.0
|
|
m[:3, 3] = np.array([0,0,0]).astype(np.float32)
|
|
return torch.tensor(m, dtype=torch.float32, device=device)
|
|
|
|
|
|
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return torch.sum(x*y, -1, keepdim=True)
|
|
|
|
|
|
def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
|
|
return torch.sqrt(torch.clamp(dot(x,x), min=eps))
|
|
|
|
|
|
def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
|
|
return x / length(x, eps)
|
|
|
|
|
|
def lr_schedule(iter, warmup_iter, scheduler_decay):
|
|
if iter < warmup_iter:
|
|
return iter / warmup_iter
|
|
return max(0.0, 10 ** (
|
|
-(iter - warmup_iter) * scheduler_decay))
|
|
|
|
|
|
def trans_depth(depth):
|
|
depth = depth[0].detach().cpu().numpy()
|
|
valid = depth > 0
|
|
depth[valid] -= depth[valid].min()
|
|
depth[valid] = ((depth[valid] / depth[valid].max()) * 255)
|
|
return depth.astype('uint8')
|
|
|
|
|
|
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None):
|
|
assert isinstance(input, torch.Tensor)
|
|
if posinf is None:
|
|
posinf = torch.finfo(input.dtype).max
|
|
if neginf is None:
|
|
neginf = torch.finfo(input.dtype).min
|
|
assert nan == 0
|
|
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
|
|
|
|
|
def load_item(filepath):
|
|
with open(filepath, 'r') as f:
|
|
items = [name.strip() for name in f.readlines()]
|
|
return set(items)
|
|
|
|
def load_prompt(filepath):
|
|
uuid2prompt = {}
|
|
with open(filepath, 'r') as f:
|
|
for line in f.readlines():
|
|
list_line = line.split(',')
|
|
uuid2prompt[list_line[0]] = ','.join(list_line[1:]).strip()
|
|
return uuid2prompt
|
|
|
|
def resize_and_center_image(image_tensor, scale=0.95, c = 0, shift = 0, rgb=False, aug_shift = 0):
|
|
if scale == 1:
|
|
return image_tensor
|
|
B, C, H, W = image_tensor.shape
|
|
new_H, new_W = int(H * scale), int(W * scale)
|
|
resized_image = torch.nn.functional.interpolate(image_tensor, size=(new_H, new_W), mode='bilinear', align_corners=False).squeeze(0)
|
|
background = torch.zeros_like(image_tensor) + c
|
|
start_y, start_x = (H - new_H) // 2, (W - new_W) // 2
|
|
if shift == 0:
|
|
background[:, :, start_y:start_y + new_H, start_x:start_x + new_W] = resized_image
|
|
else:
|
|
for i in range(B):
|
|
randx = random.randint(-shift, shift)
|
|
randy = random.randint(-shift, shift)
|
|
if rgb == True:
|
|
if i == 0 or i==2 or i==4:
|
|
randx = 0
|
|
randy = 0
|
|
background[i, :, start_y+randy:start_y + new_H+randy, start_x+randx:start_x + new_W+randx] = resized_image[i]
|
|
if aug_shift == 0:
|
|
return background
|
|
for i in range(B):
|
|
for j in range(C):
|
|
background[i, j, :, :] += (random.random() - 0.5)*2 * aug_shift / 255
|
|
return background
|
|
|
|
def get_tri(triview_color, dim = 1, blender=True, c = 0, scale=0.95, shift = 0, fix = False, rgb=False, aug_shift = 0):
|
|
|
|
|
|
triview_color = resize_and_center_image(triview_color, scale=scale, c = c, shift=shift,rgb=rgb, aug_shift = aug_shift)
|
|
if blender is False:
|
|
triview_color0 = torch.rot90(triview_color[0],k=2,dims=[1,2])
|
|
triview_color1 = torch.rot90(triview_color[4],k=1,dims=[1,2]).flip(2).flip(1)
|
|
triview_color2 = torch.rot90(triview_color[5],k=1,dims=[1,2]).flip(2)
|
|
triview_color3 = torch.rot90(triview_color[3],k=2,dims=[1,2]).flip(2)
|
|
triview_color4 = torch.rot90(triview_color[1],k=3,dims=[1,2]).flip(1)
|
|
triview_color5 = torch.rot90(triview_color[2],k=3,dims=[1,2]).flip(1).flip(2)
|
|
else:
|
|
triview_color0 = torch.rot90(triview_color[2],k=2,dims=[1,2])
|
|
triview_color1 = torch.rot90(triview_color[4],k=0,dims=[1,2]).flip(2).flip(1)
|
|
triview_color2 = torch.rot90(torch.rot90(triview_color[0],k=3,dims=[1,2]).flip(2), k=2,dims=[1,2])
|
|
triview_color3 = torch.rot90(torch.rot90(triview_color[5],k=2,dims=[1,2]).flip(2), k=2,dims=[1,2])
|
|
triview_color4 = torch.rot90(triview_color[1],k=2,dims=[1,2]).flip(1).flip(1).flip(2)
|
|
triview_color5 = torch.rot90(triview_color[3],k=1,dims=[1,2]).flip(1).flip(2)
|
|
if fix == True:
|
|
triview_color0[1] = triview_color0[1] * 0
|
|
triview_color0[2] = triview_color0[2] * 0
|
|
triview_color3[1] = triview_color3[1] * 0
|
|
triview_color3[2] = triview_color3[2] * 0
|
|
|
|
triview_color1[0] = triview_color1[0] * 0
|
|
triview_color1[1] = triview_color1[1] * 0
|
|
triview_color4[0] = triview_color4[0] * 0
|
|
triview_color4[1] = triview_color4[1] * 0
|
|
|
|
triview_color2[0] = triview_color2[0] * 0
|
|
triview_color2[2] = triview_color2[2] * 0
|
|
triview_color5[0] = triview_color5[0] * 0
|
|
triview_color5[2] = triview_color5[2] * 0
|
|
color_tensor1_gt = torch.cat((triview_color0, triview_color1, triview_color2), dim=2)
|
|
color_tensor2_gt = torch.cat((triview_color3, triview_color4, triview_color5), dim=2)
|
|
color_tensor_gt = torch.cat((color_tensor1_gt, color_tensor2_gt), dim = dim)
|
|
return color_tensor_gt
|
|
|
|
|