wzhouxiff commited on
Commit
92e76e8
1 Parent(s): df4a3e9

Create img_utils.py

Browse files
Files changed (1) hide show
  1. img_utils.py +172 -0
img_utils.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from torchvision.utils import make_grid
7
+
8
+
9
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
10
+ """Numpy array to tensor.
11
+
12
+ Args:
13
+ imgs (list[ndarray] | ndarray): Input images.
14
+ bgr2rgb (bool): Whether to change bgr to rgb.
15
+ float32 (bool): Whether to change to float32.
16
+
17
+ Returns:
18
+ list[tensor] | tensor: Tensor images. If returned results only have
19
+ one element, just return tensor.
20
+ """
21
+
22
+ def _totensor(img, bgr2rgb, float32):
23
+ if img.shape[2] == 3 and bgr2rgb:
24
+ if img.dtype == 'float64':
25
+ img = img.astype('float32')
26
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27
+ img = torch.from_numpy(img.transpose(2, 0, 1))
28
+ if float32:
29
+ img = img.float()
30
+ return img
31
+
32
+ if isinstance(imgs, list):
33
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
34
+ else:
35
+ return _totensor(imgs, bgr2rgb, float32)
36
+
37
+
38
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
39
+ """Convert torch Tensors into image numpy arrays.
40
+
41
+ After clamping to [min, max], values will be normalized to [0, 1].
42
+
43
+ Args:
44
+ tensor (Tensor or list[Tensor]): Accept shapes:
45
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
46
+ 2) 3D Tensor of shape (3/1 x H x W);
47
+ 3) 2D Tensor of shape (H x W).
48
+ Tensor channel should be in RGB order.
49
+ rgb2bgr (bool): Whether to change rgb to bgr.
50
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
51
+ to uint8 type with range [0, 255]; otherwise, float type with
52
+ range [0, 1]. Default: ``np.uint8``.
53
+ min_max (tuple[int]): min and max values for clamp.
54
+
55
+ Returns:
56
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
57
+ shape (H x W). The channel order is BGR.
58
+ """
59
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
61
+
62
+ if torch.is_tensor(tensor):
63
+ tensor = [tensor]
64
+ result = []
65
+ for _tensor in tensor:
66
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
67
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
68
+
69
+ n_dim = _tensor.dim()
70
+ if n_dim == 4:
71
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
72
+ img_np = img_np.transpose(1, 2, 0)
73
+ if rgb2bgr:
74
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
75
+ elif n_dim == 3:
76
+ img_np = _tensor.numpy()
77
+ img_np = img_np.transpose(1, 2, 0)
78
+ if img_np.shape[2] == 1: # gray image
79
+ img_np = np.squeeze(img_np, axis=2)
80
+ else:
81
+ if rgb2bgr:
82
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
83
+ elif n_dim == 2:
84
+ img_np = _tensor.numpy()
85
+ else:
86
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
87
+ if out_type == np.uint8:
88
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
89
+ img_np = (img_np * 255.0).round()
90
+ img_np = img_np.astype(out_type)
91
+ result.append(img_np)
92
+ if len(result) == 1:
93
+ result = result[0]
94
+ return result
95
+
96
+
97
+ def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
98
+ """This implementation is slightly faster than tensor2img.
99
+ It now only supports torch tensor with shape (1, c, h, w).
100
+
101
+ Args:
102
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
103
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
104
+ min_max (tuple[int]): min and max values for clamp.
105
+ """
106
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
107
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
108
+ output = output.type(torch.uint8).cpu().numpy()
109
+ if rgb2bgr:
110
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
111
+ return output
112
+
113
+
114
+ def imfrombytes(content, flag='color', float32=False):
115
+ """Read an image from bytes.
116
+
117
+ Args:
118
+ content (bytes): Image bytes got from files or other streams.
119
+ flag (str): Flags specifying the color type of a loaded image,
120
+ candidates are `color`, `grayscale` and `unchanged`.
121
+ float32 (bool): Whether to change to float32., If True, will also norm
122
+ to [0, 1]. Default: False.
123
+
124
+ Returns:
125
+ ndarray: Loaded image array.
126
+ """
127
+ img_np = np.frombuffer(content, np.uint8)
128
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
129
+ img = cv2.imdecode(img_np, imread_flags[flag])
130
+ if float32:
131
+ img = img.astype(np.float32) / 255.
132
+ return img
133
+
134
+
135
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
136
+ """Write image to file.
137
+
138
+ Args:
139
+ img (ndarray): Image array to be written.
140
+ file_path (str): Image file path.
141
+ params (None or list): Same as opencv's :func:`imwrite` interface.
142
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
143
+ whether to create it automatically.
144
+
145
+ Returns:
146
+ bool: Successful or not.
147
+ """
148
+ if auto_mkdir:
149
+ dir_name = os.path.abspath(os.path.dirname(file_path))
150
+ os.makedirs(dir_name, exist_ok=True)
151
+ ok = cv2.imwrite(file_path, img, params)
152
+ if not ok:
153
+ raise IOError('Failed in writing images.')
154
+
155
+
156
+ def crop_border(imgs, crop_border):
157
+ """Crop borders of images.
158
+
159
+ Args:
160
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
161
+ crop_border (int): Crop border for each end of height and weight.
162
+
163
+ Returns:
164
+ list[ndarray]: Cropped images.
165
+ """
166
+ if crop_border == 0:
167
+ return imgs
168
+ else:
169
+ if isinstance(imgs, list):
170
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
171
+ else:
172
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]