# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import time import hashlib import random import pickle import copy import uuid from urllib.parse import urlparse import numpy as np import torch from tools import dnnlib from tools.utils.dataset import video_to_image_dataset_kwargs #---------------------------------------------------------------------------- class MetricOptions: def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True, gen_dataset_kwargs={}, generator_as_dataset=False): assert 0 <= rank < num_gpus self.G = G self.G_kwargs = dnnlib.EasyDict(G_kwargs) self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs) self.num_gpus = num_gpus self.rank = rank self.device = device if device is not None else torch.device('cuda', rank) self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor() self.cache = cache self.gen_dataset_kwargs = gen_dataset_kwargs self.generator_as_dataset = generator_as_dataset #---------------------------------------------------------------------------- _feature_detector_cache = dict() def get_feature_detector_name(url): return os.path.splitext(url.split('/')[-1])[0] def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False): assert 0 <= rank < num_gpus key = (url, device) if key not in _feature_detector_cache: is_leader = (rank == 0) if not is_leader and num_gpus > 1: torch.distributed.barrier() # leader goes first with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f: if urlparse(url).path.endswith('.pkl'): _feature_detector_cache[key] = pickle.load(f).to(device) else: _feature_detector_cache[key] = torch.jit.load(f).eval().to(device) if is_leader and num_gpus > 1: torch.distributed.barrier() # others follow return _feature_detector_cache[key] #---------------------------------------------------------------------------- class FeatureStats: def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None): self.capture_all = capture_all self.capture_mean_cov = capture_mean_cov self.max_items = max_items self.num_items = 0 self.num_features = None self.all_features = None self.raw_mean = None self.raw_cov = None def set_num_features(self, num_features): if self.num_features is not None: assert num_features == self.num_features else: self.num_features = num_features self.all_features = [] self.raw_mean = np.zeros([num_features], dtype=np.float64) self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64) def is_full(self): return (self.max_items is not None) and (self.num_items >= self.max_items) def append(self, x): x = np.asarray(x, dtype=np.float32) assert x.ndim == 2 if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items): if self.num_items >= self.max_items: return x = x[:self.max_items - self.num_items] self.set_num_features(x.shape[1]) self.num_items += x.shape[0] if self.capture_all: self.all_features.append(x) if self.capture_mean_cov: x64 = x.astype(np.float64) self.raw_mean += x64.sum(axis=0) self.raw_cov += x64.T @ x64 def append_torch(self, x, num_gpus=1, rank=0): assert isinstance(x, torch.Tensor) and x.ndim == 2 assert 0 <= rank < num_gpus if num_gpus > 1: ys = [] for src in range(num_gpus): y = x.clone() torch.distributed.broadcast(y, src=src) ys.append(y) x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples self.append(x.cpu().numpy()) def get_all(self): assert self.capture_all return np.concatenate(self.all_features, axis=0) def get_all_torch(self): return torch.from_numpy(self.get_all()) def get_mean_cov(self): assert self.capture_mean_cov mean = self.raw_mean / self.num_items cov = self.raw_cov / self.num_items cov = cov - np.outer(mean, mean) return mean, cov def save(self, pkl_file): with open(pkl_file, 'wb') as f: pickle.dump(self.__dict__, f) @staticmethod def load(pkl_file): with open(pkl_file, 'rb') as f: s = dnnlib.EasyDict(pickle.load(f)) obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items) obj.__dict__.update(s) return obj #---------------------------------------------------------------------------- class ProgressMonitor: def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000): self.tag = tag self.num_items = num_items self.verbose = verbose self.flush_interval = flush_interval self.progress_fn = progress_fn self.pfn_lo = pfn_lo self.pfn_hi = pfn_hi self.pfn_total = pfn_total self.start_time = time.time() self.batch_time = self.start_time self.batch_items = 0 if self.progress_fn is not None: self.progress_fn(self.pfn_lo, self.pfn_total) def update(self, cur_items: int): assert (self.num_items is None) or (cur_items <= self.num_items), f"Wrong `items` values: cur_items={cur_items}, self.num_items={self.num_items}" if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items): return cur_time = time.time() total_time = cur_time - self.start_time time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1) if (self.verbose) and (self.tag is not None): print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}') self.batch_time = cur_time self.batch_items = cur_items if (self.progress_fn is not None) and (self.num_items is not None): self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total) def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1): return ProgressMonitor( tag = tag, num_items = num_items, flush_interval = flush_interval, verbose = self.verbose, progress_fn = self.progress_fn, pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo, pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi, pfn_total = self.pfn_total, ) #---------------------------------------------------------------------------- @torch.no_grad() def compute_feature_stats_for_dataset( opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, temporal_detector=False, use_image_dataset=False, feature_stats_cls=FeatureStats, **stats_kwargs): dataset_kwargs = video_to_image_dataset_kwargs(opts.dataset_kwargs) if use_image_dataset else opts.dataset_kwargs dataset = dnnlib.util.construct_class_by_name(**dataset_kwargs) if data_loader_kwargs is None: data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2) # Try to lookup from cache. cache_file = None if opts.cache: # Choose cache file name. args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs, feature_stats_cls=feature_stats_cls.__name__) md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8')) cache_tag = f'{dataset.name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}' cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl') # Check if the file exists (all processes must agree). flag = os.path.isfile(cache_file) if opts.rank == 0 else False if opts.num_gpus > 1: flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device) torch.distributed.broadcast(tensor=flag, src=0) flag = (float(flag.cpu()) != 0) # Load. if flag: return feature_stats_cls.load(cache_file) # Initialize. num_items = len(dataset) data_length = len(dataset) if max_items is not None: num_items = min(num_items, max_items) stats = feature_stats_cls(max_items=num_items, **stats_kwargs) progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi) detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) # Main loop. # item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)] # original stylegan-v code item_subset = random.sample(range(data_length), num_items) # added by xin, randomly selected 2048 videos for batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=batch_size, **data_loader_kwargs): images = batch['image'] if temporal_detector: images = images.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w] # images = images.float() / 255 # images = torch.nn.functional.interpolate(images, size=(images.shape[2], 128, 128), mode='trilinear', align_corners=False) # downsample # images = torch.nn.functional.interpolate(images, size=(images.shape[2], 256, 256), mode='trilinear', align_corners=False) # upsample # images = (images * 255).to(torch.uint8) else: images = images.view(-1, *images.shape[-3:]) # [-1, c, h, w] if images.shape[1] == 1: images = images.repeat([1, 3, *([1] * (images.ndim - 2))]) features = detector(images.to(opts.device), **detector_kwargs) stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) progress.update(stats.num_items) # Save to cache. if cache_file is not None and opts.rank == 0: os.makedirs(os.path.dirname(cache_file), exist_ok=True) temp_file = cache_file + '.' + uuid.uuid4().hex stats.save(temp_file) os.replace(temp_file, cache_file) # atomic return stats #---------------------------------------------------------------------------- @torch.no_grad() def compute_feature_stats_for_generator( opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size: int=16, batch_gen=None, jit=False, temporal_detector=False, num_video_frames: int=16, feature_stats_cls=FeatureStats, subsample_factor: int=1, **stats_kwargs): if batch_gen is None: batch_gen = min(batch_size, 4) assert batch_size % batch_gen == 0 # Setup generator and load labels. G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device) dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) # Image generation func. def run_generator(z, c, t): img = G(z=z, c=c, t=t, **opts.G_kwargs) bt, c, h, w = img.shape if temporal_detector: img = img.view(bt // num_video_frames, num_video_frames, c, h, w) # [batch_size, t, c, h, w] img = img.permute(0, 2, 1, 3, 4).contiguous() # [batch_size, c, t, h, w] # img = torch.nn.functional.interpolate(img, size=(img.shape[2], 128, 128), mode='trilinear', align_corners=False) # downsample # img = torch.nn.functional.interpolate(img, size=(img.shape[2], 256, 256), mode='trilinear', align_corners=False) # upsample img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) return img # JIT. if jit: z = torch.zeros([batch_gen, G.z_dim], device=opts.device) c = torch.zeros([batch_gen, G.c_dim], device=opts.device) t = torch.zeros([batch_gen, G.cfg.sampling.num_frames_per_video], device=opts.device) run_generator = torch.jit.trace(run_generator, [z, c, t], check_trace=False) # Initialize. stats = feature_stats_cls(**stats_kwargs) assert stats.max_items is not None progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi) detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose) # Main loop. while not stats.is_full(): images = [] for _i in range(batch_size // batch_gen): z = torch.randn([batch_gen, G.z_dim], device=opts.device) cond_sample_idx = [np.random.randint(len(dataset)) for _ in range(batch_gen)] c = [dataset.get_label(i) for i in cond_sample_idx] c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) t = [list(range(0, num_video_frames * subsample_factor, subsample_factor)) for _i in range(batch_gen)] t = torch.from_numpy(np.stack(t)).pin_memory().to(opts.device) images.append(run_generator(z, c, t)) images = torch.cat(images) if images.shape[1] == 1: images = images.repeat([1, 3, *([1] * (images.ndim - 2))]) features = detector(images, **detector_kwargs) stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank) progress.update(stats.num_items) return stats #---------------------------------------------------------------------------- def rewrite_opts_for_gen_dataset(opts): """ Updates dataset arguments in the opts to enable the second dataset stats computation """ new_opts = copy.deepcopy(opts) new_opts.dataset_kwargs = new_opts.gen_dataset_kwargs new_opts.cache = False return new_opts #----------------------------------------------------------------------------