Spaces:
Running
Running
simplify (#1)
Browse files- simplify (fd5dde034da333a19a1b05e6a56a80a5f9803b61)
- Utils/JDC/model.py +1 -1
- app.py +2 -3
- models.py +11 -76
Utils/JDC/model.py
CHANGED
@@ -134,7 +134,7 @@ class JDCNet(nn.Module):
|
|
134 |
# sizes: (b, 31, 722), (b, 31, 2)
|
135 |
# classifier output consists of predicted pitch classes per frame
|
136 |
# detector output consists of: (isvoice, notvoice) estimates per frame
|
137 |
-
return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
|
138 |
|
139 |
@staticmethod
|
140 |
def init_weights(m):
|
|
|
134 |
# sizes: (b, 31, 722), (b, 31, 2)
|
135 |
# classifier output consists of predicted pitch classes per frame
|
136 |
# detector output consists of: (isvoice, notvoice) estimates per frame
|
137 |
+
return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
|
138 |
|
139 |
@staticmethod
|
140 |
def init_weights(m):
|
app.py
CHANGED
@@ -13,7 +13,6 @@ from transformers import WavLMModel
|
|
13 |
from env import AttrDict
|
14 |
from meldataset import mel_spectrogram, MAX_WAV_VALUE
|
15 |
from models import Generator
|
16 |
-
from stft import TorchSTFT
|
17 |
from Utils.JDC.model import JDCNet
|
18 |
|
19 |
|
@@ -38,7 +37,6 @@ h = AttrDict(json_config)
|
|
38 |
# load models
|
39 |
F0_model = JDCNet(num_class=1, seq_len=192)
|
40 |
generator = Generator(h, F0_model).to(device)
|
41 |
-
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
|
42 |
|
43 |
state_dict_g = torch.load(ptfile, map_location=device)
|
44 |
generator.load_state_dict(state_dict_g['generator'], strict=True)
|
@@ -84,6 +82,7 @@ def convert(tgt_spk, src_wav, f0_shift=0):
|
|
84 |
spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
|
85 |
|
86 |
f0_mean_tgt = f0_stats[tgt_spk]["mean"]
|
|
|
87 |
|
88 |
# src
|
89 |
wav, sr = librosa.load(src_wav, sr=16000)
|
@@ -98,7 +97,7 @@ def convert(tgt_spk, src_wav, f0_shift=0):
|
|
98 |
f0 = generator.get_f0(mel, f0_mean_tgt)
|
99 |
f0 = tune_f0(f0, f0_shift)
|
100 |
x = generator.get_x(x, spk_emb, spk_id)
|
101 |
-
y = generator.infer(x, f0
|
102 |
|
103 |
audio = y.squeeze()
|
104 |
audio = audio / torch.max(torch.abs(audio)) * 0.95
|
|
|
13 |
from env import AttrDict
|
14 |
from meldataset import mel_spectrogram, MAX_WAV_VALUE
|
15 |
from models import Generator
|
|
|
16 |
from Utils.JDC.model import JDCNet
|
17 |
|
18 |
|
|
|
37 |
# load models
|
38 |
F0_model = JDCNet(num_class=1, seq_len=192)
|
39 |
generator = Generator(h, F0_model).to(device)
|
|
|
40 |
|
41 |
state_dict_g = torch.load(ptfile, map_location=device)
|
42 |
generator.load_state_dict(state_dict_g['generator'], strict=True)
|
|
|
82 |
spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
|
83 |
|
84 |
f0_mean_tgt = f0_stats[tgt_spk]["mean"]
|
85 |
+
f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0).to(device)
|
86 |
|
87 |
# src
|
88 |
wav, sr = librosa.load(src_wav, sr=16000)
|
|
|
97 |
f0 = generator.get_f0(mel, f0_mean_tgt)
|
98 |
f0 = tune_f0(f0, f0_shift)
|
99 |
x = generator.get_x(x, spk_emb, spk_id)
|
100 |
+
y = generator.infer(x, f0)
|
101 |
|
102 |
audio = y.squeeze()
|
103 |
audio = audio / torch.max(torch.abs(audio)) * 0.95
|
models.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import math
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
import torch.nn as nn
|
@@ -486,9 +485,6 @@ class Generator(torch.nn.Module):
|
|
486 |
g = g + spk_emb.unsqueeze(-1)
|
487 |
|
488 |
f0, _, _ = self.F0_model(mel.unsqueeze(1))
|
489 |
-
if len(f0.shape) == 1:
|
490 |
-
f0 = f0.unsqueeze(0)
|
491 |
-
|
492 |
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
493 |
|
494 |
har_source, _, _ = self.m_source(f0)
|
@@ -526,28 +522,21 @@ class Generator(torch.nn.Module):
|
|
526 |
|
527 |
return spec, phase
|
528 |
|
529 |
-
def get_f0(self, mel, f0_mean_tgt, voiced_threshold=10
|
530 |
f0, _, _ = self.F0_model(mel.unsqueeze(1))
|
531 |
-
|
532 |
voiced = f0 > voiced_threshold
|
533 |
|
534 |
lf0 = torch.log(f0)
|
535 |
-
|
536 |
-
|
|
|
|
|
537 |
f0_adj = torch.exp(lf0_adj)
|
538 |
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
|
543 |
-
f0_adj = self.interp_f0(f0_adj.unsqueeze(0), voiced.unsqueeze(0)).squeeze(0)
|
544 |
-
energy = torch.sum(mel.squeeze(0), dim=0) # simple vad
|
545 |
-
unsilent = energy > -700
|
546 |
-
unsilent = unsilent | voiced
|
547 |
-
f0_adj = torch.where(unsilent, f0_adj, 0)
|
548 |
-
|
549 |
-
if len(f0_adj.shape) == 1:
|
550 |
-
f0_adj = f0_adj.unsqueeze(0)
|
551 |
|
552 |
return f0_adj
|
553 |
|
@@ -562,7 +551,7 @@ class Generator(torch.nn.Module):
|
|
562 |
|
563 |
return x
|
564 |
|
565 |
-
def infer(self, x, f0
|
566 |
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
567 |
|
568 |
har_source, _, _ = self.m_source(f0)
|
@@ -593,62 +582,8 @@ class Generator(torch.nn.Module):
|
|
593 |
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
594 |
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
595 |
|
596 |
-
y = stft.inverse(spec, phase)
|
597 |
-
|
598 |
return y
|
599 |
-
|
600 |
-
def interp_f0(self, pitch, voiced):
|
601 |
-
"""Fill unvoiced regions via linear interpolation"""
|
602 |
-
|
603 |
-
# Handle no voiced frames
|
604 |
-
if not voiced.any():
|
605 |
-
return pitch
|
606 |
-
|
607 |
-
# Pitch is linear in base-2 log-space
|
608 |
-
pitch = torch.log2(pitch)
|
609 |
-
|
610 |
-
# Anchor endpoints
|
611 |
-
pitch[..., 0] = pitch[voiced][..., 0]
|
612 |
-
pitch[..., -1] = pitch[voiced][..., -1]
|
613 |
-
voiced[..., 0] = True
|
614 |
-
voiced[..., -1] = True
|
615 |
-
|
616 |
-
# Interpolate
|
617 |
-
pitch[~voiced] = self.interp(
|
618 |
-
torch.where(~voiced[0])[0][None],
|
619 |
-
torch.where(voiced[0])[0][None],
|
620 |
-
pitch[voiced][None])
|
621 |
-
|
622 |
-
return 2 ** pitch
|
623 |
-
|
624 |
-
@staticmethod
|
625 |
-
def interp(x, xp, fp):
|
626 |
-
"""1D linear interpolation for monotonically increasing sample points"""
|
627 |
-
# Handle edge cases
|
628 |
-
if xp.shape[-1] == 0:
|
629 |
-
return x
|
630 |
-
if xp.shape[-1] == 1:
|
631 |
-
return torch.full(
|
632 |
-
x.shape,
|
633 |
-
fp.squeeze(),
|
634 |
-
device=fp.device,
|
635 |
-
dtype=fp.dtype)
|
636 |
-
|
637 |
-
# Get slope and intercept using right-side first-differences
|
638 |
-
m = (fp[:, 1:] - fp[:, :-1]) / (xp[:, 1:] - xp[:, :-1])
|
639 |
-
b = fp[:, :-1] - (m.mul(xp[:, :-1]))
|
640 |
-
|
641 |
-
# Get indices to sample slope and intercept
|
642 |
-
indicies = torch.sum(torch.ge(x[:, :, None], xp[:, None, :]), -1) - 1
|
643 |
-
indicies = torch.clamp(indicies, 0, m.shape[-1] - 1)
|
644 |
-
line_idx = torch.linspace(
|
645 |
-
0,
|
646 |
-
indicies.shape[0],
|
647 |
-
1,
|
648 |
-
device=indicies.device).to(torch.long).expand(indicies.shape)
|
649 |
-
|
650 |
-
# Interpolate
|
651 |
-
return m[line_idx, indicies].mul(x) + b[line_idx, indicies]
|
652 |
|
653 |
def remove_weight_norm(self):
|
654 |
print('Removing weight norm...')
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
import torch.nn as nn
|
|
|
485 |
g = g + spk_emb.unsqueeze(-1)
|
486 |
|
487 |
f0, _, _ = self.F0_model(mel.unsqueeze(1))
|
|
|
|
|
|
|
488 |
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
489 |
|
490 |
har_source, _, _ = self.m_source(f0)
|
|
|
522 |
|
523 |
return spec, phase
|
524 |
|
525 |
+
def get_f0(self, mel, f0_mean_tgt, voiced_threshold=10):
|
526 |
f0, _, _ = self.F0_model(mel.unsqueeze(1))
|
|
|
527 |
voiced = f0 > voiced_threshold
|
528 |
|
529 |
lf0 = torch.log(f0)
|
530 |
+
lf0_ = lf0 * voiced.float()
|
531 |
+
lf0_mean = lf0_.sum(1) / voiced.float().sum(1)
|
532 |
+
lf0_mean = lf0_mean.unsqueeze(1)
|
533 |
+
lf0_adj = lf0 - lf0_mean + torch.log(f0_mean_tgt)
|
534 |
f0_adj = torch.exp(lf0_adj)
|
535 |
|
536 |
+
energy = mel.sum(1)
|
537 |
+
unsilent = energy > -700
|
538 |
+
unsilent = unsilent | voiced # simple vad
|
539 |
+
f0_adj = f0_adj * unsilent.float()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
|
541 |
return f0_adj
|
542 |
|
|
|
551 |
|
552 |
return x
|
553 |
|
554 |
+
def infer(self, x, f0):
|
555 |
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
556 |
|
557 |
har_source, _, _ = self.m_source(f0)
|
|
|
582 |
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
583 |
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
584 |
|
585 |
+
y = self.stft.inverse(spec, phase)
|
|
|
586 |
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
587 |
|
588 |
def remove_weight_norm(self):
|
589 |
print('Removing weight norm...')
|