RamAnanth1 commited on
Commit
9788c55
1 Parent(s): 09808dd

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +49 -0
utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import numpy as np
3
+ import torch
4
+ import torchvision.transforms as T
5
+
6
+ totensor = T.ToTensor()
7
+ topil = T.ToPILImage()
8
+
9
+ def recover_image(image, init_image, mask, background=False):
10
+ image = totensor(image)
11
+ mask = totensor(mask)
12
+ init_image = totensor(init_image)
13
+ if background:
14
+ result = mask * init_image + (1 - mask) * image
15
+ else:
16
+ result = mask * image + (1 - mask) * init_image
17
+ return topil(result)
18
+
19
+ def preprocess(image):
20
+ w, h = image.size
21
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
22
+ image = image.resize((w, h), resample=Image.LANCZOS)
23
+ image = np.array(image).astype(np.float32) / 255.0
24
+ image = image[None].transpose(0, 3, 1, 2)
25
+ image = torch.from_numpy(image)
26
+ return 2.0 * image - 1.0
27
+
28
+ def prepare_mask_and_masked_image(image, mask):
29
+ image = np.array(image.convert("RGB"))
30
+ image = image[None].transpose(0, 3, 1, 2)
31
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
32
+
33
+ mask = np.array(mask.convert("L"))
34
+ mask = mask.astype(np.float32) / 255.0
35
+ mask = mask[None, None]
36
+ mask[mask < 0.5] = 0
37
+ mask[mask >= 0.5] = 1
38
+ mask = torch.from_numpy(mask)
39
+
40
+ masked_image = image * (mask < 0.5)
41
+
42
+ return mask, masked_image
43
+
44
+ def prepare_image(image):
45
+ image = np.array(image.convert("RGB"))
46
+ image = image[None].transpose(0, 3, 1, 2)
47
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
48
+
49
+ return image[0]