Spaces:
Runtime error
Runtime error
File size: 2,016 Bytes
5d756f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import tops
from dp2 import utils
from pathlib import Path
from torch_fidelity.generative_model_modulewrapper import GenerativeModelModuleWrapper
import torch
import torch_fidelity
class GeneratorIteratorWrapper(GenerativeModelModuleWrapper):
def __init__(self, generator, dataloader, zero_z: bool, n_diverse: int):
if isinstance(generator, utils.EMA):
generator = generator.generator
z_size = generator.z_channels
super().__init__(generator, z_size, "normal", 0)
self.zero_z = zero_z
self.dataloader = iter(dataloader)
self.n_diverse = n_diverse
self.cur_div_idx = 0
@torch.no_grad()
def forward(self, z, **kwargs):
if self.cur_div_idx == 0:
self.batch = next(self.dataloader)
if self.zero_z:
z = z.zero_()
self.cur_div_idx += 1
self.cur_div_idx = 0 if self.cur_div_idx == self.n_diverse else self.cur_div_idx
with torch.cuda.amp.autocast(enabled=tops.AMP()):
img = self.module(**self.batch)["img"]
img = (utils.denormalize_img(img)*255).byte()
return img
def compute_fid(generator, dataloader, real_directory, n_source, zero_z, n_diverse):
generator = GeneratorIteratorWrapper(generator, dataloader, zero_z, n_diverse)
batch_size = dataloader.batch_size
num_samples = (n_source * n_diverse) // batch_size * batch_size
assert n_diverse >= 1
assert (not zero_z) or n_diverse == 1
assert num_samples % batch_size == 0
assert n_source <= batch_size * len(dataloader), (batch_size*len(dataloader), n_source, n_diverse)
metrics = torch_fidelity.calculate_metrics(
input1=generator,
input2=real_directory,
cuda=torch.cuda.is_available(),
fid=True,
input2_cache_name="_".join(Path(real_directory).parts) + "_cached",
input1_model_num_samples=int(num_samples),
batch_size=dataloader.batch_size
)
return metrics["frechet_inception_distance"]
|