jonluca commited on
Commit
1373f78
·
unverified ·
1 Parent(s): f5915fd

add precomputed voices, reformat code, remove unused code

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
5
+ <inspection_tool class="PyCompatibilityInspection" enabled="true" level="ERROR" enabled_by_default="true" editorAttributes="ERRORS_ATTRIBUTES">
6
+ <option name="ourVersions">
7
+ <value>
8
+ <list size="1">
9
+ <item index="0" class="java.lang.String" itemvalue="3.10" />
10
+ </list>
11
+ </value>
12
+ </option>
13
+ </inspection_tool>
14
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
15
+ <option name="ignoredPackages">
16
+ <value>
17
+ <list size="58">
18
+ <item index="0" class="java.lang.String" itemvalue="pandas" />
19
+ <item index="1" class="java.lang.String" itemvalue="fastapi" />
20
+ <item index="2" class="java.lang.String" itemvalue="pydantic" />
21
+ <item index="3" class="java.lang.String" itemvalue="clickhouse-connect" />
22
+ <item index="4" class="java.lang.String" itemvalue="uvicorn" />
23
+ <item index="5" class="java.lang.String" itemvalue="requests" />
24
+ <item index="6" class="java.lang.String" itemvalue="posthog" />
25
+ <item index="7" class="java.lang.String" itemvalue="numba" />
26
+ <item index="8" class="java.lang.String" itemvalue="faiss-cpu" />
27
+ <item index="9" class="java.lang.String" itemvalue="llvmlite" />
28
+ <item index="10" class="java.lang.String" itemvalue="tokenizers" />
29
+ <item index="11" class="java.lang.String" itemvalue="scipy" />
30
+ <item index="12" class="java.lang.String" itemvalue="transformers" />
31
+ <item index="13" class="java.lang.String" itemvalue="tornado" />
32
+ <item index="14" class="java.lang.String" itemvalue="threadpoolctl" />
33
+ <item index="15" class="java.lang.String" itemvalue="unidecode" />
34
+ <item index="16" class="java.lang.String" itemvalue="py-cpuinfo" />
35
+ <item index="17" class="java.lang.String" itemvalue="nbconvert" />
36
+ <item index="18" class="java.lang.String" itemvalue="tqdm" />
37
+ <item index="19" class="java.lang.String" itemvalue="appdirs" />
38
+ <item index="20" class="java.lang.String" itemvalue="rotary_embedding_torch" />
39
+ <item index="21" class="java.lang.String" itemvalue="deepspeed" />
40
+ <item index="22" class="java.lang.String" itemvalue="progressbar" />
41
+ <item index="23" class="java.lang.String" itemvalue="inflect" />
42
+ <item index="24" class="java.lang.String" itemvalue="librosa" />
43
+ <item index="25" class="java.lang.String" itemvalue="ffmpeg" />
44
+ <item index="26" class="java.lang.String" itemvalue="hjson" />
45
+ <item index="27" class="java.lang.String" itemvalue="einops" />
46
+ <item index="28" class="java.lang.String" itemvalue="torchaudio" />
47
+ <item index="29" class="java.lang.String" itemvalue="pyinstaller" />
48
+ <item index="30" class="java.lang.String" itemvalue="pytorch-lightning" />
49
+ <item index="31" class="java.lang.String" itemvalue="bitarray" />
50
+ <item index="32" class="java.lang.String" itemvalue="pyright" />
51
+ <item index="33" class="java.lang.String" itemvalue="yt-dlp" />
52
+ <item index="34" class="java.lang.String" itemvalue="torch" />
53
+ <item index="35" class="java.lang.String" itemvalue="torchvision" />
54
+ <item index="36" class="java.lang.String" itemvalue="sacrebleu" />
55
+ <item index="37" class="java.lang.String" itemvalue="aioshutil" />
56
+ <item index="38" class="java.lang.String" itemvalue="absl-py" />
57
+ <item index="39" class="java.lang.String" itemvalue="gradio" />
58
+ <item index="40" class="java.lang.String" itemvalue="matplotlib-inline" />
59
+ <item index="41" class="java.lang.String" itemvalue="Werkzeug" />
60
+ <item index="42" class="java.lang.String" itemvalue="fairseq" />
61
+ <item index="43" class="java.lang.String" itemvalue="json5" />
62
+ <item index="44" class="java.lang.String" itemvalue="torchfcpe" />
63
+ <item index="45" class="java.lang.String" itemvalue="numpy" />
64
+ <item index="46" class="java.lang.String" itemvalue="pyasn1" />
65
+ <item index="47" class="java.lang.String" itemvalue="torchcrepe" />
66
+ <item index="48" class="java.lang.String" itemvalue="pyasn1-modules" />
67
+ <item index="49" class="java.lang.String" itemvalue="tensorboard" />
68
+ <item index="50" class="java.lang.String" itemvalue="av" />
69
+ <item index="51" class="java.lang.String" itemvalue="matplotlib" />
70
+ <item index="52" class="java.lang.String" itemvalue="tensorboardX" />
71
+ <item index="53" class="java.lang.String" itemvalue="uc-micro-py" />
72
+ <item index="54" class="java.lang.String" itemvalue="ffmpy" />
73
+ <item index="55" class="java.lang.String" itemvalue="pyworld" />
74
+ <item index="56" class="java.lang.String" itemvalue="Markdown" />
75
+ <item index="57" class="java.lang.String" itemvalue="praat-parselmouth" />
76
+ </list>
77
+ </value>
78
+ </option>
79
+ </inspection_tool>
80
+ </profile>
81
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="styletts2-personal" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="styletts2-personal" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/styletts2-personal.iml" filepath="$PROJECT_DIR$/.idea/styletts2-personal.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/ruff.xml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="RuffConfigService">
4
+ <option name="disableOnSaveOutsideOfProject" value="false" />
5
+ <option name="globalRuffExecutablePath" value="/opt/homebrew/bin/ruff" />
6
+ <option name="globalRuffLspExecutablePath" value="/opt/homebrew/bin/ruff" />
7
+ <option name="projectRuffLspExecutablePath" value="/opt/homebrew/Caskroom/miniconda/base/envs/styletts2-personal/bin/ruff" />
8
+ <option name="runRuffOnSave" value="true" />
9
+ <option name="useRuffFormat" value="true" />
10
+ </component>
11
+ </project>
.idea/styletts2-personal.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="styletts2-personal" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
Modules/diffusion/diffusion.py CHANGED
@@ -1,11 +1,5 @@
1
- from math import pi
2
- from random import randint
3
- from typing import Any, Optional, Sequence, Tuple, Union
4
 
5
- import torch
6
- from einops import rearrange
7
  from torch import Tensor, nn
8
- from tqdm import tqdm
9
 
10
  from .utils import *
11
  from .sampler import *
 
 
 
 
1
 
 
 
2
  from torch import Tensor, nn
 
3
 
4
  from .utils import *
5
  from .sampler import *
Modules/diffusion/modules.py CHANGED
@@ -1,7 +1,7 @@
1
- from math import floor, log, pi
2
- from typing import Any, List, Optional, Sequence, Tuple, Union
3
 
4
- from .utils import *
5
 
6
  import torch
7
  import torch.nn as nn
@@ -10,6 +10,7 @@ from einops.layers.torch import Rearrange
10
  from einops_exts import rearrange_many
11
  from torch import Tensor, einsum
12
 
 
13
 
14
  """
15
  Utils
 
1
+ from math import log, pi
2
+ from typing import Optional
3
 
4
+ import torch.nn.functional as F
5
 
6
  import torch
7
  import torch.nn as nn
 
10
  from einops_exts import rearrange_many
11
  from torch import Tensor, einsum
12
 
13
+ from Modules.diffusion.utils import default, exists, rand_bool
14
 
15
  """
16
  Utils
Modules/diffusion/sampler.py CHANGED
@@ -1,10 +1,10 @@
1
  from math import atan, cos, pi, sin, sqrt
2
  from typing import Any, Callable, List, Optional, Tuple, Type
3
 
4
- import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
- from einops import rearrange, reduce
8
  from torch import Tensor
9
 
10
  from .utils import *
@@ -437,7 +437,11 @@ class KarrasSampler(Sampler):
437
  # Denoise to sample
438
  for i in range(num_steps - 1):
439
  x = self.step(
440
- x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
 
 
 
 
441
  )
442
 
443
  return x
 
1
  from math import atan, cos, pi, sin, sqrt
2
  from typing import Any, Callable, List, Optional, Tuple, Type
3
 
4
+
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
+ from einops import rearrange
8
  from torch import Tensor
9
 
10
  from .utils import *
 
437
  # Denoise to sample
438
  for i in range(num_steps - 1):
439
  x = self.step(
440
+ x,
441
+ fn=fn,
442
+ sigma=sigmas[i],
443
+ sigma_next=sigmas[i + 1],
444
+ gamma=gammas[i], # type: ignore # noqa
445
  )
446
 
447
  return x
Modules/diffusion/utils.py CHANGED
@@ -1,12 +1,9 @@
1
  from functools import reduce
2
  from inspect import isfunction
3
- from math import ceil, floor, log2, pi
4
  from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
 
6
  import torch
7
- import torch.nn.functional as F
8
- from einops import rearrange
9
- from torch import Generator, Tensor
10
  from typing_extensions import TypeGuard
11
 
12
  T = TypeVar("T")
 
1
  from functools import reduce
2
  from inspect import isfunction
3
+ from math import ceil, floor, log2
4
  from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
 
6
  import torch
 
 
 
7
  from typing_extensions import TypeGuard
8
 
9
  T = TypeVar("T")
Modules/discriminators.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
2
  import torch.nn.functional as F
3
  import torch.nn as nn
4
- from torch.nn import Conv1d, AvgPool1d, Conv2d
5
- from torch.nn.utils import weight_norm, spectral_norm
 
6
 
7
  from .utils import get_padding
8
 
@@ -21,8 +22,8 @@ def stft(x, fft_size, hop_size, win_length, window):
21
  Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
22
  """
23
  x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
24
- real = x_stft[..., 0]
25
- imag = x_stft[..., 1]
26
 
27
  return torch.abs(x_stft).transpose(2, 1)
28
 
@@ -39,7 +40,7 @@ class SpecDiscriminator(nn.Module):
39
  use_spectral_norm=False,
40
  ):
41
  super(SpecDiscriminator, self).__init__()
42
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
43
  self.fft_size = fft_size
44
  self.shift_size = shift_size
45
  self.win_length = win_length
@@ -123,7 +124,7 @@ class DiscriminatorP(torch.nn.Module):
123
  def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
124
  super(DiscriminatorP, self).__init__()
125
  self.period = period
126
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
127
  self.convs = nn.ModuleList(
128
  [
129
  norm_f(
@@ -225,7 +226,7 @@ class WavLMDiscriminator(nn.Module):
225
  self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
226
  ):
227
  super(WavLMDiscriminator, self).__init__()
228
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
229
  self.pre = norm_f(
230
  Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
231
  )
 
1
  import torch
2
  import torch.nn.functional as F
3
  import torch.nn as nn
4
+ from torch.nn import Conv1d, Conv2d
5
+ from torch.nn.utils import spectral_norm
6
+ from torch.nn.utils.parametrizations import weight_norm
7
 
8
  from .utils import get_padding
9
 
 
22
  Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
23
  """
