matikosowy commited on
Commit
2620eb0
·
1 Parent(s): 3e2af2c
Files changed (4) hide show
  1. app.py +18 -86
  2. model.pth +2 -2
  3. model.py +94 -0
  4. utils.py +18 -0
app.py CHANGED
@@ -2,88 +2,14 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
  from torchvision import transforms
5
- import torchvision.models as models
6
- import torch.nn as nn
 
7
 
8
 
9
- class ColorizingModel(nn.Module):
10
- def __init__(self):
11
- super(ColorizingModel, self).__init__()
12
-
13
- self.encoder1 = nn.Sequential(
14
- nn.Conv2d(1, 64, 3, 2, 1), # 150x150 -> 75x75
15
- nn.LeakyReLU()
16
- )
17
-
18
- self.encoder2 = nn.Sequential(
19
- nn.Conv2d(64, 128, 3, 2, 1), # 75x75 -> 38x38
20
- nn.LeakyReLU()
21
- )
22
-
23
- self.encoder3 = nn.Sequential(
24
- nn.Conv2d(128, 256, 3, 2, 1), # 38x38 -> 19x19
25
- nn.LeakyReLU()
26
- )
27
-
28
- self.encoder4 = nn.Sequential(
29
- nn.Conv2d(256, 512, 3, 2, 1), # 19x19 -> 10x10
30
- nn.LeakyReLU()
31
- )
32
-
33
- # Bottleneck
34
- self.bottleneck = nn.Sequential(
35
- nn.Flatten(),
36
- nn.Linear(512 * 10 * 10, 2048)
37
- )
38
-
39
- # Decoder
40
- self.decoder_fc = nn.Sequential(
41
- nn.Linear(2048, 512 * 10 * 10),
42
- nn.Unflatten(1, (512, 10, 10))
43
- )
44
-
45
- self.decoder1 = nn.Sequential(
46
- nn.ConvTranspose2d(512, 256, 3, 2, 1), # 10x10 -> 19x19
47
- nn.LeakyReLU()
48
- )
49
-
50
- self.decoder2 = nn.Sequential(
51
- nn.ConvTranspose2d(256, 128, 3, 2, 1, output_padding=1), # 19x19 -> 38x38
52
- nn.LeakyReLU()
53
- )
54
-
55
- self.decoder3 = nn.Sequential(
56
- nn.ConvTranspose2d(128, 64, 3, 2, 1), # 38x38 -> 75x75
57
- nn.LeakyReLU()
58
- )
59
-
60
- self.decoder4 = nn.Sequential(
61
- nn.ConvTranspose2d(64, 3, 3, 2, 1, output_padding=1), # 75x75 -> 150x150
62
- nn.Sigmoid()
63
- )
64
-
65
- def forward(self, x):
66
- # Encoder
67
- enc1 = self.encoder1(x) # 64 channels, 75x75
68
- enc2 = self.encoder2(enc1) # 128 channels, 38x38
69
- enc3 = self.encoder3(enc2) # 256 channels, 19x19
70
- enc4 = self.encoder4(enc3) # 512 channels, 10x10
71
-
72
- # Bottleneck
73
- bottleneck = self.bottleneck(enc4)
74
-
75
- # Decoder (with skip connections)
76
- dec_fc = self.decoder_fc(bottleneck)
77
- dec1 = self.decoder1(dec_fc + enc4) # Skip connection from encoder4
78
- dec2 = self.decoder2(dec1 + enc3) # Skip connection from encoder3
79
- dec3 = self.decoder3(dec2 + enc2) # Skip connection from encoder2
80
- dec4 = self.decoder4(dec3 + enc1) # Skip connection from encoder1
81
-
82
- return dec4
83
-
84
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
85
- model = ColorizingModel()
86
- model_weights = torch.load('model.pth', map_location=device)
87
  model.load_state_dict(model_weights)
88
  model = model.to(device)
89
  model.eval()
@@ -91,25 +17,31 @@ model.eval()
91
 
92
  # Define preprocessing transforms
