import torchvision.transforms as transforms from PIL import Image import numpy as np def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): transform_list = [] if grayscale: transform_list.append(transforms.Grayscale(1)) if 'fixsize' in opt.preprocess: transform_list.append(transforms.Resize(params["size"], method)) if 'resize' in opt.preprocess: osize = [opt.load_size, opt.load_size] if "gta2cityscapes" in opt.dataroot: osize[0] = opt.load_size // 2 transform_list.append(transforms.Resize(osize, method)) elif 'scale_width' in opt.preprocess: transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) elif 'scale_shortside' in opt.preprocess: transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, opt.crop_size, method))) if opt.preprocess == 'yarflam_auto': transform_list.append(transforms.Lambda(lambda img: __scale_yarflam(img, opt.yarflam_img_wh, method))) if 'zoom' in opt.preprocess: if params is None: transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method))) else: transform_list.append(transforms.Lambda(lambda img: __random_zoom(img, opt.load_size, opt.crop_size, method, factor=params["scale_factor"]))) if 'crop' in opt.preprocess: if params is None or 'crop_pos' not in params: transform_list.append(transforms.RandomCrop(opt.crop_size)) else: transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) if 'patch' in opt.preprocess: transform_list.append(transforms.Lambda(lambda img: __patch(img, params['patch_index'], opt.crop_size))) if 'trim' in opt.preprocess: transform_list.append(transforms.Lambda(lambda img: __trim(img, opt.crop_size))) # if opt.preprocess == 'none': transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) if not opt.no_flip: if params is None or 'flip' not in params: transform_list.append(transforms.RandomHorizontalFlip()) elif 'flip' in params: transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) if convert: transform_list += [transforms.ToTensor()] if grayscale: transform_list += [transforms.Normalize((0.5,), (0.5,))] else: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def __make_power_2(img, base, method=Image.BICUBIC): ow, oh = img.size h = int(round(oh / base) * base) w = int(round(ow / base) * base) if h == oh and w == ow: return img return img.resize((w, h), method) def __random_zoom(img, target_width, crop_width, method=Image.BICUBIC, factor=None): if factor is None: zoom_level = np.random.uniform(0.8, 1.0, size=[2]) else: zoom_level = (factor[0], factor[1]) iw, ih = img.size zoomw = max(crop_width, iw * zoom_level[0]) zoomh = max(crop_width, ih * zoom_level[1]) img = img.resize((int(round(zoomw)), int(round(zoomh))), method) return img def __scale_shortside(img, target_width, crop_width, method=Image.BICUBIC): ow, oh = img.size shortside = min(ow, oh) if shortside >= target_width: return img else: scale = target_width / shortside return img.resize((round(ow * scale), round(oh * scale)), method) def __trim(img, trim_width): ow, oh = img.size if ow > trim_width: xstart = np.random.randint(ow - trim_width) xend = xstart + trim_width else: xstart = 0 xend = ow if oh > trim_width: ystart = np.random.randint(oh - trim_width) yend = ystart + trim_width else: ystart = 0 yend = oh return img.crop((xstart, ystart, xend, yend)) def __scale_width(img, target_width, crop_width, method=Image.BICUBIC): ow, oh = img.size if ow == target_width and oh >= crop_width: return img w = target_width h = int(max(target_width * oh / ow, crop_width)) return img.resize((w, h), method) def __scale_yarflam(img, target_wh, method=Image.BICUBIC): ow, oh = img.size if max(ow, oh) <= target_wh: return img if ow > target_wh and oh > target_wh: ratio = target_wh / max(ow, oh) w, h = int(ow * ratio), int(oh * ratio) elif ow > target_wh: w, h = target_wh, int((oh / ow) * target_wh) else: w, h = int((ow / oh) * target_wh), target_wh return img.resize((w, h), method) def __crop(img, pos, size): ow, oh = img.size x1, y1 = pos tw = th = size if (ow > tw or oh > th): return img.crop((x1, y1, x1 + tw, y1 + th)) return img def __patch(img, index, size): ow, oh = img.size nw, nh = ow // size, oh // size roomx = ow - nw * size roomy = oh - nh * size startx = np.random.randint(int(roomx) + 1) starty = np.random.randint(int(roomy) + 1) index = index % (nw * nh) ix = index // nh iy = index % nh gridx = startx + ix * size gridy = starty + iy * size return img.crop((gridx, gridy, gridx + size, gridy + size)) def __flip(img, flip): if flip: return img.transpose(Image.FLIP_LEFT_RIGHT) return img