Johannes Kolbe
commited on
Commit
·
ff2b8e3
1
Parent(s):
3b72cdb
add original sefa files back in
Browse files- SessionState.py +129 -0
- interface.py +128 -0
- models/__init__.py +114 -0
- models/pggan_discriminator.py +402 -0
- models/pggan_generator.py +338 -0
- models/stylegan2_discriminator.py +468 -0
- models/stylegan2_generator.py +996 -0
- models/stylegan_discriminator.py +530 -0
- models/stylegan_generator.py +869 -0
- models/sync_op.py +18 -0
- sefa.py +145 -0
- utils.py +509 -0
SessionState.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Adds pre-session state to StreamLit.
|
2 |
+
|
3 |
+
This file is borrowed from
|
4 |
+
https://gist.github.com/tvst/036da038ab3e999a64497f42de966a92
|
5 |
+
"""
|
6 |
+
|
7 |
+
# pylint: disable=protected-access
|
8 |
+
|
9 |
+
try:
|
10 |
+
import streamlit.ReportThread as ReportThread
|
11 |
+
from streamlit.server.Server import Server
|
12 |
+
except ModuleNotFoundError:
|
13 |
+
# Streamlit >= 0.65.0
|
14 |
+
import streamlit.report_thread as ReportThread
|
15 |
+
from streamlit.server.server import Server
|
16 |
+
|
17 |
+
|
18 |
+
class SessionState(object):
|
19 |
+
"""Hack to add per-session state to Streamlit.
|
20 |
+
|
21 |
+
Usage
|
22 |
+
-----
|
23 |
+
|
24 |
+
>>> import SessionState
|
25 |
+
>>>
|
26 |
+
>>> session_state = SessionState.get(user_name='', favorite_color='black')
|
27 |
+
>>> session_state.user_name
|
28 |
+
''
|
29 |
+
>>> session_state.user_name = 'Mary'
|
30 |
+
>>> session_state.favorite_color
|
31 |
+
'black'
|
32 |
+
|
33 |
+
Since you set user_name above, next time your script runs this will be the
|
34 |
+
result:
|
35 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
36 |
+
>>> session_state.user_name
|
37 |
+
'Mary'
|
38 |
+
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, **kwargs):
|
42 |
+
"""A new SessionState object.
|
43 |
+
|
44 |
+
Parameters
|
45 |
+
----------
|
46 |
+
**kwargs : any
|
47 |
+
Default values for the session state.
|
48 |
+
|
49 |
+
Example
|
50 |
+
-------
|
51 |
+
>>> session_state = SessionState(user_name='', favorite_color='black')
|
52 |
+
>>> session_state.user_name = 'Mary'
|
53 |
+
''
|
54 |
+
>>> session_state.favorite_color
|
55 |
+
'black'
|
56 |
+
|
57 |
+
"""
|
58 |
+
for key, val in kwargs.items():
|
59 |
+
setattr(self, key, val)
|
60 |
+
|
61 |
+
|
62 |
+
def get(**kwargs):
|
63 |
+
"""Gets a SessionState object for the current session.
|
64 |
+
|
65 |
+
Creates a new object if necessary.
|
66 |
+
|
67 |
+
Parameters
|
68 |
+
----------
|
69 |
+
**kwargs : any
|
70 |
+
Default values you want to add to the session state, if we're creating a
|
71 |
+
new one.
|
72 |
+
|
73 |
+
Example
|
74 |
+
-------
|
75 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
76 |
+
>>> session_state.user_name
|
77 |
+
''
|
78 |
+
>>> session_state.user_name = 'Mary'
|
79 |
+
>>> session_state.favorite_color
|
80 |
+
'black'
|
81 |
+
|
82 |
+
Since you set user_name above, next time your script runs this will be the
|
83 |
+
result:
|
84 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
85 |
+
>>> session_state.user_name
|
86 |
+
'Mary'
|
87 |
+
|
88 |
+
"""
|
89 |
+
# Hack to get the session object from Streamlit.
|
90 |
+
|
91 |
+
ctx = ReportThread.get_report_ctx()
|
92 |
+
|
93 |
+
this_session = None
|
94 |
+
|
95 |
+
current_server = Server.get_current()
|
96 |
+
if hasattr(current_server, '_session_infos'):
|
97 |
+
# Streamlit < 0.56
|
98 |
+
session_infos = Server.get_current()._session_infos.values()
|
99 |
+
else:
|
100 |
+
session_infos = Server.get_current()._session_info_by_id.values()
|
101 |
+
|
102 |
+
for session_info in session_infos:
|
103 |
+
s = session_info.session
|
104 |
+
if (
|
105 |
+
# Streamlit < 0.54.0
|
106 |
+
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
|
107 |
+
or
|
108 |
+
# Streamlit >= 0.54.0
|
109 |
+
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
|
110 |
+
or
|
111 |
+
# Streamlit >= 0.65.2
|
112 |
+
(not hasattr(s, '_main_dg') and
|
113 |
+
s._uploaded_file_mgr == ctx.uploaded_file_mgr)
|
114 |
+
):
|
115 |
+
this_session = s
|
116 |
+
|
117 |
+
if this_session is None:
|
118 |
+
raise RuntimeError(
|
119 |
+
"Oh noes. Couldn't get your Streamlit Session object. "
|
120 |
+
'Are you doing something fancy with threads?')
|
121 |
+
|
122 |
+
# Got the session object! Now let's attach some state into it.
|
123 |
+
|
124 |
+
if not hasattr(this_session, '_custom_session_state'):
|
125 |
+
this_session._custom_session_state = SessionState(**kwargs)
|
126 |
+
|
127 |
+
return this_session._custom_session_state
|
128 |
+
|
129 |
+
# pylint: enable=protected-access
|
interface.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python 3.7
|
2 |
+
"""Demo."""
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import streamlit as st
|
7 |
+
import SessionState
|
8 |
+
|
9 |
+
from models import parse_gan_type
|
10 |
+
from utils import to_tensor
|
11 |
+
from utils import postprocess
|
12 |
+
from utils import load_generator
|
13 |
+
from utils import factorize_weight
|
14 |
+
|
15 |
+
|
16 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
17 |
+
def get_model(model_name):
|
18 |
+
"""Gets model by name."""
|
19 |
+
return load_generator(model_name)
|
20 |
+
|
21 |
+
|
22 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
23 |
+
def factorize_model(model, layer_idx):
|
24 |
+
"""Factorizes semantics from target layers of the given model."""
|
25 |
+
return factorize_weight(model, layer_idx)
|
26 |
+
|
27 |
+
|
28 |
+
def sample(model, gan_type, num=1):
|
29 |
+
"""Samples latent codes."""
|
30 |
+
codes = torch.randn(num, model.z_space_dim).cuda()
|
31 |
+
if gan_type == 'pggan':
|
32 |
+
codes = model.layer0.pixel_norm(codes)
|
33 |
+
elif gan_type == 'stylegan':
|
34 |
+
codes = model.mapping(codes)['w']
|
35 |
+
codes = model.truncation(codes,
|
36 |
+
trunc_psi=0.7,
|
37 |
+
trunc_layers=8)
|
38 |
+
elif gan_type == 'stylegan2':
|
39 |
+
codes = model.mapping(codes)['w']
|
40 |
+
codes = model.truncation(codes,
|
41 |
+
trunc_psi=0.5,
|
42 |
+
trunc_layers=18)
|
43 |
+
codes = codes.detach().cpu().numpy()
|
44 |
+
return codes
|
45 |
+
|
46 |
+
|
47 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
48 |
+
def synthesize(model, gan_type, code):
|
49 |
+
"""Synthesizes an image with the give code."""
|
50 |
+
if gan_type == 'pggan':
|
51 |
+
image = model(to_tensor(code))['image']
|
52 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
53 |
+
image = model.synthesis(to_tensor(code))['image']
|
54 |
+
image = postprocess(image)[0]
|
55 |
+
return image
|
56 |
+
|
57 |
+
|
58 |
+
def main():
|
59 |
+
"""Main function (loop for StreamLit)."""
|
60 |
+
st.title('Closed-Form Factorization of Latent Semantics in GANs')
|
61 |
+
st.sidebar.title('Options')
|
62 |
+
reset = st.sidebar.button('Reset')
|
63 |
+
|
64 |
+
model_name = st.sidebar.selectbox(
|
65 |
+
'Model to Interpret',
|
66 |
+
['stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256',
|
67 |
+
'pggan_celebahq1024'])
|
68 |
+
|
69 |
+
model = get_model(model_name)
|
70 |
+
gan_type = parse_gan_type(model)
|
71 |
+
layer_idx = st.sidebar.selectbox(
|
72 |
+
'Layers to Interpret',
|
73 |
+
['all', '0-1', '2-5', '6-13'])
|
74 |
+
layers, boundaries, eigen_values = factorize_model(model, layer_idx)
|
75 |
+
|
76 |
+
num_semantics = st.sidebar.number_input(
|
77 |
+
'Number of semantics', value=10, min_value=0, max_value=None, step=1)
|
78 |
+
steps = {sem_idx: 0 for sem_idx in range(num_semantics)}
|
79 |
+
if gan_type == 'pggan':
|
80 |
+
max_step = 5.0
|
81 |
+
elif gan_type == 'stylegan':
|
82 |
+
max_step = 2.0
|
83 |
+
elif gan_type == 'stylegan2':
|
84 |
+
max_step = 15.0
|
85 |
+
for sem_idx in steps:
|
86 |
+
eigen_value = eigen_values[sem_idx]
|
87 |
+
steps[sem_idx] = st.sidebar.slider(
|
88 |
+
f'Semantic {sem_idx:03d} (eigen value: {eigen_value:.3f})',
|
89 |
+
value=0.0,
|
90 |
+
min_value=-max_step,
|
91 |
+
max_value=max_step,
|
92 |
+
step=0.04 * max_step if not reset else 0.0)
|
93 |
+
|
94 |
+
image_placeholder = st.empty()
|
95 |
+
button_placeholder = st.empty()
|
96 |
+
|
97 |
+
try:
|
98 |
+
base_codes = np.load(f'latent_codes/{model_name}_latents.npy')
|
99 |
+
except FileNotFoundError:
|
100 |
+
base_codes = sample(model, gan_type)
|
101 |
+
|
102 |
+
state = SessionState.get(model_name=model_name,
|
103 |
+
code_idx=0,
|
104 |
+
codes=base_codes[0:1])
|
105 |
+
if state.model_name != model_name:
|
106 |
+
state.model_name = model_name
|
107 |
+
state.code_idx = 0
|
108 |
+
state.codes = base_codes[0:1]
|
109 |
+
|
110 |
+
if button_placeholder.button('Random', key=0):
|
111 |
+
state.code_idx += 1
|
112 |
+
if state.code_idx < base_codes.shape[0]:
|
113 |
+
state.codes = base_codes[state.code_idx][np.newaxis]
|
114 |
+
else:
|
115 |
+
state.codes = sample(model, gan_type)
|
116 |
+
|
117 |
+
code = state.codes.copy()
|
118 |
+
for sem_idx, step in steps.items():
|
119 |
+
if gan_type == 'pggan':
|
120 |
+
code += boundaries[sem_idx:sem_idx + 1] * step
|
121 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
122 |
+
code[:, layers, :] += boundaries[sem_idx:sem_idx + 1] * step
|
123 |
+
image = synthesize(model, gan_type, code)
|
124 |
+
image_placeholder.image(image / 255.0)
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
main()
|
models/__init__.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Collects all available models together."""
|
3 |
+
|
4 |
+
from .model_zoo import MODEL_ZOO
|
5 |
+
from .pggan_generator import PGGANGenerator
|
6 |
+
from .pggan_discriminator import PGGANDiscriminator
|
7 |
+
from .stylegan_generator import StyleGANGenerator
|
8 |
+
from .stylegan_discriminator import StyleGANDiscriminator
|
9 |
+
from .stylegan2_generator import StyleGAN2Generator
|
10 |
+
from .stylegan2_discriminator import StyleGAN2Discriminator
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
'MODEL_ZOO', 'PGGANGenerator', 'PGGANDiscriminator', 'StyleGANGenerator',
|
14 |
+
'StyleGANDiscriminator', 'StyleGAN2Generator', 'StyleGAN2Discriminator',
|
15 |
+
'build_generator', 'build_discriminator', 'build_model'
|
16 |
+
]
|
17 |
+
|
18 |
+
_GAN_TYPES_ALLOWED = ['pggan', 'stylegan', 'stylegan2']
|
19 |
+
_MODULES_ALLOWED = ['generator', 'discriminator']
|
20 |
+
|
21 |
+
|
22 |
+
def build_generator(gan_type, resolution, **kwargs):
|
23 |
+
"""Builds generator by GAN type.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
gan_type: GAN type to which the generator belong.
|
27 |
+
resolution: Synthesis resolution.
|
28 |
+
**kwargs: Additional arguments to build the generator.
|
29 |
+
|
30 |
+
Raises:
|
31 |
+
ValueError: If the `gan_type` is not supported.
|
32 |
+
NotImplementedError: If the `gan_type` is not implemented.
|
33 |
+
"""
|
34 |
+
if gan_type not in _GAN_TYPES_ALLOWED:
|
35 |
+
raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
|
36 |
+
f'Types allowed: {_GAN_TYPES_ALLOWED}.')
|
37 |
+
|
38 |
+
if gan_type == 'pggan':
|
39 |
+
return PGGANGenerator(resolution, **kwargs)
|
40 |
+
if gan_type == 'stylegan':
|
41 |
+
return StyleGANGenerator(resolution, **kwargs)
|
42 |
+
if gan_type == 'stylegan2':
|
43 |
+
return StyleGAN2Generator(resolution, **kwargs)
|
44 |
+
raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
|
45 |
+
|
46 |
+
|
47 |
+
def build_discriminator(gan_type, resolution, **kwargs):
|
48 |
+
"""Builds discriminator by GAN type.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
gan_type: GAN type to which the discriminator belong.
|
52 |
+
resolution: Synthesis resolution.
|
53 |
+
**kwargs: Additional arguments to build the discriminator.
|
54 |
+
|
55 |
+
Raises:
|
56 |
+
ValueError: If the `gan_type` is not supported.
|
57 |
+
NotImplementedError: If the `gan_type` is not implemented.
|
58 |
+
"""
|
59 |
+
if gan_type not in _GAN_TYPES_ALLOWED:
|
60 |
+
raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
|
61 |
+
f'Types allowed: {_GAN_TYPES_ALLOWED}.')
|
62 |
+
|
63 |
+
if gan_type == 'pggan':
|
64 |
+
return PGGANDiscriminator(resolution, **kwargs)
|
65 |
+
if gan_type == 'stylegan':
|
66 |
+
return StyleGANDiscriminator(resolution, **kwargs)
|
67 |
+
if gan_type == 'stylegan2':
|
68 |
+
return StyleGAN2Discriminator(resolution, **kwargs)
|
69 |
+
raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
|
70 |
+
|
71 |
+
|
72 |
+
def build_model(gan_type, module, resolution, **kwargs):
|
73 |
+
"""Builds a GAN module (generator/discriminator/etc).
|
74 |
+
|
75 |
+
Args:
|
76 |
+
gan_type: GAN type to which the model belong.
|
77 |
+
module: GAN module to build, such as generator or discrimiantor.
|
78 |
+
resolution: Synthesis resolution.
|
79 |
+
**kwargs: Additional arguments to build the discriminator.
|
80 |
+
|
81 |
+
Raises:
|
82 |
+
ValueError: If the `module` is not supported.
|
83 |
+
NotImplementedError: If the `module` is not implemented.
|
84 |
+
"""
|
85 |
+
if module not in _MODULES_ALLOWED:
|
86 |
+
raise ValueError(f'Invalid module: `{module}`!\n'
|
87 |
+
f'Modules allowed: {_MODULES_ALLOWED}.')
|
88 |
+
|
89 |
+
if module == 'generator':
|
90 |
+
return build_generator(gan_type, resolution, **kwargs)
|
91 |
+
if module == 'discriminator':
|
92 |
+
return build_discriminator(gan_type, resolution, **kwargs)
|
93 |
+
raise NotImplementedError(f'Unsupported module `{module}`!')
|
94 |
+
|
95 |
+
|
96 |
+
def parse_gan_type(module):
|
97 |
+
"""Parses GAN type of a given module.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
module: The module to parse GAN type from.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
A string, indicating the GAN type.
|
104 |
+
|
105 |
+
Raises:
|
106 |
+
ValueError: If the GAN type is unknown.
|
107 |
+
"""
|
108 |
+
if isinstance(module, (PGGANGenerator, PGGANDiscriminator)):
|
109 |
+
return 'pggan'
|
110 |
+
if isinstance(module, (StyleGANGenerator, StyleGANDiscriminator)):
|
111 |
+
return 'stylegan'
|
112 |
+
if isinstance(module, (StyleGAN2Generator, StyleGAN2Discriminator)):
|
113 |
+
return 'stylegan2'
|
114 |
+
raise ValueError(f'Unable to parse GAN type from type `{type(module)}`!')
|
models/pggan_discriminator.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of discriminator described in PGGAN.
|
3 |
+
|
4 |
+
Paper: https://arxiv.org/pdf/1710.10196.pdf
|
5 |
+
|
6 |
+
Official TensorFlow implementation:
|
7 |
+
https://github.com/tkarras/progressive_growing_of_gans
|
8 |
+
"""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
__all__ = ['PGGANDiscriminator']
|
17 |
+
|
18 |
+
# Resolutions allowed.
|
19 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
20 |
+
|
21 |
+
# Initial resolution.
|
22 |
+
_INIT_RES = 4
|
23 |
+
|
24 |
+
# Default gain factor for weight scaling.
|
25 |
+
_WSCALE_GAIN = np.sqrt(2.0)
|
26 |
+
|
27 |
+
|
28 |
+
class PGGANDiscriminator(nn.Module):
|
29 |
+
"""Defines the discriminator network in PGGAN.
|
30 |
+
|
31 |
+
NOTE: The discriminator takes images with `RGB` channel order and pixel
|
32 |
+
range [-1, 1] as inputs.
|
33 |
+
|
34 |
+
Settings for the network:
|
35 |
+
|
36 |
+
(1) resolution: The resolution of the input image.
|
37 |
+
(2) image_channels: Number of channels of the input image. (default: 3)
|
38 |
+
(3) label_size: Size of the additional label for conditional generation.
|
39 |
+
(default: 0)
|
40 |
+
(4) fused_scale: Whether to fused `conv2d` and `downsample` together,
|
41 |
+
resulting in `conv2d` with strides. (default: False)
|
42 |
+
(5) use_wscale: Whether to use weight scaling. (default: True)
|
43 |
+
(6) minibatch_std_group_size: Group size for the minibatch standard
|
44 |
+
deviation layer. 0 means disable. (default: 16)
|
45 |
+
(7) fmaps_base: Factor to control number of feature maps for each layer.
|
46 |
+
(default: 16 << 10)
|
47 |
+
(8) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(self,
|
51 |
+
resolution,
|
52 |
+
image_channels=3,
|
53 |
+
label_size=0,
|
54 |
+
fused_scale=False,
|
55 |
+
use_wscale=True,
|
56 |
+
minibatch_std_group_size=16,
|
57 |
+
fmaps_base=16 << 10,
|
58 |
+
fmaps_max=512):
|
59 |
+
"""Initializes with basic settings.
|
60 |
+
|
61 |
+
Raises:
|
62 |
+
ValueError: If the `resolution` is not supported.
|
63 |
+
"""
|
64 |
+
super().__init__()
|
65 |
+
|
66 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
67 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
68 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
69 |
+
|
70 |
+
self.init_res = _INIT_RES
|
71 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
72 |
+
self.resolution = resolution
|
73 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
74 |
+
self.image_channels = image_channels
|
75 |
+
self.label_size = label_size
|
76 |
+
self.fused_scale = fused_scale
|
77 |
+
self.use_wscale = use_wscale
|
78 |
+
self.minibatch_std_group_size = minibatch_std_group_size
|
79 |
+
self.fmaps_base = fmaps_base
|
80 |
+
self.fmaps_max = fmaps_max
|
81 |
+
|
82 |
+
# Level of detail (used for progressive training).
|
83 |
+
self.register_buffer('lod', torch.zeros(()))
|
84 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
85 |
+
|
86 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
87 |
+
res = 2 ** res_log2
|
88 |
+
block_idx = self.final_res_log2 - res_log2
|
89 |
+
|
90 |
+
# Input convolution layer for each resolution.
|
91 |
+
self.add_module(
|
92 |
+
f'input{block_idx}',
|
93 |
+
ConvBlock(in_channels=self.image_channels,
|
94 |
+
out_channels=self.get_nf(res),
|
95 |
+
kernel_size=1,
|
96 |
+
padding=0,
|
97 |
+
use_wscale=self.use_wscale))
|
98 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
|
99 |
+
f'FromRGB_lod{block_idx}/weight')
|
100 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
|
101 |
+
f'FromRGB_lod{block_idx}/bias')
|
102 |
+
|
103 |
+
# Convolution block for each resolution (except the last one).
|
104 |
+
if res != self.init_res:
|
105 |
+
self.add_module(
|
106 |
+
f'layer{2 * block_idx}',
|
107 |
+
ConvBlock(in_channels=self.get_nf(res),
|
108 |
+
out_channels=self.get_nf(res),
|
109 |
+
use_wscale=self.use_wscale))
|
110 |
+
tf_layer0_name = 'Conv0'
|
111 |
+
self.add_module(
|
112 |
+
f'layer{2 * block_idx + 1}',
|
113 |
+
ConvBlock(in_channels=self.get_nf(res),
|
114 |
+
out_channels=self.get_nf(res // 2),
|
115 |
+
downsample=True,
|
116 |
+
fused_scale=self.fused_scale,
|
117 |
+
use_wscale=self.use_wscale))
|
118 |
+
tf_layer1_name = 'Conv1_down' if self.fused_scale else 'Conv1'
|
119 |
+
|
120 |
+
# Convolution block for last resolution.
|
121 |
+
else:
|
122 |
+
self.add_module(
|
123 |
+
f'layer{2 * block_idx}',
|
124 |
+
ConvBlock(
|
125 |
+
in_channels=self.get_nf(res),
|
126 |
+
out_channels=self.get_nf(res),
|
127 |
+
use_wscale=self.use_wscale,
|
128 |
+
minibatch_std_group_size=self.minibatch_std_group_size))
|
129 |
+
tf_layer0_name = 'Conv'
|
130 |
+
self.add_module(
|
131 |
+
f'layer{2 * block_idx + 1}',
|
132 |
+
DenseBlock(in_channels=self.get_nf(res) * res * res,
|
133 |
+
out_channels=self.get_nf(res // 2),
|
134 |
+
use_wscale=self.use_wscale))
|
135 |
+
tf_layer1_name = 'Dense0'
|
136 |
+
|
137 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
|
138 |
+
f'{res}x{res}/{tf_layer0_name}/weight')
|
139 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
|
140 |
+
f'{res}x{res}/{tf_layer0_name}/bias')
|
141 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
|
142 |
+
f'{res}x{res}/{tf_layer1_name}/weight')
|
143 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
|
144 |
+
f'{res}x{res}/{tf_layer1_name}/bias')
|
145 |
+
|
146 |
+
# Final dense block.
|
147 |
+
self.add_module(
|
148 |
+
f'layer{2 * block_idx + 2}',
|
149 |
+
DenseBlock(in_channels=self.get_nf(res // 2),
|
150 |
+
out_channels=1 + self.label_size,
|
151 |
+
use_wscale=self.use_wscale,
|
152 |
+
wscale_gain=1.0,
|
153 |
+
activation_type='linear'))
|
154 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = (
|
155 |
+
f'{res}x{res}/Dense1/weight')
|
156 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = (
|
157 |
+
f'{res}x{res}/Dense1/bias')
|
158 |
+
|
159 |
+
self.downsample = DownsamplingLayer()
|
160 |
+
|
161 |
+
def get_nf(self, res):
|
162 |
+
"""Gets number of feature maps according to current resolution."""
|
163 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
164 |
+
|
165 |
+
def forward(self, image, lod=None, **_unused_kwargs):
|
166 |
+
expected_shape = (self.image_channels, self.resolution, self.resolution)
|
167 |
+
if image.ndim != 4 or image.shape[1:] != expected_shape:
|
168 |
+
raise ValueError(f'The input tensor should be with shape '
|
169 |
+
f'[batch_size, channel, height, width], where '
|
170 |
+
f'`channel` equals to {self.image_channels}, '
|
171 |
+
f'`height`, `width` equal to {self.resolution}!\n'
|
172 |
+
f'But `{image.shape}` is received!')
|
173 |
+
|
174 |
+
lod = self.lod.cpu().tolist() if lod is None else lod
|
175 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
176 |
+
raise ValueError(f'Maximum level-of-detail (lod) is '
|
177 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
178 |
+
f'but `{lod}` is received!')
|
179 |
+
|
180 |
+
lod = self.lod.cpu().tolist()
|
181 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
182 |
+
block_idx = current_lod = self.final_res_log2 - res_log2
|
183 |
+
if current_lod <= lod < current_lod + 1:
|
184 |
+
x = self.__getattr__(f'input{block_idx}')(image)
|
185 |
+
elif current_lod - 1 < lod < current_lod:
|
186 |
+
alpha = lod - np.floor(lod)
|
187 |
+
x = (self.__getattr__(f'input{block_idx}')(image) * alpha +
|
188 |
+
x * (1 - alpha))
|
189 |
+
if lod < current_lod + 1:
|
190 |
+
x = self.__getattr__(f'layer{2 * block_idx}')(x)
|
191 |
+
x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
|
192 |
+
if lod > current_lod:
|
193 |
+
image = self.downsample(image)
|
194 |
+
x = self.__getattr__(f'layer{2 * block_idx + 2}')(x)
|
195 |
+
return x
|
196 |
+
|
197 |
+
|
198 |
+
class MiniBatchSTDLayer(nn.Module):
|
199 |
+
"""Implements the minibatch standard deviation layer."""
|
200 |
+
|
201 |
+
def __init__(self, group_size=16, epsilon=1e-8):
|
202 |
+
super().__init__()
|
203 |
+
self.group_size = group_size
|
204 |
+
self.epsilon = epsilon
|
205 |
+
|
206 |
+
def forward(self, x):
|
207 |
+
if self.group_size <= 1:
|
208 |
+
return x
|
209 |
+
group_size = min(self.group_size, x.shape[0]) # [NCHW]
|
210 |
+
y = x.view(group_size, -1, x.shape[1], x.shape[2], x.shape[3]) # [GMCHW]
|
211 |
+
y = y - torch.mean(y, dim=0, keepdim=True) # [GMCHW]
|
212 |
+
y = torch.mean(y ** 2, dim=0) # [MCHW]
|
213 |
+
y = torch.sqrt(y + self.epsilon) # [MCHW]
|
214 |
+
y = torch.mean(y, dim=[1, 2, 3], keepdim=True) # [M111]
|
215 |
+
y = y.repeat(group_size, 1, x.shape[2], x.shape[3]) # [N1HW]
|
216 |
+
return torch.cat([x, y], dim=1)
|
217 |
+
|
218 |
+
|
219 |
+
class DownsamplingLayer(nn.Module):
|
220 |
+
"""Implements the downsampling layer.
|
221 |
+
|
222 |
+
Basically, this layer can be used to downsample feature maps with average
|
223 |
+
pooling.
|
224 |
+
"""
|
225 |
+
|
226 |
+
def __init__(self, scale_factor=2):
|
227 |
+
super().__init__()
|
228 |
+
self.scale_factor = scale_factor
|
229 |
+
|
230 |
+
def forward(self, x):
|
231 |
+
if self.scale_factor <= 1:
|
232 |
+
return x
|
233 |
+
return F.avg_pool2d(x,
|
234 |
+
kernel_size=self.scale_factor,
|
235 |
+
stride=self.scale_factor,
|
236 |
+
padding=0)
|
237 |
+
|
238 |
+
|
239 |
+
class ConvBlock(nn.Module):
|
240 |
+
"""Implements the convolutional block.
|
241 |
+
|
242 |
+
Basically, this block executes minibatch standard deviation layer (if
|
243 |
+
needed), convolutional layer, activation layer, and downsampling layer (
|
244 |
+
if needed) in sequence.
|
245 |
+
"""
|
246 |
+
|
247 |
+
def __init__(self,
|
248 |
+
in_channels,
|
249 |
+
out_channels,
|
250 |
+
kernel_size=3,
|
251 |
+
stride=1,
|
252 |
+
padding=1,
|
253 |
+
add_bias=True,
|
254 |
+
downsample=False,
|
255 |
+
fused_scale=False,
|
256 |
+
use_wscale=True,
|
257 |
+
wscale_gain=_WSCALE_GAIN,
|
258 |
+
activation_type='lrelu',
|
259 |
+
minibatch_std_group_size=0):
|
260 |
+
"""Initializes with block settings.
|
261 |
+
|
262 |
+
Args:
|
263 |
+
in_channels: Number of channels of the input tensor.
|
264 |
+
out_channels: Number of channels of the output tensor.
|
265 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
266 |
+
stride: Stride parameter for convolution operation. (default: 1)
|
267 |
+
padding: Padding parameter for convolution operation. (default: 1)
|
268 |
+
add_bias: Whether to add bias onto the convolutional result.
|
269 |
+
(default: True)
|
270 |
+
downsample: Whether to downsample the result after convolution.
|
271 |
+
(default: False)
|
272 |
+
fused_scale: Whether to fused `conv2d` and `downsample` together,
|
273 |
+
resulting in `conv2d` with strides. (default: False)
|
274 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
275 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
276 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
277 |
+
(default: `lrelu`)
|
278 |
+
minibatch_std_group_size: Group size for the minibatch standard
|
279 |
+
deviation layer. 0 means disable. (default: 0)
|
280 |
+
|
281 |
+
Raises:
|
282 |
+
NotImplementedError: If the `activation_type` is not supported.
|
283 |
+
"""
|
284 |
+
super().__init__()
|
285 |
+
|
286 |
+
if minibatch_std_group_size > 1:
|
287 |
+
in_channels = in_channels + 1
|
288 |
+
self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size)
|
289 |
+
else:
|
290 |
+
self.mbstd = nn.Identity()
|
291 |
+
|
292 |
+
if downsample and not fused_scale:
|
293 |
+
self.downsample = DownsamplingLayer()
|
294 |
+
else:
|
295 |
+
self.downsample = nn.Identity()
|
296 |
+
|
297 |
+
if downsample and fused_scale:
|
298 |
+
self.use_stride = True
|
299 |
+
self.stride = 2
|
300 |
+
self.padding = 1
|
301 |
+
else:
|
302 |
+
self.use_stride = False
|
303 |
+
self.stride = stride
|
304 |
+
self.padding = padding
|
305 |
+
|
306 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
307 |
+
fan_in = kernel_size * kernel_size * in_channels
|
308 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
309 |
+
if use_wscale:
|
310 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
311 |
+
self.wscale = wscale
|
312 |
+
else:
|
313 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
314 |
+
self.wscale = 1.0
|
315 |
+
|
316 |
+
if add_bias:
|
317 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
318 |
+
else:
|
319 |
+
self.bias = None
|
320 |
+
|
321 |
+
if activation_type == 'linear':
|
322 |
+
self.activate = nn.Identity()
|
323 |
+
elif activation_type == 'lrelu':
|
324 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
325 |
+
else:
|
326 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
327 |
+
f'`{activation_type}`!')
|
328 |
+
|
329 |
+
def forward(self, x):
|
330 |
+
x = self.mbstd(x)
|
331 |
+
weight = self.weight * self.wscale
|
332 |
+
if self.use_stride:
|
333 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
|
334 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
335 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25
|
336 |
+
x = F.conv2d(x,
|
337 |
+
weight=weight,
|
338 |
+
bias=self.bias,
|
339 |
+
stride=self.stride,
|
340 |
+
padding=self.padding)
|
341 |
+
x = self.activate(x)
|
342 |
+
x = self.downsample(x)
|
343 |
+
return x
|
344 |
+
|
345 |
+
|
346 |
+
class DenseBlock(nn.Module):
|
347 |
+
"""Implements the dense block.
|
348 |
+
|
349 |
+
Basically, this block executes fully-connected layer, and activation layer.
|
350 |
+
"""
|
351 |
+
|
352 |
+
def __init__(self,
|
353 |
+
in_channels,
|
354 |
+
out_channels,
|
355 |
+
add_bias=True,
|
356 |
+
use_wscale=True,
|
357 |
+
wscale_gain=_WSCALE_GAIN,
|
358 |
+
activation_type='lrelu'):
|
359 |
+
"""Initializes with block settings.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
in_channels: Number of channels of the input tensor.
|
363 |
+
out_channels: Number of channels of the output tensor.
|
364 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
365 |
+
(default: True)
|
366 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
367 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
368 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
369 |
+
(default: `lrelu`)
|
370 |
+
|
371 |
+
Raises:
|
372 |
+
NotImplementedError: If the `activation_type` is not supported.
|
373 |
+
"""
|
374 |
+
super().__init__()
|
375 |
+
weight_shape = (out_channels, in_channels)
|
376 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
377 |
+
if use_wscale:
|
378 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
379 |
+
self.wscale = wscale
|
380 |
+
else:
|
381 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
382 |
+
self.wscale = 1.0
|
383 |
+
|
384 |
+
if add_bias:
|
385 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
386 |
+
else:
|
387 |
+
self.bias = None
|
388 |
+
|
389 |
+
if activation_type == 'linear':
|
390 |
+
self.activate = nn.Identity()
|
391 |
+
elif activation_type == 'lrelu':
|
392 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
393 |
+
else:
|
394 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
395 |
+
f'`{activation_type}`!')
|
396 |
+
|
397 |
+
def forward(self, x):
|
398 |
+
if x.ndim != 2:
|
399 |
+
x = x.view(x.shape[0], -1)
|
400 |
+
x = F.linear(x, weight=self.weight * self.wscale, bias=self.bias)
|
401 |
+
x = self.activate(x)
|
402 |
+
return x
|
models/pggan_generator.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of generator described in PGGAN.
|
3 |
+
|
4 |
+
Paper: https://arxiv.org/pdf/1710.10196.pdf
|
5 |
+
|
6 |
+
Official TensorFlow implementation:
|
7 |
+
https://github.com/tkarras/progressive_growing_of_gans
|
8 |
+
"""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
__all__ = ['PGGANGenerator']
|
17 |
+
|
18 |
+
# Resolutions allowed.
|
19 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
20 |
+
|
21 |
+
# Initial resolution.
|
22 |
+
_INIT_RES = 4
|
23 |
+
|
24 |
+
# Default gain factor for weight scaling.
|
25 |
+
_WSCALE_GAIN = np.sqrt(2.0)
|
26 |
+
|
27 |
+
|
28 |
+
class PGGANGenerator(nn.Module):
|
29 |
+
"""Defines the generator network in PGGAN.
|
30 |
+
|
31 |
+
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
32 |
+
[-1, 1].
|
33 |
+
|
34 |
+
Settings for the network:
|
35 |
+
|
36 |
+
(1) resolution: The resolution of the output image.
|
37 |
+
(2) z_space_dim: The dimension of the latent space, Z. (default: 512)
|
38 |
+
(3) image_channels: Number of channels of the output image. (default: 3)
|
39 |
+
(4) final_tanh: Whether to use `tanh` to control the final pixel range.
|
40 |
+
(default: False)
|
41 |
+
(5) label_size: Size of the additional label for conditional generation.
|
42 |
+
(default: 0)
|
43 |
+
(6) fused_scale: Whether to fused `upsample` and `conv2d` together,
|
44 |
+
resulting in `conv2d_transpose`. (default: False)
|
45 |
+
(7) use_wscale: Whether to use weight scaling. (default: True)
|
46 |
+
(8) fmaps_base: Factor to control number of feature maps for each layer.
|
47 |
+
(default: 16 << 10)
|
48 |
+
(9) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
49 |
+
"""
|
50 |
+
|
51 |
+
def __init__(self,
|
52 |
+
resolution,
|
53 |
+
z_space_dim=512,
|
54 |
+
image_channels=3,
|
55 |
+
final_tanh=False,
|
56 |
+
label_size=0,
|
57 |
+
fused_scale=False,
|
58 |
+
use_wscale=True,
|
59 |
+
fmaps_base=16 << 10,
|
60 |
+
fmaps_max=512):
|
61 |
+
"""Initializes with basic settings.
|
62 |
+
|
63 |
+
Raises:
|
64 |
+
ValueError: If the `resolution` is not supported.
|
65 |
+
"""
|
66 |
+
super().__init__()
|
67 |
+
|
68 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
69 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
70 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
71 |
+
|
72 |
+
self.init_res = _INIT_RES
|
73 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
74 |
+
self.resolution = resolution
|
75 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
76 |
+
self.z_space_dim = z_space_dim
|
77 |
+
self.image_channels = image_channels
|
78 |
+
self.final_tanh = final_tanh
|
79 |
+
self.label_size = label_size
|
80 |
+
self.fused_scale = fused_scale
|
81 |
+
self.use_wscale = use_wscale
|
82 |
+
self.fmaps_base = fmaps_base
|
83 |
+
self.fmaps_max = fmaps_max
|
84 |
+
|
85 |
+
# Number of convolutional layers.
|
86 |
+
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
87 |
+
|
88 |
+
# Level of detail (used for progressive training).
|
89 |
+
self.register_buffer('lod', torch.zeros(()))
|
90 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
91 |
+
|
92 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
93 |
+
res = 2 ** res_log2
|
94 |
+
block_idx = res_log2 - self.init_res_log2
|
95 |
+
|
96 |
+
# First convolution layer for each resolution.
|
97 |
+
if res == self.init_res:
|
98 |
+
self.add_module(
|
99 |
+
f'layer{2 * block_idx}',
|
100 |
+
ConvBlock(in_channels=self.z_space_dim + self.label_size,
|
101 |
+
out_channels=self.get_nf(res),
|
102 |
+
kernel_size=self.init_res,
|
103 |
+
padding=self.init_res - 1,
|
104 |
+
use_wscale=self.use_wscale))
|
105 |
+
tf_layer_name = 'Dense'
|
106 |
+
else:
|
107 |
+
self.add_module(
|
108 |
+
f'layer{2 * block_idx}',
|
109 |
+
ConvBlock(in_channels=self.get_nf(res // 2),
|
110 |
+
out_channels=self.get_nf(res),
|
111 |
+
upsample=True,
|
112 |
+
fused_scale=self.fused_scale,
|
113 |
+
use_wscale=self.use_wscale))
|
114 |
+
tf_layer_name = 'Conv0_up' if self.fused_scale else 'Conv0'
|
115 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
|
116 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
117 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
|
118 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
119 |
+
|
120 |
+
# Second convolution layer for each resolution.
|
121 |
+
self.add_module(
|
122 |
+
f'layer{2 * block_idx + 1}',
|
123 |
+
ConvBlock(in_channels=self.get_nf(res),
|
124 |
+
out_channels=self.get_nf(res),
|
125 |
+
use_wscale=self.use_wscale))
|
126 |
+
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
|
127 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
|
128 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
129 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
|
130 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
131 |
+
|
132 |
+
# Output convolution layer for each resolution.
|
133 |
+
self.add_module(
|
134 |
+
f'output{block_idx}',
|
135 |
+
ConvBlock(in_channels=self.get_nf(res),
|
136 |
+
out_channels=self.image_channels,
|
137 |
+
kernel_size=1,
|
138 |
+
padding=0,
|
139 |
+
use_wscale=self.use_wscale,
|
140 |
+
wscale_gain=1.0,
|
141 |
+
activation_type='linear'))
|
142 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
|
143 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
|
144 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
|
145 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
|
146 |
+
|
147 |
+
self.upsample = UpsamplingLayer()
|
148 |
+
self.final_activate = nn.Tanh() if self.final_tanh else nn.Identity()
|
149 |
+
|
150 |
+
def get_nf(self, res):
|
151 |
+
"""Gets number of feature maps according to current resolution."""
|
152 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
153 |
+
|
154 |
+
def forward(self, z, label=None, lod=None, **_unused_kwargs):
|
155 |
+
if z.ndim != 2 or z.shape[1] != self.z_space_dim:
|
156 |
+
raise ValueError(f'Input latent code should be with shape '
|
157 |
+
f'[batch_size, latent_dim], where '
|
158 |
+
f'`latent_dim` equals to {self.z_space_dim}!\n'
|
159 |
+
f'But `{z.shape}` is received!')
|
160 |
+
z = self.layer0.pixel_norm(z)
|
161 |
+
if self.label_size:
|
162 |
+
if label is None:
|
163 |
+
raise ValueError(f'Model requires an additional label '
|
164 |
+
f'(with size {self.label_size}) as input, '
|
165 |
+
f'but no label is received!')
|
166 |
+
if label.ndim != 2 or label.shape != (z.shape[0], self.label_size):
|
167 |
+
raise ValueError(f'Input label should be with shape '
|
168 |
+
f'[batch_size, label_size], where '
|
169 |
+
f'`batch_size` equals to that of '
|
170 |
+
f'latent codes ({z.shape[0]}) and '
|
171 |
+
f'`label_size` equals to {self.label_size}!\n'
|
172 |
+
f'But `{label.shape}` is received!')
|
173 |
+
z = torch.cat((z, label), dim=1)
|
174 |
+
|
175 |
+
lod = self.lod.cpu().tolist() if lod is None else lod
|
176 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
177 |
+
raise ValueError(f'Maximum level-of-detail (lod) is '
|
178 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
179 |
+
f'but `{lod}` is received!')
|
180 |
+
|
181 |
+
x = z.view(z.shape[0], self.z_space_dim + self.label_size, 1, 1)
|
182 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
183 |
+
current_lod = self.final_res_log2 - res_log2
|
184 |
+
if lod < current_lod + 1:
|
185 |
+
block_idx = res_log2 - self.init_res_log2
|
186 |
+
x = self.__getattr__(f'layer{2 * block_idx}')(x)
|
187 |
+
x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
|
188 |
+
if current_lod - 1 < lod <= current_lod:
|
189 |
+
image = self.__getattr__(f'output{block_idx}')(x)
|
190 |
+
elif current_lod < lod < current_lod + 1:
|
191 |
+
alpha = np.ceil(lod) - lod
|
192 |
+
image = (self.__getattr__(f'output{block_idx}')(x) * alpha +
|
193 |
+
self.upsample(image) * (1 - alpha))
|
194 |
+
elif lod >= current_lod + 1:
|
195 |
+
image = self.upsample(image)
|
196 |
+
image = self.final_activate(image)
|
197 |
+
|
198 |
+
results = {
|
199 |
+
'z': z,
|
200 |
+
'label': label,
|
201 |
+
'image': image,
|
202 |
+
}
|
203 |
+
return results
|
204 |
+
|
205 |
+
|
206 |
+
class PixelNormLayer(nn.Module):
|
207 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
208 |
+
|
209 |
+
def __init__(self, epsilon=1e-8):
|
210 |
+
super().__init__()
|
211 |
+
self.eps = epsilon
|
212 |
+
|
213 |
+
def forward(self, x):
|
214 |
+
norm = torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.eps)
|
215 |
+
return x / norm
|
216 |
+
|
217 |
+
|
218 |
+
class UpsamplingLayer(nn.Module):
|
219 |
+
"""Implements the upsampling layer.
|
220 |
+
|
221 |
+
Basically, this layer can be used to upsample feature maps with nearest
|
222 |
+
neighbor interpolation.
|
223 |
+
"""
|
224 |
+
|
225 |
+
def __init__(self, scale_factor=2):
|
226 |
+
super().__init__()
|
227 |
+
self.scale_factor = scale_factor
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
if self.scale_factor <= 1:
|
231 |
+
return x
|
232 |
+
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
|
233 |
+
|
234 |
+
|
235 |
+
class ConvBlock(nn.Module):
|
236 |
+
"""Implements the convolutional block.
|
237 |
+
|
238 |
+
Basically, this block executes pixel-wise normalization layer, upsampling
|
239 |
+
layer (if needed), convolutional layer, and activation layer in sequence.
|
240 |
+
"""
|
241 |
+
|
242 |
+
def __init__(self,
|
243 |
+
in_channels,
|
244 |
+
out_channels,
|
245 |
+
kernel_size=3,
|
246 |
+
stride=1,
|
247 |
+
padding=1,
|
248 |
+
add_bias=True,
|
249 |
+
upsample=False,
|
250 |
+
fused_scale=False,
|
251 |
+
use_wscale=True,
|
252 |
+
wscale_gain=_WSCALE_GAIN,
|
253 |
+
activation_type='lrelu'):
|
254 |
+
"""Initializes with block settings.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
in_channels: Number of channels of the input tensor.
|
258 |
+
out_channels: Number of channels of the output tensor.
|
259 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
260 |
+
stride: Stride parameter for convolution operation. (default: 1)
|
261 |
+
padding: Padding parameter for convolution operation. (default: 1)
|
262 |
+
add_bias: Whether to add bias onto the convolutional result.
|
263 |
+
(default: True)
|
264 |
+
upsample: Whether to upsample the input tensor before convolution.
|
265 |
+
(default: False)
|
266 |
+
fused_scale: Whether to fused `upsample` and `conv2d` together,
|
267 |
+
resulting in `conv2d_transpose`. (default: False)
|
268 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
269 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
270 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
271 |
+
(default: `lrelu`)
|
272 |
+
|
273 |
+
Raises:
|
274 |
+
NotImplementedError: If the `activation_type` is not supported.
|
275 |
+
"""
|
276 |
+
super().__init__()
|
277 |
+
|
278 |
+
self.pixel_norm = PixelNormLayer()
|
279 |
+
|
280 |
+
if upsample and not fused_scale:
|
281 |
+
self.upsample = UpsamplingLayer()
|
282 |
+
else:
|
283 |
+
self.upsample = nn.Identity()
|
284 |
+
|
285 |
+
if upsample and fused_scale:
|
286 |
+
self.use_conv2d_transpose = True
|
287 |
+
weight_shape = (in_channels, out_channels, kernel_size, kernel_size)
|
288 |
+
self.stride = 2
|
289 |
+
self.padding = 1
|
290 |
+
else:
|
291 |
+
self.use_conv2d_transpose = False
|
292 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
293 |
+
self.stride = stride
|
294 |
+
self.padding = padding
|
295 |
+
|
296 |
+
fan_in = kernel_size * kernel_size * in_channels
|
297 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
298 |
+
if use_wscale:
|
299 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
300 |
+
self.wscale = wscale
|
301 |
+
else:
|
302 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
303 |
+
self.wscale = 1.0
|
304 |
+
|
305 |
+
if add_bias:
|
306 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
307 |
+
else:
|
308 |
+
self.bias = None
|
309 |
+
|
310 |
+
if activation_type == 'linear':
|
311 |
+
self.activate = nn.Identity()
|
312 |
+
elif activation_type == 'lrelu':
|
313 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
314 |
+
else:
|
315 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
316 |
+
f'`{activation_type}`!')
|
317 |
+
|
318 |
+
def forward(self, x):
|
319 |
+
x = self.pixel_norm(x)
|
320 |
+
x = self.upsample(x)
|
321 |
+
weight = self.weight * self.wscale
|
322 |
+
if self.use_conv2d_transpose:
|
323 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
|
324 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
325 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
|
326 |
+
x = F.conv_transpose2d(x,
|
327 |
+
weight=weight,
|
328 |
+
bias=self.bias,
|
329 |
+
stride=self.stride,
|
330 |
+
padding=self.padding)
|
331 |
+
else:
|
332 |
+
x = F.conv2d(x,
|
333 |
+
weight=weight,
|
334 |
+
bias=self.bias,
|
335 |
+
stride=self.stride,
|
336 |
+
padding=self.padding)
|
337 |
+
x = self.activate(x)
|
338 |
+
return x
|
models/stylegan2_discriminator.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of discriminator described in StyleGAN2.
|
3 |
+
|
4 |
+
Compared to that of StyleGAN, the discriminator in StyleGAN2 mainly adds skip
|
5 |
+
connections, increases model size and disables progressive growth. This script
|
6 |
+
ONLY supports config F in the original paper.
|
7 |
+
|
8 |
+
Paper: https://arxiv.org/pdf/1912.04958.pdf
|
9 |
+
|
10 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
|
11 |
+
"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
__all__ = ['StyleGAN2Discriminator']
|
20 |
+
|
21 |
+
# Resolutions allowed.
|
22 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
23 |
+
|
24 |
+
# Initial resolution.
|
25 |
+
_INIT_RES = 4
|
26 |
+
|
27 |
+
# Architectures allowed.
|
28 |
+
_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
|
29 |
+
|
30 |
+
# Default gain factor for weight scaling.
|
31 |
+
_WSCALE_GAIN = 1.0
|
32 |
+
|
33 |
+
|
34 |
+
class StyleGAN2Discriminator(nn.Module):
|
35 |
+
"""Defines the discriminator network in StyleGAN2.
|
36 |
+
|
37 |
+
NOTE: The discriminator takes images with `RGB` channel order and pixel
|
38 |
+
range [-1, 1] as inputs.
|
39 |
+
|
40 |
+
Settings for the network:
|
41 |
+
|
42 |
+
(1) resolution: The resolution of the input image.
|
43 |
+
(2) image_channels: Number of channels of the input image. (default: 3)
|
44 |
+
(3) label_size: Size of the additional label for conditional generation.
|
45 |
+
(default: 0)
|
46 |
+
(4) architecture: Type of architecture. Support `origin`, `skip`, and
|
47 |
+
`resnet`. (default: `resnet`)
|
48 |
+
(5) use_wscale: Whether to use weight scaling. (default: True)
|
49 |
+
(6) minibatch_std_group_size: Group size for the minibatch standard
|
50 |
+
deviation layer. 0 means disable. (default: 4)
|
51 |
+
(7) minibatch_std_channels: Number of new channels after the minibatch
|
52 |
+
standard deviation layer. (default: 1)
|
53 |
+
(8) fmaps_base: Factor to control number of feature maps for each layer.
|
54 |
+
(default: 32 << 10)
|
55 |
+
(9) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self,
|
59 |
+
resolution,
|
60 |
+
image_channels=3,
|
61 |
+
label_size=0,
|
62 |
+
architecture='resnet',
|
63 |
+
use_wscale=True,
|
64 |
+
minibatch_std_group_size=4,
|
65 |
+
minibatch_std_channels=1,
|
66 |
+
fmaps_base=32 << 10,
|
67 |
+
fmaps_max=512):
|
68 |
+
"""Initializes with basic settings.
|
69 |
+
|
70 |
+
Raises:
|
71 |
+
ValueError: If the `resolution` is not supported, or `architecture`
|
72 |
+
is not supported.
|
73 |
+
"""
|
74 |
+
super().__init__()
|
75 |
+
|
76 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
77 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
78 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
79 |
+
if architecture not in _ARCHITECTURES_ALLOWED:
|
80 |
+
raise ValueError(f'Invalid architecture: `{architecture}`!\n'
|
81 |
+
f'Architectures allowed: '
|
82 |
+
f'{_ARCHITECTURES_ALLOWED}.')
|
83 |
+
|
84 |
+
self.init_res = _INIT_RES
|
85 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
86 |
+
self.resolution = resolution
|
87 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
88 |
+
self.image_channels = image_channels
|
89 |
+
self.label_size = label_size
|
90 |
+
self.architecture = architecture
|
91 |
+
self.use_wscale = use_wscale
|
92 |
+
self.minibatch_std_group_size = minibatch_std_group_size
|
93 |
+
self.minibatch_std_channels = minibatch_std_channels
|
94 |
+
self.fmaps_base = fmaps_base
|
95 |
+
self.fmaps_max = fmaps_max
|
96 |
+
|
97 |
+
self.pth_to_tf_var_mapping = {}
|
98 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
99 |
+
res = 2 ** res_log2
|
100 |
+
block_idx = self.final_res_log2 - res_log2
|
101 |
+
|
102 |
+
# Input convolution layer for each resolution (if needed).
|
103 |
+
if res_log2 == self.final_res_log2 or self.architecture == 'skip':
|
104 |
+
self.add_module(
|
105 |
+
f'input{block_idx}',
|
106 |
+
ConvBlock(in_channels=self.image_channels,
|
107 |
+
out_channels=self.get_nf(res),
|
108 |
+
kernel_size=1,
|
109 |
+
use_wscale=self.use_wscale))
|
110 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
|
111 |
+
f'{res}x{res}/FromRGB/weight')
|
112 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
|
113 |
+
f'{res}x{res}/FromRGB/bias')
|
114 |
+
|
115 |
+
# Convolution block for each resolution (except the last one).
|
116 |
+
if res != self.init_res:
|
117 |
+
self.add_module(
|
118 |
+
f'layer{2 * block_idx}',
|
119 |
+
ConvBlock(in_channels=self.get_nf(res),
|
120 |
+
out_channels=self.get_nf(res),
|
121 |
+
use_wscale=self.use_wscale))
|
122 |
+
tf_layer0_name = 'Conv0'
|
123 |
+
self.add_module(
|
124 |
+
f'layer{2 * block_idx + 1}',
|
125 |
+
ConvBlock(in_channels=self.get_nf(res),
|
126 |
+
out_channels=self.get_nf(res // 2),
|
127 |
+
scale_factor=2,
|
128 |
+
use_wscale=self.use_wscale))
|
129 |
+
tf_layer1_name = 'Conv1_down'
|
130 |
+
|
131 |
+
if self.architecture == 'resnet':
|
132 |
+
layer_name = f'skip_layer{block_idx}'
|
133 |
+
self.add_module(
|
134 |
+
layer_name,
|
135 |
+
ConvBlock(in_channels=self.get_nf(res),
|
136 |
+
out_channels=self.get_nf(res // 2),
|
137 |
+
kernel_size=1,
|
138 |
+
add_bias=False,
|
139 |
+
scale_factor=2,
|
140 |
+
use_wscale=self.use_wscale,
|
141 |
+
activation_type='linear'))
|
142 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
143 |
+
f'{res}x{res}/Skip/weight')
|
144 |
+
|
145 |
+
# Convolution block for last resolution.
|
146 |
+
else:
|
147 |
+
self.add_module(
|
148 |
+
f'layer{2 * block_idx}',
|
149 |
+
ConvBlock(in_channels=self.get_nf(res),
|
150 |
+
out_channels=self.get_nf(res),
|
151 |
+
use_wscale=self.use_wscale,
|
152 |
+
minibatch_std_group_size=minibatch_std_group_size,
|
153 |
+
minibatch_std_channels=minibatch_std_channels))
|
154 |
+
tf_layer0_name = 'Conv'
|
155 |
+
self.add_module(
|
156 |
+
f'layer{2 * block_idx + 1}',
|
157 |
+
DenseBlock(in_channels=self.get_nf(res) * res * res,
|
158 |
+
out_channels=self.get_nf(res // 2),
|
159 |
+
use_wscale=self.use_wscale))
|
160 |
+
tf_layer1_name = 'Dense0'
|
161 |
+
|
162 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
|
163 |
+
f'{res}x{res}/{tf_layer0_name}/weight')
|
164 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
|
165 |
+
f'{res}x{res}/{tf_layer0_name}/bias')
|
166 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
|
167 |
+
f'{res}x{res}/{tf_layer1_name}/weight')
|
168 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
|
169 |
+
f'{res}x{res}/{tf_layer1_name}/bias')
|
170 |
+
|
171 |
+
# Final dense block.
|
172 |
+
self.add_module(
|
173 |
+
f'layer{2 * block_idx + 2}',
|
174 |
+
DenseBlock(in_channels=self.get_nf(res // 2),
|
175 |
+
out_channels=max(self.label_size, 1),
|
176 |
+
use_wscale=self.use_wscale,
|
177 |
+
activation_type='linear'))
|
178 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = (
|
179 |
+
f'Output/weight')
|
180 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = (
|
181 |
+
f'Output/bias')
|
182 |
+
|
183 |
+
if self.architecture == 'skip':
|
184 |
+
self.downsample = DownsamplingLayer()
|
185 |
+
|
186 |
+
def get_nf(self, res):
|
187 |
+
"""Gets number of feature maps according to current resolution."""
|
188 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
189 |
+
|
190 |
+
def forward(self, image, label=None, **_unused_kwargs):
|
191 |
+
expected_shape = (self.image_channels, self.resolution, self.resolution)
|
192 |
+
if image.ndim != 4 or image.shape[1:] != expected_shape:
|
193 |
+
raise ValueError(f'The input tensor should be with shape '
|
194 |
+
f'[batch_size, channel, height, width], where '
|
195 |
+
f'`channel` equals to {self.image_channels}, '
|
196 |
+
f'`height`, `width` equal to {self.resolution}!\n'
|
197 |
+
f'But `{image.shape}` is received!')
|
198 |
+
if self.label_size:
|
199 |
+
if label is None:
|
200 |
+
raise ValueError(f'Model requires an additional label '
|
201 |
+
f'(with size {self.label_size}) as inputs, '
|
202 |
+
f'but no label is received!')
|
203 |
+
batch_size = image.shape[0]
|
204 |
+
if label.ndim != 2 or label.shape != (batch_size, self.label_size):
|
205 |
+
raise ValueError(f'Input label should be with shape '
|
206 |
+
f'[batch_size, label_size], where '
|
207 |
+
f'`batch_size` equals to that of '
|
208 |
+
f'images ({image.shape[0]}) and '
|
209 |
+
f'`label_size` equals to {self.label_size}!\n'
|
210 |
+
f'But `{label.shape}` is received!')
|
211 |
+
|
212 |
+
x = self.input0(image)
|
213 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
214 |
+
block_idx = self.final_res_log2 - res_log2
|
215 |
+
if self.architecture == 'skip' and block_idx > 0:
|
216 |
+
image = self.downsample(image)
|
217 |
+
x = x + self.__getattr__(f'input{block_idx}')(image)
|
218 |
+
if self.architecture == 'resnet' and res_log2 != self.init_res_log2:
|
219 |
+
residual = self.__getattr__(f'skip_layer{block_idx}')(x)
|
220 |
+
x = self.__getattr__(f'layer{2 * block_idx}')(x)
|
221 |
+
x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
|
222 |
+
if self.architecture == 'resnet' and res_log2 != self.init_res_log2:
|
223 |
+
x = (x + residual) / np.sqrt(2.0)
|
224 |
+
x = self.__getattr__(f'layer{2 * block_idx + 2}')(x)
|
225 |
+
|
226 |
+
if self.label_size:
|
227 |
+
x = torch.sum(x * label, dim=1, keepdim=True)
|
228 |
+
return x
|
229 |
+
|
230 |
+
|
231 |
+
class MiniBatchSTDLayer(nn.Module):
|
232 |
+
"""Implements the minibatch standard deviation layer."""
|
233 |
+
|
234 |
+
def __init__(self, group_size=4, new_channels=1, epsilon=1e-8):
|
235 |
+
super().__init__()
|
236 |
+
self.group_size = group_size
|
237 |
+
self.new_channels = new_channels
|
238 |
+
self.epsilon = epsilon
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
if self.group_size <= 1:
|
242 |
+
return x
|
243 |
+
ng = min(self.group_size, x.shape[0])
|
244 |
+
nc = self.new_channels
|
245 |
+
temp_c = x.shape[1] // nc # [NCHW]
|
246 |
+
y = x.view(ng, -1, nc, temp_c, x.shape[2], x.shape[3]) # [GMncHW]
|
247 |
+
y = y - torch.mean(y, dim=0, keepdim=True) # [GMncHW]
|
248 |
+
y = torch.mean(y ** 2, dim=0) # [MncHW]
|
249 |
+
y = torch.sqrt(y + self.epsilon) # [MncHW]
|
250 |
+
y = torch.mean(y, dim=[2, 3, 4], keepdim=True) # [Mn111]
|
251 |
+
y = torch.mean(y, dim=2) # [Mn11]
|
252 |
+
y = y.repeat(ng, 1, x.shape[2], x.shape[3]) # [NnHW]
|
253 |
+
return torch.cat([x, y], dim=1)
|
254 |
+
|
255 |
+
|
256 |
+
class DownsamplingLayer(nn.Module):
|
257 |
+
"""Implements the downsampling layer.
|
258 |
+
|
259 |
+
This layer can also be used as filtering by setting `scale_factor` as 1.
|
260 |
+
"""
|
261 |
+
|
262 |
+
def __init__(self, scale_factor=2, kernel=(1, 3, 3, 1), extra_padding=0):
|
263 |
+
super().__init__()
|
264 |
+
assert scale_factor >= 1
|
265 |
+
self.scale_factor = scale_factor
|
266 |
+
|
267 |
+
if extra_padding != 0:
|
268 |
+
assert scale_factor == 1
|
269 |
+
|
270 |
+
if kernel is None:
|
271 |
+
kernel = np.ones((scale_factor), dtype=np.float32)
|
272 |
+
else:
|
273 |
+
kernel = np.array(kernel, dtype=np.float32)
|
274 |
+
assert kernel.ndim == 1
|
275 |
+
kernel = np.outer(kernel, kernel)
|
276 |
+
kernel = kernel / np.sum(kernel)
|
277 |
+
assert kernel.ndim == 2
|
278 |
+
assert kernel.shape[0] == kernel.shape[1]
|
279 |
+
kernel = kernel[np.newaxis, np.newaxis]
|
280 |
+
self.register_buffer('kernel', torch.from_numpy(kernel))
|
281 |
+
self.kernel = self.kernel.flip(0, 1)
|
282 |
+
padding = kernel.shape[2] - scale_factor + extra_padding
|
283 |
+
self.padding = ((padding + 1) // 2, padding // 2,
|
284 |
+
(padding + 1) // 2, padding // 2)
|
285 |
+
|
286 |
+
def forward(self, x):
|
287 |
+
assert x.ndim == 4
|
288 |
+
channels = x.shape[1]
|
289 |
+
x = x.view(-1, 1, x.shape[2], x.shape[3])
|
290 |
+
x = F.pad(x, self.padding, mode='constant', value=0)
|
291 |
+
x = F.conv2d(x, self.kernel, stride=self.scale_factor)
|
292 |
+
x = x.view(-1, channels, x.shape[2], x.shape[3])
|
293 |
+
return x
|
294 |
+
|
295 |
+
|
296 |
+
class ConvBlock(nn.Module):
|
297 |
+
"""Implements the convolutional block.
|
298 |
+
|
299 |
+
Basically, this block executes minibatch standard deviation layer (if
|
300 |
+
needed), filtering layer (if needed), convolutional layer, and activation
|
301 |
+
layer in sequence.
|
302 |
+
"""
|
303 |
+
|
304 |
+
def __init__(self,
|
305 |
+
in_channels,
|
306 |
+
out_channels,
|
307 |
+
kernel_size=3,
|
308 |
+
add_bias=True,
|
309 |
+
scale_factor=1,
|
310 |
+
filtering_kernel=(1, 3, 3, 1),
|
311 |
+
use_wscale=True,
|
312 |
+
wscale_gain=_WSCALE_GAIN,
|
313 |
+
lr_mul=1.0,
|
314 |
+
activation_type='lrelu',
|
315 |
+
minibatch_std_group_size=0,
|
316 |
+
minibatch_std_channels=1):
|
317 |
+
"""Initializes with block settings.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
in_channels: Number of channels of the input tensor.
|
321 |
+
out_channels: Number of channels of the output tensor.
|
322 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
323 |
+
add_bias: Whether to add bias onto the convolutional result.
|
324 |
+
(default: True)
|
325 |
+
scale_factor: Scale factor for downsampling. `1` means skip
|
326 |
+
downsampling. (default: 1)
|
327 |
+
filtering_kernel: Kernel used for filtering before downsampling.
|
328 |
+
(default: (1, 3, 3, 1))
|
329 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
330 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
331 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
332 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
333 |
+
(default: `lrelu`)
|
334 |
+
minibatch_std_group_size: Group size for the minibatch standard
|
335 |
+
deviation layer. 0 means disable. (default: 0)
|
336 |
+
minibatch_std_channels: Number of new channels after the minibatch
|
337 |
+
standard deviation layer. (default: 1)
|
338 |
+
|
339 |
+
Raises:
|
340 |
+
NotImplementedError: If the `activation_type` is not supported.
|
341 |
+
"""
|
342 |
+
super().__init__()
|
343 |
+
|
344 |
+
if minibatch_std_group_size > 1:
|
345 |
+
in_channels = in_channels + minibatch_std_channels
|
346 |
+
self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size,
|
347 |
+
new_channels=minibatch_std_channels)
|
348 |
+
else:
|
349 |
+
self.mbstd = nn.Identity()
|
350 |
+
|
351 |
+
if scale_factor > 1:
|
352 |
+
extra_padding = kernel_size - scale_factor
|
353 |
+
self.filter = DownsamplingLayer(scale_factor=1,
|
354 |
+
kernel=filtering_kernel,
|
355 |
+
extra_padding=extra_padding)
|
356 |
+
self.stride = scale_factor
|
357 |
+
self.padding = 0 # Padding is done in `DownsamplingLayer`.
|
358 |
+
else:
|
359 |
+
self.filter = nn.Identity()
|
360 |
+
assert kernel_size % 2 == 1
|
361 |
+
self.stride = 1
|
362 |
+
self.padding = kernel_size // 2
|
363 |
+
|
364 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
365 |
+
fan_in = kernel_size * kernel_size * in_channels
|
366 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
367 |
+
if use_wscale:
|
368 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
369 |
+
self.wscale = wscale * lr_mul
|
370 |
+
else:
|
371 |
+
self.weight = nn.Parameter(
|
372 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
373 |
+
self.wscale = lr_mul
|
374 |
+
|
375 |
+
if add_bias:
|
376 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
377 |
+
else:
|
378 |
+
self.bias = None
|
379 |
+
self.bscale = lr_mul
|
380 |
+
|
381 |
+
if activation_type == 'linear':
|
382 |
+
self.activate = nn.Identity()
|
383 |
+
self.activate_scale = 1.0
|
384 |
+
elif activation_type == 'lrelu':
|
385 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
386 |
+
self.activate_scale = np.sqrt(2.0)
|
387 |
+
else:
|
388 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
389 |
+
f'`{activation_type}`!')
|
390 |
+
|
391 |
+
def forward(self, x):
|
392 |
+
x = self.mbstd(x)
|
393 |
+
x = self.filter(x)
|
394 |
+
weight = self.weight * self.wscale
|
395 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
396 |
+
x = F.conv2d(x,
|
397 |
+
weight=weight,
|
398 |
+
bias=bias,
|
399 |
+
stride=self.stride,
|
400 |
+
padding=self.padding)
|
401 |
+
x = self.activate(x) * self.activate_scale
|
402 |
+
return x
|
403 |
+
|
404 |
+
|
405 |
+
class DenseBlock(nn.Module):
|
406 |
+
"""Implements the dense block.
|
407 |
+
|
408 |
+
Basically, this block executes fully-connected layer and activation layer.
|
409 |
+
"""
|
410 |
+
|
411 |
+
def __init__(self,
|
412 |
+
in_channels,
|
413 |
+
out_channels,
|
414 |
+
add_bias=True,
|
415 |
+
use_wscale=True,
|
416 |
+
wscale_gain=_WSCALE_GAIN,
|
417 |
+
lr_mul=1.0,
|
418 |
+
activation_type='lrelu'):
|
419 |
+
"""Initializes with block settings.
|
420 |
+
|
421 |
+
Args:
|
422 |
+
in_channels: Number of channels of the input tensor.
|
423 |
+
out_channels: Number of channels of the output tensor.
|
424 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
425 |
+
(default: True)
|
426 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
427 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
428 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
429 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
430 |
+
(default: `lrelu`)
|
431 |
+
|
432 |
+
Raises:
|
433 |
+
NotImplementedError: If the `activation_type` is not supported.
|
434 |
+
"""
|
435 |
+
super().__init__()
|
436 |
+
weight_shape = (out_channels, in_channels)
|
437 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
438 |
+
if use_wscale:
|
439 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
440 |
+
self.wscale = wscale * lr_mul
|
441 |
+
else:
|
442 |
+
self.weight = nn.Parameter(
|
443 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
444 |
+
self.wscale = lr_mul
|
445 |
+
|
446 |
+
if add_bias:
|
447 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
448 |
+
else:
|
449 |
+
self.bias = None
|
450 |
+
self.bscale = lr_mul
|
451 |
+
|
452 |
+
if activation_type == 'linear':
|
453 |
+
self.activate = nn.Identity()
|
454 |
+
self.activate_scale = 1.0
|
455 |
+
elif activation_type == 'lrelu':
|
456 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
457 |
+
self.activate_scale = np.sqrt(2.0)
|
458 |
+
else:
|
459 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
460 |
+
f'`{activation_type}`!')
|
461 |
+
|
462 |
+
def forward(self, x):
|
463 |
+
if x.ndim != 2:
|
464 |
+
x = x.view(x.shape[0], -1)
|
465 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
466 |
+
x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
|
467 |
+
x = self.activate(x) * self.activate_scale
|
468 |
+
return x
|
models/stylegan2_generator.py
ADDED
@@ -0,0 +1,996 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of generator described in StyleGAN2.
|
3 |
+
|
4 |
+
Compared to that of StyleGAN, the generator in StyleGAN2 mainly introduces style
|
5 |
+
demodulation, adds skip connections, increases model size, and disables
|
6 |
+
progressive growth. This script ONLY supports config F in the original paper.
|
7 |
+
|
8 |
+
Paper: https://arxiv.org/pdf/1912.04958.pdf
|
9 |
+
|
10 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan2
|
11 |
+
"""
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
from .sync_op import all_gather
|
20 |
+
|
21 |
+
__all__ = ['StyleGAN2Generator']
|
22 |
+
|
23 |
+
# Resolutions allowed.
|
24 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
25 |
+
|
26 |
+
# Initial resolution.
|
27 |
+
_INIT_RES = 4
|
28 |
+
|
29 |
+
# Architectures allowed.
|
30 |
+
_ARCHITECTURES_ALLOWED = ['resnet', 'skip', 'origin']
|
31 |
+
|
32 |
+
# Default gain factor for weight scaling.
|
33 |
+
_WSCALE_GAIN = 1.0
|
34 |
+
|
35 |
+
|
36 |
+
class StyleGAN2Generator(nn.Module):
|
37 |
+
"""Defines the generator network in StyleGAN2.
|
38 |
+
|
39 |
+
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
40 |
+
[-1, 1].
|
41 |
+
|
42 |
+
Settings for the mapping network:
|
43 |
+
|
44 |
+
(1) z_space_dim: Dimension of the input latent space, Z. (default: 512)
|
45 |
+
(2) w_space_dim: Dimension of the outout latent space, W. (default: 512)
|
46 |
+
(3) label_size: Size of the additional label for conditional generation.
|
47 |
+
(default: 0)
|
48 |
+
(4)mapping_layers: Number of layers of the mapping network. (default: 8)
|
49 |
+
(5) mapping_fmaps: Number of hidden channels of the mapping network.
|
50 |
+
(default: 512)
|
51 |
+
(6) mapping_lr_mul: Learning rate multiplier for the mapping network.
|
52 |
+
(default: 0.01)
|
53 |
+
(7) repeat_w: Repeat w-code for different layers.
|
54 |
+
|
55 |
+
Settings for the synthesis network:
|
56 |
+
|
57 |
+
(1) resolution: The resolution of the output image.
|
58 |
+
(2) image_channels: Number of channels of the output image. (default: 3)
|
59 |
+
(3) final_tanh: Whether to use `tanh` to control the final pixel range.
|
60 |
+
(default: False)
|
61 |
+
(4) const_input: Whether to use a constant in the first convolutional layer.
|
62 |
+
(default: True)
|
63 |
+
(5) architecture: Type of architecture. Support `origin`, `skip`, and
|
64 |
+
`resnet`. (default: `resnet`)
|
65 |
+
(6) fused_modulate: Whether to fuse `style_modulate` and `conv2d` together.
|
66 |
+
(default: True)
|
67 |
+
(7) demodulate: Whether to perform style demodulation. (default: True)
|
68 |
+
(8) use_wscale: Whether to use weight scaling. (default: True)
|
69 |
+
(9) fmaps_base: Factor to control number of feature maps for each layer.
|
70 |
+
(default: 16 << 10)
|
71 |
+
(10) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self,
|
75 |
+
resolution,
|
76 |
+
z_space_dim=512,
|
77 |
+
w_space_dim=512,
|
78 |
+
label_size=0,
|
79 |
+
mapping_layers=8,
|
80 |
+
mapping_fmaps=512,
|
81 |
+
mapping_lr_mul=0.01,
|
82 |
+
repeat_w=True,
|
83 |
+
image_channels=3,
|
84 |
+
final_tanh=False,
|
85 |
+
const_input=True,
|
86 |
+
architecture='skip',
|
87 |
+
fused_modulate=True,
|
88 |
+
demodulate=True,
|
89 |
+
use_wscale=True,
|
90 |
+
fmaps_base=32 << 10,
|
91 |
+
fmaps_max=512):
|
92 |
+
"""Initializes with basic settings.
|
93 |
+
|
94 |
+
Raises:
|
95 |
+
ValueError: If the `resolution` is not supported, or `architecture`
|
96 |
+
is not supported.
|
97 |
+
"""
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
101 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
102 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
103 |
+
if architecture not in _ARCHITECTURES_ALLOWED:
|
104 |
+
raise ValueError(f'Invalid architecture: `{architecture}`!\n'
|
105 |
+
f'Architectures allowed: '
|
106 |
+
f'{_ARCHITECTURES_ALLOWED}.')
|
107 |
+
|
108 |
+
self.init_res = _INIT_RES
|
109 |
+
self.resolution = resolution
|
110 |
+
self.z_space_dim = z_space_dim
|
111 |
+
self.w_space_dim = w_space_dim
|
112 |
+
self.label_size = label_size
|
113 |
+
self.mapping_layers = mapping_layers
|
114 |
+
self.mapping_fmaps = mapping_fmaps
|
115 |
+
self.mapping_lr_mul = mapping_lr_mul
|
116 |
+
self.repeat_w = repeat_w
|
117 |
+
self.image_channels = image_channels
|
118 |
+
self.final_tanh = final_tanh
|
119 |
+
self.const_input = const_input
|
120 |
+
self.architecture = architecture
|
121 |
+
self.fused_modulate = fused_modulate
|
122 |
+
self.demodulate = demodulate
|
123 |
+
self.use_wscale = use_wscale
|
124 |
+
self.fmaps_base = fmaps_base
|
125 |
+
self.fmaps_max = fmaps_max
|
126 |
+
|
127 |
+
self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
|
128 |
+
|
129 |
+
if self.repeat_w:
|
130 |
+
self.mapping_space_dim = self.w_space_dim
|
131 |
+
else:
|
132 |
+
self.mapping_space_dim = self.w_space_dim * self.num_layers
|
133 |
+
self.mapping = MappingModule(input_space_dim=self.z_space_dim,
|
134 |
+
hidden_space_dim=self.mapping_fmaps,
|
135 |
+
final_space_dim=self.mapping_space_dim,
|
136 |
+
label_size=self.label_size,
|
137 |
+
num_layers=self.mapping_layers,
|
138 |
+
use_wscale=self.use_wscale,
|
139 |
+
lr_mul=self.mapping_lr_mul)
|
140 |
+
|
141 |
+
self.truncation = TruncationModule(w_space_dim=self.w_space_dim,
|
142 |
+
num_layers=self.num_layers,
|
143 |
+
repeat_w=self.repeat_w)
|
144 |
+
|
145 |
+
self.synthesis = SynthesisModule(resolution=self.resolution,
|
146 |
+
init_resolution=self.init_res,
|
147 |
+
w_space_dim=self.w_space_dim,
|
148 |
+
image_channels=self.image_channels,
|
149 |
+
final_tanh=self.final_tanh,
|
150 |
+
const_input=self.const_input,
|
151 |
+
architecture=self.architecture,
|
152 |
+
fused_modulate=self.fused_modulate,
|
153 |
+
demodulate=self.demodulate,
|
154 |
+
use_wscale=self.use_wscale,
|
155 |
+
fmaps_base=self.fmaps_base,
|
156 |
+
fmaps_max=self.fmaps_max)
|
157 |
+
|
158 |
+
self.pth_to_tf_var_mapping = {}
|
159 |
+
for key, val in self.mapping.pth_to_tf_var_mapping.items():
|
160 |
+
self.pth_to_tf_var_mapping[f'mapping.{key}'] = val
|
161 |
+
for key, val in self.truncation.pth_to_tf_var_mapping.items():
|
162 |
+
self.pth_to_tf_var_mapping[f'truncation.{key}'] = val
|
163 |
+
for key, val in self.synthesis.pth_to_tf_var_mapping.items():
|
164 |
+
self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val
|
165 |
+
|
166 |
+
def forward(self,
|
167 |
+
z,
|
168 |
+
label=None,
|
169 |
+
w_moving_decay=0.995,
|
170 |
+
style_mixing_prob=0.9,
|
171 |
+
trunc_psi=None,
|
172 |
+
trunc_layers=None,
|
173 |
+
randomize_noise=False,
|
174 |
+
**_unused_kwargs):
|
175 |
+
mapping_results = self.mapping(z, label)
|
176 |
+
w = mapping_results['w']
|
177 |
+
|
178 |
+
if self.training and w_moving_decay < 1:
|
179 |
+
batch_w_avg = all_gather(w).mean(dim=0)
|
180 |
+
self.truncation.w_avg.copy_(
|
181 |
+
self.truncation.w_avg * w_moving_decay +
|
182 |
+
batch_w_avg * (1 - w_moving_decay))
|
183 |
+
|
184 |
+
if self.training and style_mixing_prob > 0:
|
185 |
+
new_z = torch.randn_like(z)
|
186 |
+
new_w = self.mapping(new_z, label)['w']
|
187 |
+
if np.random.uniform() < style_mixing_prob:
|
188 |
+
mixing_cutoff = np.random.randint(1, self.num_layers)
|
189 |
+
w = self.truncation(w)
|
190 |
+
new_w = self.truncation(new_w)
|
191 |
+
w[:, :mixing_cutoff] = new_w[:, :mixing_cutoff]
|
192 |
+
|
193 |
+
wp = self.truncation(w, trunc_psi, trunc_layers)
|
194 |
+
synthesis_results = self.synthesis(wp, randomize_noise)
|
195 |
+
|
196 |
+
return {**mapping_results, **synthesis_results}
|
197 |
+
|
198 |
+
|
199 |
+
class MappingModule(nn.Module):
|
200 |
+
"""Implements the latent space mapping module.
|
201 |
+
|
202 |
+
Basically, this module executes several dense layers in sequence.
|
203 |
+
"""
|
204 |
+
|
205 |
+
def __init__(self,
|
206 |
+
input_space_dim=512,
|
207 |
+
hidden_space_dim=512,
|
208 |
+
final_space_dim=512,
|
209 |
+
label_size=0,
|
210 |
+
num_layers=8,
|
211 |
+
normalize_input=True,
|
212 |
+
use_wscale=True,
|
213 |
+
lr_mul=0.01):
|
214 |
+
super().__init__()
|
215 |
+
|
216 |
+
self.input_space_dim = input_space_dim
|
217 |
+
self.hidden_space_dim = hidden_space_dim
|
218 |
+
self.final_space_dim = final_space_dim
|
219 |
+
self.label_size = label_size
|
220 |
+
self.num_layers = num_layers
|
221 |
+
self.normalize_input = normalize_input
|
222 |
+
self.use_wscale = use_wscale
|
223 |
+
self.lr_mul = lr_mul
|
224 |
+
|
225 |
+
self.norm = PixelNormLayer() if self.normalize_input else nn.Identity()
|
226 |
+
|
227 |
+
self.pth_to_tf_var_mapping = {}
|
228 |
+
for i in range(num_layers):
|
229 |
+
dim_mul = 2 if label_size else 1
|
230 |
+
in_channels = (input_space_dim * dim_mul if i == 0 else
|
231 |
+
hidden_space_dim)
|
232 |
+
out_channels = (final_space_dim if i == (num_layers - 1) else
|
233 |
+
hidden_space_dim)
|
234 |
+
self.add_module(f'dense{i}',
|
235 |
+
DenseBlock(in_channels=in_channels,
|
236 |
+
out_channels=out_channels,
|
237 |
+
use_wscale=self.use_wscale,
|
238 |
+
lr_mul=self.lr_mul))
|
239 |
+
self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight'
|
240 |
+
self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias'
|
241 |
+
if label_size:
|
242 |
+
self.label_weight = nn.Parameter(
|
243 |
+
torch.randn(label_size, input_space_dim))
|
244 |
+
self.pth_to_tf_var_mapping[f'label_weight'] = f'LabelConcat/weight'
|
245 |
+
|
246 |
+
def forward(self, z, label=None):
|
247 |
+
if z.ndim != 2 or z.shape[1] != self.input_space_dim:
|
248 |
+
raise ValueError(f'Input latent code should be with shape '
|
249 |
+
f'[batch_size, input_dim], where '
|
250 |
+
f'`input_dim` equals to {self.input_space_dim}!\n'
|
251 |
+
f'But `{z.shape}` is received!')
|
252 |
+
if self.label_size:
|
253 |
+
if label is None:
|
254 |
+
raise ValueError(f'Model requires an additional label '
|
255 |
+
f'(with size {self.label_size}) as input, '
|
256 |
+
f'but no label is received!')
|
257 |
+
if label.ndim != 2 or label.shape != (z.shape[0], self.label_size):
|
258 |
+
raise ValueError(f'Input label should be with shape '
|
259 |
+
f'[batch_size, label_size], where '
|
260 |
+
f'`batch_size` equals to that of '
|
261 |
+
f'latent codes ({z.shape[0]}) and '
|
262 |
+
f'`label_size` equals to {self.label_size}!\n'
|
263 |
+
f'But `{label.shape}` is received!')
|
264 |
+
embedding = torch.matmul(label, self.label_weight)
|
265 |
+
z = torch.cat((z, embedding), dim=1)
|
266 |
+
|
267 |
+
z = self.norm(z)
|
268 |
+
w = z
|
269 |
+
for i in range(self.num_layers):
|
270 |
+
w = self.__getattr__(f'dense{i}')(w)
|
271 |
+
results = {
|
272 |
+
'z': z,
|
273 |
+
'label': label,
|
274 |
+
'w': w,
|
275 |
+
}
|
276 |
+
if self.label_size:
|
277 |
+
results['embedding'] = embedding
|
278 |
+
return results
|
279 |
+
|
280 |
+
|
281 |
+
class TruncationModule(nn.Module):
|
282 |
+
"""Implements the truncation module.
|
283 |
+
|
284 |
+
Truncation is executed as follows:
|
285 |
+
|
286 |
+
For layers in range [0, truncation_layers), the truncated w-code is computed
|
287 |
+
as
|
288 |
+
|
289 |
+
w_new = w_avg + (w - w_avg) * truncation_psi
|
290 |
+
|
291 |
+
To disable truncation, please set
|
292 |
+
(1) truncation_psi = 1.0 (None) OR
|
293 |
+
(2) truncation_layers = 0 (None)
|
294 |
+
|
295 |
+
NOTE: The returned tensor is layer-wise style codes.
|
296 |
+
"""
|
297 |
+
|
298 |
+
def __init__(self, w_space_dim, num_layers, repeat_w=True):
|
299 |
+
super().__init__()
|
300 |
+
|
301 |
+
self.num_layers = num_layers
|
302 |
+
self.w_space_dim = w_space_dim
|
303 |
+
self.repeat_w = repeat_w
|
304 |
+
|
305 |
+
if self.repeat_w:
|
306 |
+
self.register_buffer('w_avg', torch.zeros(w_space_dim))
|
307 |
+
else:
|
308 |
+
self.register_buffer('w_avg', torch.zeros(num_layers * w_space_dim))
|
309 |
+
self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'}
|
310 |
+
|
311 |
+
def forward(self, w, trunc_psi=None, trunc_layers=None):
|
312 |
+
if w.ndim == 2:
|
313 |
+
if self.repeat_w and w.shape[1] == self.w_space_dim:
|
314 |
+
w = w.view(-1, 1, self.w_space_dim)
|
315 |
+
wp = w.repeat(1, self.num_layers, 1)
|
316 |
+
else:
|
317 |
+
assert w.shape[1] == self.w_space_dim * self.num_layers
|
318 |
+
wp = w.view(-1, self.num_layers, self.w_space_dim)
|
319 |
+
else:
|
320 |
+
wp = w
|
321 |
+
assert wp.ndim == 3
|
322 |
+
assert wp.shape[1:] == (self.num_layers, self.w_space_dim)
|
323 |
+
|
324 |
+
trunc_psi = 1.0 if trunc_psi is None else trunc_psi
|
325 |
+
trunc_layers = 0 if trunc_layers is None else trunc_layers
|
326 |
+
if trunc_psi < 1.0 and trunc_layers > 0:
|
327 |
+
layer_idx = np.arange(self.num_layers).reshape(1, -1, 1)
|
328 |
+
coefs = np.ones_like(layer_idx, dtype=np.float32)
|
329 |
+
coefs[layer_idx < trunc_layers] *= trunc_psi
|
330 |
+
coefs = torch.from_numpy(coefs).to(wp)
|
331 |
+
w_avg = self.w_avg.view(1, -1, self.w_space_dim)
|
332 |
+
wp = w_avg + (wp - w_avg) * coefs
|
333 |
+
return wp
|
334 |
+
|
335 |
+
|
336 |
+
class SynthesisModule(nn.Module):
|
337 |
+
"""Implements the image synthesis module.
|
338 |
+
|
339 |
+
Basically, this module executes several convolutional layers in sequence.
|
340 |
+
"""
|
341 |
+
|
342 |
+
def __init__(self,
|
343 |
+
resolution=1024,
|
344 |
+
init_resolution=4,
|
345 |
+
w_space_dim=512,
|
346 |
+
image_channels=3,
|
347 |
+
final_tanh=False,
|
348 |
+
const_input=True,
|
349 |
+
architecture='skip',
|
350 |
+
fused_modulate=True,
|
351 |
+
demodulate=True,
|
352 |
+
use_wscale=True,
|
353 |
+
fmaps_base=32 << 10,
|
354 |
+
fmaps_max=512):
|
355 |
+
super().__init__()
|
356 |
+
|
357 |
+
self.init_res = init_resolution
|
358 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
359 |
+
self.resolution = resolution
|
360 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
361 |
+
self.w_space_dim = w_space_dim
|
362 |
+
self.image_channels = image_channels
|
363 |
+
self.final_tanh = final_tanh
|
364 |
+
self.const_input = const_input
|
365 |
+
self.architecture = architecture
|
366 |
+
self.fused_modulate = fused_modulate
|
367 |
+
self.demodulate = demodulate
|
368 |
+
self.use_wscale = use_wscale
|
369 |
+
self.fmaps_base = fmaps_base
|
370 |
+
self.fmaps_max = fmaps_max
|
371 |
+
|
372 |
+
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
373 |
+
|
374 |
+
self.pth_to_tf_var_mapping = {}
|
375 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
376 |
+
res = 2 ** res_log2
|
377 |
+
block_idx = res_log2 - self.init_res_log2
|
378 |
+
|
379 |
+
# First convolution layer for each resolution.
|
380 |
+
if res == self.init_res:
|
381 |
+
if self.const_input:
|
382 |
+
self.add_module(f'early_layer',
|
383 |
+
InputBlock(init_resolution=self.init_res,
|
384 |
+
channels=self.get_nf(res)))
|
385 |
+
self.pth_to_tf_var_mapping[f'early_layer.const'] = (
|
386 |
+
f'{res}x{res}/Const/const')
|
387 |
+
else:
|
388 |
+
self.add_module(f'early_layer',
|
389 |
+
DenseBlock(in_channels=self.w_space_dim,
|
390 |
+
out_channels=self.get_nf(res),
|
391 |
+
use_wscale=self.use_wscale))
|
392 |
+
self.pth_to_tf_var_mapping[f'early_layer.weight'] = (
|
393 |
+
f'{res}x{res}/Dense/weight')
|
394 |
+
self.pth_to_tf_var_mapping[f'early_layer.bias'] = (
|
395 |
+
f'{res}x{res}/Dense/bias')
|
396 |
+
else:
|
397 |
+
layer_name = f'layer{2 * block_idx - 1}'
|
398 |
+
self.add_module(
|
399 |
+
layer_name,
|
400 |
+
ModulateConvBlock(in_channels=self.get_nf(res // 2),
|
401 |
+
out_channels=self.get_nf(res),
|
402 |
+
resolution=res,
|
403 |
+
w_space_dim=self.w_space_dim,
|
404 |
+
scale_factor=2,
|
405 |
+
fused_modulate=self.fused_modulate,
|
406 |
+
demodulate=self.demodulate,
|
407 |
+
use_wscale=self.use_wscale))
|
408 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
409 |
+
f'{res}x{res}/Conv0_up/weight')
|
410 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
411 |
+
f'{res}x{res}/Conv0_up/bias')
|
412 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
413 |
+
f'{res}x{res}/Conv0_up/mod_weight')
|
414 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
415 |
+
f'{res}x{res}/Conv0_up/mod_bias')
|
416 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
|
417 |
+
f'{res}x{res}/Conv0_up/noise_strength')
|
418 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
|
419 |
+
f'noise{2 * block_idx - 1}')
|
420 |
+
|
421 |
+
if self.architecture == 'resnet':
|
422 |
+
layer_name = f'layer{2 * block_idx - 1}'
|
423 |
+
self.add_module(
|
424 |
+
layer_name,
|
425 |
+
ConvBlock(in_channels=self.get_nf(res // 2),
|
426 |
+
out_channels=self.get_nf(res),
|
427 |
+
kernel_size=1,
|
428 |
+
add_bias=False,
|
429 |
+
scale_factor=2,
|
430 |
+
use_wscale=self.use_wscale,
|
431 |
+
activation_type='linear'))
|
432 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
433 |
+
f'{res}x{res}/Skip/weight')
|
434 |
+
|
435 |
+
# Second convolution layer for each resolution.
|
436 |
+
layer_name = f'layer{2 * block_idx}'
|
437 |
+
self.add_module(
|
438 |
+
layer_name,
|
439 |
+
ModulateConvBlock(in_channels=self.get_nf(res),
|
440 |
+
out_channels=self.get_nf(res),
|
441 |
+
resolution=res,
|
442 |
+
w_space_dim=self.w_space_dim,
|
443 |
+
fused_modulate=self.fused_modulate,
|
444 |
+
demodulate=self.demodulate,
|
445 |
+
use_wscale=self.use_wscale))
|
446 |
+
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
|
447 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
448 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
449 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
450 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
451 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
452 |
+
f'{res}x{res}/{tf_layer_name}/mod_weight')
|
453 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
454 |
+
f'{res}x{res}/{tf_layer_name}/mod_bias')
|
455 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise_strength'] = (
|
456 |
+
f'{res}x{res}/{tf_layer_name}/noise_strength')
|
457 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.noise'] = (
|
458 |
+
f'noise{2 * block_idx}')
|
459 |
+
|
460 |
+
# Output convolution layer for each resolution (if needed).
|
461 |
+
if res_log2 == self.final_res_log2 or self.architecture == 'skip':
|
462 |
+
layer_name = f'output{block_idx}'
|
463 |
+
self.add_module(
|
464 |
+
layer_name,
|
465 |
+
ModulateConvBlock(in_channels=self.get_nf(res),
|
466 |
+
out_channels=image_channels,
|
467 |
+
resolution=res,
|
468 |
+
w_space_dim=self.w_space_dim,
|
469 |
+
kernel_size=1,
|
470 |
+
fused_modulate=self.fused_modulate,
|
471 |
+
demodulate=False,
|
472 |
+
use_wscale=self.use_wscale,
|
473 |
+
add_noise=False,
|
474 |
+
activation_type='linear'))
|
475 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
476 |
+
f'{res}x{res}/ToRGB/weight')
|
477 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
478 |
+
f'{res}x{res}/ToRGB/bias')
|
479 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
480 |
+
f'{res}x{res}/ToRGB/mod_weight')
|
481 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
482 |
+
f'{res}x{res}/ToRGB/mod_bias')
|
483 |
+
|
484 |
+
if self.architecture == 'skip':
|
485 |
+
self.upsample = UpsamplingLayer()
|
486 |
+
self.final_activate = nn.Tanh() if final_tanh else nn.Identity()
|
487 |
+
|
488 |
+
def get_nf(self, res):
|
489 |
+
"""Gets number of feature maps according to current resolution."""
|
490 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
491 |
+
|
492 |
+
def forward(self, wp, randomize_noise=False):
|
493 |
+
if wp.ndim != 3 or wp.shape[1:] != (self.num_layers, self.w_space_dim):
|
494 |
+
raise ValueError(f'Input tensor should be with shape '
|
495 |
+
f'[batch_size, num_layers, w_space_dim], where '
|
496 |
+
f'`num_layers` equals to {self.num_layers}, and '
|
497 |
+
f'`w_space_dim` equals to {self.w_space_dim}!\n'
|
498 |
+
f'But `{wp.shape}` is received!')
|
499 |
+
|
500 |
+
results = {'wp': wp}
|
501 |
+
x = self.early_layer(wp[:, 0])
|
502 |
+
if self.architecture == 'origin':
|
503 |
+
for layer_idx in range(self.num_layers - 1):
|
504 |
+
x, style = self.__getattr__(f'layer{layer_idx}')(
|
505 |
+
x, wp[:, layer_idx], randomize_noise)
|
506 |
+
results[f'style{layer_idx:02d}'] = style
|
507 |
+
image, style = self.__getattr__(f'output{layer_idx // 2}')(
|
508 |
+
x, wp[:, layer_idx + 1])
|
509 |
+
results[f'output_style{layer_idx // 2}'] = style
|
510 |
+
elif self.architecture == 'skip':
|
511 |
+
for layer_idx in range(self.num_layers - 1):
|
512 |
+
x, style = self.__getattr__(f'layer{layer_idx}')(
|
513 |
+
x, wp[:, layer_idx], randomize_noise)
|
514 |
+
results[f'style{layer_idx:02d}'] = style
|
515 |
+
if layer_idx % 2 == 0:
|
516 |
+
temp, style = self.__getattr__(f'output{layer_idx // 2}')(
|
517 |
+
x, wp[:, layer_idx + 1])
|
518 |
+
results[f'output_style{layer_idx // 2}'] = style
|
519 |
+
if layer_idx == 0:
|
520 |
+
image = temp
|
521 |
+
else:
|
522 |
+
image = temp + self.upsample(image)
|
523 |
+
elif self.architecture == 'resnet':
|
524 |
+
x, style = self.layer0(x)
|
525 |
+
results[f'style00'] = style
|
526 |
+
for layer_idx in range(1, self.num_layers - 1, 2):
|
527 |
+
residual = self.__getattr__(f'skip_layer{layer_idx // 2}')(x)
|
528 |
+
x, style = self.__getattr__(f'layer{layer_idx}')(
|
529 |
+
x, wp[:, layer_idx], randomize_noise)
|
530 |
+
results[f'style{layer_idx:02d}'] = style
|
531 |
+
x, style = self.__getattr__(f'layer{layer_idx + 1}')(
|
532 |
+
x, wp[:, layer_idx + 1], randomize_noise)
|
533 |
+
results[f'style{layer_idx + 1:02d}'] = style
|
534 |
+
x = (x + residual) / np.sqrt(2.0)
|
535 |
+
image, style = self.__getattr__(f'output{layer_idx // 2 + 1}')(
|
536 |
+
x, wp[:, layer_idx + 2])
|
537 |
+
results[f'output_style{layer_idx // 2}'] = style
|
538 |
+
results['image'] = self.final_activate(image)
|
539 |
+
return results
|
540 |
+
|
541 |
+
|
542 |
+
class PixelNormLayer(nn.Module):
|
543 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
544 |
+
|
545 |
+
def __init__(self, dim=1, epsilon=1e-8):
|
546 |
+
super().__init__()
|
547 |
+
self.dim = dim
|
548 |
+
self.eps = epsilon
|
549 |
+
|
550 |
+
def forward(self, x):
|
551 |
+
norm = torch.sqrt(
|
552 |
+
torch.mean(x ** 2, dim=self.dim, keepdim=True) + self.eps)
|
553 |
+
return x / norm
|
554 |
+
|
555 |
+
|
556 |
+
class UpsamplingLayer(nn.Module):
|
557 |
+
"""Implements the upsampling layer.
|
558 |
+
|
559 |
+
This layer can also be used as filtering by setting `scale_factor` as 1.
|
560 |
+
"""
|
561 |
+
|
562 |
+
def __init__(self,
|
563 |
+
scale_factor=2,
|
564 |
+
kernel=(1, 3, 3, 1),
|
565 |
+
extra_padding=0,
|
566 |
+
kernel_gain=None):
|
567 |
+
super().__init__()
|
568 |
+
assert scale_factor >= 1
|
569 |
+
self.scale_factor = scale_factor
|
570 |
+
|
571 |
+
if extra_padding != 0:
|
572 |
+
assert scale_factor == 1
|
573 |
+
|
574 |
+
if kernel is None:
|
575 |
+
kernel = np.ones((scale_factor), dtype=np.float32)
|
576 |
+
else:
|
577 |
+
kernel = np.array(kernel, dtype=np.float32)
|
578 |
+
assert kernel.ndim == 1
|
579 |
+
kernel = np.outer(kernel, kernel)
|
580 |
+
kernel = kernel / np.sum(kernel)
|
581 |
+
if kernel_gain is None:
|
582 |
+
kernel = kernel * (scale_factor ** 2)
|
583 |
+
else:
|
584 |
+
assert kernel_gain > 0
|
585 |
+
kernel = kernel * (kernel_gain ** 2)
|
586 |
+
assert kernel.ndim == 2
|
587 |
+
assert kernel.shape[0] == kernel.shape[1]
|
588 |
+
kernel = kernel[np.newaxis, np.newaxis]
|
589 |
+
self.register_buffer('kernel', torch.from_numpy(kernel))
|
590 |
+
self.kernel = self.kernel.flip(0, 1)
|
591 |
+
|
592 |
+
self.upsample_padding = (0, scale_factor - 1, # Width padding.
|
593 |
+
0, 0, # Width.
|
594 |
+
0, scale_factor - 1, # Height padding.
|
595 |
+
0, 0, # Height.
|
596 |
+
0, 0, # Channel.
|
597 |
+
0, 0) # Batch size.
|
598 |
+
|
599 |
+
padding = kernel.shape[2] - scale_factor + extra_padding
|
600 |
+
self.padding = ((padding + 1) // 2 + scale_factor - 1, padding // 2,
|
601 |
+
(padding + 1) // 2 + scale_factor - 1, padding // 2)
|
602 |
+
|
603 |
+
def forward(self, x):
|
604 |
+
assert x.ndim == 4
|
605 |
+
channels = x.shape[1]
|
606 |
+
if self.scale_factor > 1:
|
607 |
+
x = x.view(-1, channels, x.shape[2], 1, x.shape[3], 1)
|
608 |
+
x = F.pad(x, self.upsample_padding, mode='constant', value=0)
|
609 |
+
x = x.view(-1, channels, x.shape[2] * self.scale_factor,
|
610 |
+
x.shape[4] * self.scale_factor)
|
611 |
+
x = x.view(-1, 1, x.shape[2], x.shape[3])
|
612 |
+
x = F.pad(x, self.padding, mode='constant', value=0)
|
613 |
+
x = F.conv2d(x, self.kernel, stride=1)
|
614 |
+
x = x.view(-1, channels, x.shape[2], x.shape[3])
|
615 |
+
return x
|
616 |
+
|
617 |
+
|
618 |
+
class InputBlock(nn.Module):
|
619 |
+
"""Implements the input block.
|
620 |
+
|
621 |
+
Basically, this block starts from a const input, which is with shape
|
622 |
+
`(channels, init_resolution, init_resolution)`.
|
623 |
+
"""
|
624 |
+
|
625 |
+
def __init__(self, init_resolution, channels):
|
626 |
+
super().__init__()
|
627 |
+
self.const = nn.Parameter(
|
628 |
+
torch.randn(1, channels, init_resolution, init_resolution))
|
629 |
+
|
630 |
+
def forward(self, w):
|
631 |
+
x = self.const.repeat(w.shape[0], 1, 1, 1)
|
632 |
+
return x
|
633 |
+
|
634 |
+
|
635 |
+
class ConvBlock(nn.Module):
|
636 |
+
"""Implements the convolutional block (no style modulation).
|
637 |
+
|
638 |
+
Basically, this block executes, convolutional layer, filtering layer (if
|
639 |
+
needed), and activation layer in sequence.
|
640 |
+
|
641 |
+
NOTE: This block is particularly used for skip-connection branch in the
|
642 |
+
`resnet` structure.
|
643 |
+
"""
|
644 |
+
|
645 |
+
def __init__(self,
|
646 |
+
in_channels,
|
647 |
+
out_channels,
|
648 |
+
kernel_size=3,
|
649 |
+
add_bias=True,
|
650 |
+
scale_factor=1,
|
651 |
+
filtering_kernel=(1, 3, 3, 1),
|
652 |
+
use_wscale=True,
|
653 |
+
wscale_gain=_WSCALE_GAIN,
|
654 |
+
lr_mul=1.0,
|
655 |
+
activation_type='lrelu'):
|
656 |
+
"""Initializes with block settings.
|
657 |
+
|
658 |
+
Args:
|
659 |
+
in_channels: Number of channels of the input tensor.
|
660 |
+
out_channels: Number of channels of the output tensor.
|
661 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
662 |
+
add_bias: Whether to add bias onto the convolutional result.
|
663 |
+
(default: True)
|
664 |
+
scale_factor: Scale factor for upsampling. `1` means skip
|
665 |
+
upsampling. (default: 1)
|
666 |
+
filtering_kernel: Kernel used for filtering after upsampling.
|
667 |
+
(default: (1, 3, 3, 1))
|
668 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
669 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
670 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
671 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
672 |
+
(default: `lrelu`)
|
673 |
+
|
674 |
+
Raises:
|
675 |
+
NotImplementedError: If the `activation_type` is not supported.
|
676 |
+
"""
|
677 |
+
super().__init__()
|
678 |
+
|
679 |
+
if scale_factor > 1:
|
680 |
+
self.use_conv2d_transpose = True
|
681 |
+
extra_padding = scale_factor - kernel_size
|
682 |
+
self.filter = UpsamplingLayer(scale_factor=1,
|
683 |
+
kernel=filtering_kernel,
|
684 |
+
extra_padding=extra_padding,
|
685 |
+
kernel_gain=scale_factor)
|
686 |
+
self.stride = scale_factor
|
687 |
+
self.padding = 0 # Padding is done in `UpsamplingLayer`.
|
688 |
+
else:
|
689 |
+
self.use_conv2d_transpose = False
|
690 |
+
assert kernel_size % 2 == 1
|
691 |
+
self.stride = 1
|
692 |
+
self.padding = kernel_size // 2
|
693 |
+
|
694 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
695 |
+
fan_in = kernel_size * kernel_size * in_channels
|
696 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
697 |
+
if use_wscale:
|
698 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
699 |
+
self.wscale = wscale * lr_mul
|
700 |
+
else:
|
701 |
+
self.weight = nn.Parameter(
|
702 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
703 |
+
self.wscale = lr_mul
|
704 |
+
|
705 |
+
if add_bias:
|
706 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
707 |
+
else:
|
708 |
+
self.bias = None
|
709 |
+
self.bscale = lr_mul
|
710 |
+
|
711 |
+
if activation_type == 'linear':
|
712 |
+
self.activate = nn.Identity()
|
713 |
+
self.activate_scale = 1.0
|
714 |
+
elif activation_type == 'lrelu':
|
715 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
716 |
+
self.activate_scale = np.sqrt(2.0)
|
717 |
+
else:
|
718 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
719 |
+
f'`{activation_type}`!')
|
720 |
+
|
721 |
+
def forward(self, x):
|
722 |
+
weight = self.weight * self.wscale
|
723 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
724 |
+
if self.use_conv2d_transpose:
|
725 |
+
weight = weight.permute(1, 0, 2, 3).flip(2, 3)
|
726 |
+
x = F.conv_transpose2d(x,
|
727 |
+
weight=weight,
|
728 |
+
bias=bias,
|
729 |
+
stride=self.scale_factor,
|
730 |
+
padding=self.padding)
|
731 |
+
x = self.filter(x)
|
732 |
+
else:
|
733 |
+
x = F.conv2d(x,
|
734 |
+
weight=weight,
|
735 |
+
bias=bias,
|
736 |
+
stride=self.stride,
|
737 |
+
padding=self.padding)
|
738 |
+
x = self.activate(x) * self.activate_scale
|
739 |
+
return x
|
740 |
+
|
741 |
+
|
742 |
+
class ModulateConvBlock(nn.Module):
|
743 |
+
"""Implements the convolutional block with style modulation."""
|
744 |
+
|
745 |
+
def __init__(self,
|
746 |
+
in_channels,
|
747 |
+
out_channels,
|
748 |
+
resolution,
|
749 |
+
w_space_dim,
|
750 |
+
kernel_size=3,
|
751 |
+
add_bias=True,
|
752 |
+
scale_factor=1,
|
753 |
+
filtering_kernel=(1, 3, 3, 1),
|
754 |
+
fused_modulate=True,
|
755 |
+
demodulate=True,
|
756 |
+
use_wscale=True,
|
757 |
+
wscale_gain=_WSCALE_GAIN,
|
758 |
+
lr_mul=1.0,
|
759 |
+
add_noise=True,
|
760 |
+
activation_type='lrelu',
|
761 |
+
epsilon=1e-8):
|
762 |
+
"""Initializes with block settings.
|
763 |
+
|
764 |
+
Args:
|
765 |
+
in_channels: Number of channels of the input tensor.
|
766 |
+
out_channels: Number of channels of the output tensor.
|
767 |
+
resolution: Resolution of the output tensor.
|
768 |
+
w_space_dim: Dimension of W space for style modulation.
|
769 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
770 |
+
add_bias: Whether to add bias onto the convolutional result.
|
771 |
+
(default: True)
|
772 |
+
scale_factor: Scale factor for upsampling. `1` means skip
|
773 |
+
upsampling. (default: 1)
|
774 |
+
filtering_kernel: Kernel used for filtering after upsampling.
|
775 |
+
(default: (1, 3, 3, 1))
|
776 |
+
fused_modulate: Whether to fuse `style_modulate` and `conv2d`
|
777 |
+
together. (default: True)
|
778 |
+
demodulate: Whether to perform style demodulation. (default: True)
|
779 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
780 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
781 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
782 |
+
add_noise: Whether to add noise onto the output tensor. (default:
|
783 |
+
True)
|
784 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
785 |
+
(default: `lrelu`)
|
786 |
+
epsilon: Small number to avoid `divide by zero`. (default: 1e-8)
|
787 |
+
|
788 |
+
Raises:
|
789 |
+
NotImplementedError: If the `activation_type` is not supported.
|
790 |
+
"""
|
791 |
+
super().__init__()
|
792 |
+
|
793 |
+
self.res = resolution
|
794 |
+
self.in_c = in_channels
|
795 |
+
self.out_c = out_channels
|
796 |
+
self.ksize = kernel_size
|
797 |
+
self.eps = epsilon
|
798 |
+
|
799 |
+
if scale_factor > 1:
|
800 |
+
self.use_conv2d_transpose = True
|
801 |
+
extra_padding = scale_factor - kernel_size
|
802 |
+
self.filter = UpsamplingLayer(scale_factor=1,
|
803 |
+
kernel=filtering_kernel,
|
804 |
+
extra_padding=extra_padding,
|
805 |
+
kernel_gain=scale_factor)
|
806 |
+
self.stride = scale_factor
|
807 |
+
self.padding = 0 # Padding is done in `UpsamplingLayer`.
|
808 |
+
else:
|
809 |
+
self.use_conv2d_transpose = False
|
810 |
+
assert kernel_size % 2 == 1
|
811 |
+
self.stride = 1
|
812 |
+
self.padding = kernel_size // 2
|
813 |
+
|
814 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
815 |
+
fan_in = kernel_size * kernel_size * in_channels
|
816 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
817 |
+
if use_wscale:
|
818 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
819 |
+
self.wscale = wscale * lr_mul
|
820 |
+
else:
|
821 |
+
self.weight = nn.Parameter(
|
822 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
823 |
+
self.wscale = lr_mul
|
824 |
+
|
825 |
+
self.style = DenseBlock(in_channels=w_space_dim,
|
826 |
+
out_channels=in_channels,
|
827 |
+
additional_bias=1.0,
|
828 |
+
use_wscale=use_wscale,
|
829 |
+
activation_type='linear')
|
830 |
+
|
831 |
+
self.fused_modulate = fused_modulate
|
832 |
+
self.demodulate = demodulate
|
833 |
+
|
834 |
+
if add_bias:
|
835 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
836 |
+
else:
|
837 |
+
self.bias = None
|
838 |
+
self.bscale = lr_mul
|
839 |
+
|
840 |
+
if activation_type == 'linear':
|
841 |
+
self.activate = nn.Identity()
|
842 |
+
self.activate_scale = 1.0
|
843 |
+
elif activation_type == 'lrelu':
|
844 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
845 |
+
self.activate_scale = np.sqrt(2.0)
|
846 |
+
else:
|
847 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
848 |
+
f'`{activation_type}`!')
|
849 |
+
|
850 |
+
self.add_noise = add_noise
|
851 |
+
if self.add_noise:
|
852 |
+
self.register_buffer('noise', torch.randn(1, 1, self.res, self.res))
|
853 |
+
self.noise_strength = nn.Parameter(torch.zeros(()))
|
854 |
+
|
855 |
+
def forward(self, x, w, randomize_noise=False):
|
856 |
+
batch = x.shape[0]
|
857 |
+
|
858 |
+
weight = self.weight * self.wscale
|
859 |
+
weight = weight.permute(2, 3, 1, 0)
|
860 |
+
|
861 |
+
# Style modulation.
|
862 |
+
style = self.style(w)
|
863 |
+
_weight = weight.view(1, self.ksize, self.ksize, self.in_c, self.out_c)
|
864 |
+
_weight = _weight * style.view(batch, 1, 1, self.in_c, 1)
|
865 |
+
|
866 |
+
# Style demodulation.
|
867 |
+
if self.demodulate:
|
868 |
+
_weight_norm = torch.sqrt(
|
869 |
+
torch.sum(_weight ** 2, dim=[1, 2, 3]) + self.eps)
|
870 |
+
_weight = _weight / _weight_norm.view(batch, 1, 1, 1, self.out_c)
|
871 |
+
|
872 |
+
if self.fused_modulate:
|
873 |
+
x = x.view(1, batch * self.in_c, x.shape[2], x.shape[3])
|
874 |
+
weight = _weight.permute(1, 2, 3, 0, 4).reshape(
|
875 |
+
self.ksize, self.ksize, self.in_c, batch * self.out_c)
|
876 |
+
else:
|
877 |
+
x = x * style.view(batch, self.in_c, 1, 1)
|
878 |
+
|
879 |
+
if self.use_conv2d_transpose:
|
880 |
+
weight = weight.flip(0, 1)
|
881 |
+
if self.fused_modulate:
|
882 |
+
weight = weight.view(
|
883 |
+
self.ksize, self.ksize, self.in_c, batch, self.out_c)
|
884 |
+
weight = weight.permute(0, 1, 4, 3, 2)
|
885 |
+
weight = weight.reshape(
|
886 |
+
self.ksize, self.ksize, self.out_c, batch * self.in_c)
|
887 |
+
weight = weight.permute(3, 2, 0, 1)
|
888 |
+
else:
|
889 |
+
weight = weight.permute(2, 3, 0, 1)
|
890 |
+
x = F.conv_transpose2d(x,
|
891 |
+
weight=weight,
|
892 |
+
bias=None,
|
893 |
+
stride=self.stride,
|
894 |
+
padding=self.padding,
|
895 |
+
groups=(batch if self.fused_modulate else 1))
|
896 |
+
x = self.filter(x)
|
897 |
+
else:
|
898 |
+
weight = weight.permute(3, 2, 0, 1)
|
899 |
+
x = F.conv2d(x,
|
900 |
+
weight=weight,
|
901 |
+
bias=None,
|
902 |
+
stride=self.stride,
|
903 |
+
padding=self.padding,
|
904 |
+
groups=(batch if self.fused_modulate else 1))
|
905 |
+
|
906 |
+
if self.fused_modulate:
|
907 |
+
x = x.view(batch, self.out_c, self.res, self.res)
|
908 |
+
elif self.demodulate:
|
909 |
+
x = x / _weight_norm.view(batch, self.out_c, 1, 1)
|
910 |
+
|
911 |
+
if self.add_noise:
|
912 |
+
if randomize_noise:
|
913 |
+
noise = torch.randn(x.shape[0], 1, self.res, self.res).to(x)
|
914 |
+
else:
|
915 |
+
noise = self.noise
|
916 |
+
x = x + noise * self.noise_strength.view(1, 1, 1, 1)
|
917 |
+
|
918 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
919 |
+
if bias is not None:
|
920 |
+
x = x + bias.view(1, -1, 1, 1)
|
921 |
+
x = self.activate(x) * self.activate_scale
|
922 |
+
return x, style
|
923 |
+
|
924 |
+
|
925 |
+
class DenseBlock(nn.Module):
|
926 |
+
"""Implements the dense block.
|
927 |
+
|
928 |
+
Basically, this block executes fully-connected layer and activation layer.
|
929 |
+
|
930 |
+
NOTE: This layer supports adding an additional bias beyond the trainable
|
931 |
+
bias parameter. This is specially used for the mapping from the w code to
|
932 |
+
the style code.
|
933 |
+
"""
|
934 |
+
|
935 |
+
def __init__(self,
|
936 |
+
in_channels,
|
937 |
+
out_channels,
|
938 |
+
add_bias=True,
|
939 |
+
additional_bias=0,
|
940 |
+
use_wscale=True,
|
941 |
+
wscale_gain=_WSCALE_GAIN,
|
942 |
+
lr_mul=1.0,
|
943 |
+
activation_type='lrelu'):
|
944 |
+
"""Initializes with block settings.
|
945 |
+
|
946 |
+
Args:
|
947 |
+
in_channels: Number of channels of the input tensor.
|
948 |
+
out_channels: Number of channels of the output tensor.
|
949 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
950 |
+
(default: True)
|
951 |
+
additional_bias: The additional bias, which is independent from the
|
952 |
+
bias parameter. (default: 0.0)
|
953 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
954 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
955 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
956 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
957 |
+
(default: `lrelu`)
|
958 |
+
|
959 |
+
Raises:
|
960 |
+
NotImplementedError: If the `activation_type` is not supported.
|
961 |
+
"""
|
962 |
+
super().__init__()
|
963 |
+
weight_shape = (out_channels, in_channels)
|
964 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
965 |
+
if use_wscale:
|
966 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
967 |
+
self.wscale = wscale * lr_mul
|
968 |
+
else:
|
969 |
+
self.weight = nn.Parameter(
|
970 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
971 |
+
self.wscale = lr_mul
|
972 |
+
|
973 |
+
if add_bias:
|
974 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
975 |
+
else:
|
976 |
+
self.bias = None
|
977 |
+
self.bscale = lr_mul
|
978 |
+
self.additional_bias = additional_bias
|
979 |
+
|
980 |
+
if activation_type == 'linear':
|
981 |
+
self.activate = nn.Identity()
|
982 |
+
self.activate_scale = 1.0
|
983 |
+
elif activation_type == 'lrelu':
|
984 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
985 |
+
self.activate_scale = np.sqrt(2.0)
|
986 |
+
else:
|
987 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
988 |
+
f'`{activation_type}`!')
|
989 |
+
|
990 |
+
def forward(self, x):
|
991 |
+
if x.ndim != 2:
|
992 |
+
x = x.view(x.shape[0], -1)
|
993 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
994 |
+
x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
|
995 |
+
x = self.activate(x + self.additional_bias) * self.activate_scale
|
996 |
+
return x
|
models/stylegan_discriminator.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of discriminator described in StyleGAN.
|
3 |
+
|
4 |
+
Paper: https://arxiv.org/pdf/1812.04948.pdf
|
5 |
+
|
6 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
__all__ = ['StyleGANDiscriminator']
|
16 |
+
|
17 |
+
# Resolutions allowed.
|
18 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
19 |
+
|
20 |
+
# Initial resolution.
|
21 |
+
_INIT_RES = 4
|
22 |
+
|
23 |
+
# Fused-scale options allowed.
|
24 |
+
_FUSED_SCALE_ALLOWED = [True, False, 'auto']
|
25 |
+
|
26 |
+
# Minimal resolution for `auto` fused-scale strategy.
|
27 |
+
_AUTO_FUSED_SCALE_MIN_RES = 128
|
28 |
+
|
29 |
+
# Default gain factor for weight scaling.
|
30 |
+
_WSCALE_GAIN = np.sqrt(2.0)
|
31 |
+
|
32 |
+
|
33 |
+
class StyleGANDiscriminator(nn.Module):
|
34 |
+
"""Defines the discriminator network in StyleGAN.
|
35 |
+
|
36 |
+
NOTE: The discriminator takes images with `RGB` channel order and pixel
|
37 |
+
range [-1, 1] as inputs.
|
38 |
+
|
39 |
+
Settings for the network:
|
40 |
+
|
41 |
+
(1) resolution: The resolution of the input image.
|
42 |
+
(2) image_channels: Number of channels of the input image. (default: 3)
|
43 |
+
(3) label_size: Size of the additional label for conditional generation.
|
44 |
+
(default: 0)
|
45 |
+
(4) fused_scale: Whether to fused `conv2d` and `downsample` together,
|
46 |
+
resulting in `conv2d` with strides. (default: `auto`)
|
47 |
+
(5) use_wscale: Whether to use weight scaling. (default: True)
|
48 |
+
(6) minibatch_std_group_size: Group size for the minibatch standard
|
49 |
+
deviation layer. 0 means disable. (default: 4)
|
50 |
+
(7) minibatch_std_channels: Number of new channels after the minibatch
|
51 |
+
standard deviation layer. (default: 1)
|
52 |
+
(8) fmaps_base: Factor to control number of feature maps for each layer.
|
53 |
+
(default: 16 << 10)
|
54 |
+
(9) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self,
|
58 |
+
resolution,
|
59 |
+
image_channels=3,
|
60 |
+
label_size=0,
|
61 |
+
fused_scale='auto',
|
62 |
+
use_wscale=True,
|
63 |
+
minibatch_std_group_size=4,
|
64 |
+
minibatch_std_channels=1,
|
65 |
+
fmaps_base=16 << 10,
|
66 |
+
fmaps_max=512):
|
67 |
+
"""Initializes with basic settings.
|
68 |
+
|
69 |
+
Raises:
|
70 |
+
ValueError: If the `resolution` is not supported, or `fused_scale`
|
71 |
+
is not supported.
|
72 |
+
"""
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
76 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
77 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
78 |
+
if fused_scale not in _FUSED_SCALE_ALLOWED:
|
79 |
+
raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n'
|
80 |
+
f'Options allowed: {_FUSED_SCALE_ALLOWED}.')
|
81 |
+
|
82 |
+
self.init_res = _INIT_RES
|
83 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
84 |
+
self.resolution = resolution
|
85 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
86 |
+
self.image_channels = image_channels
|
87 |
+
self.label_size = label_size
|
88 |
+
self.fused_scale = fused_scale
|
89 |
+
self.use_wscale = use_wscale
|
90 |
+
self.minibatch_std_group_size = minibatch_std_group_size
|
91 |
+
self.minibatch_std_channels = minibatch_std_channels
|
92 |
+
self.fmaps_base = fmaps_base
|
93 |
+
self.fmaps_max = fmaps_max
|
94 |
+
|
95 |
+
# Level of detail (used for progressive training).
|
96 |
+
self.register_buffer('lod', torch.zeros(()))
|
97 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
98 |
+
|
99 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
100 |
+
res = 2 ** res_log2
|
101 |
+
block_idx = self.final_res_log2 - res_log2
|
102 |
+
|
103 |
+
# Input convolution layer for each resolution.
|
104 |
+
self.add_module(
|
105 |
+
f'input{block_idx}',
|
106 |
+
ConvBlock(in_channels=self.image_channels,
|
107 |
+
out_channels=self.get_nf(res),
|
108 |
+
kernel_size=1,
|
109 |
+
padding=0,
|
110 |
+
use_wscale=self.use_wscale))
|
111 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.weight'] = (
|
112 |
+
f'FromRGB_lod{block_idx}/weight')
|
113 |
+
self.pth_to_tf_var_mapping[f'input{block_idx}.bias'] = (
|
114 |
+
f'FromRGB_lod{block_idx}/bias')
|
115 |
+
|
116 |
+
# Convolution block for each resolution (except the last one).
|
117 |
+
if res != self.init_res:
|
118 |
+
if self.fused_scale == 'auto':
|
119 |
+
fused_scale = (res >= _AUTO_FUSED_SCALE_MIN_RES)
|
120 |
+
else:
|
121 |
+
fused_scale = self.fused_scale
|
122 |
+
self.add_module(
|
123 |
+
f'layer{2 * block_idx}',
|
124 |
+
ConvBlock(in_channels=self.get_nf(res),
|
125 |
+
out_channels=self.get_nf(res),
|
126 |
+
use_wscale=self.use_wscale))
|
127 |
+
tf_layer0_name = 'Conv0'
|
128 |
+
self.add_module(
|
129 |
+
f'layer{2 * block_idx + 1}',
|
130 |
+
ConvBlock(in_channels=self.get_nf(res),
|
131 |
+
out_channels=self.get_nf(res // 2),
|
132 |
+
downsample=True,
|
133 |
+
fused_scale=fused_scale,
|
134 |
+
use_wscale=self.use_wscale))
|
135 |
+
tf_layer1_name = 'Conv1_down'
|
136 |
+
|
137 |
+
# Convolution block for last resolution.
|
138 |
+
else:
|
139 |
+
self.add_module(
|
140 |
+
f'layer{2 * block_idx}',
|
141 |
+
ConvBlock(in_channels=self.get_nf(res),
|
142 |
+
out_channels=self.get_nf(res),
|
143 |
+
use_wscale=self.use_wscale,
|
144 |
+
minibatch_std_group_size=minibatch_std_group_size,
|
145 |
+
minibatch_std_channels=minibatch_std_channels))
|
146 |
+
tf_layer0_name = 'Conv'
|
147 |
+
self.add_module(
|
148 |
+
f'layer{2 * block_idx + 1}',
|
149 |
+
DenseBlock(in_channels=self.get_nf(res) * res * res,
|
150 |
+
out_channels=self.get_nf(res // 2),
|
151 |
+
use_wscale=self.use_wscale))
|
152 |
+
tf_layer1_name = 'Dense0'
|
153 |
+
|
154 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.weight'] = (
|
155 |
+
f'{res}x{res}/{tf_layer0_name}/weight')
|
156 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx}.bias'] = (
|
157 |
+
f'{res}x{res}/{tf_layer0_name}/bias')
|
158 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.weight'] = (
|
159 |
+
f'{res}x{res}/{tf_layer1_name}/weight')
|
160 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 1}.bias'] = (
|
161 |
+
f'{res}x{res}/{tf_layer1_name}/bias')
|
162 |
+
|
163 |
+
# Final dense block.
|
164 |
+
self.add_module(
|
165 |
+
f'layer{2 * block_idx + 2}',
|
166 |
+
DenseBlock(in_channels=self.get_nf(res // 2),
|
167 |
+
out_channels=max(self.label_size, 1),
|
168 |
+
use_wscale=self.use_wscale,
|
169 |
+
wscale_gain=1.0,
|
170 |
+
activation_type='linear'))
|
171 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.weight'] = (
|
172 |
+
f'{res}x{res}/Dense1/weight')
|
173 |
+
self.pth_to_tf_var_mapping[f'layer{2 * block_idx + 2}.bias'] = (
|
174 |
+
f'{res}x{res}/Dense1/bias')
|
175 |
+
|
176 |
+
self.downsample = DownsamplingLayer()
|
177 |
+
|
178 |
+
def get_nf(self, res):
|
179 |
+
"""Gets number of feature maps according to current resolution."""
|
180 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
181 |
+
|
182 |
+
def forward(self, image, label=None, lod=None, **_unused_kwargs):
|
183 |
+
expected_shape = (self.image_channels, self.resolution, self.resolution)
|
184 |
+
if image.ndim != 4 or image.shape[1:] != expected_shape:
|
185 |
+
raise ValueError(f'The input tensor should be with shape '
|
186 |
+
f'[batch_size, channel, height, width], where '
|
187 |
+
f'`channel` equals to {self.image_channels}, '
|
188 |
+
f'`height`, `width` equal to {self.resolution}!\n'
|
189 |
+
f'But `{image.shape}` is received!')
|
190 |
+
|
191 |
+
lod = self.lod.cpu().tolist() if lod is None else lod
|
192 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
193 |
+
raise ValueError(f'Maximum level-of-detail (lod) is '
|
194 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
195 |
+
f'but `{lod}` is received!')
|
196 |
+
|
197 |
+
if self.label_size:
|
198 |
+
if label is None:
|
199 |
+
raise ValueError(f'Model requires an additional label '
|
200 |
+
f'(with size {self.label_size}) as input, '
|
201 |
+
f'but no label is received!')
|
202 |
+
batch_size = image.shape[0]
|
203 |
+
if label.ndim != 2 or label.shape != (batch_size, self.label_size):
|
204 |
+
raise ValueError(f'Input label should be with shape '
|
205 |
+
f'[batch_size, label_size], where '
|
206 |
+
f'`batch_size` equals to that of '
|
207 |
+
f'images ({image.shape[0]}) and '
|
208 |
+
f'`label_size` equals to {self.label_size}!\n'
|
209 |
+
f'But `{label.shape}` is received!')
|
210 |
+
|
211 |
+
for res_log2 in range(self.final_res_log2, self.init_res_log2 - 1, -1):
|
212 |
+
block_idx = current_lod = self.final_res_log2 - res_log2
|
213 |
+
if current_lod <= lod < current_lod + 1:
|
214 |
+
x = self.__getattr__(f'input{block_idx}')(image)
|
215 |
+
elif current_lod - 1 < lod < current_lod:
|
216 |
+
alpha = lod - np.floor(lod)
|
217 |
+
x = (self.__getattr__(f'input{block_idx}')(image) * alpha +
|
218 |
+
x * (1 - alpha))
|
219 |
+
if lod < current_lod + 1:
|
220 |
+
x = self.__getattr__(f'layer{2 * block_idx}')(x)
|
221 |
+
x = self.__getattr__(f'layer{2 * block_idx + 1}')(x)
|
222 |
+
if lod > current_lod:
|
223 |
+
image = self.downsample(image)
|
224 |
+
x = self.__getattr__(f'layer{2 * block_idx + 2}')(x)
|
225 |
+
|
226 |
+
if self.label_size:
|
227 |
+
x = torch.sum(x * label, dim=1, keepdim=True)
|
228 |
+
|
229 |
+
return x
|
230 |
+
|
231 |
+
|
232 |
+
class MiniBatchSTDLayer(nn.Module):
|
233 |
+
"""Implements the minibatch standard deviation layer."""
|
234 |
+
|
235 |
+
def __init__(self, group_size=4, new_channels=1, epsilon=1e-8):
|
236 |
+
super().__init__()
|
237 |
+
self.group_size = group_size
|
238 |
+
self.new_channels = new_channels
|
239 |
+
self.epsilon = epsilon
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
if self.group_size <= 1:
|
243 |
+
return x
|
244 |
+
ng = min(self.group_size, x.shape[0])
|
245 |
+
nc = self.new_channels
|
246 |
+
temp_c = x.shape[1] // nc # [NCHW]
|
247 |
+
y = x.view(ng, -1, nc, temp_c, x.shape[2], x.shape[3]) # [GMncHW]
|
248 |
+
y = y - torch.mean(y, dim=0, keepdim=True) # [GMncHW]
|
249 |
+
y = torch.mean(y ** 2, dim=0) # [MncHW]
|
250 |
+
y = torch.sqrt(y + self.epsilon) # [MncHW]
|
251 |
+
y = torch.mean(y, dim=[2, 3, 4], keepdim=True) # [Mn111]
|
252 |
+
y = torch.mean(y, dim=2) # [Mn11]
|
253 |
+
y = y.repeat(ng, 1, x.shape[2], x.shape[3]) # [NnHW]
|
254 |
+
return torch.cat([x, y], dim=1)
|
255 |
+
|
256 |
+
|
257 |
+
class DownsamplingLayer(nn.Module):
|
258 |
+
"""Implements the downsampling layer.
|
259 |
+
|
260 |
+
Basically, this layer can be used to downsample feature maps with average
|
261 |
+
pooling.
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(self, scale_factor=2):
|
265 |
+
super().__init__()
|
266 |
+
self.scale_factor = scale_factor
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
if self.scale_factor <= 1:
|
270 |
+
return x
|
271 |
+
return F.avg_pool2d(x,
|
272 |
+
kernel_size=self.scale_factor,
|
273 |
+
stride=self.scale_factor,
|
274 |
+
padding=0)
|
275 |
+
|
276 |
+
|
277 |
+
class Blur(torch.autograd.Function):
|
278 |
+
"""Defines blur operation with customized gradient computation."""
|
279 |
+
|
280 |
+
@staticmethod
|
281 |
+
def forward(ctx, x, kernel):
|
282 |
+
ctx.save_for_backward(kernel)
|
283 |
+
y = F.conv2d(input=x,
|
284 |
+
weight=kernel,
|
285 |
+
bias=None,
|
286 |
+
stride=1,
|
287 |
+
padding=1,
|
288 |
+
groups=x.shape[1])
|
289 |
+
return y
|
290 |
+
|
291 |
+
@staticmethod
|
292 |
+
def backward(ctx, dy):
|
293 |
+
kernel, = ctx.saved_tensors
|
294 |
+
dx = BlurBackPropagation.apply(dy, kernel)
|
295 |
+
return dx, None, None
|
296 |
+
|
297 |
+
|
298 |
+
class BlurBackPropagation(torch.autograd.Function):
|
299 |
+
"""Defines the back propagation of blur operation.
|
300 |
+
|
301 |
+
NOTE: This is used to speed up the backward of gradient penalty.
|
302 |
+
"""
|
303 |
+
|
304 |
+
@staticmethod
|
305 |
+
def forward(ctx, dy, kernel):
|
306 |
+
ctx.save_for_backward(kernel)
|
307 |
+
dx = F.conv2d(input=dy,
|
308 |
+
weight=kernel.flip((2, 3)),
|
309 |
+
bias=None,
|
310 |
+
stride=1,
|
311 |
+
padding=1,
|
312 |
+
groups=dy.shape[1])
|
313 |
+
return dx
|
314 |
+
|
315 |
+
@staticmethod
|
316 |
+
def backward(ctx, ddx):
|
317 |
+
kernel, = ctx.saved_tensors
|
318 |
+
ddy = F.conv2d(input=ddx,
|
319 |
+
weight=kernel,
|
320 |
+
bias=None,
|
321 |
+
stride=1,
|
322 |
+
padding=1,
|
323 |
+
groups=ddx.shape[1])
|
324 |
+
return ddy, None, None
|
325 |
+
|
326 |
+
|
327 |
+
class BlurLayer(nn.Module):
|
328 |
+
"""Implements the blur layer."""
|
329 |
+
|
330 |
+
def __init__(self,
|
331 |
+
channels,
|
332 |
+
kernel=(1, 2, 1),
|
333 |
+
normalize=True):
|
334 |
+
super().__init__()
|
335 |
+
kernel = np.array(kernel, dtype=np.float32).reshape(1, -1)
|
336 |
+
kernel = kernel.T.dot(kernel)
|
337 |
+
if normalize:
|
338 |
+
kernel = kernel / np.sum(kernel)
|
339 |
+
kernel = kernel[np.newaxis, np.newaxis]
|
340 |
+
kernel = np.tile(kernel, [channels, 1, 1, 1])
|
341 |
+
self.register_buffer('kernel', torch.from_numpy(kernel))
|
342 |
+
|
343 |
+
def forward(self, x):
|
344 |
+
return Blur.apply(x, self.kernel)
|
345 |
+
|
346 |
+
|
347 |
+
class ConvBlock(nn.Module):
|
348 |
+
"""Implements the convolutional block.
|
349 |
+
|
350 |
+
Basically, this block executes minibatch standard deviation layer (if
|
351 |
+
needed), convolutional layer, activation layer, and downsampling layer (
|
352 |
+
if needed) in sequence.
|
353 |
+
"""
|
354 |
+
|
355 |
+
def __init__(self,
|
356 |
+
in_channels,
|
357 |
+
out_channels,
|
358 |
+
kernel_size=3,
|
359 |
+
stride=1,
|
360 |
+
padding=1,
|
361 |
+
add_bias=True,
|
362 |
+
downsample=False,
|
363 |
+
fused_scale=False,
|
364 |
+
use_wscale=True,
|
365 |
+
wscale_gain=_WSCALE_GAIN,
|
366 |
+
lr_mul=1.0,
|
367 |
+
activation_type='lrelu',
|
368 |
+
minibatch_std_group_size=0,
|
369 |
+
minibatch_std_channels=1):
|
370 |
+
"""Initializes with block settings.
|
371 |
+
|
372 |
+
Args:
|
373 |
+
in_channels: Number of channels of the input tensor.
|
374 |
+
out_channels: Number of channels of the output tensor.
|
375 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
376 |
+
stride: Stride parameter for convolution operation. (default: 1)
|
377 |
+
padding: Padding parameter for convolution operation. (default: 1)
|
378 |
+
add_bias: Whether to add bias onto the convolutional result.
|
379 |
+
(default: True)
|
380 |
+
downsample: Whether to downsample the result after convolution.
|
381 |
+
(default: False)
|
382 |
+
fused_scale: Whether to fused `conv2d` and `downsample` together,
|
383 |
+
resulting in `conv2d` with strides. (default: False)
|
384 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
385 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
386 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
387 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
388 |
+
(default: `lrelu`)
|
389 |
+
minibatch_std_group_size: Group size for the minibatch standard
|
390 |
+
deviation layer. 0 means disable. (default: 0)
|
391 |
+
minibatch_std_channels: Number of new channels after the minibatch
|
392 |
+
standard deviation layer. (default: 1)
|
393 |
+
|
394 |
+
Raises:
|
395 |
+
NotImplementedError: If the `activation_type` is not supported.
|
396 |
+
"""
|
397 |
+
super().__init__()
|
398 |
+
|
399 |
+
if minibatch_std_group_size > 1:
|
400 |
+
in_channels = in_channels + minibatch_std_channels
|
401 |
+
self.mbstd = MiniBatchSTDLayer(group_size=minibatch_std_group_size,
|
402 |
+
new_channels=minibatch_std_channels)
|
403 |
+
else:
|
404 |
+
self.mbstd = nn.Identity()
|
405 |
+
|
406 |
+
if downsample:
|
407 |
+
self.blur = BlurLayer(channels=in_channels)
|
408 |
+
else:
|
409 |
+
self.blur = nn.Identity()
|
410 |
+
|
411 |
+
if downsample and not fused_scale:
|
412 |
+
self.downsample = DownsamplingLayer()
|
413 |
+
else:
|
414 |
+
self.downsample = nn.Identity()
|
415 |
+
|
416 |
+
if downsample and fused_scale:
|
417 |
+
self.use_stride = True
|
418 |
+
self.stride = 2
|
419 |
+
self.padding = 1
|
420 |
+
else:
|
421 |
+
self.use_stride = False
|
422 |
+
self.stride = stride
|
423 |
+
self.padding = padding
|
424 |
+
|
425 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
426 |
+
fan_in = kernel_size * kernel_size * in_channels
|
427 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
428 |
+
if use_wscale:
|
429 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
430 |
+
self.wscale = wscale * lr_mul
|
431 |
+
else:
|
432 |
+
self.weight = nn.Parameter(
|
433 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
434 |
+
self.wscale = lr_mul
|
435 |
+
|
436 |
+
if add_bias:
|
437 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
438 |
+
self.bscale = lr_mul
|
439 |
+
else:
|
440 |
+
self.bias = None
|
441 |
+
|
442 |
+
if activation_type == 'linear':
|
443 |
+
self.activate = nn.Identity()
|
444 |
+
elif activation_type == 'lrelu':
|
445 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
446 |
+
else:
|
447 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
448 |
+
f'`{activation_type}`!')
|
449 |
+
|
450 |
+
def forward(self, x):
|
451 |
+
x = self.mbstd(x)
|
452 |
+
x = self.blur(x)
|
453 |
+
weight = self.weight * self.wscale
|
454 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
455 |
+
if self.use_stride:
|
456 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0.0)
|
457 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
458 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1]) * 0.25
|
459 |
+
x = F.conv2d(x,
|
460 |
+
weight=weight,
|
461 |
+
bias=bias,
|
462 |
+
stride=self.stride,
|
463 |
+
padding=self.padding)
|
464 |
+
x = self.downsample(x)
|
465 |
+
x = self.activate(x)
|
466 |
+
return x
|
467 |
+
|
468 |
+
|
469 |
+
class DenseBlock(nn.Module):
|
470 |
+
"""Implements the dense block.
|
471 |
+
|
472 |
+
Basically, this block executes fully-connected layer and activation layer.
|
473 |
+
"""
|
474 |
+
|
475 |
+
def __init__(self,
|
476 |
+
in_channels,
|
477 |
+
out_channels,
|
478 |
+
add_bias=True,
|
479 |
+
use_wscale=True,
|
480 |
+
wscale_gain=_WSCALE_GAIN,
|
481 |
+
lr_mul=1.0,
|
482 |
+
activation_type='lrelu'):
|
483 |
+
"""Initializes with block settings.
|
484 |
+
|
485 |
+
Args:
|
486 |
+
in_channels: Number of channels of the input tensor.
|
487 |
+
out_channels: Number of channels of the output tensor.
|
488 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
489 |
+
(default: True)
|
490 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
491 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
492 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
493 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
494 |
+
(default: `lrelu`)
|
495 |
+
|
496 |
+
Raises:
|
497 |
+
NotImplementedError: If the `activation_type` is not supported.
|
498 |
+
"""
|
499 |
+
super().__init__()
|
500 |
+
weight_shape = (out_channels, in_channels)
|
501 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
502 |
+
if use_wscale:
|
503 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
504 |
+
self.wscale = wscale * lr_mul
|
505 |
+
else:
|
506 |
+
self.weight = nn.Parameter(
|
507 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
508 |
+
self.wscale = lr_mul
|
509 |
+
|
510 |
+
if add_bias:
|
511 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
512 |
+
self.bscale = lr_mul
|
513 |
+
else:
|
514 |
+
self.bias = None
|
515 |
+
|
516 |
+
if activation_type == 'linear':
|
517 |
+
self.activate = nn.Identity()
|
518 |
+
elif activation_type == 'lrelu':
|
519 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
520 |
+
else:
|
521 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
522 |
+
f'`{activation_type}`!')
|
523 |
+
|
524 |
+
def forward(self, x):
|
525 |
+
if x.ndim != 2:
|
526 |
+
x = x.view(x.shape[0], -1)
|
527 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
528 |
+
x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
|
529 |
+
x = self.activate(x)
|
530 |
+
return x
|
models/stylegan_generator.py
ADDED
@@ -0,0 +1,869 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of generator described in StyleGAN.
|
3 |
+
|
4 |
+
Paper: https://arxiv.org/pdf/1812.04948.pdf
|
5 |
+
|
6 |
+
Official TensorFlow implementation: https://github.com/NVlabs/stylegan
|
7 |
+
"""
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from .sync_op import all_gather
|
16 |
+
|
17 |
+
__all__ = ['StyleGANGenerator']
|
18 |
+
|
19 |
+
# Resolutions allowed.
|
20 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
21 |
+
|
22 |
+
# Initial resolution.
|
23 |
+
_INIT_RES = 4
|
24 |
+
|
25 |
+
# Fused-scale options allowed.
|
26 |
+
_FUSED_SCALE_ALLOWED = [True, False, 'auto']
|
27 |
+
|
28 |
+
# Minimal resolution for `auto` fused-scale strategy.
|
29 |
+
_AUTO_FUSED_SCALE_MIN_RES = 128
|
30 |
+
|
31 |
+
# Default gain factor for weight scaling.
|
32 |
+
_WSCALE_GAIN = np.sqrt(2.0)
|
33 |
+
_STYLEMOD_WSCALE_GAIN = 1.0
|
34 |
+
|
35 |
+
|
36 |
+
class StyleGANGenerator(nn.Module):
|
37 |
+
"""Defines the generator network in StyleGAN.
|
38 |
+
|
39 |
+
NOTE: The synthesized images are with `RGB` channel order and pixel range
|
40 |
+
[-1, 1].
|
41 |
+
|
42 |
+
Settings for the mapping network:
|
43 |
+
|
44 |
+
(1) z_space_dim: Dimension of the input latent space, Z. (default: 512)
|
45 |
+
(2) w_space_dim: Dimension of the outout latent space, W. (default: 512)
|
46 |
+
(3) label_size: Size of the additional label for conditional generation.
|
47 |
+
(default: 0)
|
48 |
+
(4)mapping_layers: Number of layers of the mapping network. (default: 8)
|
49 |
+
(5) mapping_fmaps: Number of hidden channels of the mapping network.
|
50 |
+
(default: 512)
|
51 |
+
(6) mapping_lr_mul: Learning rate multiplier for the mapping network.
|
52 |
+
(default: 0.01)
|
53 |
+
(7) repeat_w: Repeat w-code for different layers.
|
54 |
+
|
55 |
+
Settings for the synthesis network:
|
56 |
+
|
57 |
+
(1) resolution: The resolution of the output image.
|
58 |
+
(2) image_channels: Number of channels of the output image. (default: 3)
|
59 |
+
(3) final_tanh: Whether to use `tanh` to control the final pixel range.
|
60 |
+
(default: False)
|
61 |
+
(4) const_input: Whether to use a constant in the first convolutional layer.
|
62 |
+
(default: True)
|
63 |
+
(5) fused_scale: Whether to fused `upsample` and `conv2d` together,
|
64 |
+
resulting in `conv2d_transpose`. (default: `auto`)
|
65 |
+
(6) use_wscale: Whether to use weight scaling. (default: True)
|
66 |
+
(7) fmaps_base: Factor to control number of feature maps for each layer.
|
67 |
+
(default: 16 << 10)
|
68 |
+
(8) fmaps_max: Maximum number of feature maps in each layer. (default: 512)
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self,
|
72 |
+
resolution,
|
73 |
+
z_space_dim=512,
|
74 |
+
w_space_dim=512,
|
75 |
+
label_size=0,
|
76 |
+
mapping_layers=8,
|
77 |
+
mapping_fmaps=512,
|
78 |
+
mapping_lr_mul=0.01,
|
79 |
+
repeat_w=True,
|
80 |
+
image_channels=3,
|
81 |
+
final_tanh=False,
|
82 |
+
const_input=True,
|
83 |
+
fused_scale='auto',
|
84 |
+
use_wscale=True,
|
85 |
+
fmaps_base=16 << 10,
|
86 |
+
fmaps_max=512):
|
87 |
+
"""Initializes with basic settings.
|
88 |
+
|
89 |
+
Raises:
|
90 |
+
ValueError: If the `resolution` is not supported, or `fused_scale`
|
91 |
+
is not supported.
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
|
95 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
96 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
97 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
98 |
+
if fused_scale not in _FUSED_SCALE_ALLOWED:
|
99 |
+
raise ValueError(f'Invalid fused-scale option: `{fused_scale}`!\n'
|
100 |
+
f'Options allowed: {_FUSED_SCALE_ALLOWED}.')
|
101 |
+
|
102 |
+
self.init_res = _INIT_RES
|
103 |
+
self.resolution = resolution
|
104 |
+
self.z_space_dim = z_space_dim
|
105 |
+
self.w_space_dim = w_space_dim
|
106 |
+
self.label_size = label_size
|
107 |
+
self.mapping_layers = mapping_layers
|
108 |
+
self.mapping_fmaps = mapping_fmaps
|
109 |
+
self.mapping_lr_mul = mapping_lr_mul
|
110 |
+
self.repeat_w = repeat_w
|
111 |
+
self.image_channels = image_channels
|
112 |
+
self.final_tanh = final_tanh
|
113 |
+
self.const_input = const_input
|
114 |
+
self.fused_scale = fused_scale
|
115 |
+
self.use_wscale = use_wscale
|
116 |
+
self.fmaps_base = fmaps_base
|
117 |
+
self.fmaps_max = fmaps_max
|
118 |
+
|
119 |
+
self.num_layers = int(np.log2(self.resolution // self.init_res * 2)) * 2
|
120 |
+
|
121 |
+
if self.repeat_w:
|
122 |
+
self.mapping_space_dim = self.w_space_dim
|
123 |
+
else:
|
124 |
+
self.mapping_space_dim = self.w_space_dim * self.num_layers
|
125 |
+
self.mapping = MappingModule(input_space_dim=self.z_space_dim,
|
126 |
+
hidden_space_dim=self.mapping_fmaps,
|
127 |
+
final_space_dim=self.mapping_space_dim,
|
128 |
+
label_size=self.label_size,
|
129 |
+
num_layers=self.mapping_layers,
|
130 |
+
use_wscale=self.use_wscale,
|
131 |
+
lr_mul=self.mapping_lr_mul)
|
132 |
+
|
133 |
+
self.truncation = TruncationModule(w_space_dim=self.w_space_dim,
|
134 |
+
num_layers=self.num_layers,
|
135 |
+
repeat_w=self.repeat_w)
|
136 |
+
|
137 |
+
self.synthesis = SynthesisModule(resolution=self.resolution,
|
138 |
+
init_resolution=self.init_res,
|
139 |
+
w_space_dim=self.w_space_dim,
|
140 |
+
image_channels=self.image_channels,
|
141 |
+
final_tanh=self.final_tanh,
|
142 |
+
const_input=self.const_input,
|
143 |
+
fused_scale=self.fused_scale,
|
144 |
+
use_wscale=self.use_wscale,
|
145 |
+
fmaps_base=self.fmaps_base,
|
146 |
+
fmaps_max=self.fmaps_max)
|
147 |
+
|
148 |
+
self.pth_to_tf_var_mapping = {}
|
149 |
+
for key, val in self.mapping.pth_to_tf_var_mapping.items():
|
150 |
+
self.pth_to_tf_var_mapping[f'mapping.{key}'] = val
|
151 |
+
for key, val in self.truncation.pth_to_tf_var_mapping.items():
|
152 |
+
self.pth_to_tf_var_mapping[f'truncation.{key}'] = val
|
153 |
+
for key, val in self.synthesis.pth_to_tf_var_mapping.items():
|
154 |
+
self.pth_to_tf_var_mapping[f'synthesis.{key}'] = val
|
155 |
+
|
156 |
+
def forward(self,
|
157 |
+
z,
|
158 |
+
label=None,
|
159 |
+
lod=None,
|
160 |
+
w_moving_decay=0.995,
|
161 |
+
style_mixing_prob=0.9,
|
162 |
+
trunc_psi=None,
|
163 |
+
trunc_layers=None,
|
164 |
+
randomize_noise=False,
|
165 |
+
**_unused_kwargs):
|
166 |
+
mapping_results = self.mapping(z, label)
|
167 |
+
w = mapping_results['w']
|
168 |
+
|
169 |
+
if self.training and w_moving_decay < 1:
|
170 |
+
batch_w_avg = all_gather(w).mean(dim=0)
|
171 |
+
self.truncation.w_avg.copy_(
|
172 |
+
self.truncation.w_avg * w_moving_decay +
|
173 |
+
batch_w_avg * (1 - w_moving_decay))
|
174 |
+
|
175 |
+
if self.training and style_mixing_prob > 0:
|
176 |
+
new_z = torch.randn_like(z)
|
177 |
+
new_w = self.mapping(new_z, label)['w']
|
178 |
+
lod = self.synthesis.lod.cpu().tolist() if lod is None else lod
|
179 |
+
current_layers = self.num_layers - int(lod) * 2
|
180 |
+
if np.random.uniform() < style_mixing_prob:
|
181 |
+
mixing_cutoff = np.random.randint(1, current_layers)
|
182 |
+
w = self.truncation(w)
|
183 |
+
new_w = self.truncation(new_w)
|
184 |
+
w[:, mixing_cutoff:] = new_w[:, mixing_cutoff:]
|
185 |
+
|
186 |
+
wp = self.truncation(w, trunc_psi, trunc_layers)
|
187 |
+
synthesis_results = self.synthesis(wp, lod, randomize_noise)
|
188 |
+
|
189 |
+
return {**mapping_results, **synthesis_results}
|
190 |
+
|
191 |
+
|
192 |
+
class MappingModule(nn.Module):
|
193 |
+
"""Implements the latent space mapping module.
|
194 |
+
|
195 |
+
Basically, this module executes several dense layers in sequence.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(self,
|
199 |
+
input_space_dim=512,
|
200 |
+
hidden_space_dim=512,
|
201 |
+
final_space_dim=512,
|
202 |
+
label_size=0,
|
203 |
+
num_layers=8,
|
204 |
+
normalize_input=True,
|
205 |
+
use_wscale=True,
|
206 |
+
lr_mul=0.01):
|
207 |
+
super().__init__()
|
208 |
+
|
209 |
+
self.input_space_dim = input_space_dim
|
210 |
+
self.hidden_space_dim = hidden_space_dim
|
211 |
+
self.final_space_dim = final_space_dim
|
212 |
+
self.label_size = label_size
|
213 |
+
self.num_layers = num_layers
|
214 |
+
self.normalize_input = normalize_input
|
215 |
+
self.use_wscale = use_wscale
|
216 |
+
self.lr_mul = lr_mul
|
217 |
+
|
218 |
+
self.norm = PixelNormLayer() if self.normalize_input else nn.Identity()
|
219 |
+
|
220 |
+
self.pth_to_tf_var_mapping = {}
|
221 |
+
for i in range(num_layers):
|
222 |
+
dim_mul = 2 if label_size else 1
|
223 |
+
in_channels = (input_space_dim * dim_mul if i == 0 else
|
224 |
+
hidden_space_dim)
|
225 |
+
out_channels = (final_space_dim if i == (num_layers - 1) else
|
226 |
+
hidden_space_dim)
|
227 |
+
self.add_module(f'dense{i}',
|
228 |
+
DenseBlock(in_channels=in_channels,
|
229 |
+
out_channels=out_channels,
|
230 |
+
use_wscale=self.use_wscale,
|
231 |
+
lr_mul=self.lr_mul))
|
232 |
+
self.pth_to_tf_var_mapping[f'dense{i}.weight'] = f'Dense{i}/weight'
|
233 |
+
self.pth_to_tf_var_mapping[f'dense{i}.bias'] = f'Dense{i}/bias'
|
234 |
+
if label_size:
|
235 |
+
self.label_weight = nn.Parameter(
|
236 |
+
torch.randn(label_size, input_space_dim))
|
237 |
+
self.pth_to_tf_var_mapping[f'label_weight'] = f'LabelConcat/weight'
|
238 |
+
|
239 |
+
def forward(self, z, label=None):
|
240 |
+
if z.ndim != 2 or z.shape[1] != self.input_space_dim:
|
241 |
+
raise ValueError(f'Input latent code should be with shape '
|
242 |
+
f'[batch_size, input_dim], where '
|
243 |
+
f'`input_dim` equals to {self.input_space_dim}!\n'
|
244 |
+
f'But `{z.shape}` is received!')
|
245 |
+
if self.label_size:
|
246 |
+
if label is None:
|
247 |
+
raise ValueError(f'Model requires an additional label '
|
248 |
+
f'(with size {self.label_size}) as input, '
|
249 |
+
f'but no label is received!')
|
250 |
+
if label.ndim != 2 or label.shape != (z.shape[0], self.label_size):
|
251 |
+
raise ValueError(f'Input label should be with shape '
|
252 |
+
f'[batch_size, label_size], where '
|
253 |
+
f'`batch_size` equals to that of '
|
254 |
+
f'latent codes ({z.shape[0]}) and '
|
255 |
+
f'`label_size` equals to {self.label_size}!\n'
|
256 |
+
f'But `{label.shape}` is received!')
|
257 |
+
embedding = torch.matmul(label, self.label_weight)
|
258 |
+
z = torch.cat((z, embedding), dim=1)
|
259 |
+
|
260 |
+
z = self.norm(z)
|
261 |
+
w = z
|
262 |
+
for i in range(self.num_layers):
|
263 |
+
w = self.__getattr__(f'dense{i}')(w)
|
264 |
+
results = {
|
265 |
+
'z': z,
|
266 |
+
'label': label,
|
267 |
+
'w': w,
|
268 |
+
}
|
269 |
+
if self.label_size:
|
270 |
+
results['embedding'] = embedding
|
271 |
+
return results
|
272 |
+
|
273 |
+
|
274 |
+
class TruncationModule(nn.Module):
|
275 |
+
"""Implements the truncation module.
|
276 |
+
|
277 |
+
Truncation is executed as follows:
|
278 |
+
|
279 |
+
For layers in range [0, truncation_layers), the truncated w-code is computed
|
280 |
+
as
|
281 |
+
|
282 |
+
w_new = w_avg + (w - w_avg) * truncation_psi
|
283 |
+
|
284 |
+
To disable truncation, please set
|
285 |
+
(1) truncation_psi = 1.0 (None) OR
|
286 |
+
(2) truncation_layers = 0 (None)
|
287 |
+
|
288 |
+
NOTE: The returned tensor is layer-wise style codes.
|
289 |
+
"""
|
290 |
+
|
291 |
+
def __init__(self, w_space_dim, num_layers, repeat_w=True):
|
292 |
+
super().__init__()
|
293 |
+
|
294 |
+
self.num_layers = num_layers
|
295 |
+
self.w_space_dim = w_space_dim
|
296 |
+
self.repeat_w = repeat_w
|
297 |
+
|
298 |
+
if self.repeat_w:
|
299 |
+
self.register_buffer('w_avg', torch.zeros(w_space_dim))
|
300 |
+
else:
|
301 |
+
self.register_buffer('w_avg', torch.zeros(num_layers * w_space_dim))
|
302 |
+
self.pth_to_tf_var_mapping = {'w_avg': 'dlatent_avg'}
|
303 |
+
|
304 |
+
def forward(self, w, trunc_psi=None, trunc_layers=None):
|
305 |
+
if w.ndim == 2:
|
306 |
+
if self.repeat_w and w.shape[1] == self.w_space_dim:
|
307 |
+
w = w.view(-1, 1, self.w_space_dim)
|
308 |
+
wp = w.repeat(1, self.num_layers, 1)
|
309 |
+
else:
|
310 |
+
assert w.shape[1] == self.w_space_dim * self.num_layers
|
311 |
+
wp = w.view(-1, self.num_layers, self.w_space_dim)
|
312 |
+
else:
|
313 |
+
wp = w
|
314 |
+
assert wp.ndim == 3
|
315 |
+
assert wp.shape[1:] == (self.num_layers, self.w_space_dim)
|
316 |
+
|
317 |
+
trunc_psi = 1.0 if trunc_psi is None else trunc_psi
|
318 |
+
trunc_layers = 0 if trunc_layers is None else trunc_layers
|
319 |
+
if trunc_psi < 1.0 and trunc_layers > 0:
|
320 |
+
layer_idx = np.arange(self.num_layers).reshape(1, -1, 1)
|
321 |
+
coefs = np.ones_like(layer_idx, dtype=np.float32)
|
322 |
+
coefs[layer_idx < trunc_layers] *= trunc_psi
|
323 |
+
coefs = torch.from_numpy(coefs).to(wp)
|
324 |
+
w_avg = self.w_avg.view(1, -1, self.w_space_dim)
|
325 |
+
wp = w_avg + (wp - w_avg) * coefs
|
326 |
+
return wp
|
327 |
+
|
328 |
+
|
329 |
+
class SynthesisModule(nn.Module):
|
330 |
+
"""Implements the image synthesis module.
|
331 |
+
|
332 |
+
Basically, this module executes several convolutional layers in sequence.
|
333 |
+
"""
|
334 |
+
|
335 |
+
def __init__(self,
|
336 |
+
resolution=1024,
|
337 |
+
init_resolution=4,
|
338 |
+
w_space_dim=512,
|
339 |
+
image_channels=3,
|
340 |
+
final_tanh=False,
|
341 |
+
const_input=True,
|
342 |
+
fused_scale='auto',
|
343 |
+
use_wscale=True,
|
344 |
+
fmaps_base=16 << 10,
|
345 |
+
fmaps_max=512):
|
346 |
+
super().__init__()
|
347 |
+
|
348 |
+
self.init_res = init_resolution
|
349 |
+
self.init_res_log2 = int(np.log2(self.init_res))
|
350 |
+
self.resolution = resolution
|
351 |
+
self.final_res_log2 = int(np.log2(self.resolution))
|
352 |
+
self.w_space_dim = w_space_dim
|
353 |
+
self.image_channels = image_channels
|
354 |
+
self.final_tanh = final_tanh
|
355 |
+
self.const_input = const_input
|
356 |
+
self.fused_scale = fused_scale
|
357 |
+
self.use_wscale = use_wscale
|
358 |
+
self.fmaps_base = fmaps_base
|
359 |
+
self.fmaps_max = fmaps_max
|
360 |
+
|
361 |
+
self.num_layers = (self.final_res_log2 - self.init_res_log2 + 1) * 2
|
362 |
+
|
363 |
+
# Level of detail (used for progressive training).
|
364 |
+
self.register_buffer('lod', torch.zeros(()))
|
365 |
+
self.pth_to_tf_var_mapping = {'lod': 'lod'}
|
366 |
+
|
367 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
368 |
+
res = 2 ** res_log2
|
369 |
+
block_idx = res_log2 - self.init_res_log2
|
370 |
+
|
371 |
+
# First convolution layer for each resolution.
|
372 |
+
layer_name = f'layer{2 * block_idx}'
|
373 |
+
if res == self.init_res:
|
374 |
+
if self.const_input:
|
375 |
+
self.add_module(layer_name,
|
376 |
+
ConvBlock(in_channels=self.get_nf(res),
|
377 |
+
out_channels=self.get_nf(res),
|
378 |
+
resolution=self.init_res,
|
379 |
+
w_space_dim=self.w_space_dim,
|
380 |
+
position='const_init',
|
381 |
+
use_wscale=self.use_wscale))
|
382 |
+
tf_layer_name = 'Const'
|
383 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.const'] = (
|
384 |
+
f'{res}x{res}/{tf_layer_name}/const')
|
385 |
+
else:
|
386 |
+
self.add_module(layer_name,
|
387 |
+
ConvBlock(in_channels=self.w_space_dim,
|
388 |
+
out_channels=self.get_nf(res),
|
389 |
+
resolution=self.init_res,
|
390 |
+
w_space_dim=self.w_space_dim,
|
391 |
+
kernel_size=self.init_res,
|
392 |
+
padding=self.init_res - 1,
|
393 |
+
use_wscale=self.use_wscale))
|
394 |
+
tf_layer_name = 'Dense'
|
395 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
396 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
397 |
+
else:
|
398 |
+
if self.fused_scale == 'auto':
|
399 |
+
fused_scale = (res >= _AUTO_FUSED_SCALE_MIN_RES)
|
400 |
+
else:
|
401 |
+
fused_scale = self.fused_scale
|
402 |
+
self.add_module(layer_name,
|
403 |
+
ConvBlock(in_channels=self.get_nf(res // 2),
|
404 |
+
out_channels=self.get_nf(res),
|
405 |
+
resolution=res,
|
406 |
+
w_space_dim=self.w_space_dim,
|
407 |
+
upsample=True,
|
408 |
+
fused_scale=fused_scale,
|
409 |
+
use_wscale=self.use_wscale))
|
410 |
+
tf_layer_name = 'Conv0_up'
|
411 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
412 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
413 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
414 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
415 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
416 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/weight')
|
417 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
418 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/bias')
|
419 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.weight'] = (
|
420 |
+
f'{res}x{res}/{tf_layer_name}/Noise/weight')
|
421 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.noise'] = (
|
422 |
+
f'noise{2 * block_idx}')
|
423 |
+
|
424 |
+
# Second convolution layer for each resolution.
|
425 |
+
layer_name = f'layer{2 * block_idx + 1}'
|
426 |
+
self.add_module(layer_name,
|
427 |
+
ConvBlock(in_channels=self.get_nf(res),
|
428 |
+
out_channels=self.get_nf(res),
|
429 |
+
resolution=res,
|
430 |
+
w_space_dim=self.w_space_dim,
|
431 |
+
use_wscale=self.use_wscale))
|
432 |
+
tf_layer_name = 'Conv' if res == self.init_res else 'Conv1'
|
433 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.weight'] = (
|
434 |
+
f'{res}x{res}/{tf_layer_name}/weight')
|
435 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.bias'] = (
|
436 |
+
f'{res}x{res}/{tf_layer_name}/bias')
|
437 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.weight'] = (
|
438 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/weight')
|
439 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.style.bias'] = (
|
440 |
+
f'{res}x{res}/{tf_layer_name}/StyleMod/bias')
|
441 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.weight'] = (
|
442 |
+
f'{res}x{res}/{tf_layer_name}/Noise/weight')
|
443 |
+
self.pth_to_tf_var_mapping[f'{layer_name}.apply_noise.noise'] = (
|
444 |
+
f'noise{2 * block_idx + 1}')
|
445 |
+
|
446 |
+
# Output convolution layer for each resolution.
|
447 |
+
self.add_module(f'output{block_idx}',
|
448 |
+
ConvBlock(in_channels=self.get_nf(res),
|
449 |
+
out_channels=self.image_channels,
|
450 |
+
resolution=res,
|
451 |
+
w_space_dim=self.w_space_dim,
|
452 |
+
position='last',
|
453 |
+
kernel_size=1,
|
454 |
+
padding=0,
|
455 |
+
use_wscale=self.use_wscale,
|
456 |
+
wscale_gain=1.0,
|
457 |
+
activation_type='linear'))
|
458 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.weight'] = (
|
459 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/weight')
|
460 |
+
self.pth_to_tf_var_mapping[f'output{block_idx}.bias'] = (
|
461 |
+
f'ToRGB_lod{self.final_res_log2 - res_log2}/bias')
|
462 |
+
|
463 |
+
self.upsample = UpsamplingLayer()
|
464 |
+
self.final_activate = nn.Tanh() if final_tanh else nn.Identity()
|
465 |
+
|
466 |
+
def get_nf(self, res):
|
467 |
+
"""Gets number of feature maps according to current resolution."""
|
468 |
+
return min(self.fmaps_base // res, self.fmaps_max)
|
469 |
+
|
470 |
+
def forward(self, wp, lod=None, randomize_noise=False):
|
471 |
+
if wp.ndim != 3 or wp.shape[1:] != (self.num_layers, self.w_space_dim):
|
472 |
+
raise ValueError(f'Input tensor should be with shape '
|
473 |
+
f'[batch_size, num_layers, w_space_dim], where '
|
474 |
+
f'`num_layers` equals to {self.num_layers}, and '
|
475 |
+
f'`w_space_dim` equals to {self.w_space_dim}!\n'
|
476 |
+
f'But `{wp.shape}` is received!')
|
477 |
+
|
478 |
+
lod = self.lod.cpu().tolist() if lod is None else lod
|
479 |
+
if lod + self.init_res_log2 > self.final_res_log2:
|
480 |
+
raise ValueError(f'Maximum level-of-detail (lod) is '
|
481 |
+
f'{self.final_res_log2 - self.init_res_log2}, '
|
482 |
+
f'but `{lod}` is received!')
|
483 |
+
|
484 |
+
results = {'wp': wp}
|
485 |
+
for res_log2 in range(self.init_res_log2, self.final_res_log2 + 1):
|
486 |
+
current_lod = self.final_res_log2 - res_log2
|
487 |
+
if lod < current_lod + 1:
|
488 |
+
block_idx = res_log2 - self.init_res_log2
|
489 |
+
if block_idx == 0:
|
490 |
+
if self.const_input:
|
491 |
+
x, style = self.layer0(None, wp[:, 0], randomize_noise)
|
492 |
+
else:
|
493 |
+
x = wp[:, 0].view(-1, self.w_space_dim, 1, 1)
|
494 |
+
x, style = self.layer0(x, wp[:, 0], randomize_noise)
|
495 |
+
else:
|
496 |
+
x, style = self.__getattr__(f'layer{2 * block_idx}')(
|
497 |
+
x, wp[:, 2 * block_idx])
|
498 |
+
results[f'style{2 * block_idx:02d}'] = style
|
499 |
+
x, style = self.__getattr__(f'layer{2 * block_idx + 1}')(
|
500 |
+
x, wp[:, 2 * block_idx + 1])
|
501 |
+
results[f'style{2 * block_idx + 1:02d}'] = style
|
502 |
+
if current_lod - 1 < lod <= current_lod:
|
503 |
+
image = self.__getattr__(f'output{block_idx}')(x, None)
|
504 |
+
elif current_lod < lod < current_lod + 1:
|
505 |
+
alpha = np.ceil(lod) - lod
|
506 |
+
image = (self.__getattr__(f'output{block_idx}')(x, None) * alpha
|
507 |
+
+ self.upsample(image) * (1 - alpha))
|
508 |
+
elif lod >= current_lod + 1:
|
509 |
+
image = self.upsample(image)
|
510 |
+
results['image'] = self.final_activate(image)
|
511 |
+
return results
|
512 |
+
|
513 |
+
|
514 |
+
class PixelNormLayer(nn.Module):
|
515 |
+
"""Implements pixel-wise feature vector normalization layer."""
|
516 |
+
|
517 |
+
def __init__(self, epsilon=1e-8):
|
518 |
+
super().__init__()
|
519 |
+
self.eps = epsilon
|
520 |
+
|
521 |
+
def forward(self, x):
|
522 |
+
norm = torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.eps)
|
523 |
+
return x / norm
|
524 |
+
|
525 |
+
|
526 |
+
class InstanceNormLayer(nn.Module):
|
527 |
+
"""Implements instance normalization layer."""
|
528 |
+
|
529 |
+
def __init__(self, epsilon=1e-8):
|
530 |
+
super().__init__()
|
531 |
+
self.eps = epsilon
|
532 |
+
|
533 |
+
def forward(self, x):
|
534 |
+
if x.ndim != 4:
|
535 |
+
raise ValueError(f'The input tensor should be with shape '
|
536 |
+
f'[batch_size, channel, height, width], '
|
537 |
+
f'but `{x.shape}` is received!')
|
538 |
+
x = x - torch.mean(x, dim=[2, 3], keepdim=True)
|
539 |
+
norm = torch.sqrt(
|
540 |
+
torch.mean(x ** 2, dim=[2, 3], keepdim=True) + self.eps)
|
541 |
+
return x / norm
|
542 |
+
|
543 |
+
|
544 |
+
class UpsamplingLayer(nn.Module):
|
545 |
+
"""Implements the upsampling layer.
|
546 |
+
|
547 |
+
Basically, this layer can be used to upsample feature maps with nearest
|
548 |
+
neighbor interpolation.
|
549 |
+
"""
|
550 |
+
|
551 |
+
def __init__(self, scale_factor=2):
|
552 |
+
super().__init__()
|
553 |
+
self.scale_factor = scale_factor
|
554 |
+
|
555 |
+
def forward(self, x):
|
556 |
+
if self.scale_factor <= 1:
|
557 |
+
return x
|
558 |
+
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
|
559 |
+
|
560 |
+
|
561 |
+
class Blur(torch.autograd.Function):
|
562 |
+
"""Defines blur operation with customized gradient computation."""
|
563 |
+
|
564 |
+
@staticmethod
|
565 |
+
def forward(ctx, x, kernel):
|
566 |
+
ctx.save_for_backward(kernel)
|
567 |
+
y = F.conv2d(input=x,
|
568 |
+
weight=kernel,
|
569 |
+
bias=None,
|
570 |
+
stride=1,
|
571 |
+
padding=1,
|
572 |
+
groups=x.shape[1])
|
573 |
+
return y
|
574 |
+
|
575 |
+
@staticmethod
|
576 |
+
def backward(ctx, dy):
|
577 |
+
kernel, = ctx.saved_tensors
|
578 |
+
dx = F.conv2d(input=dy,
|
579 |
+
weight=kernel.flip((2, 3)),
|
580 |
+
bias=None,
|
581 |
+
stride=1,
|
582 |
+
padding=1,
|
583 |
+
groups=dy.shape[1])
|
584 |
+
return dx, None, None
|
585 |
+
|
586 |
+
|
587 |
+
class BlurLayer(nn.Module):
|
588 |
+
"""Implements the blur layer."""
|
589 |
+
|
590 |
+
def __init__(self,
|
591 |
+
channels,
|
592 |
+
kernel=(1, 2, 1),
|
593 |
+
normalize=True):
|
594 |
+
super().__init__()
|
595 |
+
kernel = np.array(kernel, dtype=np.float32).reshape(1, -1)
|
596 |
+
kernel = kernel.T.dot(kernel)
|
597 |
+
if normalize:
|
598 |
+
kernel /= np.sum(kernel)
|
599 |
+
kernel = kernel[np.newaxis, np.newaxis]
|
600 |
+
kernel = np.tile(kernel, [channels, 1, 1, 1])
|
601 |
+
self.register_buffer('kernel', torch.from_numpy(kernel))
|
602 |
+
|
603 |
+
def forward(self, x):
|
604 |
+
return Blur.apply(x, self.kernel)
|
605 |
+
|
606 |
+
|
607 |
+
class NoiseApplyingLayer(nn.Module):
|
608 |
+
"""Implements the noise applying layer."""
|
609 |
+
|
610 |
+
def __init__(self, resolution, channels):
|
611 |
+
super().__init__()
|
612 |
+
self.res = resolution
|
613 |
+
self.register_buffer('noise', torch.randn(1, 1, self.res, self.res))
|
614 |
+
self.weight = nn.Parameter(torch.zeros(channels))
|
615 |
+
|
616 |
+
def forward(self, x, randomize_noise=False):
|
617 |
+
if x.ndim != 4:
|
618 |
+
raise ValueError(f'The input tensor should be with shape '
|
619 |
+
f'[batch_size, channel, height, width], '
|
620 |
+
f'but `{x.shape}` is received!')
|
621 |
+
if randomize_noise:
|
622 |
+
noise = torch.randn(x.shape[0], 1, self.res, self.res).to(x)
|
623 |
+
else:
|
624 |
+
noise = self.noise
|
625 |
+
return x + noise * self.weight.view(1, -1, 1, 1)
|
626 |
+
|
627 |
+
|
628 |
+
class StyleModLayer(nn.Module):
|
629 |
+
"""Implements the style modulation layer."""
|
630 |
+
|
631 |
+
def __init__(self,
|
632 |
+
w_space_dim,
|
633 |
+
out_channels,
|
634 |
+
use_wscale=True):
|
635 |
+
super().__init__()
|
636 |
+
self.w_space_dim = w_space_dim
|
637 |
+
self.out_channels = out_channels
|
638 |
+
|
639 |
+
weight_shape = (self.out_channels * 2, self.w_space_dim)
|
640 |
+
wscale = _STYLEMOD_WSCALE_GAIN / np.sqrt(self.w_space_dim)
|
641 |
+
if use_wscale:
|
642 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
643 |
+
self.wscale = wscale
|
644 |
+
else:
|
645 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) * wscale)
|
646 |
+
self.wscale = 1.0
|
647 |
+
|
648 |
+
self.bias = nn.Parameter(torch.zeros(self.out_channels * 2))
|
649 |
+
|
650 |
+
def forward(self, x, w):
|
651 |
+
if w.ndim != 2 or w.shape[1] != self.w_space_dim:
|
652 |
+
raise ValueError(f'The input tensor should be with shape '
|
653 |
+
f'[batch_size, w_space_dim], where '
|
654 |
+
f'`w_space_dim` equals to {self.w_space_dim}!\n'
|
655 |
+
f'But `{w.shape}` is received!')
|
656 |
+
style = F.linear(w, weight=self.weight * self.wscale, bias=self.bias)
|
657 |
+
style_split = style.view(-1, 2, self.out_channels, 1, 1)
|
658 |
+
x = x * (style_split[:, 0] + 1) + style_split[:, 1]
|
659 |
+
return x, style
|
660 |
+
|
661 |
+
|
662 |
+
class ConvBlock(nn.Module):
|
663 |
+
"""Implements the normal convolutional block.
|
664 |
+
|
665 |
+
Basically, this block executes upsampling layer (if needed), convolutional
|
666 |
+
layer, blurring layer, noise applying layer, activation layer, instance
|
667 |
+
normalization layer, and style modulation layer in sequence.
|
668 |
+
"""
|
669 |
+
|
670 |
+
def __init__(self,
|
671 |
+
in_channels,
|
672 |
+
out_channels,
|
673 |
+
resolution,
|
674 |
+
w_space_dim,
|
675 |
+
position=None,
|
676 |
+
kernel_size=3,
|
677 |
+
stride=1,
|
678 |
+
padding=1,
|
679 |
+
add_bias=True,
|
680 |
+
upsample=False,
|
681 |
+
fused_scale=False,
|
682 |
+
use_wscale=True,
|
683 |
+
wscale_gain=_WSCALE_GAIN,
|
684 |
+
lr_mul=1.0,
|
685 |
+
activation_type='lrelu'):
|
686 |
+
"""Initializes with block settings.
|
687 |
+
|
688 |
+
Args:
|
689 |
+
in_channels: Number of channels of the input tensor.
|
690 |
+
out_channels: Number of channels of the output tensor.
|
691 |
+
resolution: Resolution of the output tensor.
|
692 |
+
w_space_dim: Dimension of W space for style modulation.
|
693 |
+
position: Position of the layer. `const_init`, `last` would lead to
|
694 |
+
different behavior. (default: None)
|
695 |
+
kernel_size: Size of the convolutional kernels. (default: 3)
|
696 |
+
stride: Stride parameter for convolution operation. (default: 1)
|
697 |
+
padding: Padding parameter for convolution operation. (default: 1)
|
698 |
+
add_bias: Whether to add bias onto the convolutional result.
|
699 |
+
(default: True)
|
700 |
+
upsample: Whether to upsample the input tensor before convolution.
|
701 |
+
(default: False)
|
702 |
+
fused_scale: Whether to fused `upsample` and `conv2d` together,
|
703 |
+
resulting in `conv2d_transpose`. (default: False)
|
704 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
705 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
706 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
707 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
708 |
+
(default: `lrelu`)
|
709 |
+
|
710 |
+
Raises:
|
711 |
+
NotImplementedError: If the `activation_type` is not supported.
|
712 |
+
"""
|
713 |
+
super().__init__()
|
714 |
+
|
715 |
+
self.position = position
|
716 |
+
|
717 |
+
if add_bias:
|
718 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
719 |
+
self.bscale = lr_mul
|
720 |
+
else:
|
721 |
+
self.bias = None
|
722 |
+
|
723 |
+
if activation_type == 'linear':
|
724 |
+
self.activate = nn.Identity()
|
725 |
+
elif activation_type == 'lrelu':
|
726 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
727 |
+
else:
|
728 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
729 |
+
f'`{activation_type}`!')
|
730 |
+
|
731 |
+
if self.position != 'last':
|
732 |
+
self.apply_noise = NoiseApplyingLayer(resolution, out_channels)
|
733 |
+
self.normalize = InstanceNormLayer()
|
734 |
+
self.style = StyleModLayer(w_space_dim, out_channels, use_wscale)
|
735 |
+
|
736 |
+
if self.position == 'const_init':
|
737 |
+
self.const = nn.Parameter(
|
738 |
+
torch.ones(1, in_channels, resolution, resolution))
|
739 |
+
return
|
740 |
+
|
741 |
+
self.blur = BlurLayer(out_channels) if upsample else nn.Identity()
|
742 |
+
|
743 |
+
if upsample and not fused_scale:
|
744 |
+
self.upsample = UpsamplingLayer()
|
745 |
+
else:
|
746 |
+
self.upsample = nn.Identity()
|
747 |
+
|
748 |
+
if upsample and fused_scale:
|
749 |
+
self.use_conv2d_transpose = True
|
750 |
+
self.stride = 2
|
751 |
+
self.padding = 1
|
752 |
+
else:
|
753 |
+
self.use_conv2d_transpose = False
|
754 |
+
self.stride = stride
|
755 |
+
self.padding = padding
|
756 |
+
|
757 |
+
weight_shape = (out_channels, in_channels, kernel_size, kernel_size)
|
758 |
+
fan_in = kernel_size * kernel_size * in_channels
|
759 |
+
wscale = wscale_gain / np.sqrt(fan_in)
|
760 |
+
if use_wscale:
|
761 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
762 |
+
self.wscale = wscale * lr_mul
|
763 |
+
else:
|
764 |
+
self.weight = nn.Parameter(
|
765 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
766 |
+
self.wscale = lr_mul
|
767 |
+
|
768 |
+
def forward(self, x, w, randomize_noise=False):
|
769 |
+
if self.position != 'const_init':
|
770 |
+
x = self.upsample(x)
|
771 |
+
weight = self.weight * self.wscale
|
772 |
+
if self.use_conv2d_transpose:
|
773 |
+
weight = F.pad(weight, (1, 1, 1, 1, 0, 0, 0, 0), 'constant', 0)
|
774 |
+
weight = (weight[:, :, 1:, 1:] + weight[:, :, :-1, 1:] +
|
775 |
+
weight[:, :, 1:, :-1] + weight[:, :, :-1, :-1])
|
776 |
+
weight = weight.permute(1, 0, 2, 3)
|
777 |
+
x = F.conv_transpose2d(x,
|
778 |
+
weight=weight,
|
779 |
+
bias=None,
|
780 |
+
stride=self.stride,
|
781 |
+
padding=self.padding)
|
782 |
+
else:
|
783 |
+
x = F.conv2d(x,
|
784 |
+
weight=weight,
|
785 |
+
bias=None,
|
786 |
+
stride=self.stride,
|
787 |
+
padding=self.padding)
|
788 |
+
x = self.blur(x)
|
789 |
+
else:
|
790 |
+
x = self.const.repeat(w.shape[0], 1, 1, 1)
|
791 |
+
|
792 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
793 |
+
|
794 |
+
if self.position == 'last':
|
795 |
+
if bias is not None:
|
796 |
+
x = x + bias.view(1, -1, 1, 1)
|
797 |
+
return x
|
798 |
+
|
799 |
+
x = self.apply_noise(x, randomize_noise)
|
800 |
+
if bias is not None:
|
801 |
+
x = x + bias.view(1, -1, 1, 1)
|
802 |
+
x = self.activate(x)
|
803 |
+
x = self.normalize(x)
|
804 |
+
x, style = self.style(x, w)
|
805 |
+
return x, style
|
806 |
+
|
807 |
+
|
808 |
+
class DenseBlock(nn.Module):
|
809 |
+
"""Implements the dense block.
|
810 |
+
|
811 |
+
Basically, this block executes fully-connected layer and activation layer.
|
812 |
+
"""
|
813 |
+
|
814 |
+
def __init__(self,
|
815 |
+
in_channels,
|
816 |
+
out_channels,
|
817 |
+
add_bias=True,
|
818 |
+
use_wscale=True,
|
819 |
+
wscale_gain=_WSCALE_GAIN,
|
820 |
+
lr_mul=1.0,
|
821 |
+
activation_type='lrelu'):
|
822 |
+
"""Initializes with block settings.
|
823 |
+
|
824 |
+
Args:
|
825 |
+
in_channels: Number of channels of the input tensor.
|
826 |
+
out_channels: Number of channels of the output tensor.
|
827 |
+
add_bias: Whether to add bias onto the fully-connected result.
|
828 |
+
(default: True)
|
829 |
+
use_wscale: Whether to use weight scaling. (default: True)
|
830 |
+
wscale_gain: Gain factor for weight scaling. (default: _WSCALE_GAIN)
|
831 |
+
lr_mul: Learning multiplier for both weight and bias. (default: 1.0)
|
832 |
+
activation_type: Type of activation. Support `linear` and `lrelu`.
|
833 |
+
(default: `lrelu`)
|
834 |
+
|
835 |
+
Raises:
|
836 |
+
NotImplementedError: If the `activation_type` is not supported.
|
837 |
+
"""
|
838 |
+
super().__init__()
|
839 |
+
weight_shape = (out_channels, in_channels)
|
840 |
+
wscale = wscale_gain / np.sqrt(in_channels)
|
841 |
+
if use_wscale:
|
842 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape) / lr_mul)
|
843 |
+
self.wscale = wscale * lr_mul
|
844 |
+
else:
|
845 |
+
self.weight = nn.Parameter(
|
846 |
+
torch.randn(*weight_shape) * wscale / lr_mul)
|
847 |
+
self.wscale = lr_mul
|
848 |
+
|
849 |
+
if add_bias:
|
850 |
+
self.bias = nn.Parameter(torch.zeros(out_channels))
|
851 |
+
self.bscale = lr_mul
|
852 |
+
else:
|
853 |
+
self.bias = None
|
854 |
+
|
855 |
+
if activation_type == 'linear':
|
856 |
+
self.activate = nn.Identity()
|
857 |
+
elif activation_type == 'lrelu':
|
858 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
859 |
+
else:
|
860 |
+
raise NotImplementedError(f'Not implemented activation function: '
|
861 |
+
f'`{activation_type}`!')
|
862 |
+
|
863 |
+
def forward(self, x):
|
864 |
+
if x.ndim != 2:
|
865 |
+
x = x.view(x.shape[0], -1)
|
866 |
+
bias = self.bias * self.bscale if self.bias is not None else None
|
867 |
+
x = F.linear(x, weight=self.weight * self.wscale, bias=bias)
|
868 |
+
x = self.activate(x)
|
869 |
+
return x
|
models/sync_op.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the synchronizing operator."""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
__all__ = ['all_gather']
|
8 |
+
|
9 |
+
|
10 |
+
def all_gather(tensor):
|
11 |
+
"""Gathers tensor from all devices and does averaging."""
|
12 |
+
if not dist.is_initialized():
|
13 |
+
return tensor
|
14 |
+
|
15 |
+
world_size = dist.get_world_size()
|
16 |
+
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
|
17 |
+
dist.all_gather(tensor_list, tensor, async_op=False)
|
18 |
+
return torch.mean(torch.stack(tensor_list, dim=0), dim=0)
|
sefa.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SeFa."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import argparse
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from models import parse_gan_type
|
11 |
+
from utils import to_tensor
|
12 |
+
from utils import postprocess
|
13 |
+
from utils import load_generator
|
14 |
+
from utils import factorize_weight
|
15 |
+
from utils import HtmlPageVisualizer
|
16 |
+
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
"""Parses arguments."""
|
20 |
+
parser = argparse.ArgumentParser(
|
21 |
+
description='Discover semantics from the pre-trained weight.')
|
22 |
+
parser.add_argument('model_name', type=str,
|
23 |
+
help='Name to the pre-trained model.')
|
24 |
+
parser.add_argument('--save_dir', type=str, default='results',
|
25 |
+
help='Directory to save the visualization pages. '
|
26 |
+
'(default: %(default)s)')
|
27 |
+
parser.add_argument('-L', '--layer_idx', type=str, default='all',
|
28 |
+
help='Indices of layers to interpret. '
|
29 |
+
'(default: %(default)s)')
|
30 |
+
parser.add_argument('-N', '--num_samples', type=int, default=5,
|
31 |
+
help='Number of samples used for visualization. '
|
32 |
+
'(default: %(default)s)')
|
33 |
+
parser.add_argument('-K', '--num_semantics', type=int, default=5,
|
34 |
+
help='Number of semantic boundaries corresponding to '
|
35 |
+
'the top-k eigen values. (default: %(default)s)')
|
36 |
+
parser.add_argument('--start_distance', type=float, default=-3.0,
|
37 |
+
help='Start point for manipulation on each semantic. '
|
38 |
+
'(default: %(default)s)')
|
39 |
+
parser.add_argument('--end_distance', type=float, default=3.0,
|
40 |
+
help='Ending point for manipulation on each semantic. '
|
41 |
+
'(default: %(default)s)')
|
42 |
+
parser.add_argument('--step', type=int, default=11,
|
43 |
+
help='Manipulation step on each semantic. '
|
44 |
+
'(default: %(default)s)')
|
45 |
+
parser.add_argument('--viz_size', type=int, default=256,
|
46 |
+
help='Size of images to visualize on the HTML page. '
|
47 |
+
'(default: %(default)s)')
|
48 |
+
parser.add_argument('--trunc_psi', type=float, default=0.7,
|
49 |
+
help='Psi factor used for truncation. This is '
|
50 |
+
'particularly applicable to StyleGAN (v1/v2). '
|
51 |
+
'(default: %(default)s)')
|
52 |
+
parser.add_argument('--trunc_layers', type=int, default=8,
|
53 |
+
help='Number of layers to perform truncation. This is '
|
54 |
+
'particularly applicable to StyleGAN (v1/v2). '
|
55 |
+
'(default: %(default)s)')
|
56 |
+
parser.add_argument('--seed', type=int, default=0,
|
57 |
+
help='Seed for sampling. (default: %(default)s)')
|
58 |
+
parser.add_argument('--gpu_id', type=str, default='0',
|
59 |
+
help='GPU(s) to use. (default: %(default)s)')
|
60 |
+
return parser.parse_args()
|
61 |
+
|
62 |
+
|
63 |
+
def main():
|
64 |
+
"""Main function."""
|
65 |
+
args = parse_args()
|
66 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
|
67 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
68 |
+
|
69 |
+
# Factorize weights.
|
70 |
+
generator = load_generator(args.model_name)
|
71 |
+
gan_type = parse_gan_type(generator)
|
72 |
+
layers, boundaries, values = factorize_weight(generator, args.layer_idx)
|
73 |
+
|
74 |
+
# Set random seed.
|
75 |
+
np.random.seed(args.seed)
|
76 |
+
torch.manual_seed(args.seed)
|
77 |
+
|
78 |
+
# Prepare codes.
|
79 |
+
codes = torch.randn(args.num_samples, generator.z_space_dim).cuda()
|
80 |
+
if gan_type == 'pggan':
|
81 |
+
codes = generator.layer0.pixel_norm(codes)
|
82 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
83 |
+
codes = generator.mapping(codes)['w']
|
84 |
+
codes = generator.truncation(codes,
|
85 |
+
trunc_psi=args.trunc_psi,
|
86 |
+
trunc_layers=args.trunc_layers)
|
87 |
+
codes = codes.detach().cpu().numpy()
|
88 |
+
|
89 |
+
# Generate visualization pages.
|
90 |
+
distances = np.linspace(args.start_distance,args.end_distance, args.step)
|
91 |
+
num_sam = args.num_samples
|
92 |
+
num_sem = args.num_semantics
|
93 |
+
vizer_1 = HtmlPageVisualizer(num_rows=num_sem * (num_sam + 1),
|
94 |
+
num_cols=args.step + 1,
|
95 |
+
viz_size=args.viz_size)
|
96 |
+
vizer_2 = HtmlPageVisualizer(num_rows=num_sam * (num_sem + 1),
|
97 |
+
num_cols=args.step + 1,
|
98 |
+
viz_size=args.viz_size)
|
99 |
+
|
100 |
+
headers = [''] + [f'Distance {d:.2f}' for d in distances]
|
101 |
+
vizer_1.set_headers(headers)
|
102 |
+
vizer_2.set_headers(headers)
|
103 |
+
for sem_id in range(num_sem):
|
104 |
+
value = values[sem_id]
|
105 |
+
vizer_1.set_cell(sem_id * (num_sam + 1), 0,
|
106 |
+
text=f'Semantic {sem_id:03d}<br>({value:.3f})',
|
107 |
+
highlight=True)
|
108 |
+
for sam_id in range(num_sam):
|
109 |
+
vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, 0,
|
110 |
+
text=f'Sample {sam_id:03d}')
|
111 |
+
for sam_id in range(num_sam):
|
112 |
+
vizer_2.set_cell(sam_id * (num_sem + 1), 0,
|
113 |
+
text=f'Sample {sam_id:03d}',
|
114 |
+
highlight=True)
|
115 |
+
for sem_id in range(num_sem):
|
116 |
+
value = values[sem_id]
|
117 |
+
vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, 0,
|
118 |
+
text=f'Semantic {sem_id:03d}<br>({value:.3f})')
|
119 |
+
|
120 |
+
for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
|
121 |
+
code = codes[sam_id:sam_id + 1]
|
122 |
+
for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
|
123 |
+
boundary = boundaries[sem_id:sem_id + 1]
|
124 |
+
for col_id, d in enumerate(distances, start=1):
|
125 |
+
temp_code = code.copy()
|
126 |
+
if gan_type == 'pggan':
|
127 |
+
temp_code += boundary * d
|
128 |
+
image = generator(to_tensor(temp_code))['image']
|
129 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
130 |
+
temp_code[:, layers, :] += boundary * d
|
131 |
+
image = generator.synthesis(to_tensor(temp_code))['image']
|
132 |
+
image = postprocess(image)[0]
|
133 |
+
vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, col_id,
|
134 |
+
image=image)
|
135 |
+
vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, col_id,
|
136 |
+
image=image)
|
137 |
+
|
138 |
+
prefix = (f'{args.model_name}_'
|
139 |
+
f'N{num_sam}_K{num_sem}_L{args.layer_idx}_seed{args.seed}')
|
140 |
+
vizer_1.save(os.path.join(args.save_dir, f'{prefix}_sample_first.html'))
|
141 |
+
vizer_2.save(os.path.join(args.save_dir, f'{prefix}_semantic_first.html'))
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == '__main__':
|
145 |
+
main()
|
utils.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions."""
|
2 |
+
|
3 |
+
import base64
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from models import MODEL_ZOO
|
12 |
+
from models import build_generator
|
13 |
+
from models import parse_gan_type
|
14 |
+
|
15 |
+
__all__ = ['postprocess', 'load_generator', 'factorize_weight',
|
16 |
+
'HtmlPageVisualizer']
|
17 |
+
|
18 |
+
CHECKPOINT_DIR = 'checkpoints'
|
19 |
+
|
20 |
+
|
21 |
+
def to_tensor(array):
|
22 |
+
"""Converts a `numpy.ndarray` to `torch.Tensor`.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
array: The input array to convert.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
A `torch.Tensor` with dtype `torch.FloatTensor` on cuda device.
|
29 |
+
"""
|
30 |
+
assert isinstance(array, np.ndarray)
|
31 |
+
return torch.from_numpy(array).type(torch.FloatTensor).cuda()
|
32 |
+
|
33 |
+
|
34 |
+
def postprocess(images, min_val=-1.0, max_val=1.0):
|
35 |
+
"""Post-processes images from `torch.Tensor` to `numpy.ndarray`.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
images: A `torch.Tensor` with shape `NCHW` to process.
|
39 |
+
min_val: The minimum value of the input tensor. (default: -1.0)
|
40 |
+
max_val: The maximum value of the input tensor. (default: 1.0)
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
A `numpy.ndarray` with shape `NHWC` and pixel range [0, 255].
|
44 |
+
"""
|
45 |
+
assert isinstance(images, torch.Tensor)
|
46 |
+
images = images.detach().cpu().numpy()
|
47 |
+
images = (images - min_val) * 255 / (max_val - min_val)
|
48 |
+
images = np.clip(images + 0.5, 0, 255).astype(np.uint8)
|
49 |
+
images = images.transpose(0, 2, 3, 1)
|
50 |
+
return images
|
51 |
+
|
52 |
+
|
53 |
+
def load_generator(model_name):
|
54 |
+
"""Loads pre-trained generator.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
model_name: Name of the model. Should be a key in `models.MODEL_ZOO`.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
A generator, which is a `torch.nn.Module`, with pre-trained weights
|
61 |
+
loaded.
|
62 |
+
|
63 |
+
Raises:
|
64 |
+
KeyError: If the input `model_name` is not in `models.MODEL_ZOO`.
|
65 |
+
"""
|
66 |
+
if model_name not in MODEL_ZOO:
|
67 |
+
raise KeyError(f'Unknown model name `{model_name}`!')
|
68 |
+
|
69 |
+
model_config = MODEL_ZOO[model_name].copy()
|
70 |
+
url = model_config.pop('url') # URL to download model if needed.
|
71 |
+
|
72 |
+
# Build generator.
|
73 |
+
print(f'Building generator for model `{model_name}` ...')
|
74 |
+
generator = build_generator(**model_config)
|
75 |
+
print(f'Finish building generator.')
|
76 |
+
|
77 |
+
# Load pre-trained weights.
|
78 |
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
79 |
+
checkpoint_path = os.path.join(CHECKPOINT_DIR, model_name + '.pth')
|
80 |
+
print(f'Loading checkpoint from `{checkpoint_path}` ...')
|
81 |
+
if not os.path.exists(checkpoint_path):
|
82 |
+
print(f' Downloading checkpoint from `{url}` ...')
|
83 |
+
subprocess.call(['wget', '--quiet', '-O', checkpoint_path, url])
|
84 |
+
print(f' Finish downloading checkpoint.')
|
85 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
86 |
+
if 'generator_smooth' in checkpoint:
|
87 |
+
generator.load_state_dict(checkpoint['generator_smooth'])
|
88 |
+
else:
|
89 |
+
generator.load_state_dict(checkpoint['generator'])
|
90 |
+
generator = generator.cuda()
|
91 |
+
generator.eval()
|
92 |
+
print(f'Finish loading checkpoint.')
|
93 |
+
return generator
|
94 |
+
|
95 |
+
|
96 |
+
def parse_indices(obj, min_val=None, max_val=None):
|
97 |
+
"""Parses indices.
|
98 |
+
|
99 |
+
The input can be a list or a tuple or a string, which is either a comma
|
100 |
+
separated list of numbers 'a, b, c', or a dash separated range 'a - c'.
|
101 |
+
Space in the string will be ignored.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
obj: The input object to parse indices from.
|
105 |
+
min_val: If not `None`, this function will check that all indices are
|
106 |
+
equal to or larger than this value. (default: None)
|
107 |
+
max_val: If not `None`, this function will check that all indices are
|
108 |
+
equal to or smaller than this value. (default: None)
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
A list of integers.
|
112 |
+
|
113 |
+
Raises:
|
114 |
+
If the input is invalid, i.e., neither a list or tuple, nor a string.
|
115 |
+
"""
|
116 |
+
if obj is None or obj == '':
|
117 |
+
indices = []
|
118 |
+
elif isinstance(obj, int):
|
119 |
+
indices = [obj]
|
120 |
+
elif isinstance(obj, (list, tuple, np.ndarray)):
|
121 |
+
indices = list(obj)
|
122 |
+
elif isinstance(obj, str):
|
123 |
+
indices = []
|
124 |
+
splits = obj.replace(' ', '').split(',')
|
125 |
+
for split in splits:
|
126 |
+
numbers = list(map(int, split.split('-')))
|
127 |
+
if len(numbers) == 1:
|
128 |
+
indices.append(numbers[0])
|
129 |
+
elif len(numbers) == 2:
|
130 |
+
indices.extend(list(range(numbers[0], numbers[1] + 1)))
|
131 |
+
else:
|
132 |
+
raise ValueError(f'Unable to parse the input!')
|
133 |
+
|
134 |
+
else:
|
135 |
+
raise ValueError(f'Invalid type of input: `{type(obj)}`!')
|
136 |
+
|
137 |
+
assert isinstance(indices, list)
|
138 |
+
indices = sorted(list(set(indices)))
|
139 |
+
for idx in indices:
|
140 |
+
assert isinstance(idx, int)
|
141 |
+
if min_val is not None:
|
142 |
+
assert idx >= min_val, f'{idx} is smaller than min val `{min_val}`!'
|
143 |
+
if max_val is not None:
|
144 |
+
assert idx <= max_val, f'{idx} is larger than max val `{max_val}`!'
|
145 |
+
|
146 |
+
return indices
|
147 |
+
|
148 |
+
|
149 |
+
def factorize_weight(generator, layer_idx='all'):
|
150 |
+
"""Factorizes the generator weight to get semantics boundaries.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
generator: Generator to factorize.
|
154 |
+
layer_idx: Indices of layers to interpret, especially for StyleGAN and
|
155 |
+
StyleGAN2. (default: `all`)
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
A tuple of (layers_to_interpret, semantic_boundaries, eigen_values).
|
159 |
+
|
160 |
+
Raises:
|
161 |
+
ValueError: If the generator type is not supported.
|
162 |
+
"""
|
163 |
+
# Get GAN type.
|
164 |
+
gan_type = parse_gan_type(generator)
|
165 |
+
|
166 |
+
# Get layers.
|
167 |
+
if gan_type == 'pggan':
|
168 |
+
layers = [0]
|
169 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
170 |
+
if layer_idx == 'all':
|
171 |
+
layers = list(range(generator.num_layers))
|
172 |
+
else:
|
173 |
+
layers = parse_indices(layer_idx,
|
174 |
+
min_val=0,
|
175 |
+
max_val=generator.num_layers - 1)
|
176 |
+
|
177 |
+
# Factorize semantics from weight.
|
178 |
+
weights = []
|
179 |
+
for idx in layers:
|
180 |
+
layer_name = f'layer{idx}'
|
181 |
+
if gan_type == 'stylegan2' and idx == generator.num_layers - 1:
|
182 |
+
layer_name = f'output{idx // 2}'
|
183 |
+
if gan_type == 'pggan':
|
184 |
+
weight = generator.__getattr__(layer_name).weight
|
185 |
+
weight = weight.flip(2, 3).permute(1, 0, 2, 3).flatten(1)
|
186 |
+
elif gan_type in ['stylegan', 'stylegan2']:
|
187 |
+
weight = generator.synthesis.__getattr__(layer_name).style.weight.T
|
188 |
+
weights.append(weight.cpu().detach().numpy())
|
189 |
+
weight = np.concatenate(weights, axis=1).astype(np.float32)
|
190 |
+
weight = weight / np.linalg.norm(weight, axis=0, keepdims=True)
|
191 |
+
eigen_values, eigen_vectors = np.linalg.eig(weight.dot(weight.T))
|
192 |
+
|
193 |
+
return layers, eigen_vectors.T, eigen_values
|
194 |
+
|
195 |
+
|
196 |
+
def get_sortable_html_header(column_name_list, sort_by_ascending=False):
|
197 |
+
"""Gets header for sortable html page.
|
198 |
+
|
199 |
+
Basically, the html page contains a sortable table, where user can sort the
|
200 |
+
rows by a particular column by clicking the column head.
|
201 |
+
|
202 |
+
Example:
|
203 |
+
|
204 |
+
column_name_list = [name_1, name_2, name_3]
|
205 |
+
header = get_sortable_html_header(column_name_list)
|
206 |
+
footer = get_sortable_html_footer()
|
207 |
+
sortable_table = ...
|
208 |
+
html_page = header + sortable_table + footer
|
209 |
+
|
210 |
+
Args:
|
211 |
+
column_name_list: List of column header names.
|
212 |
+
sort_by_ascending: Default sorting order. If set as `True`, the html
|
213 |
+
page will be sorted by ascending order when the header is clicked
|
214 |
+
for the first time.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
A string, which represents for the header for a sortable html page.
|
218 |
+
"""
|
219 |
+
header = '\n'.join([
|
220 |
+
'<script type="text/javascript">',
|
221 |
+
'var column_idx;',
|
222 |
+
'var sort_by_ascending = ' + str(sort_by_ascending).lower() + ';',
|
223 |
+
'',
|
224 |
+
'function sorting(tbody, column_idx){',
|
225 |
+
' this.column_idx = column_idx;',
|
226 |
+
' Array.from(tbody.rows)',
|
227 |
+
' .sort(compareCells)',
|
228 |
+
' .forEach(function(row) { tbody.appendChild(row); })',
|
229 |
+
' sort_by_ascending = !sort_by_ascending;',
|
230 |
+
'}',
|
231 |
+
'',
|
232 |
+
'function compareCells(row_a, row_b) {',
|
233 |
+
' var val_a = row_a.cells[column_idx].innerText;',
|
234 |
+
' var val_b = row_b.cells[column_idx].innerText;',
|
235 |
+
' var flag = sort_by_ascending ? 1 : -1;',
|
236 |
+
' return flag * (val_a > val_b ? 1 : -1);',
|
237 |
+
'}',
|
238 |
+
'</script>',
|
239 |
+
'',
|
240 |
+
'<html>',
|
241 |
+
'',
|
242 |
+
'<head>',
|
243 |
+
'<style>',
|
244 |
+
' table {',
|
245 |
+
' border-spacing: 0;',
|
246 |
+
' border: 1px solid black;',
|
247 |
+
' }',
|
248 |
+
' th {',
|
249 |
+
' cursor: pointer;',
|
250 |
+
' }',
|
251 |
+
' th, td {',
|
252 |
+
' text-align: left;',
|
253 |
+
' vertical-align: middle;',
|
254 |
+
' border-collapse: collapse;',
|
255 |
+
' border: 0.5px solid black;',
|
256 |
+
' padding: 8px;',
|
257 |
+
' }',
|
258 |
+
' tr:nth-child(even) {',
|
259 |
+
' background-color: #d2d2d2;',
|
260 |
+
' }',
|
261 |
+
'</style>',
|
262 |
+
'</head>',
|
263 |
+
'',
|
264 |
+
'<body>',
|
265 |
+
'',
|
266 |
+
'<table>',
|
267 |
+
'<thead>',
|
268 |
+
'<tr>',
|
269 |
+
''])
|
270 |
+
for idx, name in enumerate(column_name_list):
|
271 |
+
header += f' <th onclick="sorting(tbody, {idx})">{name}</th>\n'
|
272 |
+
header += '</tr>\n'
|
273 |
+
header += '</thead>\n'
|
274 |
+
header += '<tbody id="tbody">\n'
|
275 |
+
|
276 |
+
return header
|
277 |
+
|
278 |
+
|
279 |
+
def get_sortable_html_footer():
|
280 |
+
"""Gets footer for sortable html page.
|
281 |
+
|
282 |
+
Check function `get_sortable_html_header()` for more details.
|
283 |
+
"""
|
284 |
+
return '</tbody>\n</table>\n\n</body>\n</html>\n'
|
285 |
+
|
286 |
+
|
287 |
+
def parse_image_size(obj):
|
288 |
+
"""Parses object to a pair of image size, i.e., (width, height).
|
289 |
+
|
290 |
+
Args:
|
291 |
+
obj: The input object to parse image size from.
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
A two-element tuple, indicating image width and height respectively.
|
295 |
+
|
296 |
+
Raises:
|
297 |
+
If the input is invalid, i.e., neither a list or tuple, nor a string.
|
298 |
+
"""
|
299 |
+
if obj is None or obj == '':
|
300 |
+
width = height = 0
|
301 |
+
elif isinstance(obj, int):
|
302 |
+
width = height = obj
|
303 |
+
elif isinstance(obj, (list, tuple, np.ndarray)):
|
304 |
+
numbers = tuple(obj)
|
305 |
+
if len(numbers) == 0:
|
306 |
+
width = height = 0
|
307 |
+
elif len(numbers) == 1:
|
308 |
+
width = height = numbers[0]
|
309 |
+
elif len(numbers) == 2:
|
310 |
+
width = numbers[0]
|
311 |
+
height = numbers[1]
|
312 |
+
else:
|
313 |
+
raise ValueError(f'At most two elements for image size.')
|
314 |
+
elif isinstance(obj, str):
|
315 |
+
splits = obj.replace(' ', '').split(',')
|
316 |
+
numbers = tuple(map(int, splits))
|
317 |
+
if len(numbers) == 0:
|
318 |
+
width = height = 0
|
319 |
+
elif len(numbers) == 1:
|
320 |
+
width = height = numbers[0]
|
321 |
+
elif len(numbers) == 2:
|
322 |
+
width = numbers[0]
|
323 |
+
height = numbers[1]
|
324 |
+
else:
|
325 |
+
raise ValueError(f'At most two elements for image size.')
|
326 |
+
else:
|
327 |
+
raise ValueError(f'Invalid type of input: {type(obj)}!')
|
328 |
+
|
329 |
+
return (max(0, width), max(0, height))
|
330 |
+
|
331 |
+
|
332 |
+
def encode_image_to_html_str(image, image_size=None):
|
333 |
+
"""Encodes an image to html language.
|
334 |
+
NOTE: Input image is always assumed to be with `RGB` channel order.
|
335 |
+
Args:
|
336 |
+
image: The input image to encode. Should be with `RGB` channel order.
|
337 |
+
image_size: This field is used to resize the image before encoding. `0`
|
338 |
+
disables resizing. (default: None)
|
339 |
+
Returns:
|
340 |
+
A string which represents the encoded image.
|
341 |
+
"""
|
342 |
+
if image is None:
|
343 |
+
return ''
|
344 |
+
|
345 |
+
assert image.ndim == 3 and image.shape[2] in [1, 3]
|
346 |
+
|
347 |
+
# Change channel order to `BGR`, which is opencv-friendly.
|
348 |
+
image = image[:, :, ::-1]
|
349 |
+
|
350 |
+
# Resize the image if needed.
|
351 |
+
width, height = parse_image_size(image_size)
|
352 |
+
if height or width:
|
353 |
+
height = height or image.shape[0]
|
354 |
+
width = width or image.shape[1]
|
355 |
+
image = cv2.resize(image, (width, height))
|
356 |
+
|
357 |
+
# Encode the image to html-format string.
|
358 |
+
encoded_image = cv2.imencode('.jpg', image)[1].tostring()
|
359 |
+
encoded_image_base64 = base64.b64encode(encoded_image).decode('utf-8')
|
360 |
+
html_str = f'<img src="data:image/jpeg;base64, {encoded_image_base64}"/>'
|
361 |
+
|
362 |
+
return html_str
|
363 |
+
|
364 |
+
|
365 |
+
def get_grid_shape(size, row=0, col=0, is_portrait=False):
|
366 |
+
"""Gets the shape of a grid based on the size.
|
367 |
+
|
368 |
+
This function makes greatest effort on making the output grid square if
|
369 |
+
neither `row` nor `col` is set. If `is_portrait` is set as `False`, the
|
370 |
+
height will always be equal to or smaller than the width. For example, if
|
371 |
+
input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`,
|
372 |
+
output shape will be (3, 5). Otherwise, the height will always be equal to
|
373 |
+
or larger than the width.
|
374 |
+
|
375 |
+
Args:
|
376 |
+
size: Size (height * width) of the target grid.
|
377 |
+
is_portrait: Whether to return a portrait size of a landscape size.
|
378 |
+
(default: False)
|
379 |
+
|
380 |
+
Returns:
|
381 |
+
A two-element tuple, representing height and width respectively.
|
382 |
+
"""
|
383 |
+
assert isinstance(size, int)
|
384 |
+
assert isinstance(row, int)
|
385 |
+
assert isinstance(col, int)
|
386 |
+
if size == 0:
|
387 |
+
return (0, 0)
|
388 |
+
|
389 |
+
if row > 0 and col > 0 and row * col != size:
|
390 |
+
row = 0
|
391 |
+
col = 0
|
392 |
+
|
393 |
+
if row > 0 and size % row == 0:
|
394 |
+
return (row, size // row)
|
395 |
+
if col > 0 and size % col == 0:
|
396 |
+
return (size // col, col)
|
397 |
+
|
398 |
+
row = int(np.sqrt(size))
|
399 |
+
while row > 0:
|
400 |
+
if size % row == 0:
|
401 |
+
col = size // row
|
402 |
+
break
|
403 |
+
row = row - 1
|
404 |
+
|
405 |
+
return (col, row) if is_portrait else (row, col)
|
406 |
+
|
407 |
+
|
408 |
+
class HtmlPageVisualizer(object):
|
409 |
+
"""Defines the html page visualizer.
|
410 |
+
|
411 |
+
This class can be used to visualize image results as html page. Basically,
|
412 |
+
it is based on an html-format sorted table with helper functions
|
413 |
+
`get_sortable_html_header()`, `get_sortable_html_footer()`, and
|
414 |
+
`encode_image_to_html_str()`. To simplify the usage, specifying the
|
415 |
+
following fields are enough to create a visualization page:
|
416 |
+
|
417 |
+
(1) num_rows: Number of rows of the table (header-row exclusive).
|
418 |
+
(2) num_cols: Number of columns of the table.
|
419 |
+
(3) header contents (optional): Title of each column.
|
420 |
+
|
421 |
+
NOTE: `grid_size` can be used to assign `num_rows` and `num_cols`
|
422 |
+
automatically.
|
423 |
+
|
424 |
+
Example:
|
425 |
+
|
426 |
+
html = HtmlPageVisualizer(num_rows, num_cols)
|
427 |
+
html.set_headers([...])
|
428 |
+
for i in range(num_rows):
|
429 |
+
for j in range(num_cols):
|
430 |
+
html.set_cell(i, j, text=..., image=..., highlight=False)
|
431 |
+
html.save('visualize.html')
|
432 |
+
"""
|
433 |
+
|
434 |
+
def __init__(self,
|
435 |
+
num_rows=0,
|
436 |
+
num_cols=0,
|
437 |
+
grid_size=0,
|
438 |
+
is_portrait=True,
|
439 |
+
viz_size=None):
|
440 |
+
if grid_size > 0:
|
441 |
+
num_rows, num_cols = get_grid_shape(
|
442 |
+
grid_size, row=num_rows, col=num_cols, is_portrait=is_portrait)
|
443 |
+
assert num_rows > 0 and num_cols > 0
|
444 |
+
|
445 |
+
self.num_rows = num_rows
|
446 |
+
self.num_cols = num_cols
|
447 |
+
self.viz_size = parse_image_size(viz_size)
|
448 |
+
self.headers = ['' for _ in range(self.num_cols)]
|
449 |
+
self.cells = [[{
|
450 |
+
'text': '',
|
451 |
+
'image': '',
|
452 |
+
'highlight': False,
|
453 |
+
} for _ in range(self.num_cols)] for _ in range(self.num_rows)]
|
454 |
+
|
455 |
+
def set_header(self, col_idx, content):
|
456 |
+
"""Sets the content of a particular header by column index."""
|
457 |
+
self.headers[col_idx] = content
|
458 |
+
|
459 |
+
def set_headers(self, contents):
|
460 |
+
"""Sets the contents of all headers."""
|
461 |
+
if isinstance(contents, str):
|
462 |
+
contents = [contents]
|
463 |
+
assert isinstance(contents, (list, tuple))
|
464 |
+
assert len(contents) == self.num_cols
|
465 |
+
for col_idx, content in enumerate(contents):
|
466 |
+
self.set_header(col_idx, content)
|
467 |
+
|
468 |
+
def set_cell(self, row_idx, col_idx, text='', image=None, highlight=False):
|
469 |
+
"""Sets the content of a particular cell.
|
470 |
+
|
471 |
+
Basically, a cell contains some text as well as an image. Both text and
|
472 |
+
image can be empty.
|
473 |
+
|
474 |
+
Args:
|
475 |
+
row_idx: Row index of the cell to edit.
|
476 |
+
col_idx: Column index of the cell to edit.
|
477 |
+
text: Text to add into the target cell. (default: None)
|
478 |
+
image: Image to show in the target cell. Should be with `RGB`
|
479 |
+
channel order. (default: None)
|
480 |
+
highlight: Whether to highlight this cell. (default: False)
|
481 |
+
"""
|
482 |
+
self.cells[row_idx][col_idx]['text'] = text
|
483 |
+
self.cells[row_idx][col_idx]['image'] = encode_image_to_html_str(
|
484 |
+
image, self.viz_size)
|
485 |
+
self.cells[row_idx][col_idx]['highlight'] = bool(highlight)
|
486 |
+
|
487 |
+
def save(self, save_path):
|
488 |
+
"""Saves the html page."""
|
489 |
+
html = ''
|
490 |
+
for i in range(self.num_rows):
|
491 |
+
html += f'<tr>\n'
|
492 |
+
for j in range(self.num_cols):
|
493 |
+
text = self.cells[i][j]['text']
|
494 |
+
image = self.cells[i][j]['image']
|
495 |
+
if self.cells[i][j]['highlight']:
|
496 |
+
color = ' bgcolor="#FF8888"'
|
497 |
+
else:
|
498 |
+
color = ''
|
499 |
+
if text:
|
500 |
+
html += f' <td{color}>{text}<br><br>{image}</td>\n'
|
501 |
+
else:
|
502 |
+
html += f' <td{color}>{image}</td>\n'
|
503 |
+
html += f'</tr>\n'
|
504 |
+
|
505 |
+
header = get_sortable_html_header(self.headers)
|
506 |
+
footer = get_sortable_html_footer()
|
507 |
+
|
508 |
+
with open(save_path, 'w') as f:
|
509 |
+
f.write(header + html + footer)
|