|
import argparse |
|
import os |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from model import Generator |
|
from utils import ten2cv, cv2ten |
|
import random |
|
|
|
seed = 0 |
|
|
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def generate(args, g_ema, device, mean_latent, sample_style, add_weight_index): |
|
if args.sample_zs is not None: |
|
sample_zs = torch.load(args.sample_zs) |
|
else: |
|
sample_zs = None |
|
|
|
with torch.no_grad(): |
|
g_ema.eval() |
|
for i in tqdm(range(args.pics)): |
|
if sample_zs is not None: |
|
sample_z = sample_zs[i] |
|
else: |
|
sample_z = torch.randn(1, args.latent, device=device) |
|
|
|
sample1, _ = g_ema([sample_z], |
|
truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False) |
|
sample2, _ = g_ema([sample_z], z_embed=sample_style, add_weight_index=add_weight_index, |
|
truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False) |
|
|
|
sample1 = ten2cv(sample1) |
|
sample2 = ten2cv(sample2) |
|
out = np.concatenate([sample1, sample2], axis=1) |
|
|
|
cv2.imwrite(f'{args.outdir}/{str(i).zfill(6)}.jpg', out) |
|
|
|
|
|
if __name__ == '__main__': |
|
device = 'cuda' |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--size', type=int, default=1024) |
|
parser.add_argument('--pics', type=int, default=20, help='N_PICS') |
|
parser.add_argument('--truncation', type=float, default=0.75) |
|
parser.add_argument('--truncation_mean', type=int, default=4096) |
|
parser.add_argument('--ckpt', type=str, default='', help='path to BlendGAN checkpoint') |
|
parser.add_argument('--style_img', type=str, default=None, help='path to style image') |
|
parser.add_argument('--sample_zs', type=str, default=None) |
|
parser.add_argument('--add_weight_index', type=int, default=6) |
|
|
|
parser.add_argument('--channel_multiplier', type=int, default=2) |
|
parser.add_argument('--outdir', type=str, default="") |
|
|
|
args = parser.parse_args() |
|
|
|
outdir = args.outdir |
|
if not os.path.exists(outdir): |
|
os.makedirs(outdir, exist_ok=True) |
|
|
|
args.latent = 512 |
|
args.n_mlp = 8 |
|
|
|
checkpoint = torch.load(args.ckpt) |
|
model_dict = checkpoint['g_ema'] |
|
if "latent_avg" in checkpoint.keys(): |
|
latent_avg = checkpoint["latent_avg"] |
|
else: |
|
latent_avg = None |
|
if "truncation" in checkpoint.keys(): |
|
args.truncation = checkpoint["truncation"] |
|
|
|
print('ckpt: ', args.ckpt) |
|
print('truncation: ', args.truncation) |
|
|
|
g_ema = Generator( |
|
args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier |
|
).to(device) |
|
g_ema.load_state_dict(model_dict) |
|
|
|
if args.truncation < 1: |
|
if latent_avg is not None: |
|
mean_latent = latent_avg |
|
print('### use mean_latent in ckpt["latent_avg"]') |
|
else: |
|
with torch.no_grad(): |
|
mean_latent = g_ema.mean_latent(args.truncation_mean) |
|
print('### generate mean_latent with \'g_ema.mean_latent\'') |
|
else: |
|
mean_latent = None |
|
print('### args.truncation = 1, mean_latent is None') |
|
|
|
if args.style_img is not None: |
|
img = cv2.imread(args.style_img, 1) |
|
img = cv2ten(img, device) |
|
sample_style = g_ema.get_z_embed(img) |
|
else: |
|
sample_style = torch.randn(1, args.latent, device=device) |
|
|
|
generate(args, g_ema, device, mean_latent, sample_style, args.add_weight_index) |
|
|
|
print('Done!') |
|
|