93
  transform = transforms.Compose([
94
- transforms.Resize((150, 150)),
95
  transforms.ToTensor(),
96
- transforms.Normalize([0.5], [0.5])
97
  ])
98
 
99
 
100
  def preprocess(image):
101
- image = image.convert('L')
102
  image = transform(image)
103
- image = image.unsqueeze(0)
 
 
 
104
 
105
- return image
 
106
 
107
 
108
  def predict(image):
109
- image = preprocess(image).to(device)
110
  with torch.no_grad():
111
- output = model(image)
112
 
 
 
 
113
  image = transforms.ToPILImage()(output.squeeze().cpu())
114
 
115
  return image
 
2
  import torch
3
  from PIL import Image
4
  from torchvision import transforms
5
+ from utils import normalize_lab, denormalize_lab
6
+ from model import Generator
7
+ import kornia.color as color
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
+ model = Generator()
12
+ model_weights = torch.load('model.pth', map_location=device, weights_only=True)
13
  model.load_state_dict(model_weights)
14
  model = model.to(device)
15
  model.eval()
 
17
 
18
  # Define preprocessing transforms
19
  transform = transforms.Compose([
20
+ transforms.Resize((256, 256), Image.BICUBIC),
21
  transforms.ToTensor(),
 
22
  ])
23
 
24
 
25
  def preprocess(image):
26
+ image = image.convert('RGB')
27
  image = transform(image)
28
+ image = image.to(device)
29
+ image = color.rgb_to_lab(image)
30
+ L = image[[0], ...]
31
+ L, _ = normalize_lab(L, 0)
32
 
33
+ print(L.shape)
34
+ return L.unsqueeze(0)
35
 
36
 
37
  def predict(image):
38
+ L = preprocess(image)
39
  with torch.no_grad():
40
+ output = model(L)
41
 
42
+ L, ab = denormalize_lab(L, output)
43
+ output = torch.cat([L, ab], dim=1)
44
+ output = color.lab_to_rgb(output)
45
  image = transforms.ToPILImage()(output.squeeze().cpu())
46
 
47
  return image
