Yarflam commited on
Commit
a15cce2
·
1 Parent(s): c954f09

Load Image

Browse files
Files changed (2) hide show
  1. ModelLoader.py +20 -2
  2. util/get_transform.py +142 -0
ModelLoader.py CHANGED
@@ -1,4 +1,6 @@
1
  from models import create_model
 
 
2
  import os
3
 
4
  ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints')
@@ -14,6 +16,7 @@ class Options(object):
14
  class ModelLoader:
15
  def __init__(self) -> None:
16
  self.opt = Options({
 
17
  'name': 'original',
18
  'checkpoints_dir': ckp_path,
19
  'gpu_ids': [],
@@ -28,7 +31,8 @@ class ModelLoader:
28
  'ndf': 64,
29
  'netD': 'basic',
30
  'netG': 'resnet_9blocks',
31
- 'netF': 'reshape',
 
32
  'ngf': 64,
33
  'no_antialias_up': None,
34
  'no_antialias': None,
@@ -41,12 +45,26 @@ class ModelLoader:
41
  'serial_batches': True, # disable data shuffling; comment this line if results on randomly chosen images are needed.
42
  'no_flip': True, # no flip; comment this line if results on flipped images are needed.
43
  'display_id': -1, # no visdom display; the test code saves the results to a HTML file.
 
 
 
 
44
  })
 
 
45
  def load(self) -> None:
46
  self.model = create_model(self.opt)
47
  self.model.load_networks('latest')
48
  def inference(self, src=''):
 
49
  if not os.path.isfile(src):
50
  raise Exception('The image %s is not found!' % src)
51
- # if exist_file()
52
  print('Loading the image %s' % src)
 
 
 
 
 
 
 
 
1
  from models import create_model
2
+ from util.get_transform import get_transform
3
+ from PIL import Image
4
  import os
5
 
6
  ckp_path = os.path.join(os.path.dirname(__file__), 'checkpoints')
 
16
  class ModelLoader:
17
  def __init__(self) -> None:
18
  self.opt = Options({
19
+ 'isGradio': True, # Custom
20
  'name': 'original',
21
  'checkpoints_dir': ckp_path,
22
  'gpu_ids': [],
 
31
  'ndf': 64,
32
  'netD': 'basic',
33
  'netG': 'resnet_9blocks',
34
+ 'netF': 'mlp_sample',
35
+ 'netF_nc': 256,
36
  'ngf': 64,
37
  'no_antialias_up': None,
38
  'no_antialias': None,
 
45
  'serial_batches': True, # disable data shuffling; comment this line if results on randomly chosen images are needed.
46
  'no_flip': True, # no flip; comment this line if results on flipped images are needed.
47
  'display_id': -1, # no visdom display; the test code saves the results to a HTML file.
48
+ 'direction': 'AtoB', # inference
49
+ 'flip_equivariance': False,
50
+ 'load_size': 1680,
51
+ 'crop_size': 512,
52
  })
53
+ self.transform = get_transform(self.opt, grayscale=False)
54
+ self.model = None
55
  def load(self) -> None:
56
  self.model = create_model(self.opt)
57
  self.model.load_networks('latest')
58
  def inference(self, src=''):
59
+ if self.model == None: self.load()
60
  if not os.path.isfile(src):
61
  raise Exception('The image %s is not found!' % src)
62
+ # Loading
63
  print('Loading the image %s' % src)
