Spaces:
Runtime error
Runtime error
Commit
ยท
5e764fc
1
Parent(s):
8346e07
Update Tmodel.py (#31)
Browse files- Update Tmodel.py (373cacd4513cfe9aa2725ed990eac0763d000816)
Tmodel.py
CHANGED
@@ -14,7 +14,7 @@ class GlowTTS(nn.Module):
|
|
14 |
self.encoder = Encoder()
|
15 |
self.decoder = Decoder()
|
16 |
|
17 |
-
def forward(self, text, text_len, mel=None, mel_len=None, inference=False):
|
18 |
"""
|
19 |
=====inputs=====
|
20 |
text: (B, T)
|
@@ -45,7 +45,7 @@ class GlowTTS(nn.Module):
|
|
45 |
if not inference: # training
|
46 |
y_max_len = y.size(2)
|
47 |
else: # inference
|
48 |
-
dur = torch.exp(x_log_dur) * x_mask # (B, 1, T)
|
49 |
ceil_dur = torch.ceil(dur) # (B, 1, T)
|
50 |
y_len = torch.clamp_min(torch.sum(ceil_dur, [1, 2]), 1).long() # (B)
|
51 |
# ceil_dur์ [1, 2] ์ถ์ ๋ํด sumํ ๋ค ์ต์๊ฐ์ด 1์ด์์ด ๋๋๋ก ์ค์ . ์ ์ long ํ์
์ผ๋ก ๋ฐํํ๋ค.
|
@@ -99,7 +99,7 @@ class GlowTTS(nn.Module):
|
|
99 |
z_log_std = z_log_std.transpose(1, 2) # (B, 80, F)
|
100 |
log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment์์ ํ์ฑ๋ duration์ log scale
|
101 |
|
102 |
-
z = (z_mean + torch.exp(z_log_std) * torch.randn_like(z_mean)) * z_mask # z(latent representation) ์์ฑ
|
103 |
y, log_det = self.decoder(z, z_mask, reverse=True) # mel-spectrogram ์์ฑ
|
104 |
return (y, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d)
|
105 |
|
|
|
14 |
self.encoder = Encoder()
|
15 |
self.decoder = Decoder()
|
16 |
|
17 |
+
def forward(self, text, text_len, mel=None, mel_len=None, inference=False, noise_scale=1., length_scale=1.):
|
18 |
"""
|
19 |
=====inputs=====
|
20 |
text: (B, T)
|
|
|
45 |
if not inference: # training
|
46 |
y_max_len = y.size(2)
|
47 |
else: # inference
|
48 |
+
dur = torch.exp(x_log_dur) * x_mask * length_scale # (B, 1, T)
|
49 |
ceil_dur = torch.ceil(dur) # (B, 1, T)
|
50 |
y_len = torch.clamp_min(torch.sum(ceil_dur, [1, 2]), 1).long() # (B)
|
51 |
# ceil_dur์ [1, 2] ์ถ์ ๋ํด sumํ ๋ค ์ต์๊ฐ์ด 1์ด์์ด ๋๋๋ก ์ค์ . ์ ์ long ํ์
์ผ๋ก ๋ฐํํ๋ค.
|
|
|
99 |
z_log_std = z_log_std.transpose(1, 2) # (B, 80, F)
|
100 |
log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment์์ ํ์ฑ๋ duration์ log scale
|
101 |
|
102 |
+
z = (z_mean + torch.exp(z_log_std) * torch.randn_like(z_mean) * noise_scale) * z_mask # z(latent representation) ์์ฑ
|
103 |
y, log_det = self.decoder(z, z_mask, reverse=True) # mel-spectrogram ์์ฑ
|
104 |
return (y, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d)
|
105 |
|