Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from torch.utils.data import DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from torchvision import transforms | |
from tqdm import tqdm | |
import os | |
from PIL import Image | |
import numpy as np | |
import argparse | |
import itertools | |
from skimage.metrics import peak_signal_noise_ratio as psnr_loss | |
from skimage.metrics import structural_similarity as ssim_loss | |
from dataset.augmentation import center_crop_arr | |
from dataset.build import build_dataset | |
from tokenizer.tokenizer_image.vq_model import VQ_models | |
def create_npz_from_sample_folder(sample_dir, num=50000): | |
""" | |
Builds a single .npz file from a folder of .png samples. | |
""" | |
samples = [] | |
for i in tqdm(range(num), desc="Building .npz file from samples"): | |
sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") | |
sample_np = np.asarray(sample_pil).astype(np.uint8) | |
samples.append(sample_np) | |
samples = np.stack(samples) | |
assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) | |
npz_path = f"{sample_dir}.npz" | |
np.savez(npz_path, arr_0=samples) | |
print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") | |
return npz_path | |
def main(args): | |
# Setup PyTorch: | |
assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" | |
torch.set_grad_enabled(False) | |
# Setup DDP: | |
dist.init_process_group("nccl") | |
rank = dist.get_rank() | |
device = rank % torch.cuda.device_count() | |
seed = args.global_seed * dist.get_world_size() + rank | |
torch.manual_seed(seed) | |
torch.cuda.set_device(device) | |
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") | |
# create and load model | |
vq_model = VQ_models[args.vq_model]( | |
codebook_size=args.codebook_size, | |
codebook_embed_dim=args.codebook_embed_dim) | |
vq_model.to(device) | |
vq_model.eval() | |
checkpoint = torch.load(args.vq_ckpt, map_location="cpu") | |
if "ema" in checkpoint: # ema | |
model_weight = checkpoint["ema"] | |
elif "model" in checkpoint: # ddp | |
model_weight = checkpoint["model"] | |
elif "state_dict" in checkpoint: | |
model_weight = checkpoint["state_dict"] | |
else: | |
raise Exception("please check model weight") | |
vq_model.load_state_dict(model_weight) | |
del checkpoint | |
# Create folder to save samples: | |
folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}" | |
f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}") | |
sample_folder_dir = f"{args.sample_dir}/{folder_name}" | |
if rank == 0: | |
os.makedirs(sample_folder_dir, exist_ok=True) | |
print(f"Saving .png samples at {sample_folder_dir}") | |
dist.barrier() | |
# Setup data: | |
transform = transforms.Compose([ | |
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) | |
]) | |
if args.dataset == 'imagenet': | |
dataset = build_dataset(args, transform=transform) | |
num_fid_samples = 50000 | |
elif args.dataset == 'coco': | |
dataset = build_dataset(args, transform=transform) | |
num_fid_samples = 5000 | |
elif args.dataset == 'imagenet_code': | |
dataset = build_dataset(args) | |
num_fid_samples = 50000 | |
else: | |
raise Exception("please check dataset") | |
sampler = DistributedSampler( | |
dataset, | |
num_replicas=dist.get_world_size(), | |
rank=rank, | |
shuffle=False, | |
seed=args.global_seed | |
) | |
loader = DataLoader( | |
dataset, | |
batch_size=args.per_proc_batch_size, | |
shuffle=False, | |
sampler=sampler, | |
num_workers=args.num_workers, | |
pin_memory=True, | |
drop_last=False | |
) | |
# Figure out how many samples we need to generate on each GPU and how many iterations we need to run: | |
n = args.per_proc_batch_size | |
global_batch_size = n * dist.get_world_size() | |
psnr_val_rgb = [] | |
ssim_val_rgb = [] | |
loader = tqdm(loader) if rank == 0 else loader | |
total = 0 | |
# for x, _ in loader: | |
for batch in loader: | |
x = batch['condition_imgs'].repeat(1,3,1,1) | |
# import pdb | |
# pdb.set_trace() | |
if args.image_size_eval != args.image_size: | |
rgb_gts = F.interpolate(x, size=(args.image_size_eval, args.image_size_eval), mode='bicubic') | |
else: | |
rgb_gts = x | |
rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1] | |
x = x.to(device, non_blocking=True) | |
with torch.no_grad(): | |
latent, _, [_, _, indices] = vq_model.encode(x.float()) | |
import pdb;pdb.set_trace() | |
samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1] | |
if args.image_size_eval != args.image_size: | |
samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic') | |
samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() | |
# Save samples to disk as individual .png files | |
for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)): | |
index = i * dist.get_world_size() + rank + total | |
# Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") | |
# metric | |
rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1] | |
psnr = psnr_loss(rgb_restored, rgb_gt) | |
ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1) | |
psnr_val_rgb.append(psnr) | |
ssim_val_rgb.append(ssim) | |
total += global_batch_size | |
# ------------------------------------ | |
# Summary | |
# ------------------------------------ | |
# Make sure all processes have finished saving their samples | |
dist.barrier() | |
world_size = dist.get_world_size() | |
gather_psnr_val = [None for _ in range(world_size)] | |
gather_ssim_val = [None for _ in range(world_size)] | |
dist.all_gather_object(gather_psnr_val, psnr_val_rgb) | |
dist.all_gather_object(gather_ssim_val, ssim_val_rgb) | |
if rank == 0: | |
gather_psnr_val = list(itertools.chain(*gather_psnr_val)) | |
gather_ssim_val = list(itertools.chain(*gather_ssim_val)) | |
psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val) | |
ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val) | |
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) | |
result_file = f"{sample_folder_dir}_results.txt" | |
print("writing results to {}".format(result_file)) | |
with open(result_file, 'w') as f: | |
print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f) | |
create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) | |
print("Done.") | |
dist.barrier() | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--data-path", type=str, default=None) | |
parser.add_argument("--code-path", type=str, required=True) | |
parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco', 'imagenet_code'], default='imagenet') | |
parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") | |
parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") | |
parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") | |
parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") | |
parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256) | |
parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256) | |
parser.add_argument("--sample-dir", type=str, default="reconstructions") | |
parser.add_argument("--per-proc-batch-size", type=int, default=32) | |
parser.add_argument("--global-seed", type=int, default=0) | |
parser.add_argument("--num-workers", type=int, default=4) | |
parser.add_argument("--condition", type=str, choices=['canny', 'hed'], default='canny') | |
parser.add_argument("--get-condition-img", type=bool, default=False) | |
args = parser.parse_args() | |
main(args) |