24
  x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
25
+ x_stft[..., 0]
26
+ x_stft[..., 1]
27
 
28
  return torch.abs(x_stft).transpose(2, 1)
29
 
 
40
  use_spectral_norm=False,
41
  ):
42
  super(SpecDiscriminator, self).__init__()
43
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
44
  self.fft_size = fft_size
45
  self.shift_size = shift_size
46
  self.win_length = win_length
 
124
  def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
125
  super(DiscriminatorP, self).__init__()
126
  self.period = period
127
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
128
  self.convs = nn.ModuleList(
129
  [
130
  norm_f(
 
226
  self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
227
  ):
228
  super(WavLMDiscriminator, self).__init__()
229
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
230
  self.pre = norm_f(
231
  Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
232
  )
Modules/hifigan.py CHANGED
@@ -1,12 +1,12 @@
1
  import torch
2
  import torch.nn.functional as F
3
  import torch.nn as nn
4
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
 
6
  from .utils import init_weights, get_padding
7
 
8
  import math
9
- import random
10
  import numpy as np
11
 
12
  LRELU_SLOPE = 0.1
@@ -269,7 +269,7 @@ class SineGen(torch.nn.Module):
269
  output sine_tensor: tensor(batchsize=1, length, dim)
270
  output uv: tensor(batchsize=1, length, 1)
271
  """
272
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
273
  # fundamental component
274
  fn = torch.multiply(
275
  f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
@@ -515,6 +515,7 @@ class AdainResBlk1d(nn.Module):
515
  self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
516
  self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
517
  self.norm1 = AdaIN1d(style_dim, dim_in)
 
518
  self.norm2 = AdaIN1d(style_dim, dim_out)
519
  if self.learned_sc:
520
  self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
@@ -581,6 +582,8 @@ class Decoder(nn.Module):
581
  nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
582
  )
583
 
 
 
584
  self.N_conv = weight_norm(
585
  nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
586
  )
@@ -599,30 +602,6 @@ class Decoder(nn.Module):
599
  )
600
 
601
  def forward(self, asr, F0_curve, N, s):
602
- if self.training:
603
- downlist = [0, 3, 7]
604
- F0_down = downlist[random.randint(0, 2)]
605
- downlist = [0, 3, 7, 15]
606
- N_down = downlist[random.randint(0, 3)]
607
- if F0_down:
608
- F0_curve = (
609
- nn.functional.conv1d(
610
- F0_curve.unsqueeze(1),
611
- torch.ones(1, 1, F0_down).to("cuda"),
612
- padding=F0_down // 2,
613
- ).squeeze(1)
614
- / F0_down
615
- )
616
- if N_down:
617
- N = (
618
- nn.functional.conv1d(
619
- N.unsqueeze(1),
620
- torch.ones(1, 1, N_down).to("cuda"),
621
- padding=N_down // 2,
622
- ).squeeze(1)
623
- / N_down
624
- )
625
-
626
  F0 = self.F0_conv(F0_curve.unsqueeze(1))
627
  N = self.N_conv(N.unsqueeze(1))
628
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import remove_weight_norm
6
+ from torch.nn.utils.parametrizations import weight_norm
7
  from .utils import init_weights, get_padding
8
 
9
  import math
 
10
  import numpy as np
11
 
12
  LRELU_SLOPE = 0.1
 
269
  output sine_tensor: tensor(batchsize=1, length, dim)
270
  output uv: tensor(batchsize=1, length, 1)
271
  """
272
+ torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
273
  # fundamental component
274
  fn = torch.multiply(
275
  f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
 
515
  self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
516
  self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
517
  self.norm1 = AdaIN1d(style_dim, dim_in)
518
+ # self.norm1 = torch.compile(self.norm1)
519
  self.norm2 = AdaIN1d(style_dim, dim_out)
520
  if self.learned_sc:
521
  self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
 
582
  nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
583
  )
584
 
585
+ # self.F0_conv = torch.compile(self.F0_conv)
586
+
587
  self.N_conv = weight_norm(
588
  nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
589
  )
 
602
  )
603
 
604
  def forward(self, asr, F0_curve, N, s):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  F0 = self.F0_conv(F0_curve.unsqueeze(1))
606
  N = self.N_conv(N.unsqueeze(1))
607
 
Modules/istftnet.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
2
  import torch.nn.functional as F
3
  import torch.nn as nn
4
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
 
6
  from .utils import init_weights, get_padding
7
 
8
  import math
@@ -313,7 +314,7 @@ class SineGen(torch.nn.Module):
313
  output sine_tensor: tensor(batchsize=1, length, dim)
314
  output uv: tensor(batchsize=1, length, 1)
315
  """
316
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
317
  # fundamental component
318
  fn = torch.multiply(
319
  f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
 
1
  import torch
2
  import torch.nn.functional as F
3
  import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d
5
+ from torch.nn.utils import remove_weight_norm
6
+ from torch.nn.utils.parametrizations import weight_norm
7
  from .utils import init_weights, get_padding
8
 
9
  import math
 
314
  output sine_tensor: tensor(batchsize=1, length, dim)
315
  output uv: tensor(batchsize=1, length, 1)
316
  """
317
+ torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
318
  # fundamental component
319
  fn = torch.multiply(
320
  f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
Modules/slmadv.py CHANGED
@@ -67,7 +67,7 @@ class SLMAdversarialLoss(torch.nn.Module):
67
  ).squeeze(1)
68
 
69
  s_dur = s_preds[:, 128:]
70
- s = s_preds[:, :128]
71
 
72
  d, _ = self.model.predictor(
73
  d_en,
@@ -138,8 +138,6 @@ class SLMAdversarialLoss(torch.nn.Module):
138
  p_en = []
139
  sp = []
140
 
141
- F0_fakes = []
142
- N_fakes = []
143
 
144
  wav = []
145
 
 
67
  ).squeeze(1)
68
 
69
  s_dur = s_preds[:, 128:]
70
+ s_preds[:, :128]
71
 
72
  d, _ = self.model.predictor(
73
  d_en,
 
138
  p_en = []
139
  sp = []
140
 
 
 
141
 
142
  wav = []
143
 
Utils/ASR/layers.py CHANGED
@@ -1,10 +1,6 @@
1
- import math
2
  import torch
3
  from torch import nn
4
- from typing import Optional, Any
5
- from torch import Tensor
6
  import torch.nn.functional as F
7
- import torchaudio
8
  import torchaudio.functional as audio_F
9
 
10
  import random
 
 
1
  import torch
2
  from torch import nn
 
 
3
  import torch.nn.functional as F
 
4
  import torchaudio.functional as audio_F
5
 
6
  import random
Utils/ASR/models.py CHANGED
@@ -1,7 +1,6 @@
1
  import math
2
  import torch
3
  from torch import nn
4
- from torch.nn import TransformerEncoder
5
  import torch.nn.functional as F
6
  from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
7
 
 
1
  import math
2
  import torch
3
  from torch import nn
 
4
  import torch.nn.functional as F
5
  from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
6
 
Utils/JDC/model.py CHANGED
@@ -84,7 +84,7 @@ class JDCNet(nn.Module):
84
  self.apply(self.init_weights)
85
 
86
  def get_feature_GAN(self, x):
87
- seq_len = x.shape[-2]
88
  x = x.float().transpose(-1, -2)
89
 
90
  convblock_out = self.conv_block(x)
@@ -98,7 +98,7 @@ class JDCNet(nn.Module):
98
  return poolblock_out.transpose(-1, -2)
99
 
100
  def get_feature(self, x):
101
- seq_len = x.shape[-2]
102
  x = x.float().transpose(-1, -2)
103
 
104
  convblock_out = self.conv_block(x)
 
84
  self.apply(self.init_weights)
85
 
86
  def get_feature_GAN(self, x):
87
+ x.shape[-2]
88
  x = x.float().transpose(-1, -2)
89
 
90
  convblock_out = self.conv_block(x)
 
98
  return poolblock_out.transpose(-1, -2)
99
 
100
  def get_feature(self, x):
101
+ x.shape[-2]
102
  x = x.float().transpose(-1, -2)
103
 
104
  convblock_out = self.conv_block(x)
Utils/PLBERT/util.py CHANGED
@@ -20,7 +20,6 @@ def load_plbert(log_dir):
20
  albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
21
  bert = CustomAlbert(albert_base_configuration)
22
 
23
- files = os.listdir(log_dir)
24
  ckpts = []
25
  for f in os.listdir(log_dir):
26
  if f.startswith("step_"):
 
20
  albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
21
  bert = CustomAlbert(albert_base_configuration)
22
 
 
23
  ckpts = []
24
  for f in os.listdir(log_dir):
25
  if f.startswith("step_"):
_run.py CHANGED
@@ -23,11 +23,8 @@ np.random.seed(0)
23
  import time
24
  import random
25
  import yaml
26
- from munch import Munch
27
  import numpy as np
28
  import torch
29
- from torch import nn
30
- import torch.nn.functional as F
31
  import torchaudio
32
  import librosa
33
  from nltk.tokenize import word_tokenize
@@ -305,7 +302,7 @@ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=
305
 
306
  ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
307
  ref_text_mask = length_to_mask(ref_input_lengths).to(device)
308
- ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())
309
  s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
310
  embedding=bert_dur,
311
  embedding_scale=embedding_scale,
 
23
  import time
24
  import random
25
  import yaml
 
26
  import numpy as np
27
  import torch
 
 
28
  import torchaudio
29
  import librosa
30
  from nltk.tokenize import word_tokenize
 
302
 
303
  ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
304
  ref_text_mask = length_to_mask(ref_input_lengths).to(device)
305
+ model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())
306
  s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
307
  embedding=bert_dur,
308
  embedding_scale=embedding_scale,
app.py CHANGED
@@ -1,48 +1,48 @@
1
- INTROTXT = """# StyleTTS 2
2
-
3
- [Paper](https://arxiv.org/abs/2306.07691) - [Samples](https://styletts2.github.io/) - [Code](https://github.com/yl4579/StyleTTS2) - [Discord](https://discord.gg/ha8sxdG2K4)
4
-
5
- A free demo of StyleTTS 2. **I am not affiliated with the StyleTTS 2 Authors.**
6
-
7
- **Before using this demo, you agree to inform the listeners that the speech samples are synthesized by the pre-trained models, unless you have the permission to use the voice you synthesize. That is, you agree to only use voices whose speakers grant the permission to have their voice cloned, either directly or by license before making synthesized voices public, or you have to publicly announce that these voices are synthesized if you do not have the permission to use these voices.**
8
-
9
- Is there a long queue on this space? Duplicate it and add a more powerful GPU to skip the wait! **Note: Thank you to Hugging Face for their generous GPU grant program!**
10
 
11
- **NOTE: StyleTTS 2 does better on longer texts.** For example, making it say "hi" will produce a lower-quality result than making it say a longer phrase.
12
- """
13
  import gradio as gr
