RadRotator / io_utils.py
Pouriarouzrokh's picture
added pad_to_square to io_utils.py
3c5f9a0
raw
history blame
5.42 kB
################################################################################
# This files contains OSAIL utils to read and write files.
################################################################################
import copy
import monai as mn
import numpy as np
import os
import skimage
################################################################################
# -F: pad_to_square
def pad_to_square(image):
"""A function to pad an image to square shape with zero pixels.
Args:
image (np.ndarray): the input image array.
Returns:
np.ndarray: the padded image array.
"""
height, width = image.shape
if height < width:
padded_image = np.zeros((width, width))
delta = (width - height) // 2
padded_image[delta:height+delta, :] = image
image = padded_image
elif height > width:
padded_image = np.zeros((height, height))
delta = (height - width) // 2
padded_image[:, delta:width+delta] = image
image = padded_image
return image
################################################################################
# -F: load_image
def load_image(input_object, pad=False, normalize=True, standardize=False,
dtype=np.float32, percentile_clip=None, target_shape=None,
transpose=False, ensure_grayscale=True, LoadImage_args=[], LoadImage_kwargs={}):
"""A helper function to load different input types.
Args:
input_object (Union[np.ndarray, str]):
a 2D NumPy array of X-ray an image, a DICOM file of an X-ray image,
or a string path to a .npy, any regular image file format
saved on disk that skimage.io can load.
pad (bool, optional): whether to pad the image to square shape.
Defaults to True.
normalize (bool, optional): whether to normalize the image.
Defaults to True.
standardize (bool, optional): whether to standardize the image.
Defaults to False.
dtype (np.dtype, optional): the data type of the output image.
Defaults to np.float32.
percentile_clip (float, optional): the percentile to clip the image.
Defaults to 2.5.
target_shape (tuple, optional): the target shape of the output image.
Defaults to None, which means no resizing.
transpose (bool, optional): whether to transpose the image.
Defaults to False.
ensure_grayscale (bool, optional): whether to make the image grayscale.
Defaults to True.
LoadImg_args: a list of keyword arguments to pass to mn.transforms.LoadImage.
LoadImg_kwargs: a dictionary of keyword arguments to pass to mn.transforms.LoadImage.
Returns:
the loaded image array.
"""
# Load the image.
if isinstance(input_object, np.ndarray):
image = input_object
elif isinstance(input_object, str):
assert os.path.exists(input_object), f"File not found: {input_object}"
reader = mn.transforms.LoadImage(image_only=True, *LoadImage_args, **LoadImage_kwargs)
image = reader(input_object)
# Make the image 2D.
if ensure_grayscale:
if image.shape[-1] == 3:
image = np.mean(image, axis=-1)
elif image.shape[0] == 3:
image = np.mean(image, axis=0)
elif image.shape[-1] == 4:
image = np.mean(image[...,:3], axis=-1)
elif image.shape[0] == 4:
image = np.mean(image[:3,...], axis=0)
assert len(image.shape) == 2, f"Image must be 2D: {image.shape}"
# Transpose the image.
if transpose:
image = np.transpose(image, axes=(1,0))
# Clip the image.
if percentile_clip is not None:
percentile_low = np.percentile(image, percentile_clip)
percentile_high = np.percentile(image, 100-percentile_clip)
image = np.clip(image, percentile_low, percentile_high)
# Standardize the image.
if standardize:
image = image.astype(np.float32)
image -= image.mean()
image /= (image.std() + 1e-8)
# Normalize the image.
if normalize:
image = image.astype(np.float32)
image -= image.min()
image /= (image.max() + 1e-8)
# Pad the image to square shape.
if pad:
image = pad_to_square(image)
# Resize the image.
if target_shape is not None:
image = skimage.transform.resize(image, target_shape, preserve_range=True)
# Cast the image to the target data type.
if dtype is np.uint8:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(dtype)
return image
################################################################################
# -C: LoadImageD
class LoadImageD(mn.transforms.Transform):
"""A MONAI transform to load input image using load_image function.
"""
def __init__(self, keys, *to_pass_keys, **to_pass_kwargs) -> None:
super().__init__()
self.keys = keys
self.to_pass_keys = to_pass_keys
self.to_pass_kwargs = to_pass_kwargs
def __call__(self, data):
data_copy = copy.deepcopy(data)
for key in self.keys:
data_copy[key] = load_image(data[key], *self.to_pass_keys, **self.to_pass_kwargs)
return data_copy