Spaces:
Runtime error
Runtime error
File size: 4,914 Bytes
fc16538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import random
import torch
import torchvision.transforms as transforms
from vidar.utils.data import keys_in
from vidar.utils.decorators import iterate1
def random_colorjitter(parameters):
"""
Creates a reusable color jitter transformation
Parameters
----------
parameters : Tuple
Color jittering parameters (brightness, contrast, saturation, hue, color)
Returns
-------
transform : torchvision.Transform
Color jitter transformation with fixed parameters
"""
# Get and unpack values
brightness, contrast, saturation, hue = parameters
brightness = [max(0, 1 - brightness), 1 + brightness]
contrast = [max(0, 1 - contrast), 1 + contrast]
saturation = [max(0, 1 - saturation), 1 + saturation]
hue = [-hue, hue]
# Initialize transformation list
all_transforms = []
# Add brightness transformation
if brightness is not None:
brightness_factor = random.uniform(brightness[0], brightness[1])
all_transforms.append(transforms.Lambda(
lambda img: transforms.functional.adjust_brightness(img, brightness_factor)))
# Add contrast transformation
if contrast is not None:
contrast_factor = random.uniform(contrast[0], contrast[1])
all_transforms.append(transforms.Lambda(
lambda img: transforms.functional.adjust_contrast(img, contrast_factor)))
# Add saturation transformation
if saturation is not None:
saturation_factor = random.uniform(saturation[0], saturation[1])
all_transforms.append(transforms.Lambda(
lambda img: transforms.functional.adjust_saturation(img, saturation_factor)))
# Add hue transformation
if hue is not None:
hue_factor = random.uniform(hue[0], hue[1])
all_transforms.append(transforms.Lambda(
lambda img: transforms.functional.adjust_hue(img, hue_factor)))
# Shuffle transformation order
random.shuffle(all_transforms)
# Return composed transformation
return transforms.Compose(all_transforms)
def colorjitter_sample(samples, parameters, background=None, prob=1.0):
"""
Jitters input images as data augmentation.
Parameters
----------
samples : Dict
Input sample
parameters : tuple (brightness, contrast, saturation, hue, color)
Color jittering parameters
background: None or String
Which background color should be use
prob : Float
Jittering probability
Returns
-------
sample : dict
Jittered sample
"""
if random.random() < prob:
# Prepare jitter transformation
colorjitter_transform = random_colorjitter(parameters[:4])
# Prepare color transformation if requested
if len(parameters) > 4 and parameters[4] > 0:
matrix = (random.uniform(1. - parameters[4], 1 + parameters[4]), 0, 0, 0,
0, random.uniform(1. - parameters[4], 1 + parameters[4]), 0, 0,
0, 0, random.uniform(1. - parameters[4], 1 + parameters[4]), 0)
else:
matrix = None
for sample in samples:
# Jitter sample keys
for key in keys_in(sample, ['rgb']):
for ctx in sample[key].keys():
bkg, color = [], {'white': (255, 255, 255), 'black': (0, 0, 0)}
if background is not None:
for i in range(sample[key][ctx].size[0]):
for j in range(sample[key][ctx].size[1]):
if sample[key][ctx].getpixel((i,j)) == color[background]:
bkg.append((i,j))
sample[key][ctx] = colorjitter_transform(sample[key][ctx])
if matrix is not None:
sample[key][ctx] = sample[key][ctx].convert('RGB', matrix)
if background is not None:
for ij in bkg:
sample[key][ctx].putpixel(ij, color[background])
# Return jittered (?) sample
return samples
@iterate1
def normalize_sample(sample, mean, std):
"""
Normalize sample
Parameters
----------
sample : Dict
Input sample dictionary
mean : torch.Tensor
Normalization mean [B,3]
std : torch.Tensor
Normalization standard deviation [B,3]
Returns
-------
sample : Dict
Normalized sample
"""
# Get mean and std values in the right shape
mean = torch.tensor(mean).reshape(3, 1, 1)
std = torch.tensor(std).reshape(3, 1, 1)
# Apply mean and std to every image
for key_sample in keys_in(sample, ['rgb']):
sample[key_sample] = {key:(val - mean) / std for
key, val in sample[key_sample].items()}
return sample
|