File size: 3,658 Bytes
0145b71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import argparse
import os
import cv2
import numpy as np
import torch
from tqdm import tqdm
from model import Generator
from utils import ten2cv, cv2ten
import random
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def generate(args, g_ema, device, mean_latent, sample_style, add_weight_index):
if args.sample_zs is not None:
sample_zs = torch.load(args.sample_zs)
else:
sample_zs = None
with torch.no_grad():
g_ema.eval()
for i in tqdm(range(args.pics)):
if sample_zs is not None:
sample_z = sample_zs[i]
else:
sample_z = torch.randn(1, args.latent, device=device)
sample1, _ = g_ema([sample_z],
truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False)
sample2, _ = g_ema([sample_z], z_embed=sample_style, add_weight_index=add_weight_index,
truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False)
sample1 = ten2cv(sample1)
sample2 = ten2cv(sample2)
out = np.concatenate([sample1, sample2], axis=1)
cv2.imwrite(f'{args.outdir}/{str(i).zfill(6)}.jpg', out)
if __name__ == '__main__':
device = 'cuda'
parser = argparse.ArgumentParser()
parser.add_argument('--size', type=int, default=1024)
parser.add_argument('--pics', type=int, default=20, help='N_PICS')
parser.add_argument('--truncation', type=float, default=0.75)
parser.add_argument('--truncation_mean', type=int, default=4096)
parser.add_argument('--ckpt', type=str, default='', help='path to BlendGAN checkpoint')
parser.add_argument('--style_img', type=str, default=None, help='path to style image')
parser.add_argument('--sample_zs', type=str, default=None)
parser.add_argument('--add_weight_index', type=int, default=6)
parser.add_argument('--channel_multiplier', type=int, default=2)
parser.add_argument('--outdir', type=str, default="")
args = parser.parse_args()
outdir = args.outdir
if not os.path.exists(outdir):
os.makedirs(outdir, exist_ok=True)
args.latent = 512
args.n_mlp = 8
checkpoint = torch.load(args.ckpt)
model_dict = checkpoint['g_ema']
if "latent_avg" in checkpoint.keys():
latent_avg = checkpoint["latent_avg"]
else:
latent_avg = None
if "truncation" in checkpoint.keys():
args.truncation = checkpoint["truncation"]
print('ckpt: ', args.ckpt)
print('truncation: ', args.truncation)
g_ema = Generator(
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
).to(device)
g_ema.load_state_dict(model_dict)
if args.truncation < 1:
if latent_avg is not None:
mean_latent = latent_avg
print('### use mean_latent in ckpt["latent_avg"]')
else:
with torch.no_grad():
mean_latent = g_ema.mean_latent(args.truncation_mean)
print('### generate mean_latent with \'g_ema.mean_latent\'')
else:
mean_latent = None
print('### args.truncation = 1, mean_latent is None')
if args.style_img is not None:
img = cv2.imread(args.style_img, 1)
img = cv2ten(img, device)
sample_style = g_ema.get_z_embed(img)
else:
sample_style = torch.randn(1, args.latent, device=device)
generate(args, g_ema, device, mean_latent, sample_style, args.add_weight_index)
print('Done!')
|