# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. import random import torch import torchvision.transforms as transforms from vidar.utils.data import keys_in from vidar.utils.decorators import iterate1 def random_colorjitter(parameters): """ Creates a reusable color jitter transformation Parameters ---------- parameters : Tuple Color jittering parameters (brightness, contrast, saturation, hue, color) Returns ------- transform : torchvision.Transform Color jitter transformation with fixed parameters """ # Get and unpack values brightness, contrast, saturation, hue = parameters brightness = [max(0, 1 - brightness), 1 + brightness] contrast = [max(0, 1 - contrast), 1 + contrast] saturation = [max(0, 1 - saturation), 1 + saturation] hue = [-hue, hue] # Initialize transformation list all_transforms = [] # Add brightness transformation if brightness is not None: brightness_factor = random.uniform(brightness[0], brightness[1]) all_transforms.append(transforms.Lambda( lambda img: transforms.functional.adjust_brightness(img, brightness_factor))) # Add contrast transformation if contrast is not None: contrast_factor = random.uniform(contrast[0], contrast[1]) all_transforms.append(transforms.Lambda( lambda img: transforms.functional.adjust_contrast(img, contrast_factor))) # Add saturation transformation if saturation is not None: saturation_factor = random.uniform(saturation[0], saturation[1]) all_transforms.append(transforms.Lambda( lambda img: transforms.functional.adjust_saturation(img, saturation_factor))) # Add hue transformation if hue is not None: hue_factor = random.uniform(hue[0], hue[1]) all_transforms.append(transforms.Lambda( lambda img: transforms.functional.adjust_hue(img, hue_factor))) # Shuffle transformation order random.shuffle(all_transforms) # Return composed transformation return transforms.Compose(all_transforms) def colorjitter_sample(samples, parameters, background=None, prob=1.0): """ Jitters input images as data augmentation. Parameters ---------- samples : Dict Input sample parameters : tuple (brightness, contrast, saturation, hue, color) Color jittering parameters background: None or String Which background color should be use prob : Float Jittering probability Returns ------- sample : dict Jittered sample """ if random.random() < prob: # Prepare jitter transformation colorjitter_transform = random_colorjitter(parameters[:4]) # Prepare color transformation if requested if len(parameters) > 4 and parameters[4] > 0: matrix = (random.uniform(1. - parameters[4], 1 + parameters[4]), 0, 0, 0, 0, random.uniform(1. - parameters[4], 1 + parameters[4]), 0, 0, 0, 0, random.uniform(1. - parameters[4], 1 + parameters[4]), 0) else: matrix = None for sample in samples: # Jitter sample keys for key in keys_in(sample, ['rgb']): for ctx in sample[key].keys(): bkg, color = [], {'white': (255, 255, 255), 'black': (0, 0, 0)} if background is not None: for i in range(sample[key][ctx].size[0]): for j in range(sample[key][ctx].size[1]): if sample[key][ctx].getpixel((i,j)) == color[background]: bkg.append((i,j)) sample[key][ctx] = colorjitter_transform(sample[key][ctx]) if matrix is not None: sample[key][ctx] = sample[key][ctx].convert('RGB', matrix) if background is not None: for ij in bkg: sample[key][ctx].putpixel(ij, color[background]) # Return jittered (?) sample return samples @iterate1 def normalize_sample(sample, mean, std): """ Normalize sample Parameters ---------- sample : Dict Input sample dictionary mean : torch.Tensor Normalization mean [B,3] std : torch.Tensor Normalization standard deviation [B,3] Returns ------- sample : Dict Normalized sample """ # Get mean and std values in the right shape mean = torch.tensor(mean).reshape(3, 1, 1) std = torch.tensor(std).reshape(3, 1, 1) # Apply mean and std to every image for key_sample in keys_in(sample, ['rgb']): sample[key_sample] = {key:(val - mean) / std for key, val in sample[key_sample].items()} return sample