File size: 3,658 Bytes
0145b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import os

import cv2
import numpy as np
import torch
from tqdm import tqdm

from model import Generator
from utils import ten2cv, cv2ten
import random

seed = 0

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


def generate(args, g_ema, device, mean_latent, sample_style, add_weight_index):
    if args.sample_zs is not None:
        sample_zs = torch.load(args.sample_zs)
    else:
        sample_zs = None

    with torch.no_grad():
        g_ema.eval()
        for i in tqdm(range(args.pics)):
            if sample_zs is not None:
                sample_z = sample_zs[i]
            else:
                sample_z = torch.randn(1, args.latent, device=device)

            sample1, _ = g_ema([sample_z],
                               truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False)
            sample2, _ = g_ema([sample_z], z_embed=sample_style, add_weight_index=add_weight_index,
                               truncation=args.truncation, truncation_latent=mean_latent, return_latents=False, randomize_noise=False)

            sample1 = ten2cv(sample1)
            sample2 = ten2cv(sample2)
            out = np.concatenate([sample1, sample2], axis=1)

            cv2.imwrite(f'{args.outdir}/{str(i).zfill(6)}.jpg', out)


if __name__ == '__main__':
    device = 'cuda'

    parser = argparse.ArgumentParser()

    parser.add_argument('--size', type=int, default=1024)
    parser.add_argument('--pics', type=int, default=20, help='N_PICS')
    parser.add_argument('--truncation', type=float, default=0.75)
    parser.add_argument('--truncation_mean', type=int, default=4096)
    parser.add_argument('--ckpt', type=str, default='', help='path to BlendGAN checkpoint')
    parser.add_argument('--style_img', type=str, default=None, help='path to style image')
    parser.add_argument('--sample_zs', type=str, default=None)
    parser.add_argument('--add_weight_index', type=int, default=6)

    parser.add_argument('--channel_multiplier', type=int, default=2)
    parser.add_argument('--outdir', type=str, default="")

    args = parser.parse_args()

    outdir = args.outdir
    if not os.path.exists(outdir):
        os.makedirs(outdir, exist_ok=True)

    args.latent = 512
    args.n_mlp = 8

    checkpoint = torch.load(args.ckpt)
    model_dict = checkpoint['g_ema']
    if "latent_avg" in checkpoint.keys():
        latent_avg = checkpoint["latent_avg"]
    else:
        latent_avg = None
    if "truncation" in checkpoint.keys():
        args.truncation = checkpoint["truncation"]

    print('ckpt: ', args.ckpt)
    print('truncation: ', args.truncation)

    g_ema = Generator(
        args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
    ).to(device)
    g_ema.load_state_dict(model_dict)

    if args.truncation < 1:
        if latent_avg is not None:
            mean_latent = latent_avg
            print('### use mean_latent in ckpt["latent_avg"]')
        else:
            with torch.no_grad():
                mean_latent = g_ema.mean_latent(args.truncation_mean)
                print('### generate mean_latent with \'g_ema.mean_latent\'')
    else:
        mean_latent = None
        print('### args.truncation = 1, mean_latent is None')

    if args.style_img is not None:
        img = cv2.imread(args.style_img, 1)
        img = cv2ten(img, device)
        sample_style = g_ema.get_z_embed(img)
    else:
        sample_style = torch.randn(1, args.latent, device=device)

    generate(args, g_ema, device, mean_latent, sample_style, args.add_weight_index)

    print('Done!')