Spaces:
Runtime error
Runtime error
################################################################################ | |
# This files contains OSAIL utils to read and write files. | |
################################################################################ | |
import copy | |
import monai as mn | |
import numpy as np | |
import os | |
import skimage | |
################################################################################ | |
# -F: pad_to_square | |
def pad_to_square(image): | |
"""A function to pad an image to square shape with zero pixels. | |
Args: | |
image (np.ndarray): the input image array. | |
Returns: | |
np.ndarray: the padded image array. | |
""" | |
height, width = image.shape | |
if height < width: | |
padded_image = np.zeros((width, width)) | |
delta = (width - height) // 2 | |
padded_image[delta:height+delta, :] = image | |
image = padded_image | |
elif height > width: | |
padded_image = np.zeros((height, height)) | |
delta = (height - width) // 2 | |
padded_image[:, delta:width+delta] = image | |
image = padded_image | |
return image | |
################################################################################ | |
# -F: load_image | |
def load_image(input_object, pad=False, normalize=True, standardize=False, | |
dtype=np.float32, percentile_clip=None, target_shape=None, | |
transpose=False, ensure_grayscale=True, LoadImage_args=[], LoadImage_kwargs={}): | |
"""A helper function to load different input types. | |
Args: | |
input_object (Union[np.ndarray, str]): | |
a 2D NumPy array of X-ray an image, a DICOM file of an X-ray image, | |
or a string path to a .npy, any regular image file format | |
saved on disk that skimage.io can load. | |
pad (bool, optional): whether to pad the image to square shape. | |
Defaults to True. | |
normalize (bool, optional): whether to normalize the image. | |
Defaults to True. | |
standardize (bool, optional): whether to standardize the image. | |
Defaults to False. | |
dtype (np.dtype, optional): the data type of the output image. | |
Defaults to np.float32. | |
percentile_clip (float, optional): the percentile to clip the image. | |
Defaults to 2.5. | |
target_shape (tuple, optional): the target shape of the output image. | |
Defaults to None, which means no resizing. | |
transpose (bool, optional): whether to transpose the image. | |
Defaults to False. | |
ensure_grayscale (bool, optional): whether to make the image grayscale. | |
Defaults to True. | |
LoadImg_args: a list of keyword arguments to pass to mn.transforms.LoadImage. | |
LoadImg_kwargs: a dictionary of keyword arguments to pass to mn.transforms.LoadImage. | |
Returns: | |
the loaded image array. | |
""" | |
# Load the image. | |
if isinstance(input_object, np.ndarray): | |
image = input_object | |
elif isinstance(input_object, str): | |
assert os.path.exists(input_object), f"File not found: {input_object}" | |
reader = mn.transforms.LoadImage(image_only=True, *LoadImage_args, **LoadImage_kwargs) | |
image = reader(input_object) | |
# Make the image 2D. | |
if ensure_grayscale: | |
if image.shape[-1] == 3: | |
image = np.mean(image, axis=-1) | |
elif image.shape[0] == 3: | |
image = np.mean(image, axis=0) | |
elif image.shape[-1] == 4: | |
image = np.mean(image[...,:3], axis=-1) | |
elif image.shape[0] == 4: | |
image = np.mean(image[:3,...], axis=0) | |
assert len(image.shape) == 2, f"Image must be 2D: {image.shape}" | |
# Transpose the image. | |
if transpose: | |
image = np.transpose(image, axes=(1,0)) | |
# Clip the image. | |
if percentile_clip is not None: | |
percentile_low = np.percentile(image, percentile_clip) | |
percentile_high = np.percentile(image, 100-percentile_clip) | |
image = np.clip(image, percentile_low, percentile_high) | |
# Standardize the image. | |
if standardize: | |
image = image.astype(np.float32) | |
image -= image.mean() | |
image /= (image.std() + 1e-8) | |
# Normalize the image. | |
if normalize: | |
image = image.astype(np.float32) | |
image -= image.min() | |
image /= (image.max() + 1e-8) | |
# Pad the image to square shape. | |
if pad: | |
image = pad_to_square(image) | |
# Resize the image. | |
if target_shape is not None: | |
image = skimage.transform.resize(image, target_shape, preserve_range=True) | |
# Cast the image to the target data type. | |
if dtype is np.uint8: | |
image = (image * 255).astype(np.uint8) | |
else: | |
image = image.astype(dtype) | |
return image | |
################################################################################ | |
# -C: LoadImageD | |
class LoadImageD(mn.transforms.Transform): | |
"""A MONAI transform to load input image using load_image function. | |
""" | |
def __init__(self, keys, *to_pass_keys, **to_pass_kwargs) -> None: | |
super().__init__() | |
self.keys = keys | |
self.to_pass_keys = to_pass_keys | |
self.to_pass_kwargs = to_pass_kwargs | |
def __call__(self, data): | |
data_copy = copy.deepcopy(data) | |
for key in self.keys: | |
data_copy[key] = load_image(data[key], *self.to_pass_keys, **self.to_pass_kwargs) | |
return data_copy |