Spaces:
Runtime error
Runtime error
Update Time_TravelRephotography/models/encoder4editing/models/stylegan2/op/upfirdn2d.py
Browse files
Time_TravelRephotography/models/encoder4editing/models/stylegan2/op/upfirdn2d.py
CHANGED
@@ -81,44 +81,84 @@ class UpFirDn2dBackward(Function):
|
|
81 |
|
82 |
return gradgrad_out, None, None, None, None, None, None, None, None
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
@staticmethod
|
124 |
def backward(ctx, grad_output):
|
|
|
81 |
|
82 |
return gradgrad_out, None, None, None, None, None, None, None, None
|
83 |
|
84 |
+
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
85 |
+
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
86 |
+
"""
|
87 |
+
# Validate arguments.
|
88 |
+
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
89 |
+
if f is None:
|
90 |
+
f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
|
91 |
+
assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
|
92 |
+
assert f.dtype == torch.float32 and not f.requires_grad
|
93 |
+
batch_size, num_channels, in_height, in_width = x.shape
|
94 |
+
upx, upy = _parse_scaling(up)
|
95 |
+
downx, downy = _parse_scaling(down)
|
96 |
+
padx0, padx1, pady0, pady1 = _parse_padding(padding)
|
97 |
+
|
98 |
+
# Upsample by inserting zeros.
|
99 |
+
x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
|
100 |
+
x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
|
101 |
+
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
102 |
+
|
103 |
+
# Pad or crop.
|
104 |
+
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
105 |
+
x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
|
106 |
+
|
107 |
+
# Setup filter.
|
108 |
+
f = f * (gain ** (f.ndim / 2))
|
109 |
+
f = f.to(x.dtype)
|
110 |
+
if not flip_filter:
|
111 |
+
f = f.flip(list(range(f.ndim)))
|
112 |
+
|
113 |
+
# Convolve with the filter.
|
114 |
+
f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
|
115 |
+
if f.ndim == 4:
|
116 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
|
117 |
+
else:
|
118 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
|
119 |
+
x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
|
120 |
+
|
121 |
+
# Downsample by throwing away pixels.
|
122 |
+
x = x[:, :, ::downy, ::downx]
|
123 |
+
return
|
124 |
+
|
125 |
+
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
126 |
+
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
127 |
+
Performs the following sequence of operations for each channel:
|
128 |
+
1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
|
129 |
+
2. Pad the image with the specified number of zeros on each side (`padding`).
|
130 |
+
Negative padding corresponds to cropping the image.
|
131 |
+
3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
|
132 |
+
so that the footprint of all output pixels lies within the input image.
|
133 |
+
4. Downsample the image by keeping every Nth pixel (`down`).
|
134 |
+
This sequence of operations bears close resemblance to scipy.signal.upfirdn().
|
135 |
+
The fused op is considerably more efficient than performing the same calculation
|
136 |
+
using standard PyTorch ops. It supports gradients of arbitrary order.
|
137 |
+
Args:
|
138 |
+
x: Float32/float64/float16 input tensor of the shape
|
139 |
+
`[batch_size, num_channels, in_height, in_width]`.
|
140 |
+
f: Float32 FIR filter of the shape
|
141 |
+
`[filter_height, filter_width]` (non-separable),
|
142 |
+
`[filter_taps]` (separable), or
|
143 |
+
`None` (identity).
|
144 |
+
up: Integer upsampling factor. Can be a single int or a list/tuple
|
145 |
+
`[x, y]` (default: 1).
|
146 |
+
down: Integer downsampling factor. Can be a single int or a list/tuple
|
147 |
+
`[x, y]` (default: 1).
|
148 |
+
padding: Padding with respect to the upsampled image. Can be a single number
|
149 |
+
or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
|
150 |
+
(default: 0).
|
151 |
+
flip_filter: False = convolution, True = correlation (default: False).
|
152 |
+
gain: Overall scaling factor for signal magnitude (default: 1).
|
153 |
+
impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
|
154 |
+
Returns:
|
155 |
+
Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
|
156 |
+
"""
|
157 |
+
assert isinstance(x, torch.Tensor)
|
158 |
+
assert impl in ['ref', 'cuda']
|
159 |
+
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
160 |
+
return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
|
161 |
+
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
162 |
|
163 |
@staticmethod
|
164 |
def backward(ctx, grad_output):
|