|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as mpatches |
|
|
|
def features_to_RGB(*Fs, masks=None, skip=1): |
|
"""Project a list of d-dimensional feature maps to RGB colors using PCA.""" |
|
from sklearn.decomposition import PCA |
|
|
|
def normalize(x): |
|
return x / np.linalg.norm(x, axis=-1, keepdims=True) |
|
|
|
if masks is not None: |
|
assert len(Fs) == len(masks) |
|
|
|
flatten = [] |
|
for i, F in enumerate(Fs): |
|
c, h, w = F.shape |
|
F = np.rollaxis(F, 0, 3) |
|
F_flat = F.reshape(-1, c) |
|
if masks is not None and masks[i] is not None: |
|
mask = masks[i] |
|
assert mask.shape == F.shape[:2] |
|
F_flat = F_flat[mask.reshape(-1)] |
|
flatten.append(F_flat) |
|
flatten = np.concatenate(flatten, axis=0) |
|
flatten = normalize(flatten) |
|
|
|
pca = PCA(n_components=3) |
|
if skip > 1: |
|
pca.fit(flatten[::skip]) |
|
flatten = pca.transform(flatten) |
|
else: |
|
flatten = pca.fit_transform(flatten) |
|
flatten = (normalize(flatten) + 1) / 2 |
|
|
|
Fs_rgb = [] |
|
for i, F in enumerate(Fs): |
|
h, w = F.shape[-2:] |
|
if masks is None or masks[i] is None: |
|
F_rgb, flatten = np.split(flatten, [h * w], axis=0) |
|
F_rgb = F_rgb.reshape((h, w, 3)) |
|
else: |
|
F_rgb = np.zeros((h, w, 3)) |
|
indices = np.where(masks[i]) |
|
F_rgb[indices], flatten = np.split(flatten, [len(indices[0])], axis=0) |
|
F_rgb = np.concatenate([F_rgb, masks[i][..., None]], axis=-1) |
|
Fs_rgb.append(F_rgb) |
|
assert flatten.shape[0] == 0, flatten.shape |
|
return Fs_rgb |
|
|
|
|
|
def one_hot_argmax_to_rgb(y, num_class): |
|
''' |
|
Args: |
|
probs: (B, C, H, W) |
|
num_class: int |
|
0: road 0 |
|
1: crossing 1 |
|
2: explicit_pedestrian 2 |
|
4: building |
|
6: terrain |
|
7: parking ` |
|
|
|
''' |
|
|
|
class_colors = { |
|
'road': (68, 68, 68), |
|
'crossing': (244, 162, 97), |
|
'explicit_pedestrian': (233, 196, 106), |
|
|
|
'building': (231, 111, 81), |
|
'terrain': (42, 157, 143), |
|
'parking': (204, 204, 204), |
|
'predicted_void': (255, 255, 255) |
|
} |
|
class_colors = class_colors.values() |
|
class_colors = [torch.tensor(x).float() for x in class_colors] |
|
|
|
threshold = 0.25 |
|
argmaxed = torch.argmax((y > threshold).float(), dim=1) |
|
argmaxed[torch.all(y <= threshold, dim=1)] = num_class |
|
|
|
|
|
seg_rgb = torch.ones( |
|
( |
|
argmaxed.shape[0], |
|
3, |
|
argmaxed.shape[1], |
|
argmaxed.shape[2], |
|
) |
|
) * 255 |
|
for i in range(num_class + 1): |
|
seg_rgb[:, 0, :, :][argmaxed == i] = class_colors[i][0] |
|
seg_rgb[:, 1, :, :][argmaxed == i] = class_colors[i][1] |
|
seg_rgb[:, 2, :, :][argmaxed == i] = class_colors[i][2] |
|
|
|
return seg_rgb |
|
|
|
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True): |
|
"""Plot a set of images horizontally. |
|
Args: |
|
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). |
|
titles: a list of strings, as titles for each image. |
|
cmaps: colormaps for monochrome images. |
|
adaptive: whether the figure size should fit the image aspect ratios. |
|
""" |
|
n = len(imgs) |
|
if not isinstance(cmaps, (list, tuple)): |
|
cmaps = [cmaps] * n |
|
|
|
if adaptive: |
|
ratios = [i.shape[1] / i.shape[0] for i in imgs] |
|
else: |
|
ratios = [4 / 3] * n |
|
figsize = [sum(ratios) * 4.5, 4.5] |
|
fig, ax = plt.subplots( |
|
1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} |
|
) |
|
if n == 1: |
|
ax = [ax] |
|
for i in range(n): |
|
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) |
|
ax[i].get_yaxis().set_ticks([]) |
|
ax[i].get_xaxis().set_ticks([]) |
|
ax[i].set_axis_off() |
|
for spine in ax[i].spines.values(): |
|
spine.set_visible(False) |
|
if titles: |
|
ax[i].set_title(titles[i]) |
|
|
|
|
|
class_colors = { |
|
'Road': (68, 68, 68), |
|
'Crossing': (244, 162, 97), |
|
'Sidewalk': (233, 196, 106), |
|
'Building': (231, 111, 81), |
|
'Terrain': (42, 157, 143), |
|
'Parking': (204, 204, 204), |
|
} |
|
patches = [mpatches.Patch(color=[c/255.0 for c in color], label=label) for label, color in class_colors.items()] |
|
plt.legend(handles=patches, loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=3) |
|
|
|
fig.tight_layout(pad=pad) |