Spaces:
Runtime error
Runtime error
Commit
·
4a10914
1
Parent(s):
2620eb0
different sizes compatibility
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
from PIL import Image
|
4 |
from torchvision import transforms
|
5 |
-
from utils import normalize_lab, denormalize_lab
|
6 |
from model import Generator
|
7 |
import kornia.color as color
|
8 |
|
@@ -15,15 +15,13 @@ model = model.to(device)
|
|
15 |
model.eval()
|
16 |
|
17 |
|
18 |
-
# Define preprocessing transforms
|
19 |
-
transform = transforms.Compose([
|
20 |
-
transforms.Resize((256, 256), Image.BICUBIC),
|
21 |
-
transforms.ToTensor(),
|
22 |
-
])
|
23 |
-
|
24 |
-
|
25 |
def preprocess(image):
|
26 |
image = image.convert('RGB')
|
|
|
|
|
|
|
|
|
|
|
27 |
image = transform(image)
|
28 |
image = image.to(device)
|
29 |
image = color.rgb_to_lab(image)
|
@@ -33,8 +31,13 @@ def preprocess(image):
|
|
33 |
print(L.shape)
|
34 |
return L.unsqueeze(0)
|
35 |
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def predict(image):
|
|
|
38 |
L = preprocess(image)
|
39 |
with torch.no_grad():
|
40 |
output = model(L)
|
@@ -42,6 +45,7 @@ def predict(image):
|
|
42 |
L, ab = denormalize_lab(L, output)
|
43 |
output = torch.cat([L, ab], dim=1)
|
44 |
output = color.lab_to_rgb(output)
|
|
|
45 |
image = transforms.ToPILImage()(output.squeeze().cpu())
|
46 |
|
47 |
return image
|
|
|
2 |
import torch
|
3 |
from PIL import Image
|
4 |
from torchvision import transforms
|
5 |
+
from utils import normalize_lab, denormalize_lab, pad_image
|
6 |
from model import Generator
|
7 |
import kornia.color as color
|
8 |
|
|
|
15 |
model.eval()
|
16 |
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def preprocess(image):
|
19 |
image = image.convert('RGB')
|
20 |
+
image = pad_image(image)
|
21 |
+
transform = transforms.Compose([
|
22 |
+
#transforms.Resize((height, width), Image.BICUBIC),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
])
|
25 |
image = transform(image)
|
26 |
image = image.to(device)
|
27 |
image = color.rgb_to_lab(image)
|
|
|
31 |
print(L.shape)
|
32 |
return L.unsqueeze(0)
|
33 |
|
34 |
+
def crop_to_original_size(image, original_size):
|
35 |
+
width, height = original_size
|
36 |
+
return transforms.functional.crop(image, top=0, left=0, height=height, width=width)
|
37 |
+
|
38 |
|
39 |
def predict(image):
|
40 |
+
original_size = image.size
|
41 |
L = preprocess(image)
|
42 |
with torch.no_grad():
|
43 |
output = model(L)
|
|
|
45 |
L, ab = denormalize_lab(L, output)
|
46 |
output = torch.cat([L, ab], dim=1)
|
47 |
output = color.lab_to_rgb(output)
|
48 |
+
output = crop_to_original_size(output, original_size)
|
49 |
image = transforms.ToPILImage()(output.squeeze().cpu())
|
50 |
|
51 |
return image
|
utils.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
def normalize_lab(L, ab):
|
2 |
"""
|
3 |
Normalize the L and ab channels of an image in Lab color space.
|
@@ -15,4 +17,33 @@ def denormalize_lab(L, ab):
|
|
15 |
"""
|
16 |
L = (L + 1) * 50.
|
17 |
ab = ab * 110.
|
18 |
-
return L, ab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import transforms
|
2 |
+
|
3 |
def normalize_lab(L, ab):
|
4 |
"""
|
5 |
Normalize the L and ab channels of an image in Lab color space.
|
|
|
17 |
"""
|
18 |
L = (L + 1) * 50.
|
19 |
ab = ab * 110.
|
20 |
+
return L, ab
|
21 |
+
|
22 |
+
def decide_size(image):
|
23 |
+
height = image.size[1]
|
24 |
+
width = image.size[0]
|
25 |
+
|
26 |
+
new_height = 2
|
27 |
+
new_width = 2
|
28 |
+
|
29 |
+
while new_height < height:
|
30 |
+
new_height *= 2
|
31 |
+
while new_width < width:
|
32 |
+
new_width *= 2
|
33 |
+
|
34 |
+
return new_height, new_width
|
35 |
+
|
36 |
+
def pad_image(image):
|
37 |
+
height = image.size[1]
|
38 |
+
width = image.size[0]
|
39 |
+
|
40 |
+
new_height, new_width = decide_size(image)
|
41 |
+
|
42 |
+
pad_height = new_height - height
|
43 |
+
pad_width = new_width - width
|
44 |
+
|
45 |
+
padding = (0, 0, pad_width, pad_height)
|
46 |
+
|
47 |
+
image = transforms.Pad(padding, padding_mode='reflect')(image)
|
48 |
+
|
49 |
+
return image
|