Spaces:
Running
Running
ORI-Muchim
commited on
Commit
•
84001e4
1
Parent(s):
d4d0d61
Update attentions.py
Browse files- attentions.py +159 -5
attentions.py
CHANGED
@@ -1,14 +1,19 @@
|
|
|
|
1 |
import math
|
|
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
from torch.nn import functional as F
|
|
|
5 |
|
6 |
import commons
|
|
|
7 |
from modules import LayerNorm
|
8 |
-
|
9 |
|
10 |
-
class Encoder(nn.Module):
|
11 |
-
def __init__(
|
|
|
|
|
12 |
super().__init__()
|
13 |
self.hidden_channels = hidden_channels
|
14 |
self.filter_channels = filter_channels
|
@@ -23,16 +28,32 @@ class Encoder(nn.Module):
|
|
23 |
self.norm_layers_1 = nn.ModuleList()
|
24 |
self.ffn_layers = nn.ModuleList()
|
25 |
self.norm_layers_2 = nn.ModuleList()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
for i in range(self.n_layers):
|
27 |
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
|
28 |
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
29 |
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
30 |
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
31 |
|
32 |
-
def forward(self, x, x_mask):
|
33 |
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
34 |
x = x * x_mask
|
35 |
for i in range(self.n_layers):
|
|
|
|
|
|
|
|
|
|
|
36 |
y = self.attn_layers[i](x, x, attn_mask)
|
37 |
y = self.drop(y)
|
38 |
x = self.norm_layers_1[i](x + y)
|
@@ -43,7 +64,6 @@ class Encoder(nn.Module):
|
|
43 |
x = x * x_mask
|
44 |
return x
|
45 |
|
46 |
-
|
47 |
class Decoder(nn.Module):
|
48 |
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
|
49 |
super().__init__()
|
@@ -298,3 +318,137 @@ class FFN(nn.Module):
|
|
298 |
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
299 |
x = F.pad(x, commons.convert_pad_shape(padding))
|
300 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
import math
|
3 |
+
import numpy as np
|
4 |
import torch
|
5 |
from torch import nn
|
6 |
from torch.nn import functional as F
|
7 |
+
from torch.nn.utils import remove_weight_norm, weight_norm
|
8 |
|
9 |
import commons
|
10 |
+
import modules
|
11 |
from modules import LayerNorm
|
|
|
12 |
|
13 |
+
class Encoder(nn.Module): #backward compatible vits2 encoder
|
14 |
+
def __init__(
|
15 |
+
self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs
|
16 |
+
):
|
17 |
super().__init__()
|
18 |
self.hidden_channels = hidden_channels
|
19 |
self.filter_channels = filter_channels
|
|
|
28 |
self.norm_layers_1 = nn.ModuleList()
|
29 |
self.ffn_layers = nn.ModuleList()
|
30 |
self.norm_layers_2 = nn.ModuleList()
|
31 |
+
# if kwargs has spk_emb_dim, then add a linear layer to project spk_emb_dim to hidden_channels
|
32 |
+
self.cond_layer_idx = self.n_layers
|
33 |
+
if 'gin_channels' in kwargs:
|
34 |
+
self.gin_channels = kwargs['gin_channels']
|
35 |
+
if self.gin_channels != 0:
|
36 |
+
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
37 |
+
# vits2 says 3rd block, so idx is 2 by default
|
38 |
+
self.cond_layer_idx = kwargs['cond_layer_idx'] if 'cond_layer_idx' in kwargs else 2
|
39 |
+
print(self.gin_channels, self.cond_layer_idx)
|
40 |
+
assert self.cond_layer_idx < self.n_layers, 'cond_layer_idx should be less than n_layers'
|
41 |
+
|
42 |
for i in range(self.n_layers):
|
43 |
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
|
44 |
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
45 |
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
46 |
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
47 |
|
48 |
+
def forward(self, x, x_mask, g=None):
|
49 |
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
50 |
x = x * x_mask
|
51 |
for i in range(self.n_layers):
|
52 |
+
if i == self.cond_layer_idx and g is not None:
|
53 |
+
g = self.spk_emb_linear(g.transpose(1, 2))
|
54 |
+
g = g.transpose(1, 2)
|
55 |
+
x = x + g
|
56 |
+
x = x * x_mask
|
57 |
y = self.attn_layers[i](x, x, attn_mask)
|
58 |
y = self.drop(y)
|
59 |
x = self.norm_layers_1[i](x + y)
|
|
|
64 |
x = x * x_mask
|
65 |
return x
|
66 |
|
|
|
67 |
class Decoder(nn.Module):
|
68 |
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs):
|
69 |
super().__init__()
|
|
|
318 |
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
319 |
x = F.pad(x, commons.convert_pad_shape(padding))
|
320 |
return x
|
321 |
+
|
322 |
+
|
323 |
+
class Depthwise_Separable_Conv1D(nn.Module):
|
324 |
+
def __init__(
|
325 |
+
self,
|
326 |
+
in_channels,
|
327 |
+
out_channels,
|
328 |
+
kernel_size,
|
329 |
+
stride = 1,
|
330 |
+
padding = 0,
|
331 |
+
dilation = 1,
|
332 |
+
bias = True,
|
333 |
+
padding_mode = 'zeros', # TODO: refine this type
|
334 |
+
device=None,
|
335 |
+
dtype=None
|
336 |
+
):
|
337 |
+
super().__init__()
|
338 |
+
self.depth_conv = nn.Conv1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
|
339 |
+
self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
|
340 |
+
|
341 |
+
def forward(self, input):
|
342 |
+
return self.point_conv(self.depth_conv(input))
|
343 |
+
|
344 |
+
def weight_norm(self):
|
345 |
+
self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
|
346 |
+
self.point_conv = weight_norm(self.point_conv, name = 'weight')
|
347 |
+
|
348 |
+
def remove_weight_norm(self):
|
349 |
+
self.depth_conv = remove_weight_norm(self.depth_conv, name = 'weight')
|
350 |
+
self.point_conv = remove_weight_norm(self.point_conv, name = 'weight')
|
351 |
+
|
352 |
+
class Depthwise_Separable_TransposeConv1D(nn.Module):
|
353 |
+
def __init__(
|
354 |
+
self,
|
355 |
+
in_channels,
|
356 |
+
out_channels,
|
357 |
+
kernel_size,
|
358 |
+
stride = 1,
|
359 |
+
padding = 0,
|
360 |
+
output_padding = 0,
|
361 |
+
bias = True,
|
362 |
+
dilation = 1,
|
363 |
+
padding_mode = 'zeros', # TODO: refine this type
|
364 |
+
device=None,
|
365 |
+
dtype=None
|
366 |
+
):
|
367 |
+
super().__init__()
|
368 |
+
self.depth_conv = nn.ConvTranspose1d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, groups=in_channels,stride = stride,output_padding=output_padding,padding=padding,dilation=dilation,bias=bias,padding_mode=padding_mode,device=device,dtype=dtype)
|
369 |
+
self.point_conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias, device=device,dtype=dtype)
|
370 |
+
|
371 |
+
def forward(self, input):
|
372 |
+
return self.point_conv(self.depth_conv(input))
|
373 |
+
|
374 |
+
def weight_norm(self):
|
375 |
+
self.depth_conv = weight_norm(self.depth_conv, name = 'weight')
|
376 |
+
self.point_conv = weight_norm(self.point_conv, name = 'weight')
|
377 |
+
|
378 |
+
def remove_weight_norm(self):
|
379 |
+
remove_weight_norm(self.depth_conv, name = 'weight')
|
380 |
+
remove_weight_norm(self.point_conv, name = 'weight')
|
381 |
+
|
382 |
+
|
383 |
+
def weight_norm_modules(module, name = 'weight', dim = 0):
|
384 |
+
if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
|
385 |
+
module.weight_norm()
|
386 |
+
return module
|
387 |
+
else:
|
388 |
+
return weight_norm(module,name,dim)
|
389 |
+
|
390 |
+
def remove_weight_norm_modules(module, name = 'weight'):
|
391 |
+
if isinstance(module,Depthwise_Separable_Conv1D) or isinstance(module,Depthwise_Separable_TransposeConv1D):
|
392 |
+
module.remove_weight_norm()
|
393 |
+
else:
|
394 |
+
remove_weight_norm(module,name)
|
395 |
+
|
396 |
+
class FFT(nn.Module):
|
397 |
+
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers=1, kernel_size=1, p_dropout=0.,
|
398 |
+
proximal_bias=False, proximal_init=True, isflow = False, **kwargs):
|
399 |
+
super().__init__()
|
400 |
+
self.hidden_channels = hidden_channels
|
401 |
+
self.filter_channels = filter_channels
|
402 |
+
self.n_heads = n_heads
|
403 |
+
self.n_layers = n_layers
|
404 |
+
self.kernel_size = kernel_size
|
405 |
+
self.p_dropout = p_dropout
|
406 |
+
self.proximal_bias = proximal_bias
|
407 |
+
self.proximal_init = proximal_init
|
408 |
+
if isflow and 'gin_channels' in kwargs and kwargs["gin_channels"] > 0:
|
409 |
+
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2*hidden_channels*n_layers, 1)
|
410 |
+
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
|
411 |
+
self.cond_layer = weight_norm_modules(cond_layer, name='weight')
|
412 |
+
self.gin_channels = kwargs["gin_channels"]
|
413 |
+
self.drop = nn.Dropout(p_dropout)
|
414 |
+
self.self_attn_layers = nn.ModuleList()
|
415 |
+
self.norm_layers_0 = nn.ModuleList()
|
416 |
+
self.ffn_layers = nn.ModuleList()
|
417 |
+
self.norm_layers_1 = nn.ModuleList()
|
418 |
+
for i in range(self.n_layers):
|
419 |
+
self.self_attn_layers.append(
|
420 |
+
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias,
|
421 |
+
proximal_init=proximal_init))
|
422 |
+
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
423 |
+
self.ffn_layers.append(
|
424 |
+
FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True))
|
425 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
426 |
+
|
427 |
+
def forward(self, x, x_mask, g = None):
|
428 |
+
"""
|
429 |
+
x: decoder input
|
430 |
+
h: encoder output
|
431 |
+
"""
|
432 |
+
if g is not None:
|
433 |
+
g = self.cond_layer(g)
|
434 |
+
|
435 |
+
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
436 |
+
x = x * x_mask
|
437 |
+
for i in range(self.n_layers):
|
438 |
+
if g is not None:
|
439 |
+
x = self.cond_pre(x)
|
440 |
+
cond_offset = i * 2 * self.hidden_channels
|
441 |
+
g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:]
|
442 |
+
x = commons.fused_add_tanh_sigmoid_multiply(
|
443 |
+
x,
|
444 |
+
g_l,
|
445 |
+
torch.IntTensor([self.hidden_channels]))
|
446 |
+
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
447 |
+
y = self.drop(y)
|
448 |
+
x = self.norm_layers_0[i](x + y)
|
449 |
+
|
450 |
+
y = self.ffn_layers[i](x, x_mask)
|
451 |
+
y = self.drop(y)
|
452 |
+
x = self.norm_layers_1[i](x + y)
|
453 |
+
x = x * x_mask
|
454 |
+
return x
|