|
import argparse |
|
import sys |
|
import os |
|
|
|
from typing import Dict, Optional, Tuple, List |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
from dataclasses import dataclass |
|
from collections import defaultdict |
|
import torch |
|
import torch.utils.checkpoint |
|
from torchvision.utils import make_grid, save_image |
|
from accelerate.utils import set_seed |
|
from tqdm.auto import tqdm |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from rembg import remove, new_session |
|
import pdb |
|
from mvdiffusion.pipelines.pipeline_mvdiffusion_unclip import StableUnCLIPImg2ImgPipeline |
|
from econdataset import SMPLDataset |
|
from reconstruct import ReMesh |
|
providers = [ |
|
('CUDAExecutionProvider', { |
|
'device_id': 0, |
|
'arena_extend_strategy': 'kSameAsRequested', |
|
'gpu_mem_limit': 8 * 1024 * 1024 * 1024, |
|
'cudnn_conv_algo_search': 'HEURISTIC', |
|
}) |
|
] |
|
session = new_session(providers=providers) |
|
|
|
weight_dtype = torch.float16 |
|
def tensor_to_numpy(tensor): |
|
return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
|
|
|
|
@dataclass |
|
class TestConfig: |
|
pretrained_model_name_or_path: str |
|
revision: Optional[str] |
|
validation_dataset: Dict |
|
save_dir: str |
|
seed: Optional[int] |
|
validation_batch_size: int |
|
dataloader_num_workers: int |
|
|
|
save_mode: str |
|
local_rank: int |
|
|
|
pipe_kwargs: Dict |
|
pipe_validation_kwargs: Dict |
|
unet_from_pretrained_kwargs: Dict |
|
validation_guidance_scales: float |
|
validation_grid_nrow: int |
|
|
|
num_views: int |
|
enable_xformers_memory_efficient_attention: bool |
|
with_smpl: Optional[bool] |
|
|
|
recon_opt: Dict |
|
|
|
|
|
def convert_to_numpy(tensor): |
|
return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
|
|
def convert_to_pil(tensor): |
|
return Image.fromarray(convert_to_numpy(tensor)) |
|
|
|
def save_image(tensor, fp): |
|
ndarr = convert_to_numpy(tensor) |
|
|
|
save_image_numpy(ndarr, fp) |
|
return ndarr |
|
|
|
def save_image_numpy(ndarr, fp): |
|
im = Image.fromarray(ndarr) |
|
im.save(fp) |
|
|
|
def run_inference(dataloader, econdata, pipeline, carving, cfg: TestConfig, save_dir): |
|
pipeline.set_progress_bar_config(disable=True) |
|
|
|
if cfg.seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed) |
|
|
|
images_cond, pred_cat = [], defaultdict(list) |
|
for case_id, batch in tqdm(enumerate(dataloader)): |
|
images_cond.append(batch['imgs_in'][:, 0]) |
|
|
|
imgs_in = torch.cat([batch['imgs_in']]*2, dim=0) |
|
num_views = imgs_in.shape[1] |
|
imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W") |
|
if cfg.with_smpl: |
|
smpl_in = torch.cat([batch['smpl_imgs_in']]*2, dim=0) |
|
smpl_in = rearrange(smpl_in, "B Nv C H W -> (B Nv) C H W") |
|
else: |
|
smpl_in = None |
|
|
|
normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'], batch['color_prompt_embeddings'] |
|
prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0) |
|
prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") |
|
|
|
with torch.autocast("cuda"): |
|
|
|
guidance_scale = cfg.validation_guidance_scales |
|
unet_out = pipeline( |
|
imgs_in, None, prompt_embeds=prompt_embeddings, |
|
dino_feature=None, smpl_in=smpl_in, |
|
generator=generator, guidance_scale=guidance_scale, output_type='pt', num_images_per_prompt=1, |
|
**cfg.pipe_validation_kwargs |
|
) |
|
|
|
out = unet_out.images |
|
bsz = out.shape[0] // 2 |
|
|
|
normals_pred = out[:bsz] |
|
images_pred = out[bsz:] |
|
if cfg.save_mode == 'concat': |
|
pred_cat[f"cfg{guidance_scale:.1f}"].append(torch.cat([normals_pred, images_pred], dim=-1)) |
|
cur_dir = os.path.join(save_dir, f"cropsize-{cfg.validation_dataset.crop_size}-cfg{guidance_scale:.1f}-seed{cfg.seed}-smpl-{cfg.with_smpl}") |
|
os.makedirs(cur_dir, exist_ok=True) |
|
for i in range(bsz//num_views): |
|
scene = batch['filename'][i].split('.')[0] |
|
|
|
img_in_ = images_cond[-1][i].to(out.device) |
|
vis_ = [img_in_] |
|
for j in range(num_views): |
|
idx = i*num_views + j |
|
normal = normals_pred[idx] |
|
color = images_pred[idx] |
|
|
|
vis_.append(color) |
|
vis_.append(normal) |
|
|
|
out_filename = f"{cur_dir}/{scene}.png" |
|
vis_ = torch.stack(vis_, dim=0) |
|
vis_ = make_grid(vis_, nrow=len(vis_), padding=0, value_range=(0, 1)) |
|
save_image(vis_, out_filename) |
|
elif cfg.save_mode == 'rgb': |
|
for i in range(bsz//num_views): |
|
scene = batch['filename'][i].split('.')[0] |
|
|
|
img_in_ = images_cond[-1][i].to(out.device) |
|
normals, colors = [], [] |
|
for j in range(num_views): |
|
idx = i*num_views + j |
|
normal = normals_pred[idx] |
|
if j == 0: |
|
color = imgs_in[0].to(out.device) |
|
else: |
|
color = images_pred[idx] |
|
if j in [3, 4]: |
|
normal = torch.flip(normal, dims=[2]) |
|
color = torch.flip(color, dims=[2]) |
|
|
|
colors.append(color) |
|
if j == 6: |
|
normal = F.interpolate(normal.unsqueeze(0), size=(256, 256), mode='bilinear', align_corners=False).squeeze(0) |
|
normals.append(normal) |
|
|
|
|
|
|
|
|
|
|
|
|
|
normals[0][:, :256, 256:512] = normals[-1] |
|
|
|
colors = [remove(convert_to_pil(tensor), session=session) for tensor in colors[:6]] |
|
normals = [remove(convert_to_pil(tensor), session=session) for tensor in normals[:6]] |
|
pose = econdata.__getitem__(case_id) |
|
carving.optimize_case(scene, pose, colors, normals) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
def load_pshuman_pipeline(cfg): |
|
pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained(cfg.pretrained_model_name_or_path, torch_dtype=weight_dtype) |
|
pipeline.unet.enable_xformers_memory_efficient_attention() |
|
if torch.cuda.is_available(): |
|
pipeline.to('cuda') |
|
return pipeline |
|
|
|
def main( |
|
cfg: TestConfig |
|
): |
|
|
|
|
|
if cfg.seed is not None: |
|
set_seed(cfg.seed) |
|
pipeline = load_pshuman_pipeline(cfg) |
|
|
|
|
|
if cfg.with_smpl: |
|
from mvdiffusion.data.testdata_with_smpl import SingleImageDataset |
|
else: |
|
from mvdiffusion.data.single_image_dataset import SingleImageDataset |
|
|
|
|
|
validation_dataset = SingleImageDataset( |
|
**cfg.validation_dataset |
|
) |
|
validation_dataloader = torch.utils.data.DataLoader( |
|
validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False, num_workers=cfg.dataloader_num_workers |
|
) |
|
dataset_param = {'image_dir': validation_dataset.root_dir, 'seg_dir': None, 'colab': False, 'has_det': True, 'hps_type': 'pixie'} |
|
econdata = SMPLDataset(dataset_param, device='cuda') |
|
|
|
carving = ReMesh(cfg.recon_opt, econ_dataset=econdata) |
|
run_inference(validation_dataloader, econdata, pipeline, carving, cfg, cfg.save_dir) |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', type=str, required=True) |
|
args, extras = parser.parse_known_args() |
|
from utils.misc import load_config |
|
|
|
|
|
cfg = load_config(args.config, cli_args=extras) |
|
schema = OmegaConf.structured(TestConfig) |
|
cfg = OmegaConf.merge(schema, cfg) |
|
main(cfg) |
|
|