File size: 990 Bytes
e3ba844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch

# hue loss
def rgb_to_hsv(image):
    r, g, b = image[:, 0, :, :], image[:, 1, :, :], image[:, 2, :, :]
    maxc = torch.max(image, dim=1)[0]
    minc = torch.min(image, dim=1)[0]

    v = maxc
    s = (maxc - minc) / (maxc + 1e-10)
    deltac = maxc - minc

    # Initialize hue
    h = torch.zeros_like(maxc)

    mask = maxc == r
    h[mask] = ((g - b) / deltac)[mask] % 6

    mask = maxc == g
    h[mask] = ((b - r) / deltac)[mask] + 2

    mask = maxc == b
    h[mask] = ((r - g) / deltac)[mask] + 4

    h = h / 6  # Normalize to [0, 1]
    h[deltac == 0] = 0  # If no color difference, set hue to 0

    return torch.stack([h, s, v], dim=1)


def hue_loss(images, target_hue=0.5):
    # Convert the images to HSV color space
    hsv_images = rgb_to_hsv(images)

    # Extract the hue channel
    hue = hsv_images[:, 0, :, :]

    # Calculate the error as the mean absolute deviation from the target hue
    error = torch.abs(hue - target_hue).mean()

    return error