SpyroSigma commited on
Commit
b42bedb
·
verified ·
1 Parent(s): 4e1a459

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +393 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import streamlit as st
3
+
4
+ # -------------------- base color ------------------
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ class BaseColor(nn.Module):
10
+ def __init__(self):
11
+ super(BaseColor, self).__init__()
12
+
13
+ self.l_cent = 50.
14
+ self.l_norm = 100.
15
+ self.ab_norm = 110.
16
+
17
+ def normalize_l(self, in_l):
18
+ return (in_l-self.l_cent)/self.l_norm
19
+
20
+ def unnormalize_l(self, in_l):
21
+ return in_l*self.l_norm + self.l_cent
22
+
23
+ def normalize_ab(self, in_ab):
24
+ return in_ab/self.ab_norm
25
+
26
+ def unnormalize_ab(self, in_ab):
27
+ return in_ab*self.ab_norm
28
+
29
+ # ------------------ eccv16 ---------------------
30
+
31
+ import numpy as np
32
+
33
+
34
+ class ECCVGenerator(BaseColor):
35
+ def __init__(self, norm_layer=nn.BatchNorm2d):
36
+ super(ECCVGenerator, self).__init__()
37
+
38
+ model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
39
+ model1+=[nn.ReLU(True),]
40
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
41
+ model1+=[nn.ReLU(True),]
42
+ model1+=[norm_layer(64),]
43
+
44
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
45
+ model2+=[nn.ReLU(True),]
46
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
47
+ model2+=[nn.ReLU(True),]
48
+ model2+=[norm_layer(128),]
49
+
50
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
51
+ model3+=[nn.ReLU(True),]
52
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
53
+ model3+=[nn.ReLU(True),]
54
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
55
+ model3+=[nn.ReLU(True),]
56
+ model3+=[norm_layer(256),]
57
+
58
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
59
+ model4+=[nn.ReLU(True),]
60
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
61
+ model4+=[nn.ReLU(True),]
62
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
63
+ model4+=[nn.ReLU(True),]
64
+ model4+=[norm_layer(512),]
65
+
66
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
67
+ model5+=[nn.ReLU(True),]
68
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
69
+ model5+=[nn.ReLU(True),]
70
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
71
+ model5+=[nn.ReLU(True),]
72
+ model5+=[norm_layer(512),]
73
+
74
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
75
+ model6+=[nn.ReLU(True),]
76
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
77
+ model6+=[nn.ReLU(True),]
78
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
79
+ model6+=[nn.ReLU(True),]
80
+ model6+=[norm_layer(512),]
81
+
82
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
83
+ model7+=[nn.ReLU(True),]
84
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
85
+ model7+=[nn.ReLU(True),]
86
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
87
+ model7+=[nn.ReLU(True),]
88
+ model7+=[norm_layer(512),]
89
+
90
+ model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
91
+ model8+=[nn.ReLU(True),]
92
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
93
+ model8+=[nn.ReLU(True),]
94
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
95
+ model8+=[nn.ReLU(True),]
96
+
97
+ model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
98
+
99
+ self.model1 = nn.Sequential(*model1)
100
+ self.model2 = nn.Sequential(*model2)
101
+ self.model3 = nn.Sequential(*model3)
102
+ self.model4 = nn.Sequential(*model4)
103
+ self.model5 = nn.Sequential(*model5)
104
+ self.model6 = nn.Sequential(*model6)
105
+ self.model7 = nn.Sequential(*model7)
106
+ self.model8 = nn.Sequential(*model8)
107
+
108
+ self.softmax = nn.Softmax(dim=1)
109
+ self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
110
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
111
+
112
+ def forward(self, input_l):
113
+ conv1_2 = self.model1(self.normalize_l(input_l))
114
+ conv2_2 = self.model2(conv1_2)
115
+ conv3_3 = self.model3(conv2_2)
116
+ conv4_3 = self.model4(conv3_3)
117
+ conv5_3 = self.model5(conv4_3)
118
+ conv6_3 = self.model6(conv5_3)
119
+ conv7_3 = self.model7(conv6_3)
120
+ conv8_3 = self.model8(conv7_3)
121
+ out_reg = self.model_out(self.softmax(conv8_3))
122
+
123
+ return self.unnormalize_ab(self.upsample4(out_reg))
124
+
125
+ def eccv16(pretrained=True):
126
+ model = ECCVGenerator()
127
+ if(pretrained):
128
+ import torch.utils.model_zoo as model_zoo
129
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
130
+ return model
131
+
132
+ # ------------------ siggraph17 ---------------------
133
+
134
+
135
+ class SIGGRAPHGenerator(BaseColor):
136
+ def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
137
+ super(SIGGRAPHGenerator, self).__init__()
138
+
139
+ # Conv1
140
+ model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
141
+ model1+=[nn.ReLU(True),]
142
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
143
+ model1+=[nn.ReLU(True),]
144
+ model1+=[norm_layer(64),]
145
+ # add a subsampling operation
146
+
147
+ # Conv2
148
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
149
+ model2+=[nn.ReLU(True),]
150
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
151
+ model2+=[nn.ReLU(True),]
152
+ model2+=[norm_layer(128),]
153
+ # add a subsampling layer operation
154
+
155
+ # Conv3
156
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
157
+ model3+=[nn.ReLU(True),]
158
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
159
+ model3+=[nn.ReLU(True),]
160
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
161
+ model3+=[nn.ReLU(True),]
162
+ model3+=[norm_layer(256),]
163
+ # add a subsampling layer operation
164
+
165
+ # Conv4
166
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
167
+ model4+=[nn.ReLU(True),]
168
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
169
+ model4+=[nn.ReLU(True),]
170
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
171
+ model4+=[nn.ReLU(True),]
172
+ model4+=[norm_layer(512),]
173
+
174
+ # Conv5
175
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
176
+ model5+=[nn.ReLU(True),]
177
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
178
+ model5+=[nn.ReLU(True),]
179
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
180
+ model5+=[nn.ReLU(True),]
181
+ model5+=[norm_layer(512),]
182
+
183
+ # Conv6
184
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
185
+ model6+=[nn.ReLU(True),]
186
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
187
+ model6+=[nn.ReLU(True),]
188
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
189
+ model6+=[nn.ReLU(True),]
190
+ model6+=[norm_layer(512),]
191
+
192
+ # Conv7
193
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
194
+ model7+=[nn.ReLU(True),]
195
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
196
+ model7+=[nn.ReLU(True),]
197
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
198
+ model7+=[nn.ReLU(True),]
199
+ model7+=[norm_layer(512),]
200
+
201
+ # Conv7
202
+ model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
203
+ model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
204
+
205
+ model8=[nn.ReLU(True),]
206
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
207
+ model8+=[nn.ReLU(True),]
208
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
209
+ model8+=[nn.ReLU(True),]
210
+ model8+=[norm_layer(256),]
211
+
212
+ # Conv9
213
+ model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
214
+ model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
215
+ # add the two feature maps above
216
+
217
+ model9=[nn.ReLU(True),]
218
+ model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
219
+ model9+=[nn.ReLU(True),]
220
+ model9+=[norm_layer(128),]
221
+
222
+ # Conv10
223
+ model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
224
+ model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
225
+ # add the two feature maps above
226
+
227
+ model10=[nn.ReLU(True),]
228
+ model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
229
+ model10+=[nn.LeakyReLU(negative_slope=.2),]
230
+
231
+ # classification output
232
+ model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
233
+
234
+ # regression output
235
+ model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
236
+ model_out+=[nn.Tanh()]
237
+
238
+ self.model1 = nn.Sequential(*model1)
239
+ self.model2 = nn.Sequential(*model2)
240
+ self.model3 = nn.Sequential(*model3)
241
+ self.model4 = nn.Sequential(*model4)
242
+ self.model5 = nn.Sequential(*model5)
243
+ self.model6 = nn.Sequential(*model6)
244
+ self.model7 = nn.Sequential(*model7)
245
+ self.model8up = nn.Sequential(*model8up)
246
+ self.model8 = nn.Sequential(*model8)
247
+ self.model9up = nn.Sequential(*model9up)
248
+ self.model9 = nn.Sequential(*model9)
249
+ self.model10up = nn.Sequential(*model10up)
250
+ self.model10 = nn.Sequential(*model10)
251
+ self.model3short8 = nn.Sequential(*model3short8)
252
+ self.model2short9 = nn.Sequential(*model2short9)
253
+ self.model1short10 = nn.Sequential(*model1short10)
254
+
255
+ self.model_class = nn.Sequential(*model_class)
256
+ self.model_out = nn.Sequential(*model_out)
257
+
258
+ self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
259
+ self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
260
+
261
+ def forward(self, input_A, input_B=None, mask_B=None):
262
+ if(input_B is None):
263
+ input_B = torch.cat((input_A*0, input_A*0), dim=1)
264
+ if(mask_B is None):
265
+ mask_B = input_A*0
266
+
267
+ conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
268
+ conv2_2 = self.model2(conv1_2[:,:,::2,::2])
269
+ conv3_3 = self.model3(conv2_2[:,:,::2,::2])
270
+ conv4_3 = self.model4(conv3_3[:,:,::2,::2])
271
+ conv5_3 = self.model5(conv4_3)
272
+ conv6_3 = self.model6(conv5_3)
273
+ conv7_3 = self.model7(conv6_3)
274
+
275
+ conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
276
+ conv8_3 = self.model8(conv8_up)
277
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
278
+ conv9_3 = self.model9(conv9_up)
279
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
280
+ conv10_2 = self.model10(conv10_up)
281
+ out_reg = self.model_out(conv10_2)
282
+
283
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
284
+ conv9_3 = self.model9(conv9_up)
285
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
286
+ conv10_2 = self.model10(conv10_up)
287
+ out_reg = self.model_out(conv10_2)
288
+
289
+ return self.unnormalize_ab(out_reg)
290
+
291
+ def siggraph17(pretrained=True):
292
+ model = SIGGRAPHGenerator()
293
+ if(pretrained):
294
+ import torch.utils.model_zoo as model_zoo
295
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
296
+ return model
297
+
298
+ # ------------------ utils ---------------------
299
+
300
+
301
+ from PIL import Image
302
+ import numpy as np
303
+ from skimage import color
304
+ import torch.nn.functional as F
305
+
306
+ def load_img(img_path):
307
+ out_np = np.asarray(Image.open(img_path))
308
+ if(out_np.ndim==2):
309
+ out_np = np.tile(out_np[:,:,None],3)
310
+ return out_np
311
+
312
+ def resize_img(img, HW=(256,256), resample=3):
313
+ return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))
314
+
315
+ def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
316
+ # return original size L and resized L as torch Tensors
317
+ img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
318
+
319
+ img_lab_orig = color.rgb2lab(img_rgb_orig)
320
+ img_lab_rs = color.rgb2lab(img_rgb_rs)
321
+
322
+ img_l_orig = img_lab_orig[:,:,0]
323
+ img_l_rs = img_lab_rs[:,:,0]
324
+
325
+ tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
326
+ tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
327
+
328
+ return (tens_orig_l, tens_rs_l)
329
+
330
+ def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
331
+ # tens_orig_l 1 x 1 x H_orig x W_orig
332
+ # out_ab 1 x 2 x H x W
333
+
334
+ HW_orig = tens_orig_l.shape[2:]
335
+ HW = out_ab.shape[2:]
336
+
337
+ # call resize function if needed
338
+ if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
339
+ out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
340
+ else:
341
+ out_ab_orig = out_ab
342
+
343
+ out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
344
+ return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
345
+
346
+
347
+ # parser = argparse.ArgumentParser()
348
+ # parser.add_argument('-i','--img_path', type=str, default='imgs/test.jpg')
349
+ # # parser.add_argument('--use_gpu', action='store_true', help='whether to use GPU')
350
+ # parser.add_argument('-o','--save_prefix', type=str, default='saved', help='will save into this file with {eccv16.png, siggraph17.png} suffixes')
351
+ # opt = parser.parse_args()
352
+
353
+ colorizer_eccv16 = eccv16(pretrained=True).eval()
354
+ colorizer_siggraph17 = siggraph17(pretrained=True).eval()
355
+
356
+ # if(opt.use_gpu):
357
+ # colorizer_eccv16.cuda()
358
+ # colorizer_siggraph17.cuda()
359
+
360
+ input_image = st.file_uploader("Upload Image : ", type=["jpg", "jpeg", "png"])
361
+
362
+ if input_image is not None:
363
+ img = load_img(input_image)
364
+ (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256))
365
+
366
+ img_bw = postprocess_tens(tens_l_orig, torch.cat((0*tens_l_orig,0*tens_l_orig),dim=1))
367
+ out_img_eccv16 = postprocess_tens(tens_l_orig, colorizer_eccv16(tens_l_rs).cpu())
368
+ out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu())
369
+
370
+ plt.imsave(f'eccv16.png{input_image.name}', out_img_eccv16)
371
+ plt.imsave(f'siggraph17.png{input_image.name}', out_img_siggraph17)
372
+
373
+ plt.figure(figsize=(12,8))
374
+ plt.subplot(2,2,1)
375
+ plt.imshow(img)
376
+ plt.title('Original')
377
+ plt.axis('off')
378
+
379
+ plt.subplot(2,2,2)
380
+ plt.imshow(img_bw)
381
+ plt.title('Input')
382
+ plt.axis('off')
383
+
384
+ plt.subplot(2,2,3)
385
+ plt.imshow(out_img_eccv16)
386
+ plt.title('Output (ECCV 16)')
387
+ plt.axis('off')
388
+
389
+ plt.subplot(2,2,4)
390
+ plt.imshow(out_img_siggraph17)
391
+ plt.title('Output (SIGGRAPH 17)')
392
+ plt.axis('off')
393
+ plt.show()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ scikit-image
3
+ numpy
4
+ matplotlib
5
+ pillow