64
+ source = Image.open(src).convert('RGB')
65
+ img = self.transform(source)
66
+ print(img.shape)
67
+ # Inference
68
+ self.model.set_input({ 'A': img, 'B': img, 'A_paths': src })
69
+ self.model.forward()
70
+ print(self.model)
util/get_transform.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision.transforms as transforms
2
+ from PIL import Image
3
+ import numpy as np
4
+
5
+ def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
6
+ transform_list = []
7
+ if grayscale:
8
+ transform_list.append(transforms.Grayscale(1))
9
+ if 'fixsize' in opt.preprocess:
10
+ transform_list.append(transforms.Resize(params["size"], method))
11
+ if 'resize' in opt.preprocess:
12
+ osize = [opt.load_size, opt.load_size]
13
+ if "gta2cityscapes" in opt.dataroot:
14
+ osize[0] = opt.load_size // 2
15
+ transform_list.append(transforms.Resize(osize, method))
16
+ elif 'scale_width' in opt.preprocess:
17
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
18
+ elif 'scale_shortside' in opt.preprocess:
19
+ transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method)))
20
+
21
+ if 'zoom' in opt.preprocess:
22
+ if params is None:
23
+ transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method)))
24
+ else:
25
+ transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"])))
26
+
27
+ if 'crop' in opt.preprocess:
28
+ if params is None or 'crop_pos' not in params:
29
+ transform_list.append(transforms.RandomCrop(opt.crop_size))
30
+ else:
31
+ transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
32
+
33
+ if 'patch' in opt.preprocess:
34
+ transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size)))
35
+
36
+ if 'trim' in opt.preprocess:
37
+ transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size)))
38
+
39
+ # if opt.preprocess == 'none':
40
+ transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
41
+
42
+ if not opt.no_flip:
43
+ if params is None or 'flip' not in params:
44
+ transform_list.append(transforms.RandomHorizontalFlip())
45
+ elif 'flip' in params:
46
+ transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
47
+
48
+ if convert:
49
+ transform_list += [transforms.ToTensor()]
50
+ if grayscale:
51
+ transform_list += [transforms.Normalize((0.5,), (0.5,))]
52
+ else:
53
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
54
+ return transforms.Compose(transform_list)
55
+
56
+ def __make_power_2(img, base, method=Image.BICUBIC):
57
+ ow, oh = img.size
58
+ h = int(round(oh / base) * base)
59
+ w = int(round(ow / base) * base)
60
+ if h == oh and w == ow:
61
+ return img
62
+
63
+ return img.resize((w, h), method)
64
+
65
+
66
+ def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None):
67
+ if factor is None:
68
+ zoom_level = np.random.uniform(0.8, 1.0, size=[2])
69
+ else:
70
+ zoom_level = (factor[0], factor[1])
71
+ iw, ih = img.size
72
+ zoomw = max(crop_width, iw * zoom_level[0])
73
+ zoomh = max(crop_width, ih * zoom_level[1])
74
+ img = img.resize((int(round(zoomw)), int(round(zoomh))), method)
75
+ return img
76
+
77
+
78
+ def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC):
79
+ ow, oh = img.size
80
+ shortside = min(ow, oh)
81
+ if shortside >= target_width:
82
+ return img
83
+ else:
84
+ scale = target_width / shortside
85
+ return img.resize((round(ow * scale), round(oh * scale)), method)
86
+
87
+
88
+ def __trim(img, trim_width):
89
+ ow, oh = img.size
90
+ if ow > trim_width:
91
+ xstart = np.random.randint(ow - trim_width)
92
+ xend = xstart + trim_width
93
+ else:
94
+ xstart = 0
95
+ xend = ow
96
+ if oh > trim_width:
97
+ ystart = np.random.randint(oh - trim_width)
98
+ yend = ystart + trim_width
99
+ else:
100
+ ystart = 0
101
+ yend = oh
102
+ return img.crop((xstart, ystart, xend, yend))
103
+
104
+
105
+ def __scale_width(img, target_width, crop_width, method=Image.BICUBIC):
106
+ ow, oh = img.size
107
+ if ow == target_width and oh >= crop_width:
108
+ return img
109
+ w = target_width
110
+ h = int(max(target_width * oh / ow, crop_width))
111
+ return img.resize((w, h), method)
112
+
113
+
114
+ def __crop(img, pos, size):
115
+ ow, oh = img.size
116
+ x1, y1 = pos
117
+ tw = th = size
118
+ if (ow > tw or oh > th):
119
+ return img.crop((x1, y1, x1 + tw, y1 + th))
120
+ return img
121
+
122
+
123
+ def __patch(img, index, size):
124
+ ow, oh = img.size
125
+ nw, nh = ow // size, oh // size
126
+ roomx = ow - nw * size
127
+ roomy = oh - nh * size
128
+ startx = np.random.randint(int(roomx) + 1)
129
+ starty = np.random.randint(int(roomy) + 1)
130
+
131
+ index = index % (nw * nh)
132
+ ix = index // nh
133
+ iy = index % nh
134
+ gridx = startx + ix * size
135
+ gridy = starty + iy * size
136
+ return img.crop((gridx, gridy, gridx + size, gridy + size))
137
+
138
+
139
+ def __flip(img, flip):
140
+ if flip:
141
+ return img.transpose(Image.FLIP_LEFT_RIGHT)
142
+ return img