sk0032 commited on
Commit
dc9b87a
1 Parent(s): 3051ba3

Upload transform.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. transform.py +202 -0
transform.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from https://github.com/bayesiains/nflows
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.nn import functional as F
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs,
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline(
80
+ inputs=inputs[inside_interval_mask],
81
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
82
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
83
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
84
+ inverse=inverse,
85
+ left=-tail_bound,
86
+ right=tail_bound,
87
+ bottom=-tail_bound,
88
+ top=tail_bound,
89
+ min_bin_width=min_bin_width,
90
+ min_bin_height=min_bin_height,
91
+ min_derivative=min_derivative,
92
+ )
93
+
94
+ return outputs, logabsdet
95
+
96
+
97
+ def rational_quadratic_spline(
98
+ inputs,
99
+ unnormalized_widths,
100
+ unnormalized_heights,
101
+ unnormalized_derivatives,
102
+ inverse=False,
103
+ left=0.0,
104
+ right=1.0,
105
+ bottom=0.0,
106
+ top=1.0,
107
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
108
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
109
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
110
+ ):
111
+ if torch.min(inputs) < left or torch.max(inputs) > right:
112
+ raise ValueError("Input to a transform is not within its domain")
113
+
114
+ num_bins = unnormalized_widths.shape[-1]
115
+
116
+ if min_bin_width * num_bins > 1.0:
117
+ raise ValueError("Minimal bin width too large for the number of bins")
118
+ if min_bin_height * num_bins > 1.0:
119
+ raise ValueError("Minimal bin height too large for the number of bins")
120
+
121
+ widths = F.softmax(unnormalized_widths, dim=-1)
122
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
123
+ cumwidths = torch.cumsum(widths, dim=-1)
124
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
125
+ cumwidths = (right - left) * cumwidths + left
126
+ cumwidths[..., 0] = left
127
+ cumwidths[..., -1] = right
128
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
129
+
130
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
131
+
132
+ heights = F.softmax(unnormalized_heights, dim=-1)
133
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
134
+ cumheights = torch.cumsum(heights, dim=-1)
135
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
136
+ cumheights = (top - bottom) * cumheights + bottom
137
+ cumheights[..., 0] = bottom
138
+ cumheights[..., -1] = top
139
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
140
+
141
+ if inverse:
142
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
143
+ else:
144
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
145
+
146
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
147
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
148
+
149
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
150
+ delta = heights / widths
151
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
152
+
153
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
154
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
155
+
156
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
157
+
158
+ if inverse:
159
+ a = (inputs - input_cumheights) * (
160
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
161
+ ) + input_heights * (input_delta - input_derivatives)
162
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ )
165
+ c = -input_delta * (inputs - input_cumheights)
166
+
167
+ discriminant = b.pow(2) - 4 * a * c
168
+ assert (discriminant >= 0).all()
169
+
170
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
171
+ outputs = root * input_bin_widths + input_cumwidths
172
+
173
+ theta_one_minus_theta = root * (1 - root)
174
+ denominator = input_delta + (
175
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
176
+ )
177
+ derivative_numerator = input_delta.pow(2) * (
178
+ input_derivatives_plus_one * root.pow(2)
179
+ + 2 * input_delta * theta_one_minus_theta
180
+ + input_derivatives * (1 - root).pow(2)
181
+ )
182
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
183
+
184
+ return outputs, -logabsdet
185
+ else:
186
+ theta = (inputs - input_cumwidths) / input_bin_widths
187
+ theta_one_minus_theta = theta * (1 - theta)
188
+
189
+ numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
190
+ denominator = input_delta + (
191
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
192
+ )
193
+ outputs = input_cumheights + numerator / denominator
194
+
195
+ derivative_numerator = input_delta.pow(2) * (
196
+ input_derivatives_plus_one * theta.pow(2)
197
+ + 2 * input_delta * theta_one_minus_theta
198
+ + input_derivatives * (1 - theta).pow(2)
199
+ )
200
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
201
+
202
+ return outputs, logabsdet