feng2022 commited on
Commit
b9898de
·
1 Parent(s): 3583f81

Update Time_TravelRephotography/models/encoder4editing/models/stylegan2/op/upfirdn2d.py

Browse files
Time_TravelRephotography/models/encoder4editing/models/stylegan2/op/upfirdn2d.py CHANGED
@@ -81,44 +81,84 @@ class UpFirDn2dBackward(Function):
81
 
82
  return gradgrad_out, None, None, None, None, None, None, None, None
83
 
84
-
85
- class UpFirDn2d(Function):
86
- @staticmethod
87
- def forward(ctx, input, kernel, up, down, pad):
88
- up_x, up_y = up
89
- down_x, down_y = down
90
- pad_x0, pad_x1, pad_y0, pad_y1 = pad
91
-
92
- kernel_h, kernel_w = kernel.shape
93
- batch, channel, in_h, in_w = input.shape
94
- ctx.in_size = input.shape
95
-
96
- input = input.reshape(-1, in_h, in_w, 1)
97
-
98
- ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
99
-
100
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
101
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
102
- ctx.out_size = (out_h, out_w)
103
-
104
- ctx.up = (up_x, up_y)
105
- ctx.down = (down_x, down_y)
106
- ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
107
-
108
- g_pad_x0 = kernel_w - pad_x0 - 1
109
- g_pad_y0 = kernel_h - pad_y0 - 1
110
- g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
111
- g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
112
-
113
- ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
114
-
115
- out = upfirdn2d_op.upfirdn2d(
116
- input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
117
- )
118
- # out = out.view(major, out_h, out_w, minor)
119
- out = out.view(-1, channel, out_h, out_w)
120
-
121
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  @staticmethod
124
  def backward(ctx, grad_output):
 
81
 
82
  return gradgrad_out, None, None, None, None, None, None, None, None
83
 
84
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
85
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
86
+ """
87
+ # Validate arguments.
88
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
89
+ if f is None:
90
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
91
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
92
+ assert f.dtype == torch.float32 and not f.requires_grad
93
+ batch_size, num_channels, in_height, in_width = x.shape
94
+ upx, upy = _parse_scaling(up)
95
+ downx, downy = _parse_scaling(down)
96
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
97
+
98
+ # Upsample by inserting zeros.
99
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
100
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
101
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
102
+
103
+ # Pad or crop.
104
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
105
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
106
+
107
+ # Setup filter.
108
+ f = f * (gain ** (f.ndim / 2))
109
+ f = f.to(x.dtype)
110
+ if not flip_filter:
111
+ f = f.flip(list(range(f.ndim)))
112
+
113
+ # Convolve with the filter.
114
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
115
+ if f.ndim == 4:
116
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
117
+ else:
118
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
119
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
120
+
121
+ # Downsample by throwing away pixels.
122
+ x = x[:, :, ::downy, ::downx]
123
+ return
124
+
125
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
126
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
127
+ Performs the following sequence of operations for each channel:
128
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
129
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
130
+ Negative padding corresponds to cropping the image.
131
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
132
+ so that the footprint of all output pixels lies within the input image.
133
+ 4. Downsample the image by keeping every Nth pixel (`down`).
134
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
135
+ The fused op is considerably more efficient than performing the same calculation
136
+ using standard PyTorch ops. It supports gradients of arbitrary order.
137
+ Args:
138
+ x: Float32/float64/float16 input tensor of the shape
139
+ `[batch_size, num_channels, in_height, in_width]`.
140
+ f: Float32 FIR filter of the shape
141
+ `[filter_height, filter_width]` (non-separable),
142
+ `[filter_taps]` (separable), or
143
+ `None` (identity).
144
+ up: Integer upsampling factor. Can be a single int or a list/tuple
145
+ `[x, y]` (default: 1).
146
+ down: Integer downsampling factor. Can be a single int or a list/tuple
147
+ `[x, y]` (default: 1).
148
+ padding: Padding with respect to the upsampled image. Can be a single number
149
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
150
+ (default: 0).
151
+ flip_filter: False = convolution, True = correlation (default: False).
152
+ gain: Overall scaling factor for signal magnitude (default: 1).
153
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
154
+ Returns:
155
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
156
+ """
157
+ assert isinstance(x, torch.Tensor)
158
+ assert impl in ['ref', 'cuda']
159
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
160
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
161
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
162
 
163
  @staticmethod
164
  def backward(ctx, grad_output):