line-segment-matching / plot_utils.py
Johannes
update
25a8011
raw
history blame
3.34 kB
import copy
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=6, pad=0.5):
"""Plot a set of images horizontally.
Args:
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
titles: a list of strings, as titles for each image.
cmaps: colormaps for monochrome images.
"""
n = len(imgs)
if not isinstance(cmaps, (list, tuple)):
cmaps = [cmaps] * n
figsize = (size * n, size * 3 / 4) if size is not None else None
fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
if n == 1:
ax = [ax]
for i in range(n):
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
ax[i].get_yaxis().set_ticks([])
ax[i].get_xaxis().set_ticks([])
ax[i].set_axis_off()
for spine in ax[i].spines.values(): # remove frame
spine.set_visible(False)
if titles:
ax[i].set_title(titles[i])
fig.tight_layout(pad=pad)
return fig
def plot_lines(
lines, fig, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1)
):
"""Plot lines and endpoints for existing images.
Args:
lines: list of ndarrays of size (N, 2, 2).
colors: string, or list of list of tuples (one for each keypoints).
ps: size of the keypoints as float pixels.
lw: line width as float pixels.
indices: indices of the images to draw the matches on.
"""
if not isinstance(line_colors, list):
line_colors = [line_colors] * len(lines)
if not isinstance(point_colors, list):
point_colors = [point_colors] * len(lines)
# fig = plt.gcf()
ax = fig.axes
assert len(ax) > max(indices)
axes = [ax[i] for i in indices]
fig.canvas.draw()
# Plot the lines and junctions
for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
for i in range(len(l)):
line = matplotlib.lines.Line2D(
(l[i, 1, 1], l[i, 0, 1]),
(l[i, 1, 0], l[i, 0, 0]),
zorder=1,
c=lc,
linewidth=lw,
)
a.add_line(line)
pts = l.reshape(-1, 2)
a.scatter(pts[:, 1], pts[:, 0], c=pc, s=ps, linewidths=0, zorder=2)
return fig
def plot_color_line_matches(lines, fig, lw=2, indices=(0, 1)):
"""Plot line matches for existing images with multiple colors.
Args:
lines: list of ndarrays of size (N, 2, 2).
lw: line width as float pixels.
indices: indices of the images to draw the matches on.
"""
n_lines = len(lines[0])
cmap = plt.get_cmap("nipy_spectral", lut=n_lines)
colors = np.array([mcolors.rgb2hex(cmap(i)) for i in range(cmap.N)])
np.random.shuffle(colors)
ax = fig.axes
assert len(ax) > max(indices)
axes = [ax[i] for i in indices]
fig.canvas.draw()
# Plot the lines
for a, l in zip(axes, lines):
for i in range(len(l)):
line = matplotlib.lines.Line2D(
(l[i, 1, 1], l[i, 0, 1]),
(l[i, 1, 0], l[i, 0, 0]),
zorder=1,
c=colors[i],
linewidth=lw,
)
a.add_line(line)
return fig