Spaces:
Running
on
L40S
Running
on
L40S
import argparse | |
import os | |
import cv2 | |
import glob | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from typing import Dict, Optional, List | |
from omegaconf import OmegaConf, DictConfig | |
from PIL import Image | |
from pathlib import Path | |
from dataclasses import dataclass | |
from typing import Dict | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
import torchvision.transforms.functional as TF | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from torchvision.utils import make_grid, save_image | |
from accelerate.utils import set_seed | |
from tqdm.auto import tqdm | |
from einops import rearrange, repeat | |
from multiview.pipeline_multiclass import StableUnCLIPImg2ImgPipeline | |
weight_dtype = torch.float16 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def tensor_to_numpy(tensor): | |
return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" | |
def nonzero_normalize_depth(depth, mask=None): | |
if mask.max() > 0: # not all transparent | |
nonzero_depth_min = depth[mask > 0].min() | |
else: | |
nonzero_depth_min = 0 | |
depth = (depth - nonzero_depth_min) / depth.max() | |
return np.clip(depth, 0, 1) | |
class SingleImageData(Dataset): | |
def __init__(self, | |
input_dir, | |
prompt_embeds_path='./multiview/fixed_prompt_embeds_6view', | |
image_transforms=[], | |
total_views=6, | |
ext="png", | |
return_paths=True, | |
) -> None: | |
"""Create a dataset from a folder of images. | |
If you pass in a root directory it will be searched for images | |
ending in ext (ext can be a list) | |
""" | |
self.input_dir = Path(input_dir) | |
self.return_paths = return_paths | |
self.total_views = total_views | |
self.paths = glob.glob(str(self.input_dir / f'*.{ext}')) | |
print('============= length of dataset %d =============' % len(self.paths)) | |
self.tform = image_transforms | |
self.normal_text_embeds = torch.load(f'{prompt_embeds_path}/normal_embeds.pt') | |
self.color_text_embeds = torch.load(f'{prompt_embeds_path}/clr_embeds.pt') | |
def __len__(self): | |
return len(self.paths) | |
def load_rgb(self, path, color): | |
img = plt.imread(path) | |
img = Image.fromarray(np.uint8(img * 255.)) | |
new_img = Image.new("RGB", (1024, 1024)) | |
# white background | |
width, height = img.size | |
new_width = int(width / height * 1024) | |
img = img.resize((new_width, 1024)) | |
new_img.paste((255, 255, 255), (0, 0, 1024, 1024)) | |
offset = (1024 - new_width) // 2 | |
new_img.paste(img, (offset, 0)) | |
return new_img | |
def __getitem__(self, index): | |
data = {} | |
filename = self.paths[index] | |
if self.return_paths: | |
data["path"] = str(filename) | |
color = 1.0 | |
cond_im_rgb = self.process_im(self.load_rgb(filename, color)) | |
cond_im_rgb = torch.stack([cond_im_rgb] * self.total_views, dim=0) | |
data["image_cond_rgb"] = cond_im_rgb | |
data["normal_prompt_embeddings"] = self.normal_text_embeds | |
data["color_prompt_embeddings"] = self.color_text_embeds | |
data["filename"] = filename.split('/')[-1] | |
return data | |
def process_im(self, im): | |
im = im.convert("RGB") | |
return self.tform(im) | |
def tensor_to_image(self, tensor): | |
return Image.fromarray(np.uint8(tensor.numpy() * 255.)) | |
class TestConfig: | |
pretrained_model_name_or_path: str | |
pretrained_unet_path:Optional[str] | |
revision: Optional[str] | |
validation_dataset: Dict | |
save_dir: str | |
seed: Optional[int] | |
validation_batch_size: int | |
dataloader_num_workers: int | |
save_mode: str | |
local_rank: int | |
pipe_kwargs: Dict | |
pipe_validation_kwargs: Dict | |
unet_from_pretrained_kwargs: Dict | |
validation_grid_nrow: int | |
camera_embedding_lr_mult: float | |
num_views: int | |
camera_embedding_type: str | |
pred_type: str | |
regress_elevation: bool | |
enable_xformers_memory_efficient_attention: bool | |
cond_on_normals: bool | |
cond_on_colors: bool | |
regress_elevation: bool | |
regress_focal_length: bool | |
def convert_to_numpy(tensor): | |
return tensor.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() | |
def save_image(tensor, fp): | |
ndarr = convert_to_numpy(tensor) | |
save_image_numpy(ndarr, fp) | |
return ndarr | |
def save_image_numpy(ndarr, fp): | |
im = Image.fromarray(ndarr) | |
# pad to square | |
if im.size[0] != im.size[1]: | |
size = max(im.size) | |
new_im = Image.new("RGB", (size, size)) | |
# set to white | |
new_im.paste((255, 255, 255), (0, 0, size, size)) | |
new_im.paste(im, ((size - im.size[0]) // 2, (size - im.size[1]) // 2)) | |
im = new_im | |
# resize to 1024x1024 | |
im = im.resize((1024, 1024), Image.LANCZOS) | |
im.save(fp) | |
def run_multiview_infer(dataloader, pipeline, cfg: TestConfig, save_dir, num_levels=3): | |
if cfg.seed is None: | |
generator = None | |
else: | |
generator = torch.Generator(device=pipeline.unet.device).manual_seed(cfg.seed) | |
images_cond = [] | |
for _, batch in tqdm(enumerate(dataloader)): | |
torch.cuda.empty_cache() | |
images_cond.append(batch['image_cond_rgb'][:, 0].cuda()) | |
imgs_in = torch.cat([batch['image_cond_rgb']]*2, dim=0).cuda() | |
num_views = imgs_in.shape[1] | |
imgs_in = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W")# (B*Nv, 3, H, W) | |
target_h, target_w = imgs_in.shape[-2], imgs_in.shape[-1] | |
normal_prompt_embeddings, clr_prompt_embeddings = batch['normal_prompt_embeddings'].cuda(), batch['color_prompt_embeddings'].cuda() | |
prompt_embeddings = torch.cat([normal_prompt_embeddings, clr_prompt_embeddings], dim=0) | |
prompt_embeddings = rearrange(prompt_embeddings, "B Nv N C -> (B Nv) N C") | |
# B*Nv images | |
unet_out = pipeline( | |
imgs_in, None, prompt_embeds=prompt_embeddings, | |
generator=generator, guidance_scale=3.0, output_type='pt', num_images_per_prompt=1, | |
height=cfg.height, width=cfg.width, | |
num_inference_steps=40, eta=1.0, | |
num_levels=num_levels, | |
) | |
for level in range(num_levels): | |
out = unet_out[level].images | |
bsz = out.shape[0] // 2 | |
normals_pred = out[:bsz] | |
images_pred = out[bsz:] | |
cur_dir = save_dir | |
os.makedirs(cur_dir, exist_ok=True) | |
for i in range(bsz//num_views): | |
scene = batch['filename'][i].split('.')[0] | |
scene_dir = os.path.join(cur_dir, scene, f'level{level}') | |
os.makedirs(scene_dir, exist_ok=True) | |
img_in_ = images_cond[-1][i].to(out.device) | |
for j in range(num_views): | |
view = VIEWS[j] | |
idx = i*num_views + j | |
normal = normals_pred[idx] | |
color = images_pred[idx] | |
## save color and normal--------------------- | |
normal_filename = f"normal_{j}.png" | |
rgb_filename = f"color_{j}.png" | |
save_image(normal, os.path.join(scene_dir, normal_filename)) | |
save_image(color, os.path.join(scene_dir, rgb_filename)) | |
torch.cuda.empty_cache() | |
def load_multiview_pipeline(cfg): | |
pipeline = StableUnCLIPImg2ImgPipeline.from_pretrained( | |
cfg.pretrained_path, | |
torch_dtype=torch.float16,) | |
pipeline.unet.enable_xformers_memory_efficient_attention() | |
if torch.cuda.is_available(): | |
pipeline.to(device) | |
return pipeline | |
def main( | |
cfg: TestConfig | |
): | |
set_seed(cfg.seed) | |
pipeline = load_multiview_pipeline(cfg) | |
if torch.cuda.is_available(): | |
pipeline.to(device) | |
image_transforms = [transforms.Resize(int(max(cfg.height, cfg.width))), | |
transforms.CenterCrop((cfg.height, cfg.width)), | |
transforms.ToTensor(), | |
transforms.Lambda(lambda x: x * 2. - 1), | |
] | |
image_transforms = transforms.Compose(image_transforms) | |
dataset = SingleImageData(image_transforms=image_transforms, input_dir=cfg.input_dir, total_views=cfg.num_views) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, batch_size=1, shuffle=False, num_workers=1 | |
) | |
os.makedirs(cfg.output_dir, exist_ok=True) | |
with torch.no_grad(): | |
run_multiview_infer(dataloader, pipeline, cfg, cfg.output_dir, num_levels=cfg.num_levels) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--seed", type=int, default=42) | |
parser.add_argument("--num_views", type=int, default=6) | |
parser.add_argument("--num_levels", type=int, default=3) | |
parser.add_argument("--pretrained_path", type=str, default='./ckpt/StdGEN-multiview-1024') | |
parser.add_argument("--height", type=int, default=1024) | |
parser.add_argument("--width", type=int, default=576) | |
parser.add_argument("--input_dir", type=str, default='./result/apose') | |
parser.add_argument("--output_dir", type=str, default='./result/multiview') | |
cfg = parser.parse_args() | |
if cfg.num_views == 6: | |
VIEWS = ['front', 'front_right', 'right', 'back', 'left', 'front_left'] | |
else: | |
raise NotImplementedError(f"Number of views {cfg.num_views} not supported") | |
main(cfg) | |