Blane187 commited on
Commit
6e3f338
·
verified ·
1 Parent(s): 4c03ee5

Delete lib

Browse files
lib/infer_pack/attentions.py DELETED
@@ -1,417 +0,0 @@
1
- import copy
2
- import math
3
- import numpy as np
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
-
8
- from lib.infer_pack import commons
9
- from lib.infer_pack import modules
10
- from lib.infer_pack.modules import LayerNorm
11
-
12
-
13
- class Encoder(nn.Module):
14
- def __init__(
15
- self,
16
- hidden_channels,
17
- filter_channels,
18
- n_heads,
19
- n_layers,
20
- kernel_size=1,
21
- p_dropout=0.0,
22
- window_size=10,
23
- **kwargs
24
- ):
25
- super().__init__()
26
- self.hidden_channels = hidden_channels
27
- self.filter_channels = filter_channels
28
- self.n_heads = n_heads
29
- self.n_layers = n_layers
30
- self.kernel_size = kernel_size
31
- self.p_dropout = p_dropout
32
- self.window_size = window_size
33
-
34
- self.drop = nn.Dropout(p_dropout)
35
- self.attn_layers = nn.ModuleList()
36
- self.norm_layers_1 = nn.ModuleList()
37
- self.ffn_layers = nn.ModuleList()
38
- self.norm_layers_2 = nn.ModuleList()
39
- for i in range(self.n_layers):
40
- self.attn_layers.append(
41
- MultiHeadAttention(
42
- hidden_channels,
43
- hidden_channels,
44
- n_heads,
45
- p_dropout=p_dropout,
46
- window_size=window_size,
47
- )
48
- )
49
- self.norm_layers_1.append(LayerNorm(hidden_channels))
50
- self.ffn_layers.append(
51
- FFN(
52
- hidden_channels,
53
- hidden_channels,
54
- filter_channels,
55
- kernel_size,
56
- p_dropout=p_dropout,
57
- )
58
- )
59
- self.norm_layers_2.append(LayerNorm(hidden_channels))
60
-
61
- def forward(self, x, x_mask):
62
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
63
- x = x * x_mask
64
- for i in range(self.n_layers):
65
- y = self.attn_layers[i](x, x, attn_mask)
66
- y = self.drop(y)
67
- x = self.norm_layers_1[i](x + y)
68
-
69
- y = self.ffn_layers[i](x, x_mask)
70
- y = self.drop(y)
71
- x = self.norm_layers_2[i](x + y)
72
- x = x * x_mask
73
- return x
74
-
75
-
76
- class Decoder(nn.Module):
77
- def __init__(
78
- self,
79
- hidden_channels,
80
- filter_channels,
81
- n_heads,
82
- n_layers,
83
- kernel_size=1,
84
- p_dropout=0.0,
85
- proximal_bias=False,
86
- proximal_init=True,
87
- **kwargs
88
- ):
89
- super().__init__()
90
- self.hidden_channels = hidden_channels
91
- self.filter_channels = filter_channels
92
- self.n_heads = n_heads
93
- self.n_layers = n_layers
94
- self.kernel_size = kernel_size
95
- self.p_dropout = p_dropout
96
- self.proximal_bias = proximal_bias
97
- self.proximal_init = proximal_init
98
-
99
- self.drop = nn.Dropout(p_dropout)
100
- self.self_attn_layers = nn.ModuleList()
101
- self.norm_layers_0 = nn.ModuleList()
102
- self.encdec_attn_layers = nn.ModuleList()
103
- self.norm_layers_1 = nn.ModuleList()
104
- self.ffn_layers = nn.ModuleList()
105
- self.norm_layers_2 = nn.ModuleList()
106
- for i in range(self.n_layers):
107
- self.self_attn_layers.append(
108
- MultiHeadAttention(
109
- hidden_channels,
110
- hidden_channels,
111
- n_heads,
112
- p_dropout=p_dropout,
113
- proximal_bias=proximal_bias,
114
- proximal_init=proximal_init,
115
- )
116
- )
117
- self.norm_layers_0.append(LayerNorm(hidden_channels))
118
- self.encdec_attn_layers.append(
119
- MultiHeadAttention(
120
- hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
121
- )
122
- )
123
- self.norm_layers_1.append(LayerNorm(hidden_channels))
124
- self.ffn_layers.append(
125
- FFN(
126
- hidden_channels,
127
- hidden_channels,
128
- filter_channels,
129
- kernel_size,
130
- p_dropout=p_dropout,
131
- causal=True,
132
- )
133
- )
134
- self.norm_layers_2.append(LayerNorm(hidden_channels))
135
-
136
- def forward(self, x, x_mask, h, h_mask):
137
- """
138
- x: decoder input
139
- h: encoder output
140
- """
141
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
142
- device=x.device, dtype=x.dtype
143
- )
144
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
145
- x = x * x_mask
146
- for i in range(self.n_layers):
147
- y = self.self_attn_layers[i](x, x, self_attn_mask)
148
- y = self.drop(y)
149
- x = self.norm_layers_0[i](x + y)
150
-
151
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
152
- y = self.drop(y)
153
- x = self.norm_layers_1[i](x + y)
154
-
155
- y = self.ffn_layers[i](x, x_mask)
156
- y = self.drop(y)
157
- x = self.norm_layers_2[i](x + y)
158
- x = x * x_mask
159
- return x
160
-
161
-
162
- class MultiHeadAttention(nn.Module):
163
- def __init__(
164
- self,
165
- channels,
166
- out_channels,
167
- n_heads,
168
- p_dropout=0.0,
169
- window_size=None,
170
- heads_share=True,
171
- block_length=None,
172
- proximal_bias=False,
173
- proximal_init=False,
174
- ):
175
- super().__init__()
176
- assert channels % n_heads == 0
177
-
178
- self.channels = channels
179
- self.out_channels = out_channels
180
- self.n_heads = n_heads
181
- self.p_dropout = p_dropout
182
- self.window_size = window_size
183
- self.heads_share = heads_share
184
- self.block_length = block_length
185
- self.proximal_bias = proximal_bias
186
- self.proximal_init = proximal_init
187
- self.attn = None
188
-
189
- self.k_channels = channels // n_heads
190
- self.conv_q = nn.Conv1d(channels, channels, 1)
191
- self.conv_k = nn.Conv1d(channels, channels, 1)
192
- self.conv_v = nn.Conv1d(channels, channels, 1)
193
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
194
- self.drop = nn.Dropout(p_dropout)
195
-
196
- if window_size is not None:
197
- n_heads_rel = 1 if heads_share else n_heads
198
- rel_stddev = self.k_channels**-0.5
199
- self.emb_rel_k = nn.Parameter(
200
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
201
- * rel_stddev
202
- )
203
- self.emb_rel_v = nn.Parameter(
204
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
205
- * rel_stddev
206
- )
207
-
208
- nn.init.xavier_uniform_(self.conv_q.weight)
209
- nn.init.xavier_uniform_(self.conv_k.weight)
210
- nn.init.xavier_uniform_(self.conv_v.weight)
211
- if proximal_init:
212
- with torch.no_grad():
213
- self.conv_k.weight.copy_(self.conv_q.weight)
214
- self.conv_k.bias.copy_(self.conv_q.bias)
215
-
216
- def forward(self, x, c, attn_mask=None):
217
- q = self.conv_q(x)
218
- k = self.conv_k(c)
219
- v = self.conv_v(c)
220
-
221
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
222
-
223
- x = self.conv_o(x)
224
- return x
225
-
226
- def attention(self, query, key, value, mask=None):
227
- # reshape [b, d, t] -> [b, n_h, t, d_k]
228
- b, d, t_s, t_t = (*key.size(), query.size(2))
229
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
230
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
231
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
232
-
233
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
234
- if self.window_size is not None:
235
- assert (
236
- t_s == t_t
237
- ), "Relative attention is only available for self-attention."
238
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
239
- rel_logits = self._matmul_with_relative_keys(
240
- query / math.sqrt(self.k_channels), key_relative_embeddings
241
- )
242
- scores_local = self._relative_position_to_absolute_position(rel_logits)
243
- scores = scores + scores_local
244
- if self.proximal_bias:
245
- assert t_s == t_t, "Proximal bias is only available for self-attention."
246
- scores = scores + self._attention_bias_proximal(t_s).to(
247
- device=scores.device, dtype=scores.dtype
248
- )
249
- if mask is not None:
250
- scores = scores.masked_fill(mask == 0, -1e4)
251
- if self.block_length is not None:
252
- assert (
253
- t_s == t_t
254
- ), "Local attention is only available for self-attention."
255
- block_mask = (
256
- torch.ones_like(scores)
257
- .triu(-self.block_length)
258
- .tril(self.block_length)
259
- )
260
- scores = scores.masked_fill(block_mask == 0, -1e4)
261
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
262
- p_attn = self.drop(p_attn)
263
- output = torch.matmul(p_attn, value)
264
- if self.window_size is not None:
265
- relative_weights = self._absolute_position_to_relative_position(p_attn)
266
- value_relative_embeddings = self._get_relative_embeddings(
267
- self.emb_rel_v, t_s
268
- )
269
- output = output + self._matmul_with_relative_values(
270
- relative_weights, value_relative_embeddings
271
- )
272
- output = (
273
- output.transpose(2, 3).contiguous().view(b, d, t_t)
274
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
275
- return output, p_attn
276
-
277
- def _matmul_with_relative_values(self, x, y):
278
- """
279
- x: [b, h, l, m]
280
- y: [h or 1, m, d]
281
- ret: [b, h, l, d]
282
- """
283
- ret = torch.matmul(x, y.unsqueeze(0))
284
- return ret
285
-
286
- def _matmul_with_relative_keys(self, x, y):
287
- """
288
- x: [b, h, l, d]
289
- y: [h or 1, m, d]
290
- ret: [b, h, l, m]
291
- """
292
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
293
- return ret
294
-
295
- def _get_relative_embeddings(self, relative_embeddings, length):
296
- max_relative_position = 2 * self.window_size + 1
297
- # Pad first before slice to avoid using cond ops.
298
- pad_length = max(length - (self.window_size + 1), 0)
299
- slice_start_position = max((self.window_size + 1) - length, 0)
300
- slice_end_position = slice_start_position + 2 * length - 1
301
- if pad_length > 0:
302
- padded_relative_embeddings = F.pad(
303
- relative_embeddings,
304
- commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
305
- )
306
- else:
307
- padded_relative_embeddings = relative_embeddings
308
- used_relative_embeddings = padded_relative_embeddings[
309
- :, slice_start_position:slice_end_position
310
- ]
311
- return used_relative_embeddings
312
-
313
- def _relative_position_to_absolute_position(self, x):
314
- """
315
- x: [b, h, l, 2*l-1]
316
- ret: [b, h, l, l]
317
- """
318
- batch, heads, length, _ = x.size()
319
- # Concat columns of pad to shift from relative to absolute indexing.
320
- x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
321
-
322
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
323
- x_flat = x.view([batch, heads, length * 2 * length])
324
- x_flat = F.pad(
325
- x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
326
- )
327
-
328
- # Reshape and slice out the padded elements.
329
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
330
- :, :, :length, length - 1 :
331
- ]
332
- return x_final
333
-
334
- def _absolute_position_to_relative_position(self, x):
335
- """
336
- x: [b, h, l, l]
337
- ret: [b, h, l, 2*l-1]
338
- """
339
- batch, heads, length, _ = x.size()
340
- # padd along column
341
- x = F.pad(
342
- x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
343
- )
344
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
345
- # add 0's in the beginning that will skew the elements after reshape
346
- x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
347
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
348
- return x_final
349
-
350
- def _attention_bias_proximal(self, length):
351
- """Bias for self-attention to encourage attention to close positions.
352
- Args:
353
- length: an integer scalar.
354
- Returns:
355
- a Tensor with shape [1, 1, length, length]
356
- """
357
- r = torch.arange(length, dtype=torch.float32)
358
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
359
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
360
-
361
-
362
- class FFN(nn.Module):
363
- def __init__(
364
- self,
365
- in_channels,
366
- out_channels,
367
- filter_channels,
368
- kernel_size,
369
- p_dropout=0.0,
370
- activation=None,
371
- causal=False,
372
- ):
373
- super().__init__()
374
- self.in_channels = in_channels
375
- self.out_channels = out_channels
376
- self.filter_channels = filter_channels
377
- self.kernel_size = kernel_size
378
- self.p_dropout = p_dropout
379
- self.activation = activation
380
- self.causal = causal
381
-
382
- if causal:
383
- self.padding = self._causal_padding
384
- else:
385
- self.padding = self._same_padding
386
-
387
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
388
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
389
- self.drop = nn.Dropout(p_dropout)
390
-
391
- def forward(self, x, x_mask):
392
- x = self.conv_1(self.padding(x * x_mask))
393
- if self.activation == "gelu":
394
- x = x * torch.sigmoid(1.702 * x)
395
- else:
396
- x = torch.relu(x)
397
- x = self.drop(x)
398
- x = self.conv_2(self.padding(x * x_mask))
399
- return x * x_mask
400
-
401
- def _causal_padding(self, x):
402
- if self.kernel_size == 1:
403
- return x
404
- pad_l = self.kernel_size - 1
405
- pad_r = 0
406
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
407
- x = F.pad(x, commons.convert_pad_shape(padding))
408
- return x
409
-
410
- def _same_padding(self, x):
411
- if self.kernel_size == 1:
412
- return x
413
- pad_l = (self.kernel_size - 1) // 2
414
- pad_r = self.kernel_size // 2
415
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
416
- x = F.pad(x, commons.convert_pad_shape(padding))
417
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/infer_pack/commons.py DELETED
@@ -1,166 +0,0 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
-
7
-
8
- def init_weights(m, mean=0.0, std=0.01):
9
- classname = m.__class__.__name__
10
- if classname.find("Conv") != -1:
11
- m.weight.data.normal_(mean, std)
12
-
13
-
14
- def get_padding(kernel_size, dilation=1):
15
- return int((kernel_size * dilation - dilation) / 2)
16
-
17
-
18
- def convert_pad_shape(pad_shape):
19
- l = pad_shape[::-1]
20
- pad_shape = [item for sublist in l for item in sublist]
21
- return pad_shape
22
-
23
-
24
- def kl_divergence(m_p, logs_p, m_q, logs_q):
25
- """KL(P||Q)"""
26
- kl = (logs_q - logs_p) - 0.5
27
- kl += (
28
- 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
29
- )
30
- return kl
31
-
32
-
33
- def rand_gumbel(shape):
34
- """Sample from the Gumbel distribution, protect from overflows."""
35
- uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
36
- return -torch.log(-torch.log(uniform_samples))
37
-
38
-
39
- def rand_gumbel_like(x):
40
- g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
41
- return g
42
-
43
-
44
- def slice_segments(x, ids_str, segment_size=4):
45
- ret = torch.zeros_like(x[:, :, :segment_size])
46
- for i in range(x.size(0)):
47
- idx_str = ids_str[i]
48
- idx_end = idx_str + segment_size
49
- ret[i] = x[i, :, idx_str:idx_end]
50
- return ret
51
-
52
-
53
- def slice_segments2(x, ids_str, segment_size=4):
54
- ret = torch.zeros_like(x[:, :segment_size])
55
- for i in range(x.size(0)):
56
- idx_str = ids_str[i]
57
- idx_end = idx_str + segment_size
58
- ret[i] = x[i, idx_str:idx_end]
59
- return ret
60
-
61
-
62
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
63
- b, d, t = x.size()
64
- if x_lengths is None:
65
- x_lengths = t
66
- ids_str_max = x_lengths - segment_size + 1
67
- ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
68
- ret = slice_segments(x, ids_str, segment_size)
69
- return ret, ids_str
70
-
71
-
72
- def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
73
- position = torch.arange(length, dtype=torch.float)
74
- num_timescales = channels // 2
75
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
76
- num_timescales - 1
77
- )
78
- inv_timescales = min_timescale * torch.exp(
79
- torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
80
- )
81
- scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
82
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
83
- signal = F.pad(signal, [0, 0, 0, channels % 2])
84
- signal = signal.view(1, channels, length)
85
- return signal
86
-
87
-
88
- def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
89
- b, channels, length = x.size()
90
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
91
- return x + signal.to(dtype=x.dtype, device=x.device)
92
-
93
-
94
- def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
95
- b, channels, length = x.size()
96
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
97
- return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
98
-
99
-
100
- def subsequent_mask(length):
101
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
102
- return mask
103
-
104
-
105
- @torch.jit.script
106
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
107
- n_channels_int = n_channels[0]
108
- in_act = input_a + input_b
109
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
110
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
111
- acts = t_act * s_act
112
- return acts
113
-
114
-
115
- def convert_pad_shape(pad_shape):
116
- l = pad_shape[::-1]
117
- pad_shape = [item for sublist in l for item in sublist]
118
- return pad_shape
119
-
120
-
121
- def shift_1d(x):
122
- x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
123
- return x
124
-
125
-
126
- def sequence_mask(length, max_length=None):
127
- if max_length is None:
128
- max_length = length.max()
129
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
130
- return x.unsqueeze(0) < length.unsqueeze(1)
131
-
132
-
133
- def generate_path(duration, mask):
134
- """
135
- duration: [b, 1, t_x]
136
- mask: [b, 1, t_y, t_x]
137
- """
138
- device = duration.device
139
-
140
- b, _, t_y, t_x = mask.shape
141
- cum_duration = torch.cumsum(duration, -1)
142
-
143
- cum_duration_flat = cum_duration.view(b * t_x)
144
- path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
145
- path = path.view(b, t_x, t_y)
146
- path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
147
- path = path.unsqueeze(1).transpose(2, 3) * mask
148
- return path
149
-
150
-
151
- def clip_grad_value_(parameters, clip_value, norm_type=2):
152
- if isinstance(parameters, torch.Tensor):
153
- parameters = [parameters]
154
- parameters = list(filter(lambda p: p.grad is not None, parameters))
155
- norm_type = float(norm_type)
156
- if clip_value is not None:
157
- clip_value = float(clip_value)
158
-
159
- total_norm = 0
160
- for p in parameters:
161
- param_norm = p.grad.data.norm(norm_type)
162
- total_norm += param_norm.item() ** norm_type
163
- if clip_value is not None:
164
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
165
- total_norm = total_norm ** (1.0 / norm_type)
166
- return total_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/infer_pack/models.py DELETED
@@ -1,1142 +0,0 @@
1
- import math, pdb, os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from lib.infer_pack import modules
7
- from lib.infer_pack import attentions
8
- from lib.infer_pack import commons
9
- from lib.infer_pack.commons import init_weights, get_padding
10
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from lib.infer_pack.commons import init_weights
13
- import numpy as np
14
- from lib.infer_pack import commons
15
-
16
-
17
- class TextEncoder256(nn.Module):
18
- def __init__(
19
- self,
20
- out_channels,
21
- hidden_channels,
22
- filter_channels,
23
- n_heads,
24
- n_layers,
25
- kernel_size,
26
- p_dropout,
27
- f0=True,
28
- ):
29
- super().__init__()
30
- self.out_channels = out_channels
31
- self.hidden_channels = hidden_channels
32
- self.filter_channels = filter_channels
33
- self.n_heads = n_heads
34
- self.n_layers = n_layers
35
- self.kernel_size = kernel_size
36
- self.p_dropout = p_dropout
37
- self.emb_phone = nn.Linear(256, hidden_channels)
38
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
- if f0 == True:
40
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
- self.encoder = attentions.Encoder(
42
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
- )
44
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
-
46
- def forward(self, phone, pitch, lengths):
47
- if pitch == None:
48
- x = self.emb_phone(phone)
49
- else:
50
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
- x = self.lrelu(x)
53
- x = torch.transpose(x, 1, -1) # [b, h, t]
54
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
- x.dtype
56
- )
57
- x = self.encoder(x * x_mask, x_mask)
58
- stats = self.proj(x) * x_mask
59
-
60
- m, logs = torch.split(stats, self.out_channels, dim=1)
61
- return m, logs, x_mask
62
-
63
-
64
- class TextEncoder768(nn.Module):
65
- def __init__(
66
- self,
67
- out_channels,
68
- hidden_channels,
69
- filter_channels,
70
- n_heads,
71
- n_layers,
72
- kernel_size,
73
- p_dropout,
74
- f0=True,
75
- ):
76
- super().__init__()
77
- self.out_channels = out_channels
78
- self.hidden_channels = hidden_channels
79
- self.filter_channels = filter_channels
80
- self.n_heads = n_heads
81
- self.n_layers = n_layers
82
- self.kernel_size = kernel_size
83
- self.p_dropout = p_dropout
84
- self.emb_phone = nn.Linear(768, hidden_channels)
85
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
- if f0 == True:
87
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
- self.encoder = attentions.Encoder(
89
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
- )
91
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
92
-
93
- def forward(self, phone, pitch, lengths):
94
- if pitch == None:
95
- x = self.emb_phone(phone)
96
- else:
97
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
- x = self.lrelu(x)
100
- x = torch.transpose(x, 1, -1) # [b, h, t]
101
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
- x.dtype
103
- )
104
- x = self.encoder(x * x_mask, x_mask)
105
- stats = self.proj(x) * x_mask
106
-
107
- m, logs = torch.split(stats, self.out_channels, dim=1)
108
- return m, logs, x_mask
109
-
110
-
111
- class ResidualCouplingBlock(nn.Module):
112
- def __init__(
113
- self,
114
- channels,
115
- hidden_channels,
116
- kernel_size,
117
- dilation_rate,
118
- n_layers,
119
- n_flows=4,
120
- gin_channels=0,
121
- ):
122
- super().__init__()
123
- self.channels = channels
124
- self.hidden_channels = hidden_channels
125
- self.kernel_size = kernel_size
126
- self.dilation_rate = dilation_rate
127
- self.n_layers = n_layers
128
- self.n_flows = n_flows
129
- self.gin_channels = gin_channels
130
-
131
- self.flows = nn.ModuleList()
132
- for i in range(n_flows):
133
- self.flows.append(
134
- modules.ResidualCouplingLayer(
135
- channels,
136
- hidden_channels,
137
- kernel_size,
138
- dilation_rate,
139
- n_layers,
140
- gin_channels=gin_channels,
141
- mean_only=True,
142
- )
143
- )
144
- self.flows.append(modules.Flip())
145
-
146
- def forward(self, x, x_mask, g=None, reverse=False):
147
- if not reverse:
148
- for flow in self.flows:
149
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
150
- else:
151
- for flow in reversed(self.flows):
152
- x = flow(x, x_mask, g=g, reverse=reverse)
153
- return x
154
-
155
- def remove_weight_norm(self):
156
- for i in range(self.n_flows):
157
- self.flows[i * 2].remove_weight_norm()
158
-
159
-
160
- class PosteriorEncoder(nn.Module):
161
- def __init__(
162
- self,
163
- in_channels,
164
- out_channels,
165
- hidden_channels,
166
- kernel_size,
167
- dilation_rate,
168
- n_layers,
169
- gin_channels=0,
170
- ):
171
- super().__init__()
172
- self.in_channels = in_channels
173
- self.out_channels = out_channels
174
- self.hidden_channels = hidden_channels
175
- self.kernel_size = kernel_size
176
- self.dilation_rate = dilation_rate
177
- self.n_layers = n_layers
178
- self.gin_channels = gin_channels
179
-
180
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
181
- self.enc = modules.WN(
182
- hidden_channels,
183
- kernel_size,
184
- dilation_rate,
185
- n_layers,
186
- gin_channels=gin_channels,
187
- )
188
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
189
-
190
- def forward(self, x, x_lengths, g=None):
191
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
192
- x.dtype
193
- )
194
- x = self.pre(x) * x_mask
195
- x = self.enc(x, x_mask, g=g)
196
- stats = self.proj(x) * x_mask
197
- m, logs = torch.split(stats, self.out_channels, dim=1)
198
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
199
- return z, m, logs, x_mask
200
-
201
- def remove_weight_norm(self):
202
- self.enc.remove_weight_norm()
203
-
204
-
205
- class Generator(torch.nn.Module):
206
- def __init__(
207
- self,
208
- initial_channel,
209
- resblock,
210
- resblock_kernel_sizes,
211
- resblock_dilation_sizes,
212
- upsample_rates,
213
- upsample_initial_channel,
214
- upsample_kernel_sizes,
215
- gin_channels=0,
216
- ):
217
- super(Generator, self).__init__()
218
- self.num_kernels = len(resblock_kernel_sizes)
219
- self.num_upsamples = len(upsample_rates)
220
- self.conv_pre = Conv1d(
221
- initial_channel, upsample_initial_channel, 7, 1, padding=3
222
- )
223
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
224
-
225
- self.ups = nn.ModuleList()
226
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
227
- self.ups.append(
228
- weight_norm(
229
- ConvTranspose1d(
230
- upsample_initial_channel // (2**i),
231
- upsample_initial_channel // (2 ** (i + 1)),
232
- k,
233
- u,
234
- padding=(k - u) // 2,
235
- )
236
- )
237
- )
238
-
239
- self.resblocks = nn.ModuleList()
240
- for i in range(len(self.ups)):
241
- ch = upsample_initial_channel // (2 ** (i + 1))
242
- for j, (k, d) in enumerate(
243
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
244
- ):
245
- self.resblocks.append(resblock(ch, k, d))
246
-
247
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
248
- self.ups.apply(init_weights)
249
-
250
- if gin_channels != 0:
251
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
252
-
253
- def forward(self, x, g=None):
254
- x = self.conv_pre(x)
255
- if g is not None:
256
- x = x + self.cond(g)
257
-
258
- for i in range(self.num_upsamples):
259
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
260
- x = self.ups[i](x)
261
- xs = None
262
- for j in range(self.num_kernels):
263
- if xs is None:
264
- xs = self.resblocks[i * self.num_kernels + j](x)
265
- else:
266
- xs += self.resblocks[i * self.num_kernels + j](x)
267
- x = xs / self.num_kernels
268
- x = F.leaky_relu(x)
269
- x = self.conv_post(x)
270
- x = torch.tanh(x)
271
-
272
- return x
273
-
274
- def remove_weight_norm(self):
275
- for l in self.ups:
276
- remove_weight_norm(l)
277
- for l in self.resblocks:
278
- l.remove_weight_norm()
279
-
280
-
281
- class SineGen(torch.nn.Module):
282
- """Definition of sine generator
283
- SineGen(samp_rate, harmonic_num = 0,
284
- sine_amp = 0.1, noise_std = 0.003,
285
- voiced_threshold = 0,
286
- flag_for_pulse=False)
287
- samp_rate: sampling rate in Hz
288
- harmonic_num: number of harmonic overtones (default 0)
289
- sine_amp: amplitude of sine-wavefrom (default 0.1)
290
- noise_std: std of Gaussian noise (default 0.003)
291
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
292
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
293
- Note: when flag_for_pulse is True, the first time step of a voiced
294
- segment is always sin(np.pi) or cos(0)
295
- """
296
-
297
- def __init__(
298
- self,
299
- samp_rate,
300
- harmonic_num=0,
301
- sine_amp=0.1,
302
- noise_std=0.003,
303
- voiced_threshold=0,
304
- flag_for_pulse=False,
305
- ):
306
- super(SineGen, self).__init__()
307
- self.sine_amp = sine_amp
308
- self.noise_std = noise_std
309
- self.harmonic_num = harmonic_num
310
- self.dim = self.harmonic_num + 1
311
- self.sampling_rate = samp_rate
312
- self.voiced_threshold = voiced_threshold
313
-
314
- def _f02uv(self, f0):
315
- # generate uv signal
316
- uv = torch.ones_like(f0)
317
- uv = uv * (f0 > self.voiced_threshold)
318
- return uv
319
-
320
- def forward(self, f0, upp):
321
- """sine_tensor, uv = forward(f0)
322
- input F0: tensor(batchsize=1, length, dim=1)
323
- f0 for unvoiced steps should be 0
324
- output sine_tensor: tensor(batchsize=1, length, dim)
325
- output uv: tensor(batchsize=1, length, 1)
326
- """
327
- with torch.no_grad():
328
- f0 = f0[:, None].transpose(1, 2)
329
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
330
- # fundamental component
331
- f0_buf[:, :, 0] = f0[:, :, 0]
332
- for idx in np.arange(self.harmonic_num):
333
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
334
- idx + 2
335
- ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
336
- rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
337
- rand_ini = torch.rand(
338
- f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
339
- )
340
- rand_ini[:, 0] = 0
341
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
342
- tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
343
- tmp_over_one *= upp
344
- tmp_over_one = F.interpolate(
345
- tmp_over_one.transpose(2, 1),
346
- scale_factor=upp,
347
- mode="linear",
348
- align_corners=True,
349
- ).transpose(2, 1)
350
- rad_values = F.interpolate(
351
- rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
352
- ).transpose(
353
- 2, 1
354
- ) #######
355
- tmp_over_one %= 1
356
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
357
- cumsum_shift = torch.zeros_like(rad_values)
358
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
359
- sine_waves = torch.sin(
360
- torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
361
- )
362
- sine_waves = sine_waves * self.sine_amp
363
- uv = self._f02uv(f0)
364
- uv = F.interpolate(
365
- uv.transpose(2, 1), scale_factor=upp, mode="nearest"
366
- ).transpose(2, 1)
367
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
368
- noise = noise_amp * torch.randn_like(sine_waves)
369
- sine_waves = sine_waves * uv + noise
370
- return sine_waves, uv, noise
371
-
372
-
373
- class SourceModuleHnNSF(torch.nn.Module):
374
- """SourceModule for hn-nsf
375
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
376
- add_noise_std=0.003, voiced_threshod=0)
377
- sampling_rate: sampling_rate in Hz
378
- harmonic_num: number of harmonic above F0 (default: 0)
379
- sine_amp: amplitude of sine source signal (default: 0.1)
380
- add_noise_std: std of additive Gaussian noise (default: 0.003)
381
- note that amplitude of noise in unvoiced is decided
382
- by sine_amp
383
- voiced_threshold: threhold to set U/V given F0 (default: 0)
384
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
385
- F0_sampled (batchsize, length, 1)
386
- Sine_source (batchsize, length, 1)
387
- noise_source (batchsize, length 1)
388
- uv (batchsize, length, 1)
389
- """
390
-
391
- def __init__(
392
- self,
393
- sampling_rate,
394
- harmonic_num=0,
395
- sine_amp=0.1,
396
- add_noise_std=0.003,
397
- voiced_threshod=0,
398
- is_half=True,
399
- ):
400
- super(SourceModuleHnNSF, self).__init__()
401
-
402
- self.sine_amp = sine_amp
403
- self.noise_std = add_noise_std
404
- self.is_half = is_half
405
- # to produce sine waveforms
406
- self.l_sin_gen = SineGen(
407
- sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
408
- )
409
-
410
- # to merge source harmonics into a single excitation
411
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
412
- self.l_tanh = torch.nn.Tanh()
413
-
414
- def forward(self, x, upp=None):
415
- sine_wavs, uv, _ = self.l_sin_gen(x, upp)
416
- if self.is_half:
417
- sine_wavs = sine_wavs.half()
418
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
419
- return sine_merge, None, None # noise, uv
420
-
421
-
422
- class GeneratorNSF(torch.nn.Module):
423
- def __init__(
424
- self,
425
- initial_channel,
426
- resblock,
427
- resblock_kernel_sizes,
428
- resblock_dilation_sizes,
429
- upsample_rates,
430
- upsample_initial_channel,
431
- upsample_kernel_sizes,
432
- gin_channels,
433
- sr,
434
- is_half=False,
435
- ):
436
- super(GeneratorNSF, self).__init__()
437
- self.num_kernels = len(resblock_kernel_sizes)
438
- self.num_upsamples = len(upsample_rates)
439
-
440
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
441
- self.m_source = SourceModuleHnNSF(
442
- sampling_rate=sr, harmonic_num=0, is_half=is_half
443
- )
444
- self.noise_convs = nn.ModuleList()
445
- self.conv_pre = Conv1d(
446
- initial_channel, upsample_initial_channel, 7, 1, padding=3
447
- )
448
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
449
-
450
- self.ups = nn.ModuleList()
451
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
452
- c_cur = upsample_initial_channel // (2 ** (i + 1))
453
- self.ups.append(
454
- weight_norm(
455
- ConvTranspose1d(
456
- upsample_initial_channel // (2**i),
457
- upsample_initial_channel // (2 ** (i + 1)),
458
- k,
459
- u,
460
- padding=(k - u) // 2,
461
- )
462
- )
463
- )
464
- if i + 1 < len(upsample_rates):
465
- stride_f0 = np.prod(upsample_rates[i + 1 :])
466
- self.noise_convs.append(
467
- Conv1d(
468
- 1,
469
- c_cur,
470
- kernel_size=stride_f0 * 2,
471
- stride=stride_f0,
472
- padding=stride_f0 // 2,
473
- )
474
- )
475
- else:
476
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
477
-
478
- self.resblocks = nn.ModuleList()
479
- for i in range(len(self.ups)):
480
- ch = upsample_initial_channel // (2 ** (i + 1))
481
- for j, (k, d) in enumerate(
482
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
483
- ):
484
- self.resblocks.append(resblock(ch, k, d))
485
-
486
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
487
- self.ups.apply(init_weights)
488
-
489
- if gin_channels != 0:
490
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
491
-
492
- self.upp = np.prod(upsample_rates)
493
-
494
- def forward(self, x, f0, g=None):
495
- har_source, noi_source, uv = self.m_source(f0, self.upp)
496
- har_source = har_source.transpose(1, 2)
497
- x = self.conv_pre(x)
498
- if g is not None:
499
- x = x + self.cond(g)
500
-
501
- for i in range(self.num_upsamples):
502
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
503
- x = self.ups[i](x)
504
- x_source = self.noise_convs[i](har_source)
505
- x = x + x_source
506
- xs = None
507
- for j in range(self.num_kernels):
508
- if xs is None:
509
- xs = self.resblocks[i * self.num_kernels + j](x)
510
- else:
511
- xs += self.resblocks[i * self.num_kernels + j](x)
512
- x = xs / self.num_kernels
513
- x = F.leaky_relu(x)
514
- x = self.conv_post(x)
515
- x = torch.tanh(x)
516
- return x
517
-
518
- def remove_weight_norm(self):
519
- for l in self.ups:
520
- remove_weight_norm(l)
521
- for l in self.resblocks:
522
- l.remove_weight_norm()
523
-
524
-
525
- sr2sr = {
526
- "32k": 32000,
527
- "40k": 40000,
528
- "48k": 48000,
529
- }
530
-
531
-
532
- class SynthesizerTrnMs256NSFsid(nn.Module):
533
- def __init__(
534
- self,
535
- spec_channels,
536
- segment_size,
537
- inter_channels,
538
- hidden_channels,
539
- filter_channels,
540
- n_heads,
541
- n_layers,
542
- kernel_size,
543
- p_dropout,
544
- resblock,
545
- resblock_kernel_sizes,
546
- resblock_dilation_sizes,
547
- upsample_rates,
548
- upsample_initial_channel,
549
- upsample_kernel_sizes,
550
- spk_embed_dim,
551
- gin_channels,
552
- sr,
553
- **kwargs
554
- ):
555
- super().__init__()
556
- if type(sr) == type("strr"):
557
- sr = sr2sr[sr]
558
- self.spec_channels = spec_channels
559
- self.inter_channels = inter_channels
560
- self.hidden_channels = hidden_channels
561
- self.filter_channels = filter_channels
562
- self.n_heads = n_heads
563
- self.n_layers = n_layers
564
- self.kernel_size = kernel_size
565
- self.p_dropout = p_dropout
566
- self.resblock = resblock
567
- self.resblock_kernel_sizes = resblock_kernel_sizes
568
- self.resblock_dilation_sizes = resblock_dilation_sizes
569
- self.upsample_rates = upsample_rates
570
- self.upsample_initial_channel = upsample_initial_channel
571
- self.upsample_kernel_sizes = upsample_kernel_sizes
572
- self.segment_size = segment_size
573
- self.gin_channels = gin_channels
574
- # self.hop_length = hop_length#
575
- self.spk_embed_dim = spk_embed_dim
576
- self.enc_p = TextEncoder256(
577
- inter_channels,
578
- hidden_channels,
579
- filter_channels,
580
- n_heads,
581
- n_layers,
582
- kernel_size,
583
- p_dropout,
584
- )
585
- self.dec = GeneratorNSF(
586
- inter_channels,
587
- resblock,
588
- resblock_kernel_sizes,
589
- resblock_dilation_sizes,
590
- upsample_rates,
591
- upsample_initial_channel,
592
- upsample_kernel_sizes,
593
- gin_channels=gin_channels,
594
- sr=sr,
595
- is_half=kwargs["is_half"],
596
- )
597
- self.enc_q = PosteriorEncoder(
598
- spec_channels,
599
- inter_channels,
600
- hidden_channels,
601
- 5,
602
- 1,
603
- 16,
604
- gin_channels=gin_channels,
605
- )
606
- self.flow = ResidualCouplingBlock(
607
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
608
- )
609
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
610
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
611
-
612
- def remove_weight_norm(self):
613
- self.dec.remove_weight_norm()
614
- self.flow.remove_weight_norm()
615
- self.enc_q.remove_weight_norm()
616
-
617
- def forward(
618
- self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
619
- ): # 这里ds是id,[bs,1]
620
- # print(1,pitch.shape)#[bs,t]
621
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
622
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
623
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
624
- z_p = self.flow(z, y_mask, g=g)
625
- z_slice, ids_slice = commons.rand_slice_segments(
626
- z, y_lengths, self.segment_size
627
- )
628
- # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
629
- pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
630
- # print(-2,pitchf.shape,z_slice.shape)
631
- o = self.dec(z_slice, pitchf, g=g)
632
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
633
-
634
- def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
635
- g = self.emb_g(sid).unsqueeze(-1)
636
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
637
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
638
- if rate:
639
- head = int(z_p.shape[2] * rate)
640
- z_p = z_p[:, :, -head:]
641
- x_mask = x_mask[:, :, -head:]
642
- nsff0 = nsff0[:, -head:]
643
- z = self.flow(z_p, x_mask, g=g, reverse=True)
644
- o = self.dec(z * x_mask, nsff0, g=g)
645
- return o, x_mask, (z, z_p, m_p, logs_p)
646
-
647
-
648
- class SynthesizerTrnMs768NSFsid(nn.Module):
649
- def __init__(
650
- self,
651
- spec_channels,
652
- segment_size,
653
- inter_channels,
654
- hidden_channels,
655
- filter_channels,
656
- n_heads,
657
- n_layers,
658
- kernel_size,
659
- p_dropout,
660
- resblock,
661
- resblock_kernel_sizes,
662
- resblock_dilation_sizes,
663
- upsample_rates,
664
- upsample_initial_channel,
665
- upsample_kernel_sizes,
666
- spk_embed_dim,
667
- gin_channels,
668
- sr,
669
- **kwargs
670
- ):
671
- super().__init__()
672
- if type(sr) == type("strr"):
673
- sr = sr2sr[sr]
674
- self.spec_channels = spec_channels
675
- self.inter_channels = inter_channels
676
- self.hidden_channels = hidden_channels
677
- self.filter_channels = filter_channels
678
- self.n_heads = n_heads
679
- self.n_layers = n_layers
680
- self.kernel_size = kernel_size
681
- self.p_dropout = p_dropout
682
- self.resblock = resblock
683
- self.resblock_kernel_sizes = resblock_kernel_sizes
684
- self.resblock_dilation_sizes = resblock_dilation_sizes
685
- self.upsample_rates = upsample_rates
686
- self.upsample_initial_channel = upsample_initial_channel
687
- self.upsample_kernel_sizes = upsample_kernel_sizes
688
- self.segment_size = segment_size
689
- self.gin_channels = gin_channels
690
- # self.hop_length = hop_length#
691
- self.spk_embed_dim = spk_embed_dim
692
- self.enc_p = TextEncoder768(
693
- inter_channels,
694
- hidden_channels,
695
- filter_channels,
696
- n_heads,
697
- n_layers,
698
- kernel_size,
699
- p_dropout,
700
- )
701
- self.dec = GeneratorNSF(
702
- inter_channels,
703
- resblock,
704
- resblock_kernel_sizes,
705
- resblock_dilation_sizes,
706
- upsample_rates,
707
- upsample_initial_channel,
708
- upsample_kernel_sizes,
709
- gin_channels=gin_channels,
710
- sr=sr,
711
- is_half=kwargs["is_half"],
712
- )
713
- self.enc_q = PosteriorEncoder(
714
- spec_channels,
715
- inter_channels,
716
- hidden_channels,
717
- 5,
718
- 1,
719
- 16,
720
- gin_channels=gin_channels,
721
- )
722
- self.flow = ResidualCouplingBlock(
723
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
724
- )
725
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
726
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
727
-
728
- def remove_weight_norm(self):
729
- self.dec.remove_weight_norm()
730
- self.flow.remove_weight_norm()
731
- self.enc_q.remove_weight_norm()
732
-
733
- def forward(
734
- self, phone, phone_lengths, pitch, pitchf, y, y_lengths, ds
735
- ): # 这里ds是id,[bs,1]
736
- # print(1,pitch.shape)#[bs,t]
737
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
738
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
739
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
740
- z_p = self.flow(z, y_mask, g=g)
741
- z_slice, ids_slice = commons.rand_slice_segments(
742
- z, y_lengths, self.segment_size
743
- )
744
- # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length)
745
- pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size)
746
- # print(-2,pitchf.shape,z_slice.shape)
747
- o = self.dec(z_slice, pitchf, g=g)
748
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
749
-
750
- def infer(self, phone, phone_lengths, pitch, nsff0, sid, rate=None):
751
- g = self.emb_g(sid).unsqueeze(-1)
752
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
753
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
754
- if rate:
755
- head = int(z_p.shape[2] * rate)
756
- z_p = z_p[:, :, -head:]
757
- x_mask = x_mask[:, :, -head:]
758
- nsff0 = nsff0[:, -head:]
759
- z = self.flow(z_p, x_mask, g=g, reverse=True)
760
- o = self.dec(z * x_mask, nsff0, g=g)
761
- return o, x_mask, (z, z_p, m_p, logs_p)
762
-
763
-
764
- class SynthesizerTrnMs256NSFsid_nono(nn.Module):
765
- def __init__(
766
- self,
767
- spec_channels,
768
- segment_size,
769
- inter_channels,
770
- hidden_channels,
771
- filter_channels,
772
- n_heads,
773
- n_layers,
774
- kernel_size,
775
- p_dropout,
776
- resblock,
777
- resblock_kernel_sizes,
778
- resblock_dilation_sizes,
779
- upsample_rates,
780
- upsample_initial_channel,
781
- upsample_kernel_sizes,
782
- spk_embed_dim,
783
- gin_channels,
784
- sr=None,
785
- **kwargs
786
- ):
787
- super().__init__()
788
- self.spec_channels = spec_channels
789
- self.inter_channels = inter_channels
790
- self.hidden_channels = hidden_channels
791
- self.filter_channels = filter_channels
792
- self.n_heads = n_heads
793
- self.n_layers = n_layers
794
- self.kernel_size = kernel_size
795
- self.p_dropout = p_dropout
796
- self.resblock = resblock
797
- self.resblock_kernel_sizes = resblock_kernel_sizes
798
- self.resblock_dilation_sizes = resblock_dilation_sizes
799
- self.upsample_rates = upsample_rates
800
- self.upsample_initial_channel = upsample_initial_channel
801
- self.upsample_kernel_sizes = upsample_kernel_sizes
802
- self.segment_size = segment_size
803
- self.gin_channels = gin_channels
804
- # self.hop_length = hop_length#
805
- self.spk_embed_dim = spk_embed_dim
806
- self.enc_p = TextEncoder256(
807
- inter_channels,
808
- hidden_channels,
809
- filter_channels,
810
- n_heads,
811
- n_layers,
812
- kernel_size,
813
- p_dropout,
814
- f0=False,
815
- )
816
- self.dec = Generator(
817
- inter_channels,
818
- resblock,
819
- resblock_kernel_sizes,
820
- resblock_dilation_sizes,
821
- upsample_rates,
822
- upsample_initial_channel,
823
- upsample_kernel_sizes,
824
- gin_channels=gin_channels,
825
- )
826
- self.enc_q = PosteriorEncoder(
827
- spec_channels,
828
- inter_channels,
829
- hidden_channels,
830
- 5,
831
- 1,
832
- 16,
833
- gin_channels=gin_channels,
834
- )
835
- self.flow = ResidualCouplingBlock(
836
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
837
- )
838
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
839
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
840
-
841
- def remove_weight_norm(self):
842
- self.dec.remove_weight_norm()
843
- self.flow.remove_weight_norm()
844
- self.enc_q.remove_weight_norm()
845
-
846
- def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
847
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
848
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
849
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
850
- z_p = self.flow(z, y_mask, g=g)
851
- z_slice, ids_slice = commons.rand_slice_segments(
852
- z, y_lengths, self.segment_size
853
- )
854
- o = self.dec(z_slice, g=g)
855
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
856
-
857
- def infer(self, phone, phone_lengths, sid, rate=None):
858
- g = self.emb_g(sid).unsqueeze(-1)
859
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
860
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
861
- if rate:
862
- head = int(z_p.shape[2] * rate)
863
- z_p = z_p[:, :, -head:]
864
- x_mask = x_mask[:, :, -head:]
865
- z = self.flow(z_p, x_mask, g=g, reverse=True)
866
- o = self.dec(z * x_mask, g=g)
867
- return o, x_mask, (z, z_p, m_p, logs_p)
868
-
869
-
870
- class SynthesizerTrnMs768NSFsid_nono(nn.Module):
871
- def __init__(
872
- self,
873
- spec_channels,
874
- segment_size,
875
- inter_channels,
876
- hidden_channels,
877
- filter_channels,
878
- n_heads,
879
- n_layers,
880
- kernel_size,
881
- p_dropout,
882
- resblock,
883
- resblock_kernel_sizes,
884
- resblock_dilation_sizes,
885
- upsample_rates,
886
- upsample_initial_channel,
887
- upsample_kernel_sizes,
888
- spk_embed_dim,
889
- gin_channels,
890
- sr=None,
891
- **kwargs
892
- ):
893
- super().__init__()
894
- self.spec_channels = spec_channels
895
- self.inter_channels = inter_channels
896
- self.hidden_channels = hidden_channels
897
- self.filter_channels = filter_channels
898
- self.n_heads = n_heads
899
- self.n_layers = n_layers
900
- self.kernel_size = kernel_size
901
- self.p_dropout = p_dropout
902
- self.resblock = resblock
903
- self.resblock_kernel_sizes = resblock_kernel_sizes
904
- self.resblock_dilation_sizes = resblock_dilation_sizes
905
- self.upsample_rates = upsample_rates
906
- self.upsample_initial_channel = upsample_initial_channel
907
- self.upsample_kernel_sizes = upsample_kernel_sizes
908
- self.segment_size = segment_size
909
- self.gin_channels = gin_channels
910
- # self.hop_length = hop_length#
911
- self.spk_embed_dim = spk_embed_dim
912
- self.enc_p = TextEncoder768(
913
- inter_channels,
914
- hidden_channels,
915
- filter_channels,
916
- n_heads,
917
- n_layers,
918
- kernel_size,
919
- p_dropout,
920
- f0=False,
921
- )
922
- self.dec = Generator(
923
- inter_channels,
924
- resblock,
925
- resblock_kernel_sizes,
926
- resblock_dilation_sizes,
927
- upsample_rates,
928
- upsample_initial_channel,
929
- upsample_kernel_sizes,
930
- gin_channels=gin_channels,
931
- )
932
- self.enc_q = PosteriorEncoder(
933
- spec_channels,
934
- inter_channels,
935
- hidden_channels,
936
- 5,
937
- 1,
938
- 16,
939
- gin_channels=gin_channels,
940
- )
941
- self.flow = ResidualCouplingBlock(
942
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
943
- )
944
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
945
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
946
-
947
- def remove_weight_norm(self):
948
- self.dec.remove_weight_norm()
949
- self.flow.remove_weight_norm()
950
- self.enc_q.remove_weight_norm()
951
-
952
- def forward(self, phone, phone_lengths, y, y_lengths, ds): # 这里ds是id,[bs,1]
953
- g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的
954
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
955
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
956
- z_p = self.flow(z, y_mask, g=g)
957
- z_slice, ids_slice = commons.rand_slice_segments(
958
- z, y_lengths, self.segment_size
959
- )
960
- o = self.dec(z_slice, g=g)
961
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
962
-
963
- def infer(self, phone, phone_lengths, sid, rate=None):
964
- g = self.emb_g(sid).unsqueeze(-1)
965
- m_p, logs_p, x_mask = self.enc_p(phone, None, phone_lengths)
966
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
967
- if rate:
968
- head = int(z_p.shape[2] * rate)
969
- z_p = z_p[:, :, -head:]
970
- x_mask = x_mask[:, :, -head:]
971
- z = self.flow(z_p, x_mask, g=g, reverse=True)
972
- o = self.dec(z * x_mask, g=g)
973
- return o, x_mask, (z, z_p, m_p, logs_p)
974
-
975
-
976
- class MultiPeriodDiscriminator(torch.nn.Module):
977
- def __init__(self, use_spectral_norm=False):
978
- super(MultiPeriodDiscriminator, self).__init__()
979
- periods = [2, 3, 5, 7, 11, 17]
980
- # periods = [3, 5, 7, 11, 17, 23, 37]
981
-
982
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
983
- discs = discs + [
984
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
985
- ]
986
- self.discriminators = nn.ModuleList(discs)
987
-
988
- def forward(self, y, y_hat):
989
- y_d_rs = [] #
990
- y_d_gs = []
991
- fmap_rs = []
992
- fmap_gs = []
993
- for i, d in enumerate(self.discriminators):
994
- y_d_r, fmap_r = d(y)
995
- y_d_g, fmap_g = d(y_hat)
996
- # for j in range(len(fmap_r)):
997
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
998
- y_d_rs.append(y_d_r)
999
- y_d_gs.append(y_d_g)
1000
- fmap_rs.append(fmap_r)
1001
- fmap_gs.append(fmap_g)
1002
-
1003
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1004
-
1005
-
1006
- class MultiPeriodDiscriminatorV2(torch.nn.Module):
1007
- def __init__(self, use_spectral_norm=False):
1008
- super(MultiPeriodDiscriminatorV2, self).__init__()
1009
- # periods = [2, 3, 5, 7, 11, 17]
1010
- periods = [2, 3, 5, 7, 11, 17, 23, 37]
1011
-
1012
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
1013
- discs = discs + [
1014
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
1015
- ]
1016
- self.discriminators = nn.ModuleList(discs)
1017
-
1018
- def forward(self, y, y_hat):
1019
- y_d_rs = [] #
1020
- y_d_gs = []
1021
- fmap_rs = []
1022
- fmap_gs = []
1023
- for i, d in enumerate(self.discriminators):
1024
- y_d_r, fmap_r = d(y)
1025
- y_d_g, fmap_g = d(y_hat)
1026
- # for j in range(len(fmap_r)):
1027
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
1028
- y_d_rs.append(y_d_r)
1029
- y_d_gs.append(y_d_g)
1030
- fmap_rs.append(fmap_r)
1031
- fmap_gs.append(fmap_g)
1032
-
1033
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1034
-
1035
-
1036
- class DiscriminatorS(torch.nn.Module):
1037
- def __init__(self, use_spectral_norm=False):
1038
- super(DiscriminatorS, self).__init__()
1039
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1040
- self.convs = nn.ModuleList(
1041
- [
1042
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1043
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1044
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1045
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1046
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1047
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1048
- ]
1049
- )
1050
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1051
-
1052
- def forward(self, x):
1053
- fmap = []
1054
-
1055
- for l in self.convs:
1056
- x = l(x)
1057
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
1058
- fmap.append(x)
1059
- x = self.conv_post(x)
1060
- fmap.append(x)
1061
- x = torch.flatten(x, 1, -1)
1062
-
1063
- return x, fmap
1064
-
1065
-
1066
- class DiscriminatorP(torch.nn.Module):
1067
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
1068
- super(DiscriminatorP, self).__init__()
1069
- self.period = period
1070
- self.use_spectral_norm = use_spectral_norm
1071
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1072
- self.convs = nn.ModuleList(
1073
- [
1074
- norm_f(
1075
- Conv2d(
1076
- 1,
1077
- 32,
1078
- (kernel_size, 1),
1079
- (stride, 1),
1080
- padding=(get_padding(kernel_size, 1), 0),
1081
- )
1082
- ),
1083
- norm_f(
1084
- Conv2d(
1085
- 32,
1086
- 128,
1087
- (kernel_size, 1),
1088
- (stride, 1),
1089
- padding=(get_padding(kernel_size, 1), 0),
1090
- )
1091
- ),
1092
- norm_f(
1093
- Conv2d(
1094
- 128,
1095
- 512,
1096
- (kernel_size, 1),
1097
- (stride, 1),
1098
- padding=(get_padding(kernel_size, 1), 0),
1099
- )
1100
- ),
1101
- norm_f(
1102
- Conv2d(
1103
- 512,
1104
- 1024,
1105
- (kernel_size, 1),
1106
- (stride, 1),
1107
- padding=(get_padding(kernel_size, 1), 0),
1108
- )
1109
- ),
1110
- norm_f(
1111
- Conv2d(
1112
- 1024,
1113
- 1024,
1114
- (kernel_size, 1),
1115
- 1,
1116
- padding=(get_padding(kernel_size, 1), 0),
1117
- )
1118
- ),
1119
- ]
1120
- )
1121
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1122
-
1123
- def forward(self, x):
1124
- fmap = []
1125
-
1126
- # 1d to 2d
1127
- b, c, t = x.shape
1128
- if t % self.period != 0: # pad first
1129
- n_pad = self.period - (t % self.period)
1130
- x = F.pad(x, (0, n_pad), "reflect")
1131
- t = t + n_pad
1132
- x = x.view(b, c, t // self.period, self.period)
1133
-
1134
- for l in self.convs:
1135
- x = l(x)
1136
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
1137
- fmap.append(x)
1138
- x = self.conv_post(x)
1139
- fmap.append(x)
1140
- x = torch.flatten(x, 1, -1)
1141
-
1142
- return x, fmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/infer_pack/models_onnx.py DELETED
@@ -1,819 +0,0 @@
1
- import math, pdb, os
2
- from time import time as ttime
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from lib.infer_pack import modules
7
- from lib.infer_pack import attentions
8
- from lib.infer_pack import commons
9
- from lib.infer_pack.commons import init_weights, get_padding
10
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
11
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
- from lib.infer_pack.commons import init_weights
13
- import numpy as np
14
- from lib.infer_pack import commons
15
-
16
-
17
- class TextEncoder256(nn.Module):
18
- def __init__(
19
- self,
20
- out_channels,
21
- hidden_channels,
22
- filter_channels,
23
- n_heads,
24
- n_layers,
25
- kernel_size,
26
- p_dropout,
27
- f0=True,
28
- ):
29
- super().__init__()
30
- self.out_channels = out_channels
31
- self.hidden_channels = hidden_channels
32
- self.filter_channels = filter_channels
33
- self.n_heads = n_heads
34
- self.n_layers = n_layers
35
- self.kernel_size = kernel_size
36
- self.p_dropout = p_dropout
37
- self.emb_phone = nn.Linear(256, hidden_channels)
38
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
39
- if f0 == True:
40
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
41
- self.encoder = attentions.Encoder(
42
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
43
- )
44
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
45
-
46
- def forward(self, phone, pitch, lengths):
47
- if pitch == None:
48
- x = self.emb_phone(phone)
49
- else:
50
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
51
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
52
- x = self.lrelu(x)
53
- x = torch.transpose(x, 1, -1) # [b, h, t]
54
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
55
- x.dtype
56
- )
57
- x = self.encoder(x * x_mask, x_mask)
58
- stats = self.proj(x) * x_mask
59
-
60
- m, logs = torch.split(stats, self.out_channels, dim=1)
61
- return m, logs, x_mask
62
-
63
-
64
- class TextEncoder768(nn.Module):
65
- def __init__(
66
- self,
67
- out_channels,
68
- hidden_channels,
69
- filter_channels,
70
- n_heads,
71
- n_layers,
72
- kernel_size,
73
- p_dropout,
74
- f0=True,
75
- ):
76
- super().__init__()
77
- self.out_channels = out_channels
78
- self.hidden_channels = hidden_channels
79
- self.filter_channels = filter_channels
80
- self.n_heads = n_heads
81
- self.n_layers = n_layers
82
- self.kernel_size = kernel_size
83
- self.p_dropout = p_dropout
84
- self.emb_phone = nn.Linear(768, hidden_channels)
85
- self.lrelu = nn.LeakyReLU(0.1, inplace=True)
86
- if f0 == True:
87
- self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256
88
- self.encoder = attentions.Encoder(
89
- hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
90
- )
91
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
92
-
93
- def forward(self, phone, pitch, lengths):
94
- if pitch == None:
95
- x = self.emb_phone(phone)
96
- else:
97
- x = self.emb_phone(phone) + self.emb_pitch(pitch)
98
- x = x * math.sqrt(self.hidden_channels) # [b, t, h]
99
- x = self.lrelu(x)
100
- x = torch.transpose(x, 1, -1) # [b, h, t]
101
- x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to(
102
- x.dtype
103
- )
104
- x = self.encoder(x * x_mask, x_mask)
105
- stats = self.proj(x) * x_mask
106
-
107
- m, logs = torch.split(stats, self.out_channels, dim=1)
108
- return m, logs, x_mask
109
-
110
-
111
- class ResidualCouplingBlock(nn.Module):
112
- def __init__(
113
- self,
114
- channels,
115
- hidden_channels,
116
- kernel_size,
117
- dilation_rate,
118
- n_layers,
119
- n_flows=4,
120
- gin_channels=0,
121
- ):
122
- super().__init__()
123
- self.channels = channels
124
- self.hidden_channels = hidden_channels
125
- self.kernel_size = kernel_size
126
- self.dilation_rate = dilation_rate
127
- self.n_layers = n_layers
128
- self.n_flows = n_flows
129
- self.gin_channels = gin_channels
130
-
131
- self.flows = nn.ModuleList()
132
- for i in range(n_flows):
133
- self.flows.append(
134
- modules.ResidualCouplingLayer(
135
- channels,
136
- hidden_channels,
137
- kernel_size,
138
- dilation_rate,
139
- n_layers,
140
- gin_channels=gin_channels,
141
- mean_only=True,
142
- )
143
- )
144
- self.flows.append(modules.Flip())
145
-
146
- def forward(self, x, x_mask, g=None, reverse=False):
147
- if not reverse:
148
- for flow in self.flows:
149
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
150
- else:
151
- for flow in reversed(self.flows):
152
- x = flow(x, x_mask, g=g, reverse=reverse)
153
- return x
154
-
155
- def remove_weight_norm(self):
156
- for i in range(self.n_flows):
157
- self.flows[i * 2].remove_weight_norm()
158
-
159
-
160
- class PosteriorEncoder(nn.Module):
161
- def __init__(
162
- self,
163
- in_channels,
164
- out_channels,
165
- hidden_channels,
166
- kernel_size,
167
- dilation_rate,
168
- n_layers,
169
- gin_channels=0,
170
- ):
171
- super().__init__()
172
- self.in_channels = in_channels
173
- self.out_channels = out_channels
174
- self.hidden_channels = hidden_channels
175
- self.kernel_size = kernel_size
176
- self.dilation_rate = dilation_rate
177
- self.n_layers = n_layers
178
- self.gin_channels = gin_channels
179
-
180
- self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
181
- self.enc = modules.WN(
182
- hidden_channels,
183
- kernel_size,
184
- dilation_rate,
185
- n_layers,
186
- gin_channels=gin_channels,
187
- )
188
- self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
189
-
190
- def forward(self, x, x_lengths, g=None):
191
- x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
192
- x.dtype
193
- )
194
- x = self.pre(x) * x_mask
195
- x = self.enc(x, x_mask, g=g)
196
- stats = self.proj(x) * x_mask
197
- m, logs = torch.split(stats, self.out_channels, dim=1)
198
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
199
- return z, m, logs, x_mask
200
-
201
- def remove_weight_norm(self):
202
- self.enc.remove_weight_norm()
203
-
204
-
205
- class Generator(torch.nn.Module):
206
- def __init__(
207
- self,
208
- initial_channel,
209
- resblock,
210
- resblock_kernel_sizes,
211
- resblock_dilation_sizes,
212
- upsample_rates,
213
- upsample_initial_channel,
214
- upsample_kernel_sizes,
215
- gin_channels=0,
216
- ):
217
- super(Generator, self).__init__()
218
- self.num_kernels = len(resblock_kernel_sizes)
219
- self.num_upsamples = len(upsample_rates)
220
- self.conv_pre = Conv1d(
221
- initial_channel, upsample_initial_channel, 7, 1, padding=3
222
- )
223
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
224
-
225
- self.ups = nn.ModuleList()
226
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
227
- self.ups.append(
228
- weight_norm(
229
- ConvTranspose1d(
230
- upsample_initial_channel // (2**i),
231
- upsample_initial_channel // (2 ** (i + 1)),
232
- k,
233
- u,
234
- padding=(k - u) // 2,
235
- )
236
- )
237
- )
238
-
239
- self.resblocks = nn.ModuleList()
240
- for i in range(len(self.ups)):
241
- ch = upsample_initial_channel // (2 ** (i + 1))
242
- for j, (k, d) in enumerate(
243
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
244
- ):
245
- self.resblocks.append(resblock(ch, k, d))
246
-
247
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
248
- self.ups.apply(init_weights)
249
-
250
- if gin_channels != 0:
251
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
252
-
253
- def forward(self, x, g=None):
254
- x = self.conv_pre(x)
255
- if g is not None:
256
- x = x + self.cond(g)
257
-
258
- for i in range(self.num_upsamples):
259
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
260
- x = self.ups[i](x)
261
- xs = None
262
- for j in range(self.num_kernels):
263
- if xs is None:
264
- xs = self.resblocks[i * self.num_kernels + j](x)
265
- else:
266
- xs += self.resblocks[i * self.num_kernels + j](x)
267
- x = xs / self.num_kernels
268
- x = F.leaky_relu(x)
269
- x = self.conv_post(x)
270
- x = torch.tanh(x)
271
-
272
- return x
273
-
274
- def remove_weight_norm(self):
275
- for l in self.ups:
276
- remove_weight_norm(l)
277
- for l in self.resblocks:
278
- l.remove_weight_norm()
279
-
280
-
281
- class SineGen(torch.nn.Module):
282
- """Definition of sine generator
283
- SineGen(samp_rate, harmonic_num = 0,
284
- sine_amp = 0.1, noise_std = 0.003,
285
- voiced_threshold = 0,
286
- flag_for_pulse=False)
287
- samp_rate: sampling rate in Hz
288
- harmonic_num: number of harmonic overtones (default 0)
289
- sine_amp: amplitude of sine-wavefrom (default 0.1)
290
- noise_std: std of Gaussian noise (default 0.003)
291
- voiced_thoreshold: F0 threshold for U/V classification (default 0)
292
- flag_for_pulse: this SinGen is used inside PulseGen (default False)
293
- Note: when flag_for_pulse is True, the first time step of a voiced
294
- segment is always sin(np.pi) or cos(0)
295
- """
296
-
297
- def __init__(
298
- self,
299
- samp_rate,
300
- harmonic_num=0,
301
- sine_amp=0.1,
302
- noise_std=0.003,
303
- voiced_threshold=0,
304
- flag_for_pulse=False,
305
- ):
306
- super(SineGen, self).__init__()
307
- self.sine_amp = sine_amp
308
- self.noise_std = noise_std
309
- self.harmonic_num = harmonic_num
310
- self.dim = self.harmonic_num + 1
311
- self.sampling_rate = samp_rate
312
- self.voiced_threshold = voiced_threshold
313
-
314
- def _f02uv(self, f0):
315
- # generate uv signal
316
- uv = torch.ones_like(f0)
317
- uv = uv * (f0 > self.voiced_threshold)
318
- return uv
319
-
320
- def forward(self, f0, upp):
321
- """sine_tensor, uv = forward(f0)
322
- input F0: tensor(batchsize=1, length, dim=1)
323
- f0 for unvoiced steps should be 0
324
- output sine_tensor: tensor(batchsize=1, length, dim)
325
- output uv: tensor(batchsize=1, length, 1)
326
- """
327
- with torch.no_grad():
328
- f0 = f0[:, None].transpose(1, 2)
329
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
330
- # fundamental component
331
- f0_buf[:, :, 0] = f0[:, :, 0]
332
- for idx in np.arange(self.harmonic_num):
333
- f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (
334
- idx + 2
335
- ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
336
- rad_values = (f0_buf / self.sampling_rate) % 1 ###%1意味着n_har的乘积无法后处理优化
337
- rand_ini = torch.rand(
338
- f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device
339
- )
340
- rand_ini[:, 0] = 0
341
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
342
- tmp_over_one = torch.cumsum(rad_values, 1) # % 1 #####%1意味着后面的cumsum无法再优化
343
- tmp_over_one *= upp
344
- tmp_over_one = F.interpolate(
345
- tmp_over_one.transpose(2, 1),
346
- scale_factor=upp,
347
- mode="linear",
348
- align_corners=True,
349
- ).transpose(2, 1)
350
- rad_values = F.interpolate(
351
- rad_values.transpose(2, 1), scale_factor=upp, mode="nearest"
352
- ).transpose(
353
- 2, 1
354
- ) #######
355
- tmp_over_one %= 1
356
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
357
- cumsum_shift = torch.zeros_like(rad_values)
358
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
359
- sine_waves = torch.sin(
360
- torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi
361
- )
362
- sine_waves = sine_waves * self.sine_amp
363
- uv = self._f02uv(f0)
364
- uv = F.interpolate(
365
- uv.transpose(2, 1), scale_factor=upp, mode="nearest"
366
- ).transpose(2, 1)
367
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
368
- noise = noise_amp * torch.randn_like(sine_waves)
369
- sine_waves = sine_waves * uv + noise
370
- return sine_waves, uv, noise
371
-
372
-
373
- class SourceModuleHnNSF(torch.nn.Module):
374
- """SourceModule for hn-nsf
375
- SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
376
- add_noise_std=0.003, voiced_threshod=0)
377
- sampling_rate: sampling_rate in Hz
378
- harmonic_num: number of harmonic above F0 (default: 0)
379
- sine_amp: amplitude of sine source signal (default: 0.1)
380
- add_noise_std: std of additive Gaussian noise (default: 0.003)
381
- note that amplitude of noise in unvoiced is decided
382
- by sine_amp
383
- voiced_threshold: threhold to set U/V given F0 (default: 0)
384
- Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
385
- F0_sampled (batchsize, length, 1)
386
- Sine_source (batchsize, length, 1)
387
- noise_source (batchsize, length 1)
388
- uv (batchsize, length, 1)
389
- """
390
-
391
- def __init__(
392
- self,
393
- sampling_rate,
394
- harmonic_num=0,
395
- sine_amp=0.1,
396
- add_noise_std=0.003,
397
- voiced_threshod=0,
398
- is_half=True,
399
- ):
400
- super(SourceModuleHnNSF, self).__init__()
401
-
402
- self.sine_amp = sine_amp
403
- self.noise_std = add_noise_std
404
- self.is_half = is_half
405
- # to produce sine waveforms
406
- self.l_sin_gen = SineGen(
407
- sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod
408
- )
409
-
410
- # to merge source harmonics into a single excitation
411
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
412
- self.l_tanh = torch.nn.Tanh()
413
-
414
- def forward(self, x, upp=None):
415
- sine_wavs, uv, _ = self.l_sin_gen(x, upp)
416
- if self.is_half:
417
- sine_wavs = sine_wavs.half()
418
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
419
- return sine_merge, None, None # noise, uv
420
-
421
-
422
- class GeneratorNSF(torch.nn.Module):
423
- def __init__(
424
- self,
425
- initial_channel,
426
- resblock,
427
- resblock_kernel_sizes,
428
- resblock_dilation_sizes,
429
- upsample_rates,
430
- upsample_initial_channel,
431
- upsample_kernel_sizes,
432
- gin_channels,
433
- sr,
434
- is_half=False,
435
- ):
436
- super(GeneratorNSF, self).__init__()
437
- self.num_kernels = len(resblock_kernel_sizes)
438
- self.num_upsamples = len(upsample_rates)
439
-
440
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
441
- self.m_source = SourceModuleHnNSF(
442
- sampling_rate=sr, harmonic_num=0, is_half=is_half
443
- )
444
- self.noise_convs = nn.ModuleList()
445
- self.conv_pre = Conv1d(
446
- initial_channel, upsample_initial_channel, 7, 1, padding=3
447
- )
448
- resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
449
-
450
- self.ups = nn.ModuleList()
451
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
452
- c_cur = upsample_initial_channel // (2 ** (i + 1))
453
- self.ups.append(
454
- weight_norm(
455
- ConvTranspose1d(
456
- upsample_initial_channel // (2**i),
457
- upsample_initial_channel // (2 ** (i + 1)),
458
- k,
459
- u,
460
- padding=(k - u) // 2,
461
- )
462
- )
463
- )
464
- if i + 1 < len(upsample_rates):
465
- stride_f0 = np.prod(upsample_rates[i + 1 :])
466
- self.noise_convs.append(
467
- Conv1d(
468
- 1,
469
- c_cur,
470
- kernel_size=stride_f0 * 2,
471
- stride=stride_f0,
472
- padding=stride_f0 // 2,
473
- )
474
- )
475
- else:
476
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
477
-
478
- self.resblocks = nn.ModuleList()
479
- for i in range(len(self.ups)):
480
- ch = upsample_initial_channel // (2 ** (i + 1))
481
- for j, (k, d) in enumerate(
482
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
483
- ):
484
- self.resblocks.append(resblock(ch, k, d))
485
-
486
- self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
487
- self.ups.apply(init_weights)
488
-
489
- if gin_channels != 0:
490
- self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
491
-
492
- self.upp = np.prod(upsample_rates)
493
-
494
- def forward(self, x, f0, g=None):
495
- har_source, noi_source, uv = self.m_source(f0, self.upp)
496
- har_source = har_source.transpose(1, 2)
497
- x = self.conv_pre(x)
498
- if g is not None:
499
- x = x + self.cond(g)
500
-
501
- for i in range(self.num_upsamples):
502
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
503
- x = self.ups[i](x)
504
- x_source = self.noise_convs[i](har_source)
505
- x = x + x_source
506
- xs = None
507
- for j in range(self.num_kernels):
508
- if xs is None:
509
- xs = self.resblocks[i * self.num_kernels + j](x)
510
- else:
511
- xs += self.resblocks[i * self.num_kernels + j](x)
512
- x = xs / self.num_kernels
513
- x = F.leaky_relu(x)
514
- x = self.conv_post(x)
515
- x = torch.tanh(x)
516
- return x
517
-
518
- def remove_weight_norm(self):
519
- for l in self.ups:
520
- remove_weight_norm(l)
521
- for l in self.resblocks:
522
- l.remove_weight_norm()
523
-
524
-
525
- sr2sr = {
526
- "32k": 32000,
527
- "40k": 40000,
528
- "48k": 48000,
529
- }
530
-
531
-
532
- class SynthesizerTrnMsNSFsidM(nn.Module):
533
- def __init__(
534
- self,
535
- spec_channels,
536
- segment_size,
537
- inter_channels,
538
- hidden_channels,
539
- filter_channels,
540
- n_heads,
541
- n_layers,
542
- kernel_size,
543
- p_dropout,
544
- resblock,
545
- resblock_kernel_sizes,
546
- resblock_dilation_sizes,
547
- upsample_rates,
548
- upsample_initial_channel,
549
- upsample_kernel_sizes,
550
- spk_embed_dim,
551
- gin_channels,
552
- sr,
553
- version,
554
- **kwargs
555
- ):
556
- super().__init__()
557
- if type(sr) == type("strr"):
558
- sr = sr2sr[sr]
559
- self.spec_channels = spec_channels
560
- self.inter_channels = inter_channels
561
- self.hidden_channels = hidden_channels
562
- self.filter_channels = filter_channels
563
- self.n_heads = n_heads
564
- self.n_layers = n_layers
565
- self.kernel_size = kernel_size
566
- self.p_dropout = p_dropout
567
- self.resblock = resblock
568
- self.resblock_kernel_sizes = resblock_kernel_sizes
569
- self.resblock_dilation_sizes = resblock_dilation_sizes
570
- self.upsample_rates = upsample_rates
571
- self.upsample_initial_channel = upsample_initial_channel
572
- self.upsample_kernel_sizes = upsample_kernel_sizes
573
- self.segment_size = segment_size
574
- self.gin_channels = gin_channels
575
- # self.hop_length = hop_length#
576
- self.spk_embed_dim = spk_embed_dim
577
- if version == "v1":
578
- self.enc_p = TextEncoder256(
579
- inter_channels,
580
- hidden_channels,
581
- filter_channels,
582
- n_heads,
583
- n_layers,
584
- kernel_size,
585
- p_dropout,
586
- )
587
- else:
588
- self.enc_p = TextEncoder768(
589
- inter_channels,
590
- hidden_channels,
591
- filter_channels,
592
- n_heads,
593
- n_layers,
594
- kernel_size,
595
- p_dropout,
596
- )
597
- self.dec = GeneratorNSF(
598
- inter_channels,
599
- resblock,
600
- resblock_kernel_sizes,
601
- resblock_dilation_sizes,
602
- upsample_rates,
603
- upsample_initial_channel,
604
- upsample_kernel_sizes,
605
- gin_channels=gin_channels,
606
- sr=sr,
607
- is_half=kwargs["is_half"],
608
- )
609
- self.enc_q = PosteriorEncoder(
610
- spec_channels,
611
- inter_channels,
612
- hidden_channels,
613
- 5,
614
- 1,
615
- 16,
616
- gin_channels=gin_channels,
617
- )
618
- self.flow = ResidualCouplingBlock(
619
- inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels
620
- )
621
- self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels)
622
- self.speaker_map = None
623
- print("gin_channels:", gin_channels, "self.spk_embed_dim:", self.spk_embed_dim)
624
-
625
- def remove_weight_norm(self):
626
- self.dec.remove_weight_norm()
627
- self.flow.remove_weight_norm()
628
- self.enc_q.remove_weight_norm()
629
-
630
- def construct_spkmixmap(self, n_speaker):
631
- self.speaker_map = torch.zeros((n_speaker, 1, 1, self.gin_channels))
632
- for i in range(n_speaker):
633
- self.speaker_map[i] = self.emb_g(torch.LongTensor([[i]]))
634
- self.speaker_map = self.speaker_map.unsqueeze(0)
635
-
636
- def forward(self, phone, phone_lengths, pitch, nsff0, g, rnd, max_len=None):
637
- if self.speaker_map is not None: # [N, S] * [S, B, 1, H]
638
- g = g.reshape((g.shape[0], g.shape[1], 1, 1, 1)) # [N, S, B, 1, 1]
639
- g = g * self.speaker_map # [N, S, B, 1, H]
640
- g = torch.sum(g, dim=1) # [N, 1, B, 1, H]
641
- g = g.transpose(0, -1).transpose(0, -2).squeeze(0) # [B, H, N]
642
- else:
643
- g = g.unsqueeze(0)
644
- g = self.emb_g(g).transpose(1, 2)
645
-
646
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
647
- z_p = (m_p + torch.exp(logs_p) * rnd) * x_mask
648
- z = self.flow(z_p, x_mask, g=g, reverse=True)
649
- o = self.dec((z * x_mask)[:, :, :max_len], nsff0, g=g)
650
- return o
651
-
652
-
653
- class MultiPeriodDiscriminator(torch.nn.Module):
654
- def __init__(self, use_spectral_norm=False):
655
- super(MultiPeriodDiscriminator, self).__init__()
656
- periods = [2, 3, 5, 7, 11, 17]
657
- # periods = [3, 5, 7, 11, 17, 23, 37]
658
-
659
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
660
- discs = discs + [
661
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
662
- ]
663
- self.discriminators = nn.ModuleList(discs)
664
-
665
- def forward(self, y, y_hat):
666
- y_d_rs = [] #
667
- y_d_gs = []
668
- fmap_rs = []
669
- fmap_gs = []
670
- for i, d in enumerate(self.discriminators):
671
- y_d_r, fmap_r = d(y)
672
- y_d_g, fmap_g = d(y_hat)
673
- # for j in range(len(fmap_r)):
674
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
675
- y_d_rs.append(y_d_r)
676
- y_d_gs.append(y_d_g)
677
- fmap_rs.append(fmap_r)
678
- fmap_gs.append(fmap_g)
679
-
680
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
681
-
682
-
683
- class MultiPeriodDiscriminatorV2(torch.nn.Module):
684
- def __init__(self, use_spectral_norm=False):
685
- super(MultiPeriodDiscriminatorV2, self).__init__()
686
- # periods = [2, 3, 5, 7, 11, 17]
687
- periods = [2, 3, 5, 7, 11, 17, 23, 37]
688
-
689
- discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
690
- discs = discs + [
691
- DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
692
- ]
693
- self.discriminators = nn.ModuleList(discs)
694
-
695
- def forward(self, y, y_hat):
696
- y_d_rs = [] #
697
- y_d_gs = []
698
- fmap_rs = []
699
- fmap_gs = []
700
- for i, d in enumerate(self.discriminators):
701
- y_d_r, fmap_r = d(y)
702
- y_d_g, fmap_g = d(y_hat)
703
- # for j in range(len(fmap_r)):
704
- # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
705
- y_d_rs.append(y_d_r)
706
- y_d_gs.append(y_d_g)
707
- fmap_rs.append(fmap_r)
708
- fmap_gs.append(fmap_g)
709
-
710
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
711
-
712
-
713
- class DiscriminatorS(torch.nn.Module):
714
- def __init__(self, use_spectral_norm=False):
715
- super(DiscriminatorS, self).__init__()
716
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
717
- self.convs = nn.ModuleList(
718
- [
719
- norm_f(Conv1d(1, 16, 15, 1, padding=7)),
720
- norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
721
- norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
722
- norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
723
- norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
724
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
725
- ]
726
- )
727
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
728
-
729
- def forward(self, x):
730
- fmap = []
731
-
732
- for l in self.convs:
733
- x = l(x)
734
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
735
- fmap.append(x)
736
- x = self.conv_post(x)
737
- fmap.append(x)
738
- x = torch.flatten(x, 1, -1)
739
-
740
- return x, fmap
741
-
742
-
743
- class DiscriminatorP(torch.nn.Module):
744
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
745
- super(DiscriminatorP, self).__init__()
746
- self.period = period
747
- self.use_spectral_norm = use_spectral_norm
748
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
749
- self.convs = nn.ModuleList(
750
- [
751
- norm_f(
752
- Conv2d(
753
- 1,
754
- 32,
755
- (kernel_size, 1),
756
- (stride, 1),
757
- padding=(get_padding(kernel_size, 1), 0),
758
- )
759
- ),
760
- norm_f(
761
- Conv2d(
762
- 32,
763
- 128,
764
- (kernel_size, 1),
765
- (stride, 1),
766
- padding=(get_padding(kernel_size, 1), 0),
767
- )
768
- ),
769
- norm_f(
770
- Conv2d(
771
- 128,
772
- 512,
773
- (kernel_size, 1),
774
- (stride, 1),
775
- padding=(get_padding(kernel_size, 1), 0),
776
- )
777
- ),
778
- norm_f(
779
- Conv2d(
780
- 512,
781
- 1024,
782
- (kernel_size, 1),
783
- (stride, 1),
784
- padding=(get_padding(kernel_size, 1), 0),
785
- )
786
- ),
787
- norm_f(
788
- Conv2d(
789
- 1024,
790
- 1024,
791
- (kernel_size, 1),
792
- 1,
793
- padding=(get_padding(kernel_size, 1), 0),
794
- )
795
- ),
796
- ]
797
- )
798
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
799
-
800
- def forward(self, x):
801
- fmap = []
802
-
803
- # 1d to 2d
804
- b, c, t = x.shape
805
- if t % self.period != 0: # pad first
806
- n_pad = self.period - (t % self.period)
807
- x = F.pad(x, (0, n_pad), "reflect")
808
- t = t + n_pad
809
- x = x.view(b, c, t // self.period, self.period)
810
-
811
- for l in self.convs:
812
- x = l(x)
813
- x = F.leaky_relu(x, modules.LRELU_SLOPE)
814
- fmap.append(x)
815
- x = self.conv_post(x)
816
- fmap.append(x)
817
- x = torch.flatten(x, 1, -1)
818
-
819
- return x, fmap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/infer_pack/modules.py DELETED
@@ -1,522 +0,0 @@
1
- import copy
2
- import math
3
- import numpy as np
4
- import scipy
5
- import torch
6
- from torch import nn
7
- from torch.nn import functional as F
8
-
9
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
- from torch.nn.utils import weight_norm, remove_weight_norm
11
-
12
- from lib.infer_pack import commons
13
- from lib.infer_pack.commons import init_weights, get_padding
14
- from lib.infer_pack.transforms import piecewise_rational_quadratic_transform
15
-
16
-
17
- LRELU_SLOPE = 0.1
18
-
19
-
20
- class LayerNorm(nn.Module):
21
- def __init__(self, channels, eps=1e-5):
22
- super().__init__()
23
- self.channels = channels
24
- self.eps = eps
25
-
26
- self.gamma = nn.Parameter(torch.ones(channels))
27
- self.beta = nn.Parameter(torch.zeros(channels))
28
-
29
- def forward(self, x):
30
- x = x.transpose(1, -1)
31
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
- return x.transpose(1, -1)
33
-
34
-
35
- class ConvReluNorm(nn.Module):
36
- def __init__(
37
- self,
38
- in_channels,
39
- hidden_channels,
40
- out_channels,
41
- kernel_size,
42
- n_layers,
43
- p_dropout,
44
- ):
45
- super().__init__()
46
- self.in_channels = in_channels
47
- self.hidden_channels = hidden_channels
48
- self.out_channels = out_channels
49
- self.kernel_size = kernel_size
50
- self.n_layers = n_layers
51
- self.p_dropout = p_dropout
52
- assert n_layers > 1, "Number of layers should be larger than 0."
53
-
54
- self.conv_layers = nn.ModuleList()
55
- self.norm_layers = nn.ModuleList()
56
- self.conv_layers.append(
57
- nn.Conv1d(
58
- in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
- )
60
- )
61
- self.norm_layers.append(LayerNorm(hidden_channels))
62
- self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
- for _ in range(n_layers - 1):
64
- self.conv_layers.append(
65
- nn.Conv1d(
66
- hidden_channels,
67
- hidden_channels,
68
- kernel_size,
69
- padding=kernel_size // 2,
70
- )
71
- )
72
- self.norm_layers.append(LayerNorm(hidden_channels))
73
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
- self.proj.weight.data.zero_()
75
- self.proj.bias.data.zero_()
76
-
77
- def forward(self, x, x_mask):
78
- x_org = x
79
- for i in range(self.n_layers):
80
- x = self.conv_layers[i](x * x_mask)
81
- x = self.norm_layers[i](x)
82
- x = self.relu_drop(x)
83
- x = x_org + self.proj(x)
84
- return x * x_mask
85
-
86
-
87
- class DDSConv(nn.Module):
88
- """
89
- Dialted and Depth-Separable Convolution
90
- """
91
-
92
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
- super().__init__()
94
- self.channels = channels
95
- self.kernel_size = kernel_size
96
- self.n_layers = n_layers
97
- self.p_dropout = p_dropout
98
-
99
- self.drop = nn.Dropout(p_dropout)
100
- self.convs_sep = nn.ModuleList()
101
- self.convs_1x1 = nn.ModuleList()
102
- self.norms_1 = nn.ModuleList()
103
- self.norms_2 = nn.ModuleList()
104
- for i in range(n_layers):
105
- dilation = kernel_size**i
106
- padding = (kernel_size * dilation - dilation) // 2
107
- self.convs_sep.append(
108
- nn.Conv1d(
109
- channels,
110
- channels,
111
- kernel_size,
112
- groups=channels,
113
- dilation=dilation,
114
- padding=padding,
115
- )
116
- )
117
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
- self.norms_1.append(LayerNorm(channels))
119
- self.norms_2.append(LayerNorm(channels))
120
-
121
- def forward(self, x, x_mask, g=None):
122
- if g is not None:
123
- x = x + g
124
- for i in range(self.n_layers):
125
- y = self.convs_sep[i](x * x_mask)
126
- y = self.norms_1[i](y)
127
- y = F.gelu(y)
128
- y = self.convs_1x1[i](y)
129
- y = self.norms_2[i](y)
130
- y = F.gelu(y)
131
- y = self.drop(y)
132
- x = x + y
133
- return x * x_mask
134
-
135
-
136
- class WN(torch.nn.Module):
137
- def __init__(
138
- self,
139
- hidden_channels,
140
- kernel_size,
141
- dilation_rate,
142
- n_layers,
143
- gin_channels=0,
144
- p_dropout=0,
145
- ):
146
- super(WN, self).__init__()
147
- assert kernel_size % 2 == 1
148
- self.hidden_channels = hidden_channels
149
- self.kernel_size = (kernel_size,)
150
- self.dilation_rate = dilation_rate
151
- self.n_layers = n_layers
152
- self.gin_channels = gin_channels
153
- self.p_dropout = p_dropout
154
-
155
- self.in_layers = torch.nn.ModuleList()
156
- self.res_skip_layers = torch.nn.ModuleList()
157
- self.drop = nn.Dropout(p_dropout)
158
-
159
- if gin_channels != 0:
160
- cond_layer = torch.nn.Conv1d(
161
- gin_channels, 2 * hidden_channels * n_layers, 1
162
- )
163
- self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
-
165
- for i in range(n_layers):
166
- dilation = dilation_rate**i
167
- padding = int((kernel_size * dilation - dilation) / 2)
168
- in_layer = torch.nn.Conv1d(
169
- hidden_channels,
170
- 2 * hidden_channels,
171
- kernel_size,
172
- dilation=dilation,
173
- padding=padding,
174
- )
175
- in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
- self.in_layers.append(in_layer)
177
-
178
- # last one is not necessary
179
- if i < n_layers - 1:
180
- res_skip_channels = 2 * hidden_channels
181
- else:
182
- res_skip_channels = hidden_channels
183
-
184
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
- res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
- self.res_skip_layers.append(res_skip_layer)
187
-
188
- def forward(self, x, x_mask, g=None, **kwargs):
189
- output = torch.zeros_like(x)
190
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
191
-
192
- if g is not None:
193
- g = self.cond_layer(g)
194
-
195
- for i in range(self.n_layers):
196
- x_in = self.in_layers[i](x)
197
- if g is not None:
198
- cond_offset = i * 2 * self.hidden_channels
199
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200
- else:
201
- g_l = torch.zeros_like(x_in)
202
-
203
- acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204
- acts = self.drop(acts)
205
-
206
- res_skip_acts = self.res_skip_layers[i](acts)
207
- if i < self.n_layers - 1:
208
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
209
- x = (x + res_acts) * x_mask
210
- output = output + res_skip_acts[:, self.hidden_channels :, :]
211
- else:
212
- output = output + res_skip_acts
213
- return output * x_mask
214
-
215
- def remove_weight_norm(self):
216
- if self.gin_channels != 0:
217
- torch.nn.utils.remove_weight_norm(self.cond_layer)
218
- for l in self.in_layers:
219
- torch.nn.utils.remove_weight_norm(l)
220
- for l in self.res_skip_layers:
221
- torch.nn.utils.remove_weight_norm(l)
222
-
223
-
224
- class ResBlock1(torch.nn.Module):
225
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226
- super(ResBlock1, self).__init__()
227
- self.convs1 = nn.ModuleList(
228
- [
229
- weight_norm(
230
- Conv1d(
231
- channels,
232
- channels,
233
- kernel_size,
234
- 1,
235
- dilation=dilation[0],
236
- padding=get_padding(kernel_size, dilation[0]),
237
- )
238
- ),
239
- weight_norm(
240
- Conv1d(
241
- channels,
242
- channels,
243
- kernel_size,
244
- 1,
245
- dilation=dilation[1],
246
- padding=get_padding(kernel_size, dilation[1]),
247
- )
248
- ),
249
- weight_norm(
250
- Conv1d(
251
- channels,
252
- channels,
253
- kernel_size,
254
- 1,
255
- dilation=dilation[2],
256
- padding=get_padding(kernel_size, dilation[2]),
257
- )
258
- ),
259
- ]
260
- )
261
- self.convs1.apply(init_weights)
262
-
263
- self.convs2 = nn.ModuleList(
264
- [
265
- weight_norm(
266
- Conv1d(
267
- channels,
268
- channels,
269
- kernel_size,
270
- 1,
271
- dilation=1,
272
- padding=get_padding(kernel_size, 1),
273
- )
274
- ),
275
- weight_norm(
276
- Conv1d(
277
- channels,
278
- channels,
279
- kernel_size,
280
- 1,
281
- dilation=1,
282
- padding=get_padding(kernel_size, 1),
283
- )
284
- ),
285
- weight_norm(
286
- Conv1d(
287
- channels,
288
- channels,
289
- kernel_size,
290
- 1,
291
- dilation=1,
292
- padding=get_padding(kernel_size, 1),
293
- )
294
- ),
295
- ]
296
- )
297
- self.convs2.apply(init_weights)
298
-
299
- def forward(self, x, x_mask=None):
300
- for c1, c2 in zip(self.convs1, self.convs2):
301
- xt = F.leaky_relu(x, LRELU_SLOPE)
302
- if x_mask is not None:
303
- xt = xt * x_mask
304
- xt = c1(xt)
305
- xt = F.leaky_relu(xt, LRELU_SLOPE)
306
- if x_mask is not None:
307
- xt = xt * x_mask
308
- xt = c2(xt)
309
- x = xt + x
310
- if x_mask is not None:
311
- x = x * x_mask
312
- return x
313
-
314
- def remove_weight_norm(self):
315
- for l in self.convs1:
316
- remove_weight_norm(l)
317
- for l in self.convs2:
318
- remove_weight_norm(l)
319
-
320
-
321
- class ResBlock2(torch.nn.Module):
322
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323
- super(ResBlock2, self).__init__()
324
- self.convs = nn.ModuleList(
325
- [
326
- weight_norm(
327
- Conv1d(
328
- channels,
329
- channels,
330
- kernel_size,
331
- 1,
332
- dilation=dilation[0],
333
- padding=get_padding(kernel_size, dilation[0]),
334
- )
335
- ),
336
- weight_norm(
337
- Conv1d(
338
- channels,
339
- channels,
340
- kernel_size,
341
- 1,
342
- dilation=dilation[1],
343
- padding=get_padding(kernel_size, dilation[1]),
344
- )
345
- ),
346
- ]
347
- )
348
- self.convs.apply(init_weights)
349
-
350
- def forward(self, x, x_mask=None):
351
- for c in self.convs:
352
- xt = F.leaky_relu(x, LRELU_SLOPE)
353
- if x_mask is not None:
354
- xt = xt * x_mask
355
- xt = c(xt)
356
- x = xt + x
357
- if x_mask is not None:
358
- x = x * x_mask
359
- return x
360
-
361
- def remove_weight_norm(self):
362
- for l in self.convs:
363
- remove_weight_norm(l)
364
-
365
-
366
- class Log(nn.Module):
367
- def forward(self, x, x_mask, reverse=False, **kwargs):
368
- if not reverse:
369
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370
- logdet = torch.sum(-y, [1, 2])
371
- return y, logdet
372
- else:
373
- x = torch.exp(x) * x_mask
374
- return x
375
-
376
-
377
- class Flip(nn.Module):
378
- def forward(self, x, *args, reverse=False, **kwargs):
379
- x = torch.flip(x, [1])
380
- if not reverse:
381
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382
- return x, logdet
383
- else:
384
- return x
385
-
386
-
387
- class ElementwiseAffine(nn.Module):
388
- def __init__(self, channels):
389
- super().__init__()
390
- self.channels = channels
391
- self.m = nn.Parameter(torch.zeros(channels, 1))
392
- self.logs = nn.Parameter(torch.zeros(channels, 1))
393
-
394
- def forward(self, x, x_mask, reverse=False, **kwargs):
395
- if not reverse:
396
- y = self.m + torch.exp(self.logs) * x
397
- y = y * x_mask
398
- logdet = torch.sum(self.logs * x_mask, [1, 2])
399
- return y, logdet
400
- else:
401
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
402
- return x
403
-
404
-
405
- class ResidualCouplingLayer(nn.Module):
406
- def __init__(
407
- self,
408
- channels,
409
- hidden_channels,
410
- kernel_size,
411
- dilation_rate,
412
- n_layers,
413
- p_dropout=0,
414
- gin_channels=0,
415
- mean_only=False,
416
- ):
417
- assert channels % 2 == 0, "channels should be divisible by 2"
418
- super().__init__()
419
- self.channels = channels
420
- self.hidden_channels = hidden_channels
421
- self.kernel_size = kernel_size
422
- self.dilation_rate = dilation_rate
423
- self.n_layers = n_layers
424
- self.half_channels = channels // 2
425
- self.mean_only = mean_only
426
-
427
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428
- self.enc = WN(
429
- hidden_channels,
430
- kernel_size,
431
- dilation_rate,
432
- n_layers,
433
- p_dropout=p_dropout,
434
- gin_channels=gin_channels,
435
- )
436
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437
- self.post.weight.data.zero_()
438
- self.post.bias.data.zero_()
439
-
440
- def forward(self, x, x_mask, g=None, reverse=False):
441
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442
- h = self.pre(x0) * x_mask
443
- h = self.enc(h, x_mask, g=g)
444
- stats = self.post(h) * x_mask
445
- if not self.mean_only:
446
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447
- else:
448
- m = stats
449
- logs = torch.zeros_like(m)
450
-
451
- if not reverse:
452
- x1 = m + x1 * torch.exp(logs) * x_mask
453
- x = torch.cat([x0, x1], 1)
454
- logdet = torch.sum(logs, [1, 2])
455
- return x, logdet
456
- else:
457
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
458
- x = torch.cat([x0, x1], 1)
459
- return x
460
-
461
- def remove_weight_norm(self):
462
- self.enc.remove_weight_norm()
463
-
464
-
465
- class ConvFlow(nn.Module):
466
- def __init__(
467
- self,
468
- in_channels,
469
- filter_channels,
470
- kernel_size,
471
- n_layers,
472
- num_bins=10,
473
- tail_bound=5.0,
474
- ):
475
- super().__init__()
476
- self.in_channels = in_channels
477
- self.filter_channels = filter_channels
478
- self.kernel_size = kernel_size
479
- self.n_layers = n_layers
480
- self.num_bins = num_bins
481
- self.tail_bound = tail_bound
482
- self.half_channels = in_channels // 2
483
-
484
- self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
485
- self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
486
- self.proj = nn.Conv1d(
487
- filter_channels, self.half_channels * (num_bins * 3 - 1), 1
488
- )
489
- self.proj.weight.data.zero_()
490
- self.proj.bias.data.zero_()
491
-
492
- def forward(self, x, x_mask, g=None, reverse=False):
493
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
494
- h = self.pre(x0)
495
- h = self.convs(h, x_mask, g=g)
496
- h = self.proj(h) * x_mask
497
-
498
- b, c, t = x0.shape
499
- h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
500
-
501
- unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
502
- unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
503
- self.filter_channels
504
- )
505
- unnormalized_derivatives = h[..., 2 * self.num_bins :]
506
-
507
- x1, logabsdet = piecewise_rational_quadratic_transform(
508
- x1,
509
- unnormalized_widths,
510
- unnormalized_heights,
511
- unnormalized_derivatives,
512
- inverse=reverse,
513
- tails="linear",
514
- tail_bound=self.tail_bound,
515
- )
516
-
517
- x = torch.cat([x0, x1], 1) * x_mask
518
- logdet = torch.sum(logabsdet * x_mask, [1, 2])
519
- if not reverse:
520
- return x, logdet
521
- else:
522
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/infer_pack/modules/F0Predictor/DioF0Predictor.py DELETED
@@ -1,90 +0,0 @@
1
- from lib.infer_pack.modules.F0Predictor.F0Predictor import F0Predictor
2
- import pyworld
3
- import numpy as np
4
-
5
-
6
- class DioF0Predictor(F0Predictor):
7
- def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
8
- self.hop_length = hop_length
9
- self.f0_min = f0_min
10
- self.f0_max = f0_max
11
- self.sampling_rate = sampling_rate
12
-
13
- def interpolate_f0(self, f0):
14
- """
15
- 对F0进行插值处理
16
- """
17
-
18
- data = np.reshape(f0, (f0.size, 1))
19
-
20
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
21
- vuv_vector[data > 0.0] = 1.0
22
- vuv_vector[data <= 0.0] = 0.0
23
-
24
- ip_data = data
25
-
26
- frame_number = data.size
27
- last_value = 0.0
28
- for i in range(frame_number):
29
- if data[i] <= 0.0:
30
- j = i + 1
31
- for j in range(i + 1, frame_number):
32
- if data[j] > 0.0:
33
- break
34
- if j < frame_number - 1:
35
- if last_value > 0.0:
36
- step = (data[j] - data[i - 1]) / float(j - i)
37
- for k in range(i, j):
38
- ip_data[k] = data[i - 1] + step * (k - i + 1)
39
- else:
40
- for k in range(i, j):
41
- ip_data[k] = data[j]
42
- else:
43
- for k in range(i, frame_number):
44
- ip_data[k] = last_value
45
- else:
46
- ip_data[i] = data[i] # 这里可能存在一个没有必要的拷贝
47
- last_value = data[i]
48
-
49
- return ip_data[:, 0], vuv_vector[:, 0]
50
-
51
- def resize_f0(self, x, target_len):
52
- source = np.array(x)
53
- source[source < 0.001] = np.nan
54
- target = np.interp(
55
- np.arange(0, len(source) * target_len, len(source)) / target_len,
56
- np.arange(0, len(source)),
57
- source,
58
- )
59
- res = np.nan_to_num(target)
60
- return res
61
-
62
- def compute_f0(self, wav, p_len=None):
63
- if p_len is None:
64
- p_len = wav.shape[0] // self.hop_length
65
- f0, t = pyworld.dio(
66
- wav.astype(np.double),
67
- fs=self.sampling_rate,
68
- f0_floor=self.f0_min,
69
- f0_ceil=self.f0_max,
70
- frame_period=1000 * self.hop_length / self.sampling_rate,
71
- )
72
- f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
73
- for index, pitch in enumerate(f0):
74
- f0[index] = round(pitch, 1)
75
- return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
76
-
77
- def compute_f0_uv(self, wav, p_len=None):
78
- if p_len is None:
79
- p_len = wav.shape[0] // self.hop_length
80
- f0, t = pyworld.dio(
81
- wav.astype(np.double),
82
- fs=self.sampling_rate,
83
- f0_floor=self.f0_min,
84
- f0_ceil=self.f0_max,
85
- frame_period=1000 * self.hop_length / self.sampling_rate,
86
- )
87
- f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
88
- for index, pitch in enumerate(f0):
89
- f0[index] = round(pitch, 1)
90
- return self.interpolate_f0(self.resize_f0(f0, p_len))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/infer_pack/modules/F0Predictor/F0Predictor.py DELETED
@@ -1,16 +0,0 @@
1
- class F0Predictor(object):
2
- def compute_f0(self, wav, p_len):
3
- """
4
- input: wav:[signal_length]
5
- p_len:int
6
- output: f0:[signal_length//hop_length]
7
- """
8
- pass
9
-
10
- def compute_f0_uv(self, wav, p_len):
11
- """
12
- input: wav:[signal_length]
13
- p_len:int
14
- output: f0:[signal_length//hop_length],uv:[signal_length//hop_length]
15
- """
16
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/infer_pack/modules/F0Predictor/HarvestF0Predictor.py DELETED
@@ -1,86 +0,0 @@
1
- from lib.infer_pack.modules.F0Predictor.F0Predictor import F0Predictor
2
- import pyworld
3
- import numpy as np
4
-
5
-
6
- class HarvestF0Predictor(F0Predictor):
7
- def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
8
- self.hop_length = hop_length
9
- self.f0_min = f0_min
10
- self.f0_max = f0_max
11
- self.sampling_rate = sampling_rate
12
-
13
- def interpolate_f0(self, f0):
14
- """
15
- 对F0进行插值处理
16
- """
17
-
18
- data = np.reshape(f0, (f0.size, 1))
19
-
20
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
21
- vuv_vector[data > 0.0] = 1.0
22
- vuv_vector[data <= 0.0] = 0.0
23
-
24
- ip_data = data
25
-
26
- frame_number = data.size
27
- last_value = 0.0
28
- for i in range(frame_number):
29
- if data[i] <= 0.0:
30
- j = i + 1
31
- for j in range(i + 1, frame_number):
32
- if data[j] > 0.0:
33
- break
34
- if j < frame_number - 1:
35
- if last_value > 0.0:
36
- step = (data[j] - data[i - 1]) / float(j - i)
37
- for k in range(i, j):
38
- ip_data[k] = data[i - 1] + step * (k - i + 1)
39
- else:
40
- for k in range(i, j):
41
- ip_data[k] = data[j]
42
- else:
43
- for k in range(i, frame_number):
44
- ip_data[k] = last_value
45
- else:
46
- ip_data[i] = data[i] # 这里可能存在一个没有必要的拷贝
47
- last_value = data[i]
48
-
49
- return ip_data[:, 0], vuv_vector[:, 0]
50
-
51
- def resize_f0(self, x, target_len):
52
- source = np.array(x)
53
- source[source < 0.001] = np.nan
54
- target = np.interp(
55
- np.arange(0, len(source) * target_len, len(source)) / target_len,
56
- np.arange(0, len(source)),
57
- source,
58
- )
59
- res = np.nan_to_num(target)
60
- return res
61
-
62
- def compute_f0(self, wav, p_len=None):
63
- if p_len is None:
64
- p_len = wav.shape[0] // self.hop_length
65
- f0, t = pyworld.harvest(
66
- wav.astype(np.double),
67
- fs=self.hop_length,
68
- f0_ceil=self.f0_max,
69
- f0_floor=self.f0_min,
70
- frame_period=1000 * self.hop_length / self.sampling_rate,
71
- )
72
- f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.fs)
73
- return self.interpolate_f0(self.resize_f0(f0, p_len))[0]
74
-
75
- def compute_f0_uv(self, wav, p_len=None):
76
- if p_len is None:
77
- p_len = wav.shape[0] // self.hop_length
78
- f0, t = pyworld.harvest(
79
- wav.astype(np.double),
80
- fs=self.sampling_rate,
81
- f0_floor=self.f0_min,
82
- f0_ceil=self.f0_max,
83
- frame_period=1000 * self.hop_length / self.sampling_rate,
84
- )
85
- f0 = pyworld.stonemask(wav.astype(np.double), f0, t, self.sampling_rate)
86
- return self.interpolate_f0(self.resize_f0(f0, p_len))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/infer_pack/modules/F0Predictor/PMF0Predictor.py DELETED
@@ -1,97 +0,0 @@
1
- from lib.infer_pack.modules.F0Predictor.F0Predictor import F0Predictor
2
- import parselmouth
3
- import numpy as np
4
-
5
-
6
- class PMF0Predictor(F0Predictor):
7
- def __init__(self, hop_length=512, f0_min=50, f0_max=1100, sampling_rate=44100):
8
- self.hop_length = hop_length
9
- self.f0_min = f0_min
10
- self.f0_max = f0_max
11
- self.sampling_rate = sampling_rate
12
-
13
- def interpolate_f0(self, f0):
14
- """
15
- 对F0进行插值处理
16
- """
17
-
18
- data = np.reshape(f0, (f0.size, 1))
19
-
20
- vuv_vector = np.zeros((data.size, 1), dtype=np.float32)
21
- vuv_vector[data > 0.0] = 1.0
22
- vuv_vector[data <= 0.0] = 0.0
23
-
24
- ip_data = data
25
-
26
- frame_number = data.size
27
- last_value = 0.0
28
- for i in range(frame_number):
29
- if data[i] <= 0.0:
30
- j = i + 1
31
- for j in range(i + 1, frame_number):
32
- if data[j] > 0.0:
33
- break
34
- if j < frame_number - 1:
35
- if last_value > 0.0:
36
- step = (data[j] - data[i - 1]) / float(j - i)
37
- for k in range(i, j):
38
- ip_data[k] = data[i - 1] + step * (k - i + 1)
39
- else:
40
- for k in range(i, j):
41
- ip_data[k] = data[j]
42
- else:
43
- for k in range(i, frame_number):
44
- ip_data[k] = last_value
45
- else:
46
- ip_data[i] = data[i] # 这里可能存在一个没有必要的拷贝
47
- last_value = data[i]
48
-
49
- return ip_data[:, 0], vuv_vector[:, 0]
50
-
51
- def compute_f0(self, wav, p_len=None):
52
- x = wav
53
- if p_len is None:
54
- p_len = x.shape[0] // self.hop_length
55
- else:
56
- assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
57
- time_step = self.hop_length / self.sampling_rate * 1000
58
- f0 = (
59
- parselmouth.Sound(x, self.sampling_rate)
60
- .to_pitch_ac(
61
- time_step=time_step / 1000,
62
- voicing_threshold=0.6,
63
- pitch_floor=self.f0_min,
64
- pitch_ceiling=self.f0_max,
65
- )
66
- .selected_array["frequency"]
67
- )
68
-
69
- pad_size = (p_len - len(f0) + 1) // 2
70
- if pad_size > 0 or p_len - len(f0) - pad_size > 0:
71
- f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
72
- f0, uv = self.interpolate_f0(f0)
73
- return f0
74
-
75
- def compute_f0_uv(self, wav, p_len=None):
76
- x = wav
77
- if p_len is None:
78
- p_len = x.shape[0] // self.hop_length
79
- else:
80
- assert abs(p_len - x.shape[0] // self.hop_length) < 4, "pad length error"
81
- time_step = self.hop_length / self.sampling_rate * 1000
82
- f0 = (
83
- parselmouth.Sound(x, self.sampling_rate)
84
- .to_pitch_ac(
85
- time_step=time_step / 1000,
86
- voicing_threshold=0.6,
87
- pitch_floor=self.f0_min,
88
- pitch_ceiling=self.f0_max,
89
- )
90
- .selected_array["frequency"]
91
- )
92
-
93
- pad_size = (p_len - len(f0) + 1) // 2
94
- if pad_size > 0 or p_len - len(f0) - pad_size > 0:
95
- f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
96
- f0, uv = self.interpolate_f0(f0)
97
- return f0, uv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/infer_pack/modules/F0Predictor/__init__.py DELETED
File without changes
lib/infer_pack/transforms.py DELETED
@@ -1,209 +0,0 @@
1
- import torch
2
- from torch.nn import functional as F
3
-
4
- import numpy as np
5
-
6
-
7
- DEFAULT_MIN_BIN_WIDTH = 1e-3
8
- DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
- DEFAULT_MIN_DERIVATIVE = 1e-3
10
-
11
-
12
- def piecewise_rational_quadratic_transform(
13
- inputs,
14
- unnormalized_widths,
15
- unnormalized_heights,
16
- unnormalized_derivatives,
17
- inverse=False,
18
- tails=None,
19
- tail_bound=1.0,
20
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
- min_derivative=DEFAULT_MIN_DERIVATIVE,
23
- ):
24
- if tails is None:
25
- spline_fn = rational_quadratic_spline
26
- spline_kwargs = {}
27
- else:
28
- spline_fn = unconstrained_rational_quadratic_spline
29
- spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
-
31
- outputs, logabsdet = spline_fn(
32
- inputs=inputs,
33
- unnormalized_widths=unnormalized_widths,
34
- unnormalized_heights=unnormalized_heights,
35
- unnormalized_derivatives=unnormalized_derivatives,
36
- inverse=inverse,
37
- min_bin_width=min_bin_width,
38
- min_bin_height=min_bin_height,
39
- min_derivative=min_derivative,
40
- **spline_kwargs
41
- )
42
- return outputs, logabsdet
43
-
44
-
45
- def searchsorted(bin_locations, inputs, eps=1e-6):
46
- bin_locations[..., -1] += eps
47
- return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
-
49
-
50
- def unconstrained_rational_quadratic_spline(
51
- inputs,
52
- unnormalized_widths,
53
- unnormalized_heights,
54
- unnormalized_derivatives,
55
- inverse=False,
56
- tails="linear",
57
- tail_bound=1.0,
58
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
- min_derivative=DEFAULT_MIN_DERIVATIVE,
61
- ):
62
- inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
- outside_interval_mask = ~inside_interval_mask
64
-
65
- outputs = torch.zeros_like(inputs)
66
- logabsdet = torch.zeros_like(inputs)
67
-
68
- if tails == "linear":
69
- unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
- constant = np.log(np.exp(1 - min_derivative) - 1)
71
- unnormalized_derivatives[..., 0] = constant
72
- unnormalized_derivatives[..., -1] = constant
73
-
74
- outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
- logabsdet[outside_interval_mask] = 0
76
- else:
77
- raise RuntimeError("{} tails are not implemented.".format(tails))
78
-
79
- (
80
- outputs[inside_interval_mask],
81
- logabsdet[inside_interval_mask],
82
- ) = rational_quadratic_spline(
83
- inputs=inputs[inside_interval_mask],
84
- unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
- unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
- unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
- inverse=inverse,
88
- left=-tail_bound,
89
- right=tail_bound,
90
- bottom=-tail_bound,
91
- top=tail_bound,
92
- min_bin_width=min_bin_width,
93
- min_bin_height=min_bin_height,
94
- min_derivative=min_derivative,
95
- )
96
-
97
- return outputs, logabsdet
98
-
99
-
100
- def rational_quadratic_spline(
101
- inputs,
102
- unnormalized_widths,
103
- unnormalized_heights,
104
- unnormalized_derivatives,
105
- inverse=False,
106
- left=0.0,
107
- right=1.0,
108
- bottom=0.0,
109
- top=1.0,
110
- min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
- min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
- min_derivative=DEFAULT_MIN_DERIVATIVE,
113
- ):
114
- if torch.min(inputs) < left or torch.max(inputs) > right:
115
- raise ValueError("Input to a transform is not within its domain")
116
-
117
- num_bins = unnormalized_widths.shape[-1]
118
-
119
- if min_bin_width * num_bins > 1.0:
120
- raise ValueError("Minimal bin width too large for the number of bins")
121
- if min_bin_height * num_bins > 1.0:
122
- raise ValueError("Minimal bin height too large for the number of bins")
123
-
124
- widths = F.softmax(unnormalized_widths, dim=-1)
125
- widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
- cumwidths = torch.cumsum(widths, dim=-1)
127
- cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
- cumwidths = (right - left) * cumwidths + left
129
- cumwidths[..., 0] = left
130
- cumwidths[..., -1] = right
131
- widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
-
133
- derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
-
135
- heights = F.softmax(unnormalized_heights, dim=-1)
136
- heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
- cumheights = torch.cumsum(heights, dim=-1)
138
- cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
- cumheights = (top - bottom) * cumheights + bottom
140
- cumheights[..., 0] = bottom
141
- cumheights[..., -1] = top
142
- heights = cumheights[..., 1:] - cumheights[..., :-1]
143
-
144
- if inverse:
145
- bin_idx = searchsorted(cumheights, inputs)[..., None]
146
- else:
147
- bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
-
149
- input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
- input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
-
152
- input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
- delta = heights / widths
154
- input_delta = delta.gather(-1, bin_idx)[..., 0]
155
-
156
- input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
- input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
-
159
- input_heights = heights.gather(-1, bin_idx)[..., 0]
160
-
161
- if inverse:
162
- a = (inputs - input_cumheights) * (
163
- input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
- ) + input_heights * (input_delta - input_derivatives)
165
- b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
- input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
- )
168
- c = -input_delta * (inputs - input_cumheights)
169
-
170
- discriminant = b.pow(2) - 4 * a * c
171
- assert (discriminant >= 0).all()
172
-
173
- root = (2 * c) / (-b - torch.sqrt(discriminant))
174
- outputs = root * input_bin_widths + input_cumwidths
175
-
176
- theta_one_minus_theta = root * (1 - root)
177
- denominator = input_delta + (
178
- (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
- * theta_one_minus_theta
180
- )
181
- derivative_numerator = input_delta.pow(2) * (
182
- input_derivatives_plus_one * root.pow(2)
183
- + 2 * input_delta * theta_one_minus_theta
184
- + input_derivatives * (1 - root).pow(2)
185
- )
186
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
-
188
- return outputs, -logabsdet
189
- else:
190
- theta = (inputs - input_cumwidths) / input_bin_widths
191
- theta_one_minus_theta = theta * (1 - theta)
192
-
193
- numerator = input_heights * (
194
- input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
- )
196
- denominator = input_delta + (
197
- (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
- * theta_one_minus_theta
199
- )
200
- outputs = input_cumheights + numerator / denominator
201
-
202
- derivative_numerator = input_delta.pow(2) * (
203
- input_derivatives_plus_one * theta.pow(2)
204
- + 2 * input_delta * theta_one_minus_theta
205
- + input_derivatives * (1 - theta).pow(2)
206
- )
207
- logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
-
209
- return outputs, logabsdet