File size: 2,931 Bytes
117183e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
color_naming.py - Contains the Joost van de Weijer et al. (2009) color naming model.

David Serrano (dserrano@cvc.uab.cat)
May 2024
"""

import os
import pathlib
from scipy.io import loadmat
import torch
from torch import tensor as to_tensor
from torchvision.transforms.functional import pil_to_tensor

class ColorNaming():
    def __init__(self, matrix_path=os.path.join(str(pathlib.Path(__file__).parent.resolve()), "joost_color_naming.mat"),
                 num_categories=6,
                 device='cuda'):
        """ Van de Weijer et al. (2009) Color Naming model python implementation.
        Van De Weijer, J. et al. Learning color names for real-world applications. IEEE Transactions on Image Processing
        The class is based on the MATLAB implementation by Van de Weijer et al. (2009) and it needs the w2c.mat original
        file. The input RGB image is converted to a set (6 or 11) color naming probability maps.

        If num_categories is 6: orange-brown-yellow, achromatic, pink-purple, red, green, blue.
        If num_categories is 11: black, blue, brown, gray, green, orange, pink, purple, red, white, yellow.
        """
        self.matrix = to_tensor(loadmat(matrix_path)['w2c']).to(device)
        self.num_categories = num_categories
        self.device = device

        if num_categories == 6:
            self.color_categories = [[2,5,10], [0,3,9], [6,7], [8], [4], [1]]
            self.color_categories = [torch.tensor(x).to(device) for x in self.color_categories]

    def __call__(self, input_tensor):
        """Converts an RGB image to a color naming image.

        Args:
        input_tensor: batch of RGB images (B x 3 x H x W)

        Returns:
            torch.tensor: Color naming image.
        """
        # Reconvert image to [0-255] range
        input_tensor = torch.clamp(input_tensor, 0, 1)
        img = (input_tensor * 255).int()

        index_tensor = torch.floor(
            img[:, 0, ...].view(img.shape[0], -1) / 8).long() + 32 * torch.floor(
            img[:, 1, ...].view(img.shape[0], -1) / 8).long() + 32 * 32 * torch.floor(
            img[:, 2, ...].view(img.shape[0], -1) / 8).long()

        prob_maps = []
        for w2cM in self.matrix.permute(*torch.arange(self.matrix.ndim-1, -1, -1)):
            out = w2cM[index_tensor].view(input_tensor.size(0), input_tensor.size(2), input_tensor.size(3))
            prob_maps.append(out)
        prob_maps = torch.stack(prob_maps, dim=0)

        if self.num_categories == 11:
            return prob_maps

        elif self.num_categories == 6:
            category_probs = []  # prob maps for each color category. [0, 1]
            for category in self.color_categories:
                cat_tensors = torch.index_select(prob_maps, 0, category).sum(dim=0)
                category_probs.append(cat_tensors)

            category_probs = torch.stack(category_probs, dim=0)

            return category_probs