JungminChung commited on
Commit
3617b5f
·
1 Parent(s): de37443

first commit

Browse files
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import utils
2
+ import torch
3
+ import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image
6
+ from network import ImageTransformNet_dpws
7
+ from torch.autograd import Variable
8
+ from torchvision import transforms
9
+
10
+ with gr.Blocks() as demo:
11
+ with gr.Row():
12
+ gr.HTML('<h1 style="text-align: center;">스타일 변환기</h1>')
13
+
14
+ with gr.Row():
15
+ with gr.Column():
16
+ style_radio = gr.Radio(['La muse', 'Mosaic', 'Starry Night Crop', 'Wave Crop'], label='원하는 스타일 선택!')
17
+ image_input = gr.Image(label='콘텐츠 이미지')
18
+ convert_button = gr.Button('변환!')
19
+
20
+ with gr.Column():
21
+ result_image = gr.Image(label='결과 이미지')
22
+
23
+ def transform_image(style, img):
24
+ dtype = torch.FloatTensor
25
+
26
+ # content image
27
+ img_transform_512 = transforms.Compose([
28
+ # transforms.Scale(512), # scale shortest side to image_size
29
+ transforms.Resize(512), # scale shortest side to image_size
30
+ # transforms.CenterCrop(512), # crop center image_size out
31
+ transforms.ToTensor(), # turn image from [0-255] to [0-1]
32
+ utils.normalize_tensor_transform() # normalize with ImageNet values
33
+ ])
34
+
35
+ content = Image.fromarray(img)
36
+ content = img_transform_512(content)
37
+ content = content.unsqueeze(0)
38
+ # content = Variable(content).type(dtype)
39
+ content = Variable(content.repeat(1, 1, 1, 1), requires_grad=False).type(dtype)
40
+
41
+ # load style model
42
+ model_folder_name = '_'.join(style.lower().split())
43
+ model_path = 'models/' + model_folder_name + '/compressed.model'
44
+ checkpoint_lw = torch.load(model_path)
45
+
46
+ style_model = ImageTransformNet_dpws().type(dtype)
47
+ style_model.load_state_dict((checkpoint_lw))
48
+
49
+ # process input image
50
+ stylized = style_model(content).cpu()
51
+ utils.save_image('results.jpg', stylized.data[0])
52
+ return 'results.jpg'
53
+
54
+ convert_button.click(
55
+ transform_image,
56
+ inputs=[style_radio, image_input],
57
+ outputs=[result_image],
58
+ )
59
+
60
+ demo.launch()
models/la_muse/compressed.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b4d90ccfbf9679128254df891aec7a2a460f34df5490ad370bb1d3e6a32cc17
3
+ size 92080
models/mosaic/compressed.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24ff3368b90ba4ce65e0efaf97876bca4148958dd506e84a568a9425142cd314
3
+ size 92080
models/starry_night_crop/compressed.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e5a3665567b07638dcc713c1d8ea30c1148824ebce3eade677b07181787a03d
3
+ size 92080
models/wave_crop/compressed.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d0974a83e9d62be50d636e80c3be48a90893376c22da6dc1cd004697616eeed
3
+ size 92080
network.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # Conv Layer
5
+ class ConvLayer(nn.Module):
6
+ def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
7
+ super(ConvLayer, self).__init__()
8
+ paddings = kernel_size // 2
9
+ self.reflection_pad = nn.ReflectionPad2d(paddings)
10
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, groups=groups) #, padding)
11
+ # self.in_d = nn.InstanceNorm2d(out_channels, affine=True)
12
+
13
+ def forward(self, x):
14
+ out = self.reflection_pad(x)
15
+ out = self.conv2d(out)
16
+ return out
17
+
18
+
19
+ class ConvLayer_dpws(nn.Module):
20
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
21
+ super(ConvLayer_dpws, self).__init__()
22
+ self.conv1 = ConvLayer(in_channels, in_channels, kernel_size, stride=stride, groups=in_channels)
23
+ self.in_1d = nn.InstanceNorm2d(in_channels, affine=True)
24
+ self.conv2 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1)
25
+ self.in_2d = nn.InstanceNorm2d(out_channels, affine=True)
26
+ self.relu = nn.ReLU()
27
+
28
+
29
+ def forward(self, x):
30
+ out = self.in_1d(self.conv1(x))
31
+ out = self.relu(self.in_2d(self.conv2(out)))
32
+ return out
33
+
34
+
35
+ class ConvLayer_dpws_last(nn.Module):
36
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
37
+ super(ConvLayer_dpws_last, self).__init__()
38
+ self.conv1 = ConvLayer(in_channels, in_channels, kernel_size, stride=stride, groups=in_channels)
39
+ self.in_1d = nn.InstanceNorm2d(in_channels, affine=True)
40
+ self.conv2 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1)
41
+ # self.in_2d = nn.InstanceNorm2d(out_channels, affine=True)
42
+ # self.relu = nn.ReLU()
43
+
44
+ def forward(self, x):
45
+ out = self.in_1d(self.conv1(x))
46
+ # out = self.relu(self.in_2d(self.conv2(out)))
47
+ out = self.conv2(out)
48
+ return out
49
+
50
+ # Upsample Conv Layer
51
+ class UpsampleConvLayer(nn.Module):
52
+ def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
53
+ super(UpsampleConvLayer, self).__init__()
54
+ self.upsample = upsample
55
+ if upsample:
56
+ self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest')
57
+ reflection_padding = kernel_size // 2
58
+ self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
59
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
60
+ # self.in_d = nn.InstanceNorm2d(out_channels, affine=True)
61
+
62
+ def forward(self, x):
63
+ if self.upsample:
64
+ x = self.upsample(x)
65
+ out = self.reflection_pad(x)
66
+ out = self.conv2d(out)
67
+ return out
68
+
69
+
70
+ class UpsampleConvLayer_dpws(nn.Module):
71
+ def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
72
+ super(UpsampleConvLayer_dpws, self).__init__()
73
+ self.upsample = upsample
74
+ if upsample:
75
+ self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest')
76
+ self.conv1 = ConvLayer(in_channels, in_channels, kernel_size, stride, groups=in_channels)
77
+ self.in1 = nn.InstanceNorm2d(in_channels, affine=True )
78
+ self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
79
+ self.in2 = nn.InstanceNorm2d(out_channels, affine=True )
80
+ self.relu = nn.ReLU()
81
+
82
+ def forward(self, x):
83
+ if self.upsample:
84
+ x = self.upsample(x)
85
+ # out = self.reflection_pad(x)
86
+ # out = self.conv2d(out)
87
+ out = self.relu(self.in1(self.conv1(x)))
88
+ out = self.in2(self.conv2(out))
89
+ return out
90
+
91
+
92
+ class DeConvLayer(nn.Module):
93
+ def __init__(self, in_channels, out_channels, kernel_size, stride):
94
+ super(DeConvLayer, self).__init__()
95
+
96
+ # reflection_padding = kernel_size // 2
97
+ # self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
98
+ self.deconv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=1, output_padding=1)
99
+ def forward(self, x):
100
+ # out = self.reflection_pad(x)
101
+ out = self.deconv2d(x)
102
+ return out
103
+
104
+ # Residual Block
105
+ # adapted from pytorch tutorial
106
+ # https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-
107
+ # intermediate/deep_residual_network/main.py
108
+ class ResidualBlock(nn.Module):
109
+ def __init__(self, channels):
110
+ super(ResidualBlock, self).__init__()
111
+
112
+ self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
113
+ self.in1 = nn.InstanceNorm2d(channels, affine=True)
114
+ self.relu = nn.ReLU()
115
+ self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
116
+ self.in2 = nn.InstanceNorm2d(channels, affine=True)
117
+
118
+
119
+ def forward(self, x):
120
+ residual = x
121
+ out = self.relu(self.in1(self.conv1(x)))
122
+ # out = self.relu(self.in2(self.conv2(out)))
123
+ out = self.in2(self.conv2(out))
124
+ # out = self.relu(self.conv2(out))
125
+ # out = self.conv2(out)
126
+ out = out + residual
127
+ # out = self.relu(out)
128
+
129
+ return out
130
+
131
+
132
+ class ResidualBlock_depthwise(nn.Module):
133
+ def __init__(self, channels):
134
+ super(ResidualBlock_depthwise, self).__init__()
135
+
136
+ # ########################## deptwise ###########################################
137
+ self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, groups=channels)
138
+ self.in1 = nn.InstanceNorm2d(channels, affine=True )
139
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=1, stride=1)
140
+ self.in2 = nn.InstanceNorm2d(channels, affine=True )
141
+
142
+ self.conv3 = ConvLayer(channels, channels, kernel_size=3, stride=1, groups=channels)
143
+ self.in3 = nn.InstanceNorm2d(channels, affine=True )
144
+ self.conv4 = nn.Conv2d(channels, channels, kernel_size=1, stride=1)
145
+ self.in4 = nn.InstanceNorm2d(channels, affine=True )
146
+
147
+ self.relu = nn.ReLU()
148
+ self.prelu = nn.PReLU()
149
+
150
+ def forward(self, x):
151
+
152
+ # ############### DEPTWISE ###################
153
+ # residual = x
154
+ # out = self.relu(self.in1(self.conv1(x)))
155
+ # out = self.relu(self.in2(self.conv2(out)))
156
+ # out = self.relu(self.in3(self.conv3(out)))
157
+ # out = self.relu(self.in4(self.conv4(out)))
158
+ # out = out + residual
159
+
160
+ # # ################## v1 ####################
161
+ # residual = x
162
+ # out = self.in1(self.conv1(x))
163
+ # out = self.relu(self.in2(self.conv2(out)))
164
+ # out = self.in3(self.conv3(out))
165
+ # out = self.in4(self.conv4(out))
166
+ # out = out + residual
167
+ # out = self.relu(out)
168
+
169
+ # ################## v2 #################### √
170
+ residual = x
171
+ out = self.in1(self.conv1(x))
172
+ out = self.relu(self.in2(self.conv2(out)))
173
+ out = self.in3(self.conv3(out))
174
+ out = self.in4(self.conv4(out))
175
+ out = out + residual
176
+
177
+ # ################## v3 ####################
178
+ # residual = x
179
+ # out = self.conv1(x)
180
+ # out = self.relu(self.in2(self.conv2(out)))
181
+ # out = self.conv3(out)
182
+ # out = self.in4(self.conv4(out))
183
+ # out = out + residual
184
+
185
+ # ################## v4 ####################
186
+ # residual = x
187
+ # out = self.in1(self.conv1(x))
188
+ # out = self.relu(self.in2(self.conv2(out)))
189
+ # out = self.in3(self.conv3(out))
190
+ # out = self.relu(self.in4(self.conv4(out)))
191
+ # out = out + residual
192
+
193
+ return out
194
+
195
+
196
+ # Image Transform Network
197
+ class ImageTransformNet(nn.Module):
198
+ def __init__(self):
199
+ super(ImageTransformNet, self).__init__()
200
+
201
+ # nonlineraity
202
+ self.relu = nn.ReLU()
203
+ self.tanh = nn.Tanh()
204
+
205
+ # encoding layers
206
+ self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
207
+ self.in1_e = nn.InstanceNorm2d(32, affine=True )
208
+
209
+ self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
210
+ self.in2_e = nn.InstanceNorm2d(64, affine=True )
211
+
212
+ self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
213
+ self.in3_e = nn.InstanceNorm2d(128, affine=True )
214
+
215
+ # residual layers
216
+ self.res1 = ResidualBlock(128)
217
+ self.res2 = ResidualBlock(128)
218
+ self.res3 = ResidualBlock(128)
219
+ self.res4 = ResidualBlock(128)
220
+ self.res5 = ResidualBlock(128)
221
+ # self.res6 = ResidualBlock(128)
222
+
223
+ # decoding layers
224
+ # TODO:
225
+ # self.deconv3 = DeConvLayer(128, 64, kernel_size=3, stride=2)
226
+ self.deconv3 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
227
+ self.in3_d = nn.InstanceNorm2d(64, affine=True )
228
+
229
+ # self.deconv2 = DeConvLayer(64, 32, kernel_size=3, stride=2)
230
+ self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
231
+ self.in2_d = nn.InstanceNorm2d(32, affine=True )
232
+
233
+ self.deconv1 = ConvLayer(32, 3, kernel_size=9, stride=1)
234
+ self.in1_d = nn.InstanceNorm2d(3, affine=True )
235
+
236
+ def forward(self, x):
237
+ # encode
238
+ y = self.relu(self.in1_e(self.conv1(x)))
239
+ y = self.relu(self.in2_e(self.conv2(y)))
240
+ y = self.relu(self.in3_e(self.conv3(y)))
241
+ # y = self.relu(self.conv1(x))
242
+ # y = self.relu(self.conv2(y))
243
+ # y = self.relu(self.conv3(y))
244
+ y_downsample = y
245
+
246
+ # residual layers
247
+ y = self.res1(y)
248
+ y = self.res2(y)
249
+ y = self.res3(y)
250
+ y = self.res4(y)
251
+ y = self.res5(y)
252
+
253
+ y_upsample = y
254
+ # decode
255
+ y = self.relu(self.in3_d(self.deconv3(y)))
256
+ y = self.relu(self.in2_d(self.deconv2(y)))
257
+ # y = self.relu(self.deconv3(y))
258
+ # y = self.relu(self.deconv2(y))
259
+ # y = self.tanh(self.in1_d(self.deconv1(y)))
260
+ y = self.deconv1(y)
261
+
262
+ # return y, y_downsample, y_upsample
263
+ return y
264
+
265
+
266
+ ALAPHA_1 = 0.25
267
+ ALAPHA_2 = 0.25
268
+ # ALAPHA_1 = 0.5
269
+ # ALAPHA_2 = 0.5
270
+ class ImageTransformNet_dpws(nn.Module):
271
+ def __init__(self):
272
+ super(ImageTransformNet_dpws, self).__init__()
273
+
274
+ # nonlineraity
275
+ self.relu = nn.ReLU()
276
+ self.tanh = nn.Tanh()
277
+
278
+ # encoding layers
279
+ self.conv1 = ConvLayer_dpws(3, int(32*ALAPHA_1), kernel_size=9, stride=1)
280
+ # self.in1_e = nn.InstanceNorm2d(int(32*ALAPHA_1), affine=True )
281
+
282
+ self.conv2 = ConvLayer_dpws(int(32*ALAPHA_1), int(64*ALAPHA_1), kernel_size=3, stride= 2)
283
+ self.conv3 = ConvLayer_dpws(int(64*ALAPHA_1), int(128*ALAPHA_2), kernel_size=3, stride= 2)
284
+
285
+ # residual layers
286
+ self.res1 = ResidualBlock_depthwise(int(128*ALAPHA_2))
287
+ self.res2 = ResidualBlock_depthwise(int(128*ALAPHA_2))
288
+ self.res3 = ResidualBlock_depthwise(int(128*ALAPHA_2))
289
+ self.res4 = ResidualBlock_depthwise(int(128*ALAPHA_2))
290
+ self.res5 = ResidualBlock_depthwise(int(128*ALAPHA_2))
291
+ # self.res6 = ResidualBlock_depthwise(128)
292
+
293
+ # decoding layers
294
+ # TODO:
295
+ # self.deconv3 = DeConvLayer(128, 64, kernel_size=3, stride=2)
296
+ self.deconv3 = UpsampleConvLayer_dpws(int(128*ALAPHA_2), int(64*ALAPHA_1), kernel_size=3, stride=1, upsample=2)
297
+ # self.in3_d = nn.InstanceNorm2d(int(64*ALAPHA_1), affine=True )
298
+
299
+ # self.deconv2 = DeConvLayer(64, 32, kernel_size=3, stride=2)
300
+ self.deconv2 = UpsampleConvLayer_dpws(int(64*ALAPHA_1), int(32*ALAPHA_1), kernel_size=3, stride=1, upsample=2)
301
+ # self.in2_d = nn.InstanceNorm2d(32, affine=True )
302
+
303
+ self.deconv1 = ConvLayer_dpws_last(int(32*ALAPHA_1), 3, kernel_size=9, stride=1)
304
+ # self.deconv1 = ConvLayer_dpws_last(int(32*ALAPHA_1), 3, kernel_size=9, stride=1)
305
+ # self.in1_d = nn.InstanceNorm2d(3, affine=True )
306
+
307
+ def forward(self, x):
308
+ # encode
309
+ # y = self.relu(self.in1_e(self.conv1(x)))
310
+ y = self.conv1(x)
311
+ y = self.conv2(y)
312
+ y = self.conv3(y)
313
+ y_downsample = y
314
+
315
+ # y = self.relu(self.in2_e(self.conv2(y)))
316
+ # y = self.relu(self.in3_e(self.conv3(y)))
317
+ # residual layers
318
+ y = self.res1(y)
319
+ y = self.res2(y)
320
+ y = self.res3(y)
321
+ y = self.res4(y)
322
+ y = self.res5(y)
323
+ # y = self.res6(y)
324
+ y_upsample = y
325
+
326
+ # decode
327
+ y = self.deconv3(y)
328
+ y = self.deconv2(y)
329
+ y = self.deconv1(y)
330
+
331
+ # return y, y_downsample, y_upsample
332
+ return y
333
+
334
+ class distiller_1(nn.Module):
335
+ def __init__(self):
336
+ super(distiller_1, self).__init__()
337
+
338
+ self.conv = nn.Conv2d(128, int(128*ALAPHA_2), kernel_size=1, stride=1)
339
+
340
+ def forward(self, x):
341
+ # encode
342
+ y = self.conv(x)
343
+
344
+ return y
345
+
346
+ class distiller_2(nn.Module):
347
+ def __init__(self):
348
+ super(distiller_2, self).__init__()
349
+
350
+ self.conv = nn.Conv2d(128, int(128*ALAPHA_2), kernel_size=1, stride=1)
351
+
352
+ def forward(self, x):
353
+ # encode
354
+ y = self.conv(x)
355
+
356
+ return y
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ torchvision==0.13.1
3
+ numpy==1.24.0
4
+ torchsummary
5
+ ptflops
style.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import os
4
+ import argparse
5
+ import time
6
+ import collections
7
+
8
+ from torch.autograd import Variable
9
+ from torch.optim import Adam
10
+ from torch.utils.data import DataLoader
11
+ from torchvision import datasets
12
+ from torchvision import transforms
13
+
14
+ import utils
15
+ from network import ImageTransformNet, ImageTransformNet_dpws
16
+ from vgg import Vgg16
17
+
18
+ # Global Variables
19
+ IMAGE_SIZE = 256
20
+ BATCH_SIZE = 4
21
+ LEARNING_RATE = 1e-3
22
+ EPOCHS = 2
23
+
24
+ # STYLE_WEIGHT = 9.2e3
25
+ # STYLE_WEIGHT = 8e4
26
+ STYLE_WEIGHT = 7e4
27
+ # STYLE_WEIGHT = 0
28
+ # CONTENT_WEIGHT = 1e-2
29
+ # CONTENT_WEIGHT = 0.15
30
+ CONTENT_WEIGHT = 0.1
31
+ L1_WEIGHT = 1
32
+
33
+ def train(args):
34
+ # GPU enabling
35
+ if (args.gpu != None):
36
+ use_cuda = True
37
+ dtype = torch.cuda.FloatTensor
38
+ torch.cuda.set_device(args.gpu)
39
+ print("Current device: %d" %torch.cuda.current_device())
40
+
41
+ # visualization of training controlled by flag
42
+ visualize = (args.visualize != None)
43
+ if (visualize):
44
+ img_transform_512 = transforms.Compose([
45
+ transforms.Scale(512), # scale shortest side to image_size
46
+ # transforms.CenterCrop(512), # crop center image_size out
47
+ transforms.ToTensor(), # turn image from [0-255] to [0-1]
48
+ utils.normalize_tensor_transform() # normalize with ImageNet values
49
+ ])
50
+
51
+
52
+ testImage_maine = utils.load_image(args.test_image)
53
+ testImage_maine = img_transform_512(testImage_maine)
54
+ testImage_maine = Variable(testImage_maine.repeat(1, 1, 1, 1), requires_grad=False).type(dtype)
55
+ test_name = os.path.split(args.test_image)[-1].split('.')[0]
56
+
57
+ # define network
58
+ image_transformer_dpws = ImageTransformNet_dpws().type(dtype)
59
+ # paras = [image_transformer_dpws.parameters()]
60
+ optimizer = Adam(image_transformer_dpws.parameters(), LEARNING_RATE)
61
+
62
+ loss_mse = torch.nn.MSELoss()
63
+ loss_l1 = torch.nn.L1Loss()
64
+
65
+ vgg = Vgg16().type(dtype)
66
+ image_transformer = ImageTransformNet().type(dtype)
67
+ image_transformer.load_state_dict(torch.load(args.load_path))
68
+
69
+ # get training dataset
70
+ dataset_transform = transforms.Compose([
71
+ transforms.Scale(IMAGE_SIZE), # scale shortest side to image_size
72
+ transforms.CenterCrop(IMAGE_SIZE), # crop center image_size out
73
+ transforms.ToTensor(), # turn image from [0-255] to [0-1]
74
+ utils.normalize_tensor_transform() # normalize with ImageNet values
75
+ ])
76
+ train_dataset = datasets.ImageFolder(args.dataset, dataset_transform)
77
+ train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE)
78
+
79
+ # style image
80
+ style_transform = transforms.Compose([
81
+ transforms.ToTensor(), # turn image from [0-255] to [0-1]
82
+ utils.normalize_tensor_transform() # normalize with ImageNet values
83
+ ])
84
+ style = utils.load_image(args.style_image)
85
+ style = style_transform(style)
86
+ style = Variable(style.repeat(BATCH_SIZE, 1, 1, 1)).type(dtype)
87
+ style_name = os.path.split(args.style_image)[-1].split('.')[0]
88
+
89
+ # calculate gram matrices for style feature layer maps we care about
90
+ style_features = vgg(style)
91
+ style_gram = [utils.gram(fmap) for fmap in style_features]
92
+
93
+ for e in range(EPOCHS):
94
+
95
+ # track values for...
96
+ img_count = 0
97
+ aggregate_style_loss = 0.0
98
+ aggregate_content_loss = 0.0
99
+ aggregate_l1_loss = 0.0
100
+ # aggregate_tv_loss = 0.0
101
+
102
+ # train network
103
+ image_transformer_dpws.train()
104
+ for batch_num, (x, label) in enumerate(train_loader):
105
+ img_batch_read = len(x)
106
+ img_count += img_batch_read
107
+
108
+ # zero out gradients
109
+ optimizer.zero_grad()
110
+
111
+ # input batch to transformer network
112
+
113
+ x = Variable(x).type(dtype)
114
+ y_hat = image_transformer_dpws(x)
115
+ y_label = image_transformer(x)
116
+
117
+ # get vgg features
118
+ y_c_features = vgg(x)
119
+ y_hat_features = vgg(y_hat)
120
+
121
+ # calculate style loss
122
+ y_hat_gram = [utils.gram(fmap) for fmap in y_hat_features]
123
+ style_loss = 0.0
124
+
125
+ for j in range(4):
126
+ style_loss += loss_mse(y_hat_gram[j], style_gram[j][:img_batch_read])
127
+ style_loss = STYLE_WEIGHT*style_loss
128
+ aggregate_style_loss += style_loss.data.item()
129
+
130
+ # calculate content loss (h_relu_2_2)
131
+ recon = y_c_features[1]
132
+ recon_hat = y_hat_features[1]
133
+ content_loss = CONTENT_WEIGHT*loss_mse(recon_hat, recon)
134
+ aggregate_content_loss += content_loss.data.item()
135
+
136
+ # calculate l1 loss
137
+ l1_loss = L1_WEIGHT*loss_mse(y_hat, y_label)
138
+ aggregate_l1_loss += l1_loss.data.item()
139
+
140
+ # total loss
141
+ # total_loss = style_loss + content_loss + tv_loss + l1_loss + dis_loss
142
+ total_loss = style_loss + l1_loss + content_loss
143
+
144
+ # backprop
145
+ total_loss.backward()
146
+ optimizer.step()
147
+
148
+ # print out status message
149
+ if ((batch_num + 1) % 100 == 0):
150
+ status = "{} Epoch {}: [{}/{}] Batch:[{}] agg_style: {:.6f} agg_l1: {:.6f} agg_content: {:.6f} ".format(
151
+ time.ctime(), e + 1, img_count, len(train_dataset), batch_num+1,
152
+ aggregate_style_loss/(batch_num+1.0), aggregate_l1_loss/(batch_num+1.0), aggregate_content_loss/(batch_num+1.0)
153
+ )
154
+ print(status)
155
+
156
+ if ((batch_num + 1) % 5000 == 0) and (visualize):
157
+ image_transformer_dpws.eval()
158
+
159
+ if not os.path.exists("visualization"):
160
+ os.makedirs("visualization")
161
+
162
+ outputTestImage_maine = image_transformer_dpws(testImage_maine)
163
+
164
+ test_path = "visualization/%s/%s%d_%05d.jpg" %(style_name, test_name, e+1, batch_num+1)
165
+ utils.save_image(test_path, outputTestImage_maine.data[0].cpu())
166
+
167
+ print("images saved")
168
+ image_transformer_dpws.train()
169
+
170
+ # save model
171
+ image_transformer_dpws.eval()
172
+
173
+ if use_cuda:
174
+ image_transformer_dpws.cpu()
175
+
176
+ if not os.path.exists("models"):
177
+ os.makedirs("models")
178
+ filename = "models/%s.model" %style_name
179
+ torch.save(image_transformer_dpws.state_dict(), filename)
180
+
181
+ if use_cuda:
182
+ image_transformer_dpws.cuda()
183
+
184
+ def style_transfer(args):
185
+ # GPU enabling
186
+ if (args.gpu != None):
187
+ use_cuda = True
188
+ dtype = torch.cuda.FloatTensor
189
+ torch.cuda.set_device(args.gpu)
190
+ print("Current device: %d" %torch.cuda.current_device())
191
+ else :
192
+ dtype = torch.FloatTensor
193
+
194
+ # content image
195
+ img_transform_512 = transforms.Compose([
196
+ # transforms.Scale(512), # scale shortest side to image_size
197
+ transforms.Resize(512), # scale shortest side to image_size
198
+ # transforms.CenterCrop(512), # crop center image_size out
199
+ transforms.ToTensor(), # turn image from [0-255] to [0-1]
200
+ utils.normalize_tensor_transform() # normalize with ImageNet values
201
+ ])
202
+
203
+ content = utils.load_image(args.source)
204
+ content = img_transform_512(content)
205
+ content = content.unsqueeze(0)
206
+ # content = Variable(content).type(dtype)
207
+ content = Variable(content.repeat(1, 1, 1, 1), requires_grad=False).type(dtype)
208
+
209
+ # load style model
210
+ checkpoint_lw = torch.load(args.model_path)
211
+
212
+ style_model = ImageTransformNet_dpws().type(dtype)
213
+ style_model.load_state_dict((checkpoint_lw))
214
+
215
+ # process input image
216
+ stylized = style_model(content).cpu()
217
+ utils.save_image(args.output, stylized.data[0])
218
+
219
+
220
+ def main():
221
+ parser = argparse.ArgumentParser(description='style transfer in pytorch')
222
+ subparsers = parser.add_subparsers(title="subcommands", dest="subcommand")
223
+
224
+ train_parser = subparsers.add_parser("train", help="train a model to do style transfer")
225
+ train_parser.add_argument("--style_image", type=str, required=True, help="path to a style image to train with")
226
+ train_parser.add_argument("--test_image", type=str, required=True, help="path to a test image to test with")
227
+ train_parser.add_argument("--dataset", type=str, required=True, help="path to a dataset")
228
+ train_parser.add_argument("--gpu", type=int, default=None, help="ID of GPU to be used")
229
+ train_parser.add_argument("--visualize", type=int, default=None, help="Set to 1 if you want to visualize training")
230
+
231
+ style_parser = subparsers.add_parser("transfer", help="do style transfer with a trained model")
232
+ style_parser.add_argument("--model_path", type=str, required=True, help="path to a pretrained model for a style image")
233
+ style_parser.add_argument("--source", type=str, required=True, help="path to source image")
234
+ style_parser.add_argument("--output", type=str, required=True, help="file name for stylized output image")
235
+ style_parser.add_argument("--gpu", type=int, default=None, help="ID of GPU to be used")
236
+
237
+ args = parser.parse_args()
238
+
239
+ # command
240
+ if (args.subcommand == "train"):
241
+ print("Training!")
242
+ train(args)
243
+ elif (args.subcommand == "transfer"):
244
+ print("Style transfering!")
245
+ style_transfer(args)
246
+ else:
247
+ print("invalid command")
248
+
249
+ if __name__ == '__main__':
250
+ main()
251
+
252
+
253
+
254
+
255
+
256
+
257
+
258
+
utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from torch.autograd import Variable
4
+ from torchvision import transforms
5
+ import numpy as np
6
+
7
+ # opens and returns image file as a PIL image (0-255)
8
+ def load_image(filename):
9
+ img = Image.open(filename)
10
+ return img
11
+
12
+ # assumes data comes in batch form (ch, h, w)
13
+ def save_image(filename, data):
14
+ std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
15
+ mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
16
+ img = data.clone().numpy()
17
+ img = ((img * std + mean).transpose(1, 2, 0)*255.0).clip(0, 255).astype("uint8")
18
+ img = Image.fromarray(img)
19
+ img.save(filename)
20
+
21
+ # Calculate Gram matrix (G = FF^T)
22
+ def gram(x):
23
+ (bs, ch, h, w) = x.size()
24
+ f = x.view(bs, ch, w*h)
25
+ f_T = f.transpose(1, 2)
26
+ G = f.bmm(f_T) / (ch * h * w)
27
+ return G
28
+
29
+ # using ImageNet values
30
+ def normalize_tensor_transform():
31
+ return transforms.Normalize(mean=[0.485, 0.456, 0.406],
32
+ std=[0.229, 0.224, 0.225])