model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bc47c60e171b66970021950e9af2fc8c9987fa6e029a7965a910063d4e618701
3
- size 851481002
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:893644c8e4b35d8dde82b867753c33e364c76d51b10fffeeb1ddf600220f13e4
3
+ size 217659569
model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ # Dropout layer that works even in the evaluation mode
6
+ class DropoutAlways(nn.Dropout2d):
7
+ def forward(self, x):
8
+ return F.dropout2d(x, self.p, training=True)
9
+
10
+
11
+ class DownBlock(nn.Module):
12
+ def __init__(self, in_channels, out_channels, normalize=True):
13
+ super().__init__()
14
+
15
+ self.block = nn.Sequential(
16
+ nn.Conv2d(in_channels, out_channels, 4, 2, 1, padding_mode='reflect', bias=False if normalize else True),
17
+ nn.InstanceNorm2d(out_channels, affine=True) if normalize else nn.Identity(),
18
+ # Note that nn.Identity() is just a placeholder layer that returns its input.
19
+ nn.LeakyReLU(0.2),
20
+ )
21
+
22
+ def forward(self, x):
23
+ return self.block(x)
24
+
25
+
26
+ class UpBlock(nn.Module):
27
+ def __init__(self, in_channels, out_channels, normalize=True, dropout=False, activation='relu'):
28
+ super().__init__()
29
+
30
+ self.block = nn.Sequential(
31
+ nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False if normalize else True),
32
+ nn.InstanceNorm2d(out_channels, affine=True) if normalize else nn.Identity(),
33
+ DropoutAlways(p=0.5) if dropout else nn.Identity(),
34
+ nn.ReLU() if activation == 'relu' else nn.Tanh(),
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.block(x)
39
+
40
+
41
+ class Generator(nn.Module):
42
+ def __init__(self):
43
+ super().__init__()
44
+
45
+ # Encoder
46
+ self.encoder1 = DownBlock(1, 64, normalize=False) # 256x256 -> 128x128
47
+ self.encoder2 = DownBlock(64, 128) # 128x128 -> 64x64
48
+ self.encoder3 = DownBlock(128, 256) # 64x64 -> 32x32
49
+ self.encoder4 = DownBlock(256, 512) # 32x32 -> 16x16
50
+ self.encoder5 = DownBlock(512, 512) # 16x16 -> 8x8
51
+ self.encoder6 = DownBlock(512, 512) # 8x8 -> 4x4
52
+ self.encoder7 = DownBlock(512, 512) # 4x4 -> 2x2
53
+ self.encoder8 = DownBlock(512, 512, normalize=False) # 2x2 -> 1x1
54
+
55
+ # Decoder
56
+ self.decoder1 = UpBlock(512, 512, dropout=True) # 1x1 -> 2x2
57
+ self.decoder2 = UpBlock(512 * 2, 512, dropout=True) # 2x2 -> 4x4
58
+ self.decoder3 = UpBlock(512 * 2, 512, dropout=True) # 4x4 -> 8x8
59
+ self.decoder4 = UpBlock(512 * 2, 512) # 8x8 -> 16x16
60
+ self.decoder5 = UpBlock(512 * 2, 256) # 16x16 -> 32x32
61
+ self.decoder6 = UpBlock(256 * 2, 128) # 32x32 -> 64x64
62
+ self.decoder7 = UpBlock(128 * 2, 64) # 64x64 -> 128x128
63
+ self.decoder8 = UpBlock(64 * 2, 2, normalize=False, activation='tanh') # 128x128 -> 256x256
64
+
65
+ def forward(self, x):
66
+ # Encoder
67
+ ch256_down = x
68
+ ch128_down = self.encoder1(ch256_down)
69
+ ch64_down = self.encoder2(ch128_down)
70
+ ch32_down = self.encoder3(ch64_down)
71
+ ch16_down = self.encoder4(ch32_down)
72
+ ch8_down = self.encoder5(ch16_down)
73
+ ch4_down = self.encoder6(ch8_down)
74
+ ch2_down = self.encoder7(ch4_down)
75
+ ch1 = self.encoder8(ch2_down)
76
+
77
+ # Decoder
78
+ ch2_up = self.decoder1(ch1)
79
+ ch2 = torch.cat([ch2_up, ch2_down], dim=1)
80
+ ch4_up = self.decoder2(ch2)
81
+ ch4 = torch.cat([ch4_up, ch4_down], dim=1)
82
+ ch8_up = self.decoder3(ch4)
83
+ ch8 = torch.cat([ch8_up, ch8_down], dim=1)
84
+ ch16_up = self.decoder4(ch8)
85
+ ch16 = torch.cat([ch16_up, ch16_down], dim=1)
86
+ ch32_up = self.decoder5(ch16)
87
+ ch32 = torch.cat([ch32_up, ch32_down], dim=1)
88
+ ch64_up = self.decoder6(ch32)
89
+ ch64 = torch.cat([ch64_up, ch64_down], dim=1)
90
+ ch128_up = self.decoder7(ch64)
91
+ ch128 = torch.cat([ch128_up, ch128_down], dim=1)
92
+ ch256_up = self.decoder8(ch128)
93
+
94
+ return ch256_up
utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def normalize_lab(L, ab):
2
+ """
3
+ Normalize the L and ab channels of an image in Lab color space.
4
+ (Even though ab channels are in [-128, 127] range, we divide them by 110 because higher values are very rare.
5
+ This makes the distribution closer to [-1, 1] in most cases.)
6
+ """
7
+ L = L / 50. - 1.
8
+ ab = ab / 110.
9
+ return L, ab
10
+
11
+ def denormalize_lab(L, ab):
12
+ """
13
+ Denormalize the L and ab channels of an image in Lab color space.
14
+ (reverse of normalize_lab function)
15
+ """
16
+ L = (L + 1) * 50.
17
+ ab = ab * 110.
18
+ return L, ab