matikosowy commited on
Commit
4a10914
·
1 Parent(s): 2620eb0

different sizes compatibility

Browse files
Files changed (2) hide show
  1. app.py +12 -8
  2. utils.py +32 -1
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
  from torchvision import transforms
5
- from utils import normalize_lab, denormalize_lab
6
  from model import Generator
7
  import kornia.color as color
8
 
@@ -15,15 +15,13 @@ model = model.to(device)
15
  model.eval()
16
 
17
 
18
- # Define preprocessing transforms
19
- transform = transforms.Compose([
20
- transforms.Resize((256, 256), Image.BICUBIC),
21
- transforms.ToTensor(),
22
- ])
23
-
24
-
25
  def preprocess(image):
26
  image = image.convert('RGB')
 
 
 
 
 
27
  image = transform(image)
28
  image = image.to(device)
29
  image = color.rgb_to_lab(image)
@@ -33,8 +31,13 @@ def preprocess(image):
33
  print(L.shape)
34
  return L.unsqueeze(0)
35
 
 
 
 
 
36
 
37
  def predict(image):
 
38
  L = preprocess(image)
39
  with torch.no_grad():
40
  output = model(L)
@@ -42,6 +45,7 @@ def predict(image):
42
  L, ab = denormalize_lab(L, output)
43
  output = torch.cat([L, ab], dim=1)
44
  output = color.lab_to_rgb(output)
 
45
  image = transforms.ToPILImage()(output.squeeze().cpu())
46
 
47
  return image
 
2
  import torch
3
  from PIL import Image
4
  from torchvision import transforms
5
+ from utils import normalize_lab, denormalize_lab, pad_image
6
  from model import Generator
7
  import kornia.color as color
8
 
 
15
  model.eval()
16
 
17
 
 
 
 
 
 
 
 
18
  def preprocess(image):
19
  image = image.convert('RGB')
20
+ image = pad_image(image)
21
+ transform = transforms.Compose([
22
+ #transforms.Resize((height, width), Image.BICUBIC),
23
+ transforms.ToTensor(),
24
+ ])
25
  image = transform(image)
26
  image = image.to(device)
27
  image = color.rgb_to_lab(image)
 
31
  print(L.shape)
32
  return L.unsqueeze(0)
33
 
34
+ def crop_to_original_size(image, original_size):
35
+ width, height = original_size
36
+ return transforms.functional.crop(image, top=0, left=0, height=height, width=width)
37
+
38
 
39
  def predict(image):
40
+ original_size = image.size
41
  L = preprocess(image)
42
  with torch.no_grad():
43
  output = model(L)
 
45
  L, ab = denormalize_lab(L, output)
46
  output = torch.cat([L, ab], dim=1)
47
  output = color.lab_to_rgb(output)
48
+ output = crop_to_original_size(output, original_size)
49
  image = transforms.ToPILImage()(output.squeeze().cpu())
50
 
51
  return image
utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  def normalize_lab(L, ab):
2
  """
3
  Normalize the L and ab channels of an image in Lab color space.
@@ -15,4 +17,33 @@ def denormalize_lab(L, ab):
15
  """
16
  L = (L + 1) * 50.
17
  ab = ab * 110.
18
- return L, ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+
3
  def normalize_lab(L, ab):
4
  """
5
  Normalize the L and ab channels of an image in Lab color space.
 
17
  """
18
  L = (L + 1) * 50.
19
  ab = ab * 110.
20
+ return L, ab
21
+
22
+ def decide_size(image):
23
+ height = image.size[1]
24
+ width = image.size[0]
25
+
26
+ new_height = 2
27
+ new_width = 2
28
+
29
+ while new_height < height:
30
+ new_height *= 2
31
+ while new_width < width:
32
+ new_width *= 2
33
+
34
+ return new_height, new_width
35
+
36
+ def pad_image(image):
37
+ height = image.size[1]
38
+ width = image.size[0]
39
+
40
+ new_height, new_width = decide_size(image)
41
+
42
+ pad_height = new_height - height
43
+ pad_width = new_width - width
44
+
45
+ padding = (0, 0, pad_width, pad_height)
46
+
47
+ image = transforms.Pad(padding, padding_mode='reflect')(image)
48
+
49
+ return image