Spaces:
Runtime error
Runtime error
import math | |
import os | |
from draggan.viz import renderer | |
import torch | |
from torch import optim | |
from torch.nn import functional as F | |
from torchvision import transforms | |
from PIL import Image | |
from tqdm import tqdm | |
import dataclasses | |
import draggan.dnnlib as dnnlib | |
from .lpips import util | |
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): | |
lr_ramp = min(1, (1 - t) / rampdown) | |
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) | |
lr_ramp = lr_ramp * min(1, t / rampup) | |
return initial_lr * lr_ramp | |
def make_image(tensor): | |
return ( | |
tensor.detach() | |
.clamp_(min=-1, max=1) | |
.add(1) | |
.div_(2) | |
.mul(255) | |
.type(torch.uint8) | |
.permute(0, 2, 3, 1) | |
.to("cpu") | |
.numpy() | |
) | |
class InverseConfig: | |
lr_warmup = 0.05 | |
lr_decay = 0.25 | |
lr = 0.1 | |
noise = 0.05 | |
noise_decay = 0.75 | |
# step = 1000 | |
step = 1000 | |
noise_regularize = 1e5 | |
mse = 0.1 | |
def inverse_image( | |
g_ema, | |
image, | |
percept, | |
image_size=256, | |
w_plus = False, | |
config=InverseConfig(), | |
device='cuda:0' | |
): | |
args = config | |
n_mean_latent = 10000 | |
resize = min(image_size, 256) | |
if torch.is_tensor(image)==False: | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(resize,), | |
transforms.CenterCrop(resize), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
img = transform(image) | |
else: | |
img = transforms.functional.resize(image,resize) | |
transform = transforms.Compose( | |
[ | |
transforms.CenterCrop(resize), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
img = transform(img) | |
imgs = [] | |
imgs.append(img) | |
imgs = torch.stack(imgs, 0).to(device) | |
with torch.no_grad(): | |
#noise_sample = torch.randn(n_mean_latent, 512, device=device) | |
noise_sample = torch.randn(n_mean_latent, g_ema.z_dim, device=device) | |
#label = torch.zeros([n_mean_latent,g_ema.c_dim],device = device) | |
w_samples = g_ema.mapping(noise_sample,None) | |
w_samples = w_samples[:, :1, :] | |
w_avg = w_samples.mean(0) | |
w_std = ((w_samples - w_avg).pow(2).sum() / n_mean_latent) ** 0.5 | |
noises = {name: buf for (name, buf) in g_ema.synthesis.named_buffers() if 'noise_const' in name} | |
for noise in noises.values(): | |
noise = torch.randn_like(noise) | |
noise.requires_grad = True | |
w_opt = w_avg.detach().clone() | |
if w_plus: | |
w_opt = w_opt.repeat(1,g_ema.mapping.num_ws, 1) | |
w_opt.requires_grad = True | |
#if args.w_plus: | |
#latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) | |
optimizer = optim.Adam([w_opt] + list(noises.values()), lr=args.lr) | |
pbar = tqdm(range(args.step)) | |
latent_path = [] | |
for i in pbar: | |
t = i / args.step | |
lr = get_lr(t, args.lr) | |
optimizer.param_groups[0]["lr"] = lr | |
noise_strength = w_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2 | |
w_noise = torch.randn_like(w_opt) * noise_strength | |
if w_plus: | |
ws = w_opt + w_noise | |
else: | |
ws = (w_opt + w_noise).repeat([1, g_ema.mapping.num_ws, 1]) | |
img_gen = g_ema.synthesis(ws, noise_mode='const', force_fp32=True) | |
#latent_n = latent_noise(latent_in, noise_strength.item()) | |
#latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises) | |
#img_gen, F = g_ema.generate(latent, noise) | |
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. | |
if img_gen.shape[2] > 256: | |
img_gen = F.interpolate(img_gen, size=(256, 256), mode='area') | |
p_loss = percept(img_gen,imgs) | |
# Noise regularization. | |
reg_loss = 0.0 | |
for v in noises.values(): | |
noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d() | |
while True: | |
reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 | |
reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 | |
if noise.shape[2] <= 8: | |
break | |
noise = F.avg_pool2d(noise, kernel_size=2) | |
mse_loss = F.mse_loss(img_gen, imgs) | |
loss = p_loss + args.noise_regularize * reg_loss + args.mse * mse_loss | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# Normalize noise. | |
with torch.no_grad(): | |
for buf in noises.values(): | |
buf -= buf.mean() | |
buf *= buf.square().mean().rsqrt() | |
if (i + 1) % 100 == 0: | |
latent_path.append(w_opt.detach().clone()) | |
pbar.set_description( | |
( | |
f"perceptual: {p_loss.item():.4f}; noise regularize: {reg_loss:.4f};" | |
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" | |
) | |
) | |
#latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises) | |
#img_gen, F = g_ema.generate(latent, noise) | |
if w_plus: | |
ws = latent_path[-1] | |
else: | |
ws = latent_path[-1].repeat([1, g_ema.mapping.num_ws, 1]) | |
img_gen = g_ema.synthesis(ws, noise_mode='const') | |
result = { | |
"latent": latent_path[-1], | |
"sample": img_gen, | |
"real": imgs, | |
} | |
return result | |
def toogle_grad(model, flag=True): | |
for p in model.parameters(): | |
p.requires_grad = flag | |
class PTI: | |
def __init__(self,G, percept, l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ): | |
self.g_ema = G | |
self.l2_lambda = l2_lambda | |
self.max_pti_step = max_pti_step | |
self.pti_lr = pti_lr | |
self.percept = percept | |
def cacl_loss(self,percept, generated_image,real_image): | |
mse_loss = F.mse_loss(generated_image, real_image) | |
p_loss = percept(generated_image, real_image).sum() | |
loss = p_loss +self.l2_lambda * mse_loss | |
return loss | |
def train(self,img,w_plus=False): | |
if not torch.cuda.is_available(): | |
device = 'cpu' | |
else: | |
device = 'cuda' | |
if torch.is_tensor(img) == False: | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(self.g_ema.img_resolution, ), | |
transforms.CenterCrop(self.g_ema.img_resolution), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
real_img = transform(img).to(device).unsqueeze(0) | |
else: | |
img = transforms.functional.resize(img, self.g_ema.img_resolution) | |
transform = transforms.Compose( | |
[ | |
transforms.CenterCrop(self.g_ema.img_resolution), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
real_img = transform(img).to(device).unsqueeze(0) | |
inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus,device=device) | |
w_pivot = inversed_result['latent'] | |
if w_plus: | |
ws = w_pivot | |
else: | |
ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1]) | |
toogle_grad(self.g_ema,True) | |
optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr) | |
print('start PTI') | |
pbar = tqdm(range(self.max_pti_step)) | |
for i in pbar: | |
t = i / self.max_pti_step | |
lr = get_lr(t, self.pti_lr) | |
optimizer.param_groups[0]["lr"] = lr | |
generated_image = self.g_ema.synthesis(ws,noise_mode='const') | |
loss = self.cacl_loss(self.percept,generated_image,real_img) | |
pbar.set_description( | |
( | |
f"loss: {loss.item():.4f}" | |
) | |
) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
with torch.no_grad(): | |
generated_image = self.g_ema.synthesis(ws, noise_mode='const') | |
return generated_image,ws |