VIVEK JAYARAM commited on
Commit
64e8afb
1 Parent(s): 95aa1d5

Super resolution and gaussian blur

Browse files
cdim/diffusion/diffusion_pipeline.py CHANGED
@@ -90,6 +90,7 @@ def run_diffusion(
90
  else:
91
  raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
92
 
93
- image -= 5 / torch.linalg.norm(image.grad) * image.grad
 
94
 
95
- return image
 
90
  else:
91
  raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
92
 
93
+ image -= 10 / torch.linalg.norm(image.grad) * image.grad
94
+ image = image.detach().requires_grad_()
95
 
96
+ return image
cdim/operators/__init__.py CHANGED
@@ -23,3 +23,5 @@ def get_operator(name: str, **kwargs):
23
  from .random_box_masker import RandomBoxMasker
24
  from .random_pixel_masker import RandomPixelMasker
25
  from .identity_operator import IdentityOperator
 
 
 
23
  from .random_box_masker import RandomBoxMasker
24
  from .random_pixel_masker import RandomPixelMasker
25
  from .identity_operator import IdentityOperator
26
+ from .super_resolution_operator import SuperResolutionOperator
27
+ from .gaussian_blur_operator import GaussianBlurOperator
cdim/operators/blur_kernel.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import scipy
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+
7
+
8
+ class BlurKernel(nn.Module):
9
+ def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
10
+ super().__init__()
11
+ self.blur_type = blur_type
12
+ self.kernel_size = kernel_size
13
+ self.std = std
14
+ self.device = device
15
+ self.padding = self.kernel_size // 2
16
+
17
+ if self.blur_type == "gaussian":
18
+ n = np.zeros((self.kernel_size, self.kernel_size))
19
+ n[self.kernel_size // 2, self.kernel_size // 2] = 1
20
+ k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
21
+ k = torch.from_numpy(k).float()
22
+ k = k.unsqueeze(0).unsqueeze(0) # Shape (1, 1, kernel_size, kernel_size)
23
+ k = k.repeat(3, 1, 1, 1) # Shape (3, 1, kernel_size, kernel_size)
24
+ self.register_buffer('kernel', k)
25
+ else:
26
+ raise ValueError(f"Unknown blur type {self.blur_type}")
27
+
28
+ def forward(self, x):
29
+ x = F.pad(x, [self.padding]*4, mode='reflect')
30
+ x = F.conv2d(x, self.kernel, groups=3)
31
+ return x
cdim/operators/gaussian_blur_operator.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from cdim.operators import register_operator
3
+ from cdim.operators.blur_kernel import BlurKernel
4
+
5
+ @register_operator(name='gaussian_blur')
6
+ class GaussianBlurOperator:
7
+ def __init__(self, kernel_size, intensity, device='cpu'):
8
+ self.device = device
9
+ self.kernel_size = kernel_size
10
+ self.conv = BlurKernel(blur_type='gaussian',
11
+ kernel_size=kernel_size,
12
+ std=intensity,
13
+ device=device).to(device)
14
+
15
+ def __call__(self, data, **kwargs):
16
+ return self.conv(data)
cdim/operators/resizer.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code was taken from: https://github.com/assafshocher/resizer by Assaf Shocher
2
+ import numpy as np
3
+ import torch
4
+ from math import pi
5
+ from torch import nn
6
+
7
+
8
+ class Resizer(nn.Module):
9
+ def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True):
10
+ super(Resizer, self).__init__()
11
+
12
+ # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa
13
+ scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor)
14
+
15
+ # Choose interpolation method, each method has the matching kernel size
16
+ method, kernel_width = {
17
+ "cubic": (cubic, 4.0),
18
+ "lanczos2": (lanczos2, 4.0),
19
+ "lanczos3": (lanczos3, 6.0),
20
+ "box": (box, 1.0),
21
+ "linear": (linear, 2.0),
22
+ None: (cubic, 4.0) # set default interpolation method as cubic
23
+ }.get(kernel)
24
+
25
+ # Antialiasing is only used when downscaling
26
+ antialiasing *= (np.any(np.array(scale_factor) < 1))
27
+
28
+ # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient
29
+ sorted_dims = np.argsort(np.array(scale_factor))
30
+ self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1]
31
+
32
+ # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction
33
+ field_of_view_list = []
34
+ weights_list = []
35
+ for dim in self.sorted_dims:
36
+ # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the
37
+ # weights that multiply the values there to get its result.
38
+ weights, field_of_view = self.contributions(in_shape[dim], output_shape[dim], scale_factor[dim], method,
39
+ kernel_width, antialiasing)
40
+
41
+ # convert to torch tensor
42
+ weights = torch.tensor(weights.T, dtype=torch.float32)
43
+
44
+ # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for
45
+ # tmp_im[field_of_view.T], (bsxfun style)
46
+ weights_list.append(
47
+ nn.Parameter(torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]),
48
+ requires_grad=False))
49
+ field_of_view_list.append(
50
+ nn.Parameter(torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False))
51
+
52
+ self.field_of_view = nn.ParameterList(field_of_view_list)
53
+ self.weights = nn.ParameterList(weights_list)
54
+
55
+ def forward(self, in_tensor):
56
+ x = in_tensor
57
+
58
+ # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim
59
+ for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights):
60
+ # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize
61
+ x = torch.transpose(x, dim, 0)
62
+
63
+ # This is a bit of a complicated multiplication: x[field_of_view.T] is a tensor of order image_dims+1.
64
+ # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim
65
+ # only, this is why it only adds 1 dim to 5the shape). We then multiply, for each pixel, its set of positions with
66
+ # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style:
67
+ # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the
68
+ # same number
69
+ x = torch.sum(x[fov] * w, dim=0)
70
+
71
+ # Finally we swap back the axes to the original order
72
+ x = torch.transpose(x, dim, 0)
73
+
74
+ return x
75
+
76
+ def fix_scale_and_size(self, input_shape, output_shape, scale_factor):
77
+ # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the
78
+ # same size as the number of input dimensions)
79
+ if scale_factor is not None:
80
+ # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it.
81
+ if np.isscalar(scale_factor) and len(input_shape) > 1:
82
+ scale_factor = [scale_factor, scale_factor]
83
+
84
+ # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales
85
+ scale_factor = list(scale_factor)
86
+ scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor
87
+
88
+ # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size
89
+ # to all the unspecified dimensions
90
+ if output_shape is not None:
91
+ output_shape = list(input_shape[len(output_shape):]) + list(np.uint(np.array(output_shape)))
92
+
93
+ # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is
94
+ # sub-optimal, because there can be different scales to the same output-shape.
95
+ if scale_factor is None:
96
+ scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape)
97
+
98
+ # Dealing with missing output-shape. calculating according to scale-factor
99
+ if output_shape is None:
100
+ output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor)))
101
+
102
+ return scale_factor, output_shape
103
+
104
+ def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing):
105
+ # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied
106
+ # such that each position from the field_of_view will be multiplied with a matching filter from the
107
+ # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers
108
+ # around it. This is only done for one dimension of the image.
109
+
110
+ # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of
111
+ # 1/sf. this means filtering is more 'low-pass filter'.
112
+ fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel
113
+ kernel_width *= 1.0 / scale if antialiasing else 1.0
114
+
115
+ # These are the coordinates of the output image
116
+ out_coordinates = np.arange(1, out_length + 1)
117
+
118
+ # since both scale-factor and output size can be provided simulatneously, perserving the center of the image requires shifting
119
+ # the output coordinates. the deviation is because out_length doesn't necesary equal in_length*scale.
120
+ # to keep the center we need to subtract half of this deivation so that we get equal margins for boths sides and center is preserved.
121
+ shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2
122
+
123
+ # These are the matching positions of the output-coordinates on the input image coordinates.
124
+ # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels:
125
+ # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel.
126
+ # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to
127
+ # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big
128
+ # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor).
129
+ # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is
130
+ # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means:
131
+ # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf)
132
+ match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale)
133
+
134
+ # This is the left boundary to start multiplying the filter from, it depends on the size of the filter
135
+ left_boundary = np.floor(match_coordinates - kernel_width / 2)
136
+
137
+ # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers
138
+ # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them)
139
+ expanded_kernel_width = np.ceil(kernel_width) + 2
140
+
141
+ # Determine a set of field_of_view for each each output position, these are the pixels in the input image
142
+ # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the
143
+ # vertical dim is the pixels it 'sees' (kernel_size + 2)
144
+ field_of_view = np.squeeze(
145
+ np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1))
146
+
147
+ # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the
148
+ # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in
149
+ # 'field_of_view')
150
+ weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1)
151
+
152
+ # Normalize weights to sum up to 1. be careful from dividing by 0
153
+ sum_weights = np.sum(weights, axis=1)
154
+ sum_weights[sum_weights == 0] = 1.0
155
+ weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1)
156
+
157
+ # We use this mirror structure as a trick for reflection padding at the boundaries
158
+ mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))))
159
+ field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])]
160
+
161
+ # Get rid of weights and pixel positions that are of zero weight
162
+ non_zero_out_pixels = np.nonzero(np.any(weights, axis=0))
163
+ weights = np.squeeze(weights[:, non_zero_out_pixels])
164
+ field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels])
165
+
166
+ # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size
167
+ return weights, field_of_view
168
+
169
+
170
+ # These next functions are all interpolation methods. x is the distance from the left pixel center
171
+
172
+
173
+ def cubic(x):
174
+ absx = np.abs(x)
175
+ absx2 = absx ** 2
176
+ absx3 = absx ** 3
177
+ return ((1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) +
178
+ (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((1 < absx) & (absx <= 2)))
179
+
180
+
181
+ def lanczos2(x):
182
+ return (((np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps) /
183
+ ((pi ** 2 * x ** 2 / 2) + np.finfo(np.float32).eps))
184
+ * (abs(x) < 2))
185
+
186
+
187
+ def box(x):
188
+ return ((-0.5 <= x) & (x < 0.5)) * 1.0
189
+
190
+
191
+ def lanczos3(x):
192
+ return (((np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps) /
193
+ ((pi ** 2 * x ** 2 / 3) + np.finfo(np.float32).eps))
194
+ * (abs(x) < 3))
195
+
196
+
197
+ def linear(x):
198
+ return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1))
cdim/operators/super_resolution_operator.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from cdim.operators import register_operator
4
+ from cdim.operators.resizer import Resizer
5
+
6
+
7
+ @register_operator(name='super_resolution')
8
+ class SuperResolutionOperator:
9
+ def __init__(self, scale=4, in_shape=(1, 3, 256, 256), device='cpu'):
10
+ self.device = device
11
+ self.down_sample = Resizer(in_shape, 1/scale).to(device)
12
+
13
+ def __call__(self, data, **kwargs):
14
+ return self.down_sample(data)
15
+
operator_configs/gaussian_blur_config.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: gaussian_blur
2
+ kernel_size: 61
3
+ intensity: 3.0
operator_configs/super_resolution_config.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ name: super_resolution
2
+ scale: 8
3
+ in_shape: !!python/tuple [1, 3, 256, 256]