|
import numpy as np |
|
import cv2 |
|
from matplotlib import pyplot as plt |
|
|
|
|
|
def get_mpl_colormap(cmap_name): |
|
cmap = plt.get_cmap(cmap_name) |
|
|
|
|
|
sm = plt.cm.ScalarMappable(cmap=cmap) |
|
|
|
|
|
color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:, 2::-1] |
|
|
|
return color_range.reshape(256, 1, 3) |
|
|
|
|
|
def show_cam_on_image(img, mask, neg_saliency=False): |
|
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) |
|
|
|
heatmap = np.float32(heatmap) / 255 |
|
cam = heatmap + np.float32(img) |
|
cam = cam / np.max(cam) |
|
return cam |
|
|
|
|
|
def show_overlapped_cam(img, neg_mask, pos_mask): |
|
|
|
|
|
neg_heatmap = cv2.applyColorMap(np.uint8(255 * neg_mask), get_mpl_colormap("Blues")) |
|
pos_heatmap = cv2.applyColorMap(np.uint8(255 * pos_mask), get_mpl_colormap("Reds")) |
|
neg_heatmap = np.float32(neg_heatmap) / 255 |
|
pos_heatmap = np.float32(pos_heatmap) / 255 |
|
|
|
heatmap = neg_heatmap + pos_heatmap |
|
cam = heatmap + np.float32(img) |
|
cam = cam / np.max(cam) |
|
return cam |
|
|