BerfScene / utils /visualizers /grid_visualizer.py
3v324v23's picture
init
2f85de4
# python3.7
"""Contains the visualizer to visualize images by composing them as a gird."""
from ..image_utils import get_blank_image
from ..image_utils import get_grid_shape
from ..image_utils import parse_image_size
from ..image_utils import load_image
from ..image_utils import save_image
from ..image_utils import resize_image
from ..image_utils import list_images_from_dir
__all__ = ['GridVisualizer']
class GridVisualizer(object):
"""Defines the visualizer that visualizes images as a grid.
Basically, given a collection of images, this visualizer stitches them one
by one. Notably, this class also supports adding spaces between images,
adding borders around images, and using white/black background.
Example:
grid = GridVisualizer(num_rows, num_cols)
for i in range(num_rows):
for j in range(num_cols):
grid.add(i, j, image)
grid.save('visualize.jpg')
"""
def __init__(self,
grid_size=0,
num_rows=0,
num_cols=0,
is_portrait=False,
image_size=None,
image_channels=0,
row_spacing=0,
col_spacing=0,
border_left=0,
border_right=0,
border_top=0,
border_bottom=0,
use_black_background=True):
"""Initializes the grid visualizer.
Args:
grid_size: Total number of cells, i.e., height * width. (default: 0)
num_rows: Number of rows. (default: 0)
num_cols: Number of columns. (default: 0)
is_portrait: Whether the grid should be portrait or landscape.
This is only used when it requires to compute `num_rows` and
`num_cols` automatically. See function `get_grid_shape()` in
file `./image_utils.py` for details. (default: False)
image_size: Size to visualize each image. (default: 0)
image_channels: Number of image channels. (default: 0)
row_spacing: Spacing between rows. (default: 0)
col_spacing: Spacing between columns. (default: 0)
border_left: Width of left border. (default: 0)
border_right: Width of right border. (default: 0)
border_top: Width of top border. (default: 0)
border_bottom: Width of bottom border. (default: 0)
use_black_background: Whether to use black background.
(default: True)
"""
self.reset(grid_size, num_rows, num_cols, is_portrait)
self.set_image_size(image_size)
self.set_image_channels(image_channels)
self.set_row_spacing(row_spacing)
self.set_col_spacing(col_spacing)
self.set_border_left(border_left)
self.set_border_right(border_right)
self.set_border_top(border_top)
self.set_border_bottom(border_bottom)
self.set_background(use_black_background)
self.grid = None
def reset(self,
grid_size=0,
num_rows=0,
num_cols=0,
is_portrait=False):
"""Resets the grid shape, i.e., number of rows/columns."""
if grid_size > 0:
num_rows, num_cols = get_grid_shape(grid_size,
height=num_rows,
width=num_cols,
is_portrait=is_portrait)
self.grid_size = num_rows * num_cols
self.num_rows = num_rows
self.num_cols = num_cols
self.grid = None
def set_image_size(self, image_size=None):
"""Sets the image size of each cell in the grid."""
height, width = parse_image_size(image_size)
self.image_height = height
self.image_width = width
def set_image_channels(self, image_channels=0):
"""Sets the number of channels of the grid."""
self.image_channels = image_channels
def set_row_spacing(self, row_spacing=0):
"""Sets the spacing between grid rows."""
self.row_spacing = row_spacing
def set_col_spacing(self, col_spacing=0):
"""Sets the spacing between grid columns."""
self.col_spacing = col_spacing
def set_border_left(self, border_left=0):
"""Sets the width of the left border of the grid."""
self.border_left = border_left
def set_border_right(self, border_right=0):
"""Sets the width of the right border of the grid."""
self.border_right = border_right
def set_border_top(self, border_top=0):
"""Sets the width of the top border of the grid."""
self.border_top = border_top
def set_border_bottom(self, border_bottom=0):
"""Sets the width of the bottom border of the grid."""
self.border_bottom = border_bottom
def set_background(self, use_black=True):
"""Sets the grid background."""
self.use_black_background = use_black
def init_grid(self):
"""Initializes the grid with a blank image."""
assert self.num_rows > 0
assert self.num_cols > 0
assert self.image_height > 0
assert self.image_width > 0
assert self.image_channels > 0
grid_height = (self.image_height * self.num_rows +
self.row_spacing * (self.num_rows - 1) +
self.border_top + self.border_bottom)
grid_width = (self.image_width * self.num_cols +
self.col_spacing * (self.num_cols - 1) +
self.border_left + self.border_right)
self.grid = get_blank_image(grid_height, grid_width,
channels=self.image_channels,
use_black=self.use_black_background)
def add(self, i, j, image):
"""Adds an image into the grid.
NOTE: The input image is assumed to be with `RGB` channel order.
"""
channels = 1 if image.ndim == 2 else image.shape[2]
if self.grid is None:
height, width = image.shape[0:2]
height = self.image_height or height
width = self.image_width or width
channels = self.image_channels or channels
self.set_image_size((height, width))
self.set_image_channels(channels)
self.init_grid()
if image.shape[0:2] != (self.image_height, self.image_width):
image = resize_image(image, (self.image_width, self.image_height))
y = self.border_top + i * (self.image_height + self.row_spacing)
x = self.border_left + j * (self.image_width + self.col_spacing)
self.grid[y:y + self.image_height,
x:x + self.image_width,
:channels] = image
def visualize_collection(self,
images,
save_path=None,
num_rows=0,
num_cols=0,
is_portrait=False,
is_row_major=True):
"""Visualizes a collection of images one by one."""
self.grid = None
self.reset(grid_size=len(images),
num_rows=num_rows,
num_cols=num_cols,
is_portrait=is_portrait)
for idx, image in enumerate(images):
if is_row_major:
row_idx, col_idx = divmod(idx, self.num_cols)
else:
col_idx, row_idx = divmod(idx, self.num_rows)
self.add(row_idx, col_idx, image)
if save_path:
self.save(save_path)
def visualize_list(self,
image_list,
save_path=None,
num_rows=0,
num_cols=0,
is_portrait=False,
is_row_major=True):
"""Visualizes a list of image files."""
self.grid = None
self.reset(grid_size=len(image_list),
num_rows=num_rows,
num_cols=num_cols,
is_portrait=is_portrait)
for idx, filename in enumerate(image_list):
image = load_image(filename)
if is_row_major:
row_idx, col_idx = divmod(idx, self.num_cols)
else:
col_idx, row_idx = divmod(idx, self.num_rows)
self.add(row_idx, col_idx, image)
if save_path:
self.save(save_path)
def visualize_directory(self,
directory,
save_path=None,
num_rows=0,
num_cols=0,
is_portrait=False,
is_row_major=True):
"""Visualizes all images under a directory."""
image_list = list_images_from_dir(directory)
self.visualize_list(image_list=image_list,
save_path=save_path,
num_rows=num_rows,
num_cols=num_cols,
is_portrait=is_portrait,
is_row_major=is_row_major)
def save(self, path):
"""Saves the grid."""
save_image(path, self.grid)