Yarflam commited on
Commit
c954f09
·
1 Parent(s): c3559b1

Main code + Checkpoints_DFG

Browse files

Source: https://github.com/JunlinHan/CWR (adaptation)

.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
ModelLoader.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models import create_model
2
+ import os
3
+
4
+ ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints')
5
+
6
+ class Options(object):
7
+ def __init__(self, *initial_data, **kwargs):
8
+ for dictionary in initial_data:
9
+ for key in dictionary:
10
+ setattr(self, key, dictionary[key])
11
+ for key in kwargs:
12
+ setattr(self, key, kwargs[key])
13
+
14
+ class ModelLoader:
15
+ def __init__(self) -> None:
16
+ self.opt = Options({
17
+ 'name': 'original',
18
+ 'checkpoints_dir': ckp_path,
19
+ 'gpu_ids': [],
20
+ 'init_gain': 0.02,
21
+ 'init_type': 'xavier',
22
+ 'input_nc': 3,
23
+ 'output_nc': 3,
24
+ 'isTrain': False,
25
+ 'model': 'cwr',
26
+ 'nce_idt': False,
27
+ 'nce_layers': '0',
28
+ 'ndf': 64,
29
+ 'netD': 'basic',
30
+ 'netG': 'resnet_9blocks',
31
+ 'netF': 'reshape',
32
+ 'ngf': 64,
33
+ 'no_antialias_up': None,
34
+ 'no_antialias': None,
35
+ 'no_dropout': True,
36
+ 'normD': 'instance',
37
+ 'normG': 'instance',
38
+ 'preprocess': 'scale_width',
39
+ 'num_threads': 0, # test code only supports num_threads = 1
40
+ 'batch_size': 1, # test code only supports batch_size = 1
41
+ 'serial_batches': True, # disable data shuffling; comment this line if results on randomly chosen images are needed.
42
+ 'no_flip': True, # no flip; comment this line if results on flipped images are needed.
43
+ 'display_id': -1, # no visdom display; the test code saves the results to a HTML file.
44
+ })
45
+ def load(self) -> None:
46
+ self.model = create_model(self.opt)
47
+ self.model.load_networks('latest')
48
+ def inference(self, src=''):
49
+ if not os.path.isfile(src):
50
+ raise Exception('The image %s is not found!' % src)
51
+ # if exist_file()
52
+ print('Loading the image %s' % src)
app.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
5
+
6
+ demo = gr.Interface(
7
+ fn=greet,
8
+ inputs="text",
9
+ outputs="text"
10
+ )
11
+ demo.launch()
checkpoints/original/latest_net_D.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de3e70ce1b90004153f94a5943457d6557ced899e2518101b357bd575cde842a
3
+ size 11147372
checkpoints/original/latest_net_F.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab783f87f5e4df9df8790195d731a9e4bb894e4367ebeb40a5878f5cd5fcdf1a
3
+ size 2248355
checkpoints/original/latest_net_G.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e9a0c8bdf643c69fab732ed8764f958d46a41b3c6a3ba1a2f3fe2606e3650eb
3
+ size 45670757
examples/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ rawimg.png filter=lfs diff=lfs merge=lfs -text
examples/rawimg.png ADDED

Git LFS Details

  • SHA256: 3f71967890232c1c7343de3db3b3dece7faaad73ca52ce26ce309c25a7c97e24
  • Pointer size: 132 Bytes
  • Size of remote file: 5.45 MB
