sayanbanerjee32's picture
Upload folder using huggingface_hub
e3ba844 verified
raw
history blame contribute delete
990 Bytes
import torch
# hue loss
def rgb_to_hsv(image):
r, g, b = image[:, 0, :, :], image[:, 1, :, :], image[:, 2, :, :]
maxc = torch.max(image, dim=1)[0]
minc = torch.min(image, dim=1)[0]
v = maxc
s = (maxc - minc) / (maxc + 1e-10)
deltac = maxc - minc
# Initialize hue
h = torch.zeros_like(maxc)
mask = maxc == r
h[mask] = ((g - b) / deltac)[mask] % 6
mask = maxc == g
h[mask] = ((b - r) / deltac)[mask] + 2
mask = maxc == b
h[mask] = ((r - g) / deltac)[mask] + 4
h = h / 6 # Normalize to [0, 1]
h[deltac == 0] = 0 # If no color difference, set hue to 0
return torch.stack([h, s, v], dim=1)
def hue_loss(images, target_hue=0.5):
# Convert the images to HSV color space
hsv_images = rgb_to_hsv(images)
# Extract the hue channel
hue = hsv_images[:, 0, :, :]
# Calculate the error as the mean absolute deviation from the target hue
error = torch.abs(hue - target_hue).mean()
return error