VIVEK JAYARAM
commited on
Commit
•
64e8afb
1
Parent(s):
95aa1d5
Super resolution and gaussian blur
Browse files- cdim/diffusion/diffusion_pipeline.py +3 -2
- cdim/operators/__init__.py +2 -0
- cdim/operators/blur_kernel.py +31 -0
- cdim/operators/gaussian_blur_operator.py +16 -0
- cdim/operators/resizer.py +198 -0
- cdim/operators/super_resolution_operator.py +15 -0
- operator_configs/gaussian_blur_config.yaml +3 -0
- operator_configs/super_resolution_config.yaml +3 -0
cdim/diffusion/diffusion_pipeline.py
CHANGED
@@ -90,6 +90,7 @@ def run_diffusion(
|
|
90 |
else:
|
91 |
raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
|
92 |
|
93 |
-
image -=
|
|
|
94 |
|
95 |
-
return image
|
|
|
90 |
else:
|
91 |
raise ValueError(f"Unsupported combination: loss {loss_type} noise {noise_function.name}")
|
92 |
|
93 |
+
image -= 10 / torch.linalg.norm(image.grad) * image.grad
|
94 |
+
image = image.detach().requires_grad_()
|
95 |
|
96 |
+
return image
|
cdim/operators/__init__.py
CHANGED
@@ -23,3 +23,5 @@ def get_operator(name: str, **kwargs):
|
|
23 |
from .random_box_masker import RandomBoxMasker
|
24 |
from .random_pixel_masker import RandomPixelMasker
|
25 |
from .identity_operator import IdentityOperator
|
|
|
|
|
|
23 |
from .random_box_masker import RandomBoxMasker
|
24 |
from .random_pixel_masker import RandomPixelMasker
|
25 |
from .identity_operator import IdentityOperator
|
26 |
+
from .super_resolution_operator import SuperResolutionOperator
|
27 |
+
from .gaussian_blur_operator import GaussianBlurOperator
|
cdim/operators/blur_kernel.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import scipy
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class BlurKernel(nn.Module):
|
9 |
+
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0, device=None):
|
10 |
+
super().__init__()
|
11 |
+
self.blur_type = blur_type
|
12 |
+
self.kernel_size = kernel_size
|
13 |
+
self.std = std
|
14 |
+
self.device = device
|
15 |
+
self.padding = self.kernel_size // 2
|
16 |
+
|
17 |
+
if self.blur_type == "gaussian":
|
18 |
+
n = np.zeros((self.kernel_size, self.kernel_size))
|
19 |
+
n[self.kernel_size // 2, self.kernel_size // 2] = 1
|
20 |
+
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
|
21 |
+
k = torch.from_numpy(k).float()
|
22 |
+
k = k.unsqueeze(0).unsqueeze(0) # Shape (1, 1, kernel_size, kernel_size)
|
23 |
+
k = k.repeat(3, 1, 1, 1) # Shape (3, 1, kernel_size, kernel_size)
|
24 |
+
self.register_buffer('kernel', k)
|
25 |
+
else:
|
26 |
+
raise ValueError(f"Unknown blur type {self.blur_type}")
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = F.pad(x, [self.padding]*4, mode='reflect')
|
30 |
+
x = F.conv2d(x, self.kernel, groups=3)
|
31 |
+
return x
|
cdim/operators/gaussian_blur_operator.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from cdim.operators import register_operator
|
3 |
+
from cdim.operators.blur_kernel import BlurKernel
|
4 |
+
|
5 |
+
@register_operator(name='gaussian_blur')
|
6 |
+
class GaussianBlurOperator:
|
7 |
+
def __init__(self, kernel_size, intensity, device='cpu'):
|
8 |
+
self.device = device
|
9 |
+
self.kernel_size = kernel_size
|
10 |
+
self.conv = BlurKernel(blur_type='gaussian',
|
11 |
+
kernel_size=kernel_size,
|
12 |
+
std=intensity,
|
13 |
+
device=device).to(device)
|
14 |
+
|
15 |
+
def __call__(self, data, **kwargs):
|
16 |
+
return self.conv(data)
|
cdim/operators/resizer.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This code was taken from: https://github.com/assafshocher/resizer by Assaf Shocher
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from math import pi
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
|
8 |
+
class Resizer(nn.Module):
|
9 |
+
def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True):
|
10 |
+
super(Resizer, self).__init__()
|
11 |
+
|
12 |
+
# First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa
|
13 |
+
scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor)
|
14 |
+
|
15 |
+
# Choose interpolation method, each method has the matching kernel size
|
16 |
+
method, kernel_width = {
|
17 |
+
"cubic": (cubic, 4.0),
|
18 |
+
"lanczos2": (lanczos2, 4.0),
|
19 |
+
"lanczos3": (lanczos3, 6.0),
|
20 |
+
"box": (box, 1.0),
|
21 |
+
"linear": (linear, 2.0),
|
22 |
+
None: (cubic, 4.0) # set default interpolation method as cubic
|
23 |
+
}.get(kernel)
|
24 |
+
|
25 |
+
# Antialiasing is only used when downscaling
|
26 |
+
antialiasing *= (np.any(np.array(scale_factor) < 1))
|
27 |
+
|
28 |
+
# Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient
|
29 |
+
sorted_dims = np.argsort(np.array(scale_factor))
|
30 |
+
self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1]
|
31 |
+
|
32 |
+
# Iterate over dimensions to calculate local weights for resizing and resize each time in one direction
|
33 |
+
field_of_view_list = []
|
34 |
+
weights_list = []
|
35 |
+
for dim in self.sorted_dims:
|
36 |
+
# for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the
|
37 |
+
# weights that multiply the values there to get its result.
|
38 |
+
weights, field_of_view = self.contributions(in_shape[dim], output_shape[dim], scale_factor[dim], method,
|
39 |
+
kernel_width, antialiasing)
|
40 |
+
|
41 |
+
# convert to torch tensor
|
42 |
+
weights = torch.tensor(weights.T, dtype=torch.float32)
|
43 |
+
|
44 |
+
# We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for
|
45 |
+
# tmp_im[field_of_view.T], (bsxfun style)
|
46 |
+
weights_list.append(
|
47 |
+
nn.Parameter(torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]),
|
48 |
+
requires_grad=False))
|
49 |
+
field_of_view_list.append(
|
50 |
+
nn.Parameter(torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False))
|
51 |
+
|
52 |
+
self.field_of_view = nn.ParameterList(field_of_view_list)
|
53 |
+
self.weights = nn.ParameterList(weights_list)
|
54 |
+
|
55 |
+
def forward(self, in_tensor):
|
56 |
+
x = in_tensor
|
57 |
+
|
58 |
+
# Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim
|
59 |
+
for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights):
|
60 |
+
# To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize
|
61 |
+
x = torch.transpose(x, dim, 0)
|
62 |
+
|
63 |
+
# This is a bit of a complicated multiplication: x[field_of_view.T] is a tensor of order image_dims+1.
|
64 |
+
# for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim
|
65 |
+
# only, this is why it only adds 1 dim to 5the shape). We then multiply, for each pixel, its set of positions with
|
66 |
+
# the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style:
|
67 |
+
# matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the
|
68 |
+
# same number
|
69 |
+
x = torch.sum(x[fov] * w, dim=0)
|
70 |
+
|
71 |
+
# Finally we swap back the axes to the original order
|
72 |
+
x = torch.transpose(x, dim, 0)
|
73 |
+
|
74 |
+
return x
|
75 |
+
|
76 |
+
def fix_scale_and_size(self, input_shape, output_shape, scale_factor):
|
77 |
+
# First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the
|
78 |
+
# same size as the number of input dimensions)
|
79 |
+
if scale_factor is not None:
|
80 |
+
# By default, if scale-factor is a scalar we assume 2d resizing and duplicate it.
|
81 |
+
if np.isscalar(scale_factor) and len(input_shape) > 1:
|
82 |
+
scale_factor = [scale_factor, scale_factor]
|
83 |
+
|
84 |
+
# We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales
|
85 |
+
scale_factor = list(scale_factor)
|
86 |
+
scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor
|
87 |
+
|
88 |
+
# Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size
|
89 |
+
# to all the unspecified dimensions
|
90 |
+
if output_shape is not None:
|
91 |
+
output_shape = list(input_shape[len(output_shape):]) + list(np.uint(np.array(output_shape)))
|
92 |
+
|
93 |
+
# Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is
|
94 |
+
# sub-optimal, because there can be different scales to the same output-shape.
|
95 |
+
if scale_factor is None:
|
96 |
+
scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape)
|
97 |
+
|
98 |
+
# Dealing with missing output-shape. calculating according to scale-factor
|
99 |
+
if output_shape is None:
|
100 |
+
output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor)))
|
101 |
+
|
102 |
+
return scale_factor, output_shape
|
103 |
+
|
104 |
+
def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing):
|
105 |
+
# This function calculates a set of 'filters' and a set of field_of_view that will later on be applied
|
106 |
+
# such that each position from the field_of_view will be multiplied with a matching filter from the
|
107 |
+
# 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers
|
108 |
+
# around it. This is only done for one dimension of the image.
|
109 |
+
|
110 |
+
# When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of
|
111 |
+
# 1/sf. this means filtering is more 'low-pass filter'.
|
112 |
+
fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel
|
113 |
+
kernel_width *= 1.0 / scale if antialiasing else 1.0
|
114 |
+
|
115 |
+
# These are the coordinates of the output image
|
116 |
+
out_coordinates = np.arange(1, out_length + 1)
|
117 |
+
|
118 |
+
# since both scale-factor and output size can be provided simulatneously, perserving the center of the image requires shifting
|
119 |
+
# the output coordinates. the deviation is because out_length doesn't necesary equal in_length*scale.
|
120 |
+
# to keep the center we need to subtract half of this deivation so that we get equal margins for boths sides and center is preserved.
|
121 |
+
shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2
|
122 |
+
|
123 |
+
# These are the matching positions of the output-coordinates on the input image coordinates.
|
124 |
+
# Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels:
|
125 |
+
# [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel.
|
126 |
+
# The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to
|
127 |
+
# the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big
|
128 |
+
# one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor).
|
129 |
+
# So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is
|
130 |
+
# at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means:
|
131 |
+
# (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf)
|
132 |
+
match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale)
|
133 |
+
|
134 |
+
# This is the left boundary to start multiplying the filter from, it depends on the size of the filter
|
135 |
+
left_boundary = np.floor(match_coordinates - kernel_width / 2)
|
136 |
+
|
137 |
+
# Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers
|
138 |
+
# of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them)
|
139 |
+
expanded_kernel_width = np.ceil(kernel_width) + 2
|
140 |
+
|
141 |
+
# Determine a set of field_of_view for each each output position, these are the pixels in the input image
|
142 |
+
# that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the
|
143 |
+
# vertical dim is the pixels it 'sees' (kernel_size + 2)
|
144 |
+
field_of_view = np.squeeze(
|
145 |
+
np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1))
|
146 |
+
|
147 |
+
# Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the
|
148 |
+
# vertical dim is a list of weights matching to the pixel in the field of view (that are specified in
|
149 |
+
# 'field_of_view')
|
150 |
+
weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1)
|
151 |
+
|
152 |
+
# Normalize weights to sum up to 1. be careful from dividing by 0
|
153 |
+
sum_weights = np.sum(weights, axis=1)
|
154 |
+
sum_weights[sum_weights == 0] = 1.0
|
155 |
+
weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1)
|
156 |
+
|
157 |
+
# We use this mirror structure as a trick for reflection padding at the boundaries
|
158 |
+
mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))))
|
159 |
+
field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])]
|
160 |
+
|
161 |
+
# Get rid of weights and pixel positions that are of zero weight
|
162 |
+
non_zero_out_pixels = np.nonzero(np.any(weights, axis=0))
|
163 |
+
weights = np.squeeze(weights[:, non_zero_out_pixels])
|
164 |
+
field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels])
|
165 |
+
|
166 |
+
# Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size
|
167 |
+
return weights, field_of_view
|
168 |
+
|
169 |
+
|
170 |
+
# These next functions are all interpolation methods. x is the distance from the left pixel center
|
171 |
+
|
172 |
+
|
173 |
+
def cubic(x):
|
174 |
+
absx = np.abs(x)
|
175 |
+
absx2 = absx ** 2
|
176 |
+
absx3 = absx ** 3
|
177 |
+
return ((1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) +
|
178 |
+
(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((1 < absx) & (absx <= 2)))
|
179 |
+
|
180 |
+
|
181 |
+
def lanczos2(x):
|
182 |
+
return (((np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps) /
|
183 |
+
((pi ** 2 * x ** 2 / 2) + np.finfo(np.float32).eps))
|
184 |
+
* (abs(x) < 2))
|
185 |
+
|
186 |
+
|
187 |
+
def box(x):
|
188 |
+
return ((-0.5 <= x) & (x < 0.5)) * 1.0
|
189 |
+
|
190 |
+
|
191 |
+
def lanczos3(x):
|
192 |
+
return (((np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps) /
|
193 |
+
((pi ** 2 * x ** 2 / 3) + np.finfo(np.float32).eps))
|
194 |
+
* (abs(x) < 3))
|
195 |
+
|
196 |
+
|
197 |
+
def linear(x):
|
198 |
+
return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1))
|
cdim/operators/super_resolution_operator.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from cdim.operators import register_operator
|
4 |
+
from cdim.operators.resizer import Resizer
|
5 |
+
|
6 |
+
|
7 |
+
@register_operator(name='super_resolution')
|
8 |
+
class SuperResolutionOperator:
|
9 |
+
def __init__(self, scale=4, in_shape=(1, 3, 256, 256), device='cpu'):
|
10 |
+
self.device = device
|
11 |
+
self.down_sample = Resizer(in_shape, 1/scale).to(device)
|
12 |
+
|
13 |
+
def __call__(self, data, **kwargs):
|
14 |
+
return self.down_sample(data)
|
15 |
+
|
operator_configs/gaussian_blur_config.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
name: gaussian_blur
|
2 |
+
kernel_size: 61
|
3 |
+
intensity: 3.0
|
operator_configs/super_resolution_config.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
name: super_resolution
|
2 |
+
scale: 8
|
3 |
+
in_shape: !!python/tuple [1, 3, 256, 256]
|