|
import torch |
|
|
|
|
|
def rgb_to_hsv(image): |
|
r, g, b = image[:, 0, :, :], image[:, 1, :, :], image[:, 2, :, :] |
|
maxc = torch.max(image, dim=1)[0] |
|
minc = torch.min(image, dim=1)[0] |
|
|
|
v = maxc |
|
s = (maxc - minc) / (maxc + 1e-10) |
|
deltac = maxc - minc |
|
|
|
|
|
h = torch.zeros_like(maxc) |
|
|
|
mask = maxc == r |
|
h[mask] = ((g - b) / deltac)[mask] % 6 |
|
|
|
mask = maxc == g |
|
h[mask] = ((b - r) / deltac)[mask] + 2 |
|
|
|
mask = maxc == b |
|
h[mask] = ((r - g) / deltac)[mask] + 4 |
|
|
|
h = h / 6 |
|
h[deltac == 0] = 0 |
|
|
|
return torch.stack([h, s, v], dim=1) |
|
|
|
|
|
def hue_loss(images, target_hue=0.5): |
|
|
|
hsv_images = rgb_to_hsv(images) |
|
|
|
|
|
hue = hsv_images[:, 0, :, :] |
|
|
|
|
|
error = torch.abs(hue - target_hue).mean() |
|
|
|
return error |
|
|