vitaliykinakh's picture
Initial
8d6cd57
raw
history blame
1.94 kB
import numpy as np
import torch
class_groups = {
# group : indices (assuming 0th position is id)
0: (),
1: (1, 2, 3),
2: (4, 5),
3: (6, 7),
4: (8, 9),
5: (10, 11, 12, 13),
6: (14, 15),
7: (16, 17, 18),
8: (19, 20, 21, 22, 23, 24, 25),
9: (26, 27, 28),
10: (29, 30, 31),
11: (32, 33, 34, 35, 36, 37),
}
class_groups_indices = {g: np.array(ixs)-1 for g, ixs in class_groups.items()}
hierarchy = {
# group : parent (group, label)
2: (1, 1),
3: (2, 1),
4: (2, 1),
5: (2, 1),
7: (1, 0),
8: (6, 0),
9: (2, 0),
10: (4, 0),
11: (4, 0),
}
def make_galaxy_labels_hierarchical(labels: torch.Tensor) -> torch.Tensor:
""" transform groups of galaxy label probabilities to follow the hierarchical order defined in galaxy zoo
more info here: https://www.kaggle.com/c/galaxy-zoo-the-galaxy-challenge/overview/the-galaxy-zoo-decision-tree
labels is a NxL torch tensor, where N is the batch size and L is the number of labels,
all labels should be > 1
the indices of label groups are listed in class_groups_indices
Return
------
hierarchical_labels : NxL torch tensor, where L is the total number of labels
"""
shift = labels.shape[1] > 37 ## in case the id is included at 0th position, shift indices accordingly
index = lambda i: class_groups_indices[i] + shift
for i in range(1, 12):
## normalize probabilities to 1
norm = torch.sum(labels[:, index(i)], dim=1, keepdims=True)
norm[norm == 0] += 1e-4 ## add small number to prevent NaNs dividing by zero, yet keep track of gradient
labels[:, index(i)] /= norm
## renormalize according to hierarchical structure
if i not in [1, 6]:
parent_group_label = labels[:, index(hierarchy[i][0])]
labels[:, index(i)] *= parent_group_label[:, hierarchy[i][1]].unsqueeze(-1)
return labels