BartPoint commited on
Commit
b3d2371
1 Parent(s): fb431f8

Delete infer_pack

Browse files
infer_pack/attentions.py DELETED
@@ -1,417 +0,0 @@
1
- import copy
2
- import math
3
- import numpy as np
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
-
8
- from infer_pack import commons
9
- from infer_pack import modules
10
- from infer_pack.modules import LayerNorm
11
-
12
-
13
- class Encoder(nn.Module):
14
- def __init__(
15
- self,
16
- hidden_channels,
17
- filter_channels,
18
- n_heads,
19
- n_layers,
20
- kernel_size=1,
21
- p_dropout=0.0,
22
- window_size=10,
23
- **kwargs
24
- ):
25
- super().__init__()
26
- self.hidden_channels = hidden_channels
27
- self.filter_channels = filter_channels
28
- self.n_heads = n_heads
29
- self.n_layers = n_layers
30
- self.kernel_size = kernel_size
31
- self.p_dropout = p_dropout
32
- self.window_size = window_size
33
-
34
- self.drop = nn.Dropout(p_dropout)
35
- self.attn_layers = nn.ModuleList()
36
- self.norm_layers_1 = nn.ModuleList()
37
- self.ffn_layers = nn.ModuleList()
38
- self.norm_layers_2 = nn.ModuleList()
39
- for i in range(self.n_layers):
40
- self.attn_layers.append(
41
- MultiHeadAttention(
42
- hidden_channels,
43
- hidden_channels,
44
- n_heads,
45
- p_dropout=p_dropout,
46
- window_size=window_size,
47
- )
48
- )
49
- self.norm_layers_1.append(LayerNorm(hidden_channels))
50
- self.ffn_layers.append(
51
- FFN(
52
- hidden_channels,
53
- hidden_channels,
54
- filter_channels,
55
- kernel_size,
56
- p_dropout=p_dropout,
57
- )
58
- )
59
- self.norm_layers_2.append(LayerNorm(hidden_channels))
60
-
61
- def forward(self, x, x_mask):
62
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
63
- x = x * x_mask
64
- for i in range(self.n_layers):
65
- y = self.attn_layers[i](x, x, attn_mask)
66
- y = self.drop(y)
67
- x = self.norm_layers_1[i](x + y)
68
-
69
- y = self.ffn_layers[i](x, x_mask)
70
- y = self.drop(y)
71
- x = self.norm_layers_2[i](x + y)
72
- x = x * x_mask
73
- return x
74
-
75
-
76
- class Decoder(nn.Module):
77
- def __init__(
78
- self,
79
- hidden_channels,
80
- filter_channels,
81
- n_heads,
82
- n_layers,
83
- kernel_size=1,
84
- p_dropout=0.0,
85
- proximal_bias=False,
86
- proximal_init=True,
87
- **kwargs
88
- ):
89
- super().__init__()
90
- self.hidden_channels = hidden_channels
91
- self.filter_channels = filter_channels
92
- self.n_heads = n_heads
93
- self.n_layers = n_layers
94
- self.kernel_size = kernel_size
95
- self.p_dropout = p_dropout
96
- self.proximal_bias = proximal_bias
97
- self.proximal_init = proximal_init
98
-
99
- self.drop = nn.Dropout(p_dropout)
100
- self.self_attn_layers = nn.ModuleList()
101
- self.norm_layers_0 = nn.ModuleList()
102
- self.encdec_attn_layers = nn.ModuleList()
103
- self.norm_layers_1 = nn.ModuleList()
104
- self.ffn_layers = nn.ModuleList()
105
- self.norm_layers_2 = nn.ModuleList()
106
- for i in range(self.n_layers):
107
- self.self_attn_layers.append(
108
- MultiHeadAttention(
109
- hidden_channels,
110
- hidden_channels,
111
- n_heads,
112
- p_dropout=p_dropout,
113
- proximal_bias=proximal_bias,
114
- proximal_init=proximal_init,
115
- )
116
- )
117
- self.norm_layers_0.append(LayerNorm(hidden_channels))
118
- self.encdec_attn_layers.append(
119
- MultiHeadAttention(
120
- hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
121
- )
122
- )
123
- self.norm_layers_1.append(LayerNorm(hidden_channels))
124
- self.ffn_layers.append(
125
- FFN(
126
- hidden_channels,
127
- hidden_channels,
128
- filter_channels,
129
- kernel_size,
130
- p_dropout=p_dropout,
131
- causal=True,
132
- )
133
- )
134
- self.norm_layers_2.append(LayerNorm(hidden_channels))
135
-
136
- def forward(self, x, x_mask, h, h_mask):
137
- """
138
- x: decoder input
139
- h: encoder output
140
- """
141
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
142
- device=x.device, dtype=x.dtype
143
- )
144
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
145
- x = x * x_mask
146
- for i in range(self.n_layers):
147
- y = self.self_attn_layers[i](x, x, self_attn_mask)
148
- y = self.drop(y)
149
- x = self.norm_layers_0[i](x + y)
150
-
151
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
152
- y = self.drop(y)
153
- x = self.norm_layers_1[i](x + y)
154
-
155
- y = self.ffn_layers[i](x, x_mask)
156
- y = self.drop(y)
157
- x = self.norm_layers_2[i](x + y)
158
- x = x * x_mask
159
- return x
160
-
161
-
162
- class MultiHeadAttention(nn.Module):
163
- def __init__(
164
- self,
165
- channels,
166
- out_channels,
167
- n_heads,
168
- p_dropout=0.0,
169
- window_size=None,
170
- heads_share=True,
171
- block_length=None,
172
- proximal_bias=False,
173
- proximal_init=False,
174
- ):
175
- super().__init__()
176
- assert channels % n_heads == 0
177
-
178
- self.channels = channels
179
- self.out_channels = out_channels
180
- self.n_heads = n_heads
181
- self.p_dropout = p_dropout
182
- self.window_size = window_size
183
- self.heads_share = heads_share
184
- self.block_length = block_length
185
- self.proximal_bias = proximal_bias
186
- self.proximal_init = proximal_init
187
- self.attn = None
188
-
189
- self.k_channels = channels // n_heads
190
- self.conv_q = nn.Conv1d(channels, channels, 1)
191
- self.conv_k = nn.Conv1d(channels, channels, 1)
192
- self.conv_v = nn.Conv1d(channels, channels, 1)
193
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
194
- self.drop = nn.Dropout(p_dropout)
195
-
196
- if window_size is not None:
197
- n_heads_rel = 1 if heads_share else n_heads
198
- rel_stddev = self.k_channels**-0.5
199
- self.emb_rel_k = nn.Parameter(
200
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
201
- * rel_stddev
202
- )
203
- self.emb_rel_v = nn.Parameter(
204
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
- * rel_stddev
206
- )
207
-
208
- nn.init.xavier_uniform_(self.conv_q.weight)
209
- nn.init.xavier_uniform_(self.conv_k.weight)
210
- nn.init.xavier_uniform_(self.conv_v.weight)
211
- if proximal_init:
212
- with torch.no_grad():
213
- self.conv_k.weight.copy_(self.conv_q.weight)
214
- self.conv_k.bias.copy_(self.conv_q.bias)
215
-
216
- def forward(self, x, c, attn_mask=None):
217
- q = self.conv_q(x)
218
- k = self.conv_k(c)
219
- v = self.conv_v(c)
220
-
221
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
-
223
- x = self.conv_o(x)
224
- return x
225
-
226
- def attention(self, query, key, value, mask=None):
227
- # reshape [b, d, t] -> [b, n_h, t, d_k]
228
- b, d, t_s, t_t = (*key.size(), query.size(2))
229
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
232
-
233
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
234
- if self.window_size is not None:
235
- assert (
236
- t_s == t_t
237
- ), "Relative attention is only available for self-attention."
238
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
239
- rel_logits = self._matmul_with_relative_keys(
240
- query / math.sqrt(self.k_channels), key_relative_embeddings
241
- )
242
- scores_local = self._relative_position_to_absolute_position(rel_logits)
243
- scores = scores + scores_local
244
- if self.proximal_bias:
245
- assert t_s == t_t, "Proximal bias is only available for self-attention."
246
- scores = scores + self._attention_bias_proximal(t_s).to(
247
- device=scores.device, dtype=scores.dtype
248
- )
249
- if mask is not None:
250
- scores = scores.masked_fill(mask == 0, -1e4)
251
- if self.block_length is not None:
252
- assert (
253
- t_s == t_t
254
- ), "Local attention is only available for self-attention."
255
- block_mask = (
256
- torch.ones_like(scores)
257
- .triu(-self.block_length)
258
- .tril(self.block_length)
259
- )
260
- scores = scores.masked_fill(block_mask == 0, -1e4)
261
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
262
- p_attn = self.drop(p_attn)
263
- output = torch.matmul(p_attn, value)
264
- if self.window_size is not None:
265
- relative_weights = self._absolute_position_to_relative_position(p_attn)
266
- value_relative_embeddings = self._get_relative_embeddings(
267
- self.emb_rel_v, t_s
268
- )
269
- output = output + self._matmul_with_relative_values(
270
- relative_weights, value_relative_embeddings
271
- )
272
- output = (
273
- output.transpose(2, 3).contiguous().view(b, d, t_t)
274
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
275
- return output, p_attn
276
-
277
- def _matmul_with_relative_values(self, x, y):
278
- """
279
- x: [b, h, l, m]
280
- y: [h or 1, m, d]
281
- ret: [b, h, l, d]
282
- """
283
- ret = torch.matmul(x, y.unsqueeze(0))
284
- return ret
285
-
286
- def _matmul_with_relative_keys(self, x, y):
287
- """
288
- x: [b, h, l, d]
289
- y: [h or 1, m, d]
290
- ret: [b, h, l, m]
291
- """
292
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
293
- return ret
294
-
295
- def _get_relative_embeddings(self, relative_embeddings, length):
296
- max_relative_position = 2 * self.window_size + 1
297
- # Pad first before slice to avoid using cond ops.
298
- pad_length = max(length - (self.window_size + 1), 0)
299
- slice_start_position = max((self.window_size + 1) - length, 0)
300
- slice_end_position = slice_start_position + 2 * length - 1
301
- if pad_length > 0:
302
- padded_relative_embeddings = F.pad(
303
- relative_embeddings,
304
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
305
- )
306
- else:
307
- padded_relative_embeddings = relative_embeddings
308
- used_relative_embeddings = padded_relative_embeddings[
309
- :, slice_start_position:slice_end_position
310
- ]
311
- return used_relative_embeddings
312
-
313
- def _relative_position_to_absolute_position(self, x):
314
- """
315
- x: [b, h, l, 2*l-1]
316
- ret: [b, h, l, l]
317
- """
318
- batch, heads, length, _ = x.size()
319
- # Concat columns of pad to shift from relative to absolute indexing.
320
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
321
-
322
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
- x_flat = x.view([batch, heads, length * 2 * length])
324
- x_flat = F.pad(
325
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
326
- )
327
-
328
- # Reshape and slice out the padded elements.
329
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
330
- :, :, :length, length - 1 :
331
- ]
332
- return x_final
333
-
334
- def _absolute_position_to_relative_position(self, x):
335
- """
336
- x: [b, h, l, l]
337
- ret: [b, h, l, 2*l-1]
338
- """
339
- batch, heads, length, _ = x.size()
340
- # padd along column
341
- x = F.pad(
342
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
343
- )
344
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
345
- # add 0's in the beginning that will skew the elements after reshape
346
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
347
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
- return x_final
349
-
350
- def _attention_bias_proximal(self, length):
351
- """Bias for self-attention to encourage attention to close positions.
352
- Args:
353
- length: an integer scalar.
354
- Returns:
355
- a Tensor with shape [1, 1, length, length]
356
- """
357
- r = torch.arange(length, dtype=torch.float32)
358
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
-
361
-
362
- class FFN(nn.Module):
363
- def __init__(
364
- self,
365
- in_channels,
366
- out_channels,
367
- filter_channels,
368
- kernel_size,
369
- p_dropout=0.0,
370
- activation=None,
371
- causal=False,
372
- ):
373
- super().__init__()
374
- self.in_channels = in_channels
375
- self.out_channels = out_channels
376
- self.filter_channels = filter_channels
377
- self.kernel_size = kernel_size
378
- self.p_dropout = p_dropout
379
- self.activation = activation
380
- self.causal = causal
381
-
382
- if causal:
383
- self.padding = self._causal_padding
384
- else:
385
- self.padding = self._same_padding
386
-
387
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
388
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
389
- self.drop = nn.Dropout(p_dropout)
390
-
391
- def forward(self, x, x_mask):
392
- x = self.conv_1(self.padding(x * x_mask))
393
- if self.activation == "gelu":
394
- x = x * torch.sigmoid(1.702 * x)
395
- else:
396
- x = torch.relu(x)
397
- x = self.drop(x)
398
- x = self.conv_2(self.padding(x * x_mask))
399
- return x * x_mask
400
-
401
- def _causal_padding(self, x):
402
- if self.kernel_size == 1:
403
- return x
404
- pad_l = self.kernel_size - 1
405
- pad_r = 0
406
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
407
- x = F.pad(x, commons.convert_pad_shape(padding))
408
- return x
409
-
410
- def _same_padding(self, x):
411
- if self.kernel_size == 1:
412
- return x
413
- pad_l = (self.kernel_size - 1) // 2
414
- pad_r = self.kernel_size // 2
415
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
416
- x = F.pad(x, commons.convert_pad_shape(padding))
417
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/commons.py DELETED
@@ -1,166 +0,0 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
-
8
- def init_weights(m, mean=0.0, std=0.01):
9
- classname = m.__class__.__name__
10
- if classname.find("Conv") != -1:
11
- m.weight.data.normal_(mean, std)
12
-
13
-
14
- def get_padding(kernel_size, dilation=1):
15
- return int((kernel_size * dilation - dilation) / 2)
16
-
17
-
18
- def convert_pad_shape(pad_shape):
19
- l = pad_shape[::-1]
20
- pad_shape = [item for sublist in l for item in sublist]
21
- return pad_shape
22
-
23
-
24
- def kl_divergence(m_p, logs_p, m_q, logs_q):
25
- """KL(P||Q)"""
26
- kl = (logs_q - logs_p) - 0.5
27
- kl += (
28
- 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
29
- )
30
- return kl
31
-
32
-
33
- def rand_gumbel(shape):
34
- """Sample from the Gumbel distribution, protect from overflows."""
35
- uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
36
- return -torch.log(-torch.log(uniform_samples))
37
-
38
-
39
- def rand_gumbel_like(x):
40
- g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
41
- return g
42
-
43
-
44
- def slice_segments(x, ids_str, segment_size=4):
45
- ret = torch.zeros_like(x[:, :, :segment_size])
46
- for i in range(x.size(0)):
47
- idx_str = ids_str[i]
48
- idx_end = idx_str + segment_size
49
- ret[i] = x[i, :, idx_str:idx_end]
50
- return ret
51
-
52
-
53
- def slice_segments2(x, ids_str, segment_size=4):
54
- ret = torch.zeros_like(x[:, :segment_size])
55
- for i in range(x.size(0)):
56
- idx_str = ids_str[i]
57
- idx_end = idx_str + segment_size
58
- ret[i] = x[i, idx_str:idx_end]
59
- return ret
60
-
61
-
62
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
63
- b, d, t = x.size()
64
- if x_lengths is None:
65
- x_lengths = t
66
- ids_str_max = x_lengths - segment_size + 1
67
- ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
68
- ret = slice_segments(x, ids_str, segment_size)
69
- return ret, ids_str
70
-
71
-
72
- def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
73
- position = torch.arange(length, dtype=torch.float)
74
- num_timescales = channels // 2
75
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
76
- num_timescales - 1
77
- )
78
- inv_timescales = min_timescale * torch.exp(
79
- torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
80
- )
81
- scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
82
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
83
- signal = F.pad(signal, [0, 0, 0, channels % 2])
84
- signal = signal.view(1, channels, length)
85
- return signal
86
-
87
-
88
- def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
89
- b, channels, length = x.size()
90
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
91
- return x + signal.to(dtype=x.dtype, device=x.device)
92
-
93
-
94
- def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
95
- b, channels, length = x.size()
96
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
97
- return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
98
-
99
-
100
- def subsequent_mask(length):
101
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
102
- return mask
103
-
104
-
105
- @torch.jit.script
106
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
107
- n_channels_int = n_channels[0]
108
- in_act = input_a + input_b
109
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
110
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
111
- acts = t_act * s_act
112
- return acts
113
-
114
-
115
- def convert_pad_shape(pad_shape):
116
- l = pad_shape[::-1]
117
- pad_shape = [item for sublist in l for item in sublist]
118
- return pad_shape
119
-
120
-
121
- def shift_1d(x):
122
- x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
123
- return x
124
-
125
-
126
- def sequence_mask(length, max_length=None):
127
- if max_length is None:
128
- max_length = length.max()
129
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
130
- return x.unsqueeze(0) < length.unsqueeze(1)
131
-
132
-
133
- def generate_path(duration, mask):
134
- """
135
- duration: [b, 1, t_x]
136
- mask: [b, 1, t_y, t_x]
137
- """
138
- device = duration.device
139
-
140
- b, _, t_y, t_x = mask.shape
141
- cum_duration = torch.cumsum(duration, -1)
142
-
143
- cum_duration_flat = cum_duration.view(b * t_x)
144
- path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
145
- path = path.view(b, t_x, t_y)
146
- path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
147
- path = path.unsqueeze(1).transpose(2, 3) * mask
148
- return path
149
-
150
-
151
- def clip_grad_value_(parameters, clip_value, norm_type=2):
152
- if isinstance(parameters, torch.Tensor):
153
- parameters = [parameters]
154
- parameters = list(filter(lambda p: p.grad is not None, parameters))
155
- norm_type = float(norm_type)
156
- if clip_value is not None:
157
- clip_value = float(clip_value)
158
-
159
- total_norm = 0
160
- for p in parameters:
161
- param_norm = p.grad.data.norm(norm_type)
162
- total_norm += param_norm.item() ** norm_type
163
- if clip_value is not None:
164
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
165
- total_norm = total_norm ** (1.0 / norm_type)
166
- return total_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/models.py DELETED
@@ -1,1116 +0,0 @@
1
- import math, pdb, os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from infer_pack import modules
7
- from infer_pack import attentions
8
- from infer_pack import commons
9
- from infer_pack.commons import init_weights, get_padding
10
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from infer_pack.commons import init_weights
13
- import numpy as np
14
- from infer_pack import commons
15
-
16
-
17
- class TextEncoder256(nn.Module):
18
- def __init__(
19
- self,
20
- out_channels,
21
- hidden_channels,
22
- filter_channels,
23
- n_heads,
24
- n_layers,
25
- kernel_size,
26
- p_dropout,
27
- f0=True,
28
- ):
29
- super().__init__()
30
- self.out_channels = out_channels
31
- self.hidden_channels = hidden_channels
32
- self.filter_channels = filter_channels
33
- self.n_heads = n_heads
34
- self.n_layers = n_layers
35
- self.kernel_size = kernel_size
36
- self.p_dropout = p_dropout
37
- self.emb_phone = nn.Linear(256, hidden_channels)
38
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
- if f0 == True:
40
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
- self.encoder = attentions.Encoder(
42
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
- )
44
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
-
46
- def forward(self, phone, pitch, lengths):
47
- if pitch == None:
48
- x = self.emb_phone(phone)
49
- else:
50
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
- x = self.lrelu(x)
53
- x = torch.transpose(x, 1, -1) # [b, h, t]
54
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
- x.dtype
56
- )
57
- x = self.encoder(x * x_mask, x_mask)
58
- stats = self.proj(x) * x_mask
59
-
60
- m, logs = torch.split(stats, self.out_channels, dim=1)
61
- return m, logs, x_mask
62
- class TextEncoder768(nn.Module):
63
- def __init__(
64
- self,
65
- out_channels,
66
- hidden_channels,
67
- filter_channels,
68
- n_heads,
69
- n_layers,
70
- kernel_size,
71
- p_dropout,
72
- f0=True,
73
- ):
74
- super().__init__()
75
- self.out_channels = out_channels
76
- self.hidden_channels = hidden_channels
77
- self.filter_channels = filter_channels
78
- self.n_heads = n_heads
79
- self.n_layers = n_layers
80
- self.kernel_size = kernel_size
81
- self.p_dropout = p_dropout
82
- self.emb_phone = nn.Linear(768, hidden_channels)
83
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
84
- if f0 == True:
85
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
86
- self.encoder = attentions.Encoder(
87
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
88
- )
89
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
90
-
91
- def forward(self, phone, pitch, lengths):
92
- if pitch == None:
93
- x = self.emb_phone(phone)
94
- else:
95
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
96
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
97
- x = self.lrelu(x)
98
- x = torch.transpose(x, 1, -1) # [b, h, t]
99
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
100
- x.dtype
101
- )
102
- x = self.encoder(x * x_mask, x_mask)
103
- stats = self.proj(x) * x_mask
104
-
105
- m, logs = torch.split(stats, self.out_channels, dim=1)
106
- return m, logs, x_mask
107
-
108
- class ResidualCouplingBlock(nn.Module):
109
- def __init__(
110
- self,
111
- channels,
112
- hidden_channels,
113
- kernel_size,
114
- dilation_rate,
115
- n_layers,
116
- n_flows=4,
117
- gin_channels=0,
118
- ):
119
- super().__init__()
120
- self.channels = channels
121
- self.hidden_channels = hidden_channels
122
- self.kernel_size = kernel_size
123
- self.dilation_rate = dilation_rate
124
- self.n_layers = n_layers
125
- self.n_flows = n_flows
126
- self.gin_channels = gin_channels
127
-
128
- self.flows = nn.ModuleList()
129
- for i in range(n_flows):
130
- self.flows.append(
131
- modules.ResidualCouplingLayer(
132
- channels,
133
- hidden_channels,
134
- kernel_size,
135
- dilation_rate,
136
- n_layers,
137
- gin_channels=gin_channels,
138
- mean_only=True,
139
- )
140
- )
141
- self.flows.append(modules.Flip())
142
-
143
- def forward(self, x, x_mask, g=None, reverse=False):
144
- if not reverse:
145
- for flow in self.flows:
146
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
147
- else:
148
- for flow in reversed(self.flows):
149
- x = flow(x, x_mask, g=g, reverse=reverse)
150
- return x
151
-
152
- def remove_weight_norm(self):
153
- for i in range(self.n_flows):
154
- self.flows[i * 2].remove_weight_norm()
155
-
156
-
157
- class PosteriorEncoder(nn.Module):
158
- def __init__(
159
- self,
160
- in_channels,
161
- out_channels,
162
- hidden_channels,
163
- kernel_size,
164
- dilation_rate,
165
- n_layers,
166
- gin_channels=0,
167
- ):
168
- super().__init__()
169
- self.in_channels = in_channels
170
- self.out_channels = out_channels
171
- self.hidden_channels = hidden_channels
172
- self.kernel_size = kernel_size
173
- self.dilation_rate = dilation_rate
174
- self.n_layers = n_layers
175
- self.gin_channels = gin_channels
176
-
177
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
178
- self.enc = modules.WN(
179
- hidden_channels,
180
- kernel_size,
181
- dilation_rate,
182
- n_layers,
183
- gin_channels=gin_channels,
184
- )
185
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
186
-
187
- def forward(self, x, x_lengths, g=None):
188
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
189
- x.dtype
190
- )
191
- x = self.pre(x) * x_mask
192
- x = self.enc(x, x_mask, g=g)
193
- stats = self.proj(x) * x_mask
194
- m, logs = torch.split(stats, self.out_channels, dim=1)
195
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
196
- return z, m, logs, x_mask
197
-
198
- def remove_weight_norm(self):
199
- self.enc.remove_weight_norm()
200
-
201
-
202
- class Generator(torch.nn.Module):
203
- def __init__(
204
- self,
205
- initial_channel,
206
- resblock,
207
- resblock_kernel_sizes,
208
- resblock_dilation_sizes,
209
- upsample_rates,
210
- upsample_initial_channel,
211
- upsample_kernel_sizes,
212
- gin_channels=0,
213
- ):
214
- super(Generator, self).__init__()
215
- self.num_kernels = len(resblock_kernel_sizes)
216
- self.num_upsamples = len(upsample_rates)
217
- self.conv_pre = Conv1d(
218
- initial_channel, upsample_initial_channel, 7, 1, padding=3
219
- )
220
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
221
-
222
- self.ups = nn.ModuleList()
223
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
224
- self.ups.append(
225
- weight_norm(
226
- ConvTranspose1d(
227
- upsample_initial_channel // (2**i),
228
- upsample_initial_channel // (2 ** (i + 1)),
229
- k,
230
- u,
231
- padding=(k - u) // 2,
232
- )
233
- )
234
- )
235
-
236
- self.resblocks = nn.ModuleList()
237
- for i in range(len(self.ups)):
238
- ch = upsample_initial_channel // (2 ** (i + 1))
239
- for j, (k, d) in enumerate(
240
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
241
- ):
242
- self.resblocks.append(resblock(ch, k, d))
243
-
244
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
245
- self.ups.apply(init_weights)
246
-
247
- if gin_channels != 0:
248
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
249
-
250
- def forward(self, x, g=None):
251
- x = self.conv_pre(x)
252
- if g is not None:
253
- x = x + self.cond(g)
254
-
255
- for i in range(self.num_upsamples):
256
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
257
- x = self.ups[i](x)
258
- xs = None
259
- for j in range(self.num_kernels):
260
- if xs is None:
261
- xs = self.resblocks[i * self.num_kernels + j](x)
262
- else:
263
- xs += self.resblocks[i * self.num_kernels + j](x)
264
- x = xs / self.num_kernels
265
- x = F.leaky_relu(x)
266
- x = self.conv_post(x)
267
- x = torch.tanh(x)
268
-
269
- return x
270
-
271
- def remove_weight_norm(self):
272
- for l in self.ups:
273
- remove_weight_norm(l)
274
- for l in self.resblocks:
275
- l.remove_weight_norm()
276
-
277
-
278
- class SineGen(torch.nn.Module):
279
- """Definition of sine generator
280
- SineGen(samp_rate, harmonic_num = 0,
281
- sine_amp = 0.1, noise_std = 0.003,
282
- voiced_threshold = 0,
283
- flag_for_pulse=False)
284
- samp_rate: sampling rate in Hz
285
- harmonic_num: number of harmonic overtones (default 0)
286
- sine_amp: amplitude of sine-wavefrom (default 0.1)
287
- noise_std: std of Gaussian noise (default 0.003)
288
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
289
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
290
- Note: when flag_for_pulse is True, the first time step of a voiced
291
- segment is always sin(np.pi) or cos(0)
292
- """
293
-
294
- def __init__(
295
- self,
296
- samp_rate,
297
- harmonic_num=0,
298
- sine_amp=0.1,
299
- noise_std=0.003,
300
- voiced_threshold=0,
301
- flag_for_pulse=False,
302
- ):
303
- super(SineGen, self).__init__()
304
- self.sine_amp = sine_amp
305
- self.noise_std = noise_std
306
- self.harmonic_num = harmonic_num
307
- self.dim = self.harmonic_num + 1
308
- self.sampling_rate = samp_rate
309
- self.voiced_threshold = voiced_threshold
310
-
311
- def _f02uv(self, f0):
312
- # generate uv signal
313
- uv = torch.ones_like(f0)
314
- uv = uv * (f0 > self.voiced_threshold)
315
- return uv
316
-
317
- def forward(self, f0, upp):
318
- """sine_tensor, uv = forward(f0)
319
- input F0: tensor(batchsize=1, length, dim=1)
320
- f0 for unvoiced steps should be 0
321
- output sine_tensor: tensor(batchsize=1, length, dim)
322
- output uv: tensor(batchsize=1, length, 1)
323
- """
324
- with torch.no_grad():
325
- f0 = f0[:, None].transpose(1, 2)
326
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
327
- # fundamental component
328
- f0_buf[:, :, 0] = f0[:, :, 0]
329
- for idx in np.arange(self.harmonic_num):
330
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
331
- idx + 2
332
- ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
333
- rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
334
- rand_ini = torch.rand(
335
- f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
336
- )
337
- rand_ini[:, 0] = 0
338
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
339
- tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
340
- tmp_over_one *= upp
341
- tmp_over_one = F.interpolate(
342
- tmp_over_one.transpose(2, 1),
343
- scale_factor=upp,
344
- mode="linear",
345
- align_corners=True,
346
- ).transpose(2, 1)
347
- rad_values = F.interpolate(
348
- rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
349
- ).transpose(
350
- 2, 1
351
- ) #######
352
- tmp_over_one %= 1
353
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
354
- cumsum_shift = torch.zeros_like(rad_values)
355
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
356
- sine_waves = torch.sin(
357
- torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
358
- )
359
- sine_waves = sine_waves * self.sine_amp
360
- uv = self._f02uv(f0)
361
- uv = F.interpolate(
362
- uv.transpose(2, 1), scale_factor=upp, mode="nearest"
363
- ).transpose(2, 1)
364
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
365
- noise = noise_amp * torch.randn_like(sine_waves)
366
- sine_waves = sine_waves * uv + noise
367
- return sine_waves, uv, noise
368
-
369
-
370
- class SourceModuleHnNSF(torch.nn.Module):
371
- """SourceModule for hn-nsf
372
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
373
- add_noise_std=0.003, voiced_threshod=0)
374
- sampling_rate: sampling_rate in Hz
375
- harmonic_num: number of harmonic above F0 (default: 0)
376
- sine_amp: amplitude of sine source signal (default: 0.1)
377
- add_noise_std: std of additive Gaussian noise (default: 0.003)
378
- note that amplitude of noise in unvoiced is decided
379
- by sine_amp
380
- voiced_threshold: threhold to set U/V given F0 (default: 0)
381
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
382
- F0_sampled (batchsize, length, 1)
383
- Sine_source (batchsize, length, 1)
384
- noise_source (batchsize, length 1)
385
- uv (batchsize, length, 1)
386
- """
387
-
388
- def __init__(
389
- self,
390
- sampling_rate,
391
- harmonic_num=0,
392
- sine_amp=0.1,
393
- add_noise_std=0.003,
394
- voiced_threshod=0,
395
- is_half=True,
396
- ):
397
- super(SourceModuleHnNSF, self).__init__()
398
-
399
- self.sine_amp = sine_amp
400
- self.noise_std = add_noise_std
401
- self.is_half = is_half
402
- # to produce sine waveforms
403
- self.l_sin_gen = SineGen(
404
- sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
405
- )
406
-
407
- # to merge source harmonics into a single excitation
408
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
409
- self.l_tanh = torch.nn.Tanh()
410
-
411
- def forward(self, x, upp=None):
412
- sine_wavs, uv, _ = self.l_sin_gen(x, upp)
413
- if self.is_half:
414
- sine_wavs = sine_wavs.half()
415
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
416
- return sine_merge, None, None # noise, uv
417
-
418
-
419
- class GeneratorNSF(torch.nn.Module):
420
- def __init__(
421
- self,
422
- initial_channel,
423
- resblock,
424
- resblock_kernel_sizes,
425
- resblock_dilation_sizes,
426
- upsample_rates,
427
- upsample_initial_channel,
428
- upsample_kernel_sizes,
429
- gin_channels,
430
- sr,
431
- is_half=False,
432
- ):
433
- super(GeneratorNSF, self).__init__()
434
- self.num_kernels = len(resblock_kernel_sizes)
435
- self.num_upsamples = len(upsample_rates)
436
-
437
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
438
- self.m_source = SourceModuleHnNSF(
439
- sampling_rate=sr, harmonic_num=0, is_half=is_half
440
- )
441
- self.noise_convs = nn.ModuleList()
442
- self.conv_pre = Conv1d(
443
- initial_channel, upsample_initial_channel, 7, 1, padding=3
444
- )
445
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
446
-
447
- self.ups = nn.ModuleList()
448
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
449
- c_cur = upsample_initial_channel // (2 ** (i + 1))
450
- self.ups.append(
451
- weight_norm(
452
- ConvTranspose1d(
453
- upsample_initial_channel // (2**i),
454
- upsample_initial_channel // (2 ** (i + 1)),
455
- k,
456
- u,
457
- padding=(k - u) // 2,
458
- )
459
- )
460
- )
461
- if i + 1 < len(upsample_rates):
462
- stride_f0 = np.prod(upsample_rates[i + 1 :])
463
- self.noise_convs.append(
464
- Conv1d(
465
- 1,
466
- c_cur,
467
- kernel_size=stride_f0 * 2,
468
- stride=stride_f0,
469
- padding=stride_f0 // 2,
470
- )
471
- )
472
- else:
473
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
474
-
475
- self.resblocks = nn.ModuleList()
476
- for i in range(len(self.ups)):
477
- ch = upsample_initial_channel // (2 ** (i + 1))
478
- for j, (k, d) in enumerate(
479
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
480
- ):
481
- self.resblocks.append(resblock(ch, k, d))
482
-
483
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
484
- self.ups.apply(init_weights)
485
-
486
- if gin_channels != 0:
487
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
488
-
489
- self.upp = np.prod(upsample_rates)
490
-
491
- def forward(self, x, f0, g=None):
492
- har_source, noi_source, uv = self.m_source(f0, self.upp)
493
- har_source = har_source.transpose(1, 2)
494
- x = self.conv_pre(x)
495
- if g is not None:
496
- x = x + self.cond(g)
497
-
498
- for i in range(self.num_upsamples):
499
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
500
- x = self.ups[i](x)
501
- x_source = self.noise_convs[i](har_source)
502
- x = x + x_source
503
- xs = None
504
- for j in range(self.num_kernels):
505
- if xs is None:
506
- xs = self.resblocks[i * self.num_kernels + j](x)
507
- else:
508
- xs += self.resblocks[i * self.num_kernels + j](x)
509
- x = xs / self.num_kernels
510
- x = F.leaky_relu(x)
511
- x = self.conv_post(x)
512
- x = torch.tanh(x)
513
- return x
514
-
515
- def remove_weight_norm(self):
516
- for l in self.ups:
517
- remove_weight_norm(l)
518
- for l in self.resblocks:
519
- l.remove_weight_norm()
520
-
521
-
522
- sr2sr = {
523
- "32k": 32000,
524
- "40k": 40000,
525
- "48k": 48000,
526
- }
527
-
528
-
529
- class SynthesizerTrnMs256NSFsid(nn.Module):
530
- def __init__(
531
- self,
532
- spec_channels,
533
- segment_size,
534
- inter_channels,
535
- hidden_channels,
536
- filter_channels,
537
- n_heads,
538
- n_layers,
539
- kernel_size,
540
- p_dropout,
541
- resblock,
542
- resblock_kernel_sizes,
543
- resblock_dilation_sizes,
544
- upsample_rates,
545
- upsample_initial_channel,
546
- upsample_kernel_sizes,
547
- spk_embed_dim,
548
- gin_channels,
549
- sr,
550
- **kwargs
551
- ):
552
- super().__init__()
553
- if type(sr) == type("strr"):
554
- sr = sr2sr[sr]
555
- self.spec_channels = spec_channels
556
- self.inter_channels = inter_channels
557
- self.hidden_channels = hidden_channels
558
- self.filter_channels = filter_channels
559
- self.n_heads = n_heads
560
- self.n_layers = n_layers
561
- self.kernel_size = kernel_size
562
- self.p_dropout = p_dropout
563
- self.resblock = resblock
564
- self.resblock_kernel_sizes = resblock_kernel_sizes
565
- self.resblock_dilation_sizes = resblock_dilation_sizes
566
- self.upsample_rates = upsample_rates
567
- self.upsample_initial_channel = upsample_initial_channel
568
- self.upsample_kernel_sizes = upsample_kernel_sizes
569
- self.segment_size = segment_size
570
- self.gin_channels = gin_channels
571
- # self.hop_length = hop_length#
572
- self.spk_embed_dim = spk_embed_dim
573
- self.enc_p = TextEncoder256(
574
- inter_channels,
575
- hidden_channels,
576
- filter_channels,
577
- n_heads,
578
- n_layers,
579
- kernel_size,
580
- p_dropout,
581
- )
582
- self.dec = GeneratorNSF(
583
- inter_channels,
584
- resblock,
585
- resblock_kernel_sizes,
586
- resblock_dilation_sizes,
587
- upsample_rates,
588
- upsample_initial_channel,
589
- upsample_kernel_sizes,
590
- gin_channels=gin_channels,
591
- sr=sr,
592
- is_half=kwargs["is_half"],
593
- )
594
- self.enc_q = PosteriorEncoder(
595
- spec_channels,
596
- inter_channels,
597
- hidden_channels,
598
- 5,
599
- 1,
600
- 16,
601
- gin_channels=gin_channels,
602
- )
603
- self.flow = ResidualCouplingBlock(
604
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
605
- )
606
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
607
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
608
-
609
- def remove_weight_norm(self):
610
- self.dec.remove_weight_norm()
611
- self.flow.remove_weight_norm()
612
- self.enc_q.remove_weight_norm()
613
-
614
- def forward(
615
- self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
616
- ): # 这里ds是id,[bs,1]
617
- # print(1,pitch.shape)#[bs,t]
618
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
619
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
620
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
621
- z_p = self.flow(z, y_mask, g=g)
622
- z_slice, ids_slice = commons.rand_slice_segments(
623
- z, y_lengths, self.segment_size
624
- )
625
- # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
626
- pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
627
- # print(-2,pitchf.shape,z_slice.shape)
628
- o = self.dec(z_slice, pitchf, g=g)
629
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
630
-
631
- def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
632
- g = self.emb_g(sid).unsqueeze(-1)
633
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
634
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
635
- z = self.flow(z_p, x_mask, g=g, reverse=True)
636
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
637
- return o, x_mask, (z, z_p, m_p, logs_p)
638
- class SynthesizerTrnMs768NSFsid(nn.Module):
639
- def __init__(
640
- self,
641
- spec_channels,
642
- segment_size,
643
- inter_channels,
644
- hidden_channels,
645
- filter_channels,
646
- n_heads,
647
- n_layers,
648
- kernel_size,
649
- p_dropout,
650
- resblock,
651
- resblock_kernel_sizes,
652
- resblock_dilation_sizes,
653
- upsample_rates,
654
- upsample_initial_channel,
655
- upsample_kernel_sizes,
656
- spk_embed_dim,
657
- gin_channels,
658
- sr,
659
- **kwargs
660
- ):
661
- super().__init__()
662
- if type(sr) == type("strr"):
663
- sr = sr2sr[sr]
664
- self.spec_channels = spec_channels
665
- self.inter_channels = inter_channels
666
- self.hidden_channels = hidden_channels
667
- self.filter_channels = filter_channels
668
- self.n_heads = n_heads
669
- self.n_layers = n_layers
670
- self.kernel_size = kernel_size
671
- self.p_dropout = p_dropout
672
- self.resblock = resblock
673
- self.resblock_kernel_sizes = resblock_kernel_sizes
674
- self.resblock_dilation_sizes = resblock_dilation_sizes
675
- self.upsample_rates = upsample_rates
676
- self.upsample_initial_channel = upsample_initial_channel
677
- self.upsample_kernel_sizes = upsample_kernel_sizes
678
- self.segment_size = segment_size
679
- self.gin_channels = gin_channels
680
- # self.hop_length = hop_length#
681
- self.spk_embed_dim = spk_embed_dim
682
- self.enc_p = TextEncoder768(
683
- inter_channels,
684
- hidden_channels,
685
- filter_channels,
686
- n_heads,
687
- n_layers,
688
- kernel_size,
689
- p_dropout,
690
- )
691
- self.dec = GeneratorNSF(
692
- inter_channels,
693
- resblock,
694
- resblock_kernel_sizes,
695
- resblock_dilation_sizes,
696
- upsample_rates,
697
- upsample_initial_channel,
698
- upsample_kernel_sizes,
699
- gin_channels=gin_channels,
700
- sr=sr,
701
- is_half=kwargs["is_half"],
702
- )
703
- self.enc_q = PosteriorEncoder(
704
- spec_channels,
705
- inter_channels,
706
- hidden_channels,
707
- 5,
708
- 1,
709
- 16,
710
- gin_channels=gin_channels,
711
- )
712
- self.flow = ResidualCouplingBlock(
713
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
714
- )
715
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
716
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
717
-
718
- def remove_weight_norm(self):
719
- self.dec.remove_weight_norm()
720
- self.flow.remove_weight_norm()
721
- self.enc_q.remove_weight_norm()
722
-
723
- def forward(
724
- self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
725
- ): # 这里ds是id,[bs,1]
726
- # print(1,pitch.shape)#[bs,t]
727
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
728
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
729
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
730
- z_p = self.flow(z, y_mask, g=g)
731
- z_slice, ids_slice = commons.rand_slice_segments(
732
- z, y_lengths, self.segment_size
733
- )
734
- # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
735
- pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
736
- # print(-2,pitchf.shape,z_slice.shape)
737
- o = self.dec(z_slice, pitchf, g=g)
738
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
739
-
740
- def infer(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
741
- g = self.emb_g(sid).unsqueeze(-1)
742
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
743
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
744
- z = self.flow(z_p, x_mask, g=g, reverse=True)
745
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
746
- return o, x_mask, (z, z_p, m_p, logs_p)
747
-
748
-
749
- class SynthesizerTrnMs256NSFsid_nono(nn.Module):
750
- def __init__(
751
- self,
752
- spec_channels,
753
- segment_size,
754
- inter_channels,
755
- hidden_channels,
756
- filter_channels,
757
- n_heads,
758
- n_layers,
759
- kernel_size,
760
- p_dropout,
761
- resblock,
762
- resblock_kernel_sizes,
763
- resblock_dilation_sizes,
764
- upsample_rates,
765
- upsample_initial_channel,
766
- upsample_kernel_sizes,
767
- spk_embed_dim,
768
- gin_channels,
769
- sr=None,
770
- **kwargs
771
- ):
772
- super().__init__()
773
- self.spec_channels = spec_channels
774
- self.inter_channels = inter_channels
775
- self.hidden_channels = hidden_channels
776
- self.filter_channels = filter_channels
777
- self.n_heads = n_heads
778
- self.n_layers = n_layers
779
- self.kernel_size = kernel_size
780
- self.p_dropout = p_dropout
781
- self.resblock = resblock
782
- self.resblock_kernel_sizes = resblock_kernel_sizes
783
- self.resblock_dilation_sizes = resblock_dilation_sizes
784
- self.upsample_rates = upsample_rates
785
- self.upsample_initial_channel = upsample_initial_channel
786
- self.upsample_kernel_sizes = upsample_kernel_sizes
787
- self.segment_size = segment_size
788
- self.gin_channels = gin_channels
789
- # self.hop_length = hop_length#
790
- self.spk_embed_dim = spk_embed_dim
791
- self.enc_p = TextEncoder256(
792
- inter_channels,
793
- hidden_channels,
794
- filter_channels,
795
- n_heads,
796
- n_layers,
797
- kernel_size,
798
- p_dropout,
799
- f0=False,
800
- )
801
- self.dec = Generator(
802
- inter_channels,
803
- resblock,
804
- resblock_kernel_sizes,
805
- resblock_dilation_sizes,
806
- upsample_rates,
807
- upsample_initial_channel,
808
- upsample_kernel_sizes,
809
- gin_channels=gin_channels,
810
- )
811
- self.enc_q = PosteriorEncoder(
812
- spec_channels,
813
- inter_channels,
814
- hidden_channels,
815
- 5,
816
- 1,
817
- 16,
818
- gin_channels=gin_channels,
819
- )
820
- self.flow = ResidualCouplingBlock(
821
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
822
- )
823
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
824
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
825
-
826
- def remove_weight_norm(self):
827
- self.dec.remove_weight_norm()
828
- self.flow.remove_weight_norm()
829
- self.enc_q.remove_weight_norm()
830
-
831
- def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
832
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
833
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
834
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
835
- z_p = self.flow(z, y_mask, g=g)
836
- z_slice, ids_slice = commons.rand_slice_segments(
837
- z, y_lengths, self.segment_size
838
- )
839
- o = self.dec(z_slice, g=g)
840
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
841
-
842
- def infer(self, phone, phone_lengths, sid, max_len=None):
843
- g = self.emb_g(sid).unsqueeze(-1)
844
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
845
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
846
- z = self.flow(z_p, x_mask, g=g, reverse=True)
847
- o = self.dec((z * x_mask)[:, :, :max_len], g=g)
848
- return o, x_mask, (z, z_p, m_p, logs_p)
849
- class SynthesizerTrnMs768NSFsid_nono(nn.Module):
850
- def __init__(
851
- self,
852
- spec_channels,
853
- segment_size,
854
- inter_channels,
855
- hidden_channels,
856
- filter_channels,
857
- n_heads,
858
- n_layers,
859
- kernel_size,
860
- p_dropout,
861
- resblock,
862
- resblock_kernel_sizes,
863
- resblock_dilation_sizes,
864
- upsample_rates,
865
- upsample_initial_channel,
866
- upsample_kernel_sizes,
867
- spk_embed_dim,
868
- gin_channels,
869
- sr=None,
870
- **kwargs
871
- ):
872
- super().__init__()
873
- self.spec_channels = spec_channels
874
- self.inter_channels = inter_channels
875
- self.hidden_channels = hidden_channels
876
- self.filter_channels = filter_channels
877
- self.n_heads = n_heads
878
- self.n_layers = n_layers
879
- self.kernel_size = kernel_size
880
- self.p_dropout = p_dropout
881
- self.resblock = resblock
882
- self.resblock_kernel_sizes = resblock_kernel_sizes
883
- self.resblock_dilation_sizes = resblock_dilation_sizes
884
- self.upsample_rates = upsample_rates
885
- self.upsample_initial_channel = upsample_initial_channel
886
- self.upsample_kernel_sizes = upsample_kernel_sizes
887
- self.segment_size = segment_size
888
- self.gin_channels = gin_channels
889
- # self.hop_length = hop_length#
890
- self.spk_embed_dim = spk_embed_dim
891
- self.enc_p = TextEncoder768(
892
- inter_channels,
893
- hidden_channels,
894
- filter_channels,
895
- n_heads,
896
- n_layers,
897
- kernel_size,
898
- p_dropout,
899
- f0=False,
900
- )
901
- self.dec = Generator(
902
- inter_channels,
903
- resblock,
904
- resblock_kernel_sizes,
905
- resblock_dilation_sizes,
906
- upsample_rates,
907
- upsample_initial_channel,
908
- upsample_kernel_sizes,
909
- gin_channels=gin_channels,
910
- )
911
- self.enc_q = PosteriorEncoder(
912
- spec_channels,
913
- inter_channels,
914
- hidden_channels,
915
- 5,
916
- 1,
917
- 16,
918
- gin_channels=gin_channels,
919
- )
920
- self.flow = ResidualCouplingBlock(
921
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
922
- )
923
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
924
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
925
-
926
- def remove_weight_norm(self):
927
- self.dec.remove_weight_norm()
928
- self.flow.remove_weight_norm()
929
- self.enc_q.remove_weight_norm()
930
-
931
- def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
932
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
933
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
934
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
935
- z_p = self.flow(z, y_mask, g=g)
936
- z_slice, ids_slice = commons.rand_slice_segments(
937
- z, y_lengths, self.segment_size
938
- )
939
- o = self.dec(z_slice, g=g)
940
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
941
-
942
- def infer(self, phone, phone_lengths, sid, max_len=None):
943
- g = self.emb_g(sid).unsqueeze(-1)
944
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
945
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
946
- z = self.flow(z_p, x_mask, g=g, reverse=True)
947
- o = self.dec((z * x_mask)[:, :, :max_len], g=g)
948
- return o, x_mask, (z, z_p, m_p, logs_p)
949
-
950
-
951
- class MultiPeriodDiscriminator(torch.nn.Module):
952
- def __init__(self, use_spectral_norm=False):
953
- super(MultiPeriodDiscriminator, self).__init__()
954
- periods = [2, 3, 5, 7, 11, 17]
955
- # periods = [3, 5, 7, 11, 17, 23, 37]
956
-
957
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
958
- discs = discs + [
959
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
960
- ]
961
- self.discriminators = nn.ModuleList(discs)
962
-
963
- def forward(self, y, y_hat):
964
- y_d_rs = [] #
965
- y_d_gs = []
966
- fmap_rs = []
967
- fmap_gs = []
968
- for i, d in enumerate(self.discriminators):
969
- y_d_r, fmap_r = d(y)
970
- y_d_g, fmap_g = d(y_hat)
971
- # for j in range(len(fmap_r)):
972
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
973
- y_d_rs.append(y_d_r)
974
- y_d_gs.append(y_d_g)
975
- fmap_rs.append(fmap_r)
976
- fmap_gs.append(fmap_g)
977
-
978
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
979
-
980
- class MultiPeriodDiscriminatorV2(torch.nn.Module):
981
- def __init__(self, use_spectral_norm=False):
982
- super(MultiPeriodDiscriminatorV2, self).__init__()
983
- # periods = [2, 3, 5, 7, 11, 17]
984
- periods = [2,3, 5, 7, 11, 17, 23, 37]
985
-
986
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
987
- discs = discs + [
988
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
989
- ]
990
- self.discriminators = nn.ModuleList(discs)
991
-
992
- def forward(self, y, y_hat):
993
- y_d_rs = [] #
994
- y_d_gs = []
995
- fmap_rs = []
996
- fmap_gs = []
997
- for i, d in enumerate(self.discriminators):
998
- y_d_r, fmap_r = d(y)
999
- y_d_g, fmap_g = d(y_hat)
1000
- # for j in range(len(fmap_r)):
1001
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1002
- y_d_rs.append(y_d_r)
1003
- y_d_gs.append(y_d_g)
1004
- fmap_rs.append(fmap_r)
1005
- fmap_gs.append(fmap_g)
1006
-
1007
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1008
-
1009
-
1010
- class DiscriminatorS(torch.nn.Module):
1011
- def __init__(self, use_spectral_norm=False):
1012
- super(DiscriminatorS, self).__init__()
1013
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1014
- self.convs = nn.ModuleList(
1015
- [
1016
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1017
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1018
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1019
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1020
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1021
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1022
- ]
1023
- )
1024
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1025
-
1026
- def forward(self, x):
1027
- fmap = []
1028
-
1029
- for l in self.convs:
1030
- x = l(x)
1031
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
1032
- fmap.append(x)
1033
- x = self.conv_post(x)
1034
- fmap.append(x)
1035
- x = torch.flatten(x, 1, -1)
1036
-
1037
- return x, fmap
1038
-
1039
-
1040
- class DiscriminatorP(torch.nn.Module):
1041
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
1042
- super(DiscriminatorP, self).__init__()
1043
- self.period = period
1044
- self.use_spectral_norm = use_spectral_norm
1045
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1046
- self.convs = nn.ModuleList(
1047
- [
1048
- norm_f(
1049
- Conv2d(
1050
- 1,
1051
- 32,
1052
- (kernel_size, 1),
1053
- (stride, 1),
1054
- padding=(get_padding(kernel_size, 1), 0),
1055
- )
1056
- ),
1057
- norm_f(
1058
- Conv2d(
1059
- 32,
1060
- 128,
1061
- (kernel_size, 1),
1062
- (stride, 1),
1063
- padding=(get_padding(kernel_size, 1), 0),
1064
- )
1065
- ),
1066
- norm_f(
1067
- Conv2d(
1068
- 128,
1069
- 512,
1070
- (kernel_size, 1),
1071
- (stride, 1),
1072
- padding=(get_padding(kernel_size, 1), 0),
1073
- )
1074
- ),
1075
- norm_f(
1076
- Conv2d(
1077
- 512,
1078
- 1024,
1079
- (kernel_size, 1),
1080
- (stride, 1),
1081
- padding=(get_padding(kernel_size, 1), 0),
1082
- )
1083
- ),
1084
- norm_f(
1085
- Conv2d(
1086
- 1024,
1087
- 1024,
1088
- (kernel_size, 1),
1089
- 1,
1090
- padding=(get_padding(kernel_size, 1), 0),
1091
- )
1092
- ),
1093
- ]
1094
- )
1095
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1096
-
1097
- def forward(self, x):
1098
- fmap = []
1099
-
1100
- # 1d to 2d
1101
- b, c, t = x.shape
1102
- if t % self.period != 0: # pad first
1103
- n_pad = self.period - (t % self.period)
1104
- x = F.pad(x, (0, n_pad), "reflect")
1105
- t = t + n_pad
1106
- x = x.view(b, c, t // self.period, self.period)
1107
-
1108
- for l in self.convs:
1109
- x = l(x)
1110
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
1111
- fmap.append(x)
1112
- x = self.conv_post(x)
1113
- fmap.append(x)
1114
- x = torch.flatten(x, 1, -1)
1115
-
1116
- return x, fmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/models_onnx.py DELETED
@@ -1,760 +0,0 @@
1
- import math, pdb, os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from infer_pack import modules
7
- from infer_pack import attentions
8
- from infer_pack import commons
9
- from infer_pack.commons import init_weights, get_padding
10
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from infer_pack.commons import init_weights
13
- import numpy as np
14
- from infer_pack import commons
15
-
16
-
17
- class TextEncoder256(nn.Module):
18
- def __init__(
19
- self,
20
- out_channels,
21
- hidden_channels,
22
- filter_channels,
23
- n_heads,
24
- n_layers,
25
- kernel_size,
26
- p_dropout,
27
- f0=True,
28
- ):
29
- super().__init__()
30
- self.out_channels = out_channels
31
- self.hidden_channels = hidden_channels
32
- self.filter_channels = filter_channels
33
- self.n_heads = n_heads
34
- self.n_layers = n_layers
35
- self.kernel_size = kernel_size
36
- self.p_dropout = p_dropout
37
- self.emb_phone = nn.Linear(256, hidden_channels)
38
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
- if f0 == True:
40
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
- self.encoder = attentions.Encoder(
42
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
- )
44
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
-
46
- def forward(self, phone, pitch, lengths):
47
- if pitch == None:
48
- x = self.emb_phone(phone)
49
- else:
50
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
- x = self.lrelu(x)
53
- x = torch.transpose(x, 1, -1) # [b, h, t]
54
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
- x.dtype
56
- )
57
- x = self.encoder(x * x_mask, x_mask)
58
- stats = self.proj(x) * x_mask
59
-
60
- m, logs = torch.split(stats, self.out_channels, dim=1)
61
- return m, logs, x_mask
62
-
63
-
64
- class TextEncoder256Sim(nn.Module):
65
- def __init__(
66
- self,
67
- out_channels,
68
- hidden_channels,
69
- filter_channels,
70
- n_heads,
71
- n_layers,
72
- kernel_size,
73
- p_dropout,
74
- f0=True,
75
- ):
76
- super().__init__()
77
- self.out_channels = out_channels
78
- self.hidden_channels = hidden_channels
79
- self.filter_channels = filter_channels
80
- self.n_heads = n_heads
81
- self.n_layers = n_layers
82
- self.kernel_size = kernel_size
83
- self.p_dropout = p_dropout
84
- self.emb_phone = nn.Linear(256, hidden_channels)
85
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
- if f0 == True:
87
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
- self.encoder = attentions.Encoder(
89
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
- )
91
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
92
-
93
- def forward(self, phone, pitch, lengths):
94
- if pitch == None:
95
- x = self.emb_phone(phone)
96
- else:
97
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
- x = self.lrelu(x)
100
- x = torch.transpose(x, 1, -1) # [b, h, t]
101
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
- x.dtype
103
- )
104
- x = self.encoder(x * x_mask, x_mask)
105
- x = self.proj(x) * x_mask
106
- return x, x_mask
107
-
108
-
109
- class ResidualCouplingBlock(nn.Module):
110
- def __init__(
111
- self,
112
- channels,
113
- hidden_channels,
114
- kernel_size,
115
- dilation_rate,
116
- n_layers,
117
- n_flows=4,
118
- gin_channels=0,
119
- ):
120
- super().__init__()
121
- self.channels = channels
122
- self.hidden_channels = hidden_channels
123
- self.kernel_size = kernel_size
124
- self.dilation_rate = dilation_rate
125
- self.n_layers = n_layers
126
- self.n_flows = n_flows
127
- self.gin_channels = gin_channels
128
-
129
- self.flows = nn.ModuleList()
130
- for i in range(n_flows):
131
- self.flows.append(
132
- modules.ResidualCouplingLayer(
133
- channels,
134
- hidden_channels,
135
- kernel_size,
136
- dilation_rate,
137
- n_layers,
138
- gin_channels=gin_channels,
139
- mean_only=True,
140
- )
141
- )
142
- self.flows.append(modules.Flip())
143
-
144
- def forward(self, x, x_mask, g=None, reverse=False):
145
- if not reverse:
146
- for flow in self.flows:
147
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
148
- else:
149
- for flow in reversed(self.flows):
150
- x = flow(x, x_mask, g=g, reverse=reverse)
151
- return x
152
-
153
- def remove_weight_norm(self):
154
- for i in range(self.n_flows):
155
- self.flows[i * 2].remove_weight_norm()
156
-
157
-
158
- class PosteriorEncoder(nn.Module):
159
- def __init__(
160
- self,
161
- in_channels,
162
- out_channels,
163
- hidden_channels,
164
- kernel_size,
165
- dilation_rate,
166
- n_layers,
167
- gin_channels=0,
168
- ):
169
- super().__init__()
170
- self.in_channels = in_channels
171
- self.out_channels = out_channels
172
- self.hidden_channels = hidden_channels
173
- self.kernel_size = kernel_size
174
- self.dilation_rate = dilation_rate
175
- self.n_layers = n_layers
176
- self.gin_channels = gin_channels
177
-
178
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
179
- self.enc = modules.WN(
180
- hidden_channels,
181
- kernel_size,
182
- dilation_rate,
183
- n_layers,
184
- gin_channels=gin_channels,
185
- )
186
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
187
-
188
- def forward(self, x, x_lengths, g=None):
189
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
190
- x.dtype
191
- )
192
- x = self.pre(x) * x_mask
193
- x = self.enc(x, x_mask, g=g)
194
- stats = self.proj(x) * x_mask
195
- m, logs = torch.split(stats, self.out_channels, dim=1)
196
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
197
- return z, m, logs, x_mask
198
-
199
- def remove_weight_norm(self):
200
- self.enc.remove_weight_norm()
201
-
202
-
203
- class Generator(torch.nn.Module):
204
- def __init__(
205
- self,
206
- initial_channel,
207
- resblock,
208
- resblock_kernel_sizes,
209
- resblock_dilation_sizes,
210
- upsample_rates,
211
- upsample_initial_channel,
212
- upsample_kernel_sizes,
213
- gin_channels=0,
214
- ):
215
- super(Generator, self).__init__()
216
- self.num_kernels = len(resblock_kernel_sizes)
217
- self.num_upsamples = len(upsample_rates)
218
- self.conv_pre = Conv1d(
219
- initial_channel, upsample_initial_channel, 7, 1, padding=3
220
- )
221
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
222
-
223
- self.ups = nn.ModuleList()
224
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
225
- self.ups.append(
226
- weight_norm(
227
- ConvTranspose1d(
228
- upsample_initial_channel // (2**i),
229
- upsample_initial_channel // (2 ** (i + 1)),
230
- k,
231
- u,
232
- padding=(k - u) // 2,
233
- )
234
- )
235
- )
236
-
237
- self.resblocks = nn.ModuleList()
238
- for i in range(len(self.ups)):
239
- ch = upsample_initial_channel // (2 ** (i + 1))
240
- for j, (k, d) in enumerate(
241
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
242
- ):
243
- self.resblocks.append(resblock(ch, k, d))
244
-
245
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
246
- self.ups.apply(init_weights)
247
-
248
- if gin_channels != 0:
249
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
250
-
251
- def forward(self, x, g=None):
252
- x = self.conv_pre(x)
253
- if g is not None:
254
- x = x + self.cond(g)
255
-
256
- for i in range(self.num_upsamples):
257
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
258
- x = self.ups[i](x)
259
- xs = None
260
- for j in range(self.num_kernels):
261
- if xs is None:
262
- xs = self.resblocks[i * self.num_kernels + j](x)
263
- else:
264
- xs += self.resblocks[i * self.num_kernels + j](x)
265
- x = xs / self.num_kernels
266
- x = F.leaky_relu(x)
267
- x = self.conv_post(x)
268
- x = torch.tanh(x)
269
-
270
- return x
271
-
272
- def remove_weight_norm(self):
273
- for l in self.ups:
274
- remove_weight_norm(l)
275
- for l in self.resblocks:
276
- l.remove_weight_norm()
277
-
278
-
279
- class SineGen(torch.nn.Module):
280
- """Definition of sine generator
281
- SineGen(samp_rate, harmonic_num = 0,
282
- sine_amp = 0.1, noise_std = 0.003,
283
- voiced_threshold = 0,
284
- flag_for_pulse=False)
285
- samp_rate: sampling rate in Hz
286
- harmonic_num: number of harmonic overtones (default 0)
287
- sine_amp: amplitude of sine-wavefrom (default 0.1)
288
- noise_std: std of Gaussian noise (default 0.003)
289
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
290
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
291
- Note: when flag_for_pulse is True, the first time step of a voiced
292
- segment is always sin(np.pi) or cos(0)
293
- """
294
-
295
- def __init__(
296
- self,
297
- samp_rate,
298
- harmonic_num=0,
299
- sine_amp=0.1,
300
- noise_std=0.003,
301
- voiced_threshold=0,
302
- flag_for_pulse=False,
303
- ):
304
- super(SineGen, self).__init__()
305
- self.sine_amp = sine_amp
306
- self.noise_std = noise_std
307
- self.harmonic_num = harmonic_num
308
- self.dim = self.harmonic_num + 1
309
- self.sampling_rate = samp_rate
310
- self.voiced_threshold = voiced_threshold
311
-
312
- def _f02uv(self, f0):
313
- # generate uv signal
314
- uv = torch.ones_like(f0)
315
- uv = uv * (f0 > self.voiced_threshold)
316
- return uv
317
-
318
- def forward(self, f0, upp):
319
- """sine_tensor, uv = forward(f0)
320
- input F0: tensor(batchsize=1, length, dim=1)
321
- f0 for unvoiced steps should be 0
322
- output sine_tensor: tensor(batchsize=1, length, dim)
323
- output uv: tensor(batchsize=1, length, 1)
324
- """
325
- with torch.no_grad():
326
- f0 = f0[:, None].transpose(1, 2)
327
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
328
- # fundamental component
329
- f0_buf[:, :, 0] = f0[:, :, 0]
330
- for idx in np.arange(self.harmonic_num):
331
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
332
- idx + 2
333
- ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
334
- rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
335
- rand_ini = torch.rand(
336
- f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
337
- )
338
- rand_ini[:, 0] = 0
339
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
340
- tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
341
- tmp_over_one *= upp
342
- tmp_over_one = F.interpolate(
343
- tmp_over_one.transpose(2, 1),
344
- scale_factor=upp,
345
- mode="linear",
346
- align_corners=True,
347
- ).transpose(2, 1)
348
- rad_values = F.interpolate(
349
- rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
350
- ).transpose(
351
- 2, 1
352
- ) #######
353
- tmp_over_one %= 1
354
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
355
- cumsum_shift = torch.zeros_like(rad_values)
356
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
357
- sine_waves = torch.sin(
358
- torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
359
- )
360
- sine_waves = sine_waves * self.sine_amp
361
- uv = self._f02uv(f0)
362
- uv = F.interpolate(
363
- uv.transpose(2, 1), scale_factor=upp, mode="nearest"
364
- ).transpose(2, 1)
365
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
366
- noise = noise_amp * torch.randn_like(sine_waves)
367
- sine_waves = sine_waves * uv + noise
368
- return sine_waves, uv, noise
369
-
370
-
371
- class SourceModuleHnNSF(torch.nn.Module):
372
- """SourceModule for hn-nsf
373
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
374
- add_noise_std=0.003, voiced_threshod=0)
375
- sampling_rate: sampling_rate in Hz
376
- harmonic_num: number of harmonic above F0 (default: 0)
377
- sine_amp: amplitude of sine source signal (default: 0.1)
378
- add_noise_std: std of additive Gaussian noise (default: 0.003)
379
- note that amplitude of noise in unvoiced is decided
380
- by sine_amp
381
- voiced_threshold: threhold to set U/V given F0 (default: 0)
382
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
383
- F0_sampled (batchsize, length, 1)
384
- Sine_source (batchsize, length, 1)
385
- noise_source (batchsize, length 1)
386
- uv (batchsize, length, 1)
387
- """
388
-
389
- def __init__(
390
- self,
391
- sampling_rate,
392
- harmonic_num=0,
393
- sine_amp=0.1,
394
- add_noise_std=0.003,
395
- voiced_threshod=0,
396
- is_half=True,
397
- ):
398
- super(SourceModuleHnNSF, self).__init__()
399
-
400
- self.sine_amp = sine_amp
401
- self.noise_std = add_noise_std
402
- self.is_half = is_half
403
- # to produce sine waveforms
404
- self.l_sin_gen = SineGen(
405
- sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
406
- )
407
-
408
- # to merge source harmonics into a single excitation
409
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
410
- self.l_tanh = torch.nn.Tanh()
411
-
412
- def forward(self, x, upp=None):
413
- sine_wavs, uv, _ = self.l_sin_gen(x, upp)
414
- if self.is_half:
415
- sine_wavs = sine_wavs.half()
416
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
417
- return sine_merge, None, None # noise, uv
418
-
419
-
420
- class GeneratorNSF(torch.nn.Module):
421
- def __init__(
422
- self,
423
- initial_channel,
424
- resblock,
425
- resblock_kernel_sizes,
426
- resblock_dilation_sizes,
427
- upsample_rates,
428
- upsample_initial_channel,
429
- upsample_kernel_sizes,
430
- gin_channels,
431
- sr,
432
- is_half=False,
433
- ):
434
- super(GeneratorNSF, self).__init__()
435
- self.num_kernels = len(resblock_kernel_sizes)
436
- self.num_upsamples = len(upsample_rates)
437
-
438
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
439
- self.m_source = SourceModuleHnNSF(
440
- sampling_rate=sr, harmonic_num=0, is_half=is_half
441
- )
442
- self.noise_convs = nn.ModuleList()
443
- self.conv_pre = Conv1d(
444
- initial_channel, upsample_initial_channel, 7, 1, padding=3
445
- )
446
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
447
-
448
- self.ups = nn.ModuleList()
449
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
450
- c_cur = upsample_initial_channel // (2 ** (i + 1))
451
- self.ups.append(
452
- weight_norm(
453
- ConvTranspose1d(
454
- upsample_initial_channel // (2**i),
455
- upsample_initial_channel // (2 ** (i + 1)),
456
- k,
457
- u,
458
- padding=(k - u) // 2,
459
- )
460
- )
461
- )
462
- if i + 1 < len(upsample_rates):
463
- stride_f0 = np.prod(upsample_rates[i + 1 :])
464
- self.noise_convs.append(
465
- Conv1d(
466
- 1,
467
- c_cur,
468
- kernel_size=stride_f0 * 2,
469
- stride=stride_f0,
470
- padding=stride_f0 // 2,
471
- )
472
- )
473
- else:
474
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
475
-
476
- self.resblocks = nn.ModuleList()
477
- for i in range(len(self.ups)):
478
- ch = upsample_initial_channel // (2 ** (i + 1))
479
- for j, (k, d) in enumerate(
480
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
481
- ):
482
- self.resblocks.append(resblock(ch, k, d))
483
-
484
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
485
- self.ups.apply(init_weights)
486
-
487
- if gin_channels != 0:
488
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
489
-
490
- self.upp = np.prod(upsample_rates)
491
-
492
- def forward(self, x, f0, g=None):
493
- har_source, noi_source, uv = self.m_source(f0, self.upp)
494
- har_source = har_source.transpose(1, 2)
495
- x = self.conv_pre(x)
496
- if g is not None:
497
- x = x + self.cond(g)
498
-
499
- for i in range(self.num_upsamples):
500
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
501
- x = self.ups[i](x)
502
- x_source = self.noise_convs[i](har_source)
503
- x = x + x_source
504
- xs = None
505
- for j in range(self.num_kernels):
506
- if xs is None:
507
- xs = self.resblocks[i * self.num_kernels + j](x)
508
- else:
509
- xs += self.resblocks[i * self.num_kernels + j](x)
510
- x = xs / self.num_kernels
511
- x = F.leaky_relu(x)
512
- x = self.conv_post(x)
513
- x = torch.tanh(x)
514
- return x
515
-
516
- def remove_weight_norm(self):
517
- for l in self.ups:
518
- remove_weight_norm(l)
519
- for l in self.resblocks:
520
- l.remove_weight_norm()
521
-
522
-
523
- sr2sr = {
524
- "32k": 32000,
525
- "40k": 40000,
526
- "48k": 48000,
527
- }
528
-
529
-
530
- class SynthesizerTrnMs256NSFsidO(nn.Module):
531
- def __init__(
532
- self,
533
- spec_channels,
534
- segment_size,
535
- inter_channels,
536
- hidden_channels,
537
- filter_channels,
538
- n_heads,
539
- n_layers,
540
- kernel_size,
541
- p_dropout,
542
- resblock,
543
- resblock_kernel_sizes,
544
- resblock_dilation_sizes,
545
- upsample_rates,
546
- upsample_initial_channel,
547
- upsample_kernel_sizes,
548
- spk_embed_dim,
549
- gin_channels,
550
- sr,
551
- **kwargs
552
- ):
553
- super().__init__()
554
- if type(sr) == type("strr"):
555
- sr = sr2sr[sr]
556
- self.spec_channels = spec_channels
557
- self.inter_channels = inter_channels
558
- self.hidden_channels = hidden_channels
559
- self.filter_channels = filter_channels
560
- self.n_heads = n_heads
561
- self.n_layers = n_layers
562
- self.kernel_size = kernel_size
563
- self.p_dropout = p_dropout
564
- self.resblock = resblock
565
- self.resblock_kernel_sizes = resblock_kernel_sizes
566
- self.resblock_dilation_sizes = resblock_dilation_sizes
567
- self.upsample_rates = upsample_rates
568
- self.upsample_initial_channel = upsample_initial_channel
569
- self.upsample_kernel_sizes = upsample_kernel_sizes
570
- self.segment_size = segment_size
571
- self.gin_channels = gin_channels
572
- # self.hop_length = hop_length#
573
- self.spk_embed_dim = spk_embed_dim
574
- self.enc_p = TextEncoder256(
575
- inter_channels,
576
- hidden_channels,
577
- filter_channels,
578
- n_heads,
579
- n_layers,
580
- kernel_size,
581
- p_dropout,
582
- )
583
- self.dec = GeneratorNSF(
584
- inter_channels,
585
- resblock,
586
- resblock_kernel_sizes,
587
- resblock_dilation_sizes,
588
- upsample_rates,
589
- upsample_initial_channel,
590
- upsample_kernel_sizes,
591
- gin_channels=gin_channels,
592
- sr=sr,
593
- is_half=kwargs["is_half"],
594
- )
595
- self.enc_q = PosteriorEncoder(
596
- spec_channels,
597
- inter_channels,
598
- hidden_channels,
599
- 5,
600
- 1,
601
- 16,
602
- gin_channels=gin_channels,
603
- )
604
- self.flow = ResidualCouplingBlock(
605
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
606
- )
607
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
608
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
609
-
610
- def remove_weight_norm(self):
611
- self.dec.remove_weight_norm()
612
- self.flow.remove_weight_norm()
613
- self.enc_q.remove_weight_norm()
614
-
615
- def forward(self, phone, phone_lengths, pitch, nsff0, sid, max_len=None):
616
- g = self.emb_g(sid).unsqueeze(-1)
617
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
618
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
619
- z = self.flow(z_p, x_mask, g=g, reverse=True)
620
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
621
- return o
622
-
623
-
624
- class MultiPeriodDiscriminator(torch.nn.Module):
625
- def __init__(self, use_spectral_norm=False):
626
- super(MultiPeriodDiscriminator, self).__init__()
627
- periods = [2, 3, 5, 7, 11, 17]
628
- # periods = [3, 5, 7, 11, 17, 23, 37]
629
-
630
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
631
- discs = discs + [
632
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
633
- ]
634
- self.discriminators = nn.ModuleList(discs)
635
-
636
- def forward(self, y, y_hat):
637
- y_d_rs = [] #
638
- y_d_gs = []
639
- fmap_rs = []
640
- fmap_gs = []
641
- for i, d in enumerate(self.discriminators):
642
- y_d_r, fmap_r = d(y)
643
- y_d_g, fmap_g = d(y_hat)
644
- # for j in range(len(fmap_r)):
645
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
646
- y_d_rs.append(y_d_r)
647
- y_d_gs.append(y_d_g)
648
- fmap_rs.append(fmap_r)
649
- fmap_gs.append(fmap_g)
650
-
651
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
652
-
653
-
654
- class DiscriminatorS(torch.nn.Module):
655
- def __init__(self, use_spectral_norm=False):
656
- super(DiscriminatorS, self).__init__()
657
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
658
- self.convs = nn.ModuleList(
659
- [
660
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
661
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
662
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
663
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
664
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
665
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
666
- ]
667
- )
668
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
669
-
670
- def forward(self, x):
671
- fmap = []
672
-
673
- for l in self.convs:
674
- x = l(x)
675
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
676
- fmap.append(x)
677
- x = self.conv_post(x)
678
- fmap.append(x)
679
- x = torch.flatten(x, 1, -1)
680
-
681
- return x, fmap
682
-
683
-
684
- class DiscriminatorP(torch.nn.Module):
685
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
686
- super(DiscriminatorP, self).__init__()
687
- self.period = period
688
- self.use_spectral_norm = use_spectral_norm
689
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
690
- self.convs = nn.ModuleList(
691
- [
692
- norm_f(
693
- Conv2d(
694
- 1,
695
- 32,
696
- (kernel_size, 1),
697
- (stride, 1),
698
- padding=(get_padding(kernel_size, 1), 0),
699
- )
700
- ),
701
- norm_f(
702
- Conv2d(
703
- 32,
704
- 128,
705
- (kernel_size, 1),
706
- (stride, 1),
707
- padding=(get_padding(kernel_size, 1), 0),
708
- )
709
- ),
710
- norm_f(
711
- Conv2d(
712
- 128,
713
- 512,
714
- (kernel_size, 1),
715
- (stride, 1),
716
- padding=(get_padding(kernel_size, 1), 0),
717
- )
718
- ),
719
- norm_f(
720
- Conv2d(
721
- 512,
722
- 1024,
723
- (kernel_size, 1),
724
- (stride, 1),
725
- padding=(get_padding(kernel_size, 1), 0),
726
- )
727
- ),
728
- norm_f(
729
- Conv2d(
730
- 1024,
731
- 1024,
732
- (kernel_size, 1),
733
- 1,
734
- padding=(get_padding(kernel_size, 1), 0),
735
- )
736
- ),
737
- ]
738
- )
739
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
740
-
741
- def forward(self, x):
742
- fmap = []
743
-
744
- # 1d to 2d
745
- b, c, t = x.shape
746
- if t % self.period != 0: # pad first
747
- n_pad = self.period - (t % self.period)
748
- x = F.pad(x, (0, n_pad), "reflect")
749
- t = t + n_pad
750
- x = x.view(b, c, t // self.period, self.period)
751
-
752
- for l in self.convs:
753
- x = l(x)
754
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
755
- fmap.append(x)
756
- x = self.conv_post(x)
757
- fmap.append(x)
758
- x = torch.flatten(x, 1, -1)
759
-
760
- return x, fmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/models_onnx_moess.py DELETED
@@ -1,849 +0,0 @@
1
- import math, pdb, os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from infer_pack import modules
7
- from infer_pack import attentions
8
- from infer_pack import commons
9
- from infer_pack.commons import init_weights, get_padding
10
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from infer_pack.commons import init_weights
13
- import numpy as np
14
- from infer_pack import commons
15
-
16
-
17
- class TextEncoder256(nn.Module):
18
- def __init__(
19
- self,
20
- out_channels,
21
- hidden_channels,
22
- filter_channels,
23
- n_heads,
24
- n_layers,
25
- kernel_size,
26
- p_dropout,
27
- f0=True,
28
- ):
29
- super().__init__()
30
- self.out_channels = out_channels
31
- self.hidden_channels = hidden_channels
32
- self.filter_channels = filter_channels
33
- self.n_heads = n_heads
34
- self.n_layers = n_layers
35
- self.kernel_size = kernel_size
36
- self.p_dropout = p_dropout
37
- self.emb_phone = nn.Linear(256, hidden_channels)
38
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
- if f0 == True:
40
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
- self.encoder = attentions.Encoder(
42
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
- )
44
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
-
46
- def forward(self, phone, pitch, lengths):
47
- if pitch == None:
48
- x = self.emb_phone(phone)
49
- else:
50
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
- x = self.lrelu(x)
53
- x = torch.transpose(x, 1, -1) # [b, h, t]
54
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
- x.dtype
56
- )
57
- x = self.encoder(x * x_mask, x_mask)
58
- stats = self.proj(x) * x_mask
59
-
60
- m, logs = torch.split(stats, self.out_channels, dim=1)
61
- return m, logs, x_mask
62
-
63
-
64
- class TextEncoder256Sim(nn.Module):
65
- def __init__(
66
- self,
67
- out_channels,
68
- hidden_channels,
69
- filter_channels,
70
- n_heads,
71
- n_layers,
72
- kernel_size,
73
- p_dropout,
74
- f0=True,
75
- ):
76
- super().__init__()
77
- self.out_channels = out_channels
78
- self.hidden_channels = hidden_channels
79
- self.filter_channels = filter_channels
80
- self.n_heads = n_heads
81
- self.n_layers = n_layers
82
- self.kernel_size = kernel_size
83
- self.p_dropout = p_dropout
84
- self.emb_phone = nn.Linear(256, hidden_channels)
85
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
- if f0 == True:
87
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
- self.encoder = attentions.Encoder(
89
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
- )
91
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
92
-
93
- def forward(self, phone, pitch, lengths):
94
- if pitch == None:
95
- x = self.emb_phone(phone)
96
- else:
97
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
- x = self.lrelu(x)
100
- x = torch.transpose(x, 1, -1) # [b, h, t]
101
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
- x.dtype
103
- )
104
- x = self.encoder(x * x_mask, x_mask)
105
- x = self.proj(x) * x_mask
106
- return x, x_mask
107
-
108
-
109
- class ResidualCouplingBlock(nn.Module):
110
- def __init__(
111
- self,
112
- channels,
113
- hidden_channels,
114
- kernel_size,
115
- dilation_rate,
116
- n_layers,
117
- n_flows=4,
118
- gin_channels=0,
119
- ):
120
- super().__init__()
121
- self.channels = channels
122
- self.hidden_channels = hidden_channels
123
- self.kernel_size = kernel_size
124
- self.dilation_rate = dilation_rate
125
- self.n_layers = n_layers
126
- self.n_flows = n_flows
127
- self.gin_channels = gin_channels
128
-
129
- self.flows = nn.ModuleList()
130
- for i in range(n_flows):
131
- self.flows.append(
132
- modules.ResidualCouplingLayer(
133
- channels,
134
- hidden_channels,
135
- kernel_size,
136
- dilation_rate,
137
- n_layers,
138
- gin_channels=gin_channels,
139
- mean_only=True,
140
- )
141
- )
142
- self.flows.append(modules.Flip())
143
-
144
- def forward(self, x, x_mask, g=None, reverse=False):
145
- if not reverse:
146
- for flow in self.flows:
147
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
148
- else:
149
- for flow in reversed(self.flows):
150
- x = flow(x, x_mask, g=g, reverse=reverse)
151
- return x
152
-
153
- def remove_weight_norm(self):
154
- for i in range(self.n_flows):
155
- self.flows[i * 2].remove_weight_norm()
156
-
157
-
158
- class PosteriorEncoder(nn.Module):
159
- def __init__(
160
- self,
161
- in_channels,
162
- out_channels,
163
- hidden_channels,
164
- kernel_size,
165
- dilation_rate,
166
- n_layers,
167
- gin_channels=0,
168
- ):
169
- super().__init__()
170
- self.in_channels = in_channels
171
- self.out_channels = out_channels
172
- self.hidden_channels = hidden_channels
173
- self.kernel_size = kernel_size
174
- self.dilation_rate = dilation_rate
175
- self.n_layers = n_layers
176
- self.gin_channels = gin_channels
177
-
178
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
179
- self.enc = modules.WN(
180
- hidden_channels,
181
- kernel_size,
182
- dilation_rate,
183
- n_layers,
184
- gin_channels=gin_channels,
185
- )
186
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
187
-
188
- def forward(self, x, x_lengths, g=None):
189
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
190
- x.dtype
191
- )
192
- x = self.pre(x) * x_mask
193
- x = self.enc(x, x_mask, g=g)
194
- stats = self.proj(x) * x_mask
195
- m, logs = torch.split(stats, self.out_channels, dim=1)
196
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
197
- return z, m, logs, x_mask
198
-
199
- def remove_weight_norm(self):
200
- self.enc.remove_weight_norm()
201
-
202
-
203
- class Generator(torch.nn.Module):
204
- def __init__(
205
- self,
206
- initial_channel,
207
- resblock,
208
- resblock_kernel_sizes,
209
- resblock_dilation_sizes,
210
- upsample_rates,
211
- upsample_initial_channel,
212
- upsample_kernel_sizes,
213
- gin_channels=0,
214
- ):
215
- super(Generator, self).__init__()
216
- self.num_kernels = len(resblock_kernel_sizes)
217
- self.num_upsamples = len(upsample_rates)
218
- self.conv_pre = Conv1d(
219
- initial_channel, upsample_initial_channel, 7, 1, padding=3
220
- )
221
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
222
-
223
- self.ups = nn.ModuleList()
224
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
225
- self.ups.append(
226
- weight_norm(
227
- ConvTranspose1d(
228
- upsample_initial_channel // (2**i),
229
- upsample_initial_channel // (2 ** (i + 1)),
230
- k,
231
- u,
232
- padding=(k - u) // 2,
233
- )
234
- )
235
- )
236
-
237
- self.resblocks = nn.ModuleList()
238
- for i in range(len(self.ups)):
239
- ch = upsample_initial_channel // (2 ** (i + 1))
240
- for j, (k, d) in enumerate(
241
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
242
- ):
243
- self.resblocks.append(resblock(ch, k, d))
244
-
245
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
246
- self.ups.apply(init_weights)
247
-
248
- if gin_channels != 0:
249
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
250
-
251
- def forward(self, x, g=None):
252
- x = self.conv_pre(x)
253
- if g is not None:
254
- x = x + self.cond(g)
255
-
256
- for i in range(self.num_upsamples):
257
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
258
- x = self.ups[i](x)
259
- xs = None
260
- for j in range(self.num_kernels):
261
- if xs is None:
262
- xs = self.resblocks[i * self.num_kernels + j](x)
263
- else:
264
- xs += self.resblocks[i * self.num_kernels + j](x)
265
- x = xs / self.num_kernels
266
- x = F.leaky_relu(x)
267
- x = self.conv_post(x)
268
- x = torch.tanh(x)
269
-
270
- return x
271
-
272
- def remove_weight_norm(self):
273
- for l in self.ups:
274
- remove_weight_norm(l)
275
- for l in self.resblocks:
276
- l.remove_weight_norm()
277
-
278
-
279
- class SineGen(torch.nn.Module):
280
- """Definition of sine generator
281
- SineGen(samp_rate, harmonic_num = 0,
282
- sine_amp = 0.1, noise_std = 0.003,
283
- voiced_threshold = 0,
284
- flag_for_pulse=False)
285
- samp_rate: sampling rate in Hz
286
- harmonic_num: number of harmonic overtones (default 0)
287
- sine_amp: amplitude of sine-wavefrom (default 0.1)
288
- noise_std: std of Gaussian noise (default 0.003)
289
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
290
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
291
- Note: when flag_for_pulse is True, the first time step of a voiced
292
- segment is always sin(np.pi) or cos(0)
293
- """
294
-
295
- def __init__(
296
- self,
297
- samp_rate,
298
- harmonic_num=0,
299
- sine_amp=0.1,
300
- noise_std=0.003,
301
- voiced_threshold=0,
302
- flag_for_pulse=False,
303
- ):
304
- super(SineGen, self).__init__()
305
- self.sine_amp = sine_amp
306
- self.noise_std = noise_std
307
- self.harmonic_num = harmonic_num
308
- self.dim = self.harmonic_num + 1
309
- self.sampling_rate = samp_rate
310
- self.voiced_threshold = voiced_threshold
311
-
312
- def _f02uv(self, f0):
313
- # generate uv signal
314
- uv = torch.ones_like(f0)
315
- uv = uv * (f0 > self.voiced_threshold)
316
- return uv
317
-
318
- def forward(self, f0, upp):
319
- """sine_tensor, uv = forward(f0)
320
- input F0: tensor(batchsize=1, length, dim=1)
321
- f0 for unvoiced steps should be 0
322
- output sine_tensor: tensor(batchsize=1, length, dim)
323
- output uv: tensor(batchsize=1, length, 1)
324
- """
325
- with torch.no_grad():
326
- f0 = f0[:, None].transpose(1, 2)
327
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
328
- # fundamental component
329
- f0_buf[:, :, 0] = f0[:, :, 0]
330
- for idx in np.arange(self.harmonic_num):
331
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
332
- idx + 2
333
- ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
334
- rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
335
- rand_ini = torch.rand(
336
- f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
337
- )
338
- rand_ini[:, 0] = 0
339
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
340
- tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
341
- tmp_over_one *= upp
342
- tmp_over_one = F.interpolate(
343
- tmp_over_one.transpose(2, 1),
344
- scale_factor=upp,
345
- mode="linear",
346
- align_corners=True,
347
- ).transpose(2, 1)
348
- rad_values = F.interpolate(
349
- rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
350
- ).transpose(
351
- 2, 1
352
- ) #######
353
- tmp_over_one %= 1
354
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
355
- cumsum_shift = torch.zeros_like(rad_values)
356
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
357
- sine_waves = torch.sin(
358
- torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
359
- )
360
- sine_waves = sine_waves * self.sine_amp
361
- uv = self._f02uv(f0)
362
- uv = F.interpolate(
363
- uv.transpose(2, 1), scale_factor=upp, mode="nearest"
364
- ).transpose(2, 1)
365
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
366
- noise = noise_amp * torch.randn_like(sine_waves)
367
- sine_waves = sine_waves * uv + noise
368
- return sine_waves, uv, noise
369
-
370
-
371
- class SourceModuleHnNSF(torch.nn.Module):
372
- """SourceModule for hn-nsf
373
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
374
- add_noise_std=0.003, voiced_threshod=0)
375
- sampling_rate: sampling_rate in Hz
376
- harmonic_num: number of harmonic above F0 (default: 0)
377
- sine_amp: amplitude of sine source signal (default: 0.1)
378
- add_noise_std: std of additive Gaussian noise (default: 0.003)
379
- note that amplitude of noise in unvoiced is decided
380
- by sine_amp
381
- voiced_threshold: threhold to set U/V given F0 (default: 0)
382
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
383
- F0_sampled (batchsize, length, 1)
384
- Sine_source (batchsize, length, 1)
385
- noise_source (batchsize, length 1)
386
- uv (batchsize, length, 1)
387
- """
388
-
389
- def __init__(
390
- self,
391
- sampling_rate,
392
- harmonic_num=0,
393
- sine_amp=0.1,
394
- add_noise_std=0.003,
395
- voiced_threshod=0,
396
- is_half=True,
397
- ):
398
- super(SourceModuleHnNSF, self).__init__()
399
-
400
- self.sine_amp = sine_amp
401
- self.noise_std = add_noise_std
402
- self.is_half = is_half
403
- # to produce sine waveforms
404
- self.l_sin_gen = SineGen(
405
- sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
406
- )
407
-
408
- # to merge source harmonics into a single excitation
409
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
410
- self.l_tanh = torch.nn.Tanh()
411
-
412
- def forward(self, x, upp=None):
413
- sine_wavs, uv, _ = self.l_sin_gen(x, upp)
414
- if self.is_half:
415
- sine_wavs = sine_wavs.half()
416
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
417
- return sine_merge, None, None # noise, uv
418
-
419
-
420
- class GeneratorNSF(torch.nn.Module):
421
- def __init__(
422
- self,
423
- initial_channel,
424
- resblock,
425
- resblock_kernel_sizes,
426
- resblock_dilation_sizes,
427
- upsample_rates,
428
- upsample_initial_channel,
429
- upsample_kernel_sizes,
430
- gin_channels,
431
- sr,
432
- is_half=False,
433
- ):
434
- super(GeneratorNSF, self).__init__()
435
- self.num_kernels = len(resblock_kernel_sizes)
436
- self.num_upsamples = len(upsample_rates)
437
-
438
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
439
- self.m_source = SourceModuleHnNSF(
440
- sampling_rate=sr, harmonic_num=0, is_half=is_half
441
- )
442
- self.noise_convs = nn.ModuleList()
443
- self.conv_pre = Conv1d(
444
- initial_channel, upsample_initial_channel, 7, 1, padding=3
445
- )
446
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
447
-
448
- self.ups = nn.ModuleList()
449
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
450
- c_cur = upsample_initial_channel // (2 ** (i + 1))
451
- self.ups.append(
452
- weight_norm(
453
- ConvTranspose1d(
454
- upsample_initial_channel // (2**i),
455
- upsample_initial_channel // (2 ** (i + 1)),
456
- k,
457
- u,
458
- padding=(k - u) // 2,
459
- )
460
- )
461
- )
462
- if i + 1 < len(upsample_rates):
463
- stride_f0 = np.prod(upsample_rates[i + 1 :])
464
- self.noise_convs.append(
465
- Conv1d(
466
- 1,
467
- c_cur,
468
- kernel_size=stride_f0 * 2,
469
- stride=stride_f0,
470
- padding=stride_f0 // 2,
471
- )
472
- )
473
- else:
474
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
475
-
476
- self.resblocks = nn.ModuleList()
477
- for i in range(len(self.ups)):
478
- ch = upsample_initial_channel // (2 ** (i + 1))
479
- for j, (k, d) in enumerate(
480
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
481
- ):
482
- self.resblocks.append(resblock(ch, k, d))
483
-
484
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
485
- self.ups.apply(init_weights)
486
-
487
- if gin_channels != 0:
488
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
489
-
490
- self.upp = np.prod(upsample_rates)
491
-
492
- def forward(self, x, f0, g=None):
493
- har_source, noi_source, uv = self.m_source(f0, self.upp)
494
- har_source = har_source.transpose(1, 2)
495
- x = self.conv_pre(x)
496
- if g is not None:
497
- x = x + self.cond(g)
498
-
499
- for i in range(self.num_upsamples):
500
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
501
- x = self.ups[i](x)
502
- x_source = self.noise_convs[i](har_source)
503
- x = x + x_source
504
- xs = None
505
- for j in range(self.num_kernels):
506
- if xs is None:
507
- xs = self.resblocks[i * self.num_kernels + j](x)
508
- else:
509
- xs += self.resblocks[i * self.num_kernels + j](x)
510
- x = xs / self.num_kernels
511
- x = F.leaky_relu(x)
512
- x = self.conv_post(x)
513
- x = torch.tanh(x)
514
- return x
515
-
516
- def remove_weight_norm(self):
517
- for l in self.ups:
518
- remove_weight_norm(l)
519
- for l in self.resblocks:
520
- l.remove_weight_norm()
521
-
522
-
523
- sr2sr = {
524
- "32k": 32000,
525
- "40k": 40000,
526
- "48k": 48000,
527
- }
528
-
529
-
530
- class SynthesizerTrnMs256NSFsidM(nn.Module):
531
- def __init__(
532
- self,
533
- spec_channels,
534
- segment_size,
535
- inter_channels,
536
- hidden_channels,
537
- filter_channels,
538
- n_heads,
539
- n_layers,
540
- kernel_size,
541
- p_dropout,
542
- resblock,
543
- resblock_kernel_sizes,
544
- resblock_dilation_sizes,
545
- upsample_rates,
546
- upsample_initial_channel,
547
- upsample_kernel_sizes,
548
- spk_embed_dim,
549
- gin_channels,
550
- sr,
551
- **kwargs
552
- ):
553
- super().__init__()
554
- if type(sr) == type("strr"):
555
- sr = sr2sr[sr]
556
- self.spec_channels = spec_channels
557
- self.inter_channels = inter_channels
558
- self.hidden_channels = hidden_channels
559
- self.filter_channels = filter_channels
560
- self.n_heads = n_heads
561
- self.n_layers = n_layers
562
- self.kernel_size = kernel_size
563
- self.p_dropout = p_dropout
564
- self.resblock = resblock
565
- self.resblock_kernel_sizes = resblock_kernel_sizes
566
- self.resblock_dilation_sizes = resblock_dilation_sizes
567
- self.upsample_rates = upsample_rates
568
- self.upsample_initial_channel = upsample_initial_channel
569
- self.upsample_kernel_sizes = upsample_kernel_sizes
570
- self.segment_size = segment_size
571
- self.gin_channels = gin_channels
572
- # self.hop_length = hop_length#
573
- self.spk_embed_dim = spk_embed_dim
574
- self.enc_p = TextEncoder256(
575
- inter_channels,
576
- hidden_channels,
577
- filter_channels,
578
- n_heads,
579
- n_layers,
580
- kernel_size,
581
- p_dropout,
582
- )
583
- self.dec = GeneratorNSF(
584
- inter_channels,
585
- resblock,
586
- resblock_kernel_sizes,
587
- resblock_dilation_sizes,
588
- upsample_rates,
589
- upsample_initial_channel,
590
- upsample_kernel_sizes,
591
- gin_channels=gin_channels,
592
- sr=sr,
593
- is_half=kwargs["is_half"],
594
- )
595
- self.enc_q = PosteriorEncoder(
596
- spec_channels,
597
- inter_channels,
598
- hidden_channels,
599
- 5,
600
- 1,
601
- 16,
602
- gin_channels=gin_channels,
603
- )
604
- self.flow = ResidualCouplingBlock(
605
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
606
- )
607
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
608
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
609
-
610
- def remove_weight_norm(self):
611
- self.dec.remove_weight_norm()
612
- self.flow.remove_weight_norm()
613
- self.enc_q.remove_weight_norm()
614
-
615
- def forward(self, phone, phone_lengths, pitch, nsff0, sid, rnd, max_len=None):
616
- g = self.emb_g(sid).unsqueeze(-1)
617
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
618
- z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
619
- z = self.flow(z_p, x_mask, g=g, reverse=True)
620
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
621
- return o
622
-
623
-
624
- class SynthesizerTrnMs256NSFsid_sim(nn.Module):
625
- """
626
- Synthesizer for Training
627
- """
628
-
629
- def __init__(
630
- self,
631
- spec_channels,
632
- segment_size,
633
- inter_channels,
634
- hidden_channels,
635
- filter_channels,
636
- n_heads,
637
- n_layers,
638
- kernel_size,
639
- p_dropout,
640
- resblock,
641
- resblock_kernel_sizes,
642
- resblock_dilation_sizes,
643
- upsample_rates,
644
- upsample_initial_channel,
645
- upsample_kernel_sizes,
646
- spk_embed_dim,
647
- # hop_length,
648
- gin_channels=0,
649
- use_sdp=True,
650
- **kwargs
651
- ):
652
- super().__init__()
653
- self.spec_channels = spec_channels
654
- self.inter_channels = inter_channels
655
- self.hidden_channels = hidden_channels
656
- self.filter_channels = filter_channels
657
- self.n_heads = n_heads
658
- self.n_layers = n_layers
659
- self.kernel_size = kernel_size
660
- self.p_dropout = p_dropout
661
- self.resblock = resblock
662
- self.resblock_kernel_sizes = resblock_kernel_sizes
663
- self.resblock_dilation_sizes = resblock_dilation_sizes
664
- self.upsample_rates = upsample_rates
665
- self.upsample_initial_channel = upsample_initial_channel
666
- self.upsample_kernel_sizes = upsample_kernel_sizes
667
- self.segment_size = segment_size
668
- self.gin_channels = gin_channels
669
- # self.hop_length = hop_length#
670
- self.spk_embed_dim = spk_embed_dim
671
- self.enc_p = TextEncoder256Sim(
672
- inter_channels,
673
- hidden_channels,
674
- filter_channels,
675
- n_heads,
676
- n_layers,
677
- kernel_size,
678
- p_dropout,
679
- )
680
- self.dec = GeneratorNSF(
681
- inter_channels,
682
- resblock,
683
- resblock_kernel_sizes,
684
- resblock_dilation_sizes,
685
- upsample_rates,
686
- upsample_initial_channel,
687
- upsample_kernel_sizes,
688
- gin_channels=gin_channels,
689
- is_half=kwargs["is_half"],
690
- )
691
-
692
- self.flow = ResidualCouplingBlock(
693
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
694
- )
695
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
696
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
697
-
698
- def remove_weight_norm(self):
699
- self.dec.remove_weight_norm()
700
- self.flow.remove_weight_norm()
701
- self.enc_q.remove_weight_norm()
702
-
703
- def forward(
704
- self, phone, phone_lengths, pitch, pitchf, ds, max_len=None
705
- ): # y是spec不需要了现在
706
- g = self.emb_g(ds.unsqueeze(0)).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
707
- x, x_mask = self.enc_p(phone, pitch, phone_lengths)
708
- x = self.flow(x, x_mask, g=g, reverse=True)
709
- o = self.dec((x * x_mask)[:, :, :max_len], pitchf, g=g)
710
- return o
711
-
712
-
713
- class MultiPeriodDiscriminator(torch.nn.Module):
714
- def __init__(self, use_spectral_norm=False):
715
- super(MultiPeriodDiscriminator, self).__init__()
716
- periods = [2, 3, 5, 7, 11, 17]
717
- # periods = [3, 5, 7, 11, 17, 23, 37]
718
-
719
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
720
- discs = discs + [
721
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
722
- ]
723
- self.discriminators = nn.ModuleList(discs)
724
-
725
- def forward(self, y, y_hat):
726
- y_d_rs = [] #
727
- y_d_gs = []
728
- fmap_rs = []
729
- fmap_gs = []
730
- for i, d in enumerate(self.discriminators):
731
- y_d_r, fmap_r = d(y)
732
- y_d_g, fmap_g = d(y_hat)
733
- # for j in range(len(fmap_r)):
734
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
735
- y_d_rs.append(y_d_r)
736
- y_d_gs.append(y_d_g)
737
- fmap_rs.append(fmap_r)
738
- fmap_gs.append(fmap_g)
739
-
740
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
741
-
742
-
743
- class DiscriminatorS(torch.nn.Module):
744
- def __init__(self, use_spectral_norm=False):
745
- super(DiscriminatorS, self).__init__()
746
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
747
- self.convs = nn.ModuleList(
748
- [
749
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
750
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
751
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
752
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
753
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
754
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
755
- ]
756
- )
757
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
758
-
759
- def forward(self, x):
760
- fmap = []
761
-
762
- for l in self.convs:
763
- x = l(x)
764
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
765
- fmap.append(x)
766
- x = self.conv_post(x)
767
- fmap.append(x)
768
- x = torch.flatten(x, 1, -1)
769
-
770
- return x, fmap
771
-
772
-
773
- class DiscriminatorP(torch.nn.Module):
774
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
775
- super(DiscriminatorP, self).__init__()
776
- self.period = period
777
- self.use_spectral_norm = use_spectral_norm
778
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
779
- self.convs = nn.ModuleList(
780
- [
781
- norm_f(
782
- Conv2d(
783
- 1,
784
- 32,
785
- (kernel_size, 1),
786
- (stride, 1),
787
- padding=(get_padding(kernel_size, 1), 0),
788
- )
789
- ),
790
- norm_f(
791
- Conv2d(
792
- 32,
793
- 128,
794
- (kernel_size, 1),
795
- (stride, 1),
796
- padding=(get_padding(kernel_size, 1), 0),
797
- )
798
- ),
799
- norm_f(
800
- Conv2d(
801
- 128,
802
- 512,
803
- (kernel_size, 1),
804
- (stride, 1),
805
- padding=(get_padding(kernel_size, 1), 0),
806
- )
807
- ),
808
- norm_f(
809
- Conv2d(
810
- 512,
811
- 1024,
812
- (kernel_size, 1),
813
- (stride, 1),
814
- padding=(get_padding(kernel_size, 1), 0),
815
- )
816
- ),
817
- norm_f(
818
- Conv2d(
819
- 1024,
820
- 1024,
821
- (kernel_size, 1),
822
- 1,
823
- padding=(get_padding(kernel_size, 1), 0),
824
- )
825
- ),
826
- ]
827
- )
828
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
829
-
830
- def forward(self, x):
831
- fmap = []
832
-
833
- # 1d to 2d
834
- b, c, t = x.shape
835
- if t % self.period != 0: # pad first
836
- n_pad = self.period - (t % self.period)
837
- x = F.pad(x, (0, n_pad), "reflect")
838
- t = t + n_pad
839
- x = x.view(b, c, t // self.period, self.period)
840
-
841
- for l in self.convs:
842
- x = l(x)
843
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
844
- fmap.append(x)
845
- x = self.conv_post(x)
846
- fmap.append(x)
847
- x = torch.flatten(x, 1, -1)
848
-
849
- return x, fmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/modules.py DELETED
@@ -1,522 +0,0 @@
1
- import copy
2
- import math
3
- import numpy as np
4
- import scipy
5
- import torch
6
- from torch import nn
7
- from torch.nn import functional as F
8
-
9
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
- from torch.nn.utils import weight_norm, remove_weight_norm
11
-
12
- from infer_pack import commons
13
- from infer_pack.commons import init_weights, get_padding
14
- from infer_pack.transforms import piecewise_rational_quadratic_transform
15
-
16
-
17
- LRELU_SLOPE = 0.1
18
-
19
-
20
- class LayerNorm(nn.Module):
21
- def __init__(self, channels, eps=1e-5):
22
- super().__init__()
23
- self.channels = channels
24
- self.eps = eps
25
-
26
- self.gamma = nn.Parameter(torch.ones(channels))
27
- self.beta = nn.Parameter(torch.zeros(channels))
28
-
29
- def forward(self, x):
30
- x = x.transpose(1, -1)
31
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
- return x.transpose(1, -1)
33
-
34
-
35
- class ConvReluNorm(nn.Module):
36
- def __init__(
37
- self,
38
- in_channels,
39
- hidden_channels,
40
- out_channels,
41
- kernel_size,
42
- n_layers,
43
- p_dropout,
44
- ):
45
- super().__init__()
46
- self.in_channels = in_channels
47
- self.hidden_channels = hidden_channels
48
- self.out_channels = out_channels
49
- self.kernel_size = kernel_size
50
- self.n_layers = n_layers
51
- self.p_dropout = p_dropout
52
- assert n_layers > 1, "Number of layers should be larger than 0."
53
-
54
- self.conv_layers = nn.ModuleList()
55
- self.norm_layers = nn.ModuleList()
56
- self.conv_layers.append(
57
- nn.Conv1d(
58
- in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
- )
60
- )
61
- self.norm_layers.append(LayerNorm(hidden_channels))
62
- self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
- for _ in range(n_layers - 1):
64
- self.conv_layers.append(
65
- nn.Conv1d(
66
- hidden_channels,
67
- hidden_channels,
68
- kernel_size,
69
- padding=kernel_size // 2,
70
- )
71
- )
72
- self.norm_layers.append(LayerNorm(hidden_channels))
73
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
- self.proj.weight.data.zero_()
75
- self.proj.bias.data.zero_()
76
-
77
- def forward(self, x, x_mask):
78
- x_org = x
79
- for i in range(self.n_layers):
80
- x = self.conv_layers[i](x * x_mask)
81
- x = self.norm_layers[i](x)
82
- x = self.relu_drop(x)
83
- x = x_org + self.proj(x)
84
- return x * x_mask
85
-
86
-
87
- class DDSConv(nn.Module):
88
- """
89
- Dialted and Depth-Separable Convolution
90
- """
91
-
92
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
- super().__init__()
94
- self.channels = channels
95
- self.kernel_size = kernel_size
96
- self.n_layers = n_layers
97
- self.p_dropout = p_dropout
98
-
99
- self.drop = nn.Dropout(p_dropout)
100
- self.convs_sep = nn.ModuleList()
101
- self.convs_1x1 = nn.ModuleList()
102
- self.norms_1 = nn.ModuleList()
103
- self.norms_2 = nn.ModuleList()
104
- for i in range(n_layers):
105
- dilation = kernel_size**i
106
- padding = (kernel_size * dilation - dilation) // 2
107
- self.convs_sep.append(
108
- nn.Conv1d(
109
- channels,
110
- channels,
111
- kernel_size,
112
- groups=channels,
113
- dilation=dilation,
114
- padding=padding,
115
- )
116
- )
117
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
- self.norms_1.append(LayerNorm(channels))
119
- self.norms_2.append(LayerNorm(channels))
120
-
121
- def forward(self, x, x_mask, g=None):
122
- if g is not None:
123
- x = x + g
124
- for i in range(self.n_layers):
125
- y = self.convs_sep[i](x * x_mask)
126
- y = self.norms_1[i](y)
127
- y = F.gelu(y)
128
- y = self.convs_1x1[i](y)
129
- y = self.norms_2[i](y)
130
- y = F.gelu(y)
131
- y = self.drop(y)
132
- x = x + y
133
- return x * x_mask
134
-
135
-
136
- class WN(torch.nn.Module):
137
- def __init__(
138
- self,
139
- hidden_channels,
140
- kernel_size,
141
- dilation_rate,
142
- n_layers,
143
- gin_channels=0,
144
- p_dropout=0,
145
- ):
146
- super(WN, self).__init__()
147
- assert kernel_size % 2 == 1
148
- self.hidden_channels = hidden_channels
149
- self.kernel_size = (kernel_size,)
150
- self.dilation_rate = dilation_rate
151
- self.n_layers = n_layers
152
- self.gin_channels = gin_channels
153
- self.p_dropout = p_dropout
154
-
155
- self.in_layers = torch.nn.ModuleList()
156
- self.res_skip_layers = torch.nn.ModuleList()
157
- self.drop = nn.Dropout(p_dropout)
158
-
159
- if gin_channels != 0:
160
- cond_layer = torch.nn.Conv1d(
161
- gin_channels, 2 * hidden_channels * n_layers, 1
162
- )
163
- self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
-
165
- for i in range(n_layers):
166
- dilation = dilation_rate**i
167
- padding = int((kernel_size * dilation - dilation) / 2)
168
- in_layer = torch.nn.Conv1d(
169
- hidden_channels,
170
- 2 * hidden_channels,
171
- kernel_size,
172
- dilation=dilation,
173
- padding=padding,
174
- )
175
- in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
- self.in_layers.append(in_layer)
177
-
178
- # last one is not necessary
179
- if i < n_layers - 1:
180
- res_skip_channels = 2 * hidden_channels
181
- else:
182
- res_skip_channels = hidden_channels
183
-
184
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
- res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
- self.res_skip_layers.append(res_skip_layer)
187
-
188
- def forward(self, x, x_mask, g=None, **kwargs):
189
- output = torch.zeros_like(x)
190
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
191
-
192
- if g is not None:
193
- g = self.cond_layer(g)
194
-
195
- for i in range(self.n_layers):
196
- x_in = self.in_layers[i](x)
197
- if g is not None:
198
- cond_offset = i * 2 * self.hidden_channels
199
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200
- else:
201
- g_l = torch.zeros_like(x_in)
202
-
203
- acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204
- acts = self.drop(acts)
205
-
206
- res_skip_acts = self.res_skip_layers[i](acts)
207
- if i < self.n_layers - 1:
208
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
209
- x = (x + res_acts) * x_mask
210
- output = output + res_skip_acts[:, self.hidden_channels :, :]
211
- else:
212
- output = output + res_skip_acts
213
- return output * x_mask
214
-
215
- def remove_weight_norm(self):
216
- if self.gin_channels != 0:
217
- torch.nn.utils.remove_weight_norm(self.cond_layer)
218
- for l in self.in_layers:
219
- torch.nn.utils.remove_weight_norm(l)
220
- for l in self.res_skip_layers:
221
- torch.nn.utils.remove_weight_norm(l)
222
-
223
-
224
- class ResBlock1(torch.nn.Module):
225
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226
- super(ResBlock1, self).__init__()
227
- self.convs1 = nn.ModuleList(
228
- [
229
- weight_norm(
230
- Conv1d(
231
- channels,
232
- channels,
233
- kernel_size,
234
- 1,
235
- dilation=dilation[0],
236
- padding=get_padding(kernel_size, dilation[0]),
237
- )
238
- ),
239
- weight_norm(
240
- Conv1d(
241
- channels,
242
- channels,
243
- kernel_size,
244
- 1,
245
- dilation=dilation[1],
246
- padding=get_padding(kernel_size, dilation[1]),
247
- )
248
- ),
249
- weight_norm(
250
- Conv1d(
251
- channels,
252
- channels,
253
- kernel_size,
254
- 1,
255
- dilation=dilation[2],
256
- padding=get_padding(kernel_size, dilation[2]),
257
- )
258
- ),
259
- ]
260
- )
261
- self.convs1.apply(init_weights)
262
-
263
- self.convs2 = nn.ModuleList(
264
- [
265
- weight_norm(
266
- Conv1d(
267
- channels,
268
- channels,
269
- kernel_size,
270
- 1,
271
- dilation=1,
272
- padding=get_padding(kernel_size, 1),
273
- )
274
- ),
275
- weight_norm(
276
- Conv1d(
277
- channels,
278
- channels,
279
- kernel_size,
280
- 1,
281
- dilation=1,
282
- padding=get_padding(kernel_size, 1),
283
- )
284
- ),
285
- weight_norm(
286
- Conv1d(
287
- channels,
288
- channels,
289
- kernel_size,
290
- 1,
291
- dilation=1,
292
- padding=get_padding(kernel_size, 1),
293
- )
294
- ),
295
- ]
296
- )
297
- self.convs2.apply(init_weights)
298
-
299
- def forward(self, x, x_mask=None):
300
- for c1, c2 in zip(self.convs1, self.convs2):
301
- xt = F.leaky_relu(x, LRELU_SLOPE)
302
- if x_mask is not None:
303
- xt = xt * x_mask
304
- xt = c1(xt)
305
- xt = F.leaky_relu(xt, LRELU_SLOPE)
306
- if x_mask is not None:
307
- xt = xt * x_mask
308
- xt = c2(xt)
309
- x = xt + x
310
- if x_mask is not None:
311
- x = x * x_mask
312
- return x
313
-
314
- def remove_weight_norm(self):
315
- for l in self.convs1:
316
- remove_weight_norm(l)
317
- for l in self.convs2:
318
- remove_weight_norm(l)
319
-
320
-
321
- class ResBlock2(torch.nn.Module):
322
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323
- super(ResBlock2, self).__init__()
324
- self.convs = nn.ModuleList(
325
- [
326
- weight_norm(
327
- Conv1d(
328
- channels,
329
- channels,
330
- kernel_size,
331
- 1,
332
- dilation=dilation[0],
333
- padding=get_padding(kernel_size, dilation[0]),
334
- )
335
- ),
336
- weight_norm(
337
- Conv1d(
338
- channels,
339
- channels,
340
- kernel_size,
341
- 1,
342
- dilation=dilation[1],
343
- padding=get_padding(kernel_size, dilation[1]),
344
- )
345
- ),
346
- ]
347
- )
348
- self.convs.apply(init_weights)
349
-
350
- def forward(self, x, x_mask=None):
351
- for c in self.convs:
352
- xt = F.leaky_relu(x, LRELU_SLOPE)
353
- if x_mask is not None:
354
- xt = xt * x_mask
355
- xt = c(xt)
356
- x = xt + x
357
- if x_mask is not None:
358
- x = x * x_mask
359
- return x
360
-
361
- def remove_weight_norm(self):
362
- for l in self.convs:
363
- remove_weight_norm(l)
364
-
365
-
366
- class Log(nn.Module):
367
- def forward(self, x, x_mask, reverse=False, **kwargs):
368
- if not reverse:
369
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370
- logdet = torch.sum(-y, [1, 2])
371
- return y, logdet
372
- else:
373
- x = torch.exp(x) * x_mask
374
- return x
375
-
376
-
377
- class Flip(nn.Module):
378
- def forward(self, x, *args, reverse=False, **kwargs):
379
- x = torch.flip(x, [1])
380
- if not reverse:
381
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382
- return x, logdet
383
- else:
384
- return x
385
-
386
-
387
- class ElementwiseAffine(nn.Module):
388
- def __init__(self, channels):
389
- super().__init__()
390
- self.channels = channels
391
- self.m = nn.Parameter(torch.zeros(channels, 1))
392
- self.logs = nn.Parameter(torch.zeros(channels, 1))
393
-
394
- def forward(self, x, x_mask, reverse=False, **kwargs):
395
- if not reverse:
396
- y = self.m + torch.exp(self.logs) * x
397
- y = y * x_mask
398
- logdet = torch.sum(self.logs * x_mask, [1, 2])
399
- return y, logdet
400
- else:
401
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
402
- return x
403
-
404
-
405
- class ResidualCouplingLayer(nn.Module):
406
- def __init__(
407
- self,
408
- channels,
409
- hidden_channels,
410
- kernel_size,
411
- dilation_rate,
412
- n_layers,
413
- p_dropout=0,
414
- gin_channels=0,
415
- mean_only=False,
416
- ):
417
- assert channels % 2 == 0, "channels should be divisible by 2"
418
- super().__init__()
419
- self.channels = channels
420
- self.hidden_channels = hidden_channels
421
- self.kernel_size = kernel_size
422
- self.dilation_rate = dilation_rate
423
- self.n_layers = n_layers
424
- self.half_channels = channels // 2
425
- self.mean_only = mean_only
426
-
427
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428
- self.enc = WN(
429
- hidden_channels,
430
- kernel_size,
431
- dilation_rate,
432
- n_layers,
433
- p_dropout=p_dropout,
434
- gin_channels=gin_channels,
435
- )
436
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437
- self.post.weight.data.zero_()
438
- self.post.bias.data.zero_()
439
-
440
- def forward(self, x, x_mask, g=None, reverse=False):
441
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442
- h = self.pre(x0) * x_mask
443
- h = self.enc(h, x_mask, g=g)
444
- stats = self.post(h) * x_mask
445
- if not self.mean_only:
446
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447
- else:
448
- m = stats
449
- logs = torch.zeros_like(m)
450
-
451
- if not reverse:
452
- x1 = m + x1 * torch.exp(logs) * x_mask
453
- x = torch.cat([x0, x1], 1)
454
- logdet = torch.sum(logs, [1, 2])
455
- return x, logdet
456
- else:
457
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
458
- x = torch.cat([x0, x1], 1)
459
- return x
460
-
461
- def remove_weight_norm(self):
462
- self.enc.remove_weight_norm()
463
-
464
-
465
- class ConvFlow(nn.Module):
466
- def __init__(
467
- self,
468
- in_channels,
469
- filter_channels,
470
- kernel_size,
471
- n_layers,
472
- num_bins=10,
473
- tail_bound=5.0,
474
- ):
475
- super().__init__()
476
- self.in_channels = in_channels
477
- self.filter_channels = filter_channels
478
- self.kernel_size = kernel_size
479
- self.n_layers = n_layers
480
- self.num_bins = num_bins
481
- self.tail_bound = tail_bound
482
- self.half_channels = in_channels // 2
483
-
484
- self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
485
- self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
486
- self.proj = nn.Conv1d(
487
- filter_channels, self.half_channels * (num_bins * 3 - 1), 1
488
- )
489
- self.proj.weight.data.zero_()
490
- self.proj.bias.data.zero_()
491
-
492
- def forward(self, x, x_mask, g=None, reverse=False):
493
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
494
- h = self.pre(x0)
495
- h = self.convs(h, x_mask, g=g)
496
- h = self.proj(h) * x_mask
497
-
498
- b, c, t = x0.shape
499
- h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
500
-
501
- unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
502
- unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
503
- self.filter_channels
504
- )
505
- unnormalized_derivatives = h[..., 2 * self.num_bins :]
506
-
507
- x1, logabsdet = piecewise_rational_quadratic_transform(
508
- x1,
509
- unnormalized_widths,
510
- unnormalized_heights,
511
- unnormalized_derivatives,
512
- inverse=reverse,
513
- tails="linear",
514
- tail_bound=self.tail_bound,
515
- )
516
-
517
- x = torch.cat([x0, x1], 1) * x_mask
518
- logdet = torch.sum(logabsdet * x_mask, [1, 2])
519
- if not reverse:
520
- return x, logdet
521
- else:
522
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
infer_pack/transforms.py DELETED
@@ -1,209 +0,0 @@
1
- import torch
2
- from torch.nn import functional as F
3
-
4
- import numpy as np
5
-
6
-
7
- DEFAULT_MIN_BIN_WIDTH = 1e-3
8
- DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
- DEFAULT_MIN_DERIVATIVE = 1e-3
10
-
11
-
12
- def piecewise_rational_quadratic_transform(
13
- inputs,
14
- unnormalized_widths,
15
- unnormalized_heights,
16
- unnormalized_derivatives,
17
- inverse=False,
18
- tails=None,
19
- tail_bound=1.0,
20
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
- min_derivative=DEFAULT_MIN_DERIVATIVE,
23
- ):
24
- if tails is None:
25
- spline_fn = rational_quadratic_spline
26
- spline_kwargs = {}
27
- else:
28
- spline_fn = unconstrained_rational_quadratic_spline
29
- spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
-
31
- outputs, logabsdet = spline_fn(
32
- inputs=inputs,
33
- unnormalized_widths=unnormalized_widths,
34
- unnormalized_heights=unnormalized_heights,
35
- unnormalized_derivatives=unnormalized_derivatives,
36
- inverse=inverse,
37
- min_bin_width=min_bin_width,
38
- min_bin_height=min_bin_height,
39
- min_derivative=min_derivative,
40
- **spline_kwargs
41
- )
42
- return outputs, logabsdet
43
-
44
-
45
- def searchsorted(bin_locations, inputs, eps=1e-6):
46
- bin_locations[..., -1] += eps
47
- return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
-
49
-
50
- def unconstrained_rational_quadratic_spline(
51
- inputs,
52
- unnormalized_widths,
53
- unnormalized_heights,
54
- unnormalized_derivatives,
55
- inverse=False,
56
- tails="linear",
57
- tail_bound=1.0,
58
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
- min_derivative=DEFAULT_MIN_DERIVATIVE,
61
- ):
62
- inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
- outside_interval_mask = ~inside_interval_mask
64
-
65
- outputs = torch.zeros_like(inputs)
66
- logabsdet = torch.zeros_like(inputs)
67
-
68
- if tails == "linear":
69
- unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
- constant = np.log(np.exp(1 - min_derivative) - 1)
71
- unnormalized_derivatives[..., 0] = constant
72
- unnormalized_derivatives[..., -1] = constant
73
-
74
- outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
- logabsdet[outside_interval_mask] = 0
76
- else:
77
- raise RuntimeError("{} tails are not implemented.".format(tails))
78
-
79
- (
80
- outputs[inside_interval_mask],
81
- logabsdet[inside_interval_mask],
82
- ) = rational_quadratic_spline(
83
- inputs=inputs[inside_interval_mask],
84
- unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
- unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
- unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
- inverse=inverse,
88
- left=-tail_bound,
89
- right=tail_bound,
90
- bottom=-tail_bound,
91
- top=tail_bound,
92
- min_bin_width=min_bin_width,
93
- min_bin_height=min_bin_height,
94
- min_derivative=min_derivative,
95
- )
96
-
97
- return outputs, logabsdet
98
-
99
-
100
- def rational_quadratic_spline(
101
- inputs,
102
- unnormalized_widths,
103
- unnormalized_heights,
104
- unnormalized_derivatives,
105
- inverse=False,
106
- left=0.0,
107
- right=1.0,
108
- bottom=0.0,
109
- top=1.0,
110
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
- min_derivative=DEFAULT_MIN_DERIVATIVE,
113
- ):
114
- if torch.min(inputs) < left or torch.max(inputs) > right:
115
- raise ValueError("Input to a transform is not within its domain")
116
-
117
- num_bins = unnormalized_widths.shape[-1]
118
-
119
- if min_bin_width * num_bins > 1.0:
120
- raise ValueError("Minimal bin width too large for the number of bins")
121
- if min_bin_height * num_bins > 1.0:
122
- raise ValueError("Minimal bin height too large for the number of bins")
123
-
124
- widths = F.softmax(unnormalized_widths, dim=-1)
125
- widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
- cumwidths = torch.cumsum(widths, dim=-1)
127
- cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
- cumwidths = (right - left) * cumwidths + left
129
- cumwidths[..., 0] = left
130
- cumwidths[..., -1] = right
131
- widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
-
133
- derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
-
135
- heights = F.softmax(unnormalized_heights, dim=-1)
136
- heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
- cumheights = torch.cumsum(heights, dim=-1)
138
- cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
- cumheights = (top - bottom) * cumheights + bottom
140
- cumheights[..., 0] = bottom
141
- cumheights[..., -1] = top
142
- heights = cumheights[..., 1:] - cumheights[..., :-1]
143
-
144
- if inverse:
145
- bin_idx = searchsorted(cumheights, inputs)[..., None]
146
- else:
147
- bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
-
149
- input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
- input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
-
152
- input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
- delta = heights / widths
154
- input_delta = delta.gather(-1, bin_idx)[..., 0]
155
-
156
- input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
- input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
-
159
- input_heights = heights.gather(-1, bin_idx)[..., 0]
160
-
161
- if inverse:
162
- a = (inputs - input_cumheights) * (
163
- input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
- ) + input_heights * (input_delta - input_derivatives)
165
- b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
- input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
- )
168
- c = -input_delta * (inputs - input_cumheights)
169
-
170
- discriminant = b.pow(2) - 4 * a * c
171
- assert (discriminant >= 0).all()
172
-
173
- root = (2 * c) / (-b - torch.sqrt(discriminant))
174
- outputs = root * input_bin_widths + input_cumwidths
175
-
176
- theta_one_minus_theta = root * (1 - root)
177
- denominator = input_delta + (
178
- (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
- * theta_one_minus_theta
180
- )
181
- derivative_numerator = input_delta.pow(2) * (
182
- input_derivatives_plus_one * root.pow(2)
183
- + 2 * input_delta * theta_one_minus_theta
184
- + input_derivatives * (1 - root).pow(2)
185
- )
186
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
-
188
- return outputs, -logabsdet
189
- else:
190
- theta = (inputs - input_cumwidths) / input_bin_widths
191
- theta_one_minus_theta = theta * (1 - theta)
192
-
193
- numerator = input_heights * (
194
- input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
- )
196
- denominator = input_delta + (
197
- (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
- * theta_one_minus_theta
199
- )
200
- outputs = input_cumheights + numerator / denominator
201
-
202
- derivative_numerator = input_delta.pow(2) * (
203
- input_derivatives_plus_one * theta.pow(2)
204
- + 2 * input_delta * theta_one_minus_theta
205
- + input_derivatives * (1 - theta).pow(2)
206
- )
207
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
-
209
- return outputs, logabsdet