Spaces:
Paused
Paused
Main code + Checkpoints_DFG
Browse filesSource: https://github.com/JunlinHan/CWR (adaptation)
- .gitignore +1 -0
- ModelLoader.py +52 -0
- app.py +11 -0
- checkpoints/original/latest_net_D.pth +3 -0
- checkpoints/original/latest_net_F.pth +3 -0
- checkpoints/original/latest_net_G.pth +3 -0
- examples/.gitattributes +1 -0
- examples/rawimg.png +3 -0
- models/__init__.py +67 -0
- models/base_model.py +258 -0
- models/cwr_model.py +207 -0
- models/networks.py +1401 -0
- models/patchnce.py +54 -0
- models/spectralNormalization.py +64 -0
- models/stylegan_networks.py +914 -0
- models/template_model.py +99 -0
- requirements.txt +9 -0
- test.py +11 -0
- util/__init__.py +2 -0
- util/get_data.py +110 -0
- util/html.py +86 -0
- util/image_pool.py +54 -0
- util/util.py +166 -0
- util/visualizer.py +242 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
ModelLoader.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models import create_model
|
2 |
+
import os
|
3 |
+
|
4 |
+
ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints')
|
5 |
+
|
6 |
+
class Options(object):
|
7 |
+
def __init__(self, *initial_data, **kwargs):
|
8 |
+
for dictionary in initial_data:
|
9 |
+
for key in dictionary:
|
10 |
+
setattr(self, key, dictionary[key])
|
11 |
+
for key in kwargs:
|
12 |
+
setattr(self, key, kwargs[key])
|
13 |
+
|
14 |
+
class ModelLoader:
|
15 |
+
def __init__(self) -> None:
|
16 |
+
self.opt = Options({
|
17 |
+
'name': 'original',
|
18 |
+
'checkpoints_dir': ckp_path,
|
19 |
+
'gpu_ids': [],
|
20 |
+
'init_gain': 0.02,
|
21 |
+
'init_type': 'xavier',
|
22 |
+
'input_nc': 3,
|
23 |
+
'output_nc': 3,
|
24 |
+
'isTrain': False,
|
25 |
+
'model': 'cwr',
|
26 |
+
'nce_idt': False,
|
27 |
+
'nce_layers': '0',
|
28 |
+
'ndf': 64,
|
29 |
+
'netD': 'basic',
|
30 |
+
'netG': 'resnet_9blocks',
|
31 |
+
'netF': 'reshape',
|
32 |
+
'ngf': 64,
|
33 |
+
'no_antialias_up': None,
|
34 |
+
'no_antialias': None,
|
35 |
+
'no_dropout': True,
|
36 |
+
'normD': 'instance',
|
37 |
+
'normG': 'instance',
|
38 |
+
'preprocess': 'scale_width',
|
39 |
+
'num_threads': 0, # test code only supports num_threads = 1
|
40 |
+
'batch_size': 1, # test code only supports batch_size = 1
|
41 |
+
'serial_batches': True, # disable data shuffling; comment this line if results on randomly chosen images are needed.
|
42 |
+
'no_flip': True, # no flip; comment this line if results on flipped images are needed.
|
43 |
+
'display_id': -1, # no visdom display; the test code saves the results to a HTML file.
|
44 |
+
})
|
45 |
+
def load(self) -> None:
|
46 |
+
self.model = create_model(self.opt)
|
47 |
+
self.model.load_networks('latest')
|
48 |
+
def inference(self, src=''):
|
49 |
+
if not os.path.isfile(src):
|
50 |
+
raise Exception('The image %s is not found!' % src)
|
51 |
+
# if exist_file()
|
52 |
+
print('Loading the image %s' % src)
|
app.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
def greet(name):
|
4 |
+
return "Hello " + name + "!!"
|
5 |
+
|
6 |
+
demo = gr.Interface(
|
7 |
+
fn=greet,
|
8 |
+
inputs="text",
|
9 |
+
outputs="text"
|
10 |
+
)
|
11 |
+
demo.launch()
|
checkpoints/original/latest_net_D.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de3e70ce1b90004153f94a5943457d6557ced899e2518101b357bd575cde842a
|
3 |
+
size 11147372
|
checkpoints/original/latest_net_F.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ab783f87f5e4df9df8790195d731a9e4bb894e4367ebeb40a5878f5cd5fcdf1a
|
3 |
+
size 2248355
|
checkpoints/original/latest_net_G.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e9a0c8bdf643c69fab732ed8764f958d46a41b3c6a3ba1a2f3fe2606e3650eb
|
3 |
+
size 45670757
|
examples/.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
rawimg.png filter=lfs diff=lfs merge=lfs -text
|
examples/rawimg.png
ADDED
![]() |
Git LFS Details
|
models/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import importlib
|
22 |
+
from models.base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
model = None
|
35 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
36 |
+
for name, cls in modellib.__dict__.items():
|
37 |
+
if name.lower() == target_model_name.lower() \
|
38 |
+
and issubclass(cls, BaseModel):
|
39 |
+
model = cls
|
40 |
+
|
41 |
+
if model is None:
|
42 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
43 |
+
exit(0)
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def get_option_setter(model_name):
|
49 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
50 |
+
model_class = find_model_using_name(model_name)
|
51 |
+
return model_class.modify_commandline_options
|
52 |
+
|
53 |
+
|
54 |
+
def create_model(opt):
|
55 |
+
"""Create a model given the option.
|
56 |
+
|
57 |
+
This function warps the class CustomDatasetDataLoader.
|
58 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
59 |
+
|
60 |
+
Example:
|
61 |
+
>>> from models import create_model
|
62 |
+
>>> model = create_model(opt)
|
63 |
+
"""
|
64 |
+
model = find_model_using_name(opt.model)
|
65 |
+
instance = model(opt)
|
66 |
+
print("model [%s] was created" % type(instance).__name__)
|
67 |
+
return instance
|
models/base_model.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from . import networks
|
6 |
+
|
7 |
+
|
8 |
+
class BaseModel(ABC):
|
9 |
+
"""This class is an abstract base class (ABC) for models.
|
10 |
+
To create a subclass, you need to implement the following five functions:
|
11 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
12 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
13 |
+
-- <forward>: produce intermediate results.
|
14 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
15 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, opt):
|
19 |
+
"""Initialize the BaseModel class.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
23 |
+
|
24 |
+
When creating your custom class, you need to implement your own initialization.
|
25 |
+
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
|
26 |
+
Then, you need to define four lists:
|
27 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
28 |
+
-- self.model_names (str list): specify the images that you want to display and save.
|
29 |
+
-- self.visual_names (str list): define networks used in our training.
|
30 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
31 |
+
"""
|
32 |
+
self.opt = opt
|
33 |
+
self.gpu_ids = opt.gpu_ids
|
34 |
+
self.isTrain = opt.isTrain
|
35 |
+
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
36 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
37 |
+
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
38 |
+
torch.backends.cudnn.benchmark = True
|
39 |
+
self.loss_names = []
|
40 |
+
self.model_names = []
|
41 |
+
self.visual_names = []
|
42 |
+
self.optimizers = []
|
43 |
+
self.image_paths = []
|
44 |
+
self.metric = 0 # used for learning rate policy 'plateau'
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def dict_grad_hook_factory(add_func=lambda x: x):
|
48 |
+
saved_dict = dict()
|
49 |
+
|
50 |
+
def hook_gen(name):
|
51 |
+
def grad_hook(grad):
|
52 |
+
saved_vals = add_func(grad)
|
53 |
+
saved_dict[name] = saved_vals
|
54 |
+
return grad_hook
|
55 |
+
return hook_gen, saved_dict
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def modify_commandline_options(parser, is_train):
|
59 |
+
"""Add new model-specific options, and rewrite default values for existing options.
|
60 |
+
|
61 |
+
Parameters:
|
62 |
+
parser -- original option parser
|
63 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
the modified parser.
|
67 |
+
"""
|
68 |
+
return parser
|
69 |
+
|
70 |
+
@abstractmethod
|
71 |
+
def set_input(self, input):
|
72 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
73 |
+
|
74 |
+
Parameters:
|
75 |
+
input (dict): includes the data itself and its metadata information.
|
76 |
+
"""
|
77 |
+
pass
|
78 |
+
|
79 |
+
@abstractmethod
|
80 |
+
def forward(self):
|
81 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
82 |
+
pass
|
83 |
+
|
84 |
+
@abstractmethod
|
85 |
+
def optimize_parameters(self):
|
86 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
87 |
+
pass
|
88 |
+
|
89 |
+
def setup(self, opt):
|
90 |
+
"""Load and print networks; create schedulers
|
91 |
+
|
92 |
+
Parameters:
|
93 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
94 |
+
"""
|
95 |
+
if self.isTrain:
|
96 |
+
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
97 |
+
if not self.isTrain or opt.continue_train:
|
98 |
+
load_suffix = opt.epoch
|
99 |
+
self.load_networks(load_suffix)
|
100 |
+
|
101 |
+
self.print_networks(opt.verbose)
|
102 |
+
|
103 |
+
def parallelize(self):
|
104 |
+
for name in self.model_names:
|
105 |
+
if isinstance(name, str):
|
106 |
+
net = getattr(self, 'net' + name)
|
107 |
+
setattr(self, 'net' + name, torch.nn.DataParallel(net, self.opt.gpu_ids))
|
108 |
+
|
109 |
+
def data_dependent_initialize(self, data):
|
110 |
+
pass
|
111 |
+
|
112 |
+
def eval(self):
|
113 |
+
"""Make models eval mode during test time"""
|
114 |
+
for name in self.model_names:
|
115 |
+
if isinstance(name, str):
|
116 |
+
net = getattr(self, 'net' + name)
|
117 |
+
net.eval()
|
118 |
+
|
119 |
+
def test(self):
|
120 |
+
"""Forward function used in test time.
|
121 |
+
|
122 |
+
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
123 |
+
It also calls <compute_visuals> to produce additional visualization results
|
124 |
+
"""
|
125 |
+
with torch.no_grad():
|
126 |
+
self.forward()
|
127 |
+
self.compute_visuals()
|
128 |
+
|
129 |
+
def compute_visuals(self):
|
130 |
+
"""Calculate additional output images for visdom and HTML visualization"""
|
131 |
+
pass
|
132 |
+
|
133 |
+
def get_image_paths(self):
|
134 |
+
""" Return image paths that are used to load current data"""
|
135 |
+
return self.image_paths
|
136 |
+
|
137 |
+
def update_learning_rate(self):
|
138 |
+
"""Update learning rates for all the networks; called at the end of every epoch"""
|
139 |
+
for scheduler in self.schedulers:
|
140 |
+
if self.opt.lr_policy == 'plateau':
|
141 |
+
scheduler.step(self.metric)
|
142 |
+
else:
|
143 |
+
scheduler.step()
|
144 |
+
|
145 |
+
lr = self.optimizers[0].param_groups[0]['lr']
|
146 |
+
print('learning rate = %.7f' % lr)
|
147 |
+
|
148 |
+
def get_current_visuals(self):
|
149 |
+
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
150 |
+
visual_ret = OrderedDict()
|
151 |
+
for name in self.visual_names:
|
152 |
+
if isinstance(name, str):
|
153 |
+
visual_ret[name] = getattr(self, name)
|
154 |
+
return visual_ret
|
155 |
+
|
156 |
+
def get_current_losses(self):
|
157 |
+
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
158 |
+
errors_ret = OrderedDict()
|
159 |
+
for name in self.loss_names:
|
160 |
+
if isinstance(name, str):
|
161 |
+
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
162 |
+
return errors_ret
|
163 |
+
|
164 |
+
def save_networks(self, epoch):
|
165 |
+
"""Save all the networks to the disk.
|
166 |
+
|
167 |
+
Parameters:
|
168 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
169 |
+
"""
|
170 |
+
for name in self.model_names:
|
171 |
+
if isinstance(name, str):
|
172 |
+
save_filename = '%s_net_%s.pth' % (epoch, name)
|
173 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
174 |
+
net = getattr(self, 'net' + name)
|
175 |
+
|
176 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
177 |
+
torch.save(net.module.cpu().state_dict(), save_path)
|
178 |
+
net.cuda(self.gpu_ids[0])
|
179 |
+
else:
|
180 |
+
torch.save(net.cpu().state_dict(), save_path)
|
181 |
+
|
182 |
+
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
183 |
+
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
184 |
+
key = keys[i]
|
185 |
+
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
186 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
187 |
+
(key == 'running_mean' or key == 'running_var'):
|
188 |
+
if getattr(module, key) is None:
|
189 |
+
state_dict.pop('.'.join(keys))
|
190 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
191 |
+
(key == 'num_batches_tracked'):
|
192 |
+
state_dict.pop('.'.join(keys))
|
193 |
+
else:
|
194 |
+
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
195 |
+
|
196 |
+
def load_networks(self, epoch):
|
197 |
+
"""Load all the networks from the disk.
|
198 |
+
|
199 |
+
Parameters:
|
200 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
201 |
+
"""
|
202 |
+
for name in self.model_names:
|
203 |
+
if isinstance(name, str):
|
204 |
+
load_filename = '%s_net_%s.pth' % (epoch, name)
|
205 |
+
if self.opt.isTrain and self.opt.pretrained_name is not None:
|
206 |
+
load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
|
207 |
+
else:
|
208 |
+
load_dir = self.save_dir
|
209 |
+
|
210 |
+
load_path = os.path.join(load_dir, load_filename)
|
211 |
+
net = getattr(self, 'net' + name)
|
212 |
+
if isinstance(net, torch.nn.DataParallel):
|
213 |
+
net = net.module
|
214 |
+
print('loading the model from %s' % load_path)
|
215 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
216 |
+
# GitHub source), you can remove str() on self.device
|
217 |
+
state_dict = torch.load(load_path, map_location=str(self.device), weights_only=True)
|
218 |
+
if hasattr(state_dict, '_metadata'):
|
219 |
+
del state_dict._metadata
|
220 |
+
|
221 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
222 |
+
# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
223 |
+
# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
224 |
+
net.load_state_dict(state_dict)
|
225 |
+
|
226 |
+
def print_networks(self, verbose):
|
227 |
+
"""Print the total number of parameters in the network and (if verbose) network architecture
|
228 |
+
|
229 |
+
Parameters:
|
230 |
+
verbose (bool) -- if verbose: print the network architecture
|
231 |
+
"""
|
232 |
+
print('---------- Networks initialized -------------')
|
233 |
+
for name in self.model_names:
|
234 |
+
if isinstance(name, str):
|
235 |
+
net = getattr(self, 'net' + name)
|
236 |
+
num_params = 0
|
237 |
+
for param in net.parameters():
|
238 |
+
num_params += param.numel()
|
239 |
+
if verbose:
|
240 |
+
print(net)
|
241 |
+
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
242 |
+
print('-----------------------------------------------')
|
243 |
+
|
244 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
245 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
246 |
+
Parameters:
|
247 |
+
nets (network list) -- a list of networks
|
248 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
249 |
+
"""
|
250 |
+
if not isinstance(nets, list):
|
251 |
+
nets = [nets]
|
252 |
+
for net in nets:
|
253 |
+
if net is not None:
|
254 |
+
for param in net.parameters():
|
255 |
+
param.requires_grad = requires_grad
|
256 |
+
|
257 |
+
def generate_visuals_for_evaluation(self, data, mode):
|
258 |
+
return {}
|
models/cwr_model.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from .base_model import BaseModel
|
4 |
+
from . import networks
|
5 |
+
from .patchnce import PatchNCELoss
|
6 |
+
import util.util as util
|
7 |
+
|
8 |
+
|
9 |
+
class CWRModel(BaseModel):
|
10 |
+
""" This class implements CWR model, described in the paper
|
11 |
+
Single Underwater Image Restoration by contrastive learning
|
12 |
+
Junlin Han, Mehrdad Shoeiby, Tim Malthus, Elizabeth Botha, Janet Anstee, Saeed Anwar, Ran Wei, Lars Petersson, Mohammad Ali Armin
|
13 |
+
International Geoscience and Remote Sensing Symposium (IGARSS), 2021
|
14 |
+
|
15 |
+
|
16 |
+
The code borrows heavily from the PyTorch implementation of CycleGAN and CUT
|
17 |
+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
|
18 |
+
https://github.com/taesungp/contrastive-unpaired-translation
|
19 |
+
"""
|
20 |
+
@staticmethod
|
21 |
+
def modify_commandline_options(parser, is_train=True):
|
22 |
+
""" Configures options specific for CUT model
|
23 |
+
"""
|
24 |
+
parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
|
25 |
+
parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
|
26 |
+
parser.add_argument('--lambda_IDT', type=float, default=10.0, help='weight for NCE loss: NCE(G(X), X)')
|
27 |
+
parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
|
28 |
+
parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
|
29 |
+
parser.add_argument('--nce_includes_all_negatives_from_minibatch',
|
30 |
+
type=util.str2bool, nargs='?', const=True, default=False,
|
31 |
+
help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
|
32 |
+
parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
|
33 |
+
parser.add_argument('--netF_nc', type=int, default=256)
|
34 |
+
parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
|
35 |
+
parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
|
36 |
+
parser.add_argument('--flip_equivariance',
|
37 |
+
type=util.str2bool, nargs='?', const=True, default=False,
|
38 |
+
help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CWR")
|
39 |
+
|
40 |
+
parser.set_defaults(pool_size=0) # no image pooling
|
41 |
+
|
42 |
+
opt, _ = parser.parse_known_args()
|
43 |
+
|
44 |
+
parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
|
45 |
+
|
46 |
+
return parser
|
47 |
+
|
48 |
+
def __init__(self, opt):
|
49 |
+
BaseModel.__init__(self, opt)
|
50 |
+
|
51 |
+
# specify the training losses you want to print out.
|
52 |
+
# The training/test scripts will call <BaseModel.get_current_losses>
|
53 |
+
self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE', 'idt']
|
54 |
+
self.visual_names = ['real_A', 'fake_B', 'real_B']
|
55 |
+
self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
|
56 |
+
|
57 |
+
if opt.nce_idt and self.isTrain:
|
58 |
+
self.loss_names += ['NCE_Y']
|
59 |
+
self.visual_names += ['idt_B']
|
60 |
+
|
61 |
+
if self.isTrain:
|
62 |
+
self.model_names = ['G', 'F', 'D']
|
63 |
+
else: # during test time, only load G
|
64 |
+
self.model_names = ['G']
|
65 |
+
|
66 |
+
# define networks (both generator and discriminator)
|
67 |
+
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
|
68 |
+
self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
69 |
+
|
70 |
+
if self.isTrain:
|
71 |
+
self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
|
72 |
+
|
73 |
+
# define loss functions
|
74 |
+
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
75 |
+
self.criterionNCE = []
|
76 |
+
|
77 |
+
for nce_layer in self.nce_layers:
|
78 |
+
self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
|
79 |
+
|
80 |
+
self.criterionIdt = torch.nn.L1Loss().to(self.device)
|
81 |
+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
82 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
|
83 |
+
self.optimizers.append(self.optimizer_G)
|
84 |
+
self.optimizers.append(self.optimizer_D)
|
85 |
+
|
86 |
+
def data_dependent_initialize(self, data):
|
87 |
+
"""
|
88 |
+
The feature network netF is defined in terms of the shape of the intermediate, extracted
|
89 |
+
features of the encoder portion of netG. Because of this, the weights of netF are
|
90 |
+
initialized at the first feedforward pass with some input images.
|
91 |
+
Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
|
92 |
+
"""
|
93 |
+
self.set_input(data)
|
94 |
+
bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1)
|
95 |
+
self.real_A = self.real_A[:bs_per_gpu]
|
96 |
+
self.real_B = self.real_B[:bs_per_gpu]
|
97 |
+
self.forward() # compute fake images: G(A)
|
98 |
+
if self.opt.isTrain:
|
99 |
+
self.compute_D_loss().backward() # calculate gradients for D
|
100 |
+
self.compute_G_loss().backward() # calculate graidents for G
|
101 |
+
if self.opt.lambda_NCE > 0.0:
|
102 |
+
self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
|
103 |
+
self.optimizers.append(self.optimizer_F)
|
104 |
+
|
105 |
+
def optimize_parameters(self):
|
106 |
+
# forward
|
107 |
+
self.forward()
|
108 |
+
|
109 |
+
# update D
|
110 |
+
self.set_requires_grad(self.netD, True)
|
111 |
+
self.optimizer_D.zero_grad()
|
112 |
+
self.loss_D = self.compute_D_loss()
|
113 |
+
self.loss_D.backward()
|
114 |
+
self.optimizer_D.step()
|
115 |
+
|
116 |
+
# update G
|
117 |
+
self.set_requires_grad(self.netD, False)
|
118 |
+
self.optimizer_G.zero_grad()
|
119 |
+
if self.opt.netF == 'mlp_sample':
|
120 |
+
self.optimizer_F.zero_grad()
|
121 |
+
self.loss_G = self.compute_G_loss()
|
122 |
+
self.loss_G.backward()
|
123 |
+
self.optimizer_G.step()
|
124 |
+
if self.opt.netF == 'mlp_sample':
|
125 |
+
self.optimizer_F.step()
|
126 |
+
|
127 |
+
def set_input(self, input):
|
128 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
129 |
+
Parameters:
|
130 |
+
input (dict): include the data itself and its metadata information.
|
131 |
+
The option 'direction' can be used to swap domain A and domain B.
|
132 |
+
"""
|
133 |
+
AtoB = self.opt.direction == 'AtoB'
|
134 |
+
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
135 |
+
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
136 |
+
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
137 |
+
|
138 |
+
def forward(self):
|
139 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
140 |
+
self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
|
141 |
+
if self.opt.flip_equivariance:
|
142 |
+
self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
|
143 |
+
if self.flipped_for_equivariance:
|
144 |
+
self.real = torch.flip(self.real, [3])
|
145 |
+
|
146 |
+
self.fake = self.netG(self.real)
|
147 |
+
self.fake_B = self.fake[:self.real_A.size(0)]
|
148 |
+
if self.opt.nce_idt:
|
149 |
+
self.idt_B = self.fake[self.real_A.size(0):]
|
150 |
+
|
151 |
+
def compute_D_loss(self):
|
152 |
+
"""Calculate GAN loss for the discriminator"""
|
153 |
+
fake = self.fake_B.detach()
|
154 |
+
# Fake; stop backprop to the generator by detaching fake_B
|
155 |
+
pred_fake = self.netD(fake)
|
156 |
+
self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
|
157 |
+
# Real
|
158 |
+
self.pred_real = self.netD(self.real_B)
|
159 |
+
loss_D_real = self.criterionGAN(self.pred_real, True)
|
160 |
+
self.loss_D_real = loss_D_real.mean()
|
161 |
+
|
162 |
+
# combine loss and calculate gradients
|
163 |
+
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
164 |
+
return self.loss_D
|
165 |
+
|
166 |
+
def compute_G_loss(self):
|
167 |
+
"""Calculate GAN and NCE loss for the generator"""
|
168 |
+
fake = self.fake_B
|
169 |
+
# First, G(A) should fake the discriminator
|
170 |
+
if self.opt.lambda_GAN > 0.0:
|
171 |
+
pred_fake = self.netD(fake)
|
172 |
+
self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
|
173 |
+
else:
|
174 |
+
self.loss_G_GAN = 0.0
|
175 |
+
|
176 |
+
if self.opt.lambda_NCE > 0.0:
|
177 |
+
self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
|
178 |
+
else:
|
179 |
+
self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
|
180 |
+
|
181 |
+
if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
|
182 |
+
self.loss_NCE_Y = 0
|
183 |
+
loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y)
|
184 |
+
self.loss_idt = self.criterionIdt(self.idt_B, self.real_B) * self.opt.lambda_IDT
|
185 |
+
else:
|
186 |
+
loss_NCE_both = self.loss_NCE
|
187 |
+
|
188 |
+
self.loss_G = self.loss_G_GAN + loss_NCE_both + self.loss_idt
|
189 |
+
return self.loss_G
|
190 |
+
|
191 |
+
def calculate_NCE_loss(self, src, tgt):
|
192 |
+
n_layers = len(self.nce_layers)
|
193 |
+
feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
|
194 |
+
|
195 |
+
if self.opt.flip_equivariance and self.flipped_for_equivariance:
|
196 |
+
feat_q = [torch.flip(fq, [3]) for fq in feat_q]
|
197 |
+
|
198 |
+
feat_k = self.netG(src, self.nce_layers, encode_only=True)
|
199 |
+
feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
|
200 |
+
feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
|
201 |
+
|
202 |
+
total_nce_loss = 0.0
|
203 |
+
for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
|
204 |
+
loss = crit(f_q, f_k) * self.opt.lambda_NCE
|
205 |
+
total_nce_loss += loss.mean()
|
206 |
+
|
207 |
+
return total_nce_loss / n_layers
|
models/networks.py
ADDED
@@ -0,0 +1,1401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn import init
|
5 |
+
import functools
|
6 |
+
from torch.optim import lr_scheduler
|
7 |
+
import numpy as np
|
8 |
+
from .spectralNormalization import SpectralNorm
|
9 |
+
from .stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator
|
10 |
+
|
11 |
+
###############################################################################
|
12 |
+
# Helper Functions
|
13 |
+
###############################################################################
|
14 |
+
|
15 |
+
|
16 |
+
def get_filter(filt_size=3):
|
17 |
+
if(filt_size == 1):
|
18 |
+
a = np.array([1., ])
|
19 |
+
elif(filt_size == 2):
|
20 |
+
a = np.array([1., 1.])
|
21 |
+
elif(filt_size == 3):
|
22 |
+
a = np.array([1., 2., 1.])
|
23 |
+
elif(filt_size == 4):
|
24 |
+
a = np.array([1., 3., 3., 1.])
|
25 |
+
elif(filt_size == 5):
|
26 |
+
a = np.array([1., 4., 6., 4., 1.])
|
27 |
+
elif(filt_size == 6):
|
28 |
+
a = np.array([1., 5., 10., 10., 5., 1.])
|
29 |
+
elif(filt_size == 7):
|
30 |
+
a = np.array([1., 6., 15., 20., 15., 6., 1.])
|
31 |
+
|
32 |
+
filt = torch.Tensor(a[:, None] * a[None, :])
|
33 |
+
filt = filt / torch.sum(filt)
|
34 |
+
|
35 |
+
return filt
|
36 |
+
|
37 |
+
|
38 |
+
class Downsample(nn.Module):
|
39 |
+
def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
|
40 |
+
super(Downsample, self).__init__()
|
41 |
+
self.filt_size = filt_size
|
42 |
+
self.pad_off = pad_off
|
43 |
+
self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
|
44 |
+
self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
|
45 |
+
self.stride = stride
|
46 |
+
self.off = int((self.stride - 1) / 2.)
|
47 |
+
self.channels = channels
|
48 |
+
|
49 |
+
filt = get_filter(filt_size=self.filt_size)
|
50 |
+
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
51 |
+
|
52 |
+
self.pad = get_pad_layer(pad_type)(self.pad_sizes)
|
53 |
+
|
54 |
+
def forward(self, inp):
|
55 |
+
if(self.filt_size == 1):
|
56 |
+
if(self.pad_off == 0):
|
57 |
+
return inp[:, :, ::self.stride, ::self.stride]
|
58 |
+
else:
|
59 |
+
return self.pad(inp)[:, :, ::self.stride, ::self.stride]
|
60 |
+
else:
|
61 |
+
return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
|
62 |
+
|
63 |
+
|
64 |
+
class Upsample2(nn.Module):
|
65 |
+
def __init__(self, scale_factor, mode='nearest'):
|
66 |
+
super().__init__()
|
67 |
+
self.factor = scale_factor
|
68 |
+
self.mode = mode
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
|
72 |
+
|
73 |
+
|
74 |
+
class Upsample(nn.Module):
|
75 |
+
def __init__(self, channels, pad_type='repl', filt_size=4, stride=2):
|
76 |
+
super(Upsample, self).__init__()
|
77 |
+
self.filt_size = filt_size
|
78 |
+
self.filt_odd = np.mod(filt_size, 2) == 1
|
79 |
+
self.pad_size = int((filt_size - 1) / 2)
|
80 |
+
self.stride = stride
|
81 |
+
self.off = int((self.stride - 1) / 2.)
|
82 |
+
self.channels = channels
|
83 |
+
|
84 |
+
filt = get_filter(filt_size=self.filt_size) * (stride**2)
|
85 |
+
self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
|
86 |
+
|
87 |
+
self.pad = get_pad_layer(pad_type)([1, 1, 1, 1])
|
88 |
+
|
89 |
+
def forward(self, inp):
|
90 |
+
ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:]
|
91 |
+
if(self.filt_odd):
|
92 |
+
return ret_val
|
93 |
+
else:
|
94 |
+
return ret_val[:, :, :-1, :-1]
|
95 |
+
|
96 |
+
|
97 |
+
def get_pad_layer(pad_type):
|
98 |
+
if(pad_type in ['refl', 'reflect']):
|
99 |
+
PadLayer = nn.ReflectionPad2d
|
100 |
+
elif(pad_type in ['repl', 'replicate']):
|
101 |
+
PadLayer = nn.ReplicationPad2d
|
102 |
+
elif(pad_type == 'zero'):
|
103 |
+
PadLayer = nn.ZeroPad2d
|
104 |
+
else:
|
105 |
+
print('Pad type [%s] not recognized' % pad_type)
|
106 |
+
return PadLayer
|
107 |
+
|
108 |
+
|
109 |
+
class Identity(nn.Module):
|
110 |
+
def forward(self, x):
|
111 |
+
return x
|
112 |
+
|
113 |
+
|
114 |
+
def get_norm_layer(norm_type='instance'):
|
115 |
+
"""Return a normalization layer
|
116 |
+
|
117 |
+
Parameters:
|
118 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
119 |
+
|
120 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
121 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
122 |
+
"""
|
123 |
+
if norm_type == 'batch':
|
124 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
125 |
+
elif norm_type == 'instance':
|
126 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
127 |
+
elif norm_type == 'none':
|
128 |
+
def norm_layer(x):
|
129 |
+
return Identity()
|
130 |
+
else:
|
131 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
132 |
+
return norm_layer
|
133 |
+
|
134 |
+
|
135 |
+
def get_scheduler(optimizer, opt):
|
136 |
+
"""Return a learning rate scheduler
|
137 |
+
|
138 |
+
Parameters:
|
139 |
+
optimizer -- the optimizer of the network
|
140 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
141 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
142 |
+
|
143 |
+
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
144 |
+
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
145 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
146 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
147 |
+
"""
|
148 |
+
if opt.lr_policy == 'linear':
|
149 |
+
def lambda_rule(epoch):
|
150 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
151 |
+
return lr_l
|
152 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
153 |
+
elif opt.lr_policy == 'step':
|
154 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
155 |
+
elif opt.lr_policy == 'plateau':
|
156 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
157 |
+
elif opt.lr_policy == 'cosine':
|
158 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
159 |
+
else:
|
160 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
161 |
+
return scheduler
|
162 |
+
|
163 |
+
|
164 |
+
def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
|
165 |
+
"""Initialize network weights.
|
166 |
+
|
167 |
+
Parameters:
|
168 |
+
net (network) -- network to be initialized
|
169 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
170 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
171 |
+
|
172 |
+
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
|
173 |
+
work better for some applications. Feel free to try yourself.
|
174 |
+
"""
|
175 |
+
def init_func(m): # define the initialization function
|
176 |
+
classname = m.__class__.__name__
|
177 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
178 |
+
if debug:
|
179 |
+
print(classname)
|
180 |
+
if init_type == 'normal':
|
181 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
182 |
+
elif init_type == 'xavier':
|
183 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
184 |
+
elif init_type == 'kaiming':
|
185 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
186 |
+
elif init_type == 'orthogonal':
|
187 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
188 |
+
else:
|
189 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
190 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
191 |
+
init.constant_(m.bias.data, 0.0)
|
192 |
+
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
193 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
194 |
+
init.constant_(m.bias.data, 0.0)
|
195 |
+
|
196 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
197 |
+
|
198 |
+
|
199 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
|
200 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
201 |
+
Parameters:
|
202 |
+
net (network) -- the network to be initialized
|
203 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
204 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
205 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
206 |
+
|
207 |
+
Return an initialized network.
|
208 |
+
"""
|
209 |
+
if len(gpu_ids) > 0:
|
210 |
+
assert(torch.cuda.is_available())
|
211 |
+
net.to(gpu_ids[0])
|
212 |
+
# if not amp:
|
213 |
+
# net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
|
214 |
+
if initialize_weights:
|
215 |
+
init_weights(net, init_type, init_gain=init_gain, debug=debug)
|
216 |
+
return net
|
217 |
+
|
218 |
+
|
219 |
+
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal',
|
220 |
+
init_gain=0.02, no_antialias=False, no_antialias_up=False, gpu_ids=[], opt=None):
|
221 |
+
"""Create a generator
|
222 |
+
|
223 |
+
Parameters:
|
224 |
+
input_nc (int) -- the number of channels in input images
|
225 |
+
output_nc (int) -- the number of channels in output images
|
226 |
+
ngf (int) -- the number of filters in the last conv layer
|
227 |
+
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
|
228 |
+
norm (str) -- the name of normalization layers used in the network: batch | instance | none
|
229 |
+
use_dropout (bool) -- if use dropout layers.
|
230 |
+
init_type (str) -- the name of our initialization method.
|
231 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
232 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
233 |
+
|
234 |
+
Returns a generator
|
235 |
+
|
236 |
+
Our current implementation provides two types of generators:
|
237 |
+
U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
|
238 |
+
The original U-Net paper: https://arxiv.org/abs/1505.04597
|
239 |
+
|
240 |
+
Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
|
241 |
+
Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
|
242 |
+
We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
|
243 |
+
|
244 |
+
|
245 |
+
The generator has been initialized by <init_net>. It uses RELU for non-linearity.
|
246 |
+
"""
|
247 |
+
net = None
|
248 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
249 |
+
|
250 |
+
if netG == 'resnet_9blocks':
|
251 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
|
252 |
+
elif netG == 'resnet_6blocks':
|
253 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=6, opt=opt)
|
254 |
+
elif netG == 'resnet_4blocks':
|
255 |
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=4, opt=opt)
|
256 |
+
elif netG == 'unet_128':
|
257 |
+
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
258 |
+
elif netG == 'unet_256':
|
259 |
+
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
260 |
+
elif netG == 'stylegan2':
|
261 |
+
net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, opt=opt)
|
262 |
+
elif netG == 'smallstylegan2':
|
263 |
+
net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, n_blocks=2, opt=opt)
|
264 |
+
elif netG == 'resnet_cat':
|
265 |
+
n_blocks = 8
|
266 |
+
net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
|
267 |
+
else:
|
268 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
269 |
+
return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))
|
270 |
+
|
271 |
+
|
272 |
+
def define_F(input_nc, netF, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
|
273 |
+
if netF == 'global_pool':
|
274 |
+
net = PoolingF()
|
275 |
+
elif netF == 'reshape':
|
276 |
+
net = ReshapeF()
|
277 |
+
elif netF == 'sample':
|
278 |
+
net = PatchSampleF(use_mlp=False, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc)
|
279 |
+
elif netF == 'mlp_sample':
|
280 |
+
net = PatchSampleF(use_mlp=True, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc)
|
281 |
+
elif netF == 'strided_conv':
|
282 |
+
net = StridedConvF(init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
|
283 |
+
else:
|
284 |
+
raise NotImplementedError('projection model name [%s] is not recognized' % netF)
|
285 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
286 |
+
|
287 |
+
|
288 |
+
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
|
289 |
+
"""Create a discriminator
|
290 |
+
|
291 |
+
Parameters:
|
292 |
+
input_nc (int) -- the number of channels in input images
|
293 |
+
ndf (int) -- the number of filters in the first conv layer
|
294 |
+
netD (str) -- the architecture's name: basic | n_layers | pixel
|
295 |
+
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
|
296 |
+
norm (str) -- the type of normalization layers used in the network.
|
297 |
+
init_type (str) -- the name of the initialization method.
|
298 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
299 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
300 |
+
|
301 |
+
Returns a discriminator
|
302 |
+
|
303 |
+
Our current implementation provides three types of discriminators:
|
304 |
+
[basic]: 'PatchGAN' classifier described in the original pix2pix paper.
|
305 |
+
It can classify whether 70×70 overlapping patches are real or fake.
|
306 |
+
Such a patch-level discriminator architecture has fewer parameters
|
307 |
+
than a full-image discriminator and can work on arbitrarily-sized images
|
308 |
+
in a fully convolutional fashion.
|
309 |
+
|
310 |
+
[n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
|
311 |
+
with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
|
312 |
+
|
313 |
+
[pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
|
314 |
+
It encourages greater color diversity but has no effect on spatial statistics.
|
315 |
+
|
316 |
+
The discriminator has been initialized by <init_net>. It uses Leaky RELU for non-linearity.
|
317 |
+
"""
|
318 |
+
net = None
|
319 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
320 |
+
|
321 |
+
if netD == 'basic': # default PatchGAN classifier
|
322 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias,)
|
323 |
+
elif netD == 'n_layers': # more options
|
324 |
+
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias,)
|
325 |
+
elif netD == 'pixel': # classify if each pixel is real or fake
|
326 |
+
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
|
327 |
+
elif 'stylegan2' in netD:
|
328 |
+
net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt)
|
329 |
+
else:
|
330 |
+
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
|
331 |
+
return init_net(net, init_type, init_gain, gpu_ids,
|
332 |
+
initialize_weights=('stylegan2' not in netD))
|
333 |
+
|
334 |
+
|
335 |
+
##############################################################################
|
336 |
+
# Classes
|
337 |
+
##############################################################################
|
338 |
+
class GANLoss(nn.Module):
|
339 |
+
"""Define different GAN objectives.
|
340 |
+
|
341 |
+
The GANLoss class abstracts away the need to create the target label tensor
|
342 |
+
that has the same size as the input.
|
343 |
+
"""
|
344 |
+
|
345 |
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
346 |
+
""" Initialize the GANLoss class.
|
347 |
+
|
348 |
+
Parameters:
|
349 |
+
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
350 |
+
target_real_label (bool) - - label for a real image
|
351 |
+
target_fake_label (bool) - - label of a fake image
|
352 |
+
|
353 |
+
Note: Do not use sigmoid as the last layer of Discriminator.
|
354 |
+
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
355 |
+
"""
|
356 |
+
super(GANLoss, self).__init__()
|
357 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
358 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
359 |
+
self.gan_mode = gan_mode
|
360 |
+
if gan_mode == 'lsgan':
|
361 |
+
self.loss = nn.MSELoss()
|
362 |
+
elif gan_mode == 'vanilla':
|
363 |
+
self.loss = nn.BCEWithLogitsLoss()
|
364 |
+
elif gan_mode in ['wgangp', 'nonsaturating']:
|
365 |
+
self.loss = None
|
366 |
+
elif gan_mode == "hinge":
|
367 |
+
self.loss = None
|
368 |
+
else:
|
369 |
+
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
370 |
+
|
371 |
+
def get_target_tensor(self, prediction, target_is_real):
|
372 |
+
"""Create label tensors with the same size as the input.
|
373 |
+
|
374 |
+
Parameters:
|
375 |
+
prediction (tensor) - - tpyically the prediction from a discriminator
|
376 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
377 |
+
|
378 |
+
Returns:
|
379 |
+
A label tensor filled with ground truth label, and with the size of the input
|
380 |
+
"""
|
381 |
+
|
382 |
+
if target_is_real:
|
383 |
+
target_tensor = self.real_label
|
384 |
+
else:
|
385 |
+
target_tensor = self.fake_label
|
386 |
+
return target_tensor.expand_as(prediction)
|
387 |
+
|
388 |
+
def __call__(self, prediction, target_is_real):
|
389 |
+
"""Calculate loss given Discriminator's output and grount truth labels.
|
390 |
+
|
391 |
+
Parameters:
|
392 |
+
prediction (tensor) - - tpyically the prediction output from a discriminator
|
393 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
394 |
+
|
395 |
+
Returns:
|
396 |
+
the calculated loss.
|
397 |
+
"""
|
398 |
+
bs = prediction.size(0)
|
399 |
+
if self.gan_mode in ['lsgan', 'vanilla']:
|
400 |
+
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
401 |
+
loss = self.loss(prediction, target_tensor)
|
402 |
+
elif self.gan_mode == 'wgangp':
|
403 |
+
if target_is_real:
|
404 |
+
loss = -prediction.mean()
|
405 |
+
else:
|
406 |
+
loss = prediction.mean()
|
407 |
+
elif self.gan_mode == 'nonsaturating':
|
408 |
+
if target_is_real:
|
409 |
+
loss = F.softplus(-prediction).view(bs, -1).mean(dim=1)
|
410 |
+
else:
|
411 |
+
loss = F.softplus(prediction).view(bs, -1).mean(dim=1)
|
412 |
+
elif self.gan_mode == 'hinge':
|
413 |
+
if target_is_real:
|
414 |
+
minvalue = torch.min(prediction - 1, torch.zeros(prediction.shape).to(prediction.device))
|
415 |
+
loss = -torch.mean(minvalue)
|
416 |
+
else:
|
417 |
+
minvalue = torch.min(-prediction - 1,torch.zeros(prediction.shape).to(prediction.device))
|
418 |
+
loss = -torch.mean(minvalue)
|
419 |
+
return loss
|
420 |
+
|
421 |
+
|
422 |
+
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
|
423 |
+
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
|
424 |
+
|
425 |
+
Arguments:
|
426 |
+
netD (network) -- discriminator network
|
427 |
+
real_data (tensor array) -- real images
|
428 |
+
fake_data (tensor array) -- generated images from the generator
|
429 |
+
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
|
430 |
+
type (str) -- if we mix real and fake data or not [real | fake | mixed].
|
431 |
+
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
|
432 |
+
lambda_gp (float) -- weight for this loss
|
433 |
+
|
434 |
+
Returns the gradient penalty loss
|
435 |
+
"""
|
436 |
+
if lambda_gp > 0.0:
|
437 |
+
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
|
438 |
+
interpolatesv = real_data
|
439 |
+
elif type == 'fake':
|
440 |
+
interpolatesv = fake_data
|
441 |
+
elif type == 'mixed':
|
442 |
+
alpha = torch.rand(real_data.shape[0], 1, device=device)
|
443 |
+
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
|
444 |
+
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
|
445 |
+
else:
|
446 |
+
raise NotImplementedError('{} not implemented'.format(type))
|
447 |
+
interpolatesv.requires_grad_(True)
|
448 |
+
disc_interpolates = netD(interpolatesv)
|
449 |
+
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
|
450 |
+
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
|
451 |
+
create_graph=True, retain_graph=True, only_inputs=True)
|
452 |
+
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
|
453 |
+
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
|
454 |
+
return gradient_penalty, gradients
|
455 |
+
else:
|
456 |
+
return 0.0, None
|
457 |
+
|
458 |
+
|
459 |
+
class Normalize(nn.Module):
|
460 |
+
|
461 |
+
def __init__(self, power=2):
|
462 |
+
super(Normalize, self).__init__()
|
463 |
+
self.power = power
|
464 |
+
|
465 |
+
def forward(self, x):
|
466 |
+
norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
|
467 |
+
out = x.div(norm + 1e-7)
|
468 |
+
return out
|
469 |
+
|
470 |
+
|
471 |
+
class PoolingF(nn.Module):
|
472 |
+
def __init__(self):
|
473 |
+
super(PoolingF, self).__init__()
|
474 |
+
model = [nn.AdaptiveMaxPool2d(1)]
|
475 |
+
self.model = nn.Sequential(*model)
|
476 |
+
self.l2norm = Normalize(2)
|
477 |
+
|
478 |
+
def forward(self, x):
|
479 |
+
return self.l2norm(self.model(x))
|
480 |
+
|
481 |
+
|
482 |
+
class ReshapeF(nn.Module):
|
483 |
+
def __init__(self):
|
484 |
+
super(ReshapeF, self).__init__()
|
485 |
+
model = [nn.AdaptiveAvgPool2d(4)]
|
486 |
+
self.model = nn.Sequential(*model)
|
487 |
+
self.l2norm = Normalize(2)
|
488 |
+
|
489 |
+
def forward(self, x):
|
490 |
+
x = self.model(x)
|
491 |
+
x_reshape = x.permute(0, 2, 3, 1).flatten(0, 2)
|
492 |
+
return self.l2norm(x_reshape)
|
493 |
+
|
494 |
+
class StridedConvF(nn.Module):
|
495 |
+
def __init__(self, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
496 |
+
super().__init__()
|
497 |
+
# self.conv1 = nn.Conv2d(256, 128, 3, stride=2)
|
498 |
+
# self.conv2 = nn.Conv2d(128, 64, 3, stride=1)
|
499 |
+
self.l2_norm = Normalize(2)
|
500 |
+
self.mlps = {}
|
501 |
+
self.moving_averages = {}
|
502 |
+
self.init_type = init_type
|
503 |
+
self.init_gain = init_gain
|
504 |
+
self.gpu_ids = gpu_ids
|
505 |
+
|
506 |
+
def create_mlp(self, x):
|
507 |
+
C, H = x.shape[1], x.shape[2]
|
508 |
+
n_down = int(np.rint(np.log2(H / 32)))
|
509 |
+
mlp = []
|
510 |
+
for i in range(n_down):
|
511 |
+
mlp.append(nn.Conv2d(C, max(C // 2, 64), 3, stride=2))
|
512 |
+
mlp.append(nn.ReLU())
|
513 |
+
C = max(C // 2, 64)
|
514 |
+
mlp.append(nn.Conv2d(C, 64, 3))
|
515 |
+
mlp = nn.Sequential(*mlp)
|
516 |
+
init_net(mlp, self.init_type, self.init_gain, self.gpu_ids)
|
517 |
+
return mlp
|
518 |
+
|
519 |
+
def update_moving_average(self, key, x):
|
520 |
+
if key not in self.moving_averages:
|
521 |
+
self.moving_averages[key] = x.detach()
|
522 |
+
|
523 |
+
self.moving_averages[key] = self.moving_averages[key] * 0.999 + x.detach() * 0.001
|
524 |
+
|
525 |
+
def forward(self, x, use_instance_norm=False):
|
526 |
+
C, H = x.shape[1], x.shape[2]
|
527 |
+
key = '%d_%d' % (C, H)
|
528 |
+
if key not in self.mlps:
|
529 |
+
self.mlps[key] = self.create_mlp(x)
|
530 |
+
self.add_module("child_%s" % key, self.mlps[key])
|
531 |
+
mlp = self.mlps[key]
|
532 |
+
x = mlp(x)
|
533 |
+
self.update_moving_average(key, x)
|
534 |
+
x = x - self.moving_averages[key]
|
535 |
+
if use_instance_norm:
|
536 |
+
x = F.instance_norm(x)
|
537 |
+
return self.l2_norm(x)
|
538 |
+
|
539 |
+
|
540 |
+
class PatchSampleF(nn.Module):
|
541 |
+
def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
|
542 |
+
# potential issues: currently, we use the same patch_ids for multiple images in the batch
|
543 |
+
super(PatchSampleF, self).__init__()
|
544 |
+
self.l2norm = Normalize(2)
|
545 |
+
self.use_mlp = use_mlp
|
546 |
+
self.nc = nc # hard-coded
|
547 |
+
self.mlp_init = False
|
548 |
+
self.init_type = init_type
|
549 |
+
self.init_gain = init_gain
|
550 |
+
self.gpu_ids = gpu_ids
|
551 |
+
|
552 |
+
def create_mlp(self, feats):
|
553 |
+
for mlp_id, feat in enumerate(feats):
|
554 |
+
input_nc = feat.shape[1]
|
555 |
+
mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
|
556 |
+
if len(self.gpu_ids) > 0:
|
557 |
+
mlp.cuda()
|
558 |
+
setattr(self, 'mlp_%d' % mlp_id, mlp)
|
559 |
+
init_net(self, self.init_type, self.init_gain, self.gpu_ids)
|
560 |
+
self.mlp_init = True
|
561 |
+
|
562 |
+
def forward(self, feats, num_patches=64, patch_ids=None):
|
563 |
+
return_ids = []
|
564 |
+
return_feats = []
|
565 |
+
if self.use_mlp and not self.mlp_init:
|
566 |
+
self.create_mlp(feats)
|
567 |
+
for feat_id, feat in enumerate(feats):
|
568 |
+
B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
|
569 |
+
feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
|
570 |
+
if num_patches > 0:
|
571 |
+
if patch_ids is not None:
|
572 |
+
patch_id = patch_ids[feat_id]
|
573 |
+
else:
|
574 |
+
patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
|
575 |
+
patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
|
576 |
+
x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
|
577 |
+
else:
|
578 |
+
x_sample = feat_reshape
|
579 |
+
patch_id = []
|
580 |
+
if self.use_mlp:
|
581 |
+
mlp = getattr(self, 'mlp_%d' % feat_id)
|
582 |
+
x_sample = mlp(x_sample)
|
583 |
+
return_ids.append(patch_id)
|
584 |
+
x_sample = self.l2norm(x_sample)
|
585 |
+
|
586 |
+
if num_patches == 0:
|
587 |
+
x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
|
588 |
+
return_feats.append(x_sample)
|
589 |
+
return return_feats, return_ids
|
590 |
+
|
591 |
+
|
592 |
+
class G_Resnet(nn.Module):
|
593 |
+
def __init__(self, input_nc, output_nc, nz, num_downs, n_res, ngf=64,
|
594 |
+
norm=None, nl_layer=None):
|
595 |
+
super(G_Resnet, self).__init__()
|
596 |
+
n_downsample = num_downs
|
597 |
+
pad_type = 'reflect'
|
598 |
+
self.enc_content = ContentEncoder(n_downsample, n_res, input_nc, ngf, norm, nl_layer, pad_type=pad_type)
|
599 |
+
if nz == 0:
|
600 |
+
self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
|
601 |
+
else:
|
602 |
+
self.dec = Decoder_all(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
|
603 |
+
|
604 |
+
def decode(self, content, style=None):
|
605 |
+
return self.dec(content, style)
|
606 |
+
|
607 |
+
def forward(self, image, style=None, nce_layers=[], encode_only=False):
|
608 |
+
content, feats = self.enc_content(image, nce_layers=nce_layers, encode_only=encode_only)
|
609 |
+
if encode_only:
|
610 |
+
return feats
|
611 |
+
else:
|
612 |
+
images_recon = self.decode(content, style)
|
613 |
+
if len(nce_layers) > 0:
|
614 |
+
return images_recon, feats
|
615 |
+
else:
|
616 |
+
return images_recon
|
617 |
+
|
618 |
+
##################################################################################
|
619 |
+
# Encoder and Decoders
|
620 |
+
##################################################################################
|
621 |
+
|
622 |
+
|
623 |
+
class E_adaIN(nn.Module):
|
624 |
+
def __init__(self, input_nc, output_nc=1, nef=64, n_layers=4,
|
625 |
+
norm=None, nl_layer=None, vae=False):
|
626 |
+
# style encoder
|
627 |
+
super(E_adaIN, self).__init__()
|
628 |
+
self.enc_style = StyleEncoder(n_layers, input_nc, nef, output_nc, norm='none', activ='relu', vae=vae)
|
629 |
+
|
630 |
+
def forward(self, image):
|
631 |
+
style = self.enc_style(image)
|
632 |
+
return style
|
633 |
+
|
634 |
+
|
635 |
+
class StyleEncoder(nn.Module):
|
636 |
+
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, vae=False):
|
637 |
+
super(StyleEncoder, self).__init__()
|
638 |
+
self.vae = vae
|
639 |
+
self.model = []
|
640 |
+
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
|
641 |
+
for i in range(2):
|
642 |
+
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
|
643 |
+
dim *= 2
|
644 |
+
for i in range(n_downsample - 2):
|
645 |
+
self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
|
646 |
+
self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
|
647 |
+
if self.vae:
|
648 |
+
self.fc_mean = nn.Linear(dim, style_dim) # , 1, 1, 0)
|
649 |
+
self.fc_var = nn.Linear(dim, style_dim) # , 1, 1, 0)
|
650 |
+
else:
|
651 |
+
self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
|
652 |
+
|
653 |
+
self.model = nn.Sequential(*self.model)
|
654 |
+
self.output_dim = dim
|
655 |
+
|
656 |
+
def forward(self, x):
|
657 |
+
if self.vae:
|
658 |
+
output = self.model(x)
|
659 |
+
output = output.view(x.size(0), -1)
|
660 |
+
output_mean = self.fc_mean(output)
|
661 |
+
output_var = self.fc_var(output)
|
662 |
+
return output_mean, output_var
|
663 |
+
else:
|
664 |
+
return self.model(x).view(x.size(0), -1)
|
665 |
+
|
666 |
+
|
667 |
+
class ContentEncoder(nn.Module):
|
668 |
+
def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type='zero'):
|
669 |
+
super(ContentEncoder, self).__init__()
|
670 |
+
self.model = []
|
671 |
+
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
|
672 |
+
# downsampling blocks
|
673 |
+
for i in range(n_downsample):
|
674 |
+
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
|
675 |
+
dim *= 2
|
676 |
+
# residual blocks
|
677 |
+
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
678 |
+
self.model = nn.Sequential(*self.model)
|
679 |
+
self.output_dim = dim
|
680 |
+
|
681 |
+
def forward(self, x, nce_layers=[], encode_only=False):
|
682 |
+
if len(nce_layers) > 0:
|
683 |
+
feat = x
|
684 |
+
feats = []
|
685 |
+
for layer_id, layer in enumerate(self.model):
|
686 |
+
feat = layer(feat)
|
687 |
+
if layer_id in nce_layers:
|
688 |
+
feats.append(feat)
|
689 |
+
if layer_id == nce_layers[-1] and encode_only:
|
690 |
+
return None, feats
|
691 |
+
return feat, feats
|
692 |
+
else:
|
693 |
+
return self.model(x), None
|
694 |
+
|
695 |
+
for layer_id, layer in enumerate(self.model):
|
696 |
+
print(layer_id, layer)
|
697 |
+
|
698 |
+
|
699 |
+
class Decoder_all(nn.Module):
|
700 |
+
def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
|
701 |
+
super(Decoder_all, self).__init__()
|
702 |
+
# AdaIN residual blocks
|
703 |
+
self.resnet_block = ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)
|
704 |
+
self.n_blocks = 0
|
705 |
+
# upsampling blocks
|
706 |
+
for i in range(n_upsample):
|
707 |
+
block = [Upsample2(scale_factor=2), Conv2dBlock(dim + nz, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
|
708 |
+
setattr(self, 'block_{:d}'.format(self.n_blocks), nn.Sequential(*block))
|
709 |
+
self.n_blocks += 1
|
710 |
+
dim //= 2
|
711 |
+
# use reflection padding in the last conv layer
|
712 |
+
setattr(self, 'block_{:d}'.format(self.n_blocks), Conv2dBlock(dim + nz, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect'))
|
713 |
+
self.n_blocks += 1
|
714 |
+
|
715 |
+
def forward(self, x, y=None):
|
716 |
+
if y is not None:
|
717 |
+
output = self.resnet_block(cat_feature(x, y))
|
718 |
+
for n in range(self.n_blocks):
|
719 |
+
block = getattr(self, 'block_{:d}'.format(n))
|
720 |
+
if n > 0:
|
721 |
+
output = block(cat_feature(output, y))
|
722 |
+
else:
|
723 |
+
output = block(output)
|
724 |
+
return output
|
725 |
+
|
726 |
+
|
727 |
+
class Decoder(nn.Module):
|
728 |
+
def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
|
729 |
+
super(Decoder, self).__init__()
|
730 |
+
|
731 |
+
self.model = []
|
732 |
+
# AdaIN residual blocks
|
733 |
+
self.model += [ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)]
|
734 |
+
# upsampling blocks
|
735 |
+
for i in range(n_upsample):
|
736 |
+
if i == 0:
|
737 |
+
input_dim = dim + nz
|
738 |
+
else:
|
739 |
+
input_dim = dim
|
740 |
+
self.model += [Upsample2(scale_factor=2), Conv2dBlock(input_dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
|
741 |
+
dim //= 2
|
742 |
+
# use reflection padding in the last conv layer
|
743 |
+
self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect')]
|
744 |
+
self.model = nn.Sequential(*self.model)
|
745 |
+
|
746 |
+
def forward(self, x, y=None):
|
747 |
+
if y is not None:
|
748 |
+
return self.model(cat_feature(x, y))
|
749 |
+
else:
|
750 |
+
return self.model(x)
|
751 |
+
|
752 |
+
##################################################################################
|
753 |
+
# Sequential Models
|
754 |
+
##################################################################################
|
755 |
+
|
756 |
+
|
757 |
+
class ResBlocks(nn.Module):
|
758 |
+
def __init__(self, num_blocks, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
|
759 |
+
super(ResBlocks, self).__init__()
|
760 |
+
self.model = []
|
761 |
+
for i in range(num_blocks):
|
762 |
+
self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, nz=nz)]
|
763 |
+
self.model = nn.Sequential(*self.model)
|
764 |
+
|
765 |
+
def forward(self, x):
|
766 |
+
return self.model(x)
|
767 |
+
|
768 |
+
|
769 |
+
##################################################################################
|
770 |
+
# Basic Blocks
|
771 |
+
##################################################################################
|
772 |
+
def cat_feature(x, y):
|
773 |
+
y_expand = y.view(y.size(0), y.size(1), 1, 1).expand(
|
774 |
+
y.size(0), y.size(1), x.size(2), x.size(3))
|
775 |
+
x_cat = torch.cat([x, y_expand], 1)
|
776 |
+
return x_cat
|
777 |
+
|
778 |
+
|
779 |
+
class ResBlock(nn.Module):
|
780 |
+
def __init__(self, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
|
781 |
+
super(ResBlock, self).__init__()
|
782 |
+
|
783 |
+
model = []
|
784 |
+
model += [Conv2dBlock(dim + nz, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
|
785 |
+
model += [Conv2dBlock(dim, dim + nz, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
|
786 |
+
self.model = nn.Sequential(*model)
|
787 |
+
|
788 |
+
def forward(self, x):
|
789 |
+
residual = x
|
790 |
+
out = self.model(x)
|
791 |
+
out += residual
|
792 |
+
return out
|
793 |
+
|
794 |
+
|
795 |
+
class Conv2dBlock(nn.Module):
|
796 |
+
def __init__(self, input_dim, output_dim, kernel_size, stride,
|
797 |
+
padding=0, norm='none', activation='relu', pad_type='zero'):
|
798 |
+
super(Conv2dBlock, self).__init__()
|
799 |
+
self.use_bias = True
|
800 |
+
# initialize padding
|
801 |
+
if pad_type == 'reflect':
|
802 |
+
self.pad = nn.ReflectionPad2d(padding)
|
803 |
+
elif pad_type == 'zero':
|
804 |
+
self.pad = nn.ZeroPad2d(padding)
|
805 |
+
else:
|
806 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
807 |
+
|
808 |
+
# initialize normalization
|
809 |
+
norm_dim = output_dim
|
810 |
+
if norm == 'batch':
|
811 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
812 |
+
elif norm == 'inst':
|
813 |
+
self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=False)
|
814 |
+
elif norm == 'ln':
|
815 |
+
self.norm = LayerNorm(norm_dim)
|
816 |
+
elif norm == 'none':
|
817 |
+
self.norm = None
|
818 |
+
else:
|
819 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
820 |
+
|
821 |
+
# initialize activation
|
822 |
+
if activation == 'relu':
|
823 |
+
self.activation = nn.ReLU(inplace=True)
|
824 |
+
elif activation == 'lrelu':
|
825 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
826 |
+
elif activation == 'prelu':
|
827 |
+
self.activation = nn.PReLU()
|
828 |
+
elif activation == 'selu':
|
829 |
+
self.activation = nn.SELU(inplace=True)
|
830 |
+
elif activation == 'tanh':
|
831 |
+
self.activation = nn.Tanh()
|
832 |
+
elif activation == 'none':
|
833 |
+
self.activation = None
|
834 |
+
else:
|
835 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
836 |
+
|
837 |
+
# initialize convolution
|
838 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
839 |
+
|
840 |
+
def forward(self, x):
|
841 |
+
x = self.conv(self.pad(x))
|
842 |
+
if self.norm:
|
843 |
+
x = self.norm(x)
|
844 |
+
if self.activation:
|
845 |
+
x = self.activation(x)
|
846 |
+
return x
|
847 |
+
|
848 |
+
|
849 |
+
class LinearBlock(nn.Module):
|
850 |
+
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
|
851 |
+
super(LinearBlock, self).__init__()
|
852 |
+
use_bias = True
|
853 |
+
# initialize fully connected layer
|
854 |
+
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
|
855 |
+
|
856 |
+
# initialize normalization
|
857 |
+
norm_dim = output_dim
|
858 |
+
if norm == 'batch':
|
859 |
+
self.norm = nn.BatchNorm1d(norm_dim)
|
860 |
+
elif norm == 'inst':
|
861 |
+
self.norm = nn.InstanceNorm1d(norm_dim)
|
862 |
+
elif norm == 'ln':
|
863 |
+
self.norm = LayerNorm(norm_dim)
|
864 |
+
elif norm == 'none':
|
865 |
+
self.norm = None
|
866 |
+
else:
|
867 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
868 |
+
|
869 |
+
# initialize activation
|
870 |
+
if activation == 'relu':
|
871 |
+
self.activation = nn.ReLU(inplace=True)
|
872 |
+
elif activation == 'lrelu':
|
873 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
874 |
+
elif activation == 'prelu':
|
875 |
+
self.activation = nn.PReLU()
|
876 |
+
elif activation == 'selu':
|
877 |
+
self.activation = nn.SELU(inplace=True)
|
878 |
+
elif activation == 'tanh':
|
879 |
+
self.activation = nn.Tanh()
|
880 |
+
elif activation == 'none':
|
881 |
+
self.activation = None
|
882 |
+
else:
|
883 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
884 |
+
|
885 |
+
def forward(self, x):
|
886 |
+
out = self.fc(x)
|
887 |
+
if self.norm:
|
888 |
+
out = self.norm(out)
|
889 |
+
if self.activation:
|
890 |
+
out = self.activation(out)
|
891 |
+
return out
|
892 |
+
|
893 |
+
##################################################################################
|
894 |
+
# Normalization layers
|
895 |
+
##################################################################################
|
896 |
+
|
897 |
+
|
898 |
+
class LayerNorm(nn.Module):
|
899 |
+
def __init__(self, num_features, eps=1e-5, affine=True):
|
900 |
+
super(LayerNorm, self).__init__()
|
901 |
+
self.num_features = num_features
|
902 |
+
self.affine = affine
|
903 |
+
self.eps = eps
|
904 |
+
|
905 |
+
if self.affine:
|
906 |
+
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
|
907 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
908 |
+
|
909 |
+
def forward(self, x):
|
910 |
+
shape = [-1] + [1] * (x.dim() - 1)
|
911 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
912 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
913 |
+
x = (x - mean) / (std + self.eps)
|
914 |
+
|
915 |
+
if self.affine:
|
916 |
+
shape = [1, -1] + [1] * (x.dim() - 2)
|
917 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
918 |
+
return x
|
919 |
+
|
920 |
+
|
921 |
+
class ResnetGenerator(nn.Module):
|
922 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
923 |
+
|
924 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
925 |
+
"""
|
926 |
+
|
927 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
|
928 |
+
"""Construct a Resnet-based generator
|
929 |
+
|
930 |
+
Parameters:
|
931 |
+
input_nc (int) -- the number of channels in input images
|
932 |
+
output_nc (int) -- the number of channels in output images
|
933 |
+
ngf (int) -- the number of filters in the last conv layer
|
934 |
+
norm_layer -- normalization layer
|
935 |
+
use_dropout (bool) -- if use dropout layers
|
936 |
+
n_blocks (int) -- the number of ResNet blocks
|
937 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
938 |
+
"""
|
939 |
+
assert(n_blocks >= 0)
|
940 |
+
super(ResnetGenerator, self).__init__()
|
941 |
+
self.opt = opt
|
942 |
+
if type(norm_layer) == functools.partial:
|
943 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
944 |
+
else:
|
945 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
946 |
+
|
947 |
+
model = [nn.ReflectionPad2d(3),
|
948 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
949 |
+
norm_layer(ngf),
|
950 |
+
nn.ReLU(True)]
|
951 |
+
|
952 |
+
n_downsampling = 2
|
953 |
+
for i in range(n_downsampling): # add downsampling layers
|
954 |
+
mult = 2 ** i
|
955 |
+
if(no_antialias):
|
956 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
957 |
+
norm_layer(ngf * mult * 2),
|
958 |
+
nn.ReLU(True)]
|
959 |
+
else:
|
960 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
|
961 |
+
norm_layer(ngf * mult * 2),
|
962 |
+
nn.ReLU(True),
|
963 |
+
Downsample(ngf * mult * 2)]
|
964 |
+
|
965 |
+
mult = 2 ** n_downsampling
|
966 |
+
for i in range(n_blocks): # add ResNet blocks
|
967 |
+
|
968 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
969 |
+
|
970 |
+
for i in range(n_downsampling): # add upsampling layers
|
971 |
+
mult = 2 ** (n_downsampling - i)
|
972 |
+
if no_antialias_up:
|
973 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
974 |
+
kernel_size=3, stride=2,
|
975 |
+
padding=1, output_padding=1,
|
976 |
+
bias=use_bias),
|
977 |
+
norm_layer(int(ngf * mult / 2)),
|
978 |
+
nn.ReLU(True)]
|
979 |
+
else:
|
980 |
+
model += [Upsample(ngf * mult),
|
981 |
+
nn.Conv2d(ngf * mult, int(ngf * mult / 2),
|
982 |
+
kernel_size=3, stride=1,
|
983 |
+
padding=1, # output_padding=1,
|
984 |
+
bias=use_bias),
|
985 |
+
norm_layer(int(ngf * mult / 2)),
|
986 |
+
nn.ReLU(True)]
|
987 |
+
model += [nn.ReflectionPad2d(3)]
|
988 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
989 |
+
model += [nn.Tanh()]
|
990 |
+
|
991 |
+
self.model = nn.Sequential(*model)
|
992 |
+
|
993 |
+
def forward(self, input, layers=[], encode_only=False):
|
994 |
+
if -1 in layers:
|
995 |
+
layers.append(len(self.model))
|
996 |
+
if len(layers) > 0:
|
997 |
+
feat = input
|
998 |
+
feats = []
|
999 |
+
for layer_id, layer in enumerate(self.model):
|
1000 |
+
# print(layer_id, layer)
|
1001 |
+
feat = layer(feat)
|
1002 |
+
if layer_id in layers:
|
1003 |
+
# print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
1004 |
+
feats.append(feat)
|
1005 |
+
else:
|
1006 |
+
# print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
|
1007 |
+
pass
|
1008 |
+
if layer_id == layers[-1] and encode_only:
|
1009 |
+
# print('encoder only return features')
|
1010 |
+
return feats # return intermediate features alone; stop in the last layers
|
1011 |
+
|
1012 |
+
return feat, feats # return both output and intermediate features
|
1013 |
+
else:
|
1014 |
+
"""Standard forward"""
|
1015 |
+
fake = self.model(input)
|
1016 |
+
return fake
|
1017 |
+
|
1018 |
+
|
1019 |
+
class ResnetDecoder(nn.Module):
|
1020 |
+
"""Resnet-based decoder that consists of a few Resnet blocks + a few upsampling operations.
|
1021 |
+
"""
|
1022 |
+
|
1023 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
|
1024 |
+
"""Construct a Resnet-based decoder
|
1025 |
+
|
1026 |
+
Parameters:
|
1027 |
+
input_nc (int) -- the number of channels in input images
|
1028 |
+
output_nc (int) -- the number of channels in output images
|
1029 |
+
ngf (int) -- the number of filters in the last conv layer
|
1030 |
+
norm_layer -- normalization layer
|
1031 |
+
use_dropout (bool) -- if use dropout layers
|
1032 |
+
n_blocks (int) -- the number of ResNet blocks
|
1033 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
1034 |
+
"""
|
1035 |
+
assert(n_blocks >= 0)
|
1036 |
+
super(ResnetDecoder, self).__init__()
|
1037 |
+
if type(norm_layer) == functools.partial:
|
1038 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1039 |
+
else:
|
1040 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1041 |
+
model = []
|
1042 |
+
n_downsampling = 2
|
1043 |
+
mult = 2 ** n_downsampling
|
1044 |
+
for i in range(n_blocks): # add ResNet blocks
|
1045 |
+
|
1046 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
1047 |
+
|
1048 |
+
for i in range(n_downsampling): # add upsampling layers
|
1049 |
+
mult = 2 ** (n_downsampling - i)
|
1050 |
+
if(no_antialias):
|
1051 |
+
model += [SpectralNorm(nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2)),
|
1052 |
+
kernel_size=3, stride=2,
|
1053 |
+
padding=1, output_padding=1,
|
1054 |
+
bias=use_bias),
|
1055 |
+
nn.ReLU(True)]
|
1056 |
+
else:
|
1057 |
+
model += [Upsample(ngf * mult),
|
1058 |
+
SpectralNorm(nn.Conv2d(ngf * mult, int(ngf * mult / 2)),
|
1059 |
+
kernel_size=3, stride=1,
|
1060 |
+
padding=1,
|
1061 |
+
bias=use_bias),
|
1062 |
+
norm_layer(int(ngf * mult / 2)),
|
1063 |
+
nn.ReLU(True)]
|
1064 |
+
model += [nn.ReflectionPad2d(3)]
|
1065 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
1066 |
+
model += [nn.Tanh()]
|
1067 |
+
|
1068 |
+
self.model = nn.Sequential(*model)
|
1069 |
+
|
1070 |
+
def forward(self, input):
|
1071 |
+
"""Standard forward"""
|
1072 |
+
return self.model(input)
|
1073 |
+
|
1074 |
+
|
1075 |
+
class ResnetEncoder(nn.Module):
|
1076 |
+
"""Resnet-based encoder that consists of a few downsampling + several Resnet blocks
|
1077 |
+
"""
|
1078 |
+
|
1079 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
|
1080 |
+
"""Construct a Resnet-based encoder
|
1081 |
+
|
1082 |
+
Parameters:
|
1083 |
+
input_nc (int) -- the number of channels in input images
|
1084 |
+
output_nc (int) -- the number of channels in output images
|
1085 |
+
ngf (int) -- the number of filters in the last conv layer
|
1086 |
+
norm_layer -- normalization layer
|
1087 |
+
use_dropout (bool) -- if use dropout layers
|
1088 |
+
n_blocks (int) -- the number of ResNet blocks
|
1089 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
1090 |
+
"""
|
1091 |
+
assert(n_blocks >= 0)
|
1092 |
+
super(ResnetEncoder, self).__init__()
|
1093 |
+
if type(norm_layer) == functools.partial:
|
1094 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1095 |
+
else:
|
1096 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1097 |
+
|
1098 |
+
model = [nn.ReflectionPad2d(3),
|
1099 |
+
SpectralNorm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias)),
|
1100 |
+
nn.ReLU(True)]
|
1101 |
+
|
1102 |
+
n_downsampling = 2
|
1103 |
+
for i in range(n_downsampling): # add downsampling layers
|
1104 |
+
mult = 2 ** i
|
1105 |
+
if(no_antialias):
|
1106 |
+
model += [SpectralNorm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)),
|
1107 |
+
nn.ReLU(True)]
|
1108 |
+
else:
|
1109 |
+
model += [SpectralNorm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)),
|
1110 |
+
nn.ReLU(True),
|
1111 |
+
Downsample(ngf * mult * 2)]
|
1112 |
+
|
1113 |
+
mult = 2 ** n_downsampling
|
1114 |
+
for i in range(n_blocks): # add ResNet blocks
|
1115 |
+
|
1116 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
1117 |
+
|
1118 |
+
self.model = nn.Sequential(*model)
|
1119 |
+
|
1120 |
+
def forward(self, input):
|
1121 |
+
"""Standard forward"""
|
1122 |
+
return self.model(input)
|
1123 |
+
|
1124 |
+
|
1125 |
+
class ResnetBlock(nn.Module):
|
1126 |
+
"""Define a Resnet block"""
|
1127 |
+
|
1128 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
1129 |
+
"""Initialize the Resnet block
|
1130 |
+
|
1131 |
+
A resnet block is a conv block with skip connections
|
1132 |
+
We construct a conv block with build_conv_block function,
|
1133 |
+
and implement skip connections in <forward> function.
|
1134 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
1135 |
+
"""
|
1136 |
+
super(ResnetBlock, self).__init__()
|
1137 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
1138 |
+
|
1139 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
1140 |
+
"""Construct a convolutional block.
|
1141 |
+
|
1142 |
+
Parameters:
|
1143 |
+
dim (int) -- the number of channels in the conv layer.
|
1144 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
1145 |
+
norm_layer -- normalization layer
|
1146 |
+
use_dropout (bool) -- if use dropout layers.
|
1147 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
1148 |
+
|
1149 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
1150 |
+
"""
|
1151 |
+
conv_block = []
|
1152 |
+
p = 0
|
1153 |
+
if padding_type == 'reflect':
|
1154 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
1155 |
+
elif padding_type == 'replicate':
|
1156 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
1157 |
+
elif padding_type == 'zero':
|
1158 |
+
p = 1
|
1159 |
+
else:
|
1160 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
1161 |
+
|
1162 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
1163 |
+
if use_dropout:
|
1164 |
+
conv_block += [nn.Dropout(0.5)]
|
1165 |
+
|
1166 |
+
p = 0
|
1167 |
+
if padding_type == 'reflect':
|
1168 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
1169 |
+
elif padding_type == 'replicate':
|
1170 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
1171 |
+
elif padding_type == 'zero':
|
1172 |
+
p = 1
|
1173 |
+
else:
|
1174 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
1175 |
+
conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias))]
|
1176 |
+
|
1177 |
+
return nn.Sequential(*conv_block)
|
1178 |
+
|
1179 |
+
def forward(self, x):
|
1180 |
+
"""Forward function (with skip connections)"""
|
1181 |
+
out = x + self.conv_block(x) # add skip connections
|
1182 |
+
return out
|
1183 |
+
|
1184 |
+
|
1185 |
+
class UnetGenerator(nn.Module):
|
1186 |
+
"""Create a Unet-based generator"""
|
1187 |
+
|
1188 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
1189 |
+
"""Construct a Unet generator
|
1190 |
+
Parameters:
|
1191 |
+
input_nc (int) -- the number of channels in input images
|
1192 |
+
output_nc (int) -- the number of channels in output images
|
1193 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
1194 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
1195 |
+
ngf (int) -- the number of filters in the last conv layer
|
1196 |
+
norm_layer -- normalization layer
|
1197 |
+
|
1198 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
1199 |
+
It is a recursive process.
|
1200 |
+
"""
|
1201 |
+
super(UnetGenerator, self).__init__()
|
1202 |
+
# construct unet structure
|
1203 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
1204 |
+
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
1205 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
1206 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
1207 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
1208 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
1209 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
1210 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
1211 |
+
|
1212 |
+
def forward(self, input):
|
1213 |
+
"""Standard forward"""
|
1214 |
+
return self.model(input)
|
1215 |
+
|
1216 |
+
|
1217 |
+
class UnetSkipConnectionBlock(nn.Module):
|
1218 |
+
"""Defines the Unet submodule with skip connection.
|
1219 |
+
X -------------------identity----------------------
|
1220 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
1221 |
+
"""
|
1222 |
+
|
1223 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
1224 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
1225 |
+
"""Construct a Unet submodule with skip connections.
|
1226 |
+
|
1227 |
+
Parameters:
|
1228 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
1229 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
1230 |
+
input_nc (int) -- the number of channels in input images/features
|
1231 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
1232 |
+
outermost (bool) -- if this module is the outermost module
|
1233 |
+
innermost (bool) -- if this module is the innermost module
|
1234 |
+
norm_layer -- normalization layer
|
1235 |
+
use_dropout (bool) -- if use dropout layers.
|
1236 |
+
"""
|
1237 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
1238 |
+
self.outermost = outermost
|
1239 |
+
if type(norm_layer) == functools.partial:
|
1240 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1241 |
+
else:
|
1242 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1243 |
+
if input_nc is None:
|
1244 |
+
input_nc = outer_nc
|
1245 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
1246 |
+
stride=2, padding=1, bias=use_bias)
|
1247 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
1248 |
+
downnorm = norm_layer(inner_nc)
|
1249 |
+
uprelu = nn.ReLU(True)
|
1250 |
+
upnorm = norm_layer(outer_nc)
|
1251 |
+
|
1252 |
+
if outermost:
|
1253 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
1254 |
+
kernel_size=4, stride=2,
|
1255 |
+
padding=1)
|
1256 |
+
down = [downconv]
|
1257 |
+
up = [uprelu, upconv, nn.Tanh()]
|
1258 |
+
model = down + [submodule] + up
|
1259 |
+
elif innermost:
|
1260 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
1261 |
+
kernel_size=4, stride=2,
|
1262 |
+
padding=1, bias=use_bias)
|
1263 |
+
down = [downrelu, downconv]
|
1264 |
+
up = [uprelu, upconv, upnorm]
|
1265 |
+
model = down + up
|
1266 |
+
else:
|
1267 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
1268 |
+
kernel_size=4, stride=2,
|
1269 |
+
padding=1, bias=use_bias)
|
1270 |
+
down = [downrelu, downconv, downnorm]
|
1271 |
+
up = [uprelu, upconv, upnorm]
|
1272 |
+
|
1273 |
+
if use_dropout:
|
1274 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
1275 |
+
else:
|
1276 |
+
model = down + [submodule] + up
|
1277 |
+
|
1278 |
+
self.model = nn.Sequential(*model)
|
1279 |
+
|
1280 |
+
def forward(self, x):
|
1281 |
+
if self.outermost:
|
1282 |
+
return self.model(x)
|
1283 |
+
else: # add skip connections
|
1284 |
+
return torch.cat([x, self.model(x)], 1)
|
1285 |
+
|
1286 |
+
|
1287 |
+
class NLayerDiscriminator(nn.Module):
|
1288 |
+
"""Defines a PatchGAN discriminator"""
|
1289 |
+
|
1290 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
|
1291 |
+
"""Construct a PatchGAN discriminator
|
1292 |
+
|
1293 |
+
Parameters:
|
1294 |
+
input_nc (int) -- the number of channels in input images
|
1295 |
+
ndf (int) -- the number of filters in the last conv layer
|
1296 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
1297 |
+
norm_layer -- normalization layer
|
1298 |
+
"""
|
1299 |
+
super(NLayerDiscriminator, self).__init__()
|
1300 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
1301 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1302 |
+
else:
|
1303 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1304 |
+
|
1305 |
+
kw = 4
|
1306 |
+
padw = 1
|
1307 |
+
if(no_antialias):
|
1308 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
1309 |
+
else:
|
1310 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), Downsample(ndf)]
|
1311 |
+
nf_mult = 1
|
1312 |
+
nf_mult_prev = 1
|
1313 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
1314 |
+
nf_mult_prev = nf_mult
|
1315 |
+
nf_mult = min(2 ** n, 8)
|
1316 |
+
if(no_antialias):
|
1317 |
+
sequence += [
|
1318 |
+
SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
|
1319 |
+
nn.LeakyReLU(0.2, True)
|
1320 |
+
]
|
1321 |
+
else:
|
1322 |
+
sequence += [
|
1323 |
+
SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
|
1324 |
+
nn.LeakyReLU(0.2, True),
|
1325 |
+
Downsample(ndf * nf_mult)]
|
1326 |
+
|
1327 |
+
nf_mult_prev = nf_mult
|
1328 |
+
nf_mult = min(2 ** n_layers, 8)
|
1329 |
+
sequence += [
|
1330 |
+
SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
|
1331 |
+
nn.LeakyReLU(0.2, True)
|
1332 |
+
]
|
1333 |
+
sequence += [SpectralNorm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))] # output 1 channel prediction map
|
1334 |
+
self.model = nn.Sequential(*sequence)
|
1335 |
+
|
1336 |
+
def forward(self, input):
|
1337 |
+
"""Standard forward."""
|
1338 |
+
return self.model(input)
|
1339 |
+
|
1340 |
+
|
1341 |
+
class PixelDiscriminator(nn.Module):
|
1342 |
+
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
|
1343 |
+
|
1344 |
+
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
|
1345 |
+
"""Construct a 1x1 PatchGAN discriminator
|
1346 |
+
|
1347 |
+
Parameters:
|
1348 |
+
input_nc (int) -- the number of channels in input images
|
1349 |
+
ndf (int) -- the number of filters in the last conv layer
|
1350 |
+
norm_layer -- normalization layer
|
1351 |
+
"""
|
1352 |
+
super(PixelDiscriminator, self).__init__()
|
1353 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
1354 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
1355 |
+
else:
|
1356 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
1357 |
+
|
1358 |
+
self.net = [
|
1359 |
+
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
|
1360 |
+
nn.LeakyReLU(0.2, True),
|
1361 |
+
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
|
1362 |
+
norm_layer(ndf * 2),
|
1363 |
+
nn.LeakyReLU(0.2, True),
|
1364 |
+
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
|
1365 |
+
|
1366 |
+
self.net = nn.Sequential(*self.net)
|
1367 |
+
|
1368 |
+
def forward(self, input):
|
1369 |
+
"""Standard forward."""
|
1370 |
+
return self.net(input)
|
1371 |
+
|
1372 |
+
|
1373 |
+
class PatchDiscriminator(NLayerDiscriminator):
|
1374 |
+
"""Defines a PatchGAN discriminator"""
|
1375 |
+
|
1376 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
|
1377 |
+
super().__init__(input_nc, ndf, 2, norm_layer, no_antialias)
|
1378 |
+
|
1379 |
+
def forward(self, input):
|
1380 |
+
B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
|
1381 |
+
size = 16
|
1382 |
+
Y = H // size
|
1383 |
+
X = W // size
|
1384 |
+
input = input.view(B, C, Y, size, X, size)
|
1385 |
+
input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
|
1386 |
+
return super().forward(input)
|
1387 |
+
|
1388 |
+
|
1389 |
+
class GroupedChannelNorm(nn.Module):
|
1390 |
+
def __init__(self, num_groups):
|
1391 |
+
super().__init__()
|
1392 |
+
self.num_groups = num_groups
|
1393 |
+
|
1394 |
+
def forward(self, x):
|
1395 |
+
shape = list(x.shape)
|
1396 |
+
new_shape = [shape[0], self.num_groups, shape[1] // self.num_groups] + shape[2:]
|
1397 |
+
x = x.view(*new_shape)
|
1398 |
+
mean = x.mean(dim=2, keepdim=True)
|
1399 |
+
std = x.std(dim=2, keepdim=True)
|
1400 |
+
x_norm = (x - mean) / (std + 1e-7)
|
1401 |
+
return x_norm.view(*shape)
|
models/patchnce.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from packaging import version
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class PatchNCELoss(nn.Module):
|
7 |
+
def __init__(self, opt):
|
8 |
+
super().__init__()
|
9 |
+
self.opt = opt
|
10 |
+
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
|
11 |
+
self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
|
12 |
+
|
13 |
+
def forward(self, feat_q, feat_k):
|
14 |
+
batchSize = feat_q.shape[0]
|
15 |
+
dim = feat_q.shape[1]
|
16 |
+
feat_k = feat_k.detach()
|
17 |
+
|
18 |
+
# pos logit
|
19 |
+
l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
|
20 |
+
l_pos = l_pos.view(batchSize, 1)
|
21 |
+
|
22 |
+
# neg logit
|
23 |
+
|
24 |
+
# Should the negatives from the other samples of a minibatch be utilized?
|
25 |
+
# In CUT and FastCUT, we found that it's best to only include negatives
|
26 |
+
# from the same image. Therefore, we set
|
27 |
+
# --nce_includes_all_negatives_from_minibatch as False
|
28 |
+
# However, for single-image translation, the minibatch consists of
|
29 |
+
# crops from the "same" high-resolution image.
|
30 |
+
# Therefore, we will include the negatives from the entire minibatch.
|
31 |
+
if self.opt.nce_includes_all_negatives_from_minibatch:
|
32 |
+
# reshape features as if they are all negatives of minibatch of size 1.
|
33 |
+
batch_dim_for_bmm = 1
|
34 |
+
else:
|
35 |
+
batch_dim_for_bmm = self.opt.batch_size
|
36 |
+
|
37 |
+
# reshape features to batch size
|
38 |
+
feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
|
39 |
+
feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
|
40 |
+
npatches = feat_q.size(1)
|
41 |
+
l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
|
42 |
+
|
43 |
+
# diagonal entries are similarity between same features, and hence meaningless.
|
44 |
+
# just fill the diagonal with very small number, which is exp(-10) and almost zero
|
45 |
+
diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
|
46 |
+
l_neg_curbatch.masked_fill_(diagonal, -10.0)
|
47 |
+
l_neg = l_neg_curbatch.view(-1, npatches)
|
48 |
+
|
49 |
+
out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
|
50 |
+
|
51 |
+
loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
|
52 |
+
device=feat_q.device))
|
53 |
+
|
54 |
+
return loss
|
models/spectralNormalization.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import Tensor
|
4 |
+
from torch.nn import Parameter
|
5 |
+
from torch import nn
|
6 |
+
from torch.autograd import Variable
|
7 |
+
def l2normalize(vector, eps = 1e-15):
|
8 |
+
return vector/(vector.norm()+eps)
|
9 |
+
|
10 |
+
class SpectralNorm(nn.Module):
|
11 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
12 |
+
super(SpectralNorm, self).__init__()
|
13 |
+
self.module = module
|
14 |
+
self.name = name
|
15 |
+
self.power_iterations = power_iterations
|
16 |
+
if not self._made_params():
|
17 |
+
self._make_params()
|
18 |
+
|
19 |
+
def _update_u_v(self):
|
20 |
+
u = getattr(self.module, self.name + "_u")
|
21 |
+
v = getattr(self.module, self.name + "_v")
|
22 |
+
w = getattr(self.module, self.name + "_bar")
|
23 |
+
|
24 |
+
height = w.data.shape[0]
|
25 |
+
for _ in range(self.power_iterations):
|
26 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
|
27 |
+
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
|
28 |
+
|
29 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
30 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
31 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
32 |
+
|
33 |
+
def _made_params(self):
|
34 |
+
try:
|
35 |
+
u = getattr(self.module, self.name + "_u")
|
36 |
+
v = getattr(self.module, self.name + "_v")
|
37 |
+
w = getattr(self.module, self.name + "_bar")
|
38 |
+
return True
|
39 |
+
except AttributeError:
|
40 |
+
return False
|
41 |
+
|
42 |
+
|
43 |
+
def _make_params(self):
|
44 |
+
w = getattr(self.module, self.name)
|
45 |
+
|
46 |
+
height = w.data.shape[0]
|
47 |
+
width = w.view(height, -1).data.shape[1]
|
48 |
+
|
49 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
50 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
51 |
+
u.data = l2normalize(u.data)
|
52 |
+
v.data = l2normalize(v.data)
|
53 |
+
w_bar = Parameter(w.data)
|
54 |
+
|
55 |
+
del self.module._parameters[self.name]
|
56 |
+
|
57 |
+
self.module.register_parameter(self.name + "_u", u)
|
58 |
+
self.module.register_parameter(self.name + "_v", v)
|
59 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
60 |
+
|
61 |
+
|
62 |
+
def forward(self, *args):
|
63 |
+
self._update_u_v()
|
64 |
+
return self.module.forward(*args)
|
models/stylegan_networks.py
ADDED
@@ -0,0 +1,914 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The network architectures is based on PyTorch implemenation of StyleGAN2Encoder.
|
3 |
+
Original PyTorch repo: https://github.com/rosinality/style-based-gan-pytorch
|
4 |
+
Origianl StyelGAN2 paper: https://github.com/NVlabs/stylegan2
|
5 |
+
We use the network architeture for our single-image traning setting.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import math
|
9 |
+
import numpy as np
|
10 |
+
import random
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
|
17 |
+
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
18 |
+
return F.leaky_relu(input + bias, negative_slope) * scale
|
19 |
+
|
20 |
+
|
21 |
+
class FusedLeakyReLU(nn.Module):
|
22 |
+
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
23 |
+
super().__init__()
|
24 |
+
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
25 |
+
self.negative_slope = negative_slope
|
26 |
+
self.scale = scale
|
27 |
+
|
28 |
+
def forward(self, input):
|
29 |
+
# print("FusedLeakyReLU: ", input.abs().mean())
|
30 |
+
out = fused_leaky_relu(input, self.bias,
|
31 |
+
self.negative_slope,
|
32 |
+
self.scale)
|
33 |
+
# print("FusedLeakyReLU: ", out.abs().mean())
|
34 |
+
return out
|
35 |
+
|
36 |
+
|
37 |
+
def upfirdn2d_native(
|
38 |
+
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
|
39 |
+
):
|
40 |
+
_, minor, in_h, in_w = input.shape
|
41 |
+
kernel_h, kernel_w = kernel.shape
|
42 |
+
|
43 |
+
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
44 |
+
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
45 |
+
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
46 |
+
|
47 |
+
out = F.pad(
|
48 |
+
out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
|
49 |
+
)
|
50 |
+
out = out[
|
51 |
+
:,
|
52 |
+
:,
|
53 |
+
max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
54 |
+
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0),
|
55 |
+
]
|
56 |
+
|
57 |
+
# out = out.permute(0, 3, 1, 2)
|
58 |
+
out = out.reshape(
|
59 |
+
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
|
60 |
+
)
|
61 |
+
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
62 |
+
out = F.conv2d(out, w)
|
63 |
+
out = out.reshape(
|
64 |
+
-1,
|
65 |
+
minor,
|
66 |
+
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
67 |
+
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
68 |
+
)
|
69 |
+
# out = out.permute(0, 2, 3, 1)
|
70 |
+
|
71 |
+
return out[:, :, ::down_y, ::down_x]
|
72 |
+
|
73 |
+
|
74 |
+
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
75 |
+
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
76 |
+
|
77 |
+
|
78 |
+
class PixelNorm(nn.Module):
|
79 |
+
def __init__(self):
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
def forward(self, input):
|
83 |
+
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
|
84 |
+
|
85 |
+
|
86 |
+
def make_kernel(k):
|
87 |
+
k = torch.tensor(k, dtype=torch.float32)
|
88 |
+
|
89 |
+
if len(k.shape) == 1:
|
90 |
+
k = k[None, :] * k[:, None]
|
91 |
+
|
92 |
+
k /= k.sum()
|
93 |
+
|
94 |
+
return k
|
95 |
+
|
96 |
+
|
97 |
+
class Upsample(nn.Module):
|
98 |
+
def __init__(self, kernel, factor=2):
|
99 |
+
super().__init__()
|
100 |
+
|
101 |
+
self.factor = factor
|
102 |
+
kernel = make_kernel(kernel) * (factor ** 2)
|
103 |
+
self.register_buffer('kernel', kernel)
|
104 |
+
|
105 |
+
p = kernel.shape[0] - factor
|
106 |
+
|
107 |
+
pad0 = (p + 1) // 2 + factor - 1
|
108 |
+
pad1 = p // 2
|
109 |
+
|
110 |
+
self.pad = (pad0, pad1)
|
111 |
+
|
112 |
+
def forward(self, input):
|
113 |
+
out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
|
114 |
+
|
115 |
+
return out
|
116 |
+
|
117 |
+
|
118 |
+
class Downsample(nn.Module):
|
119 |
+
def __init__(self, kernel, factor=2):
|
120 |
+
super().__init__()
|
121 |
+
|
122 |
+
self.factor = factor
|
123 |
+
kernel = make_kernel(kernel)
|
124 |
+
self.register_buffer('kernel', kernel)
|
125 |
+
|
126 |
+
p = kernel.shape[0] - factor
|
127 |
+
|
128 |
+
pad0 = (p + 1) // 2
|
129 |
+
pad1 = p // 2
|
130 |
+
|
131 |
+
self.pad = (pad0, pad1)
|
132 |
+
|
133 |
+
def forward(self, input):
|
134 |
+
out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
|
135 |
+
|
136 |
+
return out
|
137 |
+
|
138 |
+
|
139 |
+
class Blur(nn.Module):
|
140 |
+
def __init__(self, kernel, pad, upsample_factor=1):
|
141 |
+
super().__init__()
|
142 |
+
|
143 |
+
kernel = make_kernel(kernel)
|
144 |
+
|
145 |
+
if upsample_factor > 1:
|
146 |
+
kernel = kernel * (upsample_factor ** 2)
|
147 |
+
|
148 |
+
self.register_buffer('kernel', kernel)
|
149 |
+
|
150 |
+
self.pad = pad
|
151 |
+
|
152 |
+
def forward(self, input):
|
153 |
+
out = upfirdn2d(input, self.kernel, pad=self.pad)
|
154 |
+
|
155 |
+
return out
|
156 |
+
|
157 |
+
|
158 |
+
class EqualConv2d(nn.Module):
|
159 |
+
def __init__(
|
160 |
+
self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
|
161 |
+
):
|
162 |
+
super().__init__()
|
163 |
+
|
164 |
+
self.weight = nn.Parameter(
|
165 |
+
torch.randn(out_channel, in_channel, kernel_size, kernel_size)
|
166 |
+
)
|
167 |
+
self.scale = math.sqrt(1) / math.sqrt(in_channel * (kernel_size ** 2))
|
168 |
+
|
169 |
+
self.stride = stride
|
170 |
+
self.padding = padding
|
171 |
+
|
172 |
+
if bias:
|
173 |
+
self.bias = nn.Parameter(torch.zeros(out_channel))
|
174 |
+
|
175 |
+
else:
|
176 |
+
self.bias = None
|
177 |
+
|
178 |
+
def forward(self, input):
|
179 |
+
# print("Before EqualConv2d: ", input.abs().mean())
|
180 |
+
out = F.conv2d(
|
181 |
+
input,
|
182 |
+
self.weight * self.scale,
|
183 |
+
bias=self.bias,
|
184 |
+
stride=self.stride,
|
185 |
+
padding=self.padding,
|
186 |
+
)
|
187 |
+
# print("After EqualConv2d: ", out.abs().mean(), (self.weight * self.scale).abs().mean())
|
188 |
+
|
189 |
+
return out
|
190 |
+
|
191 |
+
def __repr__(self):
|
192 |
+
return (
|
193 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
194 |
+
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
195 |
+
)
|
196 |
+
|
197 |
+
|
198 |
+
class EqualLinear(nn.Module):
|
199 |
+
def __init__(
|
200 |
+
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
|
201 |
+
):
|
202 |
+
super().__init__()
|
203 |
+
|
204 |
+
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
205 |
+
|
206 |
+
if bias:
|
207 |
+
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
208 |
+
|
209 |
+
else:
|
210 |
+
self.bias = None
|
211 |
+
|
212 |
+
self.activation = activation
|
213 |
+
|
214 |
+
self.scale = (math.sqrt(1) / math.sqrt(in_dim)) * lr_mul
|
215 |
+
self.lr_mul = lr_mul
|
216 |
+
|
217 |
+
def forward(self, input):
|
218 |
+
if self.activation:
|
219 |
+
out = F.linear(input, self.weight * self.scale)
|
220 |
+
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
221 |
+
|
222 |
+
else:
|
223 |
+
out = F.linear(
|
224 |
+
input, self.weight * self.scale, bias=self.bias * self.lr_mul
|
225 |
+
)
|
226 |
+
|
227 |
+
return out
|
228 |
+
|
229 |
+
def __repr__(self):
|
230 |
+
return (
|
231 |
+
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
+
class ScaledLeakyReLU(nn.Module):
|
236 |
+
def __init__(self, negative_slope=0.2):
|
237 |
+
super().__init__()
|
238 |
+
|
239 |
+
self.negative_slope = negative_slope
|
240 |
+
|
241 |
+
def forward(self, input):
|
242 |
+
out = F.leaky_relu(input, negative_slope=self.negative_slope)
|
243 |
+
|
244 |
+
return out * math.sqrt(2)
|
245 |
+
|
246 |
+
|
247 |
+
class ModulatedConv2d(nn.Module):
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
in_channel,
|
251 |
+
out_channel,
|
252 |
+
kernel_size,
|
253 |
+
style_dim,
|
254 |
+
demodulate=True,
|
255 |
+
upsample=False,
|
256 |
+
downsample=False,
|
257 |
+
blur_kernel=[1, 3, 3, 1],
|
258 |
+
):
|
259 |
+
super().__init__()
|
260 |
+
|
261 |
+
self.eps = 1e-8
|
262 |
+
self.kernel_size = kernel_size
|
263 |
+
self.in_channel = in_channel
|
264 |
+
self.out_channel = out_channel
|
265 |
+
self.upsample = upsample
|
266 |
+
self.downsample = downsample
|
267 |
+
|
268 |
+
if upsample:
|
269 |
+
factor = 2
|
270 |
+
p = (len(blur_kernel) - factor) - (kernel_size - 1)
|
271 |
+
pad0 = (p + 1) // 2 + factor - 1
|
272 |
+
pad1 = p // 2 + 1
|
273 |
+
|
274 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
|
275 |
+
|
276 |
+
if downsample:
|
277 |
+
factor = 2
|
278 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
279 |
+
pad0 = (p + 1) // 2
|
280 |
+
pad1 = p // 2
|
281 |
+
|
282 |
+
self.blur = Blur(blur_kernel, pad=(pad0, pad1))
|
283 |
+
|
284 |
+
fan_in = in_channel * kernel_size ** 2
|
285 |
+
self.scale = math.sqrt(1) / math.sqrt(fan_in)
|
286 |
+
self.padding = kernel_size // 2
|
287 |
+
|
288 |
+
self.weight = nn.Parameter(
|
289 |
+
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
|
290 |
+
)
|
291 |
+
|
292 |
+
if style_dim is not None and style_dim > 0:
|
293 |
+
self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
|
294 |
+
|
295 |
+
self.demodulate = demodulate
|
296 |
+
|
297 |
+
def __repr__(self):
|
298 |
+
return (
|
299 |
+
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
|
300 |
+
f'upsample={self.upsample}, downsample={self.downsample})'
|
301 |
+
)
|
302 |
+
|
303 |
+
def forward(self, input, style):
|
304 |
+
batch, in_channel, height, width = input.shape
|
305 |
+
|
306 |
+
if style is not None:
|
307 |
+
style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
|
308 |
+
else:
|
309 |
+
style = torch.ones(batch, 1, in_channel, 1, 1).cuda()
|
310 |
+
weight = self.scale * self.weight * style
|
311 |
+
|
312 |
+
if self.demodulate:
|
313 |
+
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
|
314 |
+
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
|
315 |
+
|
316 |
+
weight = weight.view(
|
317 |
+
batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
318 |
+
)
|
319 |
+
|
320 |
+
if self.upsample:
|
321 |
+
input = input.view(1, batch * in_channel, height, width)
|
322 |
+
weight = weight.view(
|
323 |
+
batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
|
324 |
+
)
|
325 |
+
weight = weight.transpose(1, 2).reshape(
|
326 |
+
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
|
327 |
+
)
|
328 |
+
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
|
329 |
+
_, _, height, width = out.shape
|
330 |
+
out = out.view(batch, self.out_channel, height, width)
|
331 |
+
out = self.blur(out)
|
332 |
+
|
333 |
+
elif self.downsample:
|
334 |
+
input = self.blur(input)
|
335 |
+
_, _, height, width = input.shape
|
336 |
+
input = input.view(1, batch * in_channel, height, width)
|
337 |
+
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
|
338 |
+
_, _, height, width = out.shape
|
339 |
+
out = out.view(batch, self.out_channel, height, width)
|
340 |
+
|
341 |
+
else:
|
342 |
+
input = input.view(1, batch * in_channel, height, width)
|
343 |
+
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
|
344 |
+
_, _, height, width = out.shape
|
345 |
+
out = out.view(batch, self.out_channel, height, width)
|
346 |
+
|
347 |
+
return out
|
348 |
+
|
349 |
+
|
350 |
+
class NoiseInjection(nn.Module):
|
351 |
+
def __init__(self):
|
352 |
+
super().__init__()
|
353 |
+
|
354 |
+
self.weight = nn.Parameter(torch.zeros(1))
|
355 |
+
|
356 |
+
def forward(self, image, noise=None):
|
357 |
+
if noise is None:
|
358 |
+
batch, _, height, width = image.shape
|
359 |
+
noise = image.new_empty(batch, 1, height, width).normal_()
|
360 |
+
|
361 |
+
return image + self.weight * noise
|
362 |
+
|
363 |
+
|
364 |
+
class ConstantInput(nn.Module):
|
365 |
+
def __init__(self, channel, size=4):
|
366 |
+
super().__init__()
|
367 |
+
|
368 |
+
self.input = nn.Parameter(torch.randn(1, channel, size, size))
|
369 |
+
|
370 |
+
def forward(self, input):
|
371 |
+
batch = input.shape[0]
|
372 |
+
out = self.input.repeat(batch, 1, 1, 1)
|
373 |
+
|
374 |
+
return out
|
375 |
+
|
376 |
+
|
377 |
+
class StyledConv(nn.Module):
|
378 |
+
def __init__(
|
379 |
+
self,
|
380 |
+
in_channel,
|
381 |
+
out_channel,
|
382 |
+
kernel_size,
|
383 |
+
style_dim=None,
|
384 |
+
upsample=False,
|
385 |
+
blur_kernel=[1, 3, 3, 1],
|
386 |
+
demodulate=True,
|
387 |
+
inject_noise=True,
|
388 |
+
):
|
389 |
+
super().__init__()
|
390 |
+
|
391 |
+
self.inject_noise = inject_noise
|
392 |
+
self.conv = ModulatedConv2d(
|
393 |
+
in_channel,
|
394 |
+
out_channel,
|
395 |
+
kernel_size,
|
396 |
+
style_dim,
|
397 |
+
upsample=upsample,
|
398 |
+
blur_kernel=blur_kernel,
|
399 |
+
demodulate=demodulate,
|
400 |
+
)
|
401 |
+
|
402 |
+
self.noise = NoiseInjection()
|
403 |
+
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
|
404 |
+
# self.activate = ScaledLeakyReLU(0.2)
|
405 |
+
self.activate = FusedLeakyReLU(out_channel)
|
406 |
+
|
407 |
+
def forward(self, input, style=None, noise=None):
|
408 |
+
out = self.conv(input, style)
|
409 |
+
if self.inject_noise:
|
410 |
+
out = self.noise(out, noise=noise)
|
411 |
+
# out = out + self.bias
|
412 |
+
out = self.activate(out)
|
413 |
+
|
414 |
+
return out
|
415 |
+
|
416 |
+
|
417 |
+
class ToRGB(nn.Module):
|
418 |
+
def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
|
419 |
+
super().__init__()
|
420 |
+
|
421 |
+
if upsample:
|
422 |
+
self.upsample = Upsample(blur_kernel)
|
423 |
+
|
424 |
+
self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
|
425 |
+
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
|
426 |
+
|
427 |
+
def forward(self, input, style, skip=None):
|
428 |
+
out = self.conv(input, style)
|
429 |
+
out = out + self.bias
|
430 |
+
|
431 |
+
if skip is not None:
|
432 |
+
skip = self.upsample(skip)
|
433 |
+
|
434 |
+
out = out + skip
|
435 |
+
|
436 |
+
return out
|
437 |
+
|
438 |
+
|
439 |
+
class Generator(nn.Module):
|
440 |
+
def __init__(
|
441 |
+
self,
|
442 |
+
size,
|
443 |
+
style_dim,
|
444 |
+
n_mlp,
|
445 |
+
channel_multiplier=2,
|
446 |
+
blur_kernel=[1, 3, 3, 1],
|
447 |
+
lr_mlp=0.01,
|
448 |
+
):
|
449 |
+
super().__init__()
|
450 |
+
|
451 |
+
self.size = size
|
452 |
+
|
453 |
+
self.style_dim = style_dim
|
454 |
+
|
455 |
+
layers = [PixelNorm()]
|
456 |
+
|
457 |
+
for i in range(n_mlp):
|
458 |
+
layers.append(
|
459 |
+
EqualLinear(
|
460 |
+
style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
|
461 |
+
)
|
462 |
+
)
|
463 |
+
|
464 |
+
self.style = nn.Sequential(*layers)
|
465 |
+
|
466 |
+
self.channels = {
|
467 |
+
4: 512,
|
468 |
+
8: 512,
|
469 |
+
16: 512,
|
470 |
+
32: 512,
|
471 |
+
64: 256 * channel_multiplier,
|
472 |
+
128: 128 * channel_multiplier,
|
473 |
+
256: 64 * channel_multiplier,
|
474 |
+
512: 32 * channel_multiplier,
|
475 |
+
1024: 16 * channel_multiplier,
|
476 |
+
}
|
477 |
+
|
478 |
+
self.input = ConstantInput(self.channels[4])
|
479 |
+
self.conv1 = StyledConv(
|
480 |
+
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
|
481 |
+
)
|
482 |
+
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
|
483 |
+
|
484 |
+
self.log_size = int(math.log(size, 2))
|
485 |
+
self.num_layers = (self.log_size - 2) * 2 + 1
|
486 |
+
|
487 |
+
self.convs = nn.ModuleList()
|
488 |
+
self.upsamples = nn.ModuleList()
|
489 |
+
self.to_rgbs = nn.ModuleList()
|
490 |
+
self.noises = nn.Module()
|
491 |
+
|
492 |
+
in_channel = self.channels[4]
|
493 |
+
|
494 |
+
for layer_idx in range(self.num_layers):
|
495 |
+
res = (layer_idx + 5) // 2
|
496 |
+
shape = [1, 1, 2 ** res, 2 ** res]
|
497 |
+
self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
|
498 |
+
|
499 |
+
for i in range(3, self.log_size + 1):
|
500 |
+
out_channel = self.channels[2 ** i]
|
501 |
+
|
502 |
+
self.convs.append(
|
503 |
+
StyledConv(
|
504 |
+
in_channel,
|
505 |
+
out_channel,
|
506 |
+
3,
|
507 |
+
style_dim,
|
508 |
+
upsample=True,
|
509 |
+
blur_kernel=blur_kernel,
|
510 |
+
)
|
511 |
+
)
|
512 |
+
|
513 |
+
self.convs.append(
|
514 |
+
StyledConv(
|
515 |
+
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
|
516 |
+
)
|
517 |
+
)
|
518 |
+
|
519 |
+
self.to_rgbs.append(ToRGB(out_channel, style_dim))
|
520 |
+
|
521 |
+
in_channel = out_channel
|
522 |
+
|
523 |
+
self.n_latent = self.log_size * 2 - 2
|
524 |
+
|
525 |
+
def make_noise(self):
|
526 |
+
device = self.input.input.device
|
527 |
+
|
528 |
+
noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
|
529 |
+
|
530 |
+
for i in range(3, self.log_size + 1):
|
531 |
+
for _ in range(2):
|
532 |
+
noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
|
533 |
+
|
534 |
+
return noises
|
535 |
+
|
536 |
+
def mean_latent(self, n_latent):
|
537 |
+
latent_in = torch.randn(
|
538 |
+
n_latent, self.style_dim, device=self.input.input.device
|
539 |
+
)
|
540 |
+
latent = self.style(latent_in).mean(0, keepdim=True)
|
541 |
+
|
542 |
+
return latent
|
543 |
+
|
544 |
+
def get_latent(self, input):
|
545 |
+
return self.style(input)
|
546 |
+
|
547 |
+
def forward(
|
548 |
+
self,
|
549 |
+
styles,
|
550 |
+
return_latents=False,
|
551 |
+
inject_index=None,
|
552 |
+
truncation=1,
|
553 |
+
truncation_latent=None,
|
554 |
+
input_is_latent=False,
|
555 |
+
noise=None,
|
556 |
+
randomize_noise=True,
|
557 |
+
):
|
558 |
+
if not input_is_latent:
|
559 |
+
styles = [self.style(s) for s in styles]
|
560 |
+
|
561 |
+
if noise is None:
|
562 |
+
if randomize_noise:
|
563 |
+
noise = [None] * self.num_layers
|
564 |
+
else:
|
565 |
+
noise = [
|
566 |
+
getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
|
567 |
+
]
|
568 |
+
|
569 |
+
if truncation < 1:
|
570 |
+
style_t = []
|
571 |
+
|
572 |
+
for style in styles:
|
573 |
+
style_t.append(
|
574 |
+
truncation_latent + truncation * (style - truncation_latent)
|
575 |
+
)
|
576 |
+
|
577 |
+
styles = style_t
|
578 |
+
|
579 |
+
if len(styles) < 2:
|
580 |
+
inject_index = self.n_latent
|
581 |
+
|
582 |
+
if len(styles[0].shape) < 3:
|
583 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
584 |
+
|
585 |
+
else:
|
586 |
+
latent = styles[0]
|
587 |
+
|
588 |
+
else:
|
589 |
+
if inject_index is None:
|
590 |
+
inject_index = random.randint(1, self.n_latent - 1)
|
591 |
+
|
592 |
+
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
|
593 |
+
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
|
594 |
+
|
595 |
+
latent = torch.cat([latent, latent2], 1)
|
596 |
+
|
597 |
+
out = self.input(latent)
|
598 |
+
out = self.conv1(out, latent[:, 0], noise=noise[0])
|
599 |
+
|
600 |
+
skip = self.to_rgb1(out, latent[:, 1])
|
601 |
+
|
602 |
+
i = 1
|
603 |
+
for conv1, conv2, noise1, noise2, to_rgb in zip(
|
604 |
+
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
|
605 |
+
):
|
606 |
+
out = conv1(out, latent[:, i], noise=noise1)
|
607 |
+
out = conv2(out, latent[:, i + 1], noise=noise2)
|
608 |
+
skip = to_rgb(out, latent[:, i + 2], skip)
|
609 |
+
|
610 |
+
i += 2
|
611 |
+
|
612 |
+
image = skip
|
613 |
+
|
614 |
+
if return_latents:
|
615 |
+
return image, latent
|
616 |
+
|
617 |
+
else:
|
618 |
+
return image, None
|
619 |
+
|
620 |
+
|
621 |
+
class ConvLayer(nn.Sequential):
|
622 |
+
def __init__(
|
623 |
+
self,
|
624 |
+
in_channel,
|
625 |
+
out_channel,
|
626 |
+
kernel_size,
|
627 |
+
downsample=False,
|
628 |
+
blur_kernel=[1, 3, 3, 1],
|
629 |
+
bias=True,
|
630 |
+
activate=True,
|
631 |
+
):
|
632 |
+
layers = []
|
633 |
+
|
634 |
+
if downsample:
|
635 |
+
factor = 2
|
636 |
+
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
637 |
+
pad0 = (p + 1) // 2
|
638 |
+
pad1 = p // 2
|
639 |
+
|
640 |
+
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
641 |
+
|
642 |
+
stride = 2
|
643 |
+
self.padding = 0
|
644 |
+
|
645 |
+
else:
|
646 |
+
stride = 1
|
647 |
+
self.padding = kernel_size // 2
|
648 |
+
|
649 |
+
layers.append(
|
650 |
+
EqualConv2d(
|
651 |
+
in_channel,
|
652 |
+
out_channel,
|
653 |
+
kernel_size,
|
654 |
+
padding=self.padding,
|
655 |
+
stride=stride,
|
656 |
+
bias=bias and not activate,
|
657 |
+
)
|
658 |
+
)
|
659 |
+
|
660 |
+
if activate:
|
661 |
+
if bias:
|
662 |
+
layers.append(FusedLeakyReLU(out_channel))
|
663 |
+
|
664 |
+
else:
|
665 |
+
layers.append(ScaledLeakyReLU(0.2))
|
666 |
+
|
667 |
+
super().__init__(*layers)
|
668 |
+
|
669 |
+
|
670 |
+
class ResBlock(nn.Module):
|
671 |
+
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True, skip_gain=1.0):
|
672 |
+
super().__init__()
|
673 |
+
|
674 |
+
self.skip_gain = skip_gain
|
675 |
+
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
676 |
+
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel)
|
677 |
+
|
678 |
+
if in_channel != out_channel or downsample:
|
679 |
+
self.skip = ConvLayer(
|
680 |
+
in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False
|
681 |
+
)
|
682 |
+
else:
|
683 |
+
self.skip = nn.Identity()
|
684 |
+
|
685 |
+
def forward(self, input):
|
686 |
+
out = self.conv1(input)
|
687 |
+
out = self.conv2(out)
|
688 |
+
|
689 |
+
skip = self.skip(input)
|
690 |
+
out = (out * self.skip_gain + skip) / math.sqrt(self.skip_gain ** 2 + 1.0)
|
691 |
+
|
692 |
+
return out
|
693 |
+
|
694 |
+
|
695 |
+
class StyleGAN2Discriminator(nn.Module):
|
696 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None):
|
697 |
+
super().__init__()
|
698 |
+
self.opt = opt
|
699 |
+
self.stddev_group = 16
|
700 |
+
if size is None:
|
701 |
+
size = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
|
702 |
+
if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
|
703 |
+
size = 2 ** int(np.log2(self.opt.D_patch_size))
|
704 |
+
|
705 |
+
blur_kernel = [1, 3, 3, 1]
|
706 |
+
channel_multiplier = ndf / 64
|
707 |
+
channels = {
|
708 |
+
4: min(384, int(4096 * channel_multiplier)),
|
709 |
+
8: min(384, int(2048 * channel_multiplier)),
|
710 |
+
16: min(384, int(1024 * channel_multiplier)),
|
711 |
+
32: min(384, int(512 * channel_multiplier)),
|
712 |
+
64: int(256 * channel_multiplier),
|
713 |
+
128: int(128 * channel_multiplier),
|
714 |
+
256: int(64 * channel_multiplier),
|
715 |
+
512: int(32 * channel_multiplier),
|
716 |
+
1024: int(16 * channel_multiplier),
|
717 |
+
}
|
718 |
+
|
719 |
+
convs = [ConvLayer(3, channels[size], 1)]
|
720 |
+
|
721 |
+
log_size = int(math.log(size, 2))
|
722 |
+
|
723 |
+
in_channel = channels[size]
|
724 |
+
|
725 |
+
if "smallpatch" in self.opt.netD:
|
726 |
+
final_res_log2 = 4
|
727 |
+
elif "patch" in self.opt.netD:
|
728 |
+
final_res_log2 = 3
|
729 |
+
else:
|
730 |
+
final_res_log2 = 2
|
731 |
+
|
732 |
+
for i in range(log_size, final_res_log2, -1):
|
733 |
+
out_channel = channels[2 ** (i - 1)]
|
734 |
+
|
735 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel))
|
736 |
+
|
737 |
+
in_channel = out_channel
|
738 |
+
|
739 |
+
self.convs = nn.Sequential(*convs)
|
740 |
+
|
741 |
+
if False and "tile" in self.opt.netD:
|
742 |
+
in_channel += 1
|
743 |
+
self.final_conv = ConvLayer(in_channel, channels[4], 3)
|
744 |
+
if "patch" in self.opt.netD:
|
745 |
+
self.final_linear = ConvLayer(channels[4], 1, 3, bias=False, activate=False)
|
746 |
+
else:
|
747 |
+
self.final_linear = nn.Sequential(
|
748 |
+
EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
|
749 |
+
EqualLinear(channels[4], 1),
|
750 |
+
)
|
751 |
+
|
752 |
+
def forward(self, input, get_minibatch_features=False):
|
753 |
+
if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
|
754 |
+
h, w = input.size(2), input.size(3)
|
755 |
+
y = torch.randint(h - self.opt.D_patch_size, ())
|
756 |
+
x = torch.randint(w - self.opt.D_patch_size, ())
|
757 |
+
input = input[:, :, y:y + self.opt.D_patch_size, x:x + self.opt.D_patch_size]
|
758 |
+
out = input
|
759 |
+
for i, conv in enumerate(self.convs):
|
760 |
+
out = conv(out)
|
761 |
+
# print(i, out.abs().mean())
|
762 |
+
# out = self.convs(input)
|
763 |
+
|
764 |
+
batch, channel, height, width = out.shape
|
765 |
+
|
766 |
+
if False and "tile" in self.opt.netD:
|
767 |
+
group = min(batch, self.stddev_group)
|
768 |
+
stddev = out.view(
|
769 |
+
group, -1, 1, channel // 1, height, width
|
770 |
+
)
|
771 |
+
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
|
772 |
+
stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
|
773 |
+
stddev = stddev.repeat(group, 1, height, width)
|
774 |
+
out = torch.cat([out, stddev], 1)
|
775 |
+
|
776 |
+
out = self.final_conv(out)
|
777 |
+
# print(out.abs().mean())
|
778 |
+
|
779 |
+
if "patch" not in self.opt.netD:
|
780 |
+
out = out.view(batch, -1)
|
781 |
+
out = self.final_linear(out)
|
782 |
+
|
783 |
+
return out
|
784 |
+
|
785 |
+
|
786 |
+
class TileStyleGAN2Discriminator(StyleGAN2Discriminator):
|
787 |
+
def forward(self, input):
|
788 |
+
B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
|
789 |
+
size = self.opt.D_patch_size
|
790 |
+
Y = H // size
|
791 |
+
X = W // size
|
792 |
+
input = input.view(B, C, Y, size, X, size)
|
793 |
+
input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
|
794 |
+
return super().forward(input)
|
795 |
+
|
796 |
+
|
797 |
+
class StyleGAN2Encoder(nn.Module):
|
798 |
+
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
|
799 |
+
super().__init__()
|
800 |
+
assert opt is not None
|
801 |
+
self.opt = opt
|
802 |
+
channel_multiplier = ngf / 32
|
803 |
+
channels = {
|
804 |
+
4: min(512, int(round(4096 * channel_multiplier))),
|
805 |
+
8: min(512, int(round(2048 * channel_multiplier))),
|
806 |
+
16: min(512, int(round(1024 * channel_multiplier))),
|
807 |
+
32: min(512, int(round(512 * channel_multiplier))),
|
808 |
+
64: int(round(256 * channel_multiplier)),
|
809 |
+
128: int(round(128 * channel_multiplier)),
|
810 |
+
256: int(round(64 * channel_multiplier)),
|
811 |
+
512: int(round(32 * channel_multiplier)),
|
812 |
+
1024: int(round(16 * channel_multiplier)),
|
813 |
+
}
|
814 |
+
|
815 |
+
blur_kernel = [1, 3, 3, 1]
|
816 |
+
|
817 |
+
cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
|
818 |
+
convs = [nn.Identity(),
|
819 |
+
ConvLayer(3, channels[cur_res], 1)]
|
820 |
+
|
821 |
+
num_downsampling = self.opt.stylegan2_G_num_downsampling
|
822 |
+
for i in range(num_downsampling):
|
823 |
+
in_channel = channels[cur_res]
|
824 |
+
out_channel = channels[cur_res // 2]
|
825 |
+
convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True))
|
826 |
+
cur_res = cur_res // 2
|
827 |
+
|
828 |
+
for i in range(n_blocks // 2):
|
829 |
+
n_channel = channels[cur_res]
|
830 |
+
convs.append(ResBlock(n_channel, n_channel, downsample=False))
|
831 |
+
|
832 |
+
self.convs = nn.Sequential(*convs)
|
833 |
+
|
834 |
+
def forward(self, input, layers=[], get_features=False):
|
835 |
+
feat = input
|
836 |
+
feats = []
|
837 |
+
if -1 in layers:
|
838 |
+
layers.append(len(self.convs) - 1)
|
839 |
+
for layer_id, layer in enumerate(self.convs):
|
840 |
+
feat = layer(feat)
|
841 |
+
# print(layer_id, " features ", feat.abs().mean())
|
842 |
+
if layer_id in layers:
|
843 |
+
feats.append(feat)
|
844 |
+
|
845 |
+
if get_features:
|
846 |
+
return feat, feats
|
847 |
+
else:
|
848 |
+
return feat
|
849 |
+
|
850 |
+
|
851 |
+
class StyleGAN2Decoder(nn.Module):
|
852 |
+
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
|
853 |
+
super().__init__()
|
854 |
+
assert opt is not None
|
855 |
+
self.opt = opt
|
856 |
+
|
857 |
+
blur_kernel = [1, 3, 3, 1]
|
858 |
+
|
859 |
+
channel_multiplier = ngf / 32
|
860 |
+
channels = {
|
861 |
+
4: min(512, int(round(4096 * channel_multiplier))),
|
862 |
+
8: min(512, int(round(2048 * channel_multiplier))),
|
863 |
+
16: min(512, int(round(1024 * channel_multiplier))),
|
864 |
+
32: min(512, int(round(512 * channel_multiplier))),
|
865 |
+
64: int(round(256 * channel_multiplier)),
|
866 |
+
128: int(round(128 * channel_multiplier)),
|
867 |
+
256: int(round(64 * channel_multiplier)),
|
868 |
+
512: int(round(32 * channel_multiplier)),
|
869 |
+
1024: int(round(16 * channel_multiplier)),
|
870 |
+
}
|
871 |
+
|
872 |
+
num_downsampling = self.opt.stylegan2_G_num_downsampling
|
873 |
+
cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling)
|
874 |
+
convs = []
|
875 |
+
|
876 |
+
for i in range(n_blocks // 2):
|
877 |
+
n_channel = channels[cur_res]
|
878 |
+
convs.append(ResBlock(n_channel, n_channel, downsample=False))
|
879 |
+
|
880 |
+
for i in range(num_downsampling):
|
881 |
+
in_channel = channels[cur_res]
|
882 |
+
out_channel = channels[cur_res * 2]
|
883 |
+
inject_noise = "small" not in self.opt.netG
|
884 |
+
convs.append(
|
885 |
+
StyledConv(in_channel, out_channel, 3, upsample=True, blur_kernel=blur_kernel, inject_noise=inject_noise)
|
886 |
+
)
|
887 |
+
cur_res = cur_res * 2
|
888 |
+
|
889 |
+
convs.append(ConvLayer(channels[cur_res], 3, 1))
|
890 |
+
|
891 |
+
self.convs = nn.Sequential(*convs)
|
892 |
+
|
893 |
+
def forward(self, input):
|
894 |
+
return self.convs(input)
|
895 |
+
|
896 |
+
|
897 |
+
class StyleGAN2Generator(nn.Module):
|
898 |
+
def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
|
899 |
+
super().__init__()
|
900 |
+
self.opt = opt
|
901 |
+
self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
|
902 |
+
self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
|
903 |
+
|
904 |
+
def forward(self, input, layers=[], encode_only=False):
|
905 |
+
feat, feats = self.encoder(input, layers, True)
|
906 |
+
if encode_only:
|
907 |
+
return feats
|
908 |
+
else:
|
909 |
+
fake = self.decoder(feat)
|
910 |
+
|
911 |
+
if len(layers) > 0:
|
912 |
+
return fake, feats
|
913 |
+
else:
|
914 |
+
return fake
|
models/template_model.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Model class template
|
2 |
+
|
3 |
+
This module provides a template for users to implement custom models.
|
4 |
+
You can specify '--model template' to use this model.
|
5 |
+
The class name should be consistent with both the filename and its model option.
|
6 |
+
The filename should be <model>_dataset.py
|
7 |
+
The class name should be <Model>Dataset.py
|
8 |
+
It implements a simple image-to-image translation baseline based on regression loss.
|
9 |
+
Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
|
10 |
+
min_<netG> ||netG(data_A) - data_B||_1
|
11 |
+
You need to implement the following functions:
|
12 |
+
<modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
|
13 |
+
<__init__>: Initialize this model class.
|
14 |
+
<set_input>: Unpack input data and perform data pre-processing.
|
15 |
+
<forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
|
16 |
+
<optimize_parameters>: Update network weights; it will be called in every training iteration.
|
17 |
+
"""
|
18 |
+
import torch
|
19 |
+
from .base_model import BaseModel
|
20 |
+
from . import networks
|
21 |
+
|
22 |
+
|
23 |
+
class TemplateModel(BaseModel):
|
24 |
+
@staticmethod
|
25 |
+
def modify_commandline_options(parser, is_train=True):
|
26 |
+
"""Add new model-specific options and rewrite default values for existing options.
|
27 |
+
|
28 |
+
Parameters:
|
29 |
+
parser -- the option parser
|
30 |
+
is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
the modified parser.
|
34 |
+
"""
|
35 |
+
parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
|
36 |
+
if is_train:
|
37 |
+
parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
|
38 |
+
|
39 |
+
return parser
|
40 |
+
|
41 |
+
def __init__(self, opt):
|
42 |
+
"""Initialize this model class.
|
43 |
+
|
44 |
+
Parameters:
|
45 |
+
opt -- training/test options
|
46 |
+
|
47 |
+
A few things can be done here.
|
48 |
+
- (required) call the initialization function of BaseModel
|
49 |
+
- define loss function, visualization images, model names, and optimizers
|
50 |
+
"""
|
51 |
+
BaseModel.__init__(self, opt) # call the initialization method of BaseModel
|
52 |
+
# specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
|
53 |
+
self.loss_names = ['loss_G']
|
54 |
+
# specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
|
55 |
+
self.visual_names = ['data_A', 'data_B', 'output']
|
56 |
+
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
|
57 |
+
# you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
|
58 |
+
self.model_names = ['G']
|
59 |
+
# define networks; you can use opt.isTrain to specify different behaviors for training and test.
|
60 |
+
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
|
61 |
+
if self.isTrain: # only defined during training time
|
62 |
+
# define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
|
63 |
+
# We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
|
64 |
+
self.criterionLoss = torch.nn.L1Loss()
|
65 |
+
# define and initialize optimizers. You can define one optimizer for each network.
|
66 |
+
# If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
67 |
+
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
68 |
+
self.optimizers = [self.optimizer]
|
69 |
+
|
70 |
+
# Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
|
71 |
+
|
72 |
+
def set_input(self, input):
|
73 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
74 |
+
|
75 |
+
Parameters:
|
76 |
+
input: a dictionary that contains the data itself and its metadata information.
|
77 |
+
"""
|
78 |
+
AtoB = self.opt.direction == 'AtoB' # use <direction> to swap data_A and data_B
|
79 |
+
self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
|
80 |
+
self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
|
81 |
+
self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
|
82 |
+
|
83 |
+
def forward(self):
|
84 |
+
"""Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
|
85 |
+
self.output = self.netG(self.data_A) # generate output image given the input data_A
|
86 |
+
|
87 |
+
def backward(self):
|
88 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
89 |
+
# caculate the intermediate results if necessary; here self.output has been computed during function <forward>
|
90 |
+
# calculate loss given the input and intermediate results
|
91 |
+
self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
|
92 |
+
self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
|
93 |
+
|
94 |
+
def optimize_parameters(self):
|
95 |
+
"""Update network weights; it will be called in every training iteration."""
|
96 |
+
self.forward() # first call forward to calculate intermediate results
|
97 |
+
self.optimizer.zero_grad() # clear network G's existing gradients
|
98 |
+
self.backward() # calculate gradients for network G
|
99 |
+
self.optimizer.step() # update gradients for network G
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=1.6.0
|
2 |
+
torchvision>=0.7.0
|
3 |
+
dominate>=2.4.0
|
4 |
+
visdom>=0.1.8.8
|
5 |
+
packaging
|
6 |
+
GPUtil>=1.4.0
|
7 |
+
scipy
|
8 |
+
Pillow>=6.1.0
|
9 |
+
numpy>=1.16.4
|
test.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ModelLoader import ModelLoader
|
2 |
+
import os
|
3 |
+
|
4 |
+
def main():
|
5 |
+
model = ModelLoader()
|
6 |
+
model.load()
|
7 |
+
# Test
|
8 |
+
sample = os.path.join(os.path.dirname(__file__), 'examples', 'rawimg.png')
|
9 |
+
model.inference(src=sample)
|
10 |
+
|
11 |
+
main()
|
util/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
"""This package includes a miscellaneous collection of useful helper functions."""
|
2 |
+
from util import *
|
util/get_data.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
import requests
|
5 |
+
from warnings import warn
|
6 |
+
from zipfile import ZipFile
|
7 |
+
from bs4 import BeautifulSoup
|
8 |
+
from os.path import abspath, isdir, join, basename
|
9 |
+
|
10 |
+
|
11 |
+
class GetData(object):
|
12 |
+
"""A Python script for downloading CycleGAN or pix2pix datasets.
|
13 |
+
|
14 |
+
Parameters:
|
15 |
+
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
|
16 |
+
verbose (bool) -- If True, print additional information.
|
17 |
+
|
18 |
+
Examples:
|
19 |
+
>>> from util.get_data import GetData
|
20 |
+
>>> gd = GetData(technique='cyclegan')
|
21 |
+
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
|
22 |
+
|
23 |
+
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
|
24 |
+
and 'scripts/download_cyclegan_model.sh'.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, technique='cyclegan', verbose=True):
|
28 |
+
url_dict = {
|
29 |
+
'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
|
30 |
+
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
|
31 |
+
}
|
32 |
+
self.url = url_dict.get(technique.lower())
|
33 |
+
self._verbose = verbose
|
34 |
+
|
35 |
+
def _print(self, text):
|
36 |
+
if self._verbose:
|
37 |
+
print(text)
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def _get_options(r):
|
41 |
+
soup = BeautifulSoup(r.text, 'lxml')
|
42 |
+
options = [h.text for h in soup.find_all('a', href=True)
|
43 |
+
if h.text.endswith(('.zip', 'tar.gz'))]
|
44 |
+
return options
|
45 |
+
|
46 |
+
def _present_options(self):
|
47 |
+
r = requests.get(self.url)
|
48 |
+
options = self._get_options(r)
|
49 |
+
print('Options:\n')
|
50 |
+
for i, o in enumerate(options):
|
51 |
+
print("{0}: {1}".format(i, o))
|
52 |
+
choice = input("\nPlease enter the number of the "
|
53 |
+
"dataset above you wish to download:")
|
54 |
+
return options[int(choice)]
|
55 |
+
|
56 |
+
def _download_data(self, dataset_url, save_path):
|
57 |
+
if not isdir(save_path):
|
58 |
+
os.makedirs(save_path)
|
59 |
+
|
60 |
+
base = basename(dataset_url)
|
61 |
+
temp_save_path = join(save_path, base)
|
62 |
+
|
63 |
+
with open(temp_save_path, "wb") as f:
|
64 |
+
r = requests.get(dataset_url)
|
65 |
+
f.write(r.content)
|
66 |
+
|
67 |
+
if base.endswith('.tar.gz'):
|
68 |
+
obj = tarfile.open(temp_save_path)
|
69 |
+
elif base.endswith('.zip'):
|
70 |
+
obj = ZipFile(temp_save_path, 'r')
|
71 |
+
else:
|
72 |
+
raise ValueError("Unknown File Type: {0}.".format(base))
|
73 |
+
|
74 |
+
self._print("Unpacking Data...")
|
75 |
+
obj.extractall(save_path)
|
76 |
+
obj.close()
|
77 |
+
os.remove(temp_save_path)
|
78 |
+
|
79 |
+
def get(self, save_path, dataset=None):
|
80 |
+
"""
|
81 |
+
|
82 |
+
Download a dataset.
|
83 |
+
|
84 |
+
Parameters:
|
85 |
+
save_path (str) -- A directory to save the data to.
|
86 |
+
dataset (str) -- (optional). A specific dataset to download.
|
87 |
+
Note: this must include the file extension.
|
88 |
+
If None, options will be presented for you
|
89 |
+
to choose from.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
save_path_full (str) -- the absolute path to the downloaded data.
|
93 |
+
|
94 |
+
"""
|
95 |
+
if dataset is None:
|
96 |
+
selected_dataset = self._present_options()
|
97 |
+
else:
|
98 |
+
selected_dataset = dataset
|
99 |
+
|
100 |
+
save_path_full = join(save_path, selected_dataset.split('.')[0])
|
101 |
+
|
102 |
+
if isdir(save_path_full):
|
103 |
+
warn("\n'{0}' already exists. Voiding Download.".format(
|
104 |
+
save_path_full))
|
105 |
+
else:
|
106 |
+
self._print('Downloading Data...')
|
107 |
+
url = "{0}/{1}".format(self.url, selected_dataset)
|
108 |
+
self._download_data(url, save_path=save_path)
|
109 |
+
|
110 |
+
return abspath(save_path_full)
|
util/html.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dominate
|
2 |
+
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class HTML:
|
7 |
+
"""This HTML class allows us to save images and write texts into a single HTML file.
|
8 |
+
|
9 |
+
It consists of functions such as <add_header> (add a text header to the HTML file),
|
10 |
+
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
|
11 |
+
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, web_dir, title, refresh=0):
|
15 |
+
"""Initialize the HTML classes
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
|
19 |
+
title (str) -- the webpage name
|
20 |
+
refresh (int) -- how often the website refresh itself; if 0; no refreshing
|
21 |
+
"""
|
22 |
+
self.title = title
|
23 |
+
self.web_dir = web_dir
|
24 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
25 |
+
if not os.path.exists(self.web_dir):
|
26 |
+
os.makedirs(self.web_dir)
|
27 |
+
if not os.path.exists(self.img_dir):
|
28 |
+
os.makedirs(self.img_dir)
|
29 |
+
|
30 |
+
self.doc = dominate.document(title=title)
|
31 |
+
if refresh > 0:
|
32 |
+
with self.doc.head:
|
33 |
+
meta(http_equiv="refresh", content=str(refresh))
|
34 |
+
|
35 |
+
def get_image_dir(self):
|
36 |
+
"""Return the directory that stores images"""
|
37 |
+
return self.img_dir
|
38 |
+
|
39 |
+
def add_header(self, text):
|
40 |
+
"""Insert a header to the HTML file
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
text (str) -- the header text
|
44 |
+
"""
|
45 |
+
with self.doc:
|
46 |
+
h3(text)
|
47 |
+
|
48 |
+
def add_images(self, ims, txts, links, width=400):
|
49 |
+
"""add images to the HTML file
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
ims (str list) -- a list of image paths
|
53 |
+
txts (str list) -- a list of image names shown on the website
|
54 |
+
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
|
55 |
+
"""
|
56 |
+
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
|
57 |
+
self.doc.add(self.t)
|
58 |
+
with self.t:
|
59 |
+
with tr():
|
60 |
+
for im, txt, link in zip(ims, txts, links):
|
61 |
+
with td(style="word-wrap: break-word;", halign="center", valign="top"):
|
62 |
+
with p():
|
63 |
+
with a(href=os.path.join('images', link)):
|
64 |
+
img(style="width:%dpx" % width, src=os.path.join('images', im))
|
65 |
+
br()
|
66 |
+
p(txt)
|
67 |
+
|
68 |
+
def save(self):
|
69 |
+
"""save the current content to the HMTL file"""
|
70 |
+
html_file = '%s/index.html' % self.web_dir
|
71 |
+
f = open(html_file, 'wt')
|
72 |
+
f.write(self.doc.render())
|
73 |
+
f.close()
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__': # we show an example usage here.
|
77 |
+
html = HTML('web/', 'test_html')
|
78 |
+
html.add_header('hello world')
|
79 |
+
|
80 |
+
ims, txts, links = [], [], []
|
81 |
+
for n in range(4):
|
82 |
+
ims.append('image_%d.png' % n)
|
83 |
+
txts.append('text_%d' % n)
|
84 |
+
links.append('image_%d.png' % n)
|
85 |
+
html.add_images(ims, txts, links)
|
86 |
+
html.save()
|
util/image_pool.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class ImagePool():
|
6 |
+
"""This class implements an image buffer that stores previously generated images.
|
7 |
+
|
8 |
+
This buffer enables us to update discriminators using a history of generated images
|
9 |
+
rather than the ones produced by the latest generators.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, pool_size):
|
13 |
+
"""Initialize the ImagePool class
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
|
17 |
+
"""
|
18 |
+
self.pool_size = pool_size
|
19 |
+
if self.pool_size > 0: # create an empty pool
|
20 |
+
self.num_imgs = 0
|
21 |
+
self.images = []
|
22 |
+
|
23 |
+
def query(self, images):
|
24 |
+
"""Return an image from the pool.
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
images: the latest generated images from the generator
|
28 |
+
|
29 |
+
Returns images from the buffer.
|
30 |
+
|
31 |
+
By 50/100, the buffer will return input images.
|
32 |
+
By 50/100, the buffer will return images previously stored in the buffer,
|
33 |
+
and insert the current images to the buffer.
|
34 |
+
"""
|
35 |
+
if self.pool_size == 0: # if the buffer size is 0, do nothing
|
36 |
+
return images
|
37 |
+
return_images = []
|
38 |
+
for image in images:
|
39 |
+
image = torch.unsqueeze(image.data, 0)
|
40 |
+
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
|
41 |
+
self.num_imgs = self.num_imgs + 1
|
42 |
+
self.images.append(image)
|
43 |
+
return_images.append(image)
|
44 |
+
else:
|
45 |
+
p = random.uniform(0, 1)
|
46 |
+
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
|
47 |
+
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
|
48 |
+
tmp = self.images[random_id].clone()
|
49 |
+
self.images[random_id] = image
|
50 |
+
return_images.append(tmp)
|
51 |
+
else: # by another 50% chance, the buffer will return the current image
|
52 |
+
return_images.append(image)
|
53 |
+
return_images = torch.cat(return_images, 0) # collect all the images and return
|
54 |
+
return return_images
|
util/util.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module contains simple helper functions """
|
2 |
+
from __future__ import print_function
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
import importlib
|
8 |
+
import argparse
|
9 |
+
from argparse import Namespace
|
10 |
+
import torchvision
|
11 |
+
|
12 |
+
|
13 |
+
def str2bool(v):
|
14 |
+
if isinstance(v, bool):
|
15 |
+
return v
|
16 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
17 |
+
return True
|
18 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
19 |
+
return False
|
20 |
+
else:
|
21 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
22 |
+
|
23 |
+
|
24 |
+
def copyconf(default_opt, **kwargs):
|
25 |
+
conf = Namespace(**vars(default_opt))
|
26 |
+
for key in kwargs:
|
27 |
+
setattr(conf, key, kwargs[key])
|
28 |
+
return conf
|
29 |
+
|
30 |
+
|
31 |
+
def find_class_in_module(target_cls_name, module):
|
32 |
+
target_cls_name = target_cls_name.replace('_', '').lower()
|
33 |
+
clslib = importlib.import_module(module)
|
34 |
+
cls = None
|
35 |
+
for name, clsobj in clslib.__dict__.items():
|
36 |
+
if name.lower() == target_cls_name:
|
37 |
+
cls = clsobj
|
38 |
+
|
39 |
+
assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
|
40 |
+
|
41 |
+
return cls
|
42 |
+
|
43 |
+
|
44 |
+
def tensor2im(input_image, imtype=np.uint8):
|
45 |
+
""""Converts a Tensor array into a numpy image array.
|
46 |
+
|
47 |
+
Parameters:
|
48 |
+
input_image (tensor) -- the input image tensor array
|
49 |
+
imtype (type) -- the desired type of the converted numpy array
|
50 |
+
"""
|
51 |
+
if not isinstance(input_image, np.ndarray):
|
52 |
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
53 |
+
image_tensor = input_image.data
|
54 |
+
else:
|
55 |
+
return input_image
|
56 |
+
image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array
|
57 |
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
58 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
59 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
60 |
+
else: # if it is a numpy array, do nothing
|
61 |
+
image_numpy = input_image
|
62 |
+
return image_numpy.astype(imtype)
|
63 |
+
|
64 |
+
|
65 |
+
def diagnose_network(net, name='network'):
|
66 |
+
"""Calculate and print the mean of average absolute(gradients)
|
67 |
+
|
68 |
+
Parameters:
|
69 |
+
net (torch network) -- Torch network
|
70 |
+
name (str) -- the name of the network
|
71 |
+
"""
|
72 |
+
mean = 0.0
|
73 |
+
count = 0
|
74 |
+
for param in net.parameters():
|
75 |
+
if param.grad is not None:
|
76 |
+
mean += torch.mean(torch.abs(param.grad.data))
|
77 |
+
count += 1
|
78 |
+
if count > 0:
|
79 |
+
mean = mean / count
|
80 |
+
print(name)
|
81 |
+
print(mean)
|
82 |
+
|
83 |
+
|
84 |
+
def save_image(image_numpy, image_path, aspect_ratio=1.0):
|
85 |
+
"""Save a numpy image to the disk
|
86 |
+
|
87 |
+
Parameters:
|
88 |
+
image_numpy (numpy array) -- input numpy array
|
89 |
+
image_path (str) -- the path of the image
|
90 |
+
"""
|
91 |
+
|
92 |
+
image_pil = Image.fromarray(image_numpy)
|
93 |
+
h, w, _ = image_numpy.shape
|
94 |
+
|
95 |
+
if aspect_ratio is None:
|
96 |
+
pass
|
97 |
+
elif aspect_ratio > 1.0:
|
98 |
+
image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
|
99 |
+
elif aspect_ratio < 1.0:
|
100 |
+
image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
|
101 |
+
image_pil.save(image_path)
|
102 |
+
|
103 |
+
|
104 |
+
def print_numpy(x, val=True, shp=False):
|
105 |
+
"""Print the mean, min, max, median, std, and size of a numpy array
|
106 |
+
|
107 |
+
Parameters:
|
108 |
+
val (bool) -- if print the values of the numpy array
|
109 |
+
shp (bool) -- if print the shape of the numpy array
|
110 |
+
"""
|
111 |
+
x = x.astype(np.float64)
|
112 |
+
if shp:
|
113 |
+
print('shape,', x.shape)
|
114 |
+
if val:
|
115 |
+
x = x.flatten()
|
116 |
+
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
|
117 |
+
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
|
118 |
+
|
119 |
+
|
120 |
+
def mkdirs(paths):
|
121 |
+
"""create empty directories if they don't exist
|
122 |
+
|
123 |
+
Parameters:
|
124 |
+
paths (str list) -- a list of directory paths
|
125 |
+
"""
|
126 |
+
if isinstance(paths, list) and not isinstance(paths, str):
|
127 |
+
for path in paths:
|
128 |
+
mkdir(path)
|
129 |
+
else:
|
130 |
+
mkdir(paths)
|
131 |
+
|
132 |
+
|
133 |
+
def mkdir(path):
|
134 |
+
"""create a single empty directory if it didn't exist
|
135 |
+
|
136 |
+
Parameters:
|
137 |
+
path (str) -- a single directory path
|
138 |
+
"""
|
139 |
+
if not os.path.exists(path):
|
140 |
+
os.makedirs(path)
|
141 |
+
|
142 |
+
|
143 |
+
def correct_resize_label(t, size):
|
144 |
+
device = t.device
|
145 |
+
t = t.detach().cpu()
|
146 |
+
resized = []
|
147 |
+
for i in range(t.size(0)):
|
148 |
+
one_t = t[i, :1]
|
149 |
+
one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
|
150 |
+
one_np = one_np[:, :, 0]
|
151 |
+
one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
|
152 |
+
resized_t = torch.from_numpy(np.array(one_image)).long()
|
153 |
+
resized.append(resized_t)
|
154 |
+
return torch.stack(resized, dim=0).to(device)
|
155 |
+
|
156 |
+
|
157 |
+
def correct_resize(t, size, mode=Image.BICUBIC):
|
158 |
+
device = t.device
|
159 |
+
t = t.detach().cpu()
|
160 |
+
resized = []
|
161 |
+
for i in range(t.size(0)):
|
162 |
+
one_t = t[i:i + 1]
|
163 |
+
one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
|
164 |
+
resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
|
165 |
+
resized.append(resized_t)
|
166 |
+
return torch.stack(resized, dim=0).to(device)
|
util/visualizer.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import ntpath
|
5 |
+
import time
|
6 |
+
from . import util, html
|
7 |
+
from subprocess import Popen, PIPE
|
8 |
+
|
9 |
+
if sys.version_info[0] == 2:
|
10 |
+
VisdomExceptionBase = Exception
|
11 |
+
else:
|
12 |
+
VisdomExceptionBase = ConnectionError
|
13 |
+
|
14 |
+
|
15 |
+
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
|
16 |
+
"""Save images to the disk.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
|
20 |
+
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
|
21 |
+
image_path (str) -- the string is used to create image paths
|
22 |
+
aspect_ratio (float) -- the aspect ratio of saved images
|
23 |
+
width (int) -- the images will be resized to width x width
|
24 |
+
|
25 |
+
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
|
26 |
+
"""
|
27 |
+
image_dir = webpage.get_image_dir()
|
28 |
+
short_path = ntpath.basename(image_path[0])
|
29 |
+
name = os.path.splitext(short_path)[0]
|
30 |
+
|
31 |
+
webpage.add_header(name)
|
32 |
+
ims, txts, links = [], [], []
|
33 |
+
|
34 |
+
for label, im_data in visuals.items():
|
35 |
+
im = util.tensor2im(im_data)
|
36 |
+
image_name = '%s/%s.png' % (label, name)
|
37 |
+
os.makedirs(os.path.join(image_dir, label), exist_ok=True)
|
38 |
+
save_path = os.path.join(image_dir, image_name)
|
39 |
+
util.save_image(im, save_path, aspect_ratio=aspect_ratio)
|
40 |
+
ims.append(image_name)
|
41 |
+
txts.append(label)
|
42 |
+
links.append(image_name)
|
43 |
+
webpage.add_images(ims, txts, links, width=width)
|
44 |
+
|
45 |
+
|
46 |
+
class Visualizer():
|
47 |
+
"""This class includes several functions that can display/save images and print/save logging information.
|
48 |
+
|
49 |
+
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(self, opt):
|
53 |
+
"""Initialize the Visualizer class
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
57 |
+
Step 1: Cache the training/test options
|
58 |
+
Step 2: connect to a visdom server
|
59 |
+
Step 3: create an HTML object for saveing HTML filters
|
60 |
+
Step 4: create a logging file to store training losses
|
61 |
+
"""
|
62 |
+
self.opt = opt # cache the option
|
63 |
+
if opt.display_id is None:
|
64 |
+
self.display_id = np.random.randint(100000) * 10 # just a random display id
|
65 |
+
else:
|
66 |
+
self.display_id = opt.display_id
|
67 |
+
self.use_html = opt.isTrain and not opt.no_html
|
68 |
+
self.win_size = opt.display_winsize
|
69 |
+
self.name = opt.name
|
70 |
+
self.port = opt.display_port
|
71 |
+
self.saved = False
|
72 |
+
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
|
73 |
+
import visdom
|
74 |
+
self.plot_data = {}
|
75 |
+
self.ncols = opt.display_ncols
|
76 |
+
if "tensorboard_base_url" not in os.environ:
|
77 |
+
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
|
78 |
+
else:
|
79 |
+
self.vis = visdom.Visdom(port=2004,
|
80 |
+
base_url=os.environ['tensorboard_base_url'] + '/visdom')
|
81 |
+
if not self.vis.check_connection():
|
82 |
+
self.create_visdom_connections()
|
83 |
+
|
84 |
+
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
|
85 |
+
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
|
86 |
+
self.img_dir = os.path.join(self.web_dir, 'images')
|
87 |
+
print('create web directory %s...' % self.web_dir)
|
88 |
+
util.mkdirs([self.web_dir, self.img_dir])
|
89 |
+
# create a logging file to store training losses
|
90 |
+
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
|
91 |
+
with open(self.log_name, "a") as log_file:
|
92 |
+
now = time.strftime("%c")
|
93 |
+
log_file.write('================ Training Loss (%s) ================\n' % now)
|
94 |
+
|
95 |
+
def reset(self):
|
96 |
+
"""Reset the self.saved status"""
|
97 |
+
self.saved = False
|
98 |
+
|
99 |
+
def create_visdom_connections(self):
|
100 |
+
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
|
101 |
+
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
|
102 |
+
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
|
103 |
+
print('Command: %s' % cmd)
|
104 |
+
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
|
105 |
+
|
106 |
+
def display_current_results(self, visuals, epoch, save_result):
|
107 |
+
"""Display current results on visdom; save current results to an HTML file.
|
108 |
+
|
109 |
+
Parameters:
|
110 |
+
visuals (OrderedDict) - - dictionary of images to display or save
|
111 |
+
epoch (int) - - the current epoch
|
112 |
+
save_result (bool) - - if save the current results to an HTML file
|
113 |
+
"""
|
114 |
+
if self.display_id > 0: # show images in the browser using visdom
|
115 |
+
ncols = self.ncols
|
116 |
+
if ncols > 0: # show all the images in one visdom panel
|
117 |
+
ncols = min(ncols, len(visuals))
|
118 |
+
h, w = next(iter(visuals.values())).shape[:2]
|
119 |
+
table_css = """<style>
|
120 |
+
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
|
121 |
+
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
|
122 |
+
</style>""" % (w, h) # create a table css
|
123 |
+
# create a table of images.
|
124 |
+
title = self.name
|
125 |
+
label_html = ''
|
126 |
+
label_html_row = ''
|
127 |
+
images = []
|
128 |
+
idx = 0
|
129 |
+
for label, image in visuals.items():
|
130 |
+
image_numpy = util.tensor2im(image)
|
131 |
+
label_html_row += '<td>%s</td>' % label
|
132 |
+
images.append(image_numpy.transpose([2, 0, 1]))
|
133 |
+
idx += 1
|
134 |
+
if idx % ncols == 0:
|
135 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
136 |
+
label_html_row = ''
|
137 |
+
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
138 |
+
while idx % ncols != 0:
|
139 |
+
images.append(white_image)
|
140 |
+
label_html_row += '<td></td>'
|
141 |
+
idx += 1
|
142 |
+
if label_html_row != '':
|
143 |
+
label_html += '<tr>%s</tr>' % label_html_row
|
144 |
+
try:
|
145 |
+
self.vis.images(images, ncols, 2, self.display_id + 1,
|
146 |
+
None, dict(title=title + ' images'))
|
147 |
+
label_html = '<table>%s</table>' % label_html
|
148 |
+
self.vis.text(table_css + label_html, win=self.display_id + 2,
|
149 |
+
opts=dict(title=title + ' labels'))
|
150 |
+
except VisdomExceptionBase:
|
151 |
+
self.create_visdom_connections()
|
152 |
+
|
153 |
+
else: # show each image in a separate visdom panel;
|
154 |
+
idx = 1
|
155 |
+
try:
|
156 |
+
for label, image in visuals.items():
|
157 |
+
image_numpy = util.tensor2im(image)
|
158 |
+
self.vis.image(
|
159 |
+
image_numpy.transpose([2, 0, 1]),
|
160 |
+
self.display_id + idx,
|
161 |
+
None,
|
162 |
+
dict(title=label)
|
163 |
+
)
|
164 |
+
idx += 1
|
165 |
+
except VisdomExceptionBase:
|
166 |
+
self.create_visdom_connections()
|
167 |
+
|
168 |
+
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
|
169 |
+
self.saved = True
|
170 |
+
# save images to the disk
|
171 |
+
for label, image in visuals.items():
|
172 |
+
image_numpy = util.tensor2im(image)
|
173 |
+
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
|
174 |
+
util.save_image(image_numpy, img_path)
|
175 |
+
|
176 |
+
# update website
|
177 |
+
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
|
178 |
+
for n in range(epoch, 0, -1):
|
179 |
+
webpage.add_header('epoch [%d]' % n)
|
180 |
+
ims, txts, links = [], [], []
|
181 |
+
|
182 |
+
for label, image_numpy in visuals.items():
|
183 |
+
image_numpy = util.tensor2im(image)
|
184 |
+
img_path = 'epoch%.3d_%s.png' % (n, label)
|
185 |
+
ims.append(img_path)
|
186 |
+
txts.append(label)
|
187 |
+
links.append(img_path)
|
188 |
+
webpage.add_images(ims, txts, links, width=self.win_size)
|
189 |
+
webpage.save()
|
190 |
+
|
191 |
+
def plot_current_losses(self, epoch, counter_ratio, losses):
|
192 |
+
"""display the current losses on visdom display: dictionary of error labels and values
|
193 |
+
|
194 |
+
Parameters:
|
195 |
+
epoch (int) -- current epoch
|
196 |
+
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
|
197 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
198 |
+
"""
|
199 |
+
if len(losses) == 0:
|
200 |
+
return
|
201 |
+
|
202 |
+
plot_name = '_'.join(list(losses.keys()))
|
203 |
+
|
204 |
+
if plot_name not in self.plot_data:
|
205 |
+
self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())}
|
206 |
+
|
207 |
+
plot_data = self.plot_data[plot_name]
|
208 |
+
plot_id = list(self.plot_data.keys()).index(plot_name)
|
209 |
+
|
210 |
+
plot_data['X'].append(epoch + counter_ratio)
|
211 |
+
plot_data['Y'].append([losses[k] for k in plot_data['legend']])
|
212 |
+
try:
|
213 |
+
self.vis.line(
|
214 |
+
X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1),
|
215 |
+
Y=np.array(plot_data['Y']),
|
216 |
+
opts={
|
217 |
+
'title': self.name,
|
218 |
+
'legend': plot_data['legend'],
|
219 |
+
'xlabel': 'epoch',
|
220 |
+
'ylabel': 'loss'},
|
221 |
+
win=self.display_id - plot_id)
|
222 |
+
except VisdomExceptionBase:
|
223 |
+
self.create_visdom_connections()
|
224 |
+
|
225 |
+
# losses: same format as |losses| of plot_current_losses
|
226 |
+
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
|
227 |
+
"""print current losses on console; also save the losses to the disk
|
228 |
+
|
229 |
+
Parameters:
|
230 |
+
epoch (int) -- current epoch
|
231 |
+
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
|
232 |
+
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
|
233 |
+
t_comp (float) -- computational time per data point (normalized by batch_size)
|
234 |
+
t_data (float) -- data loading time per data point (normalized by batch_size)
|
235 |
+
"""
|
236 |
+
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
|
237 |
+
for k, v in losses.items():
|
238 |
+
message += '%s: %.3f ' % (k, v)
|
239 |
+
|
240 |
+
print(message) # print the message
|
241 |
+
with open(self.log_name, "a") as log_file:
|
242 |
+
log_file.write('%s\n' % message) # save the message
|