add precomputed voices, reformat code, remove unused code
Browse files- .idea/.gitignore +8 -0
- .idea/inspectionProfiles/Project_Default.xml +81 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/ruff.xml +11 -0
- .idea/styletts2-personal.iml +12 -0
- .idea/vcs.xml +6 -0
- Modules/diffusion/diffusion.py +0 -6
- Modules/diffusion/modules.py +4 -3
- Modules/diffusion/sampler.py +7 -3
- Modules/diffusion/utils.py +1 -4
- Modules/discriminators.py +8 -7
- Modules/hifigan.py +7 -28
- Modules/istftnet.py +4 -3
- Modules/slmadv.py +1 -3
- Utils/ASR/layers.py +0 -4
- Utils/ASR/models.py +0 -1
- Utils/JDC/model.py +2 -2
- Utils/PLBERT/util.py +0 -1
- _run.py +1 -4
- app.py +101 -143
- compute.py +0 -6
- ljspeechimportable.py +0 -225
- losses.py +0 -1
- meldataset.py +0 -6
- models.py +26 -57
- optimizers.py +0 -5
- requirements.txt +3 -2
- styletts2importable.py +197 -257
- train_finetune.py +5 -10
- train_first.py +3 -9
- train_second.py +6 -11
- utils.py +0 -7
- voices/f-us-1.wav.npy +3 -0
- voices/f-us-2.wav.npy +3 -0
- voices/f-us-3.wav.npy +3 -0
- voices/f-us-4.wav.npy +3 -0
- voices/m-us-1.wav.npy +3 -0
- voices/m-us-2.wav.npy +3 -0
- voices/m-us-3.wav.npy +3 -0
- voices/m-us-4.wav.npy +3 -0
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
5 |
+
<inspection_tool class="PyCompatibilityInspection" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES">
|
6 |
+
<option name="ourVersions">
|
7 |
+
<value>
|
8 |
+
<list size="1">
|
9 |
+
<item index="0" class="java.lang.String" itemvalue="3.10" />
|
10 |
+
</list>
|
11 |
+
</value>
|
12 |
+
</option>
|
13 |
+
</inspection_tool>
|
14 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
15 |
+
<option name="ignoredPackages">
|
16 |
+
<value>
|
17 |
+
<list size="58">
|
18 |
+
<item index="0" class="java.lang.String" itemvalue="pandas" />
|
19 |
+
<item index="1" class="java.lang.String" itemvalue="fastapi" />
|
20 |
+
<item index="2" class="java.lang.String" itemvalue="pydantic" />
|
21 |
+
<item index="3" class="java.lang.String" itemvalue="clickhouse-connect" />
|
22 |
+
<item index="4" class="java.lang.String" itemvalue="uvicorn" />
|
23 |
+
<item index="5" class="java.lang.String" itemvalue="requests" />
|
24 |
+
<item index="6" class="java.lang.String" itemvalue="posthog" />
|
25 |
+
<item index="7" class="java.lang.String" itemvalue="numba" />
|
26 |
+
<item index="8" class="java.lang.String" itemvalue="faiss-cpu" />
|
27 |
+
<item index="9" class="java.lang.String" itemvalue="llvmlite" />
|
28 |
+
<item index="10" class="java.lang.String" itemvalue="tokenizers" />
|
29 |
+
<item index="11" class="java.lang.String" itemvalue="scipy" />
|
30 |
+
<item index="12" class="java.lang.String" itemvalue="transformers" />
|
31 |
+
<item index="13" class="java.lang.String" itemvalue="tornado" />
|
32 |
+
<item index="14" class="java.lang.String" itemvalue="threadpoolctl" />
|
33 |
+
<item index="15" class="java.lang.String" itemvalue="unidecode" />
|
34 |
+
<item index="16" class="java.lang.String" itemvalue="py-cpuinfo" />
|
35 |
+
<item index="17" class="java.lang.String" itemvalue="nbconvert" />
|
36 |
+
<item index="18" class="java.lang.String" itemvalue="tqdm" />
|
37 |
+
<item index="19" class="java.lang.String" itemvalue="appdirs" />
|
38 |
+
<item index="20" class="java.lang.String" itemvalue="rotary_embedding_torch" />
|
39 |
+
<item index="21" class="java.lang.String" itemvalue="deepspeed" />
|
40 |
+
<item index="22" class="java.lang.String" itemvalue="progressbar" />
|
41 |
+
<item index="23" class="java.lang.String" itemvalue="inflect" />
|
42 |
+
<item index="24" class="java.lang.String" itemvalue="librosa" />
|
43 |
+
<item index="25" class="java.lang.String" itemvalue="ffmpeg" />
|
44 |
+
<item index="26" class="java.lang.String" itemvalue="hjson" />
|
45 |
+
<item index="27" class="java.lang.String" itemvalue="einops" />
|
46 |
+
<item index="28" class="java.lang.String" itemvalue="torchaudio" />
|
47 |
+
<item index="29" class="java.lang.String" itemvalue="pyinstaller" />
|
48 |
+
<item index="30" class="java.lang.String" itemvalue="pytorch-lightning" />
|
49 |
+
<item index="31" class="java.lang.String" itemvalue="bitarray" />
|
50 |
+
<item index="32" class="java.lang.String" itemvalue="pyright" />
|
51 |
+
<item index="33" class="java.lang.String" itemvalue="yt-dlp" />
|
52 |
+
<item index="34" class="java.lang.String" itemvalue="torch" />
|
53 |
+
<item index="35" class="java.lang.String" itemvalue="torchvision" />
|
54 |
+
<item index="36" class="java.lang.String" itemvalue="sacrebleu" />
|
55 |
+
<item index="37" class="java.lang.String" itemvalue="aioshutil" />
|
56 |
+
<item index="38" class="java.lang.String" itemvalue="absl-py" />
|
57 |
+
<item index="39" class="java.lang.String" itemvalue="gradio" />
|
58 |
+
<item index="40" class="java.lang.String" itemvalue="matplotlib-inline" />
|
59 |
+
<item index="41" class="java.lang.String" itemvalue="Werkzeug" />
|
60 |
+
<item index="42" class="java.lang.String" itemvalue="fairseq" />
|
61 |
+
<item index="43" class="java.lang.String" itemvalue="json5" />
|
62 |
+
<item index="44" class="java.lang.String" itemvalue="torchfcpe" />
|
63 |
+
<item index="45" class="java.lang.String" itemvalue="numpy" />
|
64 |
+
<item index="46" class="java.lang.String" itemvalue="pyasn1" />
|
65 |
+
<item index="47" class="java.lang.String" itemvalue="torchcrepe" />
|
66 |
+
<item index="48" class="java.lang.String" itemvalue="pyasn1-modules" />
|
67 |
+
<item index="49" class="java.lang.String" itemvalue="tensorboard" />
|
68 |
+
<item index="50" class="java.lang.String" itemvalue="av" />
|
69 |
+
<item index="51" class="java.lang.String" itemvalue="matplotlib" />
|
70 |
+
<item index="52" class="java.lang.String" itemvalue="tensorboardX" />
|
71 |
+
<item index="53" class="java.lang.String" itemvalue="uc-micro-py" />
|
72 |
+
<item index="54" class="java.lang.String" itemvalue="ffmpy" />
|
73 |
+
<item index="55" class="java.lang.String" itemvalue="pyworld" />
|
74 |
+
<item index="56" class="java.lang.String" itemvalue="Markdown" />
|
75 |
+
<item index="57" class="java.lang.String" itemvalue="praat-parselmouth" />
|
76 |
+
</list>
|
77 |
+
</value>
|
78 |
+
</option>
|
79 |
+
</inspection_tool>
|
80 |
+
</profile>
|
81 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="Black">
|
4 |
+
<option name="sdkName" value="styletts2-personal" />
|
5 |
+
</component>
|
6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="styletts2-personal" project-jdk-type="Python SDK" />
|
7 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/styletts2-personal.iml" filepath="$PROJECT_DIR$/.idea/styletts2-personal.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/ruff.xml
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="RuffConfigService">
|
4 |
+
<option name="disableOnSaveOutsideOfProject" value="false" />
|
5 |
+
<option name="globalRuffExecutablePath" value="/opt/homebrew/bin/ruff" />
|
6 |
+
<option name="globalRuffLspExecutablePath" value="/opt/homebrew/bin/ruff" />
|
7 |
+
<option name="projectRuffLspExecutablePath" value="/opt/homebrew/Caskroom/miniconda/base/envs/styletts2-personal/bin/ruff" />
|
8 |
+
<option name="runRuffOnSave" value="true" />
|
9 |
+
<option name="useRuffFormat" value="true" />
|
10 |
+
</component>
|
11 |
+
</project>
|
.idea/styletts2-personal.iml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="jdk" jdkName="styletts2-personal" jdkType="Python SDK" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
<component name="PyDocumentationSettings">
|
9 |
+
<option name="format" value="PLAIN" />
|
10 |
+
<option name="myDocStringFormat" value="Plain" />
|
11 |
+
</component>
|
12 |
+
</module>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
Modules/diffusion/diffusion.py
CHANGED
@@ -1,11 +1,5 @@
|
|
1 |
-
from math import pi
|
2 |
-
from random import randint
|
3 |
-
from typing import Any, Optional, Sequence, Tuple, Union
|
4 |
|
5 |
-
import torch
|
6 |
-
from einops import rearrange
|
7 |
from torch import Tensor, nn
|
8 |
-
from tqdm import tqdm
|
9 |
|
10 |
from .utils import *
|
11 |
from .sampler import *
|
|
|
|
|
|
|
|
|
1 |
|
|
|
|
|
2 |
from torch import Tensor, nn
|
|
|
3 |
|
4 |
from .utils import *
|
5 |
from .sampler import *
|
Modules/diffusion/modules.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
from math import
|
2 |
-
from typing import
|
3 |
|
4 |
-
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
@@ -10,6 +10,7 @@ from einops.layers.torch import Rearrange
|
|
10 |
from einops_exts import rearrange_many
|
11 |
from torch import Tensor, einsum
|
12 |
|
|
|
13 |
|
14 |
"""
|
15 |
Utils
|
|
|
1 |
+
from math import log, pi
|
2 |
+
from typing import Optional
|
3 |
|
4 |
+
import torch.nn.functional as F
|
5 |
|
6 |
import torch
|
7 |
import torch.nn as nn
|
|
|
10 |
from einops_exts import rearrange_many
|
11 |
from torch import Tensor, einsum
|
12 |
|
13 |
+
from Modules.diffusion.utils import default, exists, rand_bool
|
14 |
|
15 |
"""
|
16 |
Utils
|
Modules/diffusion/sampler.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
from math import atan, cos, pi, sin, sqrt
|
2 |
from typing import Any, Callable, List, Optional, Tuple, Type
|
3 |
|
4 |
-
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
-
from einops import rearrange
|
8 |
from torch import Tensor
|
9 |
|
10 |
from .utils import *
|
@@ -437,7 +437,11 @@ class KarrasSampler(Sampler):
|
|
437 |
# Denoise to sample
|
438 |
for i in range(num_steps - 1):
|
439 |
x = self.step(
|
440 |
-
x,
|
|
|
|
|
|
|
|
|
441 |
)
|
442 |
|
443 |
return x
|
|
|
1 |
from math import atan, cos, pi, sin, sqrt
|
2 |
from typing import Any, Callable, List, Optional, Tuple, Type
|
3 |
|
4 |
+
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
from torch import Tensor
|
9 |
|
10 |
from .utils import *
|
|
|
437 |
# Denoise to sample
|
438 |
for i in range(num_steps - 1):
|
439 |
x = self.step(
|
440 |
+
x,
|
441 |
+
fn=fn,
|
442 |
+
sigma=sigmas[i],
|
443 |
+
sigma_next=sigmas[i + 1],
|
444 |
+
gamma=gammas[i], # type: ignore # noqa
|
445 |
)
|
446 |
|
447 |
return x
|
Modules/diffusion/utils.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
from functools import reduce
|
2 |
from inspect import isfunction
|
3 |
-
from math import ceil, floor, log2
|
4 |
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
5 |
|
6 |
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
from einops import rearrange
|
9 |
-
from torch import Generator, Tensor
|
10 |
from typing_extensions import TypeGuard
|
11 |
|
12 |
T = TypeVar("T")
|
|
|
1 |
from functools import reduce
|
2 |
from inspect import isfunction
|
3 |
+
from math import ceil, floor, log2
|
4 |
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
5 |
|
6 |
import torch
|
|
|
|
|
|
|
7 |
from typing_extensions import TypeGuard
|
8 |
|
9 |
T = TypeVar("T")
|
Modules/discriminators.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
import torch.nn as nn
|
4 |
-
from torch.nn import Conv1d,
|
5 |
-
from torch.nn.utils import
|
|
|
6 |
|
7 |
from .utils import get_padding
|
8 |
|
@@ -21,8 +22,8 @@ def stft(x, fft_size, hop_size, win_length, window):
|
|
21 |
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
22 |
"""
|
23 |
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
|
24 |
-
|
25 |
-
|
26 |
|
27 |
return torch.abs(x_stft).transpose(2, 1)
|
28 |
|
@@ -39,7 +40,7 @@ class SpecDiscriminator(nn.Module):
|
|
39 |
use_spectral_norm=False,
|
40 |
):
|
41 |
super(SpecDiscriminator, self).__init__()
|
42 |
-
norm_f = weight_norm if use_spectral_norm
|
43 |
self.fft_size = fft_size
|
44 |
self.shift_size = shift_size
|
45 |
self.win_length = win_length
|
@@ -123,7 +124,7 @@ class DiscriminatorP(torch.nn.Module):
|
|
123 |
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
124 |
super(DiscriminatorP, self).__init__()
|
125 |
self.period = period
|
126 |
-
norm_f = weight_norm if use_spectral_norm
|
127 |
self.convs = nn.ModuleList(
|
128 |
[
|
129 |
norm_f(
|
@@ -225,7 +226,7 @@ class WavLMDiscriminator(nn.Module):
|
|
225 |
self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
|
226 |
):
|
227 |
super(WavLMDiscriminator, self).__init__()
|
228 |
-
norm_f = weight_norm if use_spectral_norm
|
229 |
self.pre = norm_f(
|
230 |
Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
|
231 |
)
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, Conv2d
|
5 |
+
from torch.nn.utils import spectral_norm
|
6 |
+
from torch.nn.utils.parametrizations import weight_norm
|
7 |
|
8 |
from .utils import get_padding
|
9 |
|
|
|
22 |
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
23 |
"""
|
24 |
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
|
25 |
+
x_stft[..., 0]
|
26 |
+
x_stft[..., 1]
|
27 |
|
28 |
return torch.abs(x_stft).transpose(2, 1)
|
29 |
|
|
|
40 |
use_spectral_norm=False,
|
41 |
):
|
42 |
super(SpecDiscriminator, self).__init__()
|
43 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
44 |
self.fft_size = fft_size
|
45 |
self.shift_size = shift_size
|
46 |
self.win_length = win_length
|
|
|
124 |
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
125 |
super(DiscriminatorP, self).__init__()
|
126 |
self.period = period
|
127 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
128 |
self.convs = nn.ModuleList(
|
129 |
[
|
130 |
norm_f(
|
|
|
226 |
self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
|
227 |
):
|
228 |
super(WavLMDiscriminator, self).__init__()
|
229 |
+
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
230 |
self.pre = norm_f(
|
231 |
Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
|
232 |
)
|
Modules/hifigan.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
import torch.nn as nn
|
4 |
-
from torch.nn import Conv1d, ConvTranspose1d
|
5 |
-
from torch.nn.utils import
|
|
|
6 |
from .utils import init_weights, get_padding
|
7 |
|
8 |
import math
|
9 |
-
import random
|
10 |
import numpy as np
|
11 |
|
12 |
LRELU_SLOPE = 0.1
|
@@ -269,7 +269,7 @@ class SineGen(torch.nn.Module):
|
|
269 |
output sine_tensor: tensor(batchsize=1, length, dim)
|
270 |
output uv: tensor(batchsize=1, length, 1)
|
271 |
"""
|
272 |
-
|
273 |
# fundamental component
|
274 |
fn = torch.multiply(
|
275 |
f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
|
@@ -515,6 +515,7 @@ class AdainResBlk1d(nn.Module):
|
|
515 |
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
516 |
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
517 |
self.norm1 = AdaIN1d(style_dim, dim_in)
|
|
|
518 |
self.norm2 = AdaIN1d(style_dim, dim_out)
|
519 |
if self.learned_sc:
|
520 |
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
@@ -581,6 +582,8 @@ class Decoder(nn.Module):
|
|
581 |
nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
|
582 |
)
|
583 |
|
|
|
|
|
584 |
self.N_conv = weight_norm(
|
585 |
nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
|
586 |
)
|
@@ -599,30 +602,6 @@ class Decoder(nn.Module):
|
|
599 |
)
|
600 |
|
601 |
def forward(self, asr, F0_curve, N, s):
|
602 |
-
if self.training:
|
603 |
-
downlist = [0, 3, 7]
|
604 |
-
F0_down = downlist[random.randint(0, 2)]
|
605 |
-
downlist = [0, 3, 7, 15]
|
606 |
-
N_down = downlist[random.randint(0, 3)]
|
607 |
-
if F0_down:
|
608 |
-
F0_curve = (
|
609 |
-
nn.functional.conv1d(
|
610 |
-
F0_curve.unsqueeze(1),
|
611 |
-
torch.ones(1, 1, F0_down).to("cuda"),
|
612 |
-
padding=F0_down // 2,
|
613 |
-
).squeeze(1)
|
614 |
-
/ F0_down
|
615 |
-
)
|
616 |
-
if N_down:
|
617 |
-
N = (
|
618 |
-
nn.functional.conv1d(
|
619 |
-
N.unsqueeze(1),
|
620 |
-
torch.ones(1, 1, N_down).to("cuda"),
|
621 |
-
padding=N_down // 2,
|
622 |
-
).squeeze(1)
|
623 |
-
/ N_down
|
624 |
-
)
|
625 |
-
|
626 |
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
627 |
N = self.N_conv(N.unsqueeze(1))
|
628 |
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
5 |
+
from torch.nn.utils import remove_weight_norm
|
6 |
+
from torch.nn.utils.parametrizations import weight_norm
|
7 |
from .utils import init_weights, get_padding
|
8 |
|
9 |
import math
|
|
|
10 |
import numpy as np
|
11 |
|
12 |
LRELU_SLOPE = 0.1
|
|
|
269 |
output sine_tensor: tensor(batchsize=1, length, dim)
|
270 |
output uv: tensor(batchsize=1, length, 1)
|
271 |
"""
|
272 |
+
torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
273 |
# fundamental component
|
274 |
fn = torch.multiply(
|
275 |
f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
|
|
|
515 |
self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
|
516 |
self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
|
517 |
self.norm1 = AdaIN1d(style_dim, dim_in)
|
518 |
+
# self.norm1 = torch.compile(self.norm1)
|
519 |
self.norm2 = AdaIN1d(style_dim, dim_out)
|
520 |
if self.learned_sc:
|
521 |
self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
|
|
|
582 |
nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
|
583 |
)
|
584 |
|
585 |
+
# self.F0_conv = torch.compile(self.F0_conv)
|
586 |
+
|
587 |
self.N_conv = weight_norm(
|
588 |
nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
|
589 |
)
|
|
|
602 |
)
|
603 |
|
604 |
def forward(self, asr, F0_curve, N, s):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
605 |
F0 = self.F0_conv(F0_curve.unsqueeze(1))
|
606 |
N = self.N_conv(N.unsqueeze(1))
|
607 |
|
Modules/istftnet.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
import torch.nn as nn
|
4 |
-
from torch.nn import Conv1d, ConvTranspose1d
|
5 |
-
from torch.nn.utils import
|
|
|
6 |
from .utils import init_weights, get_padding
|
7 |
|
8 |
import math
|
@@ -313,7 +314,7 @@ class SineGen(torch.nn.Module):
|
|
313 |
output sine_tensor: tensor(batchsize=1, length, dim)
|
314 |
output uv: tensor(batchsize=1, length, 1)
|
315 |
"""
|
316 |
-
|
317 |
# fundamental component
|
318 |
fn = torch.multiply(
|
319 |
f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
5 |
+
from torch.nn.utils import remove_weight_norm
|
6 |
+
from torch.nn.utils.parametrizations import weight_norm
|
7 |
from .utils import init_weights, get_padding
|
8 |
|
9 |
import math
|
|
|
314 |
output sine_tensor: tensor(batchsize=1, length, dim)
|
315 |
output uv: tensor(batchsize=1, length, 1)
|
316 |
"""
|
317 |
+
torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
318 |
# fundamental component
|
319 |
fn = torch.multiply(
|
320 |
f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
|
Modules/slmadv.py
CHANGED
@@ -67,7 +67,7 @@ class SLMAdversarialLoss(torch.nn.Module):
|
|
67 |
).squeeze(1)
|
68 |
|
69 |
s_dur = s_preds[:, 128:]
|
70 |
-
|
71 |
|
72 |
d, _ = self.model.predictor(
|
73 |
d_en,
|
@@ -138,8 +138,6 @@ class SLMAdversarialLoss(torch.nn.Module):
|
|
138 |
p_en = []
|
139 |
sp = []
|
140 |
|
141 |
-
F0_fakes = []
|
142 |
-
N_fakes = []
|
143 |
|
144 |
wav = []
|
145 |
|
|
|
67 |
).squeeze(1)
|
68 |
|
69 |
s_dur = s_preds[:, 128:]
|
70 |
+
s_preds[:, :128]
|
71 |
|
72 |
d, _ = self.model.predictor(
|
73 |
d_en,
|
|
|
138 |
p_en = []
|
139 |
sp = []
|
140 |
|
|
|
|
|
141 |
|
142 |
wav = []
|
143 |
|
Utils/ASR/layers.py
CHANGED
@@ -1,10 +1,6 @@
|
|
1 |
-
import math
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
-
from typing import Optional, Any
|
5 |
-
from torch import Tensor
|
6 |
import torch.nn.functional as F
|
7 |
-
import torchaudio
|
8 |
import torchaudio.functional as audio_F
|
9 |
|
10 |
import random
|
|
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
|
|
|
|
3 |
import torch.nn.functional as F
|
|
|
4 |
import torchaudio.functional as audio_F
|
5 |
|
6 |
import random
|
Utils/ASR/models.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import math
|
2 |
import torch
|
3 |
from torch import nn
|
4 |
-
from torch.nn import TransformerEncoder
|
5 |
import torch.nn.functional as F
|
6 |
from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
|
7 |
|
|
|
1 |
import math
|
2 |
import torch
|
3 |
from torch import nn
|
|
|
4 |
import torch.nn.functional as F
|
5 |
from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
|
6 |
|
Utils/JDC/model.py
CHANGED
@@ -84,7 +84,7 @@ class JDCNet(nn.Module):
|
|
84 |
self.apply(self.init_weights)
|
85 |
|
86 |
def get_feature_GAN(self, x):
|
87 |
-
|
88 |
x = x.float().transpose(-1, -2)
|
89 |
|
90 |
convblock_out = self.conv_block(x)
|
@@ -98,7 +98,7 @@ class JDCNet(nn.Module):
|
|
98 |
return poolblock_out.transpose(-1, -2)
|
99 |
|
100 |
def get_feature(self, x):
|
101 |
-
|
102 |
x = x.float().transpose(-1, -2)
|
103 |
|
104 |
convblock_out = self.conv_block(x)
|
|
|
84 |
self.apply(self.init_weights)
|
85 |
|
86 |
def get_feature_GAN(self, x):
|
87 |
+
x.shape[-2]
|
88 |
x = x.float().transpose(-1, -2)
|
89 |
|
90 |
convblock_out = self.conv_block(x)
|
|
|
98 |
return poolblock_out.transpose(-1, -2)
|
99 |
|
100 |
def get_feature(self, x):
|
101 |
+
x.shape[-2]
|
102 |
x = x.float().transpose(-1, -2)
|
103 |
|
104 |
convblock_out = self.conv_block(x)
|
Utils/PLBERT/util.py
CHANGED
@@ -20,7 +20,6 @@ def load_plbert(log_dir):
|
|
20 |
albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
|
21 |
bert = CustomAlbert(albert_base_configuration)
|
22 |
|
23 |
-
files = os.listdir(log_dir)
|
24 |
ckpts = []
|
25 |
for f in os.listdir(log_dir):
|
26 |
if f.startswith("step_"):
|
|
|
20 |
albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
|
21 |
bert = CustomAlbert(albert_base_configuration)
|
22 |
|
|
|
23 |
ckpts = []
|
24 |
for f in os.listdir(log_dir):
|
25 |
if f.startswith("step_"):
|
_run.py
CHANGED
@@ -23,11 +23,8 @@ np.random.seed(0)
|
|
23 |
import time
|
24 |
import random
|
25 |
import yaml
|
26 |
-
from munch import Munch
|
27 |
import numpy as np
|
28 |
import torch
|
29 |
-
from torch import nn
|
30 |
-
import torch.nn.functional as F
|
31 |
import torchaudio
|
32 |
import librosa
|
33 |
from nltk.tokenize import word_tokenize
|
@@ -305,7 +302,7 @@ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=
|
|
305 |
|
306 |
ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
|
307 |
ref_text_mask = length_to_mask(ref_input_lengths).to(device)
|
308 |
-
|
309 |
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
310 |
embedding=bert_dur,
|
311 |
embedding_scale=embedding_scale,
|
|
|
23 |
import time
|
24 |
import random
|
25 |
import yaml
|
|
|
26 |
import numpy as np
|
27 |
import torch
|
|
|
|
|
28 |
import torchaudio
|
29 |
import librosa
|
30 |
from nltk.tokenize import word_tokenize
|
|
|
302 |
|
303 |
ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
|
304 |
ref_text_mask = length_to_mask(ref_input_lengths).to(device)
|
305 |
+
model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())
|
306 |
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
307 |
embedding=bert_dur,
|
308 |
embedding_scale=embedding_scale,
|
app.py
CHANGED
@@ -1,48 +1,48 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
[Paper](https://arxiv.org/abs/2306.07691) - [Samples](https://styletts2.github.io/) - [Code](https://github.com/yl4579/StyleTTS2) - [Discord](https://discord.gg/ha8sxdG2K4)
|
4 |
-
|
5 |
-
A free demo of StyleTTS 2. **I am not affiliated with the StyleTTS 2 Authors.**
|
6 |
-
|
7 |
-
**Before using this demo, you agree to inform the listeners that the speech samples are synthesized by the pre-trained models, unless you have the permission to use the voice you synthesize. That is, you agree to only use voices whose speakers grant the permission to have their voice cloned, either directly or by license before making synthesized voices public, or you have to publicly announce that these voices are synthesized if you do not have the permission to use these voices.**
|
8 |
-
|
9 |
-
Is there a long queue on this space? Duplicate it and add a more powerful GPU to skip the wait! **Note: Thank you to Hugging Face for their generous GPU grant program!**
|
10 |
|
11 |
-
**NOTE: StyleTTS 2 does better on longer texts.** For example, making it say "hi" will produce a lower-quality result than making it say a longer phrase.
|
12 |
-
"""
|
13 |
import gradio as gr
|
14 |
-
import styletts2importable
|
15 |
-
import ljspeechimportable
|
16 |
import torch
|
17 |
-
import
|
18 |
from txtsplit import txtsplit
|
19 |
import numpy as np
|
20 |
-
import
|
|
|
|
|
21 |
theme = gr.themes.Base(
|
22 |
-
font=[
|
|
|
|
|
|
|
|
|
|
|
23 |
)
|
24 |
-
voicelist = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
voices = {}
|
26 |
-
|
27 |
-
global_phonemizer = phonemizer.backend.EspeakBackend(
|
28 |
-
|
29 |
-
|
30 |
-
# with open('voices.pkl', 'rb') as f:
|
31 |
-
# voices = pickle.load(f)
|
32 |
# else:
|
33 |
for v in voicelist:
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
if not torch.cuda.is_available(): INTROTXT += "\n\n### You are on a CPU-only system, inference will be much slower.\n\nYou can use the [online demo](https://huggingface.co/spaces/styletts2/styletts2) for fast inference."
|
45 |
-
def synthesize(text, voice, lngsteps, password, progress=gr.Progress()):
|
46 |
if text.strip() == "":
|
47 |
raise gr.Error("You must enter some text")
|
48 |
if len(text) > 50000:
|
@@ -53,123 +53,81 @@ def synthesize(text, voice, lngsteps, password, progress=gr.Progress()):
|
|
53 |
texts = txtsplit(text)
|
54 |
v = voice.lower()
|
55 |
audios = []
|
56 |
-
for t in
|
57 |
-
audios.append(
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
# texts = split_and_recombine_text(text)
|
68 |
-
# v = voice.lower()
|
69 |
-
# audios = []
|
70 |
-
# for t in progress.tqdm(texts):
|
71 |
-
# audios.append(styletts2importable.inference(t, voices[v], alpha=0.3, beta=0.7, diffusion_steps=lngsteps, embedding_scale=1))
|
72 |
-
# return (24000, np.concatenate(audios))
|
73 |
-
# else:
|
74 |
-
# raise gr.Error('Wrong access code')
|
75 |
-
def clsynthesize(text, voice, vcsteps, progress=gr.Progress()):
|
76 |
-
# if text.strip() == "":
|
77 |
-
# raise gr.Error("You must enter some text")
|
78 |
-
# # if global_phonemizer.phonemize([text]) > 300:
|
79 |
-
# if len(text) > 400:
|
80 |
-
# raise gr.Error("Text must be under 400 characters")
|
81 |
-
# # return (24000, styletts2importable.inference(text, styletts2importable.compute_style(voice), alpha=0.3, beta=0.7, diffusion_steps=20, embedding_scale=1))
|
82 |
-
# return (24000, styletts2importable.inference(text, styletts2importable.compute_style(voice), alpha=0.3, beta=0.7, diffusion_steps=vcsteps, embedding_scale=1))
|
83 |
-
if text.strip() == "":
|
84 |
-
raise gr.Error("You must enter some text")
|
85 |
-
if len(text) > 50000:
|
86 |
-
raise gr.Error("Text must be <50k characters")
|
87 |
-
print("*** saying ***")
|
88 |
-
print(text)
|
89 |
-
print("*** end ***")
|
90 |
-
texts = txtsplit(text)
|
91 |
-
audios = []
|
92 |
-
for t in progress.tqdm(texts):
|
93 |
-
audios.append(styletts2importable.inference(t, styletts2importable.compute_style(voice), alpha=0.3, beta=0.7, diffusion_steps=vcsteps, embedding_scale=1))
|
94 |
-
return (24000, np.concatenate(audios))
|
95 |
-
def ljsynthesize(text, steps, progress=gr.Progress()):
|
96 |
-
# if text.strip() == "":
|
97 |
-
# raise gr.Error("You must enter some text")
|
98 |
-
# # if global_phonemizer.phonemize([text]) > 300:
|
99 |
-
# if len(text) > 400:
|
100 |
-
# raise gr.Error("Text must be under 400 characters")
|
101 |
-
noise = torch.randn(1,1,256).to('cuda' if torch.cuda.is_available() else 'cpu')
|
102 |
-
# return (24000, ljspeechimportable.inference(text, noise, diffusion_steps=7, embedding_scale=1))
|
103 |
-
if text.strip() == "":
|
104 |
-
raise gr.Error("You must enter some text")
|
105 |
-
if len(text) > 150000:
|
106 |
-
raise gr.Error("Text must be <150k characters")
|
107 |
-
print("*** saying ***")
|
108 |
-
print(text)
|
109 |
-
print("*** end ***")
|
110 |
-
texts = txtsplit(text)
|
111 |
-
audios = []
|
112 |
-
for t in progress.tqdm(texts):
|
113 |
-
audios.append(ljspeechimportable.inference(t, noise, diffusion_steps=steps, embedding_scale=1))
|
114 |
return (24000, np.concatenate(audios))
|
115 |
|
116 |
|
117 |
with gr.Blocks() as vctk:
|
118 |
with gr.Row():
|
119 |
with gr.Column(scale=1):
|
120 |
-
inp = gr.Textbox(
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
# use_gruut = gr.Checkbox(label="Use alternate phonemizer (Gruut) - Experimental")
|
124 |
with gr.Column(scale=1):
|
125 |
btn = gr.Button("Synthesize", variant="primary")
|
126 |
-
audio = gr.Audio(
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
clbtn.click(clsynthesize, inputs=[clinp, clvoice, vcsteps], outputs=[claudio], concurrency_limit=4)
|
138 |
-
# with gr.Blocks() as longText:
|
139 |
-
# with gr.Row():
|
140 |
-
# with gr.Column(scale=1):
|
141 |
-
# lnginp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
|
142 |
-
# lngvoice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value='m-us-1', interactive=True)
|
143 |
-
# lngsteps = gr.Slider(minimum=5, maximum=25, value=10, step=1, label="Diffusion Steps", info="Higher = better quality, but slower", interactive=True)
|
144 |
-
# lngpwd = gr.Textbox(label="Access code", info="This feature is in beta. You need an access code to use it as it uses more resources and we would like to prevent abuse")
|
145 |
-
# with gr.Column(scale=1):
|
146 |
-
# lngbtn = gr.Button("Synthesize", variant="primary")
|
147 |
-
# lngaudio = gr.Audio(interactive=False, label="Synthesized Audio")
|
148 |
-
# lngbtn.click(longsynthesize, inputs=[lnginp, lngvoice, lngsteps, lngpwd], outputs=[lngaudio], concurrency_limit=4)
|
149 |
-
with gr.Blocks() as lj:
|
150 |
-
with gr.Row():
|
151 |
-
with gr.Column(scale=1):
|
152 |
-
ljinp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
|
153 |
-
ljsteps = gr.Slider(minimum=3, maximum=20, value=3, step=1, label="Diffusion Steps", info="Theoretically, higher should be better quality but slower, but we cannot notice a difference. Try with lower steps first - it is faster", interactive=True)
|
154 |
-
with gr.Column(scale=1):
|
155 |
-
ljbtn = gr.Button("Synthesize", variant="primary")
|
156 |
-
ljaudio = gr.Audio(interactive=False, label="Synthesized Audio", waveform_options={'waveform_progress_color': '#3C82F6'})
|
157 |
-
ljbtn.click(ljsynthesize, inputs=[ljinp, ljsteps], outputs=[ljaudio], concurrency_limit=4)
|
158 |
-
with gr.Blocks(title="StyleTTS 2", css="footer{display:none !important}", theme=theme) as demo:
|
159 |
-
gr.Markdown(INTROTXT)
|
160 |
-
gr.DuplicateButton("Duplicate Space")
|
161 |
-
# gr.TabbedInterface([vctk, clone, lj, longText], ['Multi-Voice', 'Voice Cloning', 'LJSpeech', 'Long Text [Beta]'])
|
162 |
-
gr.TabbedInterface([vctk, clone, lj], ['Multi-Voice', 'Voice Cloning', 'LJSpeech', 'Long Text [Beta]'])
|
163 |
-
gr.Markdown("""
|
164 |
-
Demo by [mrfakename](https://twitter.com/realmrfakename). I am not affiliated with the StyleTTS 2 authors.
|
165 |
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
if __name__ == "__main__":
|
173 |
# demo.queue(api_open=False, max_size=15).launch(show_api=False)
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
|
|
3 |
import gradio as gr
|
|
|
|
|
4 |
import torch
|
5 |
+
from styletts2importable import compute_style, inference
|
6 |
from txtsplit import txtsplit
|
7 |
import numpy as np
|
8 |
+
import phonemizer
|
9 |
+
|
10 |
+
|
11 |
theme = gr.themes.Base(
|
12 |
+
font=[
|
13 |
+
gr.themes.GoogleFont("Libre Franklin"),
|
14 |
+
gr.themes.GoogleFont("Public Sans"),
|
15 |
+
"system-ui",
|
16 |
+
"sans-serif",
|
17 |
+
],
|
18 |
)
|
19 |
+
voicelist = [
|
20 |
+
"f-us-1",
|
21 |
+
"f-us-2",
|
22 |
+
"f-us-3",
|
23 |
+
"f-us-4",
|
24 |
+
"m-us-1",
|
25 |
+
"m-us-2",
|
26 |
+
"m-us-3",
|
27 |
+
"m-us-4",
|
28 |
+
]
|
29 |
voices = {}
|
30 |
+
|
31 |
+
global_phonemizer = phonemizer.backend.EspeakBackend(
|
32 |
+
language="en-us", preserve_punctuation=True, with_stress=True
|
33 |
+
)
|
|
|
|
|
34 |
# else:
|
35 |
for v in voicelist:
|
36 |
+
cache_path = f"voices/{v}.wav.npy"
|
37 |
+
if os.path.exists(cache_path):
|
38 |
+
voices[v] = torch.from_numpy(np.load(cache_path))
|
39 |
+
else:
|
40 |
+
style = compute_style(f"voices/{v}.wav")
|
41 |
+
voices[v] = style
|
42 |
+
np.save(cache_path, style.cpu().numpy())
|
43 |
+
|
44 |
+
|
45 |
+
def synthesize(text, voice, lngsteps):
|
|
|
|
|
46 |
if text.strip() == "":
|
47 |
raise gr.Error("You must enter some text")
|
48 |
if len(text) > 50000:
|
|
|
53 |
texts = txtsplit(text)
|
54 |
v = voice.lower()
|
55 |
audios = []
|
56 |
+
for t in texts:
|
57 |
+
audios.append(
|
58 |
+
inference(
|
59 |
+
t,
|
60 |
+
voices[v],
|
61 |
+
alpha=0.3,
|
62 |
+
beta=0.7,
|
63 |
+
diffusion_steps=lngsteps,
|
64 |
+
embedding_scale=1,
|
65 |
+
)
|
66 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
return (24000, np.concatenate(audios))
|
68 |
|
69 |
|
70 |
with gr.Blocks() as vctk:
|
71 |
with gr.Row():
|
72 |
with gr.Column(scale=1):
|
73 |
+
inp = gr.Textbox(
|
74 |
+
label="Text",
|
75 |
+
info="What would you like StyleTTS 2 to read? It works better on full sentences.",
|
76 |
+
interactive=True,
|
77 |
+
)
|
78 |
+
voice = gr.Dropdown(
|
79 |
+
voicelist,
|
80 |
+
label="Voice",
|
81 |
+
info="Select a default voice.",
|
82 |
+
value="m-us-2",
|
83 |
+
interactive=True,
|
84 |
+
)
|
85 |
+
multispeakersteps = gr.Slider(
|
86 |
+
minimum=3,
|
87 |
+
maximum=15,
|
88 |
+
value=3,
|
89 |
+
step=1,
|
90 |
+
label="Diffusion Steps",
|
91 |
+
info="Theoretically, higher should be better quality but slower, but we cannot notice a difference. Try with lower steps first - it is faster",
|
92 |
+
interactive=True,
|
93 |
+
)
|
94 |
# use_gruut = gr.Checkbox(label="Use alternate phonemizer (Gruut) - Experimental")
|
95 |
with gr.Column(scale=1):
|
96 |
btn = gr.Button("Synthesize", variant="primary")
|
97 |
+
audio = gr.Audio(
|
98 |
+
interactive=False,
|
99 |
+
label="Synthesized Audio",
|
100 |
+
waveform_options={"waveform_progress_color": "#3C82F6"},
|
101 |
+
)
|
102 |
+
btn.click(
|
103 |
+
synthesize,
|
104 |
+
inputs=[inp, voice, multispeakersteps],
|
105 |
+
outputs=[audio],
|
106 |
+
concurrency_limit=4,
|
107 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
+
with gr.Blocks(
|
110 |
+
title="StyleTTS 2", css="footer{display:none !important}", theme=theme
|
111 |
+
) as demo:
|
112 |
+
gr.TabbedInterface(
|
113 |
+
[vctk], ["Multi-Voice", "Voice Cloning", "LJSpeech", "Long Text [Beta]"]
|
114 |
+
)
|
115 |
if __name__ == "__main__":
|
116 |
# demo.queue(api_open=False, max_size=15).launch(show_api=False)
|
117 |
+
print("Launching")
|
118 |
+
# start_time = time.time()
|
119 |
+
# synthesize(
|
120 |
+
# "defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None (default), the name of the function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that gr.load this app) will not be able to use this event.",
|
121 |
+
# "m-us-2",
|
122 |
+
# 3,
|
123 |
+
# )
|
124 |
+
# print(f"Launched in {time.time() - start_time} seconds")
|
125 |
+
# second_start_time = time.time()
|
126 |
+
# synthesize(
|
127 |
+
# "defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None (default), the name of the function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that gr.load this app) will not be able to use this event.",
|
128 |
+
# "m-us-2",
|
129 |
+
# 3,
|
130 |
+
# )
|
131 |
+
# print(f"Launched in {time.time() - second_start_time} seconds")
|
132 |
+
demo.queue(api_open=True, max_size=None).launch(show_api=False)
|
133 |
+
print("Launched")
|
compute.py
CHANGED
@@ -5,7 +5,6 @@ print("NLTK")
|
|
5 |
import nltk
|
6 |
nltk.download('punkt')
|
7 |
print("SCIPY")
|
8 |
-
from scipy.io.wavfile import write
|
9 |
print("TORCH STUFF")
|
10 |
import torch
|
11 |
print("START")
|
@@ -20,17 +19,12 @@ import numpy as np
|
|
20 |
np.random.seed(0)
|
21 |
|
22 |
# load packages
|
23 |
-
import time
|
24 |
import random
|
25 |
import yaml
|
26 |
-
from munch import Munch
|
27 |
import numpy as np
|
28 |
import torch
|
29 |
-
from torch import nn
|
30 |
-
import torch.nn.functional as F
|
31 |
import torchaudio
|
32 |
import librosa
|
33 |
-
from nltk.tokenize import word_tokenize
|
34 |
|
35 |
from models import *
|
36 |
from utils import *
|
|
|
5 |
import nltk
|
6 |
nltk.download('punkt')
|
7 |
print("SCIPY")
|
|
|
8 |
print("TORCH STUFF")
|
9 |
import torch
|
10 |
print("START")
|
|
|
19 |
np.random.seed(0)
|
20 |
|
21 |
# load packages
|
|
|
22 |
import random
|
23 |
import yaml
|
|
|
24 |
import numpy as np
|
25 |
import torch
|
|
|
|
|
26 |
import torchaudio
|
27 |
import librosa
|
|
|
28 |
|
29 |
from models import *
|
30 |
from utils import *
|
ljspeechimportable.py
DELETED
@@ -1,225 +0,0 @@
|
|
1 |
-
from cached_path import cached_path
|
2 |
-
|
3 |
-
|
4 |
-
import torch
|
5 |
-
torch.manual_seed(0)
|
6 |
-
torch.backends.cudnn.benchmark = False
|
7 |
-
torch.backends.cudnn.deterministic = True
|
8 |
-
|
9 |
-
import random
|
10 |
-
random.seed(0)
|
11 |
-
|
12 |
-
import numpy as np
|
13 |
-
np.random.seed(0)
|
14 |
-
|
15 |
-
import nltk
|
16 |
-
nltk.download('punkt')
|
17 |
-
|
18 |
-
# load packages
|
19 |
-
import time
|
20 |
-
import random
|
21 |
-
import yaml
|
22 |
-
from munch import Munch
|
23 |
-
import numpy as np
|
24 |
-
import torch
|
25 |
-
from torch import nn
|
26 |
-
import torch.nn.functional as F
|
27 |
-
import torchaudio
|
28 |
-
import librosa
|
29 |
-
from nltk.tokenize import word_tokenize
|
30 |
-
|
31 |
-
from models import *
|
32 |
-
from utils import *
|
33 |
-
from text_utils import TextCleaner
|
34 |
-
textclenaer = TextCleaner()
|
35 |
-
|
36 |
-
|
37 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
38 |
-
|
39 |
-
to_mel = torchaudio.transforms.MelSpectrogram(
|
40 |
-
n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
|
41 |
-
mean, std = -4, 4
|
42 |
-
|
43 |
-
def length_to_mask(lengths):
|
44 |
-
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
45 |
-
mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
46 |
-
return mask
|
47 |
-
|
48 |
-
def preprocess(wave):
|
49 |
-
wave_tensor = torch.from_numpy(wave).float()
|
50 |
-
mel_tensor = to_mel(wave_tensor)
|
51 |
-
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
52 |
-
return mel_tensor
|
53 |
-
|
54 |
-
def compute_style(ref_dicts):
|
55 |
-
reference_embeddings = {}
|
56 |
-
for key, path in ref_dicts.items():
|
57 |
-
wave, sr = librosa.load(path, sr=24000)
|
58 |
-
audio, index = librosa.effects.trim(wave, top_db=30)
|
59 |
-
if sr != 24000:
|
60 |
-
audio = librosa.resample(audio, sr, 24000)
|
61 |
-
mel_tensor = preprocess(audio).to(device)
|
62 |
-
|
63 |
-
with torch.no_grad():
|
64 |
-
ref = model.style_encoder(mel_tensor.unsqueeze(1))
|
65 |
-
reference_embeddings[key] = (ref.squeeze(1), audio)
|
66 |
-
|
67 |
-
return reference_embeddings
|
68 |
-
|
69 |
-
# load phonemizer
|
70 |
-
import phonemizer
|
71 |
-
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore')
|
72 |
-
|
73 |
-
# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
|
74 |
-
|
75 |
-
|
76 |
-
config = yaml.safe_load(open(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/config.yml'))))
|
77 |
-
|
78 |
-
# load pretrained ASR model
|
79 |
-
ASR_config = config.get('ASR_config', False)
|
80 |
-
ASR_path = config.get('ASR_path', False)
|
81 |
-
text_aligner = load_ASR_models(ASR_path, ASR_config)
|
82 |
-
|
83 |
-
# load pretrained F0 model
|
84 |
-
F0_path = config.get('F0_path', False)
|
85 |
-
pitch_extractor = load_F0_models(F0_path)
|
86 |
-
|
87 |
-
# load BERT model
|
88 |
-
from Utils.PLBERT.util import load_plbert
|
89 |
-
BERT_path = config.get('PLBERT_dir', False)
|
90 |
-
plbert = load_plbert(BERT_path)
|
91 |
-
|
92 |
-
model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
|
93 |
-
_ = [model[key].eval() for key in model]
|
94 |
-
_ = [model[key].to(device) for key in model]
|
95 |
-
|
96 |
-
# params_whole = torch.load("Models/LJSpeech/epoch_2nd_00100.pth", map_location='cpu')
|
97 |
-
params_whole = torch.load(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/epoch_2nd_00100.pth')), map_location='cpu')
|
98 |
-
params = params_whole['net']
|
99 |
-
|
100 |
-
for key in model:
|
101 |
-
if key in params:
|
102 |
-
print('%s loaded' % key)
|
103 |
-
try:
|
104 |
-
model[key].load_state_dict(params[key])
|
105 |
-
except:
|
106 |
-
from collections import OrderedDict
|
107 |
-
state_dict = params[key]
|
108 |
-
new_state_dict = OrderedDict()
|
109 |
-
for k, v in state_dict.items():
|
110 |
-
name = k[7:] # remove `module.`
|
111 |
-
new_state_dict[name] = v
|
112 |
-
# load params
|
113 |
-
model[key].load_state_dict(new_state_dict, strict=False)
|
114 |
-
# except:
|
115 |
-
# _load(params[key], model[key])
|
116 |
-
_ = [model[key].eval() for key in model]
|
117 |
-
|
118 |
-
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
|
119 |
-
|
120 |
-
sampler = DiffusionSampler(
|
121 |
-
model.diffusion.diffusion,
|
122 |
-
sampler=ADPM2Sampler(),
|
123 |
-
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
|
124 |
-
clamp=False
|
125 |
-
)
|
126 |
-
|
127 |
-
def inference(text, noise, diffusion_steps=5, embedding_scale=1):
|
128 |
-
text = text.strip()
|
129 |
-
text = text.replace('"', '')
|
130 |
-
ps = global_phonemizer.phonemize([text])
|
131 |
-
ps = word_tokenize(ps[0])
|
132 |
-
ps = ' '.join(ps)
|
133 |
-
|
134 |
-
tokens = textclenaer(ps)
|
135 |
-
tokens.insert(0, 0)
|
136 |
-
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
137 |
-
|
138 |
-
with torch.no_grad():
|
139 |
-
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
|
140 |
-
text_mask = length_to_mask(input_lengths).to(tokens.device)
|
141 |
-
|
142 |
-
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
143 |
-
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
144 |
-
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
145 |
-
|
146 |
-
s_pred = sampler(noise,
|
147 |
-
embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
|
148 |
-
embedding_scale=embedding_scale).squeeze(0)
|
149 |
-
|
150 |
-
s = s_pred[:, 128:]
|
151 |
-
ref = s_pred[:, :128]
|
152 |
-
|
153 |
-
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
154 |
-
|
155 |
-
x, _ = model.predictor.lstm(d)
|
156 |
-
duration = model.predictor.duration_proj(x)
|
157 |
-
duration = torch.sigmoid(duration).sum(axis=-1)
|
158 |
-
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
159 |
-
|
160 |
-
pred_dur[-1] += 5
|
161 |
-
|
162 |
-
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
163 |
-
c_frame = 0
|
164 |
-
for i in range(pred_aln_trg.size(0)):
|
165 |
-
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
166 |
-
c_frame += int(pred_dur[i].data)
|
167 |
-
|
168 |
-
# encode prosody
|
169 |
-
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
170 |
-
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
171 |
-
out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),
|
172 |
-
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
173 |
-
|
174 |
-
return out.squeeze().cpu().numpy()
|
175 |
-
|
176 |
-
def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):
|
177 |
-
text = text.strip()
|
178 |
-
text = text.replace('"', '')
|
179 |
-
ps = global_phonemizer.phonemize([text])
|
180 |
-
ps = word_tokenize(ps[0])
|
181 |
-
ps = ' '.join(ps)
|
182 |
-
|
183 |
-
tokens = textclenaer(ps)
|
184 |
-
tokens.insert(0, 0)
|
185 |
-
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
186 |
-
|
187 |
-
with torch.no_grad():
|
188 |
-
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
|
189 |
-
text_mask = length_to_mask(input_lengths).to(tokens.device)
|
190 |
-
|
191 |
-
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
192 |
-
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
193 |
-
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
194 |
-
|
195 |
-
s_pred = sampler(noise,
|
196 |
-
embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
|
197 |
-
embedding_scale=embedding_scale).squeeze(0)
|
198 |
-
|
199 |
-
if s_prev is not None:
|
200 |
-
# convex combination of previous and current style
|
201 |
-
s_pred = alpha * s_prev + (1 - alpha) * s_pred
|
202 |
-
|
203 |
-
s = s_pred[:, 128:]
|
204 |
-
ref = s_pred[:, :128]
|
205 |
-
|
206 |
-
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
207 |
-
|
208 |
-
x, _ = model.predictor.lstm(d)
|
209 |
-
duration = model.predictor.duration_proj(x)
|
210 |
-
duration = torch.sigmoid(duration).sum(axis=-1)
|
211 |
-
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
212 |
-
|
213 |
-
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
214 |
-
c_frame = 0
|
215 |
-
for i in range(pred_aln_trg.size(0)):
|
216 |
-
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
217 |
-
c_frame += int(pred_dur[i].data)
|
218 |
-
|
219 |
-
# encode prosody
|
220 |
-
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
221 |
-
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
222 |
-
out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),
|
223 |
-
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
224 |
-
|
225 |
-
return out.squeeze().cpu().numpy(), s_pred
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
losses.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
import torch
|
2 |
-
from torch import nn
|
3 |
import torch.nn.functional as F
|
4 |
import torchaudio
|
5 |
from transformers import AutoModel
|
|
|
1 |
import torch
|
|
|
2 |
import torch.nn.functional as F
|
3 |
import torchaudio
|
4 |
from transformers import AutoModel
|
meldataset.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
# coding: utf-8
|
2 |
-
import os
|
3 |
import os.path as osp
|
4 |
-
import time
|
5 |
import random
|
6 |
import numpy as np
|
7 |
import random
|
@@ -9,8 +7,6 @@ import soundfile as sf
|
|
9 |
import librosa
|
10 |
|
11 |
import torch
|
12 |
-
from torch import nn
|
13 |
-
import torch.nn.functional as F
|
14 |
import torchaudio
|
15 |
from torch.utils.data import DataLoader
|
16 |
|
@@ -79,8 +75,6 @@ class FilePathDataset(torch.utils.data.Dataset):
|
|
79 |
OOD_data="Data/OOD_texts.txt",
|
80 |
min_length=50,
|
81 |
):
|
82 |
-
spect_params = SPECT_PARAMS
|
83 |
-
mel_params = MEL_PARAMS
|
84 |
|
85 |
_data_list = [l[:-1].split("|") for l in data_list]
|
86 |
self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
|
|
|
1 |
# coding: utf-8
|
|
|
2 |
import os.path as osp
|
|
|
3 |
import random
|
4 |
import numpy as np
|
5 |
import random
|
|
|
7 |
import librosa
|
8 |
|
9 |
import torch
|
|
|
|
|
10 |
import torchaudio
|
11 |
from torch.utils.data import DataLoader
|
12 |
|
|
|
75 |
OOD_data="Data/OOD_texts.txt",
|
76 |
min_length=50,
|
77 |
):
|
|
|
|
|
78 |
|
79 |
_data_list = [l[:-1].split("|") for l in data_list]
|
80 |
self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
|
models.py
CHANGED
@@ -1,22 +1,16 @@
|
|
1 |
# coding:utf-8
|
2 |
-
|
3 |
-
import os
|
4 |
-
import os.path as osp
|
5 |
-
|
6 |
-
import copy
|
7 |
import math
|
8 |
|
9 |
-
import numpy as np
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
import torch.nn.functional as F
|
13 |
-
from torch.nn.utils import
|
14 |
-
|
15 |
from Utils.ASR.models import ASRCNN
|
16 |
from Utils.JDC.model import JDCNet
|
17 |
|
18 |
from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
|
19 |
-
from Modules.diffusion.modules import
|
20 |
from Modules.diffusion.diffusion import AudioDiffusionConditional
|
21 |
|
22 |
from Modules.discriminators import (
|
@@ -27,6 +21,7 @@ from Modules.discriminators import (
|
|
27 |
|
28 |
from munch import Munch
|
29 |
import yaml
|
|
|
30 |
|
31 |
|
32 |
class LearnedDownSample(nn.Module):
|
@@ -589,8 +584,8 @@ class ProsodyPredictor(nn.Module):
|
|
589 |
def forward(self, texts, style, text_lengths, alignment, m):
|
590 |
d = self.text_encoder(texts, style, text_lengths, m)
|
591 |
|
592 |
-
|
593 |
-
|
594 |
|
595 |
# predict duration
|
596 |
input_lengths = text_lengths.cpu().numpy()
|
@@ -750,37 +745,19 @@ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
|
|
750 |
return asr_model
|
751 |
|
752 |
|
753 |
-
def build_model(args,
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
758 |
|
759 |
-
|
760 |
-
dim_in=args.hidden_dim,
|
761 |
-
style_dim=args.style_dim,
|
762 |
-
dim_out=args.n_mels,
|
763 |
-
resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
|
764 |
-
upsample_rates=args.decoder.upsample_rates,
|
765 |
-
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
766 |
-
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
767 |
-
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
768 |
-
gen_istft_n_fft=args.decoder.gen_istft_n_fft,
|
769 |
-
gen_istft_hop_size=args.decoder.gen_istft_hop_size,
|
770 |
-
)
|
771 |
-
else:
|
772 |
-
from Modules.hifigan import Decoder
|
773 |
-
|
774 |
-
decoder = Decoder(
|
775 |
-
dim_in=args.hidden_dim,
|
776 |
-
style_dim=args.style_dim,
|
777 |
-
dim_out=args.n_mels,
|
778 |
-
resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
|
779 |
-
upsample_rates=args.decoder.upsample_rates,
|
780 |
-
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
781 |
-
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
782 |
-
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
783 |
-
)
|
784 |
|
785 |
text_encoder = TextEncoder(
|
786 |
channels=args.hidden_dim,
|
@@ -804,20 +781,12 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
804 |
dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim
|
805 |
) # prosodic style encoder
|
806 |
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
**args.diffusion.transformer
|
814 |
-
)
|
815 |
-
else:
|
816 |
-
transformer = Transformer1d(
|
817 |
-
channels=args.style_dim * 2,
|
818 |
-
context_embedding_features=bert.config.hidden_size,
|
819 |
-
**args.diffusion.transformer
|
820 |
-
)
|
821 |
|
822 |
diffusion = AudioDiffusionConditional(
|
823 |
in_channels=1,
|
@@ -839,6 +808,8 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
839 |
diffusion.diffusion.net = transformer
|
840 |
diffusion.unet = transformer
|
841 |
|
|
|
|
|
842 |
nets = Munch(
|
843 |
bert=bert,
|
844 |
bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
|
@@ -848,8 +819,6 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
848 |
predictor_encoder=predictor_encoder,
|
849 |
style_encoder=style_encoder,
|
850 |
diffusion=diffusion,
|
851 |
-
text_aligner=text_aligner,
|
852 |
-
pitch_extractor=pitch_extractor,
|
853 |
mpd=MultiPeriodDiscriminator(),
|
854 |
msd=MultiResSpecDiscriminator(),
|
855 |
# slm discriminator head
|
|
|
1 |
# coding:utf-8
|
|
|
|
|
|
|
|
|
|
|
2 |
import math
|
3 |
|
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
+
from torch.nn.utils import spectral_norm
|
8 |
+
from torch.nn.utils.parametrizations import weight_norm
|
9 |
from Utils.ASR.models import ASRCNN
|
10 |
from Utils.JDC.model import JDCNet
|
11 |
|
12 |
from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
|
13 |
+
from Modules.diffusion.modules import StyleTransformer1d
|
14 |
from Modules.diffusion.diffusion import AudioDiffusionConditional
|
15 |
|
16 |
from Modules.discriminators import (
|
|
|
21 |
|
22 |
from munch import Munch
|
23 |
import yaml
|
24 |
+
from Modules.hifigan import Decoder
|
25 |
|
26 |
|
27 |
class LearnedDownSample(nn.Module):
|
|
|
584 |
def forward(self, texts, style, text_lengths, alignment, m):
|
585 |
d = self.text_encoder(texts, style, text_lengths, m)
|
586 |
|
587 |
+
d.shape[0]
|
588 |
+
d.shape[1]
|
589 |
|
590 |
# predict duration
|
591 |
input_lengths = text_lengths.cpu().numpy()
|
|
|
745 |
return asr_model
|
746 |
|
747 |
|
748 |
+
def build_model(args, bert):
|
749 |
+
decoder = Decoder(
|
750 |
+
dim_in=args.hidden_dim,
|
751 |
+
style_dim=args.style_dim,
|
752 |
+
dim_out=args.n_mels,
|
753 |
+
resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
|
754 |
+
upsample_rates=args.decoder.upsample_rates,
|
755 |
+
upsample_initial_channel=args.decoder.upsample_initial_channel,
|
756 |
+
resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
757 |
+
upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
|
758 |
+
)
|
759 |
|
760 |
+
# decoder = torch.compile(decoder, dynamic=True, backend="aot_eager")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
761 |
|
762 |
text_encoder = TextEncoder(
|
763 |
channels=args.hidden_dim,
|
|
|
781 |
dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim
|
782 |
) # prosodic style encoder
|
783 |
|
784 |
+
transformer = StyleTransformer1d(
|
785 |
+
channels=args.style_dim * 2,
|
786 |
+
context_embedding_features=bert.config.hidden_size,
|
787 |
+
context_features=args.style_dim * 2,
|
788 |
+
**args.diffusion.transformer,
|
789 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
790 |
|
791 |
diffusion = AudioDiffusionConditional(
|
792 |
in_channels=1,
|
|
|
808 |
diffusion.diffusion.net = transformer
|
809 |
diffusion.unet = transformer
|
810 |
|
811 |
+
# predictor = torch.compile(predictor)
|
812 |
+
|
813 |
nets = Munch(
|
814 |
bert=bert,
|
815 |
bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
|
|
|
819 |
predictor_encoder=predictor_encoder,
|
820 |
style_encoder=style_encoder,
|
821 |
diffusion=diffusion,
|
|
|
|
|
822 |
mpd=MultiPeriodDiscriminator(),
|
823 |
msd=MultiResSpecDiscriminator(),
|
824 |
# slm discriminator head
|
optimizers.py
CHANGED
@@ -1,10 +1,5 @@
|
|
1 |
# coding:utf-8
|
2 |
-
import os, sys
|
3 |
-
import os.path as osp
|
4 |
-
import numpy as np
|
5 |
import torch
|
6 |
-
from torch import nn
|
7 |
-
from torch.optim import Optimizer
|
8 |
from functools import reduce
|
9 |
from torch.optim import AdamW
|
10 |
|
|
|
1 |
# coding:utf-8
|
|
|
|
|
|
|
2 |
import torch
|
|
|
|
|
3 |
from functools import reduce
|
4 |
from torch.optim import AdamW
|
5 |
|
requirements.txt
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
SoundFile
|
2 |
torchaudio
|
3 |
munch
|
4 |
-
torch
|
5 |
pydub
|
6 |
pyyaml
|
7 |
librosa
|
8 |
nltk
|
9 |
matplotlib
|
10 |
accelerate
|
|
|
|
|
11 |
transformers
|
12 |
einops
|
13 |
einops-exts
|
@@ -20,5 +22,4 @@ phonemizer
|
|
20 |
cached-path
|
21 |
gradio
|
22 |
gruut
|
23 |
-
#tortoise-tts
|
24 |
txtsplit
|
|
|
1 |
SoundFile
|
2 |
torchaudio
|
3 |
munch
|
4 |
+
torch>=2.2.0
|
5 |
pydub
|
6 |
pyyaml
|
7 |
librosa
|
8 |
nltk
|
9 |
matplotlib
|
10 |
accelerate
|
11 |
+
tokenizers>=0.14
|
12 |
+
bottleneck>=1.3.6
|
13 |
transformers
|
14 |
einops
|
15 |
einops-exts
|
|
|
22 |
cached-path
|
23 |
gradio
|
24 |
gruut
|
|
|
25 |
txtsplit
|
styletts2importable.py
CHANGED
@@ -1,60 +1,57 @@
|
|
|
|
|
|
|
|
|
|
1 |
from cached_path import cached_path
|
2 |
-
|
3 |
-
# from gruut_phonemize import gphonemize
|
4 |
-
|
5 |
-
# from dp.phonemizer import Phonemizer
|
6 |
-
print("NLTK")
|
7 |
import nltk
|
8 |
-
|
9 |
-
|
10 |
-
from
|
11 |
-
|
12 |
-
import
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
14 |
torch.manual_seed(0)
|
15 |
torch.backends.cudnn.benchmark = False
|
16 |
torch.backends.cudnn.deterministic = True
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
import numpy as np
|
22 |
-
np.random.seed(0)
|
23 |
|
24 |
-
# load packages
|
25 |
-
import time
|
26 |
-
import random
|
27 |
-
import yaml
|
28 |
-
from munch import Munch
|
29 |
-
import numpy as np
|
30 |
-
import torch
|
31 |
-
from torch import nn
|
32 |
-
import torch.nn.functional as F
|
33 |
-
import torchaudio
|
34 |
-
import librosa
|
35 |
-
from nltk.tokenize import word_tokenize
|
36 |
|
37 |
-
|
38 |
-
from utils import *
|
39 |
-
from text_utils import TextCleaner
|
40 |
-
textclenaer = TextCleaner()
|
41 |
|
42 |
|
43 |
to_mel = torchaudio.transforms.MelSpectrogram(
|
44 |
-
n_mels=80, n_fft=2048, win_length=1200, hop_length=300
|
|
|
45 |
mean, std = -4, 4
|
46 |
|
|
|
47 |
def length_to_mask(lengths):
|
48 |
-
mask =
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
return mask
|
51 |
|
|
|
52 |
def preprocess(wave):
|
53 |
wave_tensor = torch.from_numpy(wave).float()
|
54 |
mel_tensor = to_mel(wave_tensor)
|
55 |
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
56 |
return mel_tensor
|
57 |
|
|
|
58 |
def compute_style(path):
|
59 |
wave, sr = librosa.load(path, sr=24000)
|
60 |
audio, index = librosa.effects.trim(wave, top_db=30)
|
@@ -68,55 +65,151 @@ def compute_style(path):
|
|
68 |
|
69 |
return torch.cat([ref_s, ref_p], dim=1)
|
70 |
|
71 |
-
|
|
|
72 |
if torch.cuda.is_available():
|
73 |
-
device =
|
74 |
elif torch.backends.mps.is_available():
|
75 |
print("MPS would be available but cannot be used rn")
|
76 |
-
# device =
|
77 |
-
|
78 |
-
import phonemizer
|
79 |
-
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
|
80 |
-
# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
|
81 |
-
|
82 |
|
83 |
# config = yaml.safe_load(open("Models/LibriTTS/config.yml"))
|
84 |
-
config =
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
plbert = load_plbert(BERT_path)
|
99 |
|
100 |
-
|
101 |
-
|
|
|
102 |
_ = [model[key].eval() for key in model]
|
103 |
_ = [model[key].to(device) for key in model]
|
104 |
|
105 |
-
#
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
for key in model:
|
110 |
if key in params:
|
111 |
-
print(
|
112 |
try:
|
113 |
model[key].load_state_dict(params[key])
|
114 |
except:
|
115 |
from collections import OrderedDict
|
|
|
116 |
state_dict = params[key]
|
117 |
new_state_dict = OrderedDict()
|
118 |
for k, v in state_dict.items():
|
119 |
-
name = k[7:]
|
120 |
new_state_dict[name] = v
|
121 |
# load params
|
122 |
model[key].load_state_dict(new_state_dict, strict=False)
|
@@ -124,181 +217,34 @@ for key in model:
|
|
124 |
# _load(params[key], model[key])
|
125 |
_ = [model[key].eval() for key in model]
|
126 |
|
127 |
-
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
|
128 |
|
129 |
sampler = DiffusionSampler(
|
130 |
model.diffusion.diffusion,
|
131 |
sampler=ADPM2Sampler(),
|
132 |
-
sigma_schedule=KarrasSchedule(
|
133 |
-
|
|
|
|
|
134 |
)
|
135 |
|
136 |
-
def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
|
137 |
-
text = text.strip()
|
138 |
-
ps = global_phonemizer.phonemize([text])
|
139 |
-
ps = word_tokenize(ps[0])
|
140 |
-
ps = ' '.join(ps)
|
141 |
-
tokens = textclenaer(ps)
|
142 |
-
tokens.insert(0, 0)
|
143 |
-
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
144 |
-
|
145 |
-
with torch.no_grad():
|
146 |
-
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
147 |
-
text_mask = length_to_mask(input_lengths).to(device)
|
148 |
-
|
149 |
-
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
150 |
-
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
151 |
-
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
152 |
-
|
153 |
-
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
154 |
-
embedding=bert_dur,
|
155 |
-
embedding_scale=embedding_scale,
|
156 |
-
features=ref_s, # reference from the same speaker as the embedding
|
157 |
-
num_steps=diffusion_steps).squeeze(1)
|
158 |
-
|
159 |
-
|
160 |
-
s = s_pred[:, 128:]
|
161 |
-
ref = s_pred[:, :128]
|
162 |
-
|
163 |
-
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
164 |
-
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
165 |
-
|
166 |
-
d = model.predictor.text_encoder(d_en,
|
167 |
-
s, input_lengths, text_mask)
|
168 |
-
|
169 |
-
x, _ = model.predictor.lstm(d)
|
170 |
-
duration = model.predictor.duration_proj(x)
|
171 |
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
# encode prosody
|
183 |
-
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
184 |
-
if model_params.decoder.type == "hifigan":
|
185 |
-
asr_new = torch.zeros_like(en)
|
186 |
-
asr_new[:, :, 0] = en[:, :, 0]
|
187 |
-
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
188 |
-
en = asr_new
|
189 |
-
|
190 |
-
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
191 |
-
|
192 |
-
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
193 |
-
if model_params.decoder.type == "hifigan":
|
194 |
-
asr_new = torch.zeros_like(asr)
|
195 |
-
asr_new[:, :, 0] = asr[:, :, 0]
|
196 |
-
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
197 |
-
asr = asr_new
|
198 |
-
|
199 |
-
out = model.decoder(asr,
|
200 |
-
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
201 |
-
|
202 |
-
|
203 |
-
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
|
204 |
-
|
205 |
-
def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
|
206 |
-
text = text.strip()
|
207 |
-
ps = global_phonemizer.phonemize([text])
|
208 |
-
ps = word_tokenize(ps[0])
|
209 |
-
ps = ' '.join(ps)
|
210 |
-
ps = ps.replace('``', '"')
|
211 |
-
ps = ps.replace("''", '"')
|
212 |
-
|
213 |
-
tokens = textclenaer(ps)
|
214 |
-
tokens.insert(0, 0)
|
215 |
-
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
216 |
-
|
217 |
-
with torch.no_grad():
|
218 |
-
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
219 |
-
text_mask = length_to_mask(input_lengths).to(device)
|
220 |
-
|
221 |
-
t_en = model.text_encoder(tokens, input_lengths, text_mask)
|
222 |
-
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
223 |
-
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
224 |
-
|
225 |
-
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
226 |
-
embedding=bert_dur,
|
227 |
-
embedding_scale=embedding_scale,
|
228 |
-
features=ref_s, # reference from the same speaker as the embedding
|
229 |
-
num_steps=diffusion_steps).squeeze(1)
|
230 |
-
|
231 |
-
if s_prev is not None:
|
232 |
-
# convex combination of previous and current style
|
233 |
-
s_pred = t * s_prev + (1 - t) * s_pred
|
234 |
-
|
235 |
-
s = s_pred[:, 128:]
|
236 |
-
ref = s_pred[:, :128]
|
237 |
-
|
238 |
-
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
239 |
-
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
240 |
-
|
241 |
-
s_pred = torch.cat([ref, s], dim=-1)
|
242 |
-
|
243 |
-
d = model.predictor.text_encoder(d_en,
|
244 |
-
s, input_lengths, text_mask)
|
245 |
-
|
246 |
-
x, _ = model.predictor.lstm(d)
|
247 |
-
duration = model.predictor.duration_proj(x)
|
248 |
-
|
249 |
-
duration = torch.sigmoid(duration).sum(axis=-1)
|
250 |
-
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
251 |
-
|
252 |
-
|
253 |
-
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
254 |
-
c_frame = 0
|
255 |
-
for i in range(pred_aln_trg.size(0)):
|
256 |
-
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
257 |
-
c_frame += int(pred_dur[i].data)
|
258 |
-
|
259 |
-
# encode prosody
|
260 |
-
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
|
261 |
-
if model_params.decoder.type == "hifigan":
|
262 |
-
asr_new = torch.zeros_like(en)
|
263 |
-
asr_new[:, :, 0] = en[:, :, 0]
|
264 |
-
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
265 |
-
en = asr_new
|
266 |
-
|
267 |
-
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
268 |
-
|
269 |
-
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
270 |
-
if model_params.decoder.type == "hifigan":
|
271 |
-
asr_new = torch.zeros_like(asr)
|
272 |
-
asr_new[:, :, 0] = asr[:, :, 0]
|
273 |
-
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
274 |
-
asr = asr_new
|
275 |
-
|
276 |
-
out = model.decoder(asr,
|
277 |
-
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
278 |
-
|
279 |
-
|
280 |
-
return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
|
281 |
-
|
282 |
-
def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
|
283 |
text = text.strip()
|
284 |
ps = global_phonemizer.phonemize([text])
|
285 |
ps = word_tokenize(ps[0])
|
286 |
-
ps =
|
287 |
-
|
288 |
-
tokens = textclenaer(ps)
|
289 |
tokens.insert(0, 0)
|
290 |
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
291 |
|
292 |
-
ref_text = ref_text.strip()
|
293 |
-
ps = global_phonemizer.phonemize([ref_text])
|
294 |
-
ps = word_tokenize(ps[0])
|
295 |
-
ps = ' '.join(ps)
|
296 |
-
|
297 |
-
ref_tokens = textclenaer(ps)
|
298 |
-
ref_tokens.insert(0, 0)
|
299 |
-
ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)
|
300 |
-
|
301 |
-
|
302 |
with torch.no_grad():
|
303 |
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
304 |
text_mask = length_to_mask(input_lengths).to(device)
|
@@ -307,24 +253,21 @@ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=
|
|
307 |
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
308 |
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
309 |
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
num_steps=diffusion_steps).squeeze(1)
|
318 |
-
|
319 |
|
320 |
s = s_pred[:, 128:]
|
321 |
ref = s_pred[:, :128]
|
322 |
|
323 |
-
ref = alpha * ref + (1 - alpha)
|
324 |
-
s = beta * s + (1 - beta)
|
325 |
|
326 |
-
d = model.predictor.text_encoder(d_en,
|
327 |
-
s, input_lengths, text_mask)
|
328 |
|
329 |
x, _ = model.predictor.lstm(d)
|
330 |
duration = model.predictor.duration_proj(x)
|
@@ -332,32 +275,29 @@ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=
|
|
332 |
duration = torch.sigmoid(duration).sum(axis=-1)
|
333 |
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
334 |
|
335 |
-
|
336 |
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
337 |
c_frame = 0
|
338 |
for i in range(pred_aln_trg.size(0)):
|
339 |
-
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
340 |
c_frame += int(pred_dur[i].data)
|
341 |
|
342 |
# encode prosody
|
343 |
-
en =
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
en = asr_new
|
349 |
|
350 |
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
351 |
|
352 |
-
asr =
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
asr = asr_new
|
358 |
-
|
359 |
-
out = model.decoder(asr,
|
360 |
-
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
361 |
|
|
|
362 |
|
363 |
-
return
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
from cached_path import cached_path
|
6 |
+
import random
|
|
|
|
|
|
|
|
|
7 |
import nltk
|
8 |
+
from models import build_model
|
9 |
+
from text_utils import TextCleaner
|
10 |
+
from nltk.tokenize import word_tokenize
|
11 |
+
import phonemizer
|
12 |
+
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
|
13 |
+
from utils import recursive_munch
|
14 |
+
from Utils.PLBERT.util import load_plbert
|
15 |
+
|
16 |
+
nltk.download("punkt")
|
17 |
+
np.random.seed(0)
|
18 |
+
random.seed(0)
|
19 |
torch.manual_seed(0)
|
20 |
torch.backends.cudnn.benchmark = False
|
21 |
torch.backends.cudnn.deterministic = True
|
22 |
|
23 |
+
global_phonemizer = phonemizer.backend.EspeakBackend(
|
24 |
+
language="en-us", preserve_punctuation=True, with_stress=True
|
25 |
+
)
|
|
|
|
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
textcleaner = TextCleaner()
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
to_mel = torchaudio.transforms.MelSpectrogram(
|
32 |
+
n_mels=80, n_fft=2048, win_length=1200, hop_length=300
|
33 |
+
)
|
34 |
mean, std = -4, 4
|
35 |
|
36 |
+
|
37 |
def length_to_mask(lengths):
|
38 |
+
mask = (
|
39 |
+
torch.arange(lengths.max())
|
40 |
+
.unsqueeze(0)
|
41 |
+
.expand(lengths.shape[0], -1)
|
42 |
+
.type_as(lengths)
|
43 |
+
)
|
44 |
+
mask = torch.gt(mask + 1, lengths.unsqueeze(1))
|
45 |
return mask
|
46 |
|
47 |
+
|
48 |
def preprocess(wave):
|
49 |
wave_tensor = torch.from_numpy(wave).float()
|
50 |
mel_tensor = to_mel(wave_tensor)
|
51 |
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
52 |
return mel_tensor
|
53 |
|
54 |
+
|
55 |
def compute_style(path):
|
56 |
wave, sr = librosa.load(path, sr=24000)
|
57 |
audio, index = librosa.effects.trim(wave, top_db=30)
|
|
|
65 |
|
66 |
return torch.cat([ref_s, ref_p], dim=1)
|
67 |
|
68 |
+
|
69 |
+
device = "cpu"
|
70 |
if torch.cuda.is_available():
|
71 |
+
device = "cuda"
|
72 |
elif torch.backends.mps.is_available():
|
73 |
print("MPS would be available but cannot be used rn")
|
74 |
+
# device = "mps"
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
# config = yaml.safe_load(open("Models/LibriTTS/config.yml"))
|
77 |
+
config = {
|
78 |
+
"ASR_config": "Utils/ASR/config.yml",
|
79 |
+
"ASR_path": "Utils/ASR/epoch_00080.pth",
|
80 |
+
"F0_path": "Utils/JDC/bst.t7",
|
81 |
+
"PLBERT_dir": "Utils/PLBERT/",
|
82 |
+
"batch_size": 8,
|
83 |
+
"data_params": {
|
84 |
+
"OOD_data": "Data/OOD_texts.txt",
|
85 |
+
"min_length": 50,
|
86 |
+
"root_path": "",
|
87 |
+
"train_data": "Data/train_list.txt",
|
88 |
+
"val_data": "Data/val_list.txt",
|
89 |
+
},
|
90 |
+
"device": "cuda",
|
91 |
+
"epochs_1st": 40,
|
92 |
+
"epochs_2nd": 25,
|
93 |
+
"first_stage_path": "first_stage.pth",
|
94 |
+
"load_only_params": False,
|
95 |
+
"log_dir": "Models/LibriTTS",
|
96 |
+
"log_interval": 10,
|
97 |
+
"loss_params": {
|
98 |
+
"TMA_epoch": 4,
|
99 |
+
"diff_epoch": 0,
|
100 |
+
"joint_epoch": 0,
|
101 |
+
"lambda_F0": 1.0,
|
102 |
+
"lambda_ce": 20.0,
|
103 |
+
"lambda_diff": 1.0,
|
104 |
+
"lambda_dur": 1.0,
|
105 |
+
"lambda_gen": 1.0,
|
106 |
+
"lambda_mel": 5.0,
|
107 |
+
"lambda_mono": 1.0,
|
108 |
+
"lambda_norm": 1.0,
|
109 |
+
"lambda_s2s": 1.0,
|
110 |
+
"lambda_slm": 1.0,
|
111 |
+
"lambda_sty": 1.0,
|
112 |
+
},
|
113 |
+
"max_len": 300,
|
114 |
+
"model_params": {
|
115 |
+
"decoder": {
|
116 |
+
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
117 |
+
"resblock_kernel_sizes": [3, 7, 11],
|
118 |
+
"type": "hifigan",
|
119 |
+
"upsample_initial_channel": 512,
|
120 |
+
"upsample_kernel_sizes": [20, 10, 6, 4],
|
121 |
+
"upsample_rates": [10, 5, 3, 2],
|
122 |
+
},
|
123 |
+
"diffusion": {
|
124 |
+
"dist": {
|
125 |
+
"estimate_sigma_data": True,
|
126 |
+
"mean": -3.0,
|
127 |
+
"sigma_data": 0.19926648961191362,
|
128 |
+
"std": 1.0,
|
129 |
+
},
|
130 |
+
"embedding_mask_proba": 0.1,
|
131 |
+
"transformer": {
|
132 |
+
"head_features": 64,
|
133 |
+
"multiplier": 2,
|
134 |
+
"num_heads": 8,
|
135 |
+
"num_layers": 3,
|
136 |
+
},
|
137 |
+
},
|
138 |
+
"dim_in": 64,
|
139 |
+
"dropout": 0,
|
140 |
+
"hidden_dim": 512,
|
141 |
+
"max_conv_dim": 512,
|
142 |
+
"max_dur": 50,
|
143 |
+
"multispeaker": True,
|
144 |
+
"n_layer": 3,
|
145 |
+
"n_mels": 80,
|
146 |
+
"n_token": 178,
|
147 |
+
"slm": {
|
148 |
+
"hidden": 768,
|
149 |
+
"initial_channel": 64,
|
150 |
+
"model": "microsoft/wavlm-base-plus",
|
151 |
+
"nlayers": 13,
|
152 |
+
"sr": 16000,
|
153 |
+
},
|
154 |
+
"style_dim": 128,
|
155 |
+
},
|
156 |
+
"optimizer_params": {"bert_lr": 1e-05, "ft_lr": 1e-05, "lr": 0.0001},
|
157 |
+
"preprocess_params": {
|
158 |
+
"spect_params": {"hop_length": 300, "n_fft": 2048, "win_length": 1200},
|
159 |
+
"sr": 24000,
|
160 |
+
},
|
161 |
+
"pretrained_model": "Models/LibriTTS/epoch_2nd_00002.pth",
|
162 |
+
"save_freq": 1,
|
163 |
+
"second_stage_load_pretrained": True,
|
164 |
+
"slmadv_params": {
|
165 |
+
"batch_percentage": 0.5,
|
166 |
+
"iter": 20,
|
167 |
+
"max_len": 500,
|
168 |
+
"min_len": 400,
|
169 |
+
"scale": 0.01,
|
170 |
+
"sig": 1.5,
|
171 |
+
"thresh": 5,
|
172 |
+
},
|
173 |
+
}
|
174 |
+
|
175 |
+
|
176 |
+
BERT_path = config.get("PLBERT_dir", False)
|
177 |
plbert = load_plbert(BERT_path)
|
178 |
|
179 |
+
|
180 |
+
model_params = recursive_munch(config["model_params"])
|
181 |
+
model = build_model(model_params, plbert)
|
182 |
_ = [model[key].eval() for key in model]
|
183 |
_ = [model[key].to(device) for key in model]
|
184 |
|
185 |
+
# for key in model:
|
186 |
+
# print(f"Compiling {key}")
|
187 |
+
# model[key] = torch.compile(model[key])
|
188 |
+
# print(f"Compiled {key}")
|
189 |
+
|
190 |
+
|
191 |
+
params_whole = torch.load(
|
192 |
+
str(
|
193 |
+
cached_path(
|
194 |
+
"hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth"
|
195 |
+
)
|
196 |
+
),
|
197 |
+
map_location="cpu",
|
198 |
+
)
|
199 |
+
params = params_whole["net"]
|
200 |
|
201 |
for key in model:
|
202 |
if key in params:
|
203 |
+
print("%s loaded" % key)
|
204 |
try:
|
205 |
model[key].load_state_dict(params[key])
|
206 |
except:
|
207 |
from collections import OrderedDict
|
208 |
+
|
209 |
state_dict = params[key]
|
210 |
new_state_dict = OrderedDict()
|
211 |
for k, v in state_dict.items():
|
212 |
+
name = k[7:] # remove `module.`
|
213 |
new_state_dict[name] = v
|
214 |
# load params
|
215 |
model[key].load_state_dict(new_state_dict, strict=False)
|
|
|
217 |
# _load(params[key], model[key])
|
218 |
_ = [model[key].eval() for key in model]
|
219 |
|
|
|
220 |
|
221 |
sampler = DiffusionSampler(
|
222 |
model.diffusion.diffusion,
|
223 |
sampler=ADPM2Sampler(),
|
224 |
+
sigma_schedule=KarrasSchedule(
|
225 |
+
sigma_min=0.0001, sigma_max=3.0, rho=9.0
|
226 |
+
), # empirical parameters
|
227 |
+
clamp=False,
|
228 |
)
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
+
def inference(
|
232 |
+
text,
|
233 |
+
ref_s,
|
234 |
+
alpha=0.3,
|
235 |
+
beta=0.7,
|
236 |
+
diffusion_steps=5,
|
237 |
+
embedding_scale=1,
|
238 |
+
use_gruut=False,
|
239 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
text = text.strip()
|
241 |
ps = global_phonemizer.phonemize([text])
|
242 |
ps = word_tokenize(ps[0])
|
243 |
+
ps = " ".join(ps)
|
244 |
+
tokens = textcleaner(ps)
|
|
|
245 |
tokens.insert(0, 0)
|
246 |
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
with torch.no_grad():
|
249 |
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
250 |
text_mask = length_to_mask(input_lengths).to(device)
|
|
|
253 |
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
|
254 |
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
|
255 |
|
256 |
+
s_pred = sampler(
|
257 |
+
noise=torch.randn((1, 256)).unsqueeze(1).to(device),
|
258 |
+
embedding=bert_dur,
|
259 |
+
embedding_scale=embedding_scale,
|
260 |
+
features=ref_s, # reference from the same speaker as the embedding
|
261 |
+
num_steps=diffusion_steps,
|
262 |
+
).squeeze(1)
|
|
|
|
|
263 |
|
264 |
s = s_pred[:, 128:]
|
265 |
ref = s_pred[:, :128]
|
266 |
|
267 |
+
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
|
268 |
+
s = beta * s + (1 - beta) * ref_s[:, 128:]
|
269 |
|
270 |
+
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
|
|
|
271 |
|
272 |
x, _ = model.predictor.lstm(d)
|
273 |
duration = model.predictor.duration_proj(x)
|
|
|
275 |
duration = torch.sigmoid(duration).sum(axis=-1)
|
276 |
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
277 |
|
|
|
278 |
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
279 |
c_frame = 0
|
280 |
for i in range(pred_aln_trg.size(0)):
|
281 |
+
pred_aln_trg[i, c_frame : c_frame + int(pred_dur[i].data)] = 1
|
282 |
c_frame += int(pred_dur[i].data)
|
283 |
|
284 |
# encode prosody
|
285 |
+
en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
|
286 |
+
asr_new = torch.zeros_like(en)
|
287 |
+
asr_new[:, :, 0] = en[:, :, 0]
|
288 |
+
asr_new[:, :, 1:] = en[:, :, 0:-1]
|
289 |
+
en = asr_new
|
|
|
290 |
|
291 |
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
|
292 |
|
293 |
+
asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
|
294 |
+
asr_new = torch.zeros_like(asr)
|
295 |
+
asr_new[:, :, 0] = asr[:, :, 0]
|
296 |
+
asr_new[:, :, 1:] = asr[:, :, 0:-1]
|
297 |
+
asr = asr_new
|
|
|
|
|
|
|
|
|
298 |
|
299 |
+
out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
|
300 |
|
301 |
+
return (
|
302 |
+
out.squeeze().cpu().numpy()[..., :-50]
|
303 |
+
) # weird pulse at the end of the model, need to be fixed later
|
train_finetune.py
CHANGED
@@ -7,8 +7,6 @@ import numpy as np
|
|
7 |
import torch
|
8 |
from torch import nn
|
9 |
import torch.nn.functional as F
|
10 |
-
import torchaudio
|
11 |
-
import librosa
|
12 |
import click
|
13 |
import shutil
|
14 |
import warnings
|
@@ -18,8 +16,6 @@ from torch.utils.tensorboard import SummaryWriter
|
|
18 |
|
19 |
from meldataset import build_dataloader
|
20 |
|
21 |
-
from Utils.ASR.models import ASRCNN
|
22 |
-
from Utils.JDC.model import JDCNet
|
23 |
from Utils.PLBERT.util import load_plbert
|
24 |
|
25 |
from models import *
|
@@ -75,7 +71,7 @@ def main(config_path):
|
|
75 |
epochs = config.get("epochs", 200)
|
76 |
save_freq = config.get("save_freq", 2)
|
77 |
log_interval = config.get("log_interval", 10)
|
78 |
-
|
79 |
|
80 |
data_params = config.get("data_params", None)
|
81 |
sr = config["preprocess_params"].get("sr", 24000)
|
@@ -245,11 +241,11 @@ def main(config_path):
|
|
245 |
n_down = model.text_aligner.n_down
|
246 |
|
247 |
best_loss = float("inf") # best test loss
|
248 |
-
|
249 |
-
|
250 |
iters = 0
|
251 |
|
252 |
-
|
253 |
torch.cuda.empty_cache()
|
254 |
|
255 |
stft_loss = MultiResolutionSTFTLoss().to(device)
|
@@ -257,7 +253,6 @@ def main(config_path):
|
|
257 |
print("BERT", optimizer.optimizers["bert"])
|
258 |
print("decoder", optimizer.optimizers["decoder"])
|
259 |
|
260 |
-
start_ds = False
|
261 |
|
262 |
running_std = []
|
263 |
|
@@ -302,7 +297,7 @@ def main(config_path):
|
|
302 |
) = batch
|
303 |
with torch.no_grad():
|
304 |
mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
|
305 |
-
|
306 |
text_mask = length_to_mask(input_lengths).to(texts.device)
|
307 |
|
308 |
# compute reference styles
|
|
|
7 |
import torch
|
8 |
from torch import nn
|
9 |
import torch.nn.functional as F
|
|
|
|
|
10 |
import click
|
11 |
import shutil
|
12 |
import warnings
|
|
|
16 |
|
17 |
from meldataset import build_dataloader
|
18 |
|
|
|
|
|
19 |
from Utils.PLBERT.util import load_plbert
|
20 |
|
21 |
from models import *
|
|
|
71 |
epochs = config.get("epochs", 200)
|
72 |
save_freq = config.get("save_freq", 2)
|
73 |
log_interval = config.get("log_interval", 10)
|
74 |
+
config.get("save_freq", 2)
|
75 |
|
76 |
data_params = config.get("data_params", None)
|
77 |
sr = config["preprocess_params"].get("sr", 24000)
|
|
|
241 |
n_down = model.text_aligner.n_down
|
242 |
|
243 |
best_loss = float("inf") # best test loss
|
244 |
+
list([])
|
245 |
+
list([])
|
246 |
iters = 0
|
247 |
|
248 |
+
nn.L1Loss() # F0 loss (regression)
|
249 |
torch.cuda.empty_cache()
|
250 |
|
251 |
stft_loss = MultiResolutionSTFTLoss().to(device)
|
|
|
253 |
print("BERT", optimizer.optimizers["bert"])
|
254 |
print("decoder", optimizer.optimizers["decoder"])
|
255 |
|
|
|
256 |
|
257 |
running_std = []
|
258 |
|
|
|
297 |
) = batch
|
298 |
with torch.no_grad():
|
299 |
mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
|
300 |
+
length_to_mask(mel_input_length).to(device)
|
301 |
text_mask = length_to_mask(input_lengths).to(texts.device)
|
302 |
|
303 |
# compute reference styles
|
train_first.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
import os
|
2 |
import os.path as osp
|
3 |
-
import re
|
4 |
-
import sys
|
5 |
import yaml
|
6 |
import shutil
|
7 |
import numpy as np
|
@@ -17,10 +15,7 @@ import yaml
|
|
17 |
from munch import Munch
|
18 |
import numpy as np
|
19 |
import torch
|
20 |
-
from torch import nn
|
21 |
import torch.nn.functional as F
|
22 |
-
import torchaudio
|
23 |
-
import librosa
|
24 |
|
25 |
from models import *
|
26 |
from meldataset import build_dataloader
|
@@ -30,7 +25,6 @@ from optimizers import build_optimizer
|
|
30 |
import time
|
31 |
|
32 |
from accelerate import Accelerator
|
33 |
-
from accelerate.utils import LoggerType
|
34 |
from accelerate import DistributedDataParallelKwargs
|
35 |
|
36 |
from torch.utils.tensorboard import SummaryWriter
|
@@ -69,7 +63,7 @@ def main(config_path):
|
|
69 |
device = accelerator.device
|
70 |
|
71 |
epochs = config.get("epochs_1st", 200)
|
72 |
-
|
73 |
log_interval = config.get("log_interval", 10)
|
74 |
saving_epoch = config.get("save_freq", 2)
|
75 |
|
@@ -137,8 +131,8 @@ def main(config_path):
|
|
137 |
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
|
138 |
|
139 |
best_loss = float("inf") # best test loss
|
140 |
-
|
141 |
-
|
142 |
|
143 |
loss_params = Munch(config["loss_params"])
|
144 |
TMA_epoch = loss_params.TMA_epoch
|
|
|
1 |
import os
|
2 |
import os.path as osp
|
|
|
|
|
3 |
import yaml
|
4 |
import shutil
|
5 |
import numpy as np
|
|
|
15 |
from munch import Munch
|
16 |
import numpy as np
|
17 |
import torch
|
|
|
18 |
import torch.nn.functional as F
|
|
|
|
|
19 |
|
20 |
from models import *
|
21 |
from meldataset import build_dataloader
|
|
|
25 |
import time
|
26 |
|
27 |
from accelerate import Accelerator
|
|
|
28 |
from accelerate import DistributedDataParallelKwargs
|
29 |
|
30 |
from torch.utils.tensorboard import SummaryWriter
|
|
|
63 |
device = accelerator.device
|
64 |
|
65 |
epochs = config.get("epochs_1st", 200)
|
66 |
+
config.get("save_freq", 2)
|
67 |
log_interval = config.get("log_interval", 10)
|
68 |
saving_epoch = config.get("save_freq", 2)
|
69 |
|
|
|
131 |
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
|
132 |
|
133 |
best_loss = float("inf") # best test loss
|
134 |
+
list([])
|
135 |
+
list([])
|
136 |
|
137 |
loss_params = Munch(config["loss_params"])
|
138 |
TMA_epoch = loss_params.TMA_epoch
|
train_second.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
# load packages
|
2 |
-
import random
|
3 |
import yaml
|
4 |
import time
|
5 |
from munch import Munch
|
@@ -7,8 +6,6 @@ import numpy as np
|
|
7 |
import torch
|
8 |
from torch import nn
|
9 |
import torch.nn.functional as F
|
10 |
-
import torchaudio
|
11 |
-
import librosa
|
12 |
import click
|
13 |
import shutil
|
14 |
import warnings
|
@@ -18,8 +15,6 @@ from torch.utils.tensorboard import SummaryWriter
|
|
18 |
|
19 |
from meldataset import build_dataloader
|
20 |
|
21 |
-
from Utils.ASR.models import ASRCNN
|
22 |
-
from Utils.JDC.model import JDCNet
|
23 |
from Utils.PLBERT.util import load_plbert
|
24 |
|
25 |
from models import *
|
@@ -73,7 +68,7 @@ def main(config_path):
|
|
73 |
batch_size = config.get("batch_size", 10)
|
74 |
|
75 |
epochs = config.get("epochs_2nd", 200)
|
76 |
-
|
77 |
log_interval = config.get("log_interval", 10)
|
78 |
saving_epoch = config.get("save_freq", 2)
|
79 |
|
@@ -245,11 +240,11 @@ def main(config_path):
|
|
245 |
n_down = model.text_aligner.n_down
|
246 |
|
247 |
best_loss = float("inf") # best test loss
|
248 |
-
|
249 |
-
|
250 |
iters = 0
|
251 |
|
252 |
-
|
253 |
torch.cuda.empty_cache()
|
254 |
|
255 |
stft_loss = MultiResolutionSTFTLoss().to(device)
|
@@ -303,7 +298,7 @@ def main(config_path):
|
|
303 |
|
304 |
with torch.no_grad():
|
305 |
mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
|
306 |
-
|
307 |
text_mask = length_to_mask(input_lengths).to(texts.device)
|
308 |
|
309 |
try:
|
@@ -445,7 +440,7 @@ def main(config_path):
|
|
445 |
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
|
446 |
F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
|
447 |
|
448 |
-
|
449 |
|
450 |
N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
|
451 |
|
|
|
1 |
# load packages
|
|
|
2 |
import yaml
|
3 |
import time
|
4 |
from munch import Munch
|
|
|
6 |
import torch
|
7 |
from torch import nn
|
8 |
import torch.nn.functional as F
|
|
|
|
|
9 |
import click
|
10 |
import shutil
|
11 |
import warnings
|
|
|
15 |
|
16 |
from meldataset import build_dataloader
|
17 |
|
|
|
|
|
18 |
from Utils.PLBERT.util import load_plbert
|
19 |
|
20 |
from models import *
|
|
|
68 |
batch_size = config.get("batch_size", 10)
|
69 |
|
70 |
epochs = config.get("epochs_2nd", 200)
|
71 |
+
config.get("save_freq", 2)
|
72 |
log_interval = config.get("log_interval", 10)
|
73 |
saving_epoch = config.get("save_freq", 2)
|
74 |
|
|
|
240 |
n_down = model.text_aligner.n_down
|
241 |
|
242 |
best_loss = float("inf") # best test loss
|
243 |
+
list([])
|
244 |
+
list([])
|
245 |
iters = 0
|
246 |
|
247 |
+
nn.L1Loss() # F0 loss (regression)
|
248 |
torch.cuda.empty_cache()
|
249 |
|
250 |
stft_loss = MultiResolutionSTFTLoss().to(device)
|
|
|
298 |
|
299 |
with torch.no_grad():
|
300 |
mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
|
301 |
+
length_to_mask(mel_input_length).to(device)
|
302 |
text_mask = length_to_mask(input_lengths).to(texts.device)
|
303 |
|
304 |
try:
|
|
|
440 |
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
|
441 |
F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
|
442 |
|
443 |
+
model.text_aligner.get_feature(gt)
|
444 |
|
445 |
N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
|
446 |
|
utils.py
CHANGED
@@ -1,13 +1,6 @@
|
|
1 |
-
from monotonic_align import maximum_path
|
2 |
-
from monotonic_align import mask_from_lens
|
3 |
from monotonic_align.core import maximum_path_c
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
import copy
|
7 |
-
from torch import nn
|
8 |
-
import torch.nn.functional as F
|
9 |
-
import torchaudio
|
10 |
-
import librosa
|
11 |
import matplotlib.pyplot as plt
|
12 |
from munch import Munch
|
13 |
|
|
|
|
|
|
|
1 |
from monotonic_align.core import maximum_path_c
|
2 |
import numpy as np
|
3 |
import torch
|
|
|
|
|
|
|
|
|
|
|
4 |
import matplotlib.pyplot as plt
|
5 |
from munch import Munch
|
6 |
|
voices/f-us-1.wav.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:24012e9cefccf44b9187cb7e61907eac7120e96115f77b922c74c0e36b5b45f6
|
3 |
+
size 1152
|
voices/f-us-2.wav.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:25df85a2fb487a2e55189fbd02173c7b84028b0cbdb056aaa61d0f853136ebba
|
3 |
+
size 1152
|
voices/f-us-3.wav.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ad8ce27bfbe1d967d3b5f5f0894b6f3d899c1f23dec5a85157318ecb719eab7
|
3 |
+
size 1152
|
voices/f-us-4.wav.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:397ead679dbd859550cfe2a2635d5ddc78c0b400cd434fdd0dd41cac88ceb667
|
3 |
+
size 1152
|
voices/m-us-1.wav.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d22a0041c44675a8e2d8b9830a3261f5359e31a8418ea5ef19f9ba76bda2c13
|
3 |
+
size 1152
|
voices/m-us-2.wav.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e09851ffb127fbd3a17329b8d68a27dec5f5920fd5f1aaa7b871046552a2c902
|
3 |
+
size 1152
|
voices/m-us-3.wav.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:98ff10843ecb12f0fe4c31c9309da6f594ed2dd7248d163c927fda05dd608336
|
3 |
+
size 1152
|
voices/m-us-4.wav.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:275b5b956cef2e6eff4258fabb8fffdab130d5e858fd826258fd76e46296263d
|
3 |
+
size 1152
|