Spaces:
Running
on
Zero
Running
on
Zero
""" | |
color_naming.py - Contains the Joost van de Weijer et al. (2009) color naming model. | |
David Serrano (dserrano@cvc.uab.cat) | |
May 2024 | |
""" | |
import os | |
import pathlib | |
from scipy.io import loadmat | |
import torch | |
from torch import tensor as to_tensor | |
from torchvision.transforms.functional import pil_to_tensor | |
class ColorNaming(): | |
def __init__(self, matrix_path=os.path.join(str(pathlib.Path(__file__).parent.resolve()), "joost_color_naming.mat"), | |
num_categories=6, | |
device='cuda'): | |
""" Van de Weijer et al. (2009) Color Naming model python implementation. | |
Van De Weijer, J. et al. Learning color names for real-world applications. IEEE Transactions on Image Processing | |
The class is based on the MATLAB implementation by Van de Weijer et al. (2009) and it needs the w2c.mat original | |
file. The input RGB image is converted to a set (6 or 11) color naming probability maps. | |
If num_categories is 6: orange-brown-yellow, achromatic, pink-purple, red, green, blue. | |
If num_categories is 11: black, blue, brown, gray, green, orange, pink, purple, red, white, yellow. | |
""" | |
self.matrix = to_tensor(loadmat(matrix_path)['w2c']).to(device) | |
self.num_categories = num_categories | |
self.device = device | |
if num_categories == 6: | |
self.color_categories = [[2,5,10], [0,3,9], [6,7], [8], [4], [1]] | |
self.color_categories = [torch.tensor(x).to(device) for x in self.color_categories] | |
def __call__(self, input_tensor): | |
"""Converts an RGB image to a color naming image. | |
Args: | |
input_tensor: batch of RGB images (B x 3 x H x W) | |
Returns: | |
torch.tensor: Color naming image. | |
""" | |
# Reconvert image to [0-255] range | |
input_tensor = torch.clamp(input_tensor, 0, 1) | |
img = (input_tensor * 255).int() | |
index_tensor = torch.floor( | |
img[:, 0, ...].view(img.shape[0], -1) / 8).long() + 32 * torch.floor( | |
img[:, 1, ...].view(img.shape[0], -1) / 8).long() + 32 * 32 * torch.floor( | |
img[:, 2, ...].view(img.shape[0], -1) / 8).long() | |
prob_maps = [] | |
for w2cM in self.matrix.permute(*torch.arange(self.matrix.ndim-1, -1, -1)): | |
out = w2cM[index_tensor].view(input_tensor.size(0), input_tensor.size(2), input_tensor.size(3)) | |
prob_maps.append(out) | |
prob_maps = torch.stack(prob_maps, dim=0) | |
if self.num_categories == 11: | |
return prob_maps | |
elif self.num_categories == 6: | |
category_probs = [] # prob maps for each color category. [0, 1] | |
for category in self.color_categories: | |
cat_tensors = torch.index_select(prob_maps, 0, category).sum(dim=0) | |
category_probs.append(cat_tensors) | |
category_probs = torch.stack(category_probs, dim=0) | |
return category_probs |