14
- import styletts2importable
15
- import ljspeechimportable
16
  import torch
17
- import os
18
  from txtsplit import txtsplit
19
  import numpy as np
20
- import pickle
 
 
21
  theme = gr.themes.Base(
22
- font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
 
 
 
 
 
23
  )
24
- voicelist = ['f-us-1', 'f-us-2', 'f-us-3', 'f-us-4', 'm-us-1', 'm-us-2', 'm-us-3', 'm-us-4']
 
 
 
 
 
 
 
 
 
25
  voices = {}
26
- import phonemizer
27
- global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
28
- # todo: cache computed style, load using pickle
29
- # if os.path.exists('voices.pkl'):
30
- # with open('voices.pkl', 'rb') as f:
31
- # voices = pickle.load(f)
32
  # else:
33
  for v in voicelist:
34
- voices[v] = styletts2importable.compute_style(f'voices/{v}.wav')
35
- # def synthesize(text, voice, multispeakersteps):
36
- # if text.strip() == "":
37
- # raise gr.Error("You must enter some text")
38
- # # if len(global_phonemizer.phonemize([text])) > 300:
39
- # if len(text) > 300:
40
- # raise gr.Error("Text must be under 300 characters")
41
- # v = voice.lower()
42
- # # return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=7, embedding_scale=1))
43
- # return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=multispeakersteps, embedding_scale=1))
44
- if not torch.cuda.is_available(): INTROTXT += "\n\n### You are on a CPU-only system, inference will be much slower.\n\nYou can use the [online demo](https://huggingface.co/spaces/styletts2/styletts2) for fast inference."
45
- def synthesize(text, voice, lngsteps, password, progress=gr.Progress()):
46
  if text.strip() == "":
47
  raise gr.Error("You must enter some text")
48
  if len(text) > 50000:
@@ -53,123 +53,81 @@ def synthesize(text, voice, lngsteps, password, progress=gr.Progress()):
53
  texts = txtsplit(text)
54
  v = voice.lower()
55
  audios = []
56
- for t in progress.tqdm(texts):
57
- audios.append(styletts2importable.inference(t, voices[v], alpha=0.3, beta=0.7, diffusion_steps=lngsteps, embedding_scale=1))
58
- return (24000, np.concatenate(audios))
59
- # def longsynthesize(text, voice, lngsteps, password, progress=gr.Progress()):
60
- # if password == os.environ['ACCESS_CODE']:
61
- # if text.strip() == "":
62
- # raise gr.Error("You must enter some text")
63
- # if lngsteps > 25:
64
- # raise gr.Error("Max 25 steps")
65
- # if lngsteps < 5:
66
- # raise gr.Error("Min 5 steps")
67
- # texts = split_and_recombine_text(text)
68
- # v = voice.lower()
69
- # audios = []
70
- # for t in progress.tqdm(texts):
71
- # audios.append(styletts2importable.inference(t, voices[v], alpha=0.3, beta=0.7, diffusion_steps=lngsteps, embedding_scale=1))
72
- # return (24000, np.concatenate(audios))
73
- # else:
74
- # raise gr.Error('Wrong access code')
75
- def clsynthesize(text, voice, vcsteps, progress=gr.Progress()):
76
- # if text.strip() == "":
77
- # raise gr.Error("You must enter some text")
78
- # # if global_phonemizer.phonemize([text]) > 300:
79
- # if len(text) > 400:
80
- # raise gr.Error("Text must be under 400 characters")
81
- # # return (24000, styletts2importable.inference(text, styletts2importable.compute_style(voice), alpha=0.3, beta=0.7, diffusion_steps=20, embedding_scale=1))
82
- # return (24000, styletts2importable.inference(text, styletts2importable.compute_style(voice), alpha=0.3, beta=0.7, diffusion_steps=vcsteps, embedding_scale=1))
83
- if text.strip() == "":
84
- raise gr.Error("You must enter some text")
85
- if len(text) > 50000:
86
- raise gr.Error("Text must be <50k characters")
87
- print("*** saying ***")
88
- print(text)
89
- print("*** end ***")
90
- texts = txtsplit(text)
91
- audios = []
92
- for t in progress.tqdm(texts):
93
- audios.append(styletts2importable.inference(t, styletts2importable.compute_style(voice), alpha=0.3, beta=0.7, diffusion_steps=vcsteps, embedding_scale=1))
94
- return (24000, np.concatenate(audios))
95
- def ljsynthesize(text, steps, progress=gr.Progress()):
96
- # if text.strip() == "":
97
- # raise gr.Error("You must enter some text")
98
- # # if global_phonemizer.phonemize([text]) > 300:
99
- # if len(text) > 400:
100
- # raise gr.Error("Text must be under 400 characters")
101
- noise = torch.randn(1,1,256).to('cuda' if torch.cuda.is_available() else 'cpu')
102
- # return (24000, ljspeechimportable.inference(text, noise, diffusion_steps=7, embedding_scale=1))
103
- if text.strip() == "":
104
- raise gr.Error("You must enter some text")
105
- if len(text) > 150000:
106
- raise gr.Error("Text must be <150k characters")
107
- print("*** saying ***")
108
- print(text)
109
- print("*** end ***")
110
- texts = txtsplit(text)
111
- audios = []
112
- for t in progress.tqdm(texts):
113
- audios.append(ljspeechimportable.inference(t, noise, diffusion_steps=steps, embedding_scale=1))
114
  return (24000, np.concatenate(audios))
115
 
116
 
117
  with gr.Blocks() as vctk:
118
  with gr.Row():
119
  with gr.Column(scale=1):