models/__init__.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This package contains modules related to objective functions, optimizations, and network architectures.
2
+
3
+ To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4
+ You need to implement the following five functions:
5
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6
+ -- <set_input>: unpack data from dataset and apply preprocessing.
7
+ -- <forward>: produce intermediate results.
8
+ -- <optimize_parameters>: calculate loss, gradients, and update network weights.
9
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
10
+
11
+ In the function <__init__>, you need to define four lists:
12
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
13
+ -- self.model_names (str list): define networks used in our training.
14
+ -- self.visual_names (str list): specify the images that you want to display and save.
15
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16
+
17
+ Now you can use the model class by specifying flag '--model dummy'.
18
+ See our template model class 'template_model.py' for more details.
19
+ """
20
+
21
+ import importlib
22
+ from models.base_model import BaseModel
23
+
24
+
25
+ def find_model_using_name(model_name):
26
+ """Import the module "models/[model_name]_model.py".
27
+
28
+ In the file, the class called DatasetNameModel() will
29
+ be instantiated. It has to be a subclass of BaseModel,
30
+ and it is case-insensitive.
31
+ """
32
+ model_filename = "models." + model_name + "_model"
33
+ modellib = importlib.import_module(model_filename)
34
+ model = None
35
+ target_model_name = model_name.replace('_', '') + 'model'
36
+ for name, cls in modellib.__dict__.items():
37
+ if name.lower() == target_model_name.lower() \
38
+ and issubclass(cls, BaseModel):
39
+ model = cls
40
+
41
+ if model is None:
42
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43
+ exit(0)
44
+
45
+ return model
46
+
47
+
48
+ def get_option_setter(model_name):
49
+ """Return the static method <modify_commandline_options> of the model class."""
50
+ model_class = find_model_using_name(model_name)
51
+ return model_class.modify_commandline_options
52
+
53
+
54
+ def create_model(opt):
55
+ """Create a model given the option.
56
+
57
+ This function warps the class CustomDatasetDataLoader.
58
+ This is the main interface between this package and 'train.py'/'test.py'
59
+
60
+ Example:
61
+ >>> from models import create_model
62
+ >>> model = create_model(opt)
63
+ """
64
+ model = find_model_using_name(opt.model)
65
+ instance = model(opt)
66
+ print("model [%s] was created" % type(instance).__name__)
67
+ return instance
models/base_model.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from . import networks
6
+
7
+
8
+ class BaseModel(ABC):
9
+ """This class is an abstract base class (ABC) for models.
10
+ To create a subclass, you need to implement the following five functions:
11
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12
+ -- <set_input>: unpack data from dataset and apply preprocessing.
13
+ -- <forward>: produce intermediate results.
14
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
15
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
16
+ """
17
+
18
+ def __init__(self, opt):
19
+ """Initialize the BaseModel class.
20
+
21
+ Parameters:
22
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23
+
24
+ When creating your custom class, you need to implement your own initialization.
25
+ In this fucntion, you should first call <BaseModel.__init__(self, opt)>
26
+ Then, you need to define four lists:
27
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
28
+ -- self.model_names (str list): specify the images that you want to display and save.
29
+ -- self.visual_names (str list): define networks used in our training.
30
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31
+ """
32
+ self.opt = opt
33
+ self.gpu_ids = opt.gpu_ids
34
+ self.isTrain = opt.isTrain
35
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38
+ torch.backends.cudnn.benchmark = True
39
+ self.loss_names = []
40
+ self.model_names = []
41
+ self.visual_names = []
42
+ self.optimizers = []
43
+ self.image_paths = []
44
+ self.metric = 0 # used for learning rate policy 'plateau'
45
+
46
+ @staticmethod
47
+ def dict_grad_hook_factory(add_func=lambda x: x):
48
+ saved_dict = dict()
49
+
50
+ def hook_gen(name):
51
+ def grad_hook(grad):
52
+ saved_vals = add_func(grad)
53
+ saved_dict[name] = saved_vals
54
+ return grad_hook
55
+ return hook_gen, saved_dict
56
+
57
+ @staticmethod
58
+ def modify_commandline_options(parser, is_train):
59
+ """Add new model-specific options, and rewrite default values for existing options.
60
+
61
+ Parameters:
62
+ parser -- original option parser
63
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
64
+
65
+ Returns:
66
+ the modified parser.
67
+ """
68
+ return parser
69
+
70
+ @abstractmethod
71
+ def set_input(self, input):
72
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
73
+
74
+ Parameters:
75
+ input (dict): includes the data itself and its metadata information.
76
+ """
77
+ pass
78
+
79
+ @abstractmethod
80
+ def forward(self):
81
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
82
+ pass
83
+
84
+ @abstractmethod
85
+ def optimize_parameters(self):
86
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
87
+ pass
88
+
89
+ def setup(self, opt):
90
+ """Load and print networks; create schedulers
91
+
92
+ Parameters:
93
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
94
+ """
95
+ if self.isTrain:
96
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
97
+ if not self.isTrain or opt.continue_train:
98
+ load_suffix = opt.epoch
99
+ self.load_networks(load_suffix)
100
+
101
+ self.print_networks(opt.verbose)
102
+
103
+ def parallelize(self):
104
+ for name in self.model_names:
105
+ if isinstance(name, str):
106
+ net = getattr(self, 'net' + name)
107
+ setattr(self, 'net' + name, torch.nn.DataParallel(net, self.opt.gpu_ids))
108
+
109
+ def data_dependent_initialize(self, data):
110
+ pass
111
+
112
+ def eval(self):
113
+ """Make models eval mode during test time"""
114
+ for name in self.model_names:
115
+ if isinstance(name, str):
116
+ net = getattr(self, 'net' + name)
117
+ net.eval()
118
+
119
+ def test(self):
120
+ """Forward function used in test time.
121
+
122
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
123
+ It also calls <compute_visuals> to produce additional visualization results
124
+ """
125
+ with torch.no_grad():
126
+ self.forward()
127
+ self.compute_visuals()
128
+
129
+ def compute_visuals(self):
130
+ """Calculate additional output images for visdom and HTML visualization"""
131
+ pass
132
+
133
+ def get_image_paths(self):
134
+ """ Return image paths that are used to load current data"""
135
+ return self.image_paths
136
+
137
+ def update_learning_rate(self):
138
+ """Update learning rates for all the networks; called at the end of every epoch"""
139
+ for scheduler in self.schedulers:
140
+ if self.opt.lr_policy == 'plateau':
141
+ scheduler.step(self.metric)
142
+ else:
143
+ scheduler.step()
144
+
145
+ lr = self.optimizers[0].param_groups[0]['lr']
146
+ print('learning rate = %.7f' % lr)
147
+
148
+ def get_current_visuals(self):
149
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
150
+ visual_ret = OrderedDict()
151
+ for name in self.visual_names:
152
+ if isinstance(name, str):
153
+ visual_ret[name] = getattr(self, name)
154
+ return visual_ret
155
+
156
+ def get_current_losses(self):
157
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
158
+ errors_ret = OrderedDict()
159
+ for name in self.loss_names:
160
+ if isinstance(name, str):
161
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
162
+ return errors_ret
163
+
164
+ def save_networks(self, epoch):
165
+ """Save all the networks to the disk.
166
+
167
+ Parameters:
168
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
169
+ """
170
+ for name in self.model_names:
171
+ if isinstance(name, str):
172
+ save_filename = '%s_net_%s.pth' % (epoch, name)
173
+ save_path = os.path.join(self.save_dir, save_filename)
174
+ net = getattr(self, 'net' + name)
175
+
176
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
177
+ torch.save(net.module.cpu().state_dict(), save_path)
178
+ net.cuda(self.gpu_ids[0])
179
+ else:
180
+ torch.save(net.cpu().state_dict(), save_path)
181
+
182
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
183
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
184
+ key = keys[i]
185
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
186
+ if module.__class__.__name__.startswith('InstanceNorm') and \
187
+ (key == 'running_mean' or key == 'running_var'):
188
+ if getattr(module, key) is None:
189
+ state_dict.pop('.'.join(keys))
190
+ if module.__class__.__name__.startswith('InstanceNorm') and \
191
+ (key == 'num_batches_tracked'):
192
+ state_dict.pop('.'.join(keys))
193
+ else:
194
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
195
+
196
+ def load_networks(self, epoch):
197
+ """Load all the networks from the disk.
198
+
199
+ Parameters:
200
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
201
+ """
202
+ for name in self.model_names:
203
+ if isinstance(name, str):
204
+ load_filename = '%s_net_%s.pth' % (epoch, name)
205
+ if self.opt.isTrain and self.opt.pretrained_name is not None:
206
+ load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
207
+ else:
208
+ load_dir = self.save_dir
209
+
210
+ load_path = os.path.join(load_dir, load_filename)
211
+ net = getattr(self, 'net' + name)
212
+ if isinstance(net, torch.nn.DataParallel):
213
+ net = net.module
214
+ print('loading the model from %s' % load_path)
215
+ # if you are using PyTorch newer than 0.4 (e.g., built from
216
+ # GitHub source), you can remove str() on self.device
217
+ state_dict = torch.load(load_path, map_location=str(self.device), weights_only=True)
218
+ if hasattr(state_dict, '_metadata'):
219
+ del state_dict._metadata
220
+
221
+ # patch InstanceNorm checkpoints prior to 0.4
222
+ # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
223
+ # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
224
+ net.load_state_dict(state_dict)
225
+
226
+ def print_networks(self, verbose):
227
+ """Print the total number of parameters in the network and (if verbose) network architecture
228
+
229
+ Parameters:
230
+ verbose (bool) -- if verbose: print the network architecture
231
+ """
232
+ print('---------- Networks initialized -------------')
233
+ for name in self.model_names:
234
+ if isinstance(name, str):
235
+ net = getattr(self, 'net' + name)
236
+ num_params = 0
237
+ for param in net.parameters():
238
+ num_params += param.numel()
239
+ if verbose:
240
+ print(net)
241
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
242
+ print('-----------------------------------------------')
243
+
244
+ def set_requires_grad(self, nets, requires_grad=False):
245
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
246
+ Parameters:
247
+ nets (network list) -- a list of networks
248
+ requires_grad (bool) -- whether the networks require gradients or not
249
+ """
250
+ if not isinstance(nets, list):
251
+ nets = [nets]
252
+ for net in nets:
253
+ if net is not None:
254
+ for param in net.parameters():
255
+ param.requires_grad = requires_grad
256
+
257
+ def generate_visuals_for_evaluation(self, data, mode):
258
+ return {}
models/cwr_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from .base_model import BaseModel
4
+ from . import networks
5
+ from .patchnce import PatchNCELoss
6
+ import util.util as util
7
+
8
+
9
+ class CWRModel(BaseModel):
10
+ """ This class implements CWR model, described in the paper
11
+ Single Underwater Image Restoration by contrastive learning
12
+ Junlin Han, Mehrdad Shoeiby, Tim Malthus, Elizabeth Botha, Janet Anstee, Saeed Anwar, Ran Wei, Lars Petersson, Mohammad Ali Armin
13
+ International Geoscience and Remote Sensing Symposium (IGARSS), 2021
14
+
15
+
16
+ The code borrows heavily from the PyTorch implementation of CycleGAN and CUT
17
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
18
+ https://github.com/taesungp/contrastive-unpaired-translation
19
+ """
20
+ @staticmethod
21
+ def modify_commandline_options(parser, is_train=True):
22
+ """ Configures options specific for CUT model
23
+ """
24
+ parser.add_argument('--lambda_GAN', type=float, default=1.0, help='weight for GAN loss:GAN(G(X))')
25
+ parser.add_argument('--lambda_NCE', type=float, default=1.0, help='weight for NCE loss: NCE(G(X), X)')
26
+ parser.add_argument('--lambda_IDT', type=float, default=10.0, help='weight for NCE loss: NCE(G(X), X)')
27
+ parser.add_argument('--nce_idt', type=util.str2bool, nargs='?', const=True, default=False, help='use NCE loss for identity mapping: NCE(G(Y), Y))')
28
+ parser.add_argument('--nce_layers', type=str, default='0,4,8,12,16', help='compute NCE loss on which layers')
29
+ parser.add_argument('--nce_includes_all_negatives_from_minibatch',
30
+ type=util.str2bool, nargs='?', const=True, default=False,
31
+ help='(used for single image translation) If True, include the negatives from the other samples of the minibatch when computing the contrastive loss. Please see models/patchnce.py for more details.')
32
+ parser.add_argument('--netF', type=str, default='mlp_sample', choices=['sample', 'reshape', 'mlp_sample'], help='how to downsample the feature map')
33
+ parser.add_argument('--netF_nc', type=int, default=256)
34
+ parser.add_argument('--nce_T', type=float, default=0.07, help='temperature for NCE loss')
35
+ parser.add_argument('--num_patches', type=int, default=256, help='number of patches per layer')
36
+ parser.add_argument('--flip_equivariance',
37
+ type=util.str2bool, nargs='?', const=True, default=False,
38
+ help="Enforce flip-equivariance as additional regularization. It's used by FastCUT, but not CWR")
39
+
40
+ parser.set_defaults(pool_size=0) # no image pooling
41
+
42
+ opt, _ = parser.parse_known_args()
43
+
44
+ parser.set_defaults(nce_idt=True, lambda_NCE=1.0)
45
+
46
+ return parser
47
+
48
+ def __init__(self, opt):
49
+ BaseModel.__init__(self, opt)
50
+
51
+ # specify the training losses you want to print out.
52
+ # The training/test scripts will call <BaseModel.get_current_losses>
53
+ self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE', 'idt']
54
+ self.visual_names = ['real_A', 'fake_B', 'real_B']
55
+ self.nce_layers = [int(i) for i in self.opt.nce_layers.split(',')]
56
+
57
+ if opt.nce_idt and self.isTrain:
58
+ self.loss_names += ['NCE_Y']
59
+ self.visual_names += ['idt_B']
60
+
61
+ if self.isTrain:
62
+ self.model_names = ['G', 'F', 'D']
63
+ else: # during test time, only load G
64
+ self.model_names = ['G']
65
+
66
+ # define networks (both generator and discriminator)
67
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, opt.no_antialias_up, self.gpu_ids, opt)
68
+ self.netF = networks.define_F(opt.input_nc, opt.netF, opt.normG, not opt.no_dropout, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
69
+
70
+ if self.isTrain:
71
+ self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.normD, opt.init_type, opt.init_gain, opt.no_antialias, self.gpu_ids, opt)
72
+
73
+ # define loss functions
74
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
75
+ self.criterionNCE = []
76
+
77
+ for nce_layer in self.nce_layers:
78
+ self.criterionNCE.append(PatchNCELoss(opt).to(self.device))
79
+
80
+ self.criterionIdt = torch.nn.L1Loss().to(self.device)
81
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
82
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2))
83
+ self.optimizers.append(self.optimizer_G)
84
+ self.optimizers.append(self.optimizer_D)
85
+
86
+ def data_dependent_initialize(self, data):
87
+ """
88
+ The feature network netF is defined in terms of the shape of the intermediate, extracted
89
+ features of the encoder portion of netG. Because of this, the weights of netF are
90
+ initialized at the first feedforward pass with some input images.
91
+ Please also see PatchSampleF.create_mlp(), which is called at the first forward() call.
92
+ """
93
+ self.set_input(data)
94
+ bs_per_gpu = self.real_A.size(0) // max(len(self.opt.gpu_ids), 1)
95
+ self.real_A = self.real_A[:bs_per_gpu]
96
+ self.real_B = self.real_B[:bs_per_gpu]
97
+ self.forward() # compute fake images: G(A)
98
+ if self.opt.isTrain:
99
+ self.compute_D_loss().backward() # calculate gradients for D
100
+ self.compute_G_loss().backward() # calculate graidents for G
101
+ if self.opt.lambda_NCE > 0.0:
102
+ self.optimizer_F = torch.optim.Adam(self.netF.parameters(), lr=self.opt.lr, betas=(self.opt.beta1, self.opt.beta2))
103
+ self.optimizers.append(self.optimizer_F)
104
+
105
+ def optimize_parameters(self):
106
+ # forward
107
+ self.forward()
108
+
109
+ # update D
110
+ self.set_requires_grad(self.netD, True)
111
+ self.optimizer_D.zero_grad()
112
+ self.loss_D = self.compute_D_loss()
113
+ self.loss_D.backward()
114
+ self.optimizer_D.step()
115
+
116
+ # update G
117
+ self.set_requires_grad(self.netD, False)
118
+ self.optimizer_G.zero_grad()
119
+ if self.opt.netF == 'mlp_sample':
120
+ self.optimizer_F.zero_grad()
121
+ self.loss_G = self.compute_G_loss()
122
+ self.loss_G.backward()
123
+ self.optimizer_G.step()
124
+ if self.opt.netF == 'mlp_sample':
125
+ self.optimizer_F.step()
126
+
127
+ def set_input(self, input):
128
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
129
+ Parameters:
130
+ input (dict): include the data itself and its metadata information.
131
+ The option 'direction' can be used to swap domain A and domain B.
132
+ """
133
+ AtoB = self.opt.direction == 'AtoB'
134
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
135
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
136
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
137
+
138
+ def forward(self):
139
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
140
+ self.real = torch.cat((self.real_A, self.real_B), dim=0) if self.opt.nce_idt and self.opt.isTrain else self.real_A
141
+ if self.opt.flip_equivariance:
142
+ self.flipped_for_equivariance = self.opt.isTrain and (np.random.random() < 0.5)
143
+ if self.flipped_for_equivariance:
144
+ self.real = torch.flip(self.real, [3])
145
+
146
+ self.fake = self.netG(self.real)
147
+ self.fake_B = self.fake[:self.real_A.size(0)]
148
+ if self.opt.nce_idt:
149
+ self.idt_B = self.fake[self.real_A.size(0):]
150
+
151
+ def compute_D_loss(self):
152
+ """Calculate GAN loss for the discriminator"""
153
+ fake = self.fake_B.detach()
154
+ # Fake; stop backprop to the generator by detaching fake_B
155
+ pred_fake = self.netD(fake)
156
+ self.loss_D_fake = self.criterionGAN(pred_fake, False).mean()
157
+ # Real
158
+ self.pred_real = self.netD(self.real_B)
159
+ loss_D_real = self.criterionGAN(self.pred_real, True)
160
+ self.loss_D_real = loss_D_real.mean()
161
+
162
+ # combine loss and calculate gradients
163
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
164
+ return self.loss_D
165
+
166
+ def compute_G_loss(self):
167
+ """Calculate GAN and NCE loss for the generator"""
168
+ fake = self.fake_B
169
+ # First, G(A) should fake the discriminator
170
+ if self.opt.lambda_GAN > 0.0:
171
+ pred_fake = self.netD(fake)
172
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True).mean() * self.opt.lambda_GAN
173
+ else:
174
+ self.loss_G_GAN = 0.0
175
+
176
+ if self.opt.lambda_NCE > 0.0:
177
+ self.loss_NCE = self.calculate_NCE_loss(self.real_A, self.fake_B)
178
+ else:
179
+ self.loss_NCE, self.loss_NCE_bd = 0.0, 0.0
180
+
181
+ if self.opt.nce_idt and self.opt.lambda_NCE > 0.0:
182
+ self.loss_NCE_Y = 0
183
+ loss_NCE_both = (self.loss_NCE + self.loss_NCE_Y)
184
+ self.loss_idt = self.criterionIdt(self.idt_B, self.real_B) * self.opt.lambda_IDT
185
+ else:
186
+ loss_NCE_both = self.loss_NCE
187
+
188
+ self.loss_G = self.loss_G_GAN + loss_NCE_both + self.loss_idt
189
+ return self.loss_G
190
+
191
+ def calculate_NCE_loss(self, src, tgt):
192
+ n_layers = len(self.nce_layers)
193
+ feat_q = self.netG(tgt, self.nce_layers, encode_only=True)
194
+
195
+ if self.opt.flip_equivariance and self.flipped_for_equivariance:
196
+ feat_q = [torch.flip(fq, [3]) for fq in feat_q]
197
+
198
+ feat_k = self.netG(src, self.nce_layers, encode_only=True)
199
+ feat_k_pool, sample_ids = self.netF(feat_k, self.opt.num_patches, None)
200
+ feat_q_pool, _ = self.netF(feat_q, self.opt.num_patches, sample_ids)
201
+
202
+ total_nce_loss = 0.0
203
+ for f_q, f_k, crit, nce_layer in zip(feat_q_pool, feat_k_pool, self.criterionNCE, self.nce_layers):
204
+ loss = crit(f_q, f_k) * self.opt.lambda_NCE
205
+ total_nce_loss += loss.mean()
206
+
207
+ return total_nce_loss / n_layers
models/networks.py ADDED
@@ -0,0 +1,1401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ import functools
6
+ from torch.optim import lr_scheduler
7
+ import numpy as np
8
+ from .spectralNormalization import SpectralNorm
9
+ from .stylegan_networks import StyleGAN2Discriminator, StyleGAN2Generator, TileStyleGAN2Discriminator
10
+
11
+ ###############################################################################
12
+ # Helper Functions
13
+ ###############################################################################
14
+
15
+
16
+ def get_filter(filt_size=3):
17
+ if(filt_size == 1):
18
+ a = np.array([1., ])
19
+ elif(filt_size == 2):
20
+ a = np.array([1., 1.])
21
+ elif(filt_size == 3):
22
+ a = np.array([1., 2., 1.])
23
+ elif(filt_size == 4):
24
+ a = np.array([1., 3., 3., 1.])
25
+ elif(filt_size == 5):
26
+ a = np.array([1., 4., 6., 4., 1.])
27
+ elif(filt_size == 6):
28
+ a = np.array([1., 5., 10., 10., 5., 1.])
29
+ elif(filt_size == 7):
30
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
31
+
32
+ filt = torch.Tensor(a[:, None] * a[None, :])
33
+ filt = filt / torch.sum(filt)
34
+
35
+ return filt
36
+
37
+
38
+ class Downsample(nn.Module):
39
+ def __init__(self, channels, pad_type='reflect', filt_size=3, stride=2, pad_off=0):
40
+ super(Downsample, self).__init__()
41
+ self.filt_size = filt_size
42
+ self.pad_off = pad_off
43
+ self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2)), int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
44
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
45
+ self.stride = stride
46
+ self.off = int((self.stride - 1) / 2.)
47
+ self.channels = channels
48
+
49
+ filt = get_filter(filt_size=self.filt_size)
50
+ self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
51
+
52
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
53
+
54
+ def forward(self, inp):
55
+ if(self.filt_size == 1):
56
+ if(self.pad_off == 0):
57
+ return inp[:, :, ::self.stride, ::self.stride]
58
+ else:
59
+ return self.pad(inp)[:, :, ::self.stride, ::self.stride]
60
+ else:
61
+ return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
62
+
63
+
64
+ class Upsample2(nn.Module):
65
+ def __init__(self, scale_factor, mode='nearest'):
66
+ super().__init__()
67
+ self.factor = scale_factor
68
+ self.mode = mode
69
+
70
+ def forward(self, x):
71
+ return torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
72
+
73
+
74
+ class Upsample(nn.Module):
75
+ def __init__(self, channels, pad_type='repl', filt_size=4, stride=2):
76
+ super(Upsample, self).__init__()
77
+ self.filt_size = filt_size
78
+ self.filt_odd = np.mod(filt_size, 2) == 1
79
+ self.pad_size = int((filt_size - 1) / 2)
80
+ self.stride = stride
81
+ self.off = int((self.stride - 1) / 2.)
82
+ self.channels = channels
83
+
84
+ filt = get_filter(filt_size=self.filt_size) * (stride**2)
85
+ self.register_buffer('filt', filt[None, None, :, :].repeat((self.channels, 1, 1, 1)))
86
+
87
+ self.pad = get_pad_layer(pad_type)([1, 1, 1, 1])
88
+
89
+ def forward(self, inp):
90
+ ret_val = F.conv_transpose2d(self.pad(inp), self.filt, stride=self.stride, padding=1 + self.pad_size, groups=inp.shape[1])[:, :, 1:, 1:]
91
+ if(self.filt_odd):
92
+ return ret_val
93
+ else:
94
+ return ret_val[:, :, :-1, :-1]
95
+
96
+
97
+ def get_pad_layer(pad_type):
98
+ if(pad_type in ['refl', 'reflect']):
99
+ PadLayer = nn.ReflectionPad2d
100
+ elif(pad_type in ['repl', 'replicate']):
101
+ PadLayer = nn.ReplicationPad2d
102
+ elif(pad_type == 'zero'):
103
+ PadLayer = nn.ZeroPad2d
104
+ else:
105
+ print('Pad type [%s] not recognized' % pad_type)
106
+ return PadLayer
107
+
108
+
109
+ class Identity(nn.Module):
110
+ def forward(self, x):
111
+ return x
112
+
113
+
114
+ def get_norm_layer(norm_type='instance'):
115
+ """Return a normalization layer
116
+
117
+ Parameters:
118
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
119
+
120
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
121
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
122
+ """
123
+ if norm_type == 'batch':
124
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
125
+ elif norm_type == 'instance':
126
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
127
+ elif norm_type == 'none':
128
+ def norm_layer(x):
129
+ return Identity()
130
+ else:
131
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
132
+ return norm_layer
133
+
134
+
135
+ def get_scheduler(optimizer, opt):
136
+ """Return a learning rate scheduler
137
+
138
+ Parameters:
139
+ optimizer -- the optimizer of the network
140
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
141
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
142
+
143
+ For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
144
+ and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
145
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
146
+ See https://pytorch.org/docs/stable/optim.html for more details.
147
+ """
148
+ if opt.lr_policy == 'linear':
149
+ def lambda_rule(epoch):
150
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
151
+ return lr_l
152
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
153
+ elif opt.lr_policy == 'step':
154
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
155
+ elif opt.lr_policy == 'plateau':
156
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
157
+ elif opt.lr_policy == 'cosine':
158
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
159
+ else:
160
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
161
+ return scheduler
162
+
163
+
164
+ def init_weights(net, init_type='normal', init_gain=0.02, debug=False):
165
+ """Initialize network weights.
166
+
167
+ Parameters:
168
+ net (network) -- network to be initialized
169
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
170
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
171
+
172
+ We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
173
+ work better for some applications. Feel free to try yourself.
174
+ """
175
+ def init_func(m): # define the initialization function
176
+ classname = m.__class__.__name__
177
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
178
+ if debug:
179
+ print(classname)
180
+ if init_type == 'normal':
181
+ init.normal_(m.weight.data, 0.0, init_gain)
182
+ elif init_type == 'xavier':
183
+ init.xavier_normal_(m.weight.data, gain=init_gain)
184
+ elif init_type == 'kaiming':
185
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
186
+ elif init_type == 'orthogonal':
187
+ init.orthogonal_(m.weight.data, gain=init_gain)
188
+ else:
189
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
190
+ if hasattr(m, 'bias') and m.bias is not None:
191
+ init.constant_(m.bias.data, 0.0)
192
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
193
+ init.normal_(m.weight.data, 1.0, init_gain)
194
+ init.constant_(m.bias.data, 0.0)
195
+
196
+ net.apply(init_func) # apply the initialization function <init_func>
197
+
198
+
199
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], debug=False, initialize_weights=True):
200
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
201
+ Parameters:
202
+ net (network) -- the network to be initialized
203
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
204
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
205
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
206
+
207
+ Return an initialized network.
208
+ """
209
+ if len(gpu_ids) > 0:
210
+ assert(torch.cuda.is_available())
211
+ net.to(gpu_ids[0])
212
+ # if not amp:
213
+ # net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs for non-AMP training
214
+ if initialize_weights:
215
+ init_weights(net, init_type, init_gain=init_gain, debug=debug)
216
+ return net
217
+
218
+
219
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal',
220
+ init_gain=0.02, no_antialias=False, no_antialias_up=False, gpu_ids=[], opt=None):
221
+ """Create a generator
222
+
223
+ Parameters:
224
+ input_nc (int) -- the number of channels in input images
225
+ output_nc (int) -- the number of channels in output images
226
+ ngf (int) -- the number of filters in the last conv layer
227
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
228
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
229
+ use_dropout (bool) -- if use dropout layers.
230
+ init_type (str) -- the name of our initialization method.
231
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
232
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
233
+
234
+ Returns a generator
235
+
236
+ Our current implementation provides two types of generators:
237
+ U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
238
+ The original U-Net paper: https://arxiv.org/abs/1505.04597
239
+
240
+ Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
241
+ Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
242
+ We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
243
+
244
+
245
+ The generator has been initialized by <init_net>. It uses RELU for non-linearity.
246
+ """
247
+ net = None
248
+ norm_layer = get_norm_layer(norm_type=norm)
249
+
250
+ if netG == 'resnet_9blocks':
251
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=9, opt=opt)
252
+ elif netG == 'resnet_6blocks':
253
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=6, opt=opt)
254
+ elif netG == 'resnet_4blocks':
255
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, no_antialias=no_antialias, no_antialias_up=no_antialias_up, n_blocks=4, opt=opt)
256
+ elif netG == 'unet_128':
257
+ net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
258
+ elif netG == 'unet_256':
259
+ net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
260
+ elif netG == 'stylegan2':
261
+ net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, opt=opt)
262
+ elif netG == 'smallstylegan2':
263
+ net = StyleGAN2Generator(input_nc, output_nc, ngf, use_dropout=use_dropout, n_blocks=2, opt=opt)
264
+ elif netG == 'resnet_cat':
265
+ n_blocks = 8
266
+ net = G_Resnet(input_nc, output_nc, opt.nz, num_downs=2, n_res=n_blocks - 4, ngf=ngf, norm='inst', nl_layer='relu')
267
+ else:
268
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
269
+ return init_net(net, init_type, init_gain, gpu_ids, initialize_weights=('stylegan2' not in netG))
270
+
271
+
272
+ def define_F(input_nc, netF, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
273
+ if netF == 'global_pool':
274
+ net = PoolingF()
275
+ elif netF == 'reshape':
276
+ net = ReshapeF()
277
+ elif netF == 'sample':
278
+ net = PatchSampleF(use_mlp=False, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc)
279
+ elif netF == 'mlp_sample':
280
+ net = PatchSampleF(use_mlp=True, init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids, nc=opt.netF_nc)
281
+ elif netF == 'strided_conv':
282
+ net = StridedConvF(init_type=init_type, init_gain=init_gain, gpu_ids=gpu_ids)
283
+ else:
284
+ raise NotImplementedError('projection model name [%s] is not recognized' % netF)
285
+ return init_net(net, init_type, init_gain, gpu_ids)
286
+
287
+
288
+ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, no_antialias=False, gpu_ids=[], opt=None):
289
+ """Create a discriminator
290
+
291
+ Parameters:
292
+ input_nc (int) -- the number of channels in input images
293
+ ndf (int) -- the number of filters in the first conv layer
294
+ netD (str) -- the architecture's name: basic | n_layers | pixel
295
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
296
+ norm (str) -- the type of normalization layers used in the network.
297
+ init_type (str) -- the name of the initialization method.
298
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
299
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
300
+
301
+ Returns a discriminator
302
+
303
+ Our current implementation provides three types of discriminators:
304
+ [basic]: 'PatchGAN' classifier described in the original pix2pix paper.
305
+ It can classify whether 70×70 overlapping patches are real or fake.
306
+ Such a patch-level discriminator architecture has fewer parameters
307
+ than a full-image discriminator and can work on arbitrarily-sized images
308
+ in a fully convolutional fashion.
309
+
310
+ [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
311
+ with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
312
+
313
+ [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
314
+ It encourages greater color diversity but has no effect on spatial statistics.
315
+
316
+ The discriminator has been initialized by <init_net>. It uses Leaky RELU for non-linearity.
317
+ """
318
+ net = None
319
+ norm_layer = get_norm_layer(norm_type=norm)
320
+
321
+ if netD == 'basic': # default PatchGAN classifier
322
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, no_antialias=no_antialias,)
323
+ elif netD == 'n_layers': # more options
324
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, no_antialias=no_antialias,)
325
+ elif netD == 'pixel': # classify if each pixel is real or fake
326
+ net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
327
+ elif 'stylegan2' in netD:
328
+ net = StyleGAN2Discriminator(input_nc, ndf, n_layers_D, no_antialias=no_antialias, opt=opt)
329
+ else:
330
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
331
+ return init_net(net, init_type, init_gain, gpu_ids,
332
+ initialize_weights=('stylegan2' not in netD))
333
+
334
+
335
+ ##############################################################################
336
+ # Classes
337
+ ##############################################################################
338
+ class GANLoss(nn.Module):
339
+ """Define different GAN objectives.
340
+
341
+ The GANLoss class abstracts away the need to create the target label tensor
342
+ that has the same size as the input.
343
+ """
344
+
345
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
346
+ """ Initialize the GANLoss class.
347
+
348
+ Parameters:
349
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
350
+ target_real_label (bool) - - label for a real image
351
+ target_fake_label (bool) - - label of a fake image
352
+
353
+ Note: Do not use sigmoid as the last layer of Discriminator.
354
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
355
+ """
356
+ super(GANLoss, self).__init__()
357
+ self.register_buffer('real_label', torch.tensor(target_real_label))
358
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
359
+ self.gan_mode = gan_mode
360
+ if gan_mode == 'lsgan':
361
+ self.loss = nn.MSELoss()
362
+ elif gan_mode == 'vanilla':
363
+ self.loss = nn.BCEWithLogitsLoss()
364
+ elif gan_mode in ['wgangp', 'nonsaturating']:
365
+ self.loss = None
366
+ elif gan_mode == "hinge":
367
+ self.loss = None
368
+ else:
369
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
370
+
371
+ def get_target_tensor(self, prediction, target_is_real):
372
+ """Create label tensors with the same size as the input.
373
+
374
+ Parameters:
375
+ prediction (tensor) - - tpyically the prediction from a discriminator
376
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
377
+
378
+ Returns:
379
+ A label tensor filled with ground truth label, and with the size of the input
380
+ """
381
+
382
+ if target_is_real:
383
+ target_tensor = self.real_label
384
+ else:
385
+ target_tensor = self.fake_label
386
+ return target_tensor.expand_as(prediction)
387
+
388
+ def __call__(self, prediction, target_is_real):
389
+ """Calculate loss given Discriminator's output and grount truth labels.
390
+
391
+ Parameters:
392
+ prediction (tensor) - - tpyically the prediction output from a discriminator
393
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
394
+
395
+ Returns:
396
+ the calculated loss.
397
+ """
398
+ bs = prediction.size(0)
399
+ if self.gan_mode in ['lsgan', 'vanilla']:
400
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
401
+ loss = self.loss(prediction, target_tensor)
402
+ elif self.gan_mode == 'wgangp':
403
+ if target_is_real:
404
+ loss = -prediction.mean()
405
+ else:
406
+ loss = prediction.mean()
407
+ elif self.gan_mode == 'nonsaturating':
408
+ if target_is_real:
409
+ loss = F.softplus(-prediction).view(bs, -1).mean(dim=1)
410
+ else:
411
+ loss = F.softplus(prediction).view(bs, -1).mean(dim=1)
412
+ elif self.gan_mode == 'hinge':
413
+ if target_is_real:
414
+ minvalue = torch.min(prediction - 1, torch.zeros(prediction.shape).to(prediction.device))
415
+ loss = -torch.mean(minvalue)
416
+ else:
417
+ minvalue = torch.min(-prediction - 1,torch.zeros(prediction.shape).to(prediction.device))
418
+ loss = -torch.mean(minvalue)
419
+ return loss
420
+
421
+
422
+ def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
423
+ """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
424
+
425
+ Arguments:
426
+ netD (network) -- discriminator network
427
+ real_data (tensor array) -- real images
428
+ fake_data (tensor array) -- generated images from the generator
429
+ device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
430
+ type (str) -- if we mix real and fake data or not [real | fake | mixed].
431
+ constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
432
+ lambda_gp (float) -- weight for this loss
433
+
434
+ Returns the gradient penalty loss
435
+ """
436
+ if lambda_gp > 0.0:
437
+ if type == 'real': # either use real images, fake images, or a linear interpolation of two.
438
+ interpolatesv = real_data
439
+ elif type == 'fake':
440
+ interpolatesv = fake_data
441
+ elif type == 'mixed':
442
+ alpha = torch.rand(real_data.shape[0], 1, device=device)
443
+ alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
444
+ interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
445
+ else:
446
+ raise NotImplementedError('{} not implemented'.format(type))
447
+ interpolatesv.requires_grad_(True)
448
+ disc_interpolates = netD(interpolatesv)
449
+ gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
450
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
451
+ create_graph=True, retain_graph=True, only_inputs=True)
452
+ gradients = gradients[0].view(real_data.size(0), -1) # flat the data
453
+ gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
454
+ return gradient_penalty, gradients
455
+ else:
456
+ return 0.0, None
457
+
458
+
459
+ class Normalize(nn.Module):
460
+
461
+ def __init__(self, power=2):
462
+ super(Normalize, self).__init__()
463
+ self.power = power
464
+
465
+ def forward(self, x):
466
+ norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
467
+ out = x.div(norm + 1e-7)
468
+ return out
469
+
470
+
471
+ class PoolingF(nn.Module):
472
+ def __init__(self):
473
+ super(PoolingF, self).__init__()
474
+ model = [nn.AdaptiveMaxPool2d(1)]
475
+ self.model = nn.Sequential(*model)
476
+ self.l2norm = Normalize(2)
477
+
478
+ def forward(self, x):
479
+ return self.l2norm(self.model(x))
480
+
481
+
482
+ class ReshapeF(nn.Module):
483
+ def __init__(self):
484
+ super(ReshapeF, self).__init__()
485
+ model = [nn.AdaptiveAvgPool2d(4)]
486
+ self.model = nn.Sequential(*model)
487
+ self.l2norm = Normalize(2)
488
+
489
+ def forward(self, x):
490
+ x = self.model(x)
491
+ x_reshape = x.permute(0, 2, 3, 1).flatten(0, 2)
492
+ return self.l2norm(x_reshape)
493
+
494
+ class StridedConvF(nn.Module):
495
+ def __init__(self, init_type='normal', init_gain=0.02, gpu_ids=[]):
496
+ super().__init__()
497
+ # self.conv1 = nn.Conv2d(256, 128, 3, stride=2)
498
+ # self.conv2 = nn.Conv2d(128, 64, 3, stride=1)
499
+ self.l2_norm = Normalize(2)
500
+ self.mlps = {}
501
+ self.moving_averages = {}
502
+ self.init_type = init_type
503
+ self.init_gain = init_gain
504
+ self.gpu_ids = gpu_ids
505
+
506
+ def create_mlp(self, x):
507
+ C, H = x.shape[1], x.shape[2]
508
+ n_down = int(np.rint(np.log2(H / 32)))
509
+ mlp = []
510
+ for i in range(n_down):
511
+ mlp.append(nn.Conv2d(C, max(C // 2, 64), 3, stride=2))
512
+ mlp.append(nn.ReLU())
513
+ C = max(C // 2, 64)
514
+ mlp.append(nn.Conv2d(C, 64, 3))
515
+ mlp = nn.Sequential(*mlp)
516
+ init_net(mlp, self.init_type, self.init_gain, self.gpu_ids)
517
+ return mlp
518
+
519
+ def update_moving_average(self, key, x):
520
+ if key not in self.moving_averages:
521
+ self.moving_averages[key] = x.detach()
522
+
523
+ self.moving_averages[key] = self.moving_averages[key] * 0.999 + x.detach() * 0.001
524
+
525
+ def forward(self, x, use_instance_norm=False):
526
+ C, H = x.shape[1], x.shape[2]
527
+ key = '%d_%d' % (C, H)
528
+ if key not in self.mlps:
529
+ self.mlps[key] = self.create_mlp(x)
530
+ self.add_module("child_%s" % key, self.mlps[key])
531
+ mlp = self.mlps[key]
532
+ x = mlp(x)
533
+ self.update_moving_average(key, x)
534
+ x = x - self.moving_averages[key]
535
+ if use_instance_norm:
536
+ x = F.instance_norm(x)
537
+ return self.l2_norm(x)
538
+
539
+
540
+ class PatchSampleF(nn.Module):
541
+ def __init__(self, use_mlp=False, init_type='normal', init_gain=0.02, nc=256, gpu_ids=[]):
542
+ # potential issues: currently, we use the same patch_ids for multiple images in the batch
543
+ super(PatchSampleF, self).__init__()
544
+ self.l2norm = Normalize(2)
545
+ self.use_mlp = use_mlp
546
+ self.nc = nc # hard-coded
547
+ self.mlp_init = False
548
+ self.init_type = init_type
549
+ self.init_gain = init_gain
550
+ self.gpu_ids = gpu_ids
551
+
552
+ def create_mlp(self, feats):
553
+ for mlp_id, feat in enumerate(feats):
554
+ input_nc = feat.shape[1]
555
+ mlp = nn.Sequential(*[nn.Linear(input_nc, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)])
556
+ if len(self.gpu_ids) > 0:
557
+ mlp.cuda()
558
+ setattr(self, 'mlp_%d' % mlp_id, mlp)
559
+ init_net(self, self.init_type, self.init_gain, self.gpu_ids)
560
+ self.mlp_init = True
561
+
562
+ def forward(self, feats, num_patches=64, patch_ids=None):
563
+ return_ids = []
564
+ return_feats = []
565
+ if self.use_mlp and not self.mlp_init:
566
+ self.create_mlp(feats)
567
+ for feat_id, feat in enumerate(feats):
568
+ B, H, W = feat.shape[0], feat.shape[2], feat.shape[3]
569
+ feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2)
570
+ if num_patches > 0:
571
+ if patch_ids is not None:
572
+ patch_id = patch_ids[feat_id]
573
+ else:
574
+ patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device)
575
+ patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device)
576
+ x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1])
577
+ else:
578
+ x_sample = feat_reshape
579
+ patch_id = []
580
+ if self.use_mlp:
581
+ mlp = getattr(self, 'mlp_%d' % feat_id)
582
+ x_sample = mlp(x_sample)
583
+ return_ids.append(patch_id)
584
+ x_sample = self.l2norm(x_sample)
585
+
586
+ if num_patches == 0:
587
+ x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W])
588
+ return_feats.append(x_sample)
589
+ return return_feats, return_ids
590
+
591
+
592
+ class G_Resnet(nn.Module):
593
+ def __init__(self, input_nc, output_nc, nz, num_downs, n_res, ngf=64,
594
+ norm=None, nl_layer=None):
595
+ super(G_Resnet, self).__init__()
596
+ n_downsample = num_downs
597
+ pad_type = 'reflect'
598
+ self.enc_content = ContentEncoder(n_downsample, n_res, input_nc, ngf, norm, nl_layer, pad_type=pad_type)
599
+ if nz == 0:
600
+ self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
601
+ else:
602
+ self.dec = Decoder_all(n_downsample, n_res, self.enc_content.output_dim, output_nc, norm=norm, activ=nl_layer, pad_type=pad_type, nz=nz)
603
+
604
+ def decode(self, content, style=None):
605
+ return self.dec(content, style)
606
+
607
+ def forward(self, image, style=None, nce_layers=[], encode_only=False):
608
+ content, feats = self.enc_content(image, nce_layers=nce_layers, encode_only=encode_only)
609
+ if encode_only:
610
+ return feats
611
+ else:
612
+ images_recon = self.decode(content, style)
613
+ if len(nce_layers) > 0:
614
+ return images_recon, feats
615
+ else:
616
+ return images_recon
617
+
618
+ ##################################################################################
619
+ # Encoder and Decoders
620
+ ##################################################################################
621
+
622
+
623
+ class E_adaIN(nn.Module):
624
+ def __init__(self, input_nc, output_nc=1, nef=64, n_layers=4,
625
+ norm=None, nl_layer=None, vae=False):
626
+ # style encoder
627
+ super(E_adaIN, self).__init__()
628
+ self.enc_style = StyleEncoder(n_layers, input_nc, nef, output_nc, norm='none', activ='relu', vae=vae)
629
+
630
+ def forward(self, image):
631
+ style = self.enc_style(image)
632
+ return style
633
+
634
+
635
+ class StyleEncoder(nn.Module):
636
+ def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, vae=False):
637
+ super(StyleEncoder, self).__init__()
638
+ self.vae = vae
639
+ self.model = []
640
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
641
+ for i in range(2):
642
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
643
+ dim *= 2
644
+ for i in range(n_downsample - 2):
645
+ self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
646
+ self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
647
+ if self.vae:
648
+ self.fc_mean = nn.Linear(dim, style_dim) # , 1, 1, 0)
649
+ self.fc_var = nn.Linear(dim, style_dim) # , 1, 1, 0)
650
+ else:
651
+ self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
652
+
653
+ self.model = nn.Sequential(*self.model)
654
+ self.output_dim = dim
655
+
656
+ def forward(self, x):
657
+ if self.vae:
658
+ output = self.model(x)
659
+ output = output.view(x.size(0), -1)
660
+ output_mean = self.fc_mean(output)
661
+ output_var = self.fc_var(output)
662
+ return output_mean, output_var
663
+ else:
664
+ return self.model(x).view(x.size(0), -1)
665
+
666
+
667
+ class ContentEncoder(nn.Module):
668
+ def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type='zero'):
669
+ super(ContentEncoder, self).__init__()
670
+ self.model = []
671
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type='reflect')]
672
+ # downsampling blocks
673
+ for i in range(n_downsample):
674
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type='reflect')]
675
+ dim *= 2
676
+ # residual blocks
677
+ self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
678
+ self.model = nn.Sequential(*self.model)
679
+ self.output_dim = dim
680
+
681
+ def forward(self, x, nce_layers=[], encode_only=False):
682
+ if len(nce_layers) > 0:
683
+ feat = x
684
+ feats = []
685
+ for layer_id, layer in enumerate(self.model):
686
+ feat = layer(feat)
687
+ if layer_id in nce_layers:
688
+ feats.append(feat)
689
+ if layer_id == nce_layers[-1] and encode_only:
690
+ return None, feats
691
+ return feat, feats
692
+ else:
693
+ return self.model(x), None
694
+
695
+ for layer_id, layer in enumerate(self.model):
696
+ print(layer_id, layer)
697
+
698
+
699
+ class Decoder_all(nn.Module):
700
+ def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
701
+ super(Decoder_all, self).__init__()
702
+ # AdaIN residual blocks
703
+ self.resnet_block = ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)
704
+ self.n_blocks = 0
705
+ # upsampling blocks
706
+ for i in range(n_upsample):
707
+ block = [Upsample2(scale_factor=2), Conv2dBlock(dim + nz, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
708
+ setattr(self, 'block_{:d}'.format(self.n_blocks), nn.Sequential(*block))
709
+ self.n_blocks += 1
710
+ dim //= 2
711
+ # use reflection padding in the last conv layer
712
+ setattr(self, 'block_{:d}'.format(self.n_blocks), Conv2dBlock(dim + nz, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect'))
713
+ self.n_blocks += 1
714
+
715
+ def forward(self, x, y=None):
716
+ if y is not None:
717
+ output = self.resnet_block(cat_feature(x, y))
718
+ for n in range(self.n_blocks):
719
+ block = getattr(self, 'block_{:d}'.format(n))
720
+ if n > 0:
721
+ output = block(cat_feature(output, y))
722
+ else:
723
+ output = block(output)
724
+ return output
725
+
726
+
727
+ class Decoder(nn.Module):
728
+ def __init__(self, n_upsample, n_res, dim, output_dim, norm='batch', activ='relu', pad_type='zero', nz=0):
729
+ super(Decoder, self).__init__()
730
+
731
+ self.model = []
732
+ # AdaIN residual blocks
733
+ self.model += [ResBlocks(n_res, dim, norm, activ, pad_type=pad_type, nz=nz)]
734
+ # upsampling blocks
735
+ for i in range(n_upsample):
736
+ if i == 0:
737
+ input_dim = dim + nz
738
+ else:
739
+ input_dim = dim
740
+ self.model += [Upsample2(scale_factor=2), Conv2dBlock(input_dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type='reflect')]
741
+ dim //= 2
742
+ # use reflection padding in the last conv layer
743
+ self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type='reflect')]
744
+ self.model = nn.Sequential(*self.model)
745
+
746
+ def forward(self, x, y=None):
747
+ if y is not None:
748
+ return self.model(cat_feature(x, y))
749
+ else:
750
+ return self.model(x)
751
+
752
+ ##################################################################################
753
+ # Sequential Models
754
+ ##################################################################################
755
+
756
+
757
+ class ResBlocks(nn.Module):
758
+ def __init__(self, num_blocks, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
759
+ super(ResBlocks, self).__init__()
760
+ self.model = []
761
+ for i in range(num_blocks):
762
+ self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type, nz=nz)]
763
+ self.model = nn.Sequential(*self.model)
764
+
765
+ def forward(self, x):
766
+ return self.model(x)
767
+
768
+
769
+ ##################################################################################
770
+ # Basic Blocks
771
+ ##################################################################################
772
+ def cat_feature(x, y):
773
+ y_expand = y.view(y.size(0), y.size(1), 1, 1).expand(
774
+ y.size(0), y.size(1), x.size(2), x.size(3))
775
+ x_cat = torch.cat([x, y_expand], 1)
776
+ return x_cat
777
+
778
+
779
+ class ResBlock(nn.Module):
780
+ def __init__(self, dim, norm='inst', activation='relu', pad_type='zero', nz=0):
781
+ super(ResBlock, self).__init__()
782
+
783
+ model = []
784
+ model += [Conv2dBlock(dim + nz, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
785
+ model += [Conv2dBlock(dim, dim + nz, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
786
+ self.model = nn.Sequential(*model)
787
+
788
+ def forward(self, x):
789
+ residual = x
790
+ out = self.model(x)
791
+ out += residual
792
+ return out
793
+
794
+
795
+ class Conv2dBlock(nn.Module):
796
+ def __init__(self, input_dim, output_dim, kernel_size, stride,
797
+ padding=0, norm='none', activation='relu', pad_type='zero'):
798
+ super(Conv2dBlock, self).__init__()
799
+ self.use_bias = True
800
+ # initialize padding
801
+ if pad_type == 'reflect':
802
+ self.pad = nn.ReflectionPad2d(padding)
803
+ elif pad_type == 'zero':
804
+ self.pad = nn.ZeroPad2d(padding)
805
+ else:
806
+ assert 0, "Unsupported padding type: {}".format(pad_type)
807
+
808
+ # initialize normalization
809
+ norm_dim = output_dim
810
+ if norm == 'batch':
811
+ self.norm = nn.BatchNorm2d(norm_dim)
812
+ elif norm == 'inst':
813
+ self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=False)
814
+ elif norm == 'ln':
815
+ self.norm = LayerNorm(norm_dim)
816
+ elif norm == 'none':
817
+ self.norm = None
818
+ else:
819
+ assert 0, "Unsupported normalization: {}".format(norm)
820
+
821
+ # initialize activation
822
+ if activation == 'relu':
823
+ self.activation = nn.ReLU(inplace=True)
824
+ elif activation == 'lrelu':
825
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
826
+ elif activation == 'prelu':
827
+ self.activation = nn.PReLU()
828
+ elif activation == 'selu':
829
+ self.activation = nn.SELU(inplace=True)
830
+ elif activation == 'tanh':
831
+ self.activation = nn.Tanh()
832
+ elif activation == 'none':
833
+ self.activation = None
834
+ else:
835
+ assert 0, "Unsupported activation: {}".format(activation)
836
+
837
+ # initialize convolution
838
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
839
+
840
+ def forward(self, x):
841
+ x = self.conv(self.pad(x))
842
+ if self.norm:
843
+ x = self.norm(x)
844
+ if self.activation:
845
+ x = self.activation(x)
846
+ return x
847
+
848
+
849
+ class LinearBlock(nn.Module):
850
+ def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
851
+ super(LinearBlock, self).__init__()
852
+ use_bias = True
853
+ # initialize fully connected layer
854
+ self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
855
+
856
+ # initialize normalization
857
+ norm_dim = output_dim
858
+ if norm == 'batch':
859
+ self.norm = nn.BatchNorm1d(norm_dim)
860
+ elif norm == 'inst':
861
+ self.norm = nn.InstanceNorm1d(norm_dim)
862
+ elif norm == 'ln':
863
+ self.norm = LayerNorm(norm_dim)
864
+ elif norm == 'none':
865
+ self.norm = None
866
+ else:
867
+ assert 0, "Unsupported normalization: {}".format(norm)
868
+
869
+ # initialize activation
870
+ if activation == 'relu':
871
+ self.activation = nn.ReLU(inplace=True)
872
+ elif activation == 'lrelu':
873
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
874
+ elif activation == 'prelu':
875
+ self.activation = nn.PReLU()
876
+ elif activation == 'selu':
877
+ self.activation = nn.SELU(inplace=True)
878
+ elif activation == 'tanh':
879
+ self.activation = nn.Tanh()
880
+ elif activation == 'none':
881
+ self.activation = None
882
+ else:
883
+ assert 0, "Unsupported activation: {}".format(activation)
884
+
885
+ def forward(self, x):
886
+ out = self.fc(x)
887
+ if self.norm:
888
+ out = self.norm(out)
889
+ if self.activation:
890
+ out = self.activation(out)
891
+ return out
892
+
893
+ ##################################################################################
894
+ # Normalization layers
895
+ ##################################################################################
896
+
897
+
898
+ class LayerNorm(nn.Module):
899
+ def __init__(self, num_features, eps=1e-5, affine=True):
900
+ super(LayerNorm, self).__init__()
901
+ self.num_features = num_features
902
+ self.affine = affine
903
+ self.eps = eps
904
+
905
+ if self.affine:
906
+ self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
907
+ self.beta = nn.Parameter(torch.zeros(num_features))
908
+
909
+ def forward(self, x):
910
+ shape = [-1] + [1] * (x.dim() - 1)
911
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
912
+ std = x.view(x.size(0), -1).std(1).view(*shape)
913
+ x = (x - mean) / (std + self.eps)
914
+
915
+ if self.affine:
916
+ shape = [1, -1] + [1] * (x.dim() - 2)
917
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
918
+ return x
919
+
920
+
921
+ class ResnetGenerator(nn.Module):
922
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
923
+
924
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
925
+ """
926
+
927
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, no_antialias_up=False, opt=None):
928
+ """Construct a Resnet-based generator
929
+
930
+ Parameters:
931
+ input_nc (int) -- the number of channels in input images
932
+ output_nc (int) -- the number of channels in output images
933
+ ngf (int) -- the number of filters in the last conv layer
934
+ norm_layer -- normalization layer
935
+ use_dropout (bool) -- if use dropout layers
936
+ n_blocks (int) -- the number of ResNet blocks
937
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
938
+ """
939
+ assert(n_blocks >= 0)
940
+ super(ResnetGenerator, self).__init__()
941
+ self.opt = opt
942
+ if type(norm_layer) == functools.partial:
943
+ use_bias = norm_layer.func == nn.InstanceNorm2d
944
+ else:
945
+ use_bias = norm_layer == nn.InstanceNorm2d
946
+
947
+ model = [nn.ReflectionPad2d(3),
948
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
949
+ norm_layer(ngf),
950
+ nn.ReLU(True)]
951
+
952
+ n_downsampling = 2
953
+ for i in range(n_downsampling): # add downsampling layers
954
+ mult = 2 ** i
955
+ if(no_antialias):
956
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
957
+ norm_layer(ngf * mult * 2),
958
+ nn.ReLU(True)]
959
+ else:
960
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias),
961
+ norm_layer(ngf * mult * 2),
962
+ nn.ReLU(True),
963
+ Downsample(ngf * mult * 2)]
964
+
965
+ mult = 2 ** n_downsampling
966
+ for i in range(n_blocks): # add ResNet blocks
967
+
968
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
969
+
970
+ for i in range(n_downsampling): # add upsampling layers
971
+ mult = 2 ** (n_downsampling - i)
972
+ if no_antialias_up:
973
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
974
+ kernel_size=3, stride=2,
975
+ padding=1, output_padding=1,
976
+ bias=use_bias),
977
+ norm_layer(int(ngf * mult / 2)),
978
+ nn.ReLU(True)]
979
+ else:
980
+ model += [Upsample(ngf * mult),
981
+ nn.Conv2d(ngf * mult, int(ngf * mult / 2),
982
+ kernel_size=3, stride=1,
983
+ padding=1, # output_padding=1,
984
+ bias=use_bias),
985
+ norm_layer(int(ngf * mult / 2)),
986
+ nn.ReLU(True)]
987
+ model += [nn.ReflectionPad2d(3)]
988
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
989
+ model += [nn.Tanh()]
990
+
991
+ self.model = nn.Sequential(*model)
992
+
993
+ def forward(self, input, layers=[], encode_only=False):
994
+ if -1 in layers:
995
+ layers.append(len(self.model))
996
+ if len(layers) > 0:
997
+ feat = input
998
+ feats = []
999
+ for layer_id, layer in enumerate(self.model):
1000
+ # print(layer_id, layer)
1001
+ feat = layer(feat)
1002
+ if layer_id in layers:
1003
+ # print("%d: adding the output of %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
1004
+ feats.append(feat)
1005
+ else:
1006
+ # print("%d: skipping %s %d" % (layer_id, layer.__class__.__name__, feat.size(1)))
1007
+ pass
1008
+ if layer_id == layers[-1] and encode_only:
1009
+ # print('encoder only return features')
1010
+ return feats # return intermediate features alone; stop in the last layers
1011
+
1012
+ return feat, feats # return both output and intermediate features
1013
+ else:
1014
+ """Standard forward"""
1015
+ fake = self.model(input)
1016
+ return fake
1017
+
1018
+
1019
+ class ResnetDecoder(nn.Module):
1020
+ """Resnet-based decoder that consists of a few Resnet blocks + a few upsampling operations.
1021
+ """
1022
+
1023
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
1024
+ """Construct a Resnet-based decoder
1025
+
1026
+ Parameters:
1027
+ input_nc (int) -- the number of channels in input images
1028
+ output_nc (int) -- the number of channels in output images
1029
+ ngf (int) -- the number of filters in the last conv layer
1030
+ norm_layer -- normalization layer
1031
+ use_dropout (bool) -- if use dropout layers
1032
+ n_blocks (int) -- the number of ResNet blocks
1033
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
1034
+ """
1035
+ assert(n_blocks >= 0)
1036
+ super(ResnetDecoder, self).__init__()
1037
+ if type(norm_layer) == functools.partial:
1038
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1039
+ else:
1040
+ use_bias = norm_layer == nn.InstanceNorm2d
1041
+ model = []
1042
+ n_downsampling = 2
1043
+ mult = 2 ** n_downsampling
1044
+ for i in range(n_blocks): # add ResNet blocks
1045
+
1046
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1047
+
1048
+ for i in range(n_downsampling): # add upsampling layers
1049
+ mult = 2 ** (n_downsampling - i)
1050
+ if(no_antialias):
1051
+ model += [SpectralNorm(nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2)),
1052
+ kernel_size=3, stride=2,
1053
+ padding=1, output_padding=1,
1054
+ bias=use_bias),
1055
+ nn.ReLU(True)]
1056
+ else:
1057
+ model += [Upsample(ngf * mult),
1058
+ SpectralNorm(nn.Conv2d(ngf * mult, int(ngf * mult / 2)),
1059
+ kernel_size=3, stride=1,
1060
+ padding=1,
1061
+ bias=use_bias),
1062
+ norm_layer(int(ngf * mult / 2)),
1063
+ nn.ReLU(True)]
1064
+ model += [nn.ReflectionPad2d(3)]
1065
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
1066
+ model += [nn.Tanh()]
1067
+
1068
+ self.model = nn.Sequential(*model)
1069
+
1070
+ def forward(self, input):
1071
+ """Standard forward"""
1072
+ return self.model(input)
1073
+
1074
+
1075
+ class ResnetEncoder(nn.Module):
1076
+ """Resnet-based encoder that consists of a few downsampling + several Resnet blocks
1077
+ """
1078
+
1079
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False):
1080
+ """Construct a Resnet-based encoder
1081
+
1082
+ Parameters:
1083
+ input_nc (int) -- the number of channels in input images
1084
+ output_nc (int) -- the number of channels in output images
1085
+ ngf (int) -- the number of filters in the last conv layer
1086
+ norm_layer -- normalization layer
1087
+ use_dropout (bool) -- if use dropout layers
1088
+ n_blocks (int) -- the number of ResNet blocks
1089
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
1090
+ """
1091
+ assert(n_blocks >= 0)
1092
+ super(ResnetEncoder, self).__init__()
1093
+ if type(norm_layer) == functools.partial:
1094
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1095
+ else:
1096
+ use_bias = norm_layer == nn.InstanceNorm2d
1097
+
1098
+ model = [nn.ReflectionPad2d(3),
1099
+ SpectralNorm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias)),
1100
+ nn.ReLU(True)]
1101
+
1102
+ n_downsampling = 2
1103
+ for i in range(n_downsampling): # add downsampling layers
1104
+ mult = 2 ** i
1105
+ if(no_antialias):
1106
+ model += [SpectralNorm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias)),
1107
+ nn.ReLU(True)]
1108
+ else:
1109
+ model += [SpectralNorm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=1, padding=1, bias=use_bias)),
1110
+ nn.ReLU(True),
1111
+ Downsample(ngf * mult * 2)]
1112
+
1113
+ mult = 2 ** n_downsampling
1114
+ for i in range(n_blocks): # add ResNet blocks
1115
+
1116
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
1117
+
1118
+ self.model = nn.Sequential(*model)
1119
+
1120
+ def forward(self, input):
1121
+ """Standard forward"""
1122
+ return self.model(input)
1123
+
1124
+
1125
+ class ResnetBlock(nn.Module):
1126
+ """Define a Resnet block"""
1127
+
1128
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
1129
+ """Initialize the Resnet block
1130
+
1131
+ A resnet block is a conv block with skip connections
1132
+ We construct a conv block with build_conv_block function,
1133
+ and implement skip connections in <forward> function.
1134
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
1135
+ """
1136
+ super(ResnetBlock, self).__init__()
1137
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
1138
+
1139
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
1140
+ """Construct a convolutional block.
1141
+
1142
+ Parameters:
1143
+ dim (int) -- the number of channels in the conv layer.
1144
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
1145
+ norm_layer -- normalization layer
1146
+ use_dropout (bool) -- if use dropout layers.
1147
+ use_bias (bool) -- if the conv layer uses bias or not
1148
+
1149
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
1150
+ """
1151
+ conv_block = []
1152
+ p = 0
1153
+ if padding_type == 'reflect':
1154
+ conv_block += [nn.ReflectionPad2d(1)]
1155
+ elif padding_type == 'replicate':
1156
+ conv_block += [nn.ReplicationPad2d(1)]
1157
+ elif padding_type == 'zero':
1158
+ p = 1
1159
+ else:
1160
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
1161
+
1162
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
1163
+ if use_dropout:
1164
+ conv_block += [nn.Dropout(0.5)]
1165
+
1166
+ p = 0
1167
+ if padding_type == 'reflect':
1168
+ conv_block += [nn.ReflectionPad2d(1)]
1169
+ elif padding_type == 'replicate':
1170
+ conv_block += [nn.ReplicationPad2d(1)]
1171
+ elif padding_type == 'zero':
1172
+ p = 1
1173
+ else:
1174
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
1175
+ conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias))]
1176
+
1177
+ return nn.Sequential(*conv_block)
1178
+
1179
+ def forward(self, x):
1180
+ """Forward function (with skip connections)"""
1181
+ out = x + self.conv_block(x) # add skip connections
1182
+ return out
1183
+
1184
+
1185
+ class UnetGenerator(nn.Module):
1186
+ """Create a Unet-based generator"""
1187
+
1188
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
1189
+ """Construct a Unet generator
1190
+ Parameters:
1191
+ input_nc (int) -- the number of channels in input images
1192
+ output_nc (int) -- the number of channels in output images
1193
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
1194
+ image of size 128x128 will become of size 1x1 # at the bottleneck
1195
+ ngf (int) -- the number of filters in the last conv layer
1196
+ norm_layer -- normalization layer
1197
+
1198
+ We construct the U-Net from the innermost layer to the outermost layer.
1199
+ It is a recursive process.
1200
+ """
1201
+ super(UnetGenerator, self).__init__()
1202
+ # construct unet structure
1203
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
1204
+ for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
1205
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
1206
+ # gradually reduce the number of filters from ngf * 8 to ngf
1207
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1208
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1209
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
1210
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
1211
+
1212
+ def forward(self, input):
1213
+ """Standard forward"""
1214
+ return self.model(input)
1215
+
1216
+
1217
+ class UnetSkipConnectionBlock(nn.Module):
1218
+ """Defines the Unet submodule with skip connection.
1219
+ X -------------------identity----------------------
1220
+ |-- downsampling -- |submodule| -- upsampling --|
1221
+ """
1222
+
1223
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
1224
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
1225
+ """Construct a Unet submodule with skip connections.
1226
+
1227
+ Parameters:
1228
+ outer_nc (int) -- the number of filters in the outer conv layer
1229
+ inner_nc (int) -- the number of filters in the inner conv layer
1230
+ input_nc (int) -- the number of channels in input images/features
1231
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
1232
+ outermost (bool) -- if this module is the outermost module
1233
+ innermost (bool) -- if this module is the innermost module
1234
+ norm_layer -- normalization layer
1235
+ use_dropout (bool) -- if use dropout layers.
1236
+ """
1237
+ super(UnetSkipConnectionBlock, self).__init__()
1238
+ self.outermost = outermost
1239
+ if type(norm_layer) == functools.partial:
1240
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1241
+ else:
1242
+ use_bias = norm_layer == nn.InstanceNorm2d
1243
+ if input_nc is None:
1244
+ input_nc = outer_nc
1245
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
1246
+ stride=2, padding=1, bias=use_bias)
1247
+ downrelu = nn.LeakyReLU(0.2, True)
1248
+ downnorm = norm_layer(inner_nc)
1249
+ uprelu = nn.ReLU(True)
1250
+ upnorm = norm_layer(outer_nc)
1251
+
1252
+ if outermost:
1253
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1254
+ kernel_size=4, stride=2,
1255
+ padding=1)
1256
+ down = [downconv]
1257
+ up = [uprelu, upconv, nn.Tanh()]
1258
+ model = down + [submodule] + up
1259
+ elif innermost:
1260
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
1261
+ kernel_size=4, stride=2,
1262
+ padding=1, bias=use_bias)
1263
+ down = [downrelu, downconv]
1264
+ up = [uprelu, upconv, upnorm]
1265
+ model = down + up
1266
+ else:
1267
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
1268
+ kernel_size=4, stride=2,
1269
+ padding=1, bias=use_bias)
1270
+ down = [downrelu, downconv, downnorm]
1271
+ up = [uprelu, upconv, upnorm]
1272
+
1273
+ if use_dropout:
1274
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
1275
+ else:
1276
+ model = down + [submodule] + up
1277
+
1278
+ self.model = nn.Sequential(*model)
1279
+
1280
+ def forward(self, x):
1281
+ if self.outermost:
1282
+ return self.model(x)
1283
+ else: # add skip connections
1284
+ return torch.cat([x, self.model(x)], 1)
1285
+
1286
+
1287
+ class NLayerDiscriminator(nn.Module):
1288
+ """Defines a PatchGAN discriminator"""
1289
+
1290
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
1291
+ """Construct a PatchGAN discriminator
1292
+
1293
+ Parameters:
1294
+ input_nc (int) -- the number of channels in input images
1295
+ ndf (int) -- the number of filters in the last conv layer
1296
+ n_layers (int) -- the number of conv layers in the discriminator
1297
+ norm_layer -- normalization layer
1298
+ """
1299
+ super(NLayerDiscriminator, self).__init__()
1300
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1301
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1302
+ else:
1303
+ use_bias = norm_layer == nn.InstanceNorm2d
1304
+
1305
+ kw = 4
1306
+ padw = 1
1307
+ if(no_antialias):
1308
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
1309
+ else:
1310
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True), Downsample(ndf)]
1311
+ nf_mult = 1
1312
+ nf_mult_prev = 1
1313
+ for n in range(1, n_layers): # gradually increase the number of filters
1314
+ nf_mult_prev = nf_mult
1315
+ nf_mult = min(2 ** n, 8)
1316
+ if(no_antialias):
1317
+ sequence += [
1318
+ SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
1319
+ nn.LeakyReLU(0.2, True)
1320
+ ]
1321
+ else:
1322
+ sequence += [
1323
+ SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
1324
+ nn.LeakyReLU(0.2, True),
1325
+ Downsample(ndf * nf_mult)]
1326
+
1327
+ nf_mult_prev = nf_mult
1328
+ nf_mult = min(2 ** n_layers, 8)
1329
+ sequence += [
1330
+ SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
1331
+ nn.LeakyReLU(0.2, True)
1332
+ ]
1333
+ sequence += [SpectralNorm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw))] # output 1 channel prediction map
1334
+ self.model = nn.Sequential(*sequence)
1335
+
1336
+ def forward(self, input):
1337
+ """Standard forward."""
1338
+ return self.model(input)
1339
+
1340
+
1341
+ class PixelDiscriminator(nn.Module):
1342
+ """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
1343
+
1344
+ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
1345
+ """Construct a 1x1 PatchGAN discriminator
1346
+
1347
+ Parameters:
1348
+ input_nc (int) -- the number of channels in input images
1349
+ ndf (int) -- the number of filters in the last conv layer
1350
+ norm_layer -- normalization layer
1351
+ """
1352
+ super(PixelDiscriminator, self).__init__()
1353
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
1354
+ use_bias = norm_layer.func == nn.InstanceNorm2d
1355
+ else:
1356
+ use_bias = norm_layer == nn.InstanceNorm2d
1357
+
1358
+ self.net = [
1359
+ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
1360
+ nn.LeakyReLU(0.2, True),
1361
+ nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
1362
+ norm_layer(ndf * 2),
1363
+ nn.LeakyReLU(0.2, True),
1364
+ nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
1365
+
1366
+ self.net = nn.Sequential(*self.net)
1367
+
1368
+ def forward(self, input):
1369
+ """Standard forward."""
1370
+ return self.net(input)
1371
+
1372
+
1373
+ class PatchDiscriminator(NLayerDiscriminator):
1374
+ """Defines a PatchGAN discriminator"""
1375
+
1376
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, no_antialias=False):
1377
+ super().__init__(input_nc, ndf, 2, norm_layer, no_antialias)
1378
+
1379
+ def forward(self, input):
1380
+ B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
1381
+ size = 16
1382
+ Y = H // size
1383
+ X = W // size
1384
+ input = input.view(B, C, Y, size, X, size)
1385
+ input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
1386
+ return super().forward(input)
1387
+
1388
+
1389
+ class GroupedChannelNorm(nn.Module):
1390
+ def __init__(self, num_groups):
1391
+ super().__init__()
1392
+ self.num_groups = num_groups
1393
+
1394
+ def forward(self, x):
1395
+ shape = list(x.shape)
1396
+ new_shape = [shape[0], self.num_groups, shape[1] // self.num_groups] + shape[2:]
1397
+ x = x.view(*new_shape)
1398
+ mean = x.mean(dim=2, keepdim=True)
1399
+ std = x.std(dim=2, keepdim=True)
1400
+ x_norm = (x - mean) / (std + 1e-7)
1401
+ return x_norm.view(*shape)
models/patchnce.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class PatchNCELoss(nn.Module):
7
+ def __init__(self, opt):
8
+ super().__init__()
9
+ self.opt = opt
10
+ self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
11
+ self.mask_dtype = torch.uint8 if version.parse(torch.__version__) < version.parse('1.2.0') else torch.bool
12
+
13
+ def forward(self, feat_q, feat_k):
14
+ batchSize = feat_q.shape[0]
15
+ dim = feat_q.shape[1]
16
+ feat_k = feat_k.detach()
17
+
18
+ # pos logit
19
+ l_pos = torch.bmm(feat_q.view(batchSize, 1, -1), feat_k.view(batchSize, -1, 1))
20
+ l_pos = l_pos.view(batchSize, 1)
21
+
22
+ # neg logit
23
+
24
+ # Should the negatives from the other samples of a minibatch be utilized?
25
+ # In CUT and FastCUT, we found that it's best to only include negatives
26
+ # from the same image. Therefore, we set
27
+ # --nce_includes_all_negatives_from_minibatch as False
28
+ # However, for single-image translation, the minibatch consists of
29
+ # crops from the "same" high-resolution image.
30
+ # Therefore, we will include the negatives from the entire minibatch.
31
+ if self.opt.nce_includes_all_negatives_from_minibatch:
32
+ # reshape features as if they are all negatives of minibatch of size 1.
33
+ batch_dim_for_bmm = 1
34
+ else:
35
+ batch_dim_for_bmm = self.opt.batch_size
36
+
37
+ # reshape features to batch size
38
+ feat_q = feat_q.view(batch_dim_for_bmm, -1, dim)
39
+ feat_k = feat_k.view(batch_dim_for_bmm, -1, dim)
40
+ npatches = feat_q.size(1)
41
+ l_neg_curbatch = torch.bmm(feat_q, feat_k.transpose(2, 1))
42
+
43
+ # diagonal entries are similarity between same features, and hence meaningless.
44
+ # just fill the diagonal with very small number, which is exp(-10) and almost zero
45
+ diagonal = torch.eye(npatches, device=feat_q.device, dtype=self.mask_dtype)[None, :, :]
46
+ l_neg_curbatch.masked_fill_(diagonal, -10.0)
47
+ l_neg = l_neg_curbatch.view(-1, npatches)
48
+
49
+ out = torch.cat((l_pos, l_neg), dim=1) / self.opt.nce_T
50
+
51
+ loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long,
52
+ device=feat_q.device))
53
+
54
+ return loss
models/spectralNormalization.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import Tensor
4
+ from torch.nn import Parameter
5
+ from torch import nn
6
+ from torch.autograd import Variable
7
+ def l2normalize(vector, eps = 1e-15):
8
+ return vector/(vector.norm()+eps)
9
+
10
+ class SpectralNorm(nn.Module):
11
+ def __init__(self, module, name='weight', power_iterations=1):
12
+ super(SpectralNorm, self).__init__()
13
+ self.module = module
14
+ self.name = name
15
+ self.power_iterations = power_iterations
16
+ if not self._made_params():
17
+ self._make_params()
18
+
19
+ def _update_u_v(self):
20
+ u = getattr(self.module, self.name + "_u")
21
+ v = getattr(self.module, self.name + "_v")
22
+ w = getattr(self.module, self.name + "_bar")
23
+
24
+ height = w.data.shape[0]
25
+ for _ in range(self.power_iterations):
26
+ v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
27
+ u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
28
+
29
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
30
+ sigma = u.dot(w.view(height, -1).mv(v))
31
+ setattr(self.module, self.name, w / sigma.expand_as(w))
32
+
33
+ def _made_params(self):
34
+ try:
35
+ u = getattr(self.module, self.name + "_u")
36
+ v = getattr(self.module, self.name + "_v")
37
+ w = getattr(self.module, self.name + "_bar")
38
+ return True
39
+ except AttributeError:
40
+ return False
41
+
42
+
43
+ def _make_params(self):
44
+ w = getattr(self.module, self.name)
45
+
46
+ height = w.data.shape[0]
47
+ width = w.view(height, -1).data.shape[1]
48
+
49
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
50
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
51
+ u.data = l2normalize(u.data)
52
+ v.data = l2normalize(v.data)
53
+ w_bar = Parameter(w.data)
54
+
55
+ del self.module._parameters[self.name]
56
+
57
+ self.module.register_parameter(self.name + "_u", u)
58
+ self.module.register_parameter(self.name + "_v", v)
59
+ self.module.register_parameter(self.name + "_bar", w_bar)
60
+
61
+
62
+ def forward(self, *args):
63
+ self._update_u_v()
64
+ return self.module.forward(*args)
models/stylegan_networks.py ADDED
@@ -0,0 +1,914 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The network architectures is based on PyTorch implemenation of StyleGAN2Encoder.
3
+ Original PyTorch repo: https://github.com/rosinality/style-based-gan-pytorch
4
+ Origianl StyelGAN2 paper: https://github.com/NVlabs/stylegan2
5
+ We use the network architeture for our single-image traning setting.
6
+ """
7
+
8
+ import math
9
+ import numpy as np
10
+ import random
11
+
12
+ import torch
13
+ from torch import nn
14
+ from torch.nn import functional as F
15
+
16
+
17
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
18
+ return F.leaky_relu(input + bias, negative_slope) * scale
19
+
20
+
21
+ class FusedLeakyReLU(nn.Module):
22
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
23
+ super().__init__()
24
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
25
+ self.negative_slope = negative_slope
26
+ self.scale = scale
27
+
28
+ def forward(self, input):
29
+ # print("FusedLeakyReLU: ", input.abs().mean())
30
+ out = fused_leaky_relu(input, self.bias,
31
+ self.negative_slope,
32
+ self.scale)
33
+ # print("FusedLeakyReLU: ", out.abs().mean())
34
+ return out
35
+
36
+
37
+ def upfirdn2d_native(
38
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
39
+ ):
40
+ _, minor, in_h, in_w = input.shape
41
+ kernel_h, kernel_w = kernel.shape
42
+
43
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
44
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
45
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
46
+
47
+ out = F.pad(
48
+ out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
49
+ )
50
+ out = out[
51
+ :,
52
+ :,
53
+ max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
54
+ max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0),
55
+ ]
56
+
57
+ # out = out.permute(0, 3, 1, 2)
58
+ out = out.reshape(
59
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
60
+ )
61
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
62
+ out = F.conv2d(out, w)
63
+ out = out.reshape(
64
+ -1,
65
+ minor,
66
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
67
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
68
+ )
69
+ # out = out.permute(0, 2, 3, 1)
70
+
71
+ return out[:, :, ::down_y, ::down_x]
72
+
73
+
74
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
75
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
76
+
77
+
78
+ class PixelNorm(nn.Module):
79
+ def __init__(self):
80
+ super().__init__()
81
+
82
+ def forward(self, input):
83
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
84
+
85
+
86
+ def make_kernel(k):
87
+ k = torch.tensor(k, dtype=torch.float32)
88
+
89
+ if len(k.shape) == 1:
90
+ k = k[None, :] * k[:, None]
91
+
92
+ k /= k.sum()
93
+
94
+ return k
95
+
96
+
97
+ class Upsample(nn.Module):
98
+ def __init__(self, kernel, factor=2):
99
+ super().__init__()
100
+
101
+ self.factor = factor
102
+ kernel = make_kernel(kernel) * (factor ** 2)
103
+ self.register_buffer('kernel', kernel)
104
+
105
+ p = kernel.shape[0] - factor
106
+
107
+ pad0 = (p + 1) // 2 + factor - 1
108
+ pad1 = p // 2
109
+
110
+ self.pad = (pad0, pad1)
111
+
112
+ def forward(self, input):
113
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
114
+
115
+ return out
116
+
117
+
118
+ class Downsample(nn.Module):
119
+ def __init__(self, kernel, factor=2):
120
+ super().__init__()
121
+
122
+ self.factor = factor
123
+ kernel = make_kernel(kernel)
124
+ self.register_buffer('kernel', kernel)
125
+
126
+ p = kernel.shape[0] - factor
127
+
128
+ pad0 = (p + 1) // 2
129
+ pad1 = p // 2
130
+
131
+ self.pad = (pad0, pad1)
132
+
133
+ def forward(self, input):
134
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
135
+
136
+ return out
137
+
138
+
139
+ class Blur(nn.Module):
140
+ def __init__(self, kernel, pad, upsample_factor=1):
141
+ super().__init__()
142
+
143
+ kernel = make_kernel(kernel)
144
+
145
+ if upsample_factor > 1:
146
+ kernel = kernel * (upsample_factor ** 2)
147
+
148
+ self.register_buffer('kernel', kernel)
149
+
150
+ self.pad = pad
151
+
152
+ def forward(self, input):
153
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
154
+
155
+ return out
156
+
157
+
158
+ class EqualConv2d(nn.Module):
159
+ def __init__(
160
+ self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
161
+ ):
162
+ super().__init__()
163
+
164
+ self.weight = nn.Parameter(
165
+ torch.randn(out_channel, in_channel, kernel_size, kernel_size)
166
+ )
167
+ self.scale = math.sqrt(1) / math.sqrt(in_channel * (kernel_size ** 2))
168
+
169
+ self.stride = stride
170
+ self.padding = padding
171
+
172
+ if bias:
173
+ self.bias = nn.Parameter(torch.zeros(out_channel))
174
+
175
+ else:
176
+ self.bias = None
177
+
178
+ def forward(self, input):
179
+ # print("Before EqualConv2d: ", input.abs().mean())
180
+ out = F.conv2d(
181
+ input,
182
+ self.weight * self.scale,
183
+ bias=self.bias,
184
+ stride=self.stride,
185
+ padding=self.padding,
186
+ )
187
+ # print("After EqualConv2d: ", out.abs().mean(), (self.weight * self.scale).abs().mean())
188
+
189
+ return out
190
+
191
+ def __repr__(self):
192
+ return (
193
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
194
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
195
+ )
196
+
197
+
198
+ class EqualLinear(nn.Module):
199
+ def __init__(
200
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
201
+ ):
202
+ super().__init__()
203
+
204
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
205
+
206
+ if bias:
207
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
208
+
209
+ else:
210
+ self.bias = None
211
+
212
+ self.activation = activation
213
+
214
+ self.scale = (math.sqrt(1) / math.sqrt(in_dim)) * lr_mul
215
+ self.lr_mul = lr_mul
216
+
217
+ def forward(self, input):
218
+ if self.activation:
219
+ out = F.linear(input, self.weight * self.scale)
220
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
221
+
222
+ else:
223
+ out = F.linear(
224
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
225
+ )
226
+
227
+ return out
228
+
229
+ def __repr__(self):
230
+ return (
231
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
232
+ )
233
+
234
+
235
+ class ScaledLeakyReLU(nn.Module):
236
+ def __init__(self, negative_slope=0.2):
237
+ super().__init__()
238
+
239
+ self.negative_slope = negative_slope
240
+
241
+ def forward(self, input):
242
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
243
+
244
+ return out * math.sqrt(2)
245
+
246
+
247
+ class ModulatedConv2d(nn.Module):
248
+ def __init__(
249
+ self,
250
+ in_channel,
251
+ out_channel,
252
+ kernel_size,
253
+ style_dim,
254
+ demodulate=True,
255
+ upsample=False,
256
+ downsample=False,
257
+ blur_kernel=[1, 3, 3, 1],
258
+ ):
259
+ super().__init__()
260
+
261
+ self.eps = 1e-8
262
+ self.kernel_size = kernel_size
263
+ self.in_channel = in_channel
264
+ self.out_channel = out_channel
265
+ self.upsample = upsample
266
+ self.downsample = downsample
267
+
268
+ if upsample:
269
+ factor = 2
270
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
271
+ pad0 = (p + 1) // 2 + factor - 1
272
+ pad1 = p // 2 + 1
273
+
274
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
275
+
276
+ if downsample:
277
+ factor = 2
278
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
279
+ pad0 = (p + 1) // 2
280
+ pad1 = p // 2
281
+
282
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
283
+
284
+ fan_in = in_channel * kernel_size ** 2
285
+ self.scale = math.sqrt(1) / math.sqrt(fan_in)
286
+ self.padding = kernel_size // 2
287
+
288
+ self.weight = nn.Parameter(
289
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
290
+ )
291
+
292
+ if style_dim is not None and style_dim > 0:
293
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
294
+
295
+ self.demodulate = demodulate
296
+
297
+ def __repr__(self):
298
+ return (
299
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
300
+ f'upsample={self.upsample}, downsample={self.downsample})'
301
+ )
302
+
303
+ def forward(self, input, style):
304
+ batch, in_channel, height, width = input.shape
305
+
306
+ if style is not None:
307
+ style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
308
+ else:
309
+ style = torch.ones(batch, 1, in_channel, 1, 1).cuda()
310
+ weight = self.scale * self.weight * style
311
+
312
+ if self.demodulate:
313
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
314
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
315
+
316
+ weight = weight.view(
317
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
318
+ )
319
+
320
+ if self.upsample:
321
+ input = input.view(1, batch * in_channel, height, width)
322
+ weight = weight.view(
323
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
324
+ )
325
+ weight = weight.transpose(1, 2).reshape(
326
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
327
+ )
328
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
329
+ _, _, height, width = out.shape
330
+ out = out.view(batch, self.out_channel, height, width)
331
+ out = self.blur(out)
332
+
333
+ elif self.downsample:
334
+ input = self.blur(input)
335
+ _, _, height, width = input.shape
336
+ input = input.view(1, batch * in_channel, height, width)
337
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
338
+ _, _, height, width = out.shape
339
+ out = out.view(batch, self.out_channel, height, width)
340
+
341
+ else:
342
+ input = input.view(1, batch * in_channel, height, width)
343
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
344
+ _, _, height, width = out.shape
345
+ out = out.view(batch, self.out_channel, height, width)
346
+
347
+ return out
348
+
349
+
350
+ class NoiseInjection(nn.Module):
351
+ def __init__(self):
352
+ super().__init__()
353
+
354
+ self.weight = nn.Parameter(torch.zeros(1))
355
+
356
+ def forward(self, image, noise=None):
357
+ if noise is None:
358
+ batch, _, height, width = image.shape
359
+ noise = image.new_empty(batch, 1, height, width).normal_()
360
+
361
+ return image + self.weight * noise
362
+
363
+
364
+ class ConstantInput(nn.Module):
365
+ def __init__(self, channel, size=4):
366
+ super().__init__()
367
+
368
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
369
+
370
+ def forward(self, input):
371
+ batch = input.shape[0]
372
+ out = self.input.repeat(batch, 1, 1, 1)
373
+
374
+ return out
375
+
376
+
377
+ class StyledConv(nn.Module):
378
+ def __init__(
379
+ self,
380
+ in_channel,
381
+ out_channel,
382
+ kernel_size,
383
+ style_dim=None,
384
+ upsample=False,
385
+ blur_kernel=[1, 3, 3, 1],
386
+ demodulate=True,
387
+ inject_noise=True,
388
+ ):
389
+ super().__init__()
390
+
391
+ self.inject_noise = inject_noise
392
+ self.conv = ModulatedConv2d(
393
+ in_channel,
394
+ out_channel,
395
+ kernel_size,
396
+ style_dim,
397
+ upsample=upsample,
398
+ blur_kernel=blur_kernel,
399
+ demodulate=demodulate,
400
+ )
401
+
402
+ self.noise = NoiseInjection()
403
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
404
+ # self.activate = ScaledLeakyReLU(0.2)
405
+ self.activate = FusedLeakyReLU(out_channel)
406
+
407
+ def forward(self, input, style=None, noise=None):
408
+ out = self.conv(input, style)
409
+ if self.inject_noise:
410
+ out = self.noise(out, noise=noise)
411
+ # out = out + self.bias
412
+ out = self.activate(out)
413
+
414
+ return out
415
+
416
+
417
+ class ToRGB(nn.Module):
418
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
419
+ super().__init__()
420
+
421
+ if upsample:
422
+ self.upsample = Upsample(blur_kernel)
423
+
424
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
425
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
426
+
427
+ def forward(self, input, style, skip=None):
428
+ out = self.conv(input, style)
429
+ out = out + self.bias
430
+
431
+ if skip is not None:
432
+ skip = self.upsample(skip)
433
+
434
+ out = out + skip
435
+
436
+ return out
437
+
438
+
439
+ class Generator(nn.Module):
440
+ def __init__(
441
+ self,
442
+ size,
443
+ style_dim,
444
+ n_mlp,
445
+ channel_multiplier=2,
446
+ blur_kernel=[1, 3, 3, 1],
447
+ lr_mlp=0.01,
448
+ ):
449
+ super().__init__()
450
+
451
+ self.size = size
452
+
453
+ self.style_dim = style_dim
454
+
455
+ layers = [PixelNorm()]
456
+
457
+ for i in range(n_mlp):
458
+ layers.append(
459
+ EqualLinear(
460
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
461
+ )
462
+ )
463
+
464
+ self.style = nn.Sequential(*layers)
465
+
466
+ self.channels = {
467
+ 4: 512,
468
+ 8: 512,
469
+ 16: 512,
470
+ 32: 512,
471
+ 64: 256 * channel_multiplier,
472
+ 128: 128 * channel_multiplier,
473
+ 256: 64 * channel_multiplier,
474
+ 512: 32 * channel_multiplier,
475
+ 1024: 16 * channel_multiplier,
476
+ }
477
+
478
+ self.input = ConstantInput(self.channels[4])
479
+ self.conv1 = StyledConv(
480
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
481
+ )
482
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
483
+
484
+ self.log_size = int(math.log(size, 2))
485
+ self.num_layers = (self.log_size - 2) * 2 + 1
486
+
487
+ self.convs = nn.ModuleList()
488
+ self.upsamples = nn.ModuleList()
489
+ self.to_rgbs = nn.ModuleList()
490
+ self.noises = nn.Module()
491
+
492
+ in_channel = self.channels[4]
493
+
494
+ for layer_idx in range(self.num_layers):
495
+ res = (layer_idx + 5) // 2
496
+ shape = [1, 1, 2 ** res, 2 ** res]
497
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
498
+
499
+ for i in range(3, self.log_size + 1):
500
+ out_channel = self.channels[2 ** i]
501
+
502
+ self.convs.append(
503
+ StyledConv(
504
+ in_channel,
505
+ out_channel,
506
+ 3,
507
+ style_dim,
508
+ upsample=True,
509
+ blur_kernel=blur_kernel,
510
+ )
511
+ )
512
+
513
+ self.convs.append(
514
+ StyledConv(
515
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
516
+ )
517
+ )
518
+
519
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
520
+
521
+ in_channel = out_channel
522
+
523
+ self.n_latent = self.log_size * 2 - 2
524
+
525
+ def make_noise(self):
526
+ device = self.input.input.device
527
+
528
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
529
+
530
+ for i in range(3, self.log_size + 1):
531
+ for _ in range(2):
532
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
533
+
534
+ return noises
535
+
536
+ def mean_latent(self, n_latent):
537
+ latent_in = torch.randn(
538
+ n_latent, self.style_dim, device=self.input.input.device
539
+ )
540
+ latent = self.style(latent_in).mean(0, keepdim=True)
541
+
542
+ return latent
543
+
544
+ def get_latent(self, input):
545
+ return self.style(input)
546
+
547
+ def forward(
548
+ self,
549
+ styles,
550
+ return_latents=False,
551
+ inject_index=None,
552
+ truncation=1,
553
+ truncation_latent=None,
554
+ input_is_latent=False,
555
+ noise=None,
556
+ randomize_noise=True,
557
+ ):
558
+ if not input_is_latent:
559
+ styles = [self.style(s) for s in styles]
560
+
561
+ if noise is None:
562
+ if randomize_noise:
563
+ noise = [None] * self.num_layers
564
+ else:
565
+ noise = [
566
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
567
+ ]
568
+
569
+ if truncation < 1:
570
+ style_t = []
571
+
572
+ for style in styles:
573
+ style_t.append(
574
+ truncation_latent + truncation * (style - truncation_latent)
575
+ )
576
+
577
+ styles = style_t
578
+
579
+ if len(styles) < 2:
580
+ inject_index = self.n_latent
581
+
582
+ if len(styles[0].shape) < 3:
583
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
584
+
585
+ else:
586
+ latent = styles[0]
587
+
588
+ else:
589
+ if inject_index is None:
590
+ inject_index = random.randint(1, self.n_latent - 1)
591
+
592
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
593
+ latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
594
+
595
+ latent = torch.cat([latent, latent2], 1)
596
+
597
+ out = self.input(latent)
598
+ out = self.conv1(out, latent[:, 0], noise=noise[0])
599
+
600
+ skip = self.to_rgb1(out, latent[:, 1])
601
+
602
+ i = 1
603
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
604
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
605
+ ):
606
+ out = conv1(out, latent[:, i], noise=noise1)
607
+ out = conv2(out, latent[:, i + 1], noise=noise2)
608
+ skip = to_rgb(out, latent[:, i + 2], skip)
609
+
610
+ i += 2
611
+
612
+ image = skip
613
+
614
+ if return_latents:
615
+ return image, latent
616
+
617
+ else:
618
+ return image, None
619
+
620
+
621
+ class ConvLayer(nn.Sequential):
622
+ def __init__(
623
+ self,
624
+ in_channel,
625
+ out_channel,
626
+ kernel_size,
627
+ downsample=False,
628
+ blur_kernel=[1, 3, 3, 1],
629
+ bias=True,
630
+ activate=True,
631
+ ):
632
+ layers = []
633
+
634
+ if downsample:
635
+ factor = 2
636
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
637
+ pad0 = (p + 1) // 2
638
+ pad1 = p // 2
639
+
640
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
641
+
642
+ stride = 2
643
+ self.padding = 0
644
+
645
+ else:
646
+ stride = 1
647
+ self.padding = kernel_size // 2
648
+
649
+ layers.append(
650
+ EqualConv2d(
651
+ in_channel,
652
+ out_channel,
653
+ kernel_size,
654
+ padding=self.padding,
655
+ stride=stride,
656
+ bias=bias and not activate,
657
+ )
658
+ )
659
+
660
+ if activate:
661
+ if bias:
662
+ layers.append(FusedLeakyReLU(out_channel))
663
+
664
+ else:
665
+ layers.append(ScaledLeakyReLU(0.2))
666
+
667
+ super().__init__(*layers)
668
+
669
+
670
+ class ResBlock(nn.Module):
671
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1], downsample=True, skip_gain=1.0):
672
+ super().__init__()
673
+
674
+ self.skip_gain = skip_gain
675
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
676
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=downsample, blur_kernel=blur_kernel)
677
+
678
+ if in_channel != out_channel or downsample:
679
+ self.skip = ConvLayer(
680
+ in_channel, out_channel, 1, downsample=downsample, activate=False, bias=False
681
+ )
682
+ else:
683
+ self.skip = nn.Identity()
684
+
685
+ def forward(self, input):
686
+ out = self.conv1(input)
687
+ out = self.conv2(out)
688
+
689
+ skip = self.skip(input)
690
+ out = (out * self.skip_gain + skip) / math.sqrt(self.skip_gain ** 2 + 1.0)
691
+
692
+ return out
693
+
694
+
695
+ class StyleGAN2Discriminator(nn.Module):
696
+ def __init__(self, input_nc, ndf=64, n_layers=3, no_antialias=False, size=None, opt=None):
697
+ super().__init__()
698
+ self.opt = opt
699
+ self.stddev_group = 16
700
+ if size is None:
701
+ size = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
702
+ if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
703
+ size = 2 ** int(np.log2(self.opt.D_patch_size))
704
+
705
+ blur_kernel = [1, 3, 3, 1]
706
+ channel_multiplier = ndf / 64
707
+ channels = {
708
+ 4: min(384, int(4096 * channel_multiplier)),
709
+ 8: min(384, int(2048 * channel_multiplier)),
710
+ 16: min(384, int(1024 * channel_multiplier)),
711
+ 32: min(384, int(512 * channel_multiplier)),
712
+ 64: int(256 * channel_multiplier),
713
+ 128: int(128 * channel_multiplier),
714
+ 256: int(64 * channel_multiplier),
715
+ 512: int(32 * channel_multiplier),
716
+ 1024: int(16 * channel_multiplier),
717
+ }
718
+
719
+ convs = [ConvLayer(3, channels[size], 1)]
720
+
721
+ log_size = int(math.log(size, 2))
722
+
723
+ in_channel = channels[size]
724
+
725
+ if "smallpatch" in self.opt.netD:
726
+ final_res_log2 = 4
727
+ elif "patch" in self.opt.netD:
728
+ final_res_log2 = 3
729
+ else:
730
+ final_res_log2 = 2
731
+
732
+ for i in range(log_size, final_res_log2, -1):
733
+ out_channel = channels[2 ** (i - 1)]
734
+
735
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
736
+
737
+ in_channel = out_channel
738
+
739
+ self.convs = nn.Sequential(*convs)
740
+
741
+ if False and "tile" in self.opt.netD:
742
+ in_channel += 1
743
+ self.final_conv = ConvLayer(in_channel, channels[4], 3)
744
+ if "patch" in self.opt.netD:
745
+ self.final_linear = ConvLayer(channels[4], 1, 3, bias=False, activate=False)
746
+ else:
747
+ self.final_linear = nn.Sequential(
748
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
749
+ EqualLinear(channels[4], 1),
750
+ )
751
+
752
+ def forward(self, input, get_minibatch_features=False):
753
+ if "patch" in self.opt.netD and self.opt.D_patch_size is not None:
754
+ h, w = input.size(2), input.size(3)
755
+ y = torch.randint(h - self.opt.D_patch_size, ())
756
+ x = torch.randint(w - self.opt.D_patch_size, ())
757
+ input = input[:, :, y:y + self.opt.D_patch_size, x:x + self.opt.D_patch_size]
758
+ out = input
759
+ for i, conv in enumerate(self.convs):
760
+ out = conv(out)
761
+ # print(i, out.abs().mean())
762
+ # out = self.convs(input)
763
+
764
+ batch, channel, height, width = out.shape
765
+
766
+ if False and "tile" in self.opt.netD:
767
+ group = min(batch, self.stddev_group)
768
+ stddev = out.view(
769
+ group, -1, 1, channel // 1, height, width
770
+ )
771
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
772
+ stddev = stddev.mean([2, 3, 4], keepdim=True).squeeze(2)
773
+ stddev = stddev.repeat(group, 1, height, width)
774
+ out = torch.cat([out, stddev], 1)
775
+
776
+ out = self.final_conv(out)
777
+ # print(out.abs().mean())
778
+
779
+ if "patch" not in self.opt.netD:
780
+ out = out.view(batch, -1)
781
+ out = self.final_linear(out)
782
+
783
+ return out
784
+
785
+
786
+ class TileStyleGAN2Discriminator(StyleGAN2Discriminator):
787
+ def forward(self, input):
788
+ B, C, H, W = input.size(0), input.size(1), input.size(2), input.size(3)
789
+ size = self.opt.D_patch_size
790
+ Y = H // size
791
+ X = W // size
792
+ input = input.view(B, C, Y, size, X, size)
793
+ input = input.permute(0, 2, 4, 1, 3, 5).contiguous().view(B * Y * X, C, size, size)
794
+ return super().forward(input)
795
+
796
+
797
+ class StyleGAN2Encoder(nn.Module):
798
+ def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
799
+ super().__init__()
800
+ assert opt is not None
801
+ self.opt = opt
802
+ channel_multiplier = ngf / 32
803
+ channels = {
804
+ 4: min(512, int(round(4096 * channel_multiplier))),
805
+ 8: min(512, int(round(2048 * channel_multiplier))),
806
+ 16: min(512, int(round(1024 * channel_multiplier))),
807
+ 32: min(512, int(round(512 * channel_multiplier))),
808
+ 64: int(round(256 * channel_multiplier)),
809
+ 128: int(round(128 * channel_multiplier)),
810
+ 256: int(round(64 * channel_multiplier)),
811
+ 512: int(round(32 * channel_multiplier)),
812
+ 1024: int(round(16 * channel_multiplier)),
813
+ }
814
+
815
+ blur_kernel = [1, 3, 3, 1]
816
+
817
+ cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size)))))
818
+ convs = [nn.Identity(),
819
+ ConvLayer(3, channels[cur_res], 1)]
820
+
821
+ num_downsampling = self.opt.stylegan2_G_num_downsampling
822
+ for i in range(num_downsampling):
823
+ in_channel = channels[cur_res]
824
+ out_channel = channels[cur_res // 2]
825
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel, downsample=True))
826
+ cur_res = cur_res // 2
827
+
828
+ for i in range(n_blocks // 2):
829
+ n_channel = channels[cur_res]
830
+ convs.append(ResBlock(n_channel, n_channel, downsample=False))
831
+
832
+ self.convs = nn.Sequential(*convs)
833
+
834
+ def forward(self, input, layers=[], get_features=False):
835
+ feat = input
836
+ feats = []
837
+ if -1 in layers:
838
+ layers.append(len(self.convs) - 1)
839
+ for layer_id, layer in enumerate(self.convs):
840
+ feat = layer(feat)
841
+ # print(layer_id, " features ", feat.abs().mean())
842
+ if layer_id in layers:
843
+ feats.append(feat)
844
+
845
+ if get_features:
846
+ return feat, feats
847
+ else:
848
+ return feat
849
+
850
+
851
+ class StyleGAN2Decoder(nn.Module):
852
+ def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
853
+ super().__init__()
854
+ assert opt is not None
855
+ self.opt = opt
856
+
857
+ blur_kernel = [1, 3, 3, 1]
858
+
859
+ channel_multiplier = ngf / 32
860
+ channels = {
861
+ 4: min(512, int(round(4096 * channel_multiplier))),
862
+ 8: min(512, int(round(2048 * channel_multiplier))),
863
+ 16: min(512, int(round(1024 * channel_multiplier))),
864
+ 32: min(512, int(round(512 * channel_multiplier))),
865
+ 64: int(round(256 * channel_multiplier)),
866
+ 128: int(round(128 * channel_multiplier)),
867
+ 256: int(round(64 * channel_multiplier)),
868
+ 512: int(round(32 * channel_multiplier)),
869
+ 1024: int(round(16 * channel_multiplier)),
870
+ }
871
+
872
+ num_downsampling = self.opt.stylegan2_G_num_downsampling
873
+ cur_res = 2 ** int((np.rint(np.log2(min(opt.load_size, opt.crop_size))))) // (2 ** num_downsampling)
874
+ convs = []
875
+
876
+ for i in range(n_blocks // 2):
877
+ n_channel = channels[cur_res]
878
+ convs.append(ResBlock(n_channel, n_channel, downsample=False))
879
+
880
+ for i in range(num_downsampling):
881
+ in_channel = channels[cur_res]
882
+ out_channel = channels[cur_res * 2]
883
+ inject_noise = "small" not in self.opt.netG
884
+ convs.append(
885
+ StyledConv(in_channel, out_channel, 3, upsample=True, blur_kernel=blur_kernel, inject_noise=inject_noise)
886
+ )
887
+ cur_res = cur_res * 2
888
+
889
+ convs.append(ConvLayer(channels[cur_res], 3, 1))
890
+
891
+ self.convs = nn.Sequential(*convs)
892
+
893
+ def forward(self, input):
894
+ return self.convs(input)
895
+
896
+
897
+ class StyleGAN2Generator(nn.Module):
898
+ def __init__(self, input_nc, output_nc, ngf=64, use_dropout=False, n_blocks=6, padding_type='reflect', no_antialias=False, opt=None):
899
+ super().__init__()
900
+ self.opt = opt
901
+ self.encoder = StyleGAN2Encoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
902
+ self.decoder = StyleGAN2Decoder(input_nc, output_nc, ngf, use_dropout, n_blocks, padding_type, no_antialias, opt)
903
+
904
+ def forward(self, input, layers=[], encode_only=False):
905
+ feat, feats = self.encoder(input, layers, True)
906
+ if encode_only:
907
+ return feats
908
+ else:
909
+ fake = self.decoder(feat)
910
+
911
+ if len(layers) > 0:
912
+ return fake, feats
913
+ else:
914
+ return fake
models/template_model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model class template
2
+
3
+ This module provides a template for users to implement custom models.
4
+ You can specify '--model template' to use this model.
5
+ The class name should be consistent with both the filename and its model option.
6
+ The filename should be <model>_dataset.py
7
+ The class name should be <Model>Dataset.py
8
+ It implements a simple image-to-image translation baseline based on regression loss.
9
+ Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
10
+ min_<netG> ||netG(data_A) - data_B||_1
11
+ You need to implement the following functions:
12
+ <modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
13
+ <__init__>: Initialize this model class.
14
+ <set_input>: Unpack input data and perform data pre-processing.
15
+ <forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
16
+ <optimize_parameters>: Update network weights; it will be called in every training iteration.
17
+ """
18
+ import torch
19
+ from .base_model import BaseModel
20
+ from . import networks
21
+
22
+
23
+ class TemplateModel(BaseModel):
24
+ @staticmethod
25
+ def modify_commandline_options(parser, is_train=True):
26
+ """Add new model-specific options and rewrite default values for existing options.
27
+
28
+ Parameters:
29
+ parser -- the option parser
30
+ is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
31
+
32
+ Returns:
33
+ the modified parser.
34
+ """
35
+ parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
36
+ if is_train:
37
+ parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
38
+
39
+ return parser
40
+
41
+ def __init__(self, opt):
42
+ """Initialize this model class.
43
+
44
+ Parameters:
45
+ opt -- training/test options
46
+
47
+ A few things can be done here.
48
+ - (required) call the initialization function of BaseModel
49
+ - define loss function, visualization images, model names, and optimizers
50
+ """
51
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
52
+ # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
53
+ self.loss_names = ['loss_G']
54
+ # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
55
+ self.visual_names = ['data_A', 'data_B', 'output']
56
+ # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
57
+ # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
58
+ self.model_names = ['G']
59
+ # define networks; you can use opt.isTrain to specify different behaviors for training and test.
60
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
61
+ if self.isTrain: # only defined during training time
62
+ # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
63
+ # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
64
+ self.criterionLoss = torch.nn.L1Loss()
65
+ # define and initialize optimizers. You can define one optimizer for each network.
66
+ # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
67
+ self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
68
+ self.optimizers = [self.optimizer]
69
+
70
+ # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
71
+
72
+ def set_input(self, input):
73
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
74
+
75
+ Parameters:
76
+ input: a dictionary that contains the data itself and its metadata information.
77
+ """
78
+ AtoB = self.opt.direction == 'AtoB' # use <direction> to swap data_A and data_B
79
+ self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
80
+ self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
81
+ self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
82
+
83
+ def forward(self):
84
+ """Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
85
+ self.output = self.netG(self.data_A) # generate output image given the input data_A
86
+
87
+ def backward(self):
88
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
89
+ # caculate the intermediate results if necessary; here self.output has been computed during function <forward>
90
+ # calculate loss given the input and intermediate results
91
+ self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
92
+ self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
93
+
94
+ def optimize_parameters(self):
95
+ """Update network weights; it will be called in every training iteration."""
96
+ self.forward() # first call forward to calculate intermediate results
97
+ self.optimizer.zero_grad() # clear network G's existing gradients
98
+ self.backward() # calculate gradients for network G
99
+ self.optimizer.step() # update gradients for network G
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.6.0
2
+ torchvision>=0.7.0
3
+ dominate>=2.4.0
4
+ visdom>=0.1.8.8
5
+ packaging
6
+ GPUtil>=1.4.0
7
+ scipy
8
+ Pillow>=6.1.0
9
+ numpy>=1.16.4
test.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ModelLoader import ModelLoader
2
+ import os
3
+
4
+ def main():
5
+ model = ModelLoader()
6
+ model.load()
7
+ # Test
8
+ sample = os.path.join(os.path.dirname(__file__), 'examples', 'rawimg.png')
9
+ model.inference(src=sample)
10
+
11
+ main()
util/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ """This package includes a miscellaneous collection of useful helper functions."""
2
+ from util import *
util/get_data.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import tarfile
4
+ import requests
5
+ from warnings import warn
6
+ from zipfile import ZipFile
7
+ from bs4 import BeautifulSoup
8
+ from os.path import abspath, isdir, join, basename
9
+
10
+
11
+ class GetData(object):
12
+ """A Python script for downloading CycleGAN or pix2pix datasets.
13
+
14
+ Parameters:
15
+ technique (str) -- One of: 'cyclegan' or 'pix2pix'.
16
+ verbose (bool) -- If True, print additional information.
17
+
18
+ Examples:
19
+ >>> from util.get_data import GetData
20
+ >>> gd = GetData(technique='cyclegan')
21
+ >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
22
+
23
+ Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
24
+ and 'scripts/download_cyclegan_model.sh'.
25
+ """
26
+
27
+ def __init__(self, technique='cyclegan', verbose=True):
28
+ url_dict = {
29
+ 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
30
+ 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
31
+ }
32
+ self.url = url_dict.get(technique.lower())
33
+ self._verbose = verbose
34
+
35
+ def _print(self, text):
36
+ if self._verbose:
37
+ print(text)
38
+
39
+ @staticmethod
40
+ def _get_options(r):
41
+ soup = BeautifulSoup(r.text, 'lxml')
42
+ options = [h.text for h in soup.find_all('a', href=True)
43
+ if h.text.endswith(('.zip', 'tar.gz'))]
44
+ return options
45
+
46
+ def _present_options(self):
47
+ r = requests.get(self.url)
48
+ options = self._get_options(r)
49
+ print('Options:\n')
50
+ for i, o in enumerate(options):
51
+ print("{0}: {1}".format(i, o))
52
+ choice = input("\nPlease enter the number of the "
53
+ "dataset above you wish to download:")
54
+ return options[int(choice)]
55
+
56
+ def _download_data(self, dataset_url, save_path):
57
+ if not isdir(save_path):
58
+ os.makedirs(save_path)
59
+
60
+ base = basename(dataset_url)
61
+ temp_save_path = join(save_path, base)
62
+
63
+ with open(temp_save_path, "wb") as f:
64
+ r = requests.get(dataset_url)
65
+ f.write(r.content)
66
+
67
+ if base.endswith('.tar.gz'):
68
+ obj = tarfile.open(temp_save_path)
69
+ elif base.endswith('.zip'):
70
+ obj = ZipFile(temp_save_path, 'r')
71
+ else:
72
+ raise ValueError("Unknown File Type: {0}.".format(base))
73
+
74
+ self._print("Unpacking Data...")
75
+ obj.extractall(save_path)
76
+ obj.close()
77
+ os.remove(temp_save_path)
78
+
79
+ def get(self, save_path, dataset=None):
80
+ """
81
+
82
+ Download a dataset.
83
+
84
+ Parameters:
85
+ save_path (str) -- A directory to save the data to.
86
+ dataset (str) -- (optional). A specific dataset to download.
87
+ Note: this must include the file extension.
88
+ If None, options will be presented for you
89
+ to choose from.
90
+
91
+ Returns:
92
+ save_path_full (str) -- the absolute path to the downloaded data.
93
+
94
+ """
95
+ if dataset is None:
96
+ selected_dataset = self._present_options()
97
+ else:
98
+ selected_dataset = dataset
99
+
100
+ save_path_full = join(save_path, selected_dataset.split('.')[0])
101
+
102
+ if isdir(save_path_full):
103
+ warn("\n'{0}' already exists. Voiding Download.".format(
104
+ save_path_full))
105
+ else:
106
+ self._print('Downloading Data...')
107
+ url = "{0}/{1}".format(self.url, selected_dataset)
108
+ self._download_data(url, save_path=save_path)
109
+
110
+ return abspath(save_path_full)
util/html.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dominate
2
+ from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3
+ import os
4
+
5
+
6
+ class HTML:
7
+ """This HTML class allows us to save images and write texts into a single HTML file.
8
+
9
+ It consists of functions such as <add_header> (add a text header to the HTML file),
10
+ <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
11
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12
+ """
13
+
14
+ def __init__(self, web_dir, title, refresh=0):
15
+ """Initialize the HTML classes
16
+
17
+ Parameters:
18
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
19
+ title (str) -- the webpage name
20
+ refresh (int) -- how often the website refresh itself; if 0; no refreshing
21
+ """
22
+ self.title = title
23
+ self.web_dir = web_dir
24
+ self.img_dir = os.path.join(self.web_dir, 'images')
25
+ if not os.path.exists(self.web_dir):
26
+ os.makedirs(self.web_dir)
27
+ if not os.path.exists(self.img_dir):
28
+ os.makedirs(self.img_dir)
29
+
30
+ self.doc = dominate.document(title=title)
31
+ if refresh > 0:
32
+ with self.doc.head:
33
+ meta(http_equiv="refresh", content=str(refresh))
34
+
35
+ def get_image_dir(self):
36
+ """Return the directory that stores images"""
37
+ return self.img_dir
38
+
39
+ def add_header(self, text):
40
+ """Insert a header to the HTML file
41
+
42
+ Parameters:
43
+ text (str) -- the header text
44
+ """
45
+ with self.doc:
46
+ h3(text)
47
+
48
+ def add_images(self, ims, txts, links, width=400):
49
+ """add images to the HTML file
50
+
51
+ Parameters:
52
+ ims (str list) -- a list of image paths
53
+ txts (str list) -- a list of image names shown on the website
54
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55
+ """
56
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57
+ self.doc.add(self.t)
58
+ with self.t:
59
+ with tr():
60
+ for im, txt, link in zip(ims, txts, links):
61
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
62
+ with p():
63
+ with a(href=os.path.join('images', link)):
64
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
65
+ br()
66
+ p(txt)
67
+
68
+ def save(self):
69
+ """save the current content to the HMTL file"""
70
+ html_file = '%s/index.html' % self.web_dir
71
+ f = open(html_file, 'wt')
72
+ f.write(self.doc.render())
73
+ f.close()
74
+
75
+
76
+ if __name__ == '__main__': # we show an example usage here.
77
+ html = HTML('web/', 'test_html')
78
+ html.add_header('hello world')
79
+
80
+ ims, txts, links = [], [], []
81
+ for n in range(4):
82
+ ims.append('image_%d.png' % n)
83
+ txts.append('text_%d' % n)
84
+ links.append('image_%d.png' % n)
85
+ html.add_images(ims, txts, links)
86
+ html.save()
util/image_pool.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import torch
3
+
4
+
5
+ class ImagePool():
6
+ """This class implements an image buffer that stores previously generated images.
7
+
8
+ This buffer enables us to update discriminators using a history of generated images
9
+ rather than the ones produced by the latest generators.
10
+ """
11
+
12
+ def __init__(self, pool_size):
13
+ """Initialize the ImagePool class
14
+
15
+ Parameters:
16
+ pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
17
+ """
18
+ self.pool_size = pool_size
19
+ if self.pool_size > 0: # create an empty pool
20
+ self.num_imgs = 0
21
+ self.images = []
22
+
23
+ def query(self, images):
24
+ """Return an image from the pool.
25
+
26
+ Parameters:
27
+ images: the latest generated images from the generator
28
+
29
+ Returns images from the buffer.
30
+
31
+ By 50/100, the buffer will return input images.
32
+ By 50/100, the buffer will return images previously stored in the buffer,
33
+ and insert the current images to the buffer.
34
+ """
35
+ if self.pool_size == 0: # if the buffer size is 0, do nothing
36
+ return images
37
+ return_images = []
38
+ for image in images:
39
+ image = torch.unsqueeze(image.data, 0)
40
+ if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
41
+ self.num_imgs = self.num_imgs + 1
42
+ self.images.append(image)
43
+ return_images.append(image)
44
+ else:
45
+ p = random.uniform(0, 1)
46
+ if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
47
+ random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
48
+ tmp = self.images[random_id].clone()
49
+ self.images[random_id] = image
50
+ return_images.append(tmp)
51
+ else: # by another 50% chance, the buffer will return the current image
52
+ return_images.append(image)
53
+ return_images = torch.cat(return_images, 0) # collect all the images and return
54
+ return return_images
util/util.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This module contains simple helper functions """
2
+ from __future__ import print_function
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import os
7
+ import importlib
8
+ import argparse
9
+ from argparse import Namespace
10
+ import torchvision
11
+
12
+
13
+ def str2bool(v):
14
+ if isinstance(v, bool):
15
+ return v
16
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
17
+ return True
18
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
19
+ return False
20
+ else:
21
+ raise argparse.ArgumentTypeError('Boolean value expected.')
22
+
23
+
24
+ def copyconf(default_opt, **kwargs):
25
+ conf = Namespace(**vars(default_opt))
26
+ for key in kwargs:
27
+ setattr(conf, key, kwargs[key])
28
+ return conf
29
+
30
+
31
+ def find_class_in_module(target_cls_name, module):
32
+ target_cls_name = target_cls_name.replace('_', '').lower()
33
+ clslib = importlib.import_module(module)
34
+ cls = None
35
+ for name, clsobj in clslib.__dict__.items():
36
+ if name.lower() == target_cls_name:
37
+ cls = clsobj
38
+
39
+ assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
40
+
41
+ return cls
42
+
43
+
44
+ def tensor2im(input_image, imtype=np.uint8):
45
+ """"Converts a Tensor array into a numpy image array.
46
+
47
+ Parameters:
48
+ input_image (tensor) -- the input image tensor array
49
+ imtype (type) -- the desired type of the converted numpy array
50
+ """
51
+ if not isinstance(input_image, np.ndarray):
52
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
53
+ image_tensor = input_image.data
54
+ else:
55
+ return input_image
56
+ image_numpy = image_tensor[0].clamp(-1.0, 1.0).cpu().float().numpy() # convert it into a numpy array
57
+ if image_numpy.shape[0] == 1: # grayscale to RGB
58
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
59
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
60
+ else: # if it is a numpy array, do nothing
61
+ image_numpy = input_image
62
+ return image_numpy.astype(imtype)
63
+
64
+
65
+ def diagnose_network(net, name='network'):
66
+ """Calculate and print the mean of average absolute(gradients)
67
+
68
+ Parameters:
69
+ net (torch network) -- Torch network
70
+ name (str) -- the name of the network
71
+ """
72
+ mean = 0.0
73
+ count = 0
74
+ for param in net.parameters():
75
+ if param.grad is not None:
76
+ mean += torch.mean(torch.abs(param.grad.data))
77
+ count += 1
78
+ if count > 0:
79
+ mean = mean / count
80
+ print(name)
81
+ print(mean)
82
+
83
+
84
+ def save_image(image_numpy, image_path, aspect_ratio=1.0):
85
+ """Save a numpy image to the disk
86
+
87
+ Parameters:
88
+ image_numpy (numpy array) -- input numpy array
89
+ image_path (str) -- the path of the image
90
+ """
91
+
92
+ image_pil = Image.fromarray(image_numpy)
93
+ h, w, _ = image_numpy.shape
94
+
95
+ if aspect_ratio is None:
96
+ pass
97
+ elif aspect_ratio > 1.0:
98
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
99
+ elif aspect_ratio < 1.0:
100
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
101
+ image_pil.save(image_path)
102
+
103
+
104
+ def print_numpy(x, val=True, shp=False):
105
+ """Print the mean, min, max, median, std, and size of a numpy array
106
+
107
+ Parameters:
108
+ val (bool) -- if print the values of the numpy array
109
+ shp (bool) -- if print the shape of the numpy array
110
+ """
111
+ x = x.astype(np.float64)
112
+ if shp:
113
+ print('shape,', x.shape)
114
+ if val:
115
+ x = x.flatten()
116
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
117
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
118
+
119
+
120
+ def mkdirs(paths):
121
+ """create empty directories if they don't exist
122
+
123
+ Parameters:
124
+ paths (str list) -- a list of directory paths
125
+ """
126
+ if isinstance(paths, list) and not isinstance(paths, str):
127
+ for path in paths:
128
+ mkdir(path)
129
+ else:
130
+ mkdir(paths)
131
+
132
+
133
+ def mkdir(path):
134
+ """create a single empty directory if it didn't exist
135
+
136
+ Parameters:
137
+ path (str) -- a single directory path
138
+ """
139
+ if not os.path.exists(path):
140
+ os.makedirs(path)
141
+
142
+
143
+ def correct_resize_label(t, size):
144
+ device = t.device
145
+ t = t.detach().cpu()
146
+ resized = []
147
+ for i in range(t.size(0)):
148
+ one_t = t[i, :1]
149
+ one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
150
+ one_np = one_np[:, :, 0]
151
+ one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
152
+ resized_t = torch.from_numpy(np.array(one_image)).long()
153
+ resized.append(resized_t)
154
+ return torch.stack(resized, dim=0).to(device)
155
+
156
+
157
+ def correct_resize(t, size, mode=Image.BICUBIC):
158
+ device = t.device
159
+ t = t.detach().cpu()
160
+ resized = []
161
+ for i in range(t.size(0)):
162
+ one_t = t[i:i + 1]
163
+ one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
164
+ resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
165
+ resized.append(resized_t)
166
+ return torch.stack(resized, dim=0).to(device)
util/visualizer.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import sys
4
+ import ntpath
5
+ import time
6
+ from . import util, html
7
+ from subprocess import Popen, PIPE
8
+
9
+ if sys.version_info[0] == 2:
10
+ VisdomExceptionBase = Exception
11
+ else:
12
+ VisdomExceptionBase = ConnectionError
13
+
14
+
15
+ def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
16
+ """Save images to the disk.
17
+
18
+ Parameters:
19
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
20
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
21
+ image_path (str) -- the string is used to create image paths
22
+ aspect_ratio (float) -- the aspect ratio of saved images
23
+ width (int) -- the images will be resized to width x width
24
+
25
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
26
+ """
27
+ image_dir = webpage.get_image_dir()
28
+ short_path = ntpath.basename(image_path[0])
29
+ name = os.path.splitext(short_path)[0]
30
+
31
+ webpage.add_header(name)
32
+ ims, txts, links = [], [], []
33
+
34
+ for label, im_data in visuals.items():
35
+ im = util.tensor2im(im_data)
36
+ image_name = '%s/%s.png' % (label, name)
37
+ os.makedirs(os.path.join(image_dir, label), exist_ok=True)
38
+ save_path = os.path.join(image_dir, image_name)
39
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
40
+ ims.append(image_name)
41
+ txts.append(label)
42
+ links.append(image_name)
43
+ webpage.add_images(ims, txts, links, width=width)
44
+
45
+
46
+ class Visualizer():
47
+ """This class includes several functions that can display/save images and print/save logging information.
48
+
49
+ It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
50
+ """
51
+
52
+ def __init__(self, opt):
53
+ """Initialize the Visualizer class
54
+
55
+ Parameters:
56
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
57
+ Step 1: Cache the training/test options
58
+ Step 2: connect to a visdom server
59
+ Step 3: create an HTML object for saveing HTML filters
60
+ Step 4: create a logging file to store training losses
61
+ """
62
+ self.opt = opt # cache the option
63
+ if opt.display_id is None:
64
+ self.display_id = np.random.randint(100000) * 10 # just a random display id
65
+ else:
66
+ self.display_id = opt.display_id
67
+ self.use_html = opt.isTrain and not opt.no_html
68
+ self.win_size = opt.display_winsize
69
+ self.name = opt.name
70
+ self.port = opt.display_port
71
+ self.saved = False
72
+ if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
73
+ import visdom
74
+ self.plot_data = {}
75
+ self.ncols = opt.display_ncols
76
+ if "tensorboard_base_url" not in os.environ:
77
+ self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
78
+ else:
79
+ self.vis = visdom.Visdom(port=2004,
80
+ base_url=os.environ['tensorboard_base_url'] + '/visdom')
81
+ if not self.vis.check_connection():
82
+ self.create_visdom_connections()
83
+
84
+ if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
85
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
86
+ self.img_dir = os.path.join(self.web_dir, 'images')
87
+ print('create web directory %s...' % self.web_dir)
88
+ util.mkdirs([self.web_dir, self.img_dir])
89
+ # create a logging file to store training losses
90
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
91
+ with open(self.log_name, "a") as log_file:
92
+ now = time.strftime("%c")
93
+ log_file.write('================ Training Loss (%s) ================\n' % now)
94
+
95
+ def reset(self):
96
+ """Reset the self.saved status"""
97
+ self.saved = False
98
+
99
+ def create_visdom_connections(self):
100
+ """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
101
+ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
102
+ print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
103
+ print('Command: %s' % cmd)
104
+ Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
105
+
106
+ def display_current_results(self, visuals, epoch, save_result):
107
+ """Display current results on visdom; save current results to an HTML file.
108
+
109
+ Parameters:
110
+ visuals (OrderedDict) - - dictionary of images to display or save
111
+ epoch (int) - - the current epoch
112
+ save_result (bool) - - if save the current results to an HTML file
113
+ """
114
+ if self.display_id > 0: # show images in the browser using visdom
115
+ ncols = self.ncols
116
+ if ncols > 0: # show all the images in one visdom panel
117
+ ncols = min(ncols, len(visuals))
118
+ h, w = next(iter(visuals.values())).shape[:2]
119
+ table_css = """<style>
120
+ table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
121
+ table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
122
+ </style>""" % (w, h) # create a table css
123
+ # create a table of images.
124
+ title = self.name
125
+ label_html = ''
126
+ label_html_row = ''
127
+ images = []
128
+ idx = 0
129
+ for label, image in visuals.items():
130
+ image_numpy = util.tensor2im(image)
131
+ label_html_row += '<td>%s</td>' % label
132
+ images.append(image_numpy.transpose([2, 0, 1]))
133
+ idx += 1
134
+ if idx % ncols == 0:
135
+ label_html += '<tr>%s</tr>' % label_html_row
136
+ label_html_row = ''
137
+ white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
138
+ while idx % ncols != 0:
139
+ images.append(white_image)
140
+ label_html_row += '<td></td>'
141
+ idx += 1
142
+ if label_html_row != '':
143
+ label_html += '<tr>%s</tr>' % label_html_row
144
+ try:
145
+ self.vis.images(images, ncols, 2, self.display_id + 1,
146
+ None, dict(title=title + ' images'))
147
+ label_html = '<table>%s</table>' % label_html
148
+ self.vis.text(table_css + label_html, win=self.display_id + 2,
149
+ opts=dict(title=title + ' labels'))
150
+ except VisdomExceptionBase:
151
+ self.create_visdom_connections()
152
+
153
+ else: # show each image in a separate visdom panel;
154
+ idx = 1
155
+ try:
156
+ for label, image in visuals.items():
157
+ image_numpy = util.tensor2im(image)
158
+ self.vis.image(
159
+ image_numpy.transpose([2, 0, 1]),
160
+ self.display_id + idx,
161
+ None,
162
+ dict(title=label)
163
+ )
164
+ idx += 1
165
+ except VisdomExceptionBase:
166
+ self.create_visdom_connections()
167
+
168
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
169
+ self.saved = True
170
+ # save images to the disk
171
+ for label, image in visuals.items():
172
+ image_numpy = util.tensor2im(image)
173
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
174
+ util.save_image(image_numpy, img_path)
175
+
176
+ # update website
177
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
178
+ for n in range(epoch, 0, -1):
179
+ webpage.add_header('epoch [%d]' % n)
180
+ ims, txts, links = [], [], []
181
+
182
+ for label, image_numpy in visuals.items():
183
+ image_numpy = util.tensor2im(image)
184
+ img_path = 'epoch%.3d_%s.png' % (n, label)
185
+ ims.append(img_path)
186
+ txts.append(label)
187
+ links.append(img_path)
188
+ webpage.add_images(ims, txts, links, width=self.win_size)
189
+ webpage.save()
190
+
191
+ def plot_current_losses(self, epoch, counter_ratio, losses):
192
+ """display the current losses on visdom display: dictionary of error labels and values
193
+
194
+ Parameters:
195
+ epoch (int) -- current epoch
196
+ counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
197
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
198
+ """
199
+ if len(losses) == 0:
200
+ return
201
+
202
+ plot_name = '_'.join(list(losses.keys()))
203
+
204
+ if plot_name not in self.plot_data:
205
+ self.plot_data[plot_name] = {'X': [], 'Y': [], 'legend': list(losses.keys())}
206
+
207
+ plot_data = self.plot_data[plot_name]
208
+ plot_id = list(self.plot_data.keys()).index(plot_name)
209
+
210
+ plot_data['X'].append(epoch + counter_ratio)
211
+ plot_data['Y'].append([losses[k] for k in plot_data['legend']])
212
+ try:
213
+ self.vis.line(
214
+ X=np.stack([np.array(plot_data['X'])] * len(plot_data['legend']), 1),
215
+ Y=np.array(plot_data['Y']),
216
+ opts={
217
+ 'title': self.name,
218
+ 'legend': plot_data['legend'],
219
+ 'xlabel': 'epoch',
220
+ 'ylabel': 'loss'},
221
+ win=self.display_id - plot_id)
222
+ except VisdomExceptionBase:
223
+ self.create_visdom_connections()
224
+
225
+ # losses: same format as |losses| of plot_current_losses
226
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
227
+ """print current losses on console; also save the losses to the disk
228
+
229
+ Parameters:
230
+ epoch (int) -- current epoch
231
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
232
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
233
+ t_comp (float) -- computational time per data point (normalized by batch_size)
234
+ t_data (float) -- data loading time per data point (normalized by batch_size)
235
+ """
236
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
237
+ for k, v in losses.items():
238
+ message += '%s: %.3f ' % (k, v)
239
+
240
+ print(message) # print the message
241
+ with open(self.log_name, "a") as log_file:
242
+ log_file.write('%s\n' % message) # save the message