File size: 3,196 Bytes
560a1b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.

import sys

sys.path.append('..')

import argparse
from pathlib import Path

import torch
import torch.nn.functional as F
import torchvision.transforms as T

import dnnlib
import legacy
from expansion_utils import io_utils, latent_operations

def generate_images(
        ckpt,
        num_samples,
        truncation_psi
):
    device = torch.device('cuda')
    with dnnlib.util.open_url(ckpt) as f:
        snapshot_dict = legacy.load_network_pkl(f)
        G = snapshot_dict['G_ema'].to(device)
    latent_basis = snapshot_dict['latent_basis'].to(device)
    subspace_distance = snapshot_dict['subspace_distance']
    repurposed_dims = snapshot_dict['repurposed_dims'].cpu()

    # out_dir = Path(out_dir)

    def norm_fn(tensor):
        minFrom= tensor.min()
        maxFrom= tensor.max()
        minTo = 0
        maxTo=1
        return minTo + (maxTo - minTo) * ((tensor - minFrom) / (maxFrom - minFrom))
    topil = T.ToPILImage(mode='RGB')
    
    # norm_fn = T.Normalize(
    #     # mean=[0.485, 0.456, 0.406],
    #     # std=[0.229, 0.224, 0.225]
    #     mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
    #     std=[1/0.229, 1/0.224, 1/0.225]
    # )

    all_imgs = []
    for i in range(num_samples):
        per_sample_imgs = []
        z = torch.randn((1, G.z_dim), ).to(device)
        w = G.mapping(z, None, truncation_psi=truncation_psi)

        base_w, edit_ws = latent_operations.project_to_subspaces(w, latent_basis, repurposed_dims, step_size=subspace_distance, mean=G.mapping.w_avg)
        edit_ws = edit_ws[0] # Single step
        base_img = G.synthesis(base_w, noise_mode='const')
        per_sample_imgs.append(topil(norm_fn(base_img.squeeze())))
        # io_utils.save_images(base_img, out_dir.joinpath('base', f'{i:05d}'))

        for idx, (dim_num, edit_w) in enumerate(zip(repurposed_dims, edit_ws)):
            # dim_out_dir = out_dir.joinpath(f'dim_{dim_num}')
            if idx % 4 == 0:
                edit_img = G.synthesis(edit_w, noise_mode='const')
                # mean, std = edit_img.mean((0,2)), edit_img.std((0,2))
                # norm_fn = T.Normalize(mean, std)
                
                edited_img_pil = topil(
                    norm_fn(
                        edit_img.squeeze(),    
                    )
                )
                per_sample_imgs.append(edited_img_pil)
            # io_utils.save_images(edit_img, dim_out_dir.joinpath(f'{i:05d}'))

        all_imgs.append(per_sample_imgs)
    
    return all_imgs

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--ckpt', help='Network pickle filename', required=True)
    parser.add_argument('--out_dir', help='Where to save the output images', type=str, required=True, metavar='DIR')
    parser.add_argument('--num', help='Number of independant samples', type=int)
    parser.add_argument('--truncation_psi', help='Coefficient for truncation', type=float, default=1)

    args = parser.parse_args()

    with torch.no_grad():
        generate_images(args.ckpt, args.out_dir, args.num, args.truncation_psi)