120
- inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
121
- voice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value='m-us-2', interactive=True)
122
- multispeakersteps = gr.Slider(minimum=3, maximum=15, value=3, step=1, label="Diffusion Steps", info="Theoretically, higher should be better quality but slower, but we cannot notice a difference. Try with lower steps first - it is faster", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # use_gruut = gr.Checkbox(label="Use alternate phonemizer (Gruut) - Experimental")
124
  with gr.Column(scale=1):
125
  btn = gr.Button("Synthesize", variant="primary")
126
- audio = gr.Audio(interactive=False, label="Synthesized Audio", waveform_options={'waveform_progress_color': '#3C82F6'})
127
- btn.click(synthesize, inputs=[inp, voice, multispeakersteps], outputs=[audio], concurrency_limit=4)
128
- with gr.Blocks() as clone:
129
- with gr.Row():
130
- with gr.Column(scale=1):
131
- clinp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
132
- clvoice = gr.Audio(label="Voice", interactive=True, type='filepath', max_length=300, waveform_options={'waveform_progress_color': '#3C82F6'})
133
- vcsteps = gr.Slider(minimum=3, maximum=20, value=20, step=1, label="Diffusion Steps", info="Theoretically, higher should be better quality but slower, but we cannot notice a difference. Try with lower steps first - it is faster", interactive=True)
134
- with gr.Column(scale=1):
135
- clbtn = gr.Button("Synthesize", variant="primary")
136
- claudio = gr.Audio(interactive=False, label="Synthesized Audio", waveform_options={'waveform_progress_color': '#3C82F6'})
137
- clbtn.click(clsynthesize, inputs=[clinp, clvoice, vcsteps], outputs=[claudio], concurrency_limit=4)
138
- # with gr.Blocks() as longText:
139
- # with gr.Row():
140
- # with gr.Column(scale=1):
141
- # lnginp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
142
- # lngvoice = gr.Dropdown(voicelist, label="Voice", info="Select a default voice.", value='m-us-1', interactive=True)
143
- # lngsteps = gr.Slider(minimum=5, maximum=25, value=10, step=1, label="Diffusion Steps", info="Higher = better quality, but slower", interactive=True)
144
- # lngpwd = gr.Textbox(label="Access code", info="This feature is in beta. You need an access code to use it as it uses more resources and we would like to prevent abuse")
145
- # with gr.Column(scale=1):
146
- # lngbtn = gr.Button("Synthesize", variant="primary")
147
- # lngaudio = gr.Audio(interactive=False, label="Synthesized Audio")
148
- # lngbtn.click(longsynthesize, inputs=[lnginp, lngvoice, lngsteps, lngpwd], outputs=[lngaudio], concurrency_limit=4)
149
- with gr.Blocks() as lj:
150
- with gr.Row():
151
- with gr.Column(scale=1):
152
- ljinp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
153
- ljsteps = gr.Slider(minimum=3, maximum=20, value=3, step=1, label="Diffusion Steps", info="Theoretically, higher should be better quality but slower, but we cannot notice a difference. Try with lower steps first - it is faster", interactive=True)
154
- with gr.Column(scale=1):
155
- ljbtn = gr.Button("Synthesize", variant="primary")
156
- ljaudio = gr.Audio(interactive=False, label="Synthesized Audio", waveform_options={'waveform_progress_color': '#3C82F6'})
157
- ljbtn.click(ljsynthesize, inputs=[ljinp, ljsteps], outputs=[ljaudio], concurrency_limit=4)
158
- with gr.Blocks(title="StyleTTS 2", css="footer{display:none !important}", theme=theme) as demo:
159
- gr.Markdown(INTROTXT)
160
- gr.DuplicateButton("Duplicate Space")
161
- # gr.TabbedInterface([vctk, clone, lj, longText], ['Multi-Voice', 'Voice Cloning', 'LJSpeech', 'Long Text [Beta]'])
162
- gr.TabbedInterface([vctk, clone, lj], ['Multi-Voice', 'Voice Cloning', 'LJSpeech', 'Long Text [Beta]'])
163
- gr.Markdown("""
164
- Demo by [mrfakename](https://twitter.com/realmrfakename). I am not affiliated with the StyleTTS 2 authors.
165
 
166
- Run this demo locally using Docker:
167
-
168
- ```bash
169
- docker run -it -p 7860:7860 --platform=linux/amd64 --gpus all registry.hf.space/styletts2-styletts2:latest python app.py
170
- ```
171
- """) # Please do not remove this line.
172
  if __name__ == "__main__":
173
  # demo.queue(api_open=False, max_size=15).launch(show_api=False)
174
- demo.queue(api_open=False, max_size=15).launch(show_api=False)
175
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
 
 
 
 
 
 
 
 
2
 
 
 
3
  import gradio as gr
 
 
4
  import torch
5
+ from styletts2importable import compute_style, inference
6
  from txtsplit import txtsplit
7
  import numpy as np
8
+ import phonemizer
9
+
10
+
11
  theme = gr.themes.Base(
12
+ font=[
13
+ gr.themes.GoogleFont("Libre Franklin"),
14
+ gr.themes.GoogleFont("Public Sans"),
15
+ "system-ui",
16
+ "sans-serif",
17
+ ],
18
  )
19
+ voicelist = [
20
+ "f-us-1",
21
+ "f-us-2",
22
+ "f-us-3",
23
+ "f-us-4",
24
+ "m-us-1",
25
+ "m-us-2",
26
+ "m-us-3",
27
+ "m-us-4",
28
+ ]
29
  voices = {}
30
+
31
+ global_phonemizer = phonemizer.backend.EspeakBackend(
32
+ language="en-us", preserve_punctuation=True, with_stress=True
33
+ )
 
 
34
  # else:
35
  for v in voicelist:
36
+ cache_path = f"voices/{v}.wav.npy"
37
+ if os.path.exists(cache_path):
38
+ voices[v] = torch.from_numpy(np.load(cache_path))
39
+ else:
40
+ style = compute_style(f"voices/{v}.wav")
41
+ voices[v] = style
42
+ np.save(cache_path, style.cpu().numpy())
43
+
44
+
45
+ def synthesize(text, voice, lngsteps):
 
 
46
  if text.strip() == "":
47
  raise gr.Error("You must enter some text")
48
  if len(text) > 50000:
 
53
  texts = txtsplit(text)
54
  v = voice.lower()
55
  audios = []
56
+ for t in texts:
57
+ audios.append(
58
+ inference(
59
+ t,
60
+ voices[v],
61
+ alpha=0.3,
62
+ beta=0.7,
63
+ diffusion_steps=lngsteps,
64
+ embedding_scale=1,
65
+ )
66
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  return (24000, np.concatenate(audios))
68
 
69
 
70
  with gr.Blocks() as vctk:
71
  with gr.Row():
72
  with gr.Column(scale=1):
73
+ inp = gr.Textbox(
74
+ label="Text",
75
+ info="What would you like StyleTTS 2 to read? It works better on full sentences.",
76
+ interactive=True,
77
+ )
78
+ voice = gr.Dropdown(
79
+ voicelist,
80
+ label="Voice",
81
+ info="Select a default voice.",
82
+ value="m-us-2",
83
+ interactive=True,
84
+ )
85
+ multispeakersteps = gr.Slider(
86
+ minimum=3,
87
+ maximum=15,
88
+ value=3,
89
+ step=1,
90
+ label="Diffusion Steps",
91
+ info="Theoretically, higher should be better quality but slower, but we cannot notice a difference. Try with lower steps first - it is faster",
92
+ interactive=True,
93
+ )
94
  # use_gruut = gr.Checkbox(label="Use alternate phonemizer (Gruut) - Experimental")
95
  with gr.Column(scale=1):
96
  btn = gr.Button("Synthesize", variant="primary")
97
+ audio = gr.Audio(
98
+ interactive=False,
99
+ label="Synthesized Audio",
100
+ waveform_options={"waveform_progress_color": "#3C82F6"},
101
+ )
102
+ btn.click(
103
+ synthesize,
104
+ inputs=[inp, voice, multispeakersteps],
105
+ outputs=[audio],
106
+ concurrency_limit=4,
107
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ with gr.Blocks(
110
+ title="StyleTTS 2", css="footer{display:none !important}", theme=theme
111
+ ) as demo:
112
+ gr.TabbedInterface(
113
+ [vctk], ["Multi-Voice", "Voice Cloning", "LJSpeech", "Long Text [Beta]"]
114
+ )
115
  if __name__ == "__main__":
116
  # demo.queue(api_open=False, max_size=15).launch(show_api=False)
117
+ print("Launching")
118
+ # start_time = time.time()
119
+ # synthesize(
120
+ # "defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None (default), the name of the function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that gr.load this app) will not be able to use this event.",
121
+ # "m-us-2",
122
+ # 3,
123
+ # )
124
+ # print(f"Launched in {time.time() - start_time} seconds")
125
+ # second_start_time = time.time()
126
+ # synthesize(
127
+ # "defines how the endpoint appears in the API docs. Can be a string, None, or False. If set to a string, the endpoint will be exposed in the API docs with the given name. If None (default), the name of the function will be used as the API endpoint. If False, the endpoint will not be exposed in the API docs and downstream apps (including those that gr.load this app) will not be able to use this event.",
128
+ # "m-us-2",
129
+ # 3,
130
+ # )
131
+ # print(f"Launched in {time.time() - second_start_time} seconds")
132
+ demo.queue(api_open=True, max_size=None).launch(show_api=False)
133
+ print("Launched")
compute.py CHANGED
@@ -5,7 +5,6 @@ print("NLTK")
5
  import nltk
6
  nltk.download('punkt')
7
  print("SCIPY")
8
- from scipy.io.wavfile import write
9
  print("TORCH STUFF")
10
  import torch
11
  print("START")
@@ -20,17 +19,12 @@ import numpy as np
20
  np.random.seed(0)
21
 
22
  # load packages
23
- import time
24
  import random
25
  import yaml
26
- from munch import Munch
27
  import numpy as np
28
  import torch
29
- from torch import nn
30
- import torch.nn.functional as F
31
  import torchaudio
32
  import librosa
33
- from nltk.tokenize import word_tokenize
34
 
35
  from models import *
36
  from utils import *
 
5
  import nltk
6
  nltk.download('punkt')
7
  print("SCIPY")
 
8
  print("TORCH STUFF")
9
  import torch
10
  print("START")
 
19
  np.random.seed(0)
20
 
21
  # load packages
 
22
  import random
23
  import yaml
 
24
  import numpy as np
25
  import torch
 
 
26
  import torchaudio
27
  import librosa
 
28
 
29
  from models import *
30
  from utils import *
ljspeechimportable.py DELETED
@@ -1,225 +0,0 @@
1
- from cached_path import cached_path
2
-
3
-
4
- import torch
5
- torch.manual_seed(0)
6
- torch.backends.cudnn.benchmark = False
7
- torch.backends.cudnn.deterministic = True
8
-
9
- import random
10
- random.seed(0)
11
-
12
- import numpy as np
13
- np.random.seed(0)
14
-
15
- import nltk
16
- nltk.download('punkt')
17
-
18
- # load packages
19
- import time
20
- import random
21
- import yaml
22
- from munch import Munch
23
- import numpy as np
24
- import torch
25
- from torch import nn
26
- import torch.nn.functional as F
27
- import torchaudio
28
- import librosa
29
- from nltk.tokenize import word_tokenize
30
-
31
- from models import *
32
- from utils import *
33
- from text_utils import TextCleaner
34
- textclenaer = TextCleaner()
35
-
36
-
37
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
38
-
39
- to_mel = torchaudio.transforms.MelSpectrogram(
40
- n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
41
- mean, std = -4, 4
42
-
43
- def length_to_mask(lengths):
44
- mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
45
- mask = torch.gt(mask+1, lengths.unsqueeze(1))
46
- return mask
47
-
48
- def preprocess(wave):
49
- wave_tensor = torch.from_numpy(wave).float()
50
- mel_tensor = to_mel(wave_tensor)
51
- mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
52
- return mel_tensor
53
-
54
- def compute_style(ref_dicts):
55
- reference_embeddings = {}
56
- for key, path in ref_dicts.items():
57
- wave, sr = librosa.load(path, sr=24000)
58
- audio, index = librosa.effects.trim(wave, top_db=30)
59
- if sr != 24000:
60
- audio = librosa.resample(audio, sr, 24000)
61
- mel_tensor = preprocess(audio).to(device)
62
-
63
- with torch.no_grad():
64
- ref = model.style_encoder(mel_tensor.unsqueeze(1))
65
- reference_embeddings[key] = (ref.squeeze(1), audio)
66
-
67
- return reference_embeddings
68
-
69
- # load phonemizer
70
- import phonemizer
71
- global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore')
72
-
73
- # phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
74
-
75
-
76
- config = yaml.safe_load(open(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/config.yml'))))
77
-
78
- # load pretrained ASR model
79
- ASR_config = config.get('ASR_config', False)
80
- ASR_path = config.get('ASR_path', False)
81
- text_aligner = load_ASR_models(ASR_path, ASR_config)
82
-
83
- # load pretrained F0 model
84
- F0_path = config.get('F0_path', False)
85
- pitch_extractor = load_F0_models(F0_path)
86
-
87
- # load BERT model
88
- from Utils.PLBERT.util import load_plbert
89
- BERT_path = config.get('PLBERT_dir', False)
90
- plbert = load_plbert(BERT_path)
91
-
92
- model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
93
- _ = [model[key].eval() for key in model]
94
- _ = [model[key].to(device) for key in model]
95
-
96
- # params_whole = torch.load("Models/LJSpeech/epoch_2nd_00100.pth", map_location='cpu')
97
- params_whole = torch.load(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/epoch_2nd_00100.pth')), map_location='cpu')
98
- params = params_whole['net']
99
-
100
- for key in model:
101
- if key in params:
102
- print('%s loaded' % key)
103
- try:
104
- model[key].load_state_dict(params[key])
105
- except:
106
- from collections import OrderedDict
107
- state_dict = params[key]
108
- new_state_dict = OrderedDict()
109
- for k, v in state_dict.items():
110
- name = k[7:] # remove `module.`
111
- new_state_dict[name] = v
112
- # load params
113
- model[key].load_state_dict(new_state_dict, strict=False)
114
- # except:
115
- # _load(params[key], model[key])
116
- _ = [model[key].eval() for key in model]
117
-
118
- from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
119
-
120
- sampler = DiffusionSampler(
121
- model.diffusion.diffusion,
122
- sampler=ADPM2Sampler(),
123
- sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
124
- clamp=False
125
- )
126
-
127
- def inference(text, noise, diffusion_steps=5, embedding_scale=1):
128
- text = text.strip()
129
- text = text.replace('"', '')
130
- ps = global_phonemizer.phonemize([text])
131
- ps = word_tokenize(ps[0])
132
- ps = ' '.join(ps)
133
-
134
- tokens = textclenaer(ps)
135
- tokens.insert(0, 0)
136
- tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
137
-
138
- with torch.no_grad():
139
- input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
140
- text_mask = length_to_mask(input_lengths).to(tokens.device)
141
-
142
- t_en = model.text_encoder(tokens, input_lengths, text_mask)
143
- bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
144
- d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
145
-
146
- s_pred = sampler(noise,
147
- embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
148
- embedding_scale=embedding_scale).squeeze(0)
149
-
150
- s = s_pred[:, 128:]
151
- ref = s_pred[:, :128]
152
-
153
- d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
154
-
155
- x, _ = model.predictor.lstm(d)
156
- duration = model.predictor.duration_proj(x)
157
- duration = torch.sigmoid(duration).sum(axis=-1)
158
- pred_dur = torch.round(duration.squeeze()).clamp(min=1)
159
-
160
- pred_dur[-1] += 5
161
-
162
- pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
163
- c_frame = 0
164
- for i in range(pred_aln_trg.size(0)):
165
- pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
166
- c_frame += int(pred_dur[i].data)
167
-
168
- # encode prosody
169
- en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
170
- F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
171
- out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),
172
- F0_pred, N_pred, ref.squeeze().unsqueeze(0))
173
-
174
- return out.squeeze().cpu().numpy()
175
-
176
- def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1):
177
- text = text.strip()
178
- text = text.replace('"', '')
179
- ps = global_phonemizer.phonemize([text])
180
- ps = word_tokenize(ps[0])
181
- ps = ' '.join(ps)
182
-
183
- tokens = textclenaer(ps)
184
- tokens.insert(0, 0)
185
- tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
186
-
187
- with torch.no_grad():
188
- input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
189
- text_mask = length_to_mask(input_lengths).to(tokens.device)
190
-
191
- t_en = model.text_encoder(tokens, input_lengths, text_mask)
192
- bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
193
- d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
194
-
195
- s_pred = sampler(noise,
196
- embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
197
- embedding_scale=embedding_scale).squeeze(0)
198
-
199
- if s_prev is not None:
200
- # convex combination of previous and current style
201
- s_pred = alpha * s_prev + (1 - alpha) * s_pred
202
-
203
- s = s_pred[:, 128:]
204
- ref = s_pred[:, :128]
205
-
206
- d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
207
-
208
- x, _ = model.predictor.lstm(d)
209
- duration = model.predictor.duration_proj(x)
210
- duration = torch.sigmoid(duration).sum(axis=-1)
211
- pred_dur = torch.round(duration.squeeze()).clamp(min=1)
212
-
213
- pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
214
- c_frame = 0
215
- for i in range(pred_aln_trg.size(0)):
216
- pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
217
- c_frame += int(pred_dur[i].data)
218
-
219
- # encode prosody
220
- en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
221
- F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
222
- out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)),
223
- F0_pred, N_pred, ref.squeeze().unsqueeze(0))
224
-
225
- return out.squeeze().cpu().numpy(), s_pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
losses.py CHANGED
@@ -1,5 +1,4 @@
1
  import torch
2
- from torch import nn
3
  import torch.nn.functional as F
4
  import torchaudio
5
  from transformers import AutoModel
 
1
  import torch
 
2
  import torch.nn.functional as F
3
  import torchaudio
4
  from transformers import AutoModel
meldataset.py CHANGED
@@ -1,7 +1,5 @@
1
  # coding: utf-8
2
- import os
3
  import os.path as osp
4
- import time
5
  import random
6
  import numpy as np
7
  import random
@@ -9,8 +7,6 @@ import soundfile as sf
9
  import librosa
10
 
11
  import torch
12
- from torch import nn
13
- import torch.nn.functional as F
14
  import torchaudio
15
  from torch.utils.data import DataLoader
16
 
@@ -79,8 +75,6 @@ class FilePathDataset(torch.utils.data.Dataset):
79
  OOD_data="Data/OOD_texts.txt",
80
  min_length=50,
81
  ):
82
- spect_params = SPECT_PARAMS
83
- mel_params = MEL_PARAMS
84
 
85
  _data_list = [l[:-1].split("|") for l in data_list]
86
  self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
 
1
  # coding: utf-8
 
2
  import os.path as osp
 
3
  import random
4
  import numpy as np
5
  import random
 
7
  import librosa
8
 
9
  import torch
 
 
10
  import torchaudio
11
  from torch.utils.data import DataLoader
12
 
 
75
  OOD_data="Data/OOD_texts.txt",
76
  min_length=50,
77
  ):
 
 
78
 
79
  _data_list = [l[:-1].split("|") for l in data_list]
80
  self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
models.py CHANGED
@@ -1,22 +1,16 @@
1
  # coding:utf-8
2
-
3
- import os
4
- import os.path as osp
5
-
6
- import copy
7
  import math
8
 
9
- import numpy as np
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
-
15
  from Utils.ASR.models import ASRCNN
16
  from Utils.JDC.model import JDCNet
17
 
18
  from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
19
- from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
20
  from Modules.diffusion.diffusion import AudioDiffusionConditional
21
 
22
  from Modules.discriminators import (
@@ -27,6 +21,7 @@ from Modules.discriminators import (
27
 
28
  from munch import Munch
29
  import yaml
 
30
 
31
 
32
  class LearnedDownSample(nn.Module):
@@ -589,8 +584,8 @@ class ProsodyPredictor(nn.Module):
589
  def forward(self, texts, style, text_lengths, alignment, m):
590
  d = self.text_encoder(texts, style, text_lengths, m)
591
 
592
- batch_size = d.shape[0]
593
- text_size = d.shape[1]
594
 
595
  # predict duration
596
  input_lengths = text_lengths.cpu().numpy()
@@ -750,37 +745,19 @@ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
750
  return asr_model
751
 
752
 
753
- def build_model(args, text_aligner, pitch_extractor, bert):
754
- assert args.decoder.type in ["istftnet", "hifigan"], "Decoder type unknown"
755
-
756
- if args.decoder.type == "istftnet":
757
- from Modules.istftnet import Decoder
 
 
 
 
 
 
758
 
759
- decoder = Decoder(
760
- dim_in=args.hidden_dim,
761
- style_dim=args.style_dim,
762
- dim_out=args.n_mels,
763
- resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
764
- upsample_rates=args.decoder.upsample_rates,
765
- upsample_initial_channel=args.decoder.upsample_initial_channel,
766
- resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
767
- upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
768
- gen_istft_n_fft=args.decoder.gen_istft_n_fft,
769
- gen_istft_hop_size=args.decoder.gen_istft_hop_size,
770
- )
771
- else:
772
- from Modules.hifigan import Decoder
773
-
774
- decoder = Decoder(
775
- dim_in=args.hidden_dim,
776
- style_dim=args.style_dim,
777
- dim_out=args.n_mels,
778
- resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
779
- upsample_rates=args.decoder.upsample_rates,
780
- upsample_initial_channel=args.decoder.upsample_initial_channel,
781
- resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
782
- upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
783
- )
784
 
785
  text_encoder = TextEncoder(
786
  channels=args.hidden_dim,
@@ -804,20 +781,12 @@ def build_model(args, text_aligner, pitch_extractor, bert):
804
  dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim
805
  ) # prosodic style encoder
806
 
807
- # define diffusion model
808
- if args.multispeaker:
809
- transformer = StyleTransformer1d(
810
- channels=args.style_dim * 2,
811
- context_embedding_features=bert.config.hidden_size,
812
- context_features=args.style_dim * 2,
813
- **args.diffusion.transformer
814
- )
815
- else:
816
- transformer = Transformer1d(
817
- channels=args.style_dim * 2,
818
- context_embedding_features=bert.config.hidden_size,
819
- **args.diffusion.transformer
820
- )
821
 
822
  diffusion = AudioDiffusionConditional(
823
  in_channels=1,
@@ -839,6 +808,8 @@ def build_model(args, text_aligner, pitch_extractor, bert):
839
  diffusion.diffusion.net = transformer
840
  diffusion.unet = transformer
841
 
 
 
842
  nets = Munch(
843
  bert=bert,
844
  bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
@@ -848,8 +819,6 @@ def build_model(args, text_aligner, pitch_extractor, bert):
848
  predictor_encoder=predictor_encoder,
849
  style_encoder=style_encoder,
850
  diffusion=diffusion,
851
- text_aligner=text_aligner,
852
- pitch_extractor=pitch_extractor,
853
  mpd=MultiPeriodDiscriminator(),
854
  msd=MultiResSpecDiscriminator(),
855
  # slm discriminator head
 
1
  # coding:utf-8
 
 
 
 
 
2
  import math
3
 
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
+ from torch.nn.utils import spectral_norm
8
+ from torch.nn.utils.parametrizations import weight_norm
9
  from Utils.ASR.models import ASRCNN
10
  from Utils.JDC.model import JDCNet
11
 
12
  from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
13
+ from Modules.diffusion.modules import StyleTransformer1d
14
  from Modules.diffusion.diffusion import AudioDiffusionConditional
15
 
16
  from Modules.discriminators import (
 
21
 
22
  from munch import Munch
23
  import yaml
24
+ from Modules.hifigan import Decoder
25
 
26
 
27
  class LearnedDownSample(nn.Module):
 
584
  def forward(self, texts, style, text_lengths, alignment, m):
585
  d = self.text_encoder(texts, style, text_lengths, m)
586
 
587
+ d.shape[0]
588
+ d.shape[1]
589
 
590
  # predict duration
591
  input_lengths = text_lengths.cpu().numpy()
 
745
  return asr_model
746
 
747
 
748
+ def build_model(args, bert):
749
+ decoder = Decoder(
750
+ dim_in=args.hidden_dim,
751
+ style_dim=args.style_dim,
752
+ dim_out=args.n_mels,
753
+ resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
754
+ upsample_rates=args.decoder.upsample_rates,
755
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
756
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
757
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
758
+ )
759
 
760
+ # decoder = torch.compile(decoder, dynamic=True, backend="aot_eager")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
761
 
762
  text_encoder = TextEncoder(
763
  channels=args.hidden_dim,
 
781
  dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim
782
  ) # prosodic style encoder
783
 
784
+ transformer = StyleTransformer1d(
785
+ channels=args.style_dim * 2,
786
+ context_embedding_features=bert.config.hidden_size,
787
+ context_features=args.style_dim * 2,
788
+ **args.diffusion.transformer,
789
+ )
 
 
 
 
 
 
 
 
790
 
791
  diffusion = AudioDiffusionConditional(
792
  in_channels=1,
 
808
  diffusion.diffusion.net = transformer
809
  diffusion.unet = transformer
810
 
811
+ # predictor = torch.compile(predictor)
812
+
813
  nets = Munch(
814
  bert=bert,
815
  bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
 
819
  predictor_encoder=predictor_encoder,
820
  style_encoder=style_encoder,
821
  diffusion=diffusion,
 
 
822
  mpd=MultiPeriodDiscriminator(),
823
  msd=MultiResSpecDiscriminator(),
824
  # slm discriminator head
optimizers.py CHANGED
@@ -1,10 +1,5 @@
1
  # coding:utf-8
2
- import os, sys
3
- import os.path as osp
4
- import numpy as np
5
  import torch
6
- from torch import nn
7
- from torch.optim import Optimizer
8
  from functools import reduce
9
  from torch.optim import AdamW
10
 
 
1
  # coding:utf-8
 
 
 
2
  import torch
 
 
3
  from functools import reduce
4
  from torch.optim import AdamW
5
 
requirements.txt CHANGED
@@ -1,13 +1,15 @@
1
  SoundFile
2
  torchaudio
3
  munch
4
- torch
5
  pydub
6
  pyyaml
7
  librosa
8
  nltk
9
  matplotlib
10
  accelerate
 
 
11
  transformers
12
  einops
13
  einops-exts
@@ -20,5 +22,4 @@ phonemizer
20
  cached-path
21
  gradio
22
  gruut
23
- #tortoise-tts
24
  txtsplit
 
1
  SoundFile
2
  torchaudio
3
  munch
4
+ torch>=2.2.0
5
  pydub
6
  pyyaml
7
  librosa
8
  nltk
9
  matplotlib
10
  accelerate
11
+ tokenizers>=0.14
12
+ bottleneck>=1.3.6
13
  transformers
14
  einops
15
  einops-exts
 
22
  cached-path
23
  gradio
24
  gruut
 
25
  txtsplit
styletts2importable.py CHANGED
@@ -1,60 +1,57 @@
 
 
 
 
1
  from cached_path import cached_path
2
- # print("GRUUT")
3
- # from gruut_phonemize import gphonemize
4
-
5
- # from dp.phonemizer import Phonemizer
6
- print("NLTK")
7
  import nltk
8
- nltk.download('punkt')
9
- print("SCIPY")
10
- from scipy.io.wavfile import write
11
- print("TORCH STUFF")
12
- import torch
13
- print("START")
 
 
 
 
 
14
  torch.manual_seed(0)
15
  torch.backends.cudnn.benchmark = False
16
  torch.backends.cudnn.deterministic = True
17
 
18
- import random
19
- random.seed(0)
20
-
21
- import numpy as np
22
- np.random.seed(0)
23
 
24
- # load packages
25
- import time
26
- import random
27
- import yaml
28
- from munch import Munch
29
- import numpy as np
30
- import torch
31
- from torch import nn
32
- import torch.nn.functional as F
33
- import torchaudio
34
- import librosa
35
- from nltk.tokenize import word_tokenize
36
 
37
- from models import *
38
- from utils import *
39
- from text_utils import TextCleaner
40
- textclenaer = TextCleaner()
41
 
42
 
43
  to_mel = torchaudio.transforms.MelSpectrogram(
44
- n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
 
45
  mean, std = -4, 4
46
 
 
47
  def length_to_mask(lengths):
48
- mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
49
- mask = torch.gt(mask+1, lengths.unsqueeze(1))
 
 
 
 
 
50
  return mask
51
 
 
52
  def preprocess(wave):
53
  wave_tensor = torch.from_numpy(wave).float()
54
  mel_tensor = to_mel(wave_tensor)
55
  mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
56
  return mel_tensor
57
 
 
58
  def compute_style(path):
59
  wave, sr = librosa.load(path, sr=24000)
60
  audio, index = librosa.effects.trim(wave, top_db=30)
@@ -68,55 +65,151 @@ def compute_style(path):
68
 
69
  return torch.cat([ref_s, ref_p], dim=1)
70
 
71
- device = 'cpu'
 
72
  if torch.cuda.is_available():
73
- device = 'cuda'
74
  elif torch.backends.mps.is_available():
75
  print("MPS would be available but cannot be used rn")
76
- # device = 'mps'
77
-
78
- import phonemizer
79
- global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
80
- # phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
81
-
82
 
83
  # config = yaml.safe_load(open("Models/LibriTTS/config.yml"))
84
- config = yaml.safe_load(open(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/config.yml"))))
85
-
86
- # load pretrained ASR model
87
- ASR_config = config.get('ASR_config', False)
88
- ASR_path = config.get('ASR_path', False)
89
- text_aligner = load_ASR_models(ASR_path, ASR_config)
90
-
91
- # load pretrained F0 model
92
- F0_path = config.get('F0_path', False)
93
- pitch_extractor = load_F0_models(F0_path)
94
-
95
- # load BERT model
96
- from Utils.PLBERT.util import load_plbert
97
- BERT_path = config.get('PLBERT_dir', False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  plbert = load_plbert(BERT_path)
99
 
100
- model_params = recursive_munch(config['model_params'])
101
- model = build_model(model_params, text_aligner, pitch_extractor, plbert)
 
102
  _ = [model[key].eval() for key in model]
103
  _ = [model[key].to(device) for key in model]
104
 
105
- # params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
106
- params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
107
- params = params_whole['net']
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  for key in model:
110
  if key in params:
111
- print('%s loaded' % key)
112
  try:
113
  model[key].load_state_dict(params[key])
114
  except:
115
  from collections import OrderedDict
 
116
  state_dict = params[key]
117
  new_state_dict = OrderedDict()
118
  for k, v in state_dict.items():
119
- name = k[7:] # remove `module.`
120
  new_state_dict[name] = v
121
  # load params
122
  model[key].load_state_dict(new_state_dict, strict=False)
@@ -124,181 +217,34 @@ for key in model:
124
  # _load(params[key], model[key])
125
  _ = [model[key].eval() for key in model]
126
 
127
- from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
128
 
129
  sampler = DiffusionSampler(
130
  model.diffusion.diffusion,
131
  sampler=ADPM2Sampler(),
132
- sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
133
- clamp=False
 
 
134
  )
135
 
136
- def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
137
- text = text.strip()
138
- ps = global_phonemizer.phonemize([text])
139
- ps = word_tokenize(ps[0])
140
- ps = ' '.join(ps)
141
- tokens = textclenaer(ps)
142
- tokens.insert(0, 0)
143
- tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
144
-
145
- with torch.no_grad():
146
- input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
147
- text_mask = length_to_mask(input_lengths).to(device)
148
-
149
- t_en = model.text_encoder(tokens, input_lengths, text_mask)
150
- bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
151
- d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
152
-
153
- s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
154
- embedding=bert_dur,
155
- embedding_scale=embedding_scale,
156
- features=ref_s, # reference from the same speaker as the embedding
157
- num_steps=diffusion_steps).squeeze(1)
158
-
159
-
160
- s = s_pred[:, 128:]
161
- ref = s_pred[:, :128]
162
-
163
- ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
164
- s = beta * s + (1 - beta) * ref_s[:, 128:]
165
-
166
- d = model.predictor.text_encoder(d_en,
167
- s, input_lengths, text_mask)
168
-
169
- x, _ = model.predictor.lstm(d)
170
- duration = model.predictor.duration_proj(x)
171
 
172
- duration = torch.sigmoid(duration).sum(axis=-1)
173
- pred_dur = torch.round(duration.squeeze()).clamp(min=1)
174
-
175
-
176
- pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
177
- c_frame = 0
178
- for i in range(pred_aln_trg.size(0)):
179
- pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
180
- c_frame += int(pred_dur[i].data)
181
-
182
- # encode prosody
183
- en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
184
- if model_params.decoder.type == "hifigan":
185
- asr_new = torch.zeros_like(en)
186
- asr_new[:, :, 0] = en[:, :, 0]
187
- asr_new[:, :, 1:] = en[:, :, 0:-1]
188
- en = asr_new
189
-
190
- F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
191
-
192
- asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
193
- if model_params.decoder.type == "hifigan":
194
- asr_new = torch.zeros_like(asr)
195
- asr_new[:, :, 0] = asr[:, :, 0]
196
- asr_new[:, :, 1:] = asr[:, :, 0:-1]
197
- asr = asr_new
198
-
199
- out = model.decoder(asr,
200
- F0_pred, N_pred, ref.squeeze().unsqueeze(0))
201
-
202
-
203
- return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
204
-
205
- def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
206
- text = text.strip()
207
- ps = global_phonemizer.phonemize([text])
208
- ps = word_tokenize(ps[0])
209
- ps = ' '.join(ps)
210
- ps = ps.replace('``', '"')
211
- ps = ps.replace("''", '"')
212
-
213
- tokens = textclenaer(ps)
214
- tokens.insert(0, 0)
215
- tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
216
-
217
- with torch.no_grad():
218
- input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
219
- text_mask = length_to_mask(input_lengths).to(device)
220
-
221
- t_en = model.text_encoder(tokens, input_lengths, text_mask)
222
- bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
223
- d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
224
-
225
- s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
226
- embedding=bert_dur,
227
- embedding_scale=embedding_scale,
228
- features=ref_s, # reference from the same speaker as the embedding
229
- num_steps=diffusion_steps).squeeze(1)
230
-
231
- if s_prev is not None:
232
- # convex combination of previous and current style
233
- s_pred = t * s_prev + (1 - t) * s_pred
234
-
235
- s = s_pred[:, 128:]
236
- ref = s_pred[:, :128]
237
-
238
- ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
239
- s = beta * s + (1 - beta) * ref_s[:, 128:]
240
-
241
- s_pred = torch.cat([ref, s], dim=-1)
242
-
243
- d = model.predictor.text_encoder(d_en,
244
- s, input_lengths, text_mask)
245
-
246
- x, _ = model.predictor.lstm(d)
247
- duration = model.predictor.duration_proj(x)
248
-
249
- duration = torch.sigmoid(duration).sum(axis=-1)
250
- pred_dur = torch.round(duration.squeeze()).clamp(min=1)
251
-
252
-
253
- pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
254
- c_frame = 0
255
- for i in range(pred_aln_trg.size(0)):
256
- pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
257
- c_frame += int(pred_dur[i].data)
258
-
259
- # encode prosody
260
- en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
261
- if model_params.decoder.type == "hifigan":
262
- asr_new = torch.zeros_like(en)
263
- asr_new[:, :, 0] = en[:, :, 0]
264
- asr_new[:, :, 1:] = en[:, :, 0:-1]
265
- en = asr_new
266
-
267
- F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
268
-
269
- asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
270
- if model_params.decoder.type == "hifigan":
271
- asr_new = torch.zeros_like(asr)
272
- asr_new[:, :, 0] = asr[:, :, 0]
273
- asr_new[:, :, 1:] = asr[:, :, 0:-1]
274
- asr = asr_new
275
-
276
- out = model.decoder(asr,
277
- F0_pred, N_pred, ref.squeeze().unsqueeze(0))
278
-
279
-
280
- return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
281
-
282
- def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
283
  text = text.strip()
284
  ps = global_phonemizer.phonemize([text])
285
  ps = word_tokenize(ps[0])
286
- ps = ' '.join(ps)
287
-
288
- tokens = textclenaer(ps)
289
  tokens.insert(0, 0)
290
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
291
 
292
- ref_text = ref_text.strip()
293
- ps = global_phonemizer.phonemize([ref_text])
294
- ps = word_tokenize(ps[0])
295
- ps = ' '.join(ps)
296
-
297
- ref_tokens = textclenaer(ps)
298
- ref_tokens.insert(0, 0)
299
- ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)
300
-
301
-
302
  with torch.no_grad():
303
  input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
304
  text_mask = length_to_mask(input_lengths).to(device)
@@ -307,24 +253,21 @@ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=
307
  bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
308
  d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
309
 
310
- ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
311
- ref_text_mask = length_to_mask(ref_input_lengths).to(device)
312
- ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())
313
- s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
314
- embedding=bert_dur,
315
- embedding_scale=embedding_scale,
316
- features=ref_s, # reference from the same speaker as the embedding
317
- num_steps=diffusion_steps).squeeze(1)
318
-
319
 
320
  s = s_pred[:, 128:]
321
  ref = s_pred[:, :128]
322
 
323
- ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
324
- s = beta * s + (1 - beta) * ref_s[:, 128:]
325
 
326
- d = model.predictor.text_encoder(d_en,
327
- s, input_lengths, text_mask)
328
 
329
  x, _ = model.predictor.lstm(d)
330
  duration = model.predictor.duration_proj(x)
@@ -332,32 +275,29 @@ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=
332
  duration = torch.sigmoid(duration).sum(axis=-1)
333
  pred_dur = torch.round(duration.squeeze()).clamp(min=1)
334
 
335
-
336
  pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
337
  c_frame = 0
338
  for i in range(pred_aln_trg.size(0)):
339
- pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
340
  c_frame += int(pred_dur[i].data)
341
 
342
  # encode prosody
343
- en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
344
- if model_params.decoder.type == "hifigan":
345
- asr_new = torch.zeros_like(en)
346
- asr_new[:, :, 0] = en[:, :, 0]
347
- asr_new[:, :, 1:] = en[:, :, 0:-1]
348
- en = asr_new
349
 
350
  F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
351
 
352
- asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
353
- if model_params.decoder.type == "hifigan":
354
- asr_new = torch.zeros_like(asr)
355
- asr_new[:, :, 0] = asr[:, :, 0]
356
- asr_new[:, :, 1:] = asr[:, :, 0:-1]
357
- asr = asr_new
358
-
359
- out = model.decoder(asr,
360
- F0_pred, N_pred, ref.squeeze().unsqueeze(0))
361
 
 
362
 
363
- return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ import torchaudio
5
  from cached_path import cached_path
6
+ import random
 
 
 
 
7
  import nltk
8
+ from models import build_model
9
+ from text_utils import TextCleaner
10
+ from nltk.tokenize import word_tokenize
11
+ import phonemizer
12
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
13
+ from utils import recursive_munch
14
+ from Utils.PLBERT.util import load_plbert
15
+
16
+ nltk.download("punkt")
17
+ np.random.seed(0)
18
+ random.seed(0)
19
  torch.manual_seed(0)
20
  torch.backends.cudnn.benchmark = False
21
  torch.backends.cudnn.deterministic = True
22
 
23
+ global_phonemizer = phonemizer.backend.EspeakBackend(
24
+ language="en-us", preserve_punctuation=True, with_stress=True
25
+ )
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ textcleaner = TextCleaner()
 
 
 
29
 
30
 
31
  to_mel = torchaudio.transforms.MelSpectrogram(
32
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300
33
+ )
34
  mean, std = -4, 4
35
 
36
+
37
  def length_to_mask(lengths):
38
+ mask = (
39
+ torch.arange(lengths.max())
40
+ .unsqueeze(0)
41
+ .expand(lengths.shape[0], -1)
42
+ .type_as(lengths)
43
+ )
44
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
45
  return mask
46
 
47
+
48
  def preprocess(wave):
49
  wave_tensor = torch.from_numpy(wave).float()
50
  mel_tensor = to_mel(wave_tensor)
51
  mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
52
  return mel_tensor
53
 
54
+
55
  def compute_style(path):
56
  wave, sr = librosa.load(path, sr=24000)
57
  audio, index = librosa.effects.trim(wave, top_db=30)
 
65
 
66
  return torch.cat([ref_s, ref_p], dim=1)
67
 
68
+
69
+ device = "cpu"
70
  if torch.cuda.is_available():
71
+ device = "cuda"
72
  elif torch.backends.mps.is_available():
73
  print("MPS would be available but cannot be used rn")
74
+ # device = "mps"
 
 
 
 
 
75
 
76
  # config = yaml.safe_load(open("Models/LibriTTS/config.yml"))
77
+ config = {
78
+ "ASR_config": "Utils/ASR/config.yml",
79
+ "ASR_path": "Utils/ASR/epoch_00080.pth",
80
+ "F0_path": "Utils/JDC/bst.t7",
81
+ "PLBERT_dir": "Utils/PLBERT/",
82
+ "batch_size": 8,
83
+ "data_params": {
84
+ "OOD_data": "Data/OOD_texts.txt",
85
+ "min_length": 50,
86
+ "root_path": "",
87
+ "train_data": "Data/train_list.txt",
88
+ "val_data": "Data/val_list.txt",
89
+ },
90
+ "device": "cuda",
91
+ "epochs_1st": 40,
92
+ "epochs_2nd": 25,
93
+ "first_stage_path": "first_stage.pth",
94
+ "load_only_params": False,
95
+ "log_dir": "Models/LibriTTS",
96
+ "log_interval": 10,
97
+ "loss_params": {
98
+ "TMA_epoch": 4,
99
+ "diff_epoch": 0,
100
+ "joint_epoch": 0,
101
+ "lambda_F0": 1.0,
102
+ "lambda_ce": 20.0,
103
+ "lambda_diff": 1.0,
104
+ "lambda_dur": 1.0,
105
+ "lambda_gen": 1.0,
106
+ "lambda_mel": 5.0,
107
+ "lambda_mono": 1.0,
108
+ "lambda_norm": 1.0,
109
+ "lambda_s2s": 1.0,
110
+ "lambda_slm": 1.0,
111
+ "lambda_sty": 1.0,
112
+ },
113
+ "max_len": 300,
114
+ "model_params": {
115
+ "decoder": {
116
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
117
+ "resblock_kernel_sizes": [3, 7, 11],
118
+ "type": "hifigan",
119
+ "upsample_initial_channel": 512,
120
+ "upsample_kernel_sizes": [20, 10, 6, 4],
121
+ "upsample_rates": [10, 5, 3, 2],
122
+ },
123
+ "diffusion": {
124
+ "dist": {
125
+ "estimate_sigma_data": True,
126
+ "mean": -3.0,
127
+ "sigma_data": 0.19926648961191362,
128
+ "std": 1.0,
129
+ },
130
+ "embedding_mask_proba": 0.1,
131
+ "transformer": {
132
+ "head_features": 64,
133
+ "multiplier": 2,
134
+ "num_heads": 8,
135
+ "num_layers": 3,
136
+ },
137
+ },
138
+ "dim_in": 64,
139
+ "dropout": 0,
140
+ "hidden_dim": 512,
141
+ "max_conv_dim": 512,
142
+ "max_dur": 50,
143
+ "multispeaker": True,
144
+ "n_layer": 3,
145
+ "n_mels": 80,
146
+ "n_token": 178,
147
+ "slm": {
148
+ "hidden": 768,
149
+ "initial_channel": 64,
150
+ "model": "microsoft/wavlm-base-plus",
151
+ "nlayers": 13,
152
+ "sr": 16000,
153
+ },
154
+ "style_dim": 128,
155
+ },
156
+ "optimizer_params": {"bert_lr": 1e-05, "ft_lr": 1e-05, "lr": 0.0001},
157
+ "preprocess_params": {
158
+ "spect_params": {"hop_length": 300, "n_fft": 2048, "win_length": 1200},
159
+ "sr": 24000,
160
+ },
161
+ "pretrained_model": "Models/LibriTTS/epoch_2nd_00002.pth",
162
+ "save_freq": 1,
163
+ "second_stage_load_pretrained": True,
164
+ "slmadv_params": {
165
+ "batch_percentage": 0.5,
166
+ "iter": 20,
167
+ "max_len": 500,
168
+ "min_len": 400,
169
+ "scale": 0.01,
170
+ "sig": 1.5,
171
+ "thresh": 5,
172
+ },
173
+ }
174
+
175
+
176
+ BERT_path = config.get("PLBERT_dir", False)
177
  plbert = load_plbert(BERT_path)
178
 
179
+
180
+ model_params = recursive_munch(config["model_params"])
181
+ model = build_model(model_params, plbert)
182
  _ = [model[key].eval() for key in model]
183
  _ = [model[key].to(device) for key in model]
184
 
185
+ # for key in model:
186
+ # print(f"Compiling {key}")
187
+ # model[key] = torch.compile(model[key])
188
+ # print(f"Compiled {key}")
189
+
190
+
191
+ params_whole = torch.load(
192
+ str(
193
+ cached_path(
194
+ "hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth"
195
+ )
196
+ ),
197
+ map_location="cpu",
198
+ )
199
+ params = params_whole["net"]
200
 
201
  for key in model:
202
  if key in params:
203
+ print("%s loaded" % key)
204
  try:
205
  model[key].load_state_dict(params[key])
206
  except:
207
  from collections import OrderedDict
208
+
209
  state_dict = params[key]
210
  new_state_dict = OrderedDict()
211
  for k, v in state_dict.items():
212
+ name = k[7:] # remove `module.`
213
  new_state_dict[name] = v
214
  # load params
215
  model[key].load_state_dict(new_state_dict, strict=False)
 
217
  # _load(params[key], model[key])
218
  _ = [model[key].eval() for key in model]
219
 
 
220
 
221
  sampler = DiffusionSampler(
222
  model.diffusion.diffusion,
223
  sampler=ADPM2Sampler(),
224
+ sigma_schedule=KarrasSchedule(
225
+ sigma_min=0.0001, sigma_max=3.0, rho=9.0
226
+ ), # empirical parameters
227
+ clamp=False,
228
  )
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ def inference(
232
+ text,
233
+ ref_s,
234
+ alpha=0.3,
235
+ beta=0.7,
236
+ diffusion_steps=5,
237
+ embedding_scale=1,
238
+ use_gruut=False,
239
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  text = text.strip()
241
  ps = global_phonemizer.phonemize([text])
242
  ps = word_tokenize(ps[0])
243
+ ps = " ".join(ps)
244
+ tokens = textcleaner(ps)
 
245
  tokens.insert(0, 0)
246
  tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
247
 
 
 
 
 
 
 
 
 
 
 
248
  with torch.no_grad():
249
  input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
250
  text_mask = length_to_mask(input_lengths).to(device)
 
253
  bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
254
  d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
255
 
256
+ s_pred = sampler(
257
+ noise=torch.randn((1, 256)).unsqueeze(1).to(device),
258
+ embedding=bert_dur,
259
+ embedding_scale=embedding_scale,
260
+ features=ref_s, # reference from the same speaker as the embedding
261
+ num_steps=diffusion_steps,
262
+ ).squeeze(1)
 
 
263
 
264
  s = s_pred[:, 128:]
265
  ref = s_pred[:, :128]
266
 
267
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
268
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
269
 
270
+ d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
 
271
 
272
  x, _ = model.predictor.lstm(d)
273
  duration = model.predictor.duration_proj(x)
 
275
  duration = torch.sigmoid(duration).sum(axis=-1)
276
  pred_dur = torch.round(duration.squeeze()).clamp(min=1)
277
 
 
278
  pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
279
  c_frame = 0
280
  for i in range(pred_aln_trg.size(0)):
281
+ pred_aln_trg[i, c_frame : c_frame + int(pred_dur[i].data)] = 1
282
  c_frame += int(pred_dur[i].data)
283
 
284
  # encode prosody
285
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
286
+ asr_new = torch.zeros_like(en)
287
+ asr_new[:, :, 0] = en[:, :, 0]
288
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
289
+ en = asr_new
 
290
 
291
  F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
292
 
293
+ asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
294
+ asr_new = torch.zeros_like(asr)
295
+ asr_new[:, :, 0] = asr[:, :, 0]
296
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
297
+ asr = asr_new
 
 
 
 
298
 
299
+ out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
300
 
301
+ return (
302
+ out.squeeze().cpu().numpy()[..., :-50]
303
+ ) # weird pulse at the end of the model, need to be fixed later
train_finetune.py CHANGED
@@ -7,8 +7,6 @@ import numpy as np
7
  import torch
8
  from torch import nn
9
  import torch.nn.functional as F
10
- import torchaudio
11
- import librosa
12
  import click
13
  import shutil
14
  import warnings
@@ -18,8 +16,6 @@ from torch.utils.tensorboard import SummaryWriter
18
 
19
  from meldataset import build_dataloader
20
 
21
- from Utils.ASR.models import ASRCNN
22
- from Utils.JDC.model import JDCNet
23
  from Utils.PLBERT.util import load_plbert
24
 
25
  from models import *
@@ -75,7 +71,7 @@ def main(config_path):
75
  epochs = config.get("epochs", 200)
76
  save_freq = config.get("save_freq", 2)
77
  log_interval = config.get("log_interval", 10)
78
- saving_epoch = config.get("save_freq", 2)
79
 
80
  data_params = config.get("data_params", None)
81
  sr = config["preprocess_params"].get("sr", 24000)
@@ -245,11 +241,11 @@ def main(config_path):
245
  n_down = model.text_aligner.n_down
246
 
247
  best_loss = float("inf") # best test loss
248
- loss_train_record = list([])
249
- loss_test_record = list([])
250
  iters = 0
251
 
252
- criterion = nn.L1Loss() # F0 loss (regression)
253
  torch.cuda.empty_cache()
254
 
255
  stft_loss = MultiResolutionSTFTLoss().to(device)
@@ -257,7 +253,6 @@ def main(config_path):
257
  print("BERT", optimizer.optimizers["bert"])
258
  print("decoder", optimizer.optimizers["decoder"])
259
 
260
- start_ds = False
261
 
262
  running_std = []
263
 
@@ -302,7 +297,7 @@ def main(config_path):
302
  ) = batch
303
  with torch.no_grad():
304
  mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
305
- mel_mask = length_to_mask(mel_input_length).to(device)
306
  text_mask = length_to_mask(input_lengths).to(texts.device)
307
 
308
  # compute reference styles
 
7
  import torch
8
  from torch import nn
9
  import torch.nn.functional as F
 
 
10
  import click
11
  import shutil
12
  import warnings
 
16
 
17
  from meldataset import build_dataloader
18
 
 
 
19
  from Utils.PLBERT.util import load_plbert
20
 
21
  from models import *
 
71
  epochs = config.get("epochs", 200)
72
  save_freq = config.get("save_freq", 2)
73
  log_interval = config.get("log_interval", 10)
74
+ config.get("save_freq", 2)
75
 
76
  data_params = config.get("data_params", None)
77
  sr = config["preprocess_params"].get("sr", 24000)
 
241
  n_down = model.text_aligner.n_down
242
 
243
  best_loss = float("inf") # best test loss
244
+ list([])
245
+ list([])
246
  iters = 0
247
 
248
+ nn.L1Loss() # F0 loss (regression)
249
  torch.cuda.empty_cache()
250
 
251
  stft_loss = MultiResolutionSTFTLoss().to(device)
 
253
  print("BERT", optimizer.optimizers["bert"])
254
  print("decoder", optimizer.optimizers["decoder"])
255
 
 
256
 
257
  running_std = []
258
 
 
297
  ) = batch
298
  with torch.no_grad():
299
  mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
300
+ length_to_mask(mel_input_length).to(device)
301
  text_mask = length_to_mask(input_lengths).to(texts.device)
302
 
303
  # compute reference styles
train_first.py CHANGED
@@ -1,7 +1,5 @@
1
  import os
2
  import os.path as osp
3
- import re
4
- import sys
5
  import yaml
6
  import shutil
7
  import numpy as np
@@ -17,10 +15,7 @@ import yaml
17
  from munch import Munch
18
  import numpy as np
19
  import torch
20
- from torch import nn
21
  import torch.nn.functional as F
22
- import torchaudio
23
- import librosa
24
 
25
  from models import *
26
  from meldataset import build_dataloader
@@ -30,7 +25,6 @@ from optimizers import build_optimizer
30
  import time
31
 
32
  from accelerate import Accelerator
33
- from accelerate.utils import LoggerType
34
  from accelerate import DistributedDataParallelKwargs
35
 
36
  from torch.utils.tensorboard import SummaryWriter
@@ -69,7 +63,7 @@ def main(config_path):
69
  device = accelerator.device
70
 
71
  epochs = config.get("epochs_1st", 200)
72
- save_freq = config.get("save_freq", 2)
73
  log_interval = config.get("log_interval", 10)
74
  saving_epoch = config.get("save_freq", 2)
75
 
@@ -137,8 +131,8 @@ def main(config_path):
137
  model = build_model(model_params, text_aligner, pitch_extractor, plbert)
138
 
139
  best_loss = float("inf") # best test loss
140
- loss_train_record = list([])
141
- loss_test_record = list([])
142
 
143
  loss_params = Munch(config["loss_params"])
144
  TMA_epoch = loss_params.TMA_epoch
 
1
  import os
2
  import os.path as osp
 
 
3
  import yaml
4
  import shutil
5
  import numpy as np
 
15
  from munch import Munch
16
  import numpy as np
17
  import torch
 
18
  import torch.nn.functional as F
 
 
19
 
20
  from models import *
21
  from meldataset import build_dataloader
 
25
  import time
26
 
27
  from accelerate import Accelerator
 
28
  from accelerate import DistributedDataParallelKwargs
29
 
30
  from torch.utils.tensorboard import SummaryWriter
 
63
  device = accelerator.device
64
 
65
  epochs = config.get("epochs_1st", 200)
66
+ config.get("save_freq", 2)
67
  log_interval = config.get("log_interval", 10)
68
  saving_epoch = config.get("save_freq", 2)
69
 
 
131
  model = build_model(model_params, text_aligner, pitch_extractor, plbert)
132
 
133
  best_loss = float("inf") # best test loss
134
+ list([])
135
+ list([])
136
 
137
  loss_params = Munch(config["loss_params"])
138
  TMA_epoch = loss_params.TMA_epoch
train_second.py CHANGED
@@ -1,5 +1,4 @@
1
  # load packages
2
- import random
3
  import yaml
4
  import time
5
  from munch import Munch
@@ -7,8 +6,6 @@ import numpy as np
7
  import torch
8
  from torch import nn
9
  import torch.nn.functional as F
10
- import torchaudio
11
- import librosa
12
  import click
13
  import shutil
14
  import warnings
@@ -18,8 +15,6 @@ from torch.utils.tensorboard import SummaryWriter
18
 
19
  from meldataset import build_dataloader
20
 
21
- from Utils.ASR.models import ASRCNN
22
- from Utils.JDC.model import JDCNet
23
  from Utils.PLBERT.util import load_plbert
24
 
25
  from models import *
@@ -73,7 +68,7 @@ def main(config_path):
73
  batch_size = config.get("batch_size", 10)
74
 
75
  epochs = config.get("epochs_2nd", 200)
76
- save_freq = config.get("save_freq", 2)
77
  log_interval = config.get("log_interval", 10)
78
  saving_epoch = config.get("save_freq", 2)
79
 
@@ -245,11 +240,11 @@ def main(config_path):
245
  n_down = model.text_aligner.n_down
246
 
247
  best_loss = float("inf") # best test loss
248
- loss_train_record = list([])
249
- loss_test_record = list([])
250
  iters = 0
251
 
252
- criterion = nn.L1Loss() # F0 loss (regression)
253
  torch.cuda.empty_cache()
254
 
255
  stft_loss = MultiResolutionSTFTLoss().to(device)
@@ -303,7 +298,7 @@ def main(config_path):
303
 
304
  with torch.no_grad():
305
  mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
306
- mel_mask = length_to_mask(mel_input_length).to(device)
307
  text_mask = length_to_mask(input_lengths).to(texts.device)
308
 
309
  try:
@@ -445,7 +440,7 @@ def main(config_path):
445
  F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
446
  F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
447
 
448
- asr_real = model.text_aligner.get_feature(gt)
449
 
450
  N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
451
 
 
1
  # load packages
 
2
  import yaml
3
  import time
4
  from munch import Munch
 
6
  import torch
7
  from torch import nn
8
  import torch.nn.functional as F
 
 
9
  import click
10
  import shutil
11
  import warnings
 
15
 
16
  from meldataset import build_dataloader
17
 
 
 
18
  from Utils.PLBERT.util import load_plbert
19
 
20
  from models import *
 
68
  batch_size = config.get("batch_size", 10)
69
 
70
  epochs = config.get("epochs_2nd", 200)
71
+ config.get("save_freq", 2)
72
  log_interval = config.get("log_interval", 10)
73
  saving_epoch = config.get("save_freq", 2)
74
 
 
240
  n_down = model.text_aligner.n_down
241
 
242
  best_loss = float("inf") # best test loss
243
+ list([])
244
+ list([])
245
  iters = 0
246
 
247
+ nn.L1Loss() # F0 loss (regression)
248
  torch.cuda.empty_cache()
249
 
250
  stft_loss = MultiResolutionSTFTLoss().to(device)
 
298
 
299
  with torch.no_grad():
300
  mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
301
+ length_to_mask(mel_input_length).to(device)
302
  text_mask = length_to_mask(input_lengths).to(texts.device)
303
 
304
  try:
 
440
  F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
441
  F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
442
 
443
+ model.text_aligner.get_feature(gt)
444
 
445
  N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
446
 
utils.py CHANGED
@@ -1,13 +1,6 @@
1
- from monotonic_align import maximum_path
2
- from monotonic_align import mask_from_lens
3
  from monotonic_align.core import maximum_path_c
4
  import numpy as np
5
  import torch
6
- import copy
7
- from torch import nn
8
- import torch.nn.functional as F
9
- import torchaudio
10
- import librosa
11
  import matplotlib.pyplot as plt
12
  from munch import Munch
13
 
 
 
 
1
  from monotonic_align.core import maximum_path_c
2
  import numpy as np
3
  import torch
 
 
 
 
 
4
  import matplotlib.pyplot as plt
5
  from munch import Munch
6
 
voices/f-us-1.wav.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24012e9cefccf44b9187cb7e61907eac7120e96115f77b922c74c0e36b5b45f6
3
+ size 1152
voices/f-us-2.wav.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25df85a2fb487a2e55189fbd02173c7b84028b0cbdb056aaa61d0f853136ebba
3
+ size 1152
voices/f-us-3.wav.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ad8ce27bfbe1d967d3b5f5f0894b6f3d899c1f23dec5a85157318ecb719eab7
3
+ size 1152
voices/f-us-4.wav.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:397ead679dbd859550cfe2a2635d5ddc78c0b400cd434fdd0dd41cac88ceb667
3
+ size 1152
voices/m-us-1.wav.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d22a0041c44675a8e2d8b9830a3261f5359e31a8418ea5ef19f9ba76bda2c13
3
+ size 1152
voices/m-us-2.wav.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e09851ffb127fbd3a17329b8d68a27dec5f5920fd5f1aaa7b871046552a2c902
3
+ size 1152
voices/m-us-3.wav.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98ff10843ecb12f0fe4c31c9309da6f594ed2dd7248d163c927fda05dd608336
3
+ size 1152
voices/m-us-4.wav.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:275b5b956cef2e6eff4258fabb8fffdab130d5e858fd826258fd76e46296263d
3
+ size 1152