File size: 5,416 Bytes
b8c299e
 
 
 
 
 
 
 
 
 
3c5f9a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8c299e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
################################################################################
# This files contains OSAIL utils to read and write files.
################################################################################

import copy
import monai as mn
import numpy as np
import os
import skimage

################################################################################
# -F: pad_to_square

def pad_to_square(image):
    """A function to pad an image to square shape with zero pixels.

    Args:
        image (np.ndarray): the input image array.

    Returns:
        np.ndarray: the padded image array.
    """
    height, width = image.shape
    if height < width: 
        padded_image = np.zeros((width, width))
        delta = (width - height) // 2
        padded_image[delta:height+delta, :] = image
        image = padded_image
    elif height > width:
        padded_image = np.zeros((height, height))
        delta = (height - width) // 2
        padded_image[:, delta:width+delta] = image
        image = padded_image
    return image

################################################################################
# -F: load_image

def load_image(input_object, pad=False, normalize=True, standardize=False, 
               dtype=np.float32, percentile_clip=None, target_shape=None, 
               transpose=False, ensure_grayscale=True, LoadImage_args=[], LoadImage_kwargs={}):
    """A helper function to load different input types.

    Args:
        input_object (Union[np.ndarray, str]): 
            a 2D NumPy array of X-ray an image, a DICOM file of an X-ray image, 
            or a string path to a .npy, any regular image file format 
            saved on disk that skimage.io can load.
        pad (bool, optional): whether to pad the image to square shape. 
            Defaults to True.
        normalize (bool, optional): whether to normalize the image. 
            Defaults to True.
        standardize (bool, optional): whether to standardize the image.
            Defaults to False.
        dtype (np.dtype, optional): the data type of the output image. 
            Defaults to np.float32.
        percentile_clip (float, optional): the percentile to clip the image. 
            Defaults to 2.5.
        target_shape (tuple, optional): the target shape of the output image. 
            Defaults to None, which means no resizing.
        transpose (bool, optional): whether to transpose the image.
            Defaults to False.
        ensure_grayscale (bool, optional): whether to make the image grayscale.
            Defaults to True.
        LoadImg_args: a list of keyword arguments to pass to  mn.transforms.LoadImage.
        LoadImg_kwargs: a dictionary of keyword arguments to pass to  mn.transforms.LoadImage.
            
    Returns:
        the loaded image array.
    """
    # Load the image.
    if isinstance(input_object, np.ndarray):
        image = input_object
    elif isinstance(input_object, str):
        assert os.path.exists(input_object), f"File not found: {input_object}"
        reader = mn.transforms.LoadImage(image_only=True, *LoadImage_args, **LoadImage_kwargs)
        image = reader(input_object)

    # Make the image 2D.
    if ensure_grayscale:
        if image.shape[-1] == 3:
            image = np.mean(image, axis=-1)  
        elif image.shape[0] == 3:
            image = np.mean(image, axis=0)
        elif image.shape[-1] == 4:
            image = np.mean(image[...,:3], axis=-1)  
        elif image.shape[0] == 4:
            image = np.mean(image[:3,...], axis=0)  
        assert len(image.shape) == 2, f"Image must be 2D: {image.shape}"
    
    # Transpose the image.
    if transpose:
        image = np.transpose(image, axes=(1,0))
    
    # Clip the image.
    if percentile_clip is not None:
        percentile_low = np.percentile(image, percentile_clip)
        percentile_high = np.percentile(image, 100-percentile_clip)
        image = np.clip(image, percentile_low, percentile_high)
        
    # Standardize the image.
    if standardize:
        image = image.astype(np.float32)
        image -= image.mean()
        image /= (image.std() + 1e-8)
        
    # Normalize the image.
    if normalize:
        image = image.astype(np.float32)
        image -= image.min()
        image /= (image.max() + 1e-8)
    
    # Pad the image to square shape.
    if pad:
        image = pad_to_square(image)   
    
    # Resize the image.
    if target_shape is not None:
        image = skimage.transform.resize(image, target_shape, preserve_range=True)
        
    # Cast the image to the target data type.
    if dtype is np.uint8:
        image = (image * 255).astype(np.uint8)
    else:
        image = image.astype(dtype)  
    
    return image

################################################################################
# -C: LoadImageD

class LoadImageD(mn.transforms.Transform):
    """A MONAI transform to load input image using load_image function.
    """
    def __init__(self, keys, *to_pass_keys, **to_pass_kwargs) -> None:
        super().__init__()
        self.keys = keys
        self.to_pass_keys = to_pass_keys
        self.to_pass_kwargs = to_pass_kwargs
        
    def __call__(self, data):
        data_copy = copy.deepcopy(data)
        for key in self.keys:
            data_copy[key] = load_image(data[key], *self.to_pass_keys, **self.to_pass_kwargs)
        return data_copy