NamedCurves / models /color_naming.py
davidserra9's picture
First commit from github repo
117183e verified
"""
color_naming.py - Contains the Joost van de Weijer et al. (2009) color naming model.
David Serrano (dserrano@cvc.uab.cat)
May 2024
"""
import os
import pathlib
from scipy.io import loadmat
import torch
from torch import tensor as to_tensor
from torchvision.transforms.functional import pil_to_tensor
class ColorNaming():
def __init__(self, matrix_path=os.path.join(str(pathlib.Path(__file__).parent.resolve()), "joost_color_naming.mat"),
num_categories=6,
device='cuda'):
""" Van de Weijer et al. (2009) Color Naming model python implementation.
Van De Weijer, J. et al. Learning color names for real-world applications. IEEE Transactions on Image Processing
The class is based on the MATLAB implementation by Van de Weijer et al. (2009) and it needs the w2c.mat original
file. The input RGB image is converted to a set (6 or 11) color naming probability maps.
If num_categories is 6: orange-brown-yellow, achromatic, pink-purple, red, green, blue.
If num_categories is 11: black, blue, brown, gray, green, orange, pink, purple, red, white, yellow.
"""
self.matrix = to_tensor(loadmat(matrix_path)['w2c']).to(device)
self.num_categories = num_categories
self.device = device
if num_categories == 6:
self.color_categories = [[2,5,10], [0,3,9], [6,7], [8], [4], [1]]
self.color_categories = [torch.tensor(x).to(device) for x in self.color_categories]
def __call__(self, input_tensor):
"""Converts an RGB image to a color naming image.
Args:
input_tensor: batch of RGB images (B x 3 x H x W)
Returns:
torch.tensor: Color naming image.
"""
# Reconvert image to [0-255] range
input_tensor = torch.clamp(input_tensor, 0, 1)
img = (input_tensor * 255).int()
index_tensor = torch.floor(
img[:, 0, ...].view(img.shape[0], -1) / 8).long() + 32 * torch.floor(
img[:, 1, ...].view(img.shape[0], -1) / 8).long() + 32 * 32 * torch.floor(
img[:, 2, ...].view(img.shape[0], -1) / 8).long()
prob_maps = []
for w2cM in self.matrix.permute(*torch.arange(self.matrix.ndim-1, -1, -1)):
out = w2cM[index_tensor].view(input_tensor.size(0), input_tensor.size(2), input_tensor.size(3))
prob_maps.append(out)
prob_maps = torch.stack(prob_maps, dim=0)
if self.num_categories == 11:
return prob_maps
elif self.num_categories == 6:
category_probs = [] # prob maps for each color category. [0, 1]
for category in self.color_categories:
cat_tensors = torch.index_select(prob_maps, 0, category).sum(dim=0)
category_probs.append(cat_tensors)
category_probs = torch.stack(category_probs, dim=0)
return category_probs