# python3.7 """Contains the visualizer to visualize images by composing them as a gird.""" from ..image_utils import get_blank_image from ..image_utils import get_grid_shape from ..image_utils import parse_image_size from ..image_utils import load_image from ..image_utils import save_image from ..image_utils import resize_image from ..image_utils import list_images_from_dir __all__ = ['GridVisualizer'] class GridVisualizer(object): """Defines the visualizer that visualizes images as a grid. Basically, given a collection of images, this visualizer stitches them one by one. Notably, this class also supports adding spaces between images, adding borders around images, and using white/black background. Example: grid = GridVisualizer(num_rows, num_cols) for i in range(num_rows): for j in range(num_cols): grid.add(i, j, image) grid.save('visualize.jpg') """ def __init__(self, grid_size=0, num_rows=0, num_cols=0, is_portrait=False, image_size=None, image_channels=0, row_spacing=0, col_spacing=0, border_left=0, border_right=0, border_top=0, border_bottom=0, use_black_background=True): """Initializes the grid visualizer. Args: grid_size: Total number of cells, i.e., height * width. (default: 0) num_rows: Number of rows. (default: 0) num_cols: Number of columns. (default: 0) is_portrait: Whether the grid should be portrait or landscape. This is only used when it requires to compute `num_rows` and `num_cols` automatically. See function `get_grid_shape()` in file `./image_utils.py` for details. (default: False) image_size: Size to visualize each image. (default: 0) image_channels: Number of image channels. (default: 0) row_spacing: Spacing between rows. (default: 0) col_spacing: Spacing between columns. (default: 0) border_left: Width of left border. (default: 0) border_right: Width of right border. (default: 0) border_top: Width of top border. (default: 0) border_bottom: Width of bottom border. (default: 0) use_black_background: Whether to use black background. (default: True) """ self.reset(grid_size, num_rows, num_cols, is_portrait) self.set_image_size(image_size) self.set_image_channels(image_channels) self.set_row_spacing(row_spacing) self.set_col_spacing(col_spacing) self.set_border_left(border_left) self.set_border_right(border_right) self.set_border_top(border_top) self.set_border_bottom(border_bottom) self.set_background(use_black_background) self.grid = None def reset(self, grid_size=0, num_rows=0, num_cols=0, is_portrait=False): """Resets the grid shape, i.e., number of rows/columns.""" if grid_size > 0: num_rows, num_cols = get_grid_shape(grid_size, height=num_rows, width=num_cols, is_portrait=is_portrait) self.grid_size = num_rows * num_cols self.num_rows = num_rows self.num_cols = num_cols self.grid = None def set_image_size(self, image_size=None): """Sets the image size of each cell in the grid.""" height, width = parse_image_size(image_size) self.image_height = height self.image_width = width def set_image_channels(self, image_channels=0): """Sets the number of channels of the grid.""" self.image_channels = image_channels def set_row_spacing(self, row_spacing=0): """Sets the spacing between grid rows.""" self.row_spacing = row_spacing def set_col_spacing(self, col_spacing=0): """Sets the spacing between grid columns.""" self.col_spacing = col_spacing def set_border_left(self, border_left=0): """Sets the width of the left border of the grid.""" self.border_left = border_left def set_border_right(self, border_right=0): """Sets the width of the right border of the grid.""" self.border_right = border_right def set_border_top(self, border_top=0): """Sets the width of the top border of the grid.""" self.border_top = border_top def set_border_bottom(self, border_bottom=0): """Sets the width of the bottom border of the grid.""" self.border_bottom = border_bottom def set_background(self, use_black=True): """Sets the grid background.""" self.use_black_background = use_black def init_grid(self): """Initializes the grid with a blank image.""" assert self.num_rows > 0 assert self.num_cols > 0 assert self.image_height > 0 assert self.image_width > 0 assert self.image_channels > 0 grid_height = (self.image_height * self.num_rows + self.row_spacing * (self.num_rows - 1) + self.border_top + self.border_bottom) grid_width = (self.image_width * self.num_cols + self.col_spacing * (self.num_cols - 1) + self.border_left + self.border_right) self.grid = get_blank_image(grid_height, grid_width, channels=self.image_channels, use_black=self.use_black_background) def add(self, i, j, image): """Adds an image into the grid. NOTE: The input image is assumed to be with `RGB` channel order. """ channels = 1 if image.ndim == 2 else image.shape[2] if self.grid is None: height, width = image.shape[0:2] height = self.image_height or height width = self.image_width or width channels = self.image_channels or channels self.set_image_size((height, width)) self.set_image_channels(channels) self.init_grid() if image.shape[0:2] != (self.image_height, self.image_width): image = resize_image(image, (self.image_width, self.image_height)) y = self.border_top + i * (self.image_height + self.row_spacing) x = self.border_left + j * (self.image_width + self.col_spacing) self.grid[y:y + self.image_height, x:x + self.image_width, :channels] = image def visualize_collection(self, images, save_path=None, num_rows=0, num_cols=0, is_portrait=False, is_row_major=True): """Visualizes a collection of images one by one.""" self.grid = None self.reset(grid_size=len(images), num_rows=num_rows, num_cols=num_cols, is_portrait=is_portrait) for idx, image in enumerate(images): if is_row_major: row_idx, col_idx = divmod(idx, self.num_cols) else: col_idx, row_idx = divmod(idx, self.num_rows) self.add(row_idx, col_idx, image) if save_path: self.save(save_path) def visualize_list(self, image_list, save_path=None, num_rows=0, num_cols=0, is_portrait=False, is_row_major=True): """Visualizes a list of image files.""" self.grid = None self.reset(grid_size=len(image_list), num_rows=num_rows, num_cols=num_cols, is_portrait=is_portrait) for idx, filename in enumerate(image_list): image = load_image(filename) if is_row_major: row_idx, col_idx = divmod(idx, self.num_cols) else: col_idx, row_idx = divmod(idx, self.num_rows) self.add(row_idx, col_idx, image) if save_path: self.save(save_path) def visualize_directory(self, directory, save_path=None, num_rows=0, num_cols=0, is_portrait=False, is_row_major=True): """Visualizes all images under a directory.""" image_list = list_images_from_dir(directory) self.visualize_list(image_list=image_list, save_path=save_path, num_rows=num_rows, num_cols=num_cols, is_portrait=is_portrait, is_row_major=is_row_major) def save(self, path): """Saves the grid.""" save_image(path, self.grid)