salad-demo / salad /spaghetti /ui /occ_inference.py
DveloperY0115's picture
init repo
801501a
from ..custom_types import *
from ..options import Options
from ..utils import train_utils, mcubes_meshing, files_utils, mesh_utils
from ..models.occ_gmm import Spaghetti
from ..models import models_utils
class Inference:
def get_occ_fun(self, z: T):
def forward(x: T) -> T:
nonlocal z
x = x.unsqueeze(0)
out = self.model.occupancy_network(x, z)[0, :]
out = 2 * out.sigmoid_() - 1
return out
if z.dim() == 2:
z = z.unsqueeze(0)
return forward
@models_utils.torch_no_grad
def get_mesh(self, z: T, res: int) -> Optional[T_Mesh]:
mesh = self.meshing.occ_meshing(self.get_occ_fun(z), res=res)
return mesh
def plot_occ(self, z: Union[T, TS], z_base, gmms: Optional[TS], fixed_items: T,
folder_name: str, res=200, verbose=False):
for i in range(len(z)):
mesh = self.get_mesh(z[i], res)
name = f'{fixed_items[i]:04d}'
if mesh is not None:
files_utils.export_mesh(mesh, f'{self.opt.cp_folder}/{folder_name}/occ/{name}')
files_utils.save_pickle(z_base[i].detach().cpu(), f'{self.opt.cp_folder}/{folder_name}/occ/{name}')
if gmms is not None:
files_utils.export_gmm(gmms, i, f'{self.opt.cp_folder}/{folder_name}/occ/{name}')
if verbose:
print(f'done {i + 1:d}/{len(z):d}')
def load_file(self, info_path, disclude: Optional[List[int]] = None):
info = files_utils.load_pickle(''.join(info_path))
keys = list(info['ids'].keys())
items = map(lambda x: int(x.split('_')[1]) if type(x) is str else x, keys)
items = torch.tensor(list(items), dtype=torch.int64, device=self.device)
zh, _, gmms_sanity, _ = self.model.get_embeddings(items)
gmms = [item for item in info['gmm']]
zh_ = []
split = []
gmm_mask = torch.ones(gmms[0].shape[2], dtype=torch.bool)
counter = 0
# gmms_ = [[] for _ in range(len(gmms))]
for i, key in enumerate(keys):
gaussian_inds = info['ids'][key]
if disclude is not None:
for j in range(len(gaussian_inds)):
gmm_mask[j + counter] = gaussian_inds[j] not in disclude
counter += len(gaussian_inds)
gaussian_inds = [ind for ind in gaussian_inds if ind not in disclude]
info['ids'][key] = gaussian_inds
gaussian_inds = torch.tensor(gaussian_inds, dtype=torch.int64)
zh_.append(zh[i, gaussian_inds])
split.append(len(split) + torch.ones(len(info['ids'][key]), dtype=torch.int64, device=self.device))
zh_ = torch.cat(zh_, dim=0).unsqueeze(0).to(self.device)
gmms = [item[:, :, gmm_mask].to(self.device) for item in info['gmm']]
return zh_, gmms, split, info['ids']
@models_utils.torch_no_grad
def get_z_from_file(self, info_path):
zh_, gmms, split, _ = self.load_file(info_path)
zh_ = self.model.merge_zh_step_a(zh_, [gmms])
zh, _ = self.model.affine_transformer.forward_with_attention(zh_)
# gmms_ = [torch.cat(item, dim=1).unsqueeze(0) for item in gmms_]
# zh, _ = self.model.merge_zh(zh_, [gmms])
return zh, zh_, gmms, torch.cat(split)
def plot_from_info(self, info_path, res):
zh, zh_, gmms, split = self.get_z_from_file(info_path)
mesh = self.get_mesh(zh[0], res, gmms)
if mesh is not None:
attention = self.get_attention_faces(mesh, zh, fixed_z=split)
else:
attention = None
return mesh, attention
@staticmethod
def combine_and_pad(zh_a: T, zh_b: T) -> Tuple[T, TN]:
if zh_a.shape[1] == zh_b.shape[1]:
mask = None
else:
pad_length = max(zh_a.shape[1], zh_b.shape[1])
mask = torch.zeros(2, pad_length, device=zh_a.device, dtype=torch.bool)
padding = torch.zeros(1, abs(zh_a.shape[1] - zh_b.shape[1]), zh_a.shape[-1], device=zh_a.device)
if zh_a.shape[1] > zh_b.shape[1]:
mask[1, zh_b.shape[1]:] = True
zh_b = torch.cat((zh_b, padding), dim=1)
else:
mask[0, zh_a.shape[1]:] = True
zh_a = torch.cat((zh_a, padding), dim=1)
return torch.cat((zh_a, zh_b), dim=0), mask
@staticmethod
def get_intersection_z(z_a: T, z_b: T) -> T:
diff = (z_a[0, :, None, :] - z_b[0, None]).abs().sum(-1)
diff_a = diff.min(1)[0].lt(.1)
diff_b = diff.min(0)[0].lt(.1)
if diff_a.shape[0] != diff_b.shape[0]:
padding = torch.zeros(abs(diff_a.shape[0] - diff_b.shape[0]), device=z_a.device, dtype=torch.bool)
if diff_a.shape[0] > diff_b.shape[0]:
diff_b = torch.cat((diff_b, padding))
else:
diff_a = torch.cat((diff_a, padding))
return torch.cat((diff_a, diff_b))
def get_attention_points(self, vs: T, zh: T, mask: TN = None, alpha: TN = None):
vs = vs.unsqueeze(0)
attention = self.model.occupancy_network.forward_attention(vs, zh, mask=mask, alpha=alpha)
attention = torch.stack(attention, 0).mean(0).mean(-1)
attention = attention.permute(1, 0, 2).reshape(attention.shape[1], -1)
attention_max = attention.argmax(-1)
return attention_max
@models_utils.torch_no_grad
def get_attention_faces(self, mesh: T_Mesh, zh: T, mask: TN = None, fixed_z: TN = None, alpha: TN = None):
coords = mesh[0][mesh[1]].mean(1).to(zh.device)
attention_max = self.get_attention_points(coords, zh, mask, alpha)
if fixed_z is not None:
attention_select = fixed_z[attention_max].cpu()
else:
attention_select = attention_max
return attention_select
@models_utils.torch_no_grad
def plot_folder(self, *folders, res: int = 256):
logger = train_utils.Logger()
for folder in folders:
paths = files_utils.collect(folder, '.pkl')
logger.start(len(paths))
for path in paths:
name = path[1]
out_path = f"{self.opt.cp_folder}/from_ui/{name}"
mesh, colors = self.plot_from_info(path, res)
if mesh is not None:
files_utils.export_mesh(mesh, out_path)
files_utils.export_list(colors.tolist(), f"{out_path}_faces")
logger.reset_iter()
logger.stop()
def get_zh_from_idx(self, items: T):
zh, _, gmms, __ = self.model.get_embeddings(items.to(self.device))
zh, attn_b = self.model.merge_zh(zh, gmms)
return zh, gmms
@property
def device(self):
return self.opt.device
def get_new_ids(self, folder_name, nums_sample):
names = [int(path[1]) for path in files_utils.collect(f'{self.opt.cp_folder}/{folder_name}/occ/', '.obj')]
ids = torch.arange(nums_sample)
if len(names) == 0:
return ids + self.opt.dataset_size
return ids + max(max(names) + 1, self.opt.dataset_size)
@models_utils.torch_no_grad
def random_plot(self, folder_name: str, nums_sample, res=200, verbose=False):
zh_base, gmms = self.model.random_samples(nums_sample)
zh, attn_b = self.model.merge_zh(zh_base, gmms)
numbers = self.get_new_ids(folder_name, nums_sample)
self.plot_occ(zh, zh_base, gmms, numbers, folder_name, verbose=verbose, res=res)
@models_utils.torch_no_grad
def plot(self, folder_name: str, nums_sample: int, verbose=False, res: int = 200):
if self.model.opt.dataset_size < nums_sample:
fixed_items = torch.arange(self.model.opt.dataset_size)
else:
fixed_items = torch.randint(low=0, high=self.opt.dataset_size, size=(nums_sample,))
zh_base, _, gmms = self.model.get_embeddings(fixed_items.to(self.device))
zh, attn_b = self.model.merge_zh(zh_base, gmms)
self.plot_occ(zh, zh_base, gmms, fixed_items, folder_name, verbose=verbose, res=res)
def get_mesh_from_mid(self, gmm, included: T, res: int) -> Optional[T_Mesh]:
if self.mid is None:
return None
gmm = [elem.to(self.device) for elem in gmm]
included = included.to(device=self.device)
mid_ = self.mid[included[:, 0], included[:, 1]].unsqueeze(0)
zh = self.model.merge_zh(mid_, gmm)[0]
mesh = self.get_mesh(zh[0], res)
return mesh
def set_items(self, items: T):
self.mid = items.to(self.device)
def __init__(self, opt: Options):
self.opt = opt
model: Tuple[Spaghetti, Options] = train_utils.model_lc(opt)
self.model, self.opt = model
self.model.eval()
self.mid: Optional[T] = None
self.gmms: Optional[TN] = None
self.meshing = mcubes_meshing.MarchingCubesMeshing(self.device, scale=1.)