|
|
|
"""Contains utility functions for image processing. |
|
|
|
The module is primarily built on `cv2`. But, differently, we assume all colorful |
|
images are with `RGB` channel order by default. Also, we assume all gray-scale |
|
images to be with shape [height, width, 1]. |
|
""" |
|
|
|
import os |
|
import cv2 |
|
import numpy as np |
|
|
|
from .misc import IMAGE_EXTENSIONS |
|
from .misc import check_file_ext |
|
|
|
__all__ = [ |
|
'get_blank_image', 'load_image', 'save_image', 'resize_image', |
|
'add_text_to_image', 'preprocess_image', 'postprocess_image', |
|
'parse_image_size', 'get_grid_shape', 'list_images_from_dir' |
|
] |
|
|
|
|
|
def _check_2d_image(image): |
|
"""Checks whether a given image is valid. |
|
|
|
A valid image is expected to be with dtype `uint8`. Also, it should have |
|
shape like: |
|
|
|
(1) (height, width, 1) # gray-scale image. |
|
(2) (height, width, 3) # colorful image. |
|
(3) (height, width, 4) # colorful image with transparency (RGBA) |
|
""" |
|
assert isinstance(image, np.ndarray) |
|
assert image.dtype == np.uint8 |
|
assert image.ndim == 3 and image.shape[2] in [1, 3, 4] |
|
|
|
|
|
def get_blank_image(height, width, channels=3, use_black=True): |
|
"""Gets a blank image, either white of black. |
|
|
|
NOTE: This function will always return an image with `RGB` channel order for |
|
color image and pixel range [0, 255]. |
|
|
|
Args: |
|
height: Height of the returned image. |
|
width: Width of the returned image. |
|
channels: Number of channels. (default: 3) |
|
use_black: Whether to return a black image. (default: True) |
|
""" |
|
shape = (height, width, channels) |
|
if use_black: |
|
return np.zeros(shape, dtype=np.uint8) |
|
return np.ones(shape, dtype=np.uint8) * 255 |
|
|
|
|
|
def load_image(path): |
|
"""Loads an image from disk. |
|
|
|
NOTE: This function will always return an image with `RGB` channel order for |
|
color image and pixel range [0, 255]. |
|
|
|
Args: |
|
path: Path to load the image from. |
|
|
|
Returns: |
|
An image with dtype `np.ndarray`, or `None` if `path` does not exist. |
|
""" |
|
image = cv2.imread(path, cv2.IMREAD_UNCHANGED) |
|
if image is None: |
|
return None |
|
|
|
if image.ndim == 2: |
|
image = image[:, :, np.newaxis] |
|
_check_2d_image(image) |
|
if image.shape[2] == 3: |
|
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
if image.shape[2] == 4: |
|
return cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) |
|
return image |
|
|
|
|
|
def save_image(path, image): |
|
"""Saves an image to disk. |
|
|
|
NOTE: The input image (if colorful) is assumed to be with `RGB` channel |
|
order and pixel range [0, 255]. |
|
|
|
Args: |
|
path: Path to save the image to. |
|
image: Image to save. |
|
""" |
|
if image is None: |
|
return |
|
|
|
_check_2d_image(image) |
|
if image.shape[2] == 1: |
|
cv2.imwrite(path, image) |
|
elif image.shape[2] == 3: |
|
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) |
|
elif image.shape[2] == 4: |
|
cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGBA2BGRA)) |
|
|
|
|
|
def resize_image(image, *args, **kwargs): |
|
"""Resizes image. |
|
|
|
This is a wrap of `cv2.resize()`. |
|
|
|
NOTE: The channel order of the input image will not be changed. |
|
|
|
Args: |
|
image: Image to resize. |
|
*args: Additional positional arguments. |
|
**kwargs: Additional keyword arguments. |
|
|
|
Returns: |
|
An image with dtype `np.ndarray`, or `None` if `image` is empty. |
|
""" |
|
if image is None: |
|
return None |
|
|
|
_check_2d_image(image) |
|
if image.shape[2] == 1: |
|
return cv2.resize(image, *args, **kwargs)[:, :, np.newaxis] |
|
return cv2.resize(image, *args, **kwargs) |
|
|
|
|
|
def add_text_to_image(image, |
|
text='', |
|
position=None, |
|
font=cv2.FONT_HERSHEY_TRIPLEX, |
|
font_size=1.0, |
|
line_type=cv2.LINE_8, |
|
line_width=1, |
|
color=(255, 255, 255)): |
|
"""Overlays text on given image. |
|
|
|
NOTE: The input image is assumed to be with `RGB` channel order. |
|
|
|
Args: |
|
image: The image to overlay text on. |
|
text: Text content to overlay on the image. (default: empty) |
|
position: Target position (bottom-left corner) to add text. If not set, |
|
center of the image will be used by default. (default: None) |
|
font: Font of the text added. (default: cv2.FONT_HERSHEY_TRIPLEX) |
|
font_size: Font size of the text added. (default: 1.0) |
|
line_type: Line type used to depict the text. (default: cv2.LINE_8) |
|
line_width: Line width used to depict the text. (default: 1) |
|
color: Color of the text added in `RGB` channel order. (default: |
|
(255, 255, 255)) |
|
|
|
Returns: |
|
An image with target text overlaid on. |
|
""" |
|
if image is None or not text: |
|
return image |
|
|
|
_check_2d_image(image) |
|
cv2.putText(img=image, |
|
text=text, |
|
org=position, |
|
fontFace=font, |
|
fontScale=font_size, |
|
color=color, |
|
thickness=line_width, |
|
lineType=line_type, |
|
bottomLeftOrigin=False) |
|
return image |
|
|
|
|
|
def preprocess_image(image, min_val=-1.0, max_val=1.0): |
|
"""Pre-processes image by adjusting the pixel range and to dtype `float32`. |
|
|
|
This function is particularly used to convert an image or a batch of images |
|
to `NCHW` format, which matches the data type commonly used in deep models. |
|
|
|
NOTE: The input image is assumed to be with pixel range [0, 255] and with |
|
format `HWC` or `NHWC`. The returned image will be always be with format |
|
`NCHW`. |
|
|
|
Args: |
|
image: The input image for pre-processing. |
|
min_val: Minimum value of the output image. |
|
max_val: Maximum value of the output image. |
|
|
|
Returns: |
|
The pre-processed image. |
|
""" |
|
assert isinstance(image, np.ndarray) |
|
|
|
image = image.astype(np.float64) |
|
image = image / 255.0 * (max_val - min_val) + min_val |
|
|
|
if image.ndim == 3: |
|
image = image[np.newaxis] |
|
assert image.ndim == 4 and image.shape[3] in [1, 3, 4] |
|
return image.transpose(0, 3, 1, 2) |
|
|
|
|
|
def postprocess_image(image, min_val=-1.0, max_val=1.0): |
|
"""Post-processes image to pixel range [0, 255] with dtype `uint8`. |
|
|
|
This function is particularly used to handle the results produced by deep |
|
models. |
|
|
|
NOTE: The input image is assumed to be with format `NCHW`, and the returned |
|
image will always be with format `NHWC`. |
|
|
|
Args: |
|
image: The input image for post-processing. |
|
min_val: Expected minimum value of the input image. |
|
max_val: Expected maximum value of the input image. |
|
|
|
Returns: |
|
The post-processed image. |
|
""" |
|
assert isinstance(image, np.ndarray) |
|
|
|
image = image.astype(np.float64) |
|
image = (image - min_val) / (max_val - min_val) * 255 |
|
image = np.clip(image + 0.5, 0, 255).astype(np.uint8) |
|
|
|
assert image.ndim == 4 and image.shape[1] in [1, 3, 4] |
|
return image.transpose(0, 2, 3, 1) |
|
|
|
|
|
def parse_image_size(obj): |
|
"""Parses an object to a pair of image size, i.e., (height, width). |
|
|
|
Args: |
|
obj: The input object to parse image size from. |
|
|
|
Returns: |
|
A two-element tuple, indicating image height and width respectively. |
|
|
|
Raises: |
|
If the input is invalid, i.e., neither a list or tuple, nor a string. |
|
""" |
|
if obj is None or obj == '': |
|
height = 0 |
|
width = 0 |
|
elif isinstance(obj, int): |
|
height = obj |
|
width = obj |
|
elif isinstance(obj, (list, tuple, str, np.ndarray)): |
|
if isinstance(obj, str): |
|
splits = obj.replace(' ', '').split(',') |
|
numbers = tuple(map(int, splits)) |
|
else: |
|
numbers = tuple(obj) |
|
if len(numbers) == 0: |
|
height = 0 |
|
width = 0 |
|
elif len(numbers) == 1: |
|
height = int(numbers[0]) |
|
width = int(numbers[0]) |
|
elif len(numbers) == 2: |
|
height = int(numbers[0]) |
|
width = int(numbers[1]) |
|
else: |
|
raise ValueError('At most two elements for image size.') |
|
else: |
|
raise ValueError(f'Invalid type of input: `{type(obj)}`!') |
|
|
|
return (max(0, height), max(0, width)) |
|
|
|
|
|
def get_grid_shape(size, height=0, width=0, is_portrait=False): |
|
"""Gets the shape of a grid based on the size. |
|
|
|
This function makes greatest effort on making the output grid square if |
|
neither `height` nor `width` is set. If `is_portrait` is set as `False`, the |
|
height will always be equal to or smaller than the width. For example, if |
|
input `size = 16`, output shape will be `(4, 4)`; if input `size = 15`, |
|
output shape will be (3, 5). Otherwise, the height will always be equal to |
|
or larger than the width. |
|
|
|
Args: |
|
size: Size (height * width) of the target grid. |
|
height: Expected height. If `size % height != 0`, this field will be |
|
ignored. (default: 0) |
|
width: Expected width. If `size % width != 0`, this field will be |
|
ignored. (default: 0) |
|
is_portrait: Whether to return a portrait size of a landscape size. |
|
(default: False) |
|
|
|
Returns: |
|
A two-element tuple, representing height and width respectively. |
|
""" |
|
assert isinstance(size, int) |
|
assert isinstance(height, int) |
|
assert isinstance(width, int) |
|
if size <= 0: |
|
return (0, 0) |
|
|
|
if height > 0 and width > 0 and height * width != size: |
|
height = 0 |
|
width = 0 |
|
|
|
if height > 0 and width > 0 and height * width == size: |
|
return (height, width) |
|
if height > 0 and size % height == 0: |
|
return (height, size // height) |
|
if width > 0 and size % width == 0: |
|
return (size // width, width) |
|
|
|
height = int(np.sqrt(size)) |
|
while height > 0: |
|
if size % height == 0: |
|
width = size // height |
|
break |
|
height = height - 1 |
|
|
|
return (width, height) if is_portrait else (height, width) |
|
|
|
|
|
def list_images_from_dir(directory): |
|
"""Lists all images from the given directory. |
|
|
|
NOTE: Do NOT support finding images recursively. |
|
|
|
Args: |
|
directory: The directory to find images from. |
|
|
|
Returns: |
|
A list of sorted filenames, with the directory as prefix. |
|
""" |
|
image_list = [] |
|
for filename in os.listdir(directory): |
|
if check_file_ext(filename, *IMAGE_EXTENSIONS): |
|
image_list.append(os.path.join(directory, filename)) |
|
return sorted(image_list) |
|
|