PushkarA07 commited on
Commit
8c51cde
·
1 Parent(s): 8bf2435

Delete web_app.py

Browse files
Files changed (1) hide show
  1. web_app.py +0 -402
web_app.py DELETED
@@ -1,402 +0,0 @@
1
- from fastai.vision.models.unet import DynamicUnet
2
- from torchvision.models.resnet import resnet18
3
- from fastai.vision.learner import create_body
4
- import streamlit as st
5
- from PIL import Image
6
- import cv2 as cv
7
- import os
8
- import glob
9
- import time
10
- import numpy as np
11
- from PIL import Image
12
- from pathlib import Path
13
- from tqdm.notebook import tqdm
14
- import matplotlib.pyplot as plt
15
- from skimage.color import rgb2lab, lab2rgb
16
-
17
- # pip install fastai==2.4
18
-
19
- import torch
20
- from torch import nn, optim
21
- from torchvision import transforms
22
- from torchvision.utils import make_grid
23
- from torch.utils.data import Dataset, DataLoader
24
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
- use_colab = None
26
-
27
- SIZE = 256
28
-
29
-
30
- class ColorizationDataset(Dataset):
31
- def __init__(self, paths, split='train'):
32
- if split == 'train':
33
- self.transforms = transforms.Compose([
34
- transforms.Resize((SIZE, SIZE), Image.BICUBIC),
35
- transforms.RandomHorizontalFlip(), # A little data augmentation!
36
- ])
37
- elif split == 'val':
38
- self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
39
-
40
- self.split = split
41
- self.size = SIZE
42
- self.paths = paths
43
-
44
- def __getitem__(self, idx):
45
- img = Image.open(self.paths[idx]).convert("RGB")
46
- img = self.transforms(img)
47
- img = np.array(img)
48
- img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
49
- img_lab = transforms.ToTensor()(img_lab)
50
- L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
51
- ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
52
-
53
- return {'L': L, 'ab': ab}
54
-
55
- def __len__(self):
56
- return len(self.paths)
57
-
58
- def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs):
59
- dataset = ColorizationDataset(**kwargs)
60
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
61
- pin_memory=pin_memory)
62
- return dataloader
63
-
64
-
65
- class UnetBlock(nn.Module):
66
- def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False,
67
- innermost=False, outermost=False):
68
- super().__init__()
69
- self.outermost = outermost
70
- if input_c is None:
71
- input_c = nf
72
- downconv = nn.Conv2d(input_c, ni, kernel_size=4,
73
- stride=2, padding=1, bias=False)
74
- downrelu = nn.LeakyReLU(0.2, True)
75
- downnorm = nn.BatchNorm2d(ni)
76
- uprelu = nn.ReLU(True)
77
- upnorm = nn.BatchNorm2d(nf)
78
-
79
- if outermost:
80
- upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
81
- stride=2, padding=1)
82
- down = [downconv]
83
- up = [uprelu, upconv, nn.Tanh()]
84
- model = down + [submodule] + up
85
- elif innermost:
86
- upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4,
87
- stride=2, padding=1, bias=False)
88
- down = [downrelu, downconv]
89
- up = [uprelu, upconv, upnorm]
90
- model = down + up
91
- else:
92
- upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4,
93
- stride=2, padding=1, bias=False)
94
- down = [downrelu, downconv, downnorm]
95
- up = [uprelu, upconv, upnorm]
96
- if dropout:
97
- up += [nn.Dropout(0.5)]
98
- model = down + [submodule] + up
99
- self.model = nn.Sequential(*model)
100
-
101
- def forward(self, x):
102
- if self.outermost:
103
- return self.model(x)
104
- else:
105
- return torch.cat([x, self.model(x)], 1)
106
-
107
-
108
- class Unet(nn.Module):
109
- def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
110
- super().__init__()
111
- unet_block = UnetBlock(
112
- num_filters * 8, num_filters * 8, innermost=True)
113
- for _ in range(n_down - 5):
114
- unet_block = UnetBlock(
115
- num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True)
116
- out_filters = num_filters * 8
117
- for _ in range(3):
118
- unet_block = UnetBlock(
119
- out_filters // 2, out_filters, submodule=unet_block)
120
- out_filters //= 2
121
- self.model = UnetBlock(
122
- output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True)
123
-
124
- def forward(self, x):
125
- return self.model(x)
126
-
127
-
128
- class PatchDiscriminator(nn.Module):
129
- def __init__(self, input_c, num_filters=64, n_down=3):
130
- super().__init__()
131
- model = [self.get_layers(input_c, num_filters, norm=False)]
132
- model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2)
133
- for i in range(n_down)] # the 'if' statement is taking care of not using
134
- # stride of 2 for the last block in this loop
135
- # Make sure to not use normalization or
136
- model += [self.get_layers(num_filters * 2 **
137
- n_down, 1, s=1, norm=False, act=False)]
138
- # activation for the last layer of the model
139
- self.model = nn.Sequential(*model)
140
-
141
- # when needing to make some repeatitive blocks of layers,
142
- def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True):
143
- # it's always helpful to make a separate method for that purpose
144
- layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)]
145
- if norm:
146
- layers += [nn.BatchNorm2d(nf)]
147
- if act:
148
- layers += [nn.LeakyReLU(0.2, True)]
149
- return nn.Sequential(*layers)
150
-
151
- def forward(self, x):
152
- return self.model(x)
153
-
154
-
155
- class GANLoss(nn.Module):
156
- def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0):
157
- super().__init__()
158
- self.register_buffer('real_label', torch.tensor(real_label))
159
- self.register_buffer('fake_label', torch.tensor(fake_label))
160
- if gan_mode == 'vanilla':
161
- self.loss = nn.BCEWithLogitsLoss()
162
- elif gan_mode == 'lsgan':
163
- self.loss = nn.MSELoss()
164
-
165
- def get_labels(self, preds, target_is_real):
166
- if target_is_real:
167
- labels = self.real_label
168
- else:
169
- labels = self.fake_label
170
- return labels.expand_as(preds)
171
-
172
- def __call__(self, preds, target_is_real):
173
- labels = self.get_labels(preds, target_is_real)
174
- loss = self.loss(preds, labels)
175
- return loss
176
-
177
-
178
- def init_weights(net, init='norm', gain=0.02):
179
-
180
- def init_func(m):
181
- classname = m.__class__.__name__
182
- if hasattr(m, 'weight') and 'Conv' in classname:
183
- if init == 'norm':
184
- nn.init.normal_(m.weight.data, mean=0.0, std=gain)
185
- elif init == 'xavier':
186
- nn.init.xavier_normal_(m.weight.data, gain=gain)
187
- elif init == 'kaiming':
188
- nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
189
-
190
- if hasattr(m, 'bias') and m.bias is not None:
191
- nn.init.constant_(m.bias.data, 0.0)
192
- elif 'BatchNorm2d' in classname:
193
- nn.init.normal_(m.weight.data, 1., gain)
194
- nn.init.constant_(m.bias.data, 0.)
195
-
196
- net.apply(init_func)
197
- print(f"model initialized with {init} initialization")
198
- return net
199
-
200
-
201
- def init_model(model, device):
202
- model = model.to(device)
203
- model = init_weights(model)
204
- return model
205
-
206
-
207
- class MainModel(nn.Module):
208
- def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4,
209
- beta1=0.5, beta2=0.999, lambda_L1=100.):
210
- super().__init__()
211
-
212
- self.device = torch.device(
213
- "cuda" if torch.cuda.is_available() else "cpu")
214
- self.lambda_L1 = lambda_L1
215
-
216
- if net_G is None:
217
- self.net_G = init_model(
218
- Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device)
219
- else:
220
- self.net_G = net_G.to(self.device)
221
- self.net_D = init_model(PatchDiscriminator(
222
- input_c=3, n_down=3, num_filters=64), self.device)
223
- self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device)
224
- self.L1criterion = nn.L1Loss()
225
- self.opt_G = optim.Adam(self.net_G.parameters(),
226
- lr=lr_G, betas=(beta1, beta2))
227
- self.opt_D = optim.Adam(self.net_D.parameters(),
228
- lr=lr_D, betas=(beta1, beta2))
229
-
230
- def set_requires_grad(self, model, requires_grad=True):
231
- for p in model.parameters():
232
- p.requires_grad = requires_grad
233
-
234
- def setup_input(self, data):
235
- self.L = data['L'].to(self.device)
236
- self.ab = data['ab'].to(self.device)
237
-
238
- def forward(self):
239
- self.fake_color = self.net_G(self.L)
240
-
241
- def backward_D(self):
242
- fake_image = torch.cat([self.L, self.fake_color], dim=1)
243
- fake_preds = self.net_D(fake_image.detach())
244
- self.loss_D_fake = self.GANcriterion(fake_preds, False)
245
- real_image = torch.cat([self.L, self.ab], dim=1)
246
- real_preds = self.net_D(real_image)
247
- self.loss_D_real = self.GANcriterion(real_preds, True)
248
- self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
249
- self.loss_D.backward()
250
-
251
- def backward_G(self):
252
- fake_image = torch.cat([self.L, self.fake_color], dim=1)
253
- fake_preds = self.net_D(fake_image)
254
- self.loss_G_GAN = self.GANcriterion(fake_preds, True)
255
- self.loss_G_L1 = self.L1criterion(
256
- self.fake_color, self.ab) * self.lambda_L1
257
- self.loss_G = self.loss_G_GAN + self.loss_G_L1
258
- self.loss_G.backward()
259
-
260
- def optimize(self):
261
- self.forward()
262
- self.net_D.train()
263
- self.set_requires_grad(self.net_D, True)
264
- self.opt_D.zero_grad()
265
- self.backward_D()
266
- self.opt_D.step()
267
-
268
- self.net_G.train()
269
- self.set_requires_grad(self.net_D, False)
270
- self.opt_G.zero_grad()
271
- self.backward_G()
272
- self.opt_G.step()
273
-
274
-
275
- class AverageMeter:
276
- def __init__(self):
277
- self.reset()
278
-
279
- def reset(self):
280
- self.count, self.avg, self.sum = [0.] * 3
281
-
282
- def update(self, val, count=1):
283
- self.count += count
284
- self.sum += count * val
285
- self.avg = self.sum / self.count
286
-
287
-
288
- def create_loss_meters():
289
- loss_D_fake = AverageMeter()
290
- loss_D_real = AverageMeter()
291
- loss_D = AverageMeter()
292
- loss_G_GAN = AverageMeter()
293
- loss_G_L1 = AverageMeter()
294
- loss_G = AverageMeter()
295
-
296
- return {'loss_D_fake': loss_D_fake,
297
- 'loss_D_real': loss_D_real,
298
- 'loss_D': loss_D,
299
- 'loss_G_GAN': loss_G_GAN,
300
- 'loss_G_L1': loss_G_L1,
301
- 'loss_G': loss_G}
302
-
303
-
304
- def update_losses(model, loss_meter_dict, count):
305
- for loss_name, loss_meter in loss_meter_dict.items():
306
- loss = getattr(model, loss_name)
307
- loss_meter.update(loss.item(), count=count)
308
-
309
-
310
- def lab_to_rgb(L, ab):
311
- L = (L + 1.) * 50.
312
- ab = ab * 110.
313
- Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
314
- rgb_imgs = []
315
- for img in Lab:
316
- img_rgb = lab2rgb(img)
317
- rgb_imgs.append(img_rgb)
318
- return np.stack(rgb_imgs, axis=0)
319
-
320
-
321
- def visualize(model, data, dims):
322
- model.net_G.eval()
323
- with torch.no_grad():
324
- model.setup_input(data)
325
- model.forward()
326
- model.net_G.train()
327
- fake_color = model.fake_color.detach()
328
- real_color = model.ab
329
- L = model.L
330
- fake_imgs = lab_to_rgb(L, fake_color)
331
- real_imgs = lab_to_rgb(L, real_color)
332
- for i in range(1):
333
- # t_img = transforms.Resize((dims[0], dims[1]))(t_img)
334
- img = Image.fromarray(np.uint8(fake_imgs[i]))
335
- img = cv.resize(fake_imgs[i], dsize=(
336
- dims[1], dims[0]), interpolation=cv.INTER_CUBIC)
337
- # st.text(f"Size of fake image {fake_imgs[i].shape} \n Type of image = {type(fake_imgs[i])}")
338
- st.image(img, caption="Output image",
339
- use_column_width='auto', clamp=True)
340
-
341
-
342
- def log_results(loss_meter_dict):
343
- for loss_name, loss_meter in loss_meter_dict.items():
344
- print(f"{loss_name}: {loss_meter.avg:.5f}")
345
-
346
- def build_res_unet(n_input=1, n_output=2, size=256):
347
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
348
- body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2)
349
- net_G = DynamicUnet(body, n_output, (size, size)).to(device)
350
- return net_G
351
-
352
-
353
- net_G = build_res_unet(n_input=1, n_output=2, size=256)
354
- net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device))
355
- model = MainModel(net_G=net_G)
356
- model.load_state_dict(torch.load("main-model.pt", map_location=device))
357
-
358
-
359
- class MyDataset(torch.utils.data.Dataset):
360
- def __init__(self, img_list):
361
- super(MyDataset, self).__init__()
362
- self.img_list = img_list
363
- self.augmentations = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
364
-
365
- def __len__(self):
366
- return len(self.img_list)
367
-
368
- def __getitem__(self, idx):
369
- img = self.img_list[idx]
370
- img = self.augmentations(img)
371
- img = np.array(img)
372
- img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
373
- img_lab = transforms.ToTensor()(img_lab)
374
- L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
375
- ab = img_lab[[1, 2], ...] / 110.
376
- return {'L': L, 'ab': ab}
377
-
378
- def make_dataloaders2(batch_size=16, n_workers=4, pin_memory=True, **kwargs):
379
- dataset = MyDataset(**kwargs)
380
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
381
- pin_memory=pin_memory)
382
- return dataloader
383
-
384
-
385
- # st.set_option('deprecation.showfileUploaderEncoding', False)
386
- # @st.cache(allow_output_mutation= True)
387
- st.write("""
388
- # Image Recolorisation
389
- """
390
- )
391
- file_up = st.file_uploader("Upload an jpg image", type=["jpg", "jpeg", "png"])
392
-
393
- if file_up is not None:
394
- im = Image.open(file_up)
395
- st.text(body=f"Size of uploaded image {im.shape}")
396
- a = im.shape
397
- st.image(im, caption="Uploaded Image.", use_column_width='auto')
398
- test_dl = make_dataloaders2(img_list=[im])
399
- for data in test_dl:
400
- model.setup_input(data)
401
- model.optimize()
402
- visualize(model, data, a)