Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import io | |
import torch | |
import numpy as np | |
from omegaconf import OmegaConf | |
import PIL.Image | |
from PIL import Image | |
import rembg | |
from dva.ray_marcher import RayMarcher | |
from dva.io import load_from_config | |
from dva.utils import to_device | |
from dva.visualize import visualize_primvolume, visualize_video_primvolume | |
from models.diffusion import create_diffusion | |
import logging | |
from tqdm import tqdm | |
import mcubes | |
import xatlas | |
import nvdiffrast.torch as dr | |
import cv2 | |
from scipy.ndimage import binary_dilation, binary_erosion | |
from sklearn.neighbors import NearestNeighbors | |
from utils.meshutils import clean_mesh, decimate_mesh | |
from utils.mesh import Mesh | |
from utils.uv_unwrap import box_projection_uv_unwrap, compute_vertex_normal | |
logger = logging.getLogger("inference.py") | |
glctx = dr.RasterizeCudaContext() | |
def remove_background(image: PIL.Image.Image, | |
rembg_session = None, | |
force: bool = False, | |
**rembg_kwargs, | |
) -> PIL.Image.Image: | |
do_remove = True | |
if image.mode == "RGBA" and image.getextrema()[3][0] < 255: | |
do_remove = False | |
do_remove = do_remove or force | |
if do_remove: | |
image = rembg.remove(image, session=rembg_session, **rembg_kwargs) | |
return image | |
def resize_foreground( | |
image: PIL.Image.Image, | |
ratio: float, | |
) -> PIL.Image.Image: | |
image = np.array(image) | |
assert image.shape[-1] == 4 | |
alpha = np.where(image[..., 3] > 0) | |
y1, y2, x1, x2 = ( | |
alpha[0].min(), | |
alpha[0].max(), | |
alpha[1].min(), | |
alpha[1].max(), | |
) | |
# crop the foreground | |
fg = image[y1:y2, x1:x2] | |
# pad to square | |
size = max(fg.shape[0], fg.shape[1]) | |
ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 | |
ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 | |
new_image = np.pad( | |
fg, | |
((ph0, ph1), (pw0, pw1), (0, 0)), | |
mode="constant", | |
constant_values=((0, 0), (0, 0), (0, 0)), | |
) | |
# compute padding according to the ratio | |
new_size = int(new_image.shape[0] / ratio) | |
# pad to size, double side | |
ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 | |
ph1, pw1 = new_size - size - ph0, new_size - size - pw0 | |
new_image = np.pad( | |
new_image, | |
((ph0, ph1), (pw0, pw1), (0, 0)), | |
mode="constant", | |
constant_values=((0, 0), (0, 0), (0, 0)), | |
) | |
new_image = PIL.Image.fromarray(new_image) | |
return new_image | |
def extract_texmesh(args, model, output_path, device): | |
# Prepare directory | |
ins_dir = output_path | |
# Noise Filter | |
raw_srt_param = model.srt_param.clone() | |
raw_feat_param = model.feat_param.clone() | |
prim_position = raw_srt_param[:, 1:4] | |
prim_scale = raw_srt_param[:, 0:1] | |
dist = torch.sqrt(torch.sum((prim_position[:, None, :] - prim_position[None, :, :]) ** 2, dim=-1)) | |
dist += torch.eye(prim_position.shape[0]).to(raw_srt_param) | |
min_dist, min_indices = dist.min(1) | |
dst_prim_scale = prim_scale[min_indices, :] | |
min_scale_converage = prim_scale * 1. + dst_prim_scale * 1. | |
prim_mask = min_dist < min_scale_converage[:, 0] | |
filtered_srt_param = raw_srt_param[prim_mask, :] | |
filtered_feat_param = raw_feat_param[prim_mask, ...] | |
model.srt_param.data = filtered_srt_param | |
model.feat_param.data = filtered_feat_param | |
print(f'[INFO] Mesh Extraction on PrimX: srt={model.srt_param.shape} feat={model.feat_param.shape}') | |
# Get SDFs | |
with torch.no_grad(): | |
xx = torch.linspace(-1, 1, args.mc_resolution, device=device) | |
pts = torch.stack(torch.meshgrid(xx, xx, xx, indexing='ij'), dim=-1).reshape(-1,3) | |
chunks = torch.split(pts, args.batch_size) | |
dists = [] | |
for chunk_pts in tqdm(chunks): | |
preds = model(chunk_pts) | |
dists.append(preds['sdf'].detach()) | |
dists = torch.cat(dists, dim=0) | |
grid = dists.reshape(args.mc_resolution, args.mc_resolution, args.mc_resolution) | |
# Meshify | |
vertices, triangles = mcubes.marching_cubes(grid.cpu().numpy(), 0.0) | |
# Resize + recenter | |
b_min_np = np.array([-1., -1., -1.]) | |
b_max_np = np.array([ 1., 1., 1.]) | |
vertices = vertices / (args.mc_resolution - 1.0) * (b_max_np - b_min_np) + b_min_np | |
vertices, triangles = clean_mesh(vertices, triangles, min_f=8, min_d=5, repair=True, remesh=False) | |
if args.decimate > 0 and triangles.shape[0] > args.decimate: | |
vertices, triangles = decimate_mesh(vertices, triangles, args.decimate, remesh=args.remesh) | |
h0 = 1024 | |
w0 = 1024 | |
ssaa = 1 | |
fp16 = True | |
v_np = vertices.astype(np.float32) | |
f_np = triangles.astype(np.int64) | |
v = torch.from_numpy(vertices).float().contiguous().to(device) | |
f = torch.from_numpy(triangles.astype(np.int64)).to(torch.int64).contiguous().to(device) | |
if args.fast_unwrap: | |
print(f'[INFO] running box-based fast unwrapping to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}') | |
v_normal = compute_vertex_normal(v, f) | |
uv, indices = box_projection_uv_unwrap(v, v_normal, f, 0.02) | |
indv_v = v[f].reshape(-1, 3) | |
indv_faces = torch.arange(indv_v.shape[0], device=device, dtype=f.dtype).reshape(-1, 3) | |
uv_flat = uv[indices].reshape((-1, 2)) | |
v = indv_v.contiguous() | |
f = indv_faces.contiguous() | |
ft_np = f.cpu().numpy() | |
vt_np = uv_flat.cpu().numpy() | |
else: | |
print(f'[INFO] running xatlas to unwrap UVs for mesh: v={v_np.shape} f={f_np.shape}') | |
# unwrap uv in contracted space | |
atlas = xatlas.Atlas() | |
atlas.add_mesh(v_np, f_np) | |
chart_options = xatlas.ChartOptions() | |
chart_options.max_iterations = 0 # disable merge_chart for faster unwrap... | |
pack_options = xatlas.PackOptions() | |
atlas.generate(chart_options=chart_options, pack_options=pack_options) | |
_, ft_np, vt_np = atlas[0] # [N], [M, 3], [N, 2] | |
vt = torch.from_numpy(vt_np.astype(np.float32)).float().contiguous().to(device) | |
ft = torch.from_numpy(ft_np.astype(np.int64)).int().contiguous().to(device) | |
uv = vt * 2.0 - 1.0 # uvs to range [-1, 1] | |
uv = torch.cat((uv, torch.zeros_like(uv[..., :1]), torch.ones_like(uv[..., :1])), dim=-1) # [N, 4] | |
if ssaa > 1: | |
h = int(h0 * ssaa) | |
w = int(w0 * ssaa) | |
else: | |
h, w = h0, w0 | |
rast, _ = dr.rasterize(glctx, uv.unsqueeze(0), ft, (h, w)) # [1, h, w, 4] | |
xyzs, _ = dr.interpolate(v.unsqueeze(0), rast, f.int()) # [1, h, w, 3] | |
mask, _ = dr.interpolate(torch.ones_like(v[:, :1]).unsqueeze(0), rast, f.int()) # [1, h, w, 1] | |
# masked query | |
xyzs = xyzs.view(-1, 3) | |
mask = (mask > 0).view(-1) | |
feats = torch.zeros(h * w, 6, device=device, dtype=torch.float32) | |
if mask.any(): | |
xyzs = xyzs[mask] # [M, 3] | |
# batched inference to avoid OOM | |
all_feats = [] | |
head = 0 | |
chunk_size = args.batch_size | |
while head < xyzs.shape[0]: | |
tail = min(head + chunk_size, xyzs.shape[0]) | |
with torch.cuda.amp.autocast(enabled=fp16): | |
preds = model(xyzs[head:tail]) | |
# [R, G, B, NA, roughness, metallic] | |
all_feats.append(torch.concat([preds['tex'].float(), torch.zeros_like(preds['tex'])[..., 0:1].float(), preds['mat'].float()], dim=-1)) | |
head += chunk_size | |
feats[mask] = torch.cat(all_feats, dim=0) | |
feats = feats.view(h, w, -1) # 6 channels | |
mask = mask.view(h, w) | |
# quantize [0.0, 1.0] to [0, 255] | |
feats = feats.cpu().numpy() | |
feats = (feats * 255) | |
### NN search as a queer antialiasing ... | |
mask = mask.cpu().numpy() | |
inpaint_region = binary_dilation(mask, iterations=32) # pad width | |
inpaint_region[mask] = 0 | |
search_region = mask.copy() | |
not_search_region = binary_erosion(search_region, iterations=3) | |
search_region[not_search_region] = 0 | |
search_coords = np.stack(np.nonzero(search_region), axis=-1) | |
inpaint_coords = np.stack(np.nonzero(inpaint_region), axis=-1) | |
knn = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(search_coords) | |
_, indices = knn.kneighbors(inpaint_coords) | |
feats[tuple(inpaint_coords.T)] = feats[tuple(search_coords[indices[:, 0]].T)] | |
target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255) | |
target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb')) | |
model.srt_param.data = raw_srt_param | |
model.feat_param.data = raw_feat_param | |
def main(config): | |
logging.basicConfig(level=logging.INFO) | |
ddim_steps = config.inference.ddim | |
if ddim_steps > 0: | |
use_ddim = True | |
else: | |
use_ddim = False | |
cfg_scale = config.inference.get("cfg", 0.0) | |
inference_dir = f"{config.output_dir}/inference_folder" | |
os.makedirs(inference_dir, exist_ok=True) | |
amp = False | |
precision = config.inference.get("precision", 'fp16') | |
if precision == 'tf32': | |
precision_dtype = torch.float32 | |
elif precision == 'fp16': | |
amp = True | |
precision_dtype = torch.float16 | |
else: | |
raise NotImplementedError("{} precision is not supported".format(precision)) | |
device = torch.device(f"cuda:{0}") | |
seed = config.inference.seed | |
torch.manual_seed(seed) | |
torch.cuda.set_device(device) | |
model = load_from_config(config.model.generator) | |
vae = load_from_config(config.model.vae) | |
conditioner = load_from_config(config.model.conditioner) | |
vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu') | |
vae.load_state_dict(vae_state_dict['model_state_dict']) | |
if config.checkpoint_path: | |
state_dict = torch.load(config.checkpoint_path, map_location='cpu') | |
model.load_state_dict(state_dict['ema']) | |
vae = vae.to(device) | |
conditioner = conditioner.to(device) | |
model = model.to(device) | |
config.diffusion.pop("timestep_respacing") | |
if use_ddim: | |
respacing = "ddim{}".format(ddim_steps) | |
else: | |
respacing = "" | |
diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion) # default: 1000 steps, linear noise schedule | |
if use_ddim: | |
sample_fn = diffusion.ddim_sample_loop_progressive | |
else: | |
sample_fn = diffusion.p_sample_loop_progressive | |
if cfg_scale > 0: | |
fwd_fn = model.forward_with_cfg | |
else: | |
fwd_fn = model.forward | |
rm = RayMarcher( | |
config.image_height, | |
config.image_width, | |
**config.rm, | |
).to(device) | |
perchannel_norm = False | |
if "latent_mean" in config.model: | |
latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device) | |
latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device) | |
assert latent_mean.shape[-1] == config.model.generator.in_channels | |
perchannel_norm = True | |
model.eval() | |
examples_dir = config.inference.input_dir | |
img_list = os.listdir(examples_dir) | |
rembg_session = rembg.new_session() | |
logger.info(f"Starting Inference...") | |
for img_path in img_list: | |
full_img_path = os.path.join(examples_dir, img_path) | |
img_name = img_path[:-4] | |
current_output_dir = os.path.join(inference_dir, img_name) | |
os.makedirs(current_output_dir, exist_ok=True) | |
input_image = Image.open(full_img_path) | |
input_image = remove_background(input_image, rembg_session) | |
input_image = resize_foreground(input_image, 0.85) | |
raw_image = np.array(input_image) | |
mask = (raw_image[..., -1][..., None] > 0) * 1 | |
raw_image = raw_image[..., :3] * mask | |
input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device) | |
with torch.no_grad(): | |
latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4) | |
batch = {} | |
inf_bs = 1 | |
inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device) | |
y = conditioner.encoder(input_cond) | |
model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp) | |
if cfg_scale > 0: | |
model_kwargs['cfg_scale'] = cfg_scale | |
sampled_count = -1 | |
for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device | |
): | |
sampled_count += 1 | |
if not (sampled_count % 10 == 0 or sampled_count == diffusion.num_timesteps - 1): | |
continue | |
else: | |
recon_param = samples["sample"].reshape(inf_bs, config.model.num_prims, -1) | |
if perchannel_norm: | |
recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean | |
recon_srt_param = recon_param[:, :, 0:4] | |
recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64] | |
recon_feat_param_list = [] | |
# one-by-one to avoid oom | |
for inf_bidx in range(inf_bs): | |
if not perchannel_norm: | |
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf) | |
else: | |
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:])) | |
recon_feat_param_list.append(decoded.detach()) | |
recon_feat_param = torch.concat(recon_feat_param_list, dim=0) | |
# invert normalization | |
if not perchannel_norm: | |
recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05 | |
recon_feat_param[:, 0:1, ...] /= 5. | |
recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2. | |
recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1) | |
recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1) | |
visualize_primvolume("{}/dstep{:04d}_recon.jpg".format(current_output_dir, sampled_count), batch, recon_param, rm, device) | |
visualize_video_primvolume(current_output_dir, batch, recon_param, 60, rm, device) | |
prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()} | |
torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(current_output_dir)) | |
if config.inference.export_glb: | |
logger.info(f"Starting GLB Mesh Extraction...") | |
config.model.pop("vae") | |
config.model.pop("vae_checkpoint_path") | |
config.model.pop("conditioner") | |
config.model.pop("generator") | |
config.model.pop("latent_nf") | |
config.model.pop("latent_mean") | |
config.model.pop("latent_std") | |
model_primx = load_from_config(config.model) | |
for img_path in img_list: | |
img_name = img_path[:-4] | |
output_path = os.path.join(inference_dir, img_name) | |
denoise_param_path = os.path.join(inference_dir, img_name, 'denoised.pt') | |
ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict'] | |
model_primx.load_state_dict(ckpt_weight) | |
model_primx.to(device) | |
model_primx.eval() | |
with torch.no_grad(): | |
model_primx.srt_param[:, 1:4] *= 0.85 | |
extract_texmesh(config.inference, model_primx, output_path, device) | |
if __name__ == "__main__": | |
torch.backends.cudnn.benchmark = True | |
# manually enable tf32 to get speedup on A100 GPUs | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# set config | |
config = OmegaConf.load(str(sys.argv[1])) | |
config_cli = OmegaConf.from_cli(args_list=sys.argv[2:]) | |
if config_cli: | |
logger.info("overriding with following values from args:") | |
logger.info(OmegaConf.to_yaml(config_cli)) | |
config = OmegaConf.merge(config, config_cli) | |
main(config) | |