Upload transform.py with huggingface_hub
Browse files- transform.py +202 -0
transform.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adopted from https://github.com/bayesiains/nflows
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
DEFAULT_MIN_BIN_WIDTH = 1e-3
|
8 |
+
DEFAULT_MIN_BIN_HEIGHT = 1e-3
|
9 |
+
DEFAULT_MIN_DERIVATIVE = 1e-3
|
10 |
+
|
11 |
+
|
12 |
+
def piecewise_rational_quadratic_transform(
|
13 |
+
inputs,
|
14 |
+
unnormalized_widths,
|
15 |
+
unnormalized_heights,
|
16 |
+
unnormalized_derivatives,
|
17 |
+
inverse=False,
|
18 |
+
tails=None,
|
19 |
+
tail_bound=1.0,
|
20 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
21 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
22 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
23 |
+
):
|
24 |
+
if tails is None:
|
25 |
+
spline_fn = rational_quadratic_spline
|
26 |
+
spline_kwargs = {}
|
27 |
+
else:
|
28 |
+
spline_fn = unconstrained_rational_quadratic_spline
|
29 |
+
spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
|
30 |
+
|
31 |
+
outputs, logabsdet = spline_fn(
|
32 |
+
inputs=inputs,
|
33 |
+
unnormalized_widths=unnormalized_widths,
|
34 |
+
unnormalized_heights=unnormalized_heights,
|
35 |
+
unnormalized_derivatives=unnormalized_derivatives,
|
36 |
+
inverse=inverse,
|
37 |
+
min_bin_width=min_bin_width,
|
38 |
+
min_bin_height=min_bin_height,
|
39 |
+
min_derivative=min_derivative,
|
40 |
+
**spline_kwargs,
|
41 |
+
)
|
42 |
+
return outputs, logabsdet
|
43 |
+
|
44 |
+
|
45 |
+
def searchsorted(bin_locations, inputs, eps=1e-6):
|
46 |
+
bin_locations[..., -1] += eps
|
47 |
+
return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
|
48 |
+
|
49 |
+
|
50 |
+
def unconstrained_rational_quadratic_spline(
|
51 |
+
inputs,
|
52 |
+
unnormalized_widths,
|
53 |
+
unnormalized_heights,
|
54 |
+
unnormalized_derivatives,
|
55 |
+
inverse=False,
|
56 |
+
tails="linear",
|
57 |
+
tail_bound=1.0,
|
58 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
59 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
60 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
61 |
+
):
|
62 |
+
inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
|
63 |
+
outside_interval_mask = ~inside_interval_mask
|
64 |
+
|
65 |
+
outputs = torch.zeros_like(inputs)
|
66 |
+
logabsdet = torch.zeros_like(inputs)
|
67 |
+
|
68 |
+
if tails == "linear":
|
69 |
+
unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
|
70 |
+
constant = np.log(np.exp(1 - min_derivative) - 1)
|
71 |
+
unnormalized_derivatives[..., 0] = constant
|
72 |
+
unnormalized_derivatives[..., -1] = constant
|
73 |
+
|
74 |
+
outputs[outside_interval_mask] = inputs[outside_interval_mask]
|
75 |
+
logabsdet[outside_interval_mask] = 0
|
76 |
+
else:
|
77 |
+
raise RuntimeError("{} tails are not implemented.".format(tails))
|
78 |
+
|
79 |
+
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
|
80 |
+
inputs=inputs[inside_interval_mask],
|
81 |
+
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
|
82 |
+
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
|
83 |
+
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
|
84 |
+
inverse=inverse,
|
85 |
+
left=-tail_bound,
|
86 |
+
right=tail_bound,
|
87 |
+
bottom=-tail_bound,
|
88 |
+
top=tail_bound,
|
89 |
+
min_bin_width=min_bin_width,
|
90 |
+
min_bin_height=min_bin_height,
|
91 |
+
min_derivative=min_derivative,
|
92 |
+
)
|
93 |
+
|
94 |
+
return outputs, logabsdet
|
95 |
+
|
96 |
+
|
97 |
+
def rational_quadratic_spline(
|
98 |
+
inputs,
|
99 |
+
unnormalized_widths,
|
100 |
+
unnormalized_heights,
|
101 |
+
unnormalized_derivatives,
|
102 |
+
inverse=False,
|
103 |
+
left=0.0,
|
104 |
+
right=1.0,
|
105 |
+
bottom=0.0,
|
106 |
+
top=1.0,
|
107 |
+
min_bin_width=DEFAULT_MIN_BIN_WIDTH,
|
108 |
+
min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
|
109 |
+
min_derivative=DEFAULT_MIN_DERIVATIVE,
|
110 |
+
):
|
111 |
+
if torch.min(inputs) < left or torch.max(inputs) > right:
|
112 |
+
raise ValueError("Input to a transform is not within its domain")
|
113 |
+
|
114 |
+
num_bins = unnormalized_widths.shape[-1]
|
115 |
+
|
116 |
+
if min_bin_width * num_bins > 1.0:
|
117 |
+
raise ValueError("Minimal bin width too large for the number of bins")
|
118 |
+
if min_bin_height * num_bins > 1.0:
|
119 |
+
raise ValueError("Minimal bin height too large for the number of bins")
|
120 |
+
|
121 |
+
widths = F.softmax(unnormalized_widths, dim=-1)
|
122 |
+
widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
|
123 |
+
cumwidths = torch.cumsum(widths, dim=-1)
|
124 |
+
cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
|
125 |
+
cumwidths = (right - left) * cumwidths + left
|
126 |
+
cumwidths[..., 0] = left
|
127 |
+
cumwidths[..., -1] = right
|
128 |
+
widths = cumwidths[..., 1:] - cumwidths[..., :-1]
|
129 |
+
|
130 |
+
derivatives = min_derivative + F.softplus(unnormalized_derivatives)
|
131 |
+
|
132 |
+
heights = F.softmax(unnormalized_heights, dim=-1)
|
133 |
+
heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
|
134 |
+
cumheights = torch.cumsum(heights, dim=-1)
|
135 |
+
cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
|
136 |
+
cumheights = (top - bottom) * cumheights + bottom
|
137 |
+
cumheights[..., 0] = bottom
|
138 |
+
cumheights[..., -1] = top
|
139 |
+
heights = cumheights[..., 1:] - cumheights[..., :-1]
|
140 |
+
|
141 |
+
if inverse:
|
142 |
+
bin_idx = searchsorted(cumheights, inputs)[..., None]
|
143 |
+
else:
|
144 |
+
bin_idx = searchsorted(cumwidths, inputs)[..., None]
|
145 |
+
|
146 |
+
input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
|
147 |
+
input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
|
148 |
+
|
149 |
+
input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
|
150 |
+
delta = heights / widths
|
151 |
+
input_delta = delta.gather(-1, bin_idx)[..., 0]
|
152 |
+
|
153 |
+
input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
|
154 |
+
input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
|
155 |
+
|
156 |
+
input_heights = heights.gather(-1, bin_idx)[..., 0]
|
157 |
+
|
158 |
+
if inverse:
|
159 |
+
a = (inputs - input_cumheights) * (
|
160 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
161 |
+
) + input_heights * (input_delta - input_derivatives)
|
162 |
+
b = input_heights * input_derivatives - (inputs - input_cumheights) * (
|
163 |
+
input_derivatives + input_derivatives_plus_one - 2 * input_delta
|
164 |
+
)
|
165 |
+
c = -input_delta * (inputs - input_cumheights)
|
166 |
+
|
167 |
+
discriminant = b.pow(2) - 4 * a * c
|
168 |
+
assert (discriminant >= 0).all()
|
169 |
+
|
170 |
+
root = (2 * c) / (-b - torch.sqrt(discriminant))
|
171 |
+
outputs = root * input_bin_widths + input_cumwidths
|
172 |
+
|
173 |
+
theta_one_minus_theta = root * (1 - root)
|
174 |
+
denominator = input_delta + (
|
175 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
176 |
+
)
|
177 |
+
derivative_numerator = input_delta.pow(2) * (
|
178 |
+
input_derivatives_plus_one * root.pow(2)
|
179 |
+
+ 2 * input_delta * theta_one_minus_theta
|
180 |
+
+ input_derivatives * (1 - root).pow(2)
|
181 |
+
)
|
182 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
183 |
+
|
184 |
+
return outputs, -logabsdet
|
185 |
+
else:
|
186 |
+
theta = (inputs - input_cumwidths) / input_bin_widths
|
187 |
+
theta_one_minus_theta = theta * (1 - theta)
|
188 |
+
|
189 |
+
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
190 |
+
denominator = input_delta + (
|
191 |
+
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
192 |
+
)
|
193 |
+
outputs = input_cumheights + numerator / denominator
|
194 |
+
|
195 |
+
derivative_numerator = input_delta.pow(2) * (
|
196 |
+
input_derivatives_plus_one * theta.pow(2)
|
197 |
+
+ 2 * input_delta * theta_one_minus_theta
|
198 |
+
+ input_derivatives * (1 - theta).pow(2)
|
199 |
+
)
|
200 |
+
logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
|
201 |
+
|
202 |
+
return outputs, logabsdet
|