Donmill commited on
Commit
6c3f309
·
verified ·
1 Parent(s): 58cc91e
Files changed (1) hide show
  1. test.py +63 -0
test.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from label_colors import colorMap
3
+ from PIL import Image
4
+ from spade.model import Pix2PixModel
5
+ from spade.dataset import get_transform
6
+ from torchvision.transforms import ToPILImage
7
+
8
+ '''colors = np.array([[56, 79, 131], [239, 239, 239],
9
+ [93, 110, 50], [183, 210, 78],
10
+ [60, 59, 75], [250, 250, 250]])'''
11
+ colors = [key['color'] for key in colorMap]
12
+ id_list = [key['id'] for key in colorMap]
13
+
14
+
15
+ def semantic(img):
16
+ print("semantic", type(img))
17
+ h, w = img.size
18
+ imrgb = img.convert("RGB")
19
+ pix = list(imrgb.getdata())
20
+ mask = [id_list[colors.index(i)] if i in colors else 156 for i in pix]
21
+ return np.array(mask).reshape(h, w)
22
+
23
+
24
+ def evaluate(labelmap):
25
+ opt = {
26
+ 'label_nc': 182, # num classes in coco model
27
+ 'crop_size': 512,
28
+ 'load_size': 512,
29
+ 'aspect_ratio': 1.0,
30
+ 'isTrain': False,
31
+ 'checkpoints_dir': 'app',
32
+ 'which_epoch': 'latest',
33
+ 'use_gpu': False
34
+ }
35
+ model = Pix2PixModel(opt)
36
+ model.eval()
37
+ image = Image.fromarray(np.array(labelmap).astype(np.uint8))
38
+ transform_label = get_transform(opt, method=Image.NEAREST, normalize=False)
39
+ # transforms.ToTensor in transform_label rescales image from [0,255] to [0.0,1.0]
40
+ # lets rescale it back to [0,255] to match our label ids
41
+ label_tensor = transform_label(image) * 255.0
42
+ label_tensor[label_tensor == 255] = opt['label_nc'] # 'unknown' is opt.label_nc
43
+ print("label_tensor:", label_tensor.shape)
44
+
45
+ # not using encoder, so creating a blank image...
46
+ transform_image = get_transform(opt)
47
+ image_tensor = transform_image(Image.new('RGB', (500, 500)))
48
+
49
+ data = {
50
+ 'label': label_tensor.unsqueeze(0),
51
+ 'instance': label_tensor.unsqueeze(0),
52
+ 'image': image_tensor.unsqueeze(0)
53
+ }
54
+ generated = model(data, mode='inference')
55
+ print("generated_image:", generated.shape)
56
+
57
+ return generated
58
+
59
+
60
+ def to_image(generated):
61
+ to_img = ToPILImage()
62
+ normalized_img = ((generated.reshape([3, 512, 512]) + 1) / 2.0) * 255.0
63
+ return to_img(normalized_img.byte().cpu())