|
import matplotlib.pyplot as plt |
|
import numpy |
|
|
|
def plot_tensor_images(data, **kwargs): |
|
data = ((data + 1) / 2 * 255).permute(0, 2, 3, 1).byte().cpu().numpy() |
|
width = int(numpy.ceil(numpy.sqrt(data.shape[0]))) |
|
height = int(numpy.ceil(data.shape[0] / float(width))) |
|
kwargs = dict(kwargs) |
|
margin = 0.01 |
|
if 'figsize' not in kwargs: |
|
|
|
dpi = plt.rcParams['figure.dpi'] |
|
kwargs['figsize'] = ( |
|
(1 + margin) * (width * data.shape[2] / dpi), |
|
(1 + margin) * (height * data.shape[1] / dpi)) |
|
f, axarr = plt.subplots(height, width, **kwargs) |
|
if len(numpy.shape(axarr)) == 0: |
|
axarr = numpy.array([[axarr]]) |
|
if len(numpy.shape(axarr)) == 1: |
|
axarr = axarr[None,:] |
|
for i, im in enumerate(data): |
|
ax = axarr[i // width, i % width] |
|
ax.imshow(data[i]) |
|
ax.axis('off') |
|
for i in range(i, width * height): |
|
ax = axarr[i // width, i % width] |
|
ax.axis('off') |
|
plt.subplots_adjust(wspace=margin, hspace=margin, |
|
left=0, right=1, bottom=0, top=1) |
|
plt.show() |
|
|
|
def plot_max_heatmap(data, shape=None, **kwargs): |
|
if shape is None: |
|
shape = data.shape[2:] |
|
data = data.max(1)[0].cpu().numpy() |
|
vmin = data.min() |
|
vmax = data.max() |
|
width = int(numpy.ceil(numpy.sqrt(data.shape[0]))) |
|
height = int(numpy.ceil(data.shape[0] / float(width))) |
|
kwargs = dict(kwargs) |
|
margin = 0.01 |
|
if 'figsize' not in kwargs: |
|
|
|
dpi = plt.rcParams['figure.dpi'] |
|
kwargs['figsize'] = ( |
|
width * shape[1] / dpi, height * shape[0] / dpi) |
|
f, axarr = plt.subplots(height, width, **kwargs) |
|
if len(numpy.shape(axarr)) == 0: |
|
axarr = numpy.array([[axarr]]) |
|
if len(numpy.shape(axarr)) == 1: |
|
axarr = axarr[None,:] |
|
for i, im in enumerate(data): |
|
ax = axarr[i // width, i % width] |
|
img = ax.imshow(data[i], vmin=vmin, vmax=vmax, cmap='hot') |
|
ax.axis('off') |
|
for i in range(i, width * height): |
|
ax = axarr[i // width, i % width] |
|
ax.axis('off') |
|
plt.subplots_adjust(wspace=margin, hspace=margin, |
|
left=0, right=1, bottom=0, top=1) |
|
plt.show() |
|
|