zhzluke96 commited on
Commit
d2b7e94
1 Parent(s): 9d9fe0d
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env.webui +2 -2
  2. README.md +2 -2
  3. launch.py +11 -10
  4. modules/ChatTTS/ChatTTS/__init__.py +1 -1
  5. modules/ChatTTS/ChatTTS/core.py +7 -7
  6. modules/ChatTTS/ChatTTS/infer/api.py +1 -0
  7. modules/ChatTTS/ChatTTS/model/dvae.py +79 -48
  8. modules/ChatTTS/ChatTTS/model/gpt.py +167 -87
  9. modules/ChatTTS/ChatTTS/utils/infer_utils.py +1 -0
  10. modules/ChatTTS/ChatTTS/utils/io_utils.py +6 -6
  11. modules/Denoiser/AudioDenoiser.py +5 -3
  12. modules/Denoiser/AudioNosiseModel.py +2 -3
  13. modules/Enhancer/ResembleEnhance.py +6 -9
  14. modules/SentenceSplitter.py +1 -0
  15. modules/SynthesizeSegments.py +24 -18
  16. modules/api/Api.py +3 -5
  17. modules/api/api_setup.py +11 -13
  18. modules/api/impl/google_api.py +2 -6
  19. modules/api/impl/handler/AudioHandler.py +2 -1
  20. modules/api/impl/handler/SSMLHandler.py +3 -3
  21. modules/api/impl/handler/TTSHandler.py +2 -2
  22. modules/api/impl/model/enhancer_model.py +1 -0
  23. modules/api/impl/models_api.py +1 -1
  24. modules/api/impl/openai_api.py +6 -11
  25. modules/api/impl/ping_api.py +1 -2
  26. modules/api/impl/refiner_api.py +0 -3
  27. modules/api/impl/speaker_api.py +3 -2
  28. modules/api/impl/ssml_api.py +3 -8
  29. modules/api/impl/style_api.py +1 -1
  30. modules/api/impl/tts_api.py +3 -7
  31. modules/api/impl/xtts_v2_api.py +5 -7
  32. modules/api/utils.py +3 -7
  33. modules/api/worker.py +2 -1
  34. modules/config.py +2 -2
  35. modules/data.py +0 -1
  36. modules/denoise.py +3 -5
  37. modules/devices/devices.py +4 -3
  38. modules/devices/mac_devices.py +3 -2
  39. modules/ffmpeg_env.py +2 -1
  40. modules/finetune/train_speaker.py +8 -5
  41. modules/finetune/utils/dataset.py +6 -6
  42. modules/finetune/utils/logger.py +3 -4
  43. modules/generate_audio.py +7 -10
  44. modules/models.py +5 -5
  45. modules/normalization.py +5 -3
  46. modules/prompts/news_oral_prompt.txt +23 -4
  47. modules/refiner.py +1 -2
  48. modules/repos_static/resemble_enhance/common.py +3 -1
  49. modules/repos_static/resemble_enhance/data/dataset.py +21 -7
  50. modules/repos_static/resemble_enhance/data/distorter/base.py +1 -1
.env.webui CHANGED
@@ -14,9 +14,9 @@ DEBUG_GENERATE=True
14
  PRELOAD_MODELS=True
15
 
16
  # Text-to-Speech (TTS) configuration
17
- TTS_MAX_LEN=1000
18
  SSML_MAX_LEN=3000
19
  MAX_BATCH_SIZE=12
20
 
21
- V_GIT_TAG="🤗hf(0.6.1-rc)"
22
  V_GIT_COMMIT=main
 
14
  PRELOAD_MODELS=True
15
 
16
  # Text-to-Speech (TTS) configuration
17
+ TTS_MAX_LEN=2000
18
  SSML_MAX_LEN=3000
19
  MAX_BATCH_SIZE=12
20
 
21
+ V_GIT_TAG="🤗hf(0.6.1)"
22
  V_GIT_COMMIT=main
README.md CHANGED
@@ -16,7 +16,7 @@ sdk_version: 4.36.1
16
 
17
  | 类型 | 最大字符数 |
18
  |------|-----------|
19
- | TTS | 1000 字符 |
20
  | SSML | 3000 字符(不计算 SSML 标签,只计算文本) |
21
 
22
  # HuggingFace Space Limit
@@ -25,7 +25,7 @@ Due to the runtime limit for GPU usage on HuggingFace, extremely long tasks will
25
 
26
  | Type | Maximum Characters |
27
  |------|---------------------|
28
- | TTS | 1000 characters |
29
  | SSML | 3000 characters (excluding SSML tags, only counting text) |
30
 
31
  # 🗣️ ChatTTS-Forge
 
16
 
17
  | 类型 | 最大字符数 |
18
  |------|-----------|
19
+ | TTS | 2000 字符 |
20
  | SSML | 3000 字符(不计算 SSML 标签,只计算文本) |
21
 
22
  # HuggingFace Space Limit
 
25
 
26
  | Type | Maximum Characters |
27
  |------|---------------------|
28
+ | TTS | 2000 characters |
29
  | SSML | 3000 characters (excluding SSML tags, only counting text) |
30
 
31
  # 🗣️ ChatTTS-Forge
launch.py CHANGED
@@ -1,23 +1,24 @@
1
- import os
2
  import logging
 
3
 
4
- from modules.api.api_setup import setup_api_args, setup_model_args, setup_uvicon_args
5
  from modules.ffmpeg_env import setup_ffmpeg_path
6
 
7
- setup_ffmpeg_path()
8
- logging.basicConfig(
9
- level=os.getenv("LOG_LEVEL", "INFO"),
10
- format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
11
- )
 
 
 
12
 
13
  import argparse
 
14
  import uvicorn
15
 
16
- from modules import config
17
  from modules.utils import env
18
 
19
- from fastapi import FastAPI
20
-
21
  logger = logging.getLogger(__name__)
22
 
23
  if __name__ == "__main__":
 
 
1
  import logging
2
+ import os
3
 
 
4
  from modules.ffmpeg_env import setup_ffmpeg_path
5
 
6
+ try:
7
+ setup_ffmpeg_path()
8
+ logging.basicConfig(
9
+ level=os.getenv("LOG_LEVEL", "INFO"),
10
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
11
+ )
12
+ except BaseException:
13
+ pass
14
 
15
  import argparse
16
+
17
  import uvicorn
18
 
19
+ from modules.api.api_setup import setup_api_args, setup_model_args, setup_uvicon_args
20
  from modules.utils import env
21
 
 
 
22
  logger = logging.getLogger(__name__)
23
 
24
  if __name__ == "__main__":
modules/ChatTTS/ChatTTS/__init__.py CHANGED
@@ -1 +1 @@
1
- from .core import Chat
 
1
+ from .core import Chat
modules/ChatTTS/ChatTTS/core.py CHANGED
@@ -1,21 +1,21 @@
1
- import os
2
  import logging
3
- from omegaconf import OmegaConf
4
 
5
  import torch
 
 
6
  from vocos import Vocos
 
 
7
  from .model.dvae import DVAE
8
  from .model.gpt import GPT_warpper
9
  from .utils.infer_utils import (
10
- count_invalid_characters,
11
- detect_language,
12
  apply_character_map,
13
  apply_half2full_map,
 
 
14
  )
15
  from .utils.io_utils import get_latest_modified_file
16
- from .infer.api import refine_text, infer_code
17
-
18
- from huggingface_hub import snapshot_download
19
 
20
  logging.basicConfig(level=logging.INFO)
21
 
 
 
1
  import logging
2
+ import os
3
 
4
  import torch
5
+ from huggingface_hub import snapshot_download
6
+ from omegaconf import OmegaConf
7
  from vocos import Vocos
8
+
9
+ from .infer.api import infer_code, refine_text
10
  from .model.dvae import DVAE
11
  from .model.gpt import GPT_warpper
12
  from .utils.infer_utils import (
 
 
13
  apply_character_map,
14
  apply_half2full_map,
15
+ count_invalid_characters,
16
+ detect_language,
17
  )
18
  from .utils.io_utils import get_latest_modified_file
 
 
 
19
 
20
  logging.basicConfig(level=logging.INFO)
21
 
modules/ChatTTS/ChatTTS/infer/api.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  import torch.nn.functional as F
3
  from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
 
4
  from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
5
 
6
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
4
+
5
  from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
6
 
7
 
modules/ChatTTS/ChatTTS/model/dvae.py CHANGED
@@ -1,28 +1,36 @@
1
  import math
2
- from einops import rearrange
3
- from vector_quantize_pytorch import GroupedResidualFSQ
4
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
 
 
 
8
 
9
  class ConvNeXtBlock(nn.Module):
10
  def __init__(
11
  self,
12
  dim: int,
13
  intermediate_dim: int,
14
- kernel, dilation,
 
15
  layer_scale_init_value: float = 1e-6,
16
  ):
17
  # ConvNeXt Block copied from Vocos.
18
  super().__init__()
19
- self.dwconv = nn.Conv1d(dim, dim,
20
- kernel_size=kernel, padding=dilation*(kernel//2),
21
- dilation=dilation, groups=dim
22
- ) # depthwise conv
23
-
 
 
 
 
24
  self.norm = nn.LayerNorm(dim, eps=1e-6)
25
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
 
 
26
  self.act = nn.GELU()
27
  self.pwconv2 = nn.Linear(intermediate_dim, dim)
28
  self.gamma = (
@@ -31,7 +39,7 @@ class ConvNeXtBlock(nn.Module):
31
  else None
32
  )
33
 
34
- def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor:
35
  residual = x
36
  x = self.dwconv(x)
37
  x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
@@ -45,14 +53,11 @@ class ConvNeXtBlock(nn.Module):
45
 
46
  x = residual + x
47
  return x
48
-
49
 
50
 
51
  class GFSQ(nn.Module):
52
 
53
- def __init__(self,
54
- dim, levels, G, R, eps=1e-5, transpose = True
55
- ):
56
  super(GFSQ, self).__init__()
57
  self.quantizer = GroupedResidualFSQ(
58
  dim=dim,
@@ -65,50 +70,74 @@ class GFSQ(nn.Module):
65
  self.transpose = transpose
66
  self.G = G
67
  self.R = R
68
-
69
  def _embed(self, x):
70
  if self.transpose:
71
- x = x.transpose(1,2)
72
  x = rearrange(
73
- x, "b t (g r) -> g b t r", g = self.G, r = self.R,
74
- )
 
 
 
75
  feat = self.quantizer.get_output_from_indices(x)
76
- return feat.transpose(1,2) if self.transpose else feat
77
-
78
- def forward(self, x,):
 
 
 
79
  if self.transpose:
80
- x = x.transpose(1,2)
81
  feat, ind = self.quantizer(x)
82
  ind = rearrange(
83
- ind, "g b t r ->b t (g r)",
84
- )
 
85
  embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
86
- e_mean = torch.mean(embed_onehot, dim=[0,1])
87
  e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
88
  perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
89
-
90
  return (
91
  torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
92
- feat.transpose(1,2) if self.transpose else feat,
93
  perplexity,
94
  None,
95
- ind.transpose(1,2) if self.transpose else ind,
96
  )
97
-
 
98
  class DVAEDecoder(nn.Module):
99
- def __init__(self, idim, odim,
100
- n_layer = 12, bn_dim = 64, hidden = 256,
101
- kernel = 7, dilation = 2, up = False
102
- ):
 
 
 
 
 
 
 
103
  super().__init__()
104
  self.up = up
105
  self.conv_in = nn.Sequential(
106
- nn.Conv1d(idim, bn_dim, 3, 1, 1), nn.GELU(),
107
- nn.Conv1d(bn_dim, hidden, 3, 1, 1)
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
109
- self.decoder_block = nn.ModuleList([
110
- ConvNeXtBlock(hidden, hidden* 4, kernel, dilation,)
111
- for _ in range(n_layer)])
112
  self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
113
 
114
  def forward(self, input, conditioning=None):
@@ -117,17 +146,15 @@ class DVAEDecoder(nn.Module):
117
  x = self.conv_in(x)
118
  for f in self.decoder_block:
119
  x = f(x, conditioning)
120
-
121
  x = self.conv_out(x)
122
  return x.transpose(1, 2)
123
-
124
 
125
  class DVAE(nn.Module):
126
- def __init__(
127
- self, decoder_config, vq_config, dim=512
128
- ):
129
  super().__init__()
130
- self.register_buffer('coef', torch.randn(1, 100, 1))
131
 
132
  self.decoder = DVAEDecoder(**decoder_config)
133
  self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
@@ -142,10 +169,14 @@ class DVAE(nn.Module):
142
  vq_feats = self.vq_layer._embed(inp)
143
  else:
144
  vq_feats = inp.detach().clone()
145
-
146
- vq_feats = vq_feats.view(
147
- (vq_feats.size(0), 2, vq_feats.size(1)//2, vq_feats.size(2)),
148
- ).permute(0, 2, 3, 1).flatten(2)
 
 
 
 
149
 
150
  vq_feats = vq_feats.transpose(1, 2)
151
  dec_out = self.decoder(input=vq_feats)
 
1
  import math
 
 
2
 
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from vector_quantize_pytorch import GroupedResidualFSQ
8
+
9
 
10
  class ConvNeXtBlock(nn.Module):
11
  def __init__(
12
  self,
13
  dim: int,
14
  intermediate_dim: int,
15
+ kernel,
16
+ dilation,
17
  layer_scale_init_value: float = 1e-6,
18
  ):
19
  # ConvNeXt Block copied from Vocos.
20
  super().__init__()
21
+ self.dwconv = nn.Conv1d(
22
+ dim,
23
+ dim,
24
+ kernel_size=kernel,
25
+ padding=dilation * (kernel // 2),
26
+ dilation=dilation,
27
+ groups=dim,
28
+ ) # depthwise conv
29
+
30
  self.norm = nn.LayerNorm(dim, eps=1e-6)
31
+ self.pwconv1 = nn.Linear(
32
+ dim, intermediate_dim
33
+ ) # pointwise/1x1 convs, implemented with linear layers
34
  self.act = nn.GELU()
35
  self.pwconv2 = nn.Linear(intermediate_dim, dim)
36
  self.gamma = (
 
39
  else None
40
  )
41
 
42
+ def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
43
  residual = x
44
  x = self.dwconv(x)
45
  x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
 
53
 
54
  x = residual + x
55
  return x
 
56
 
57
 
58
  class GFSQ(nn.Module):
59
 
60
+ def __init__(self, dim, levels, G, R, eps=1e-5, transpose=True):
 
 
61
  super(GFSQ, self).__init__()
62
  self.quantizer = GroupedResidualFSQ(
63
  dim=dim,
 
70
  self.transpose = transpose
71
  self.G = G
72
  self.R = R
73
+
74
  def _embed(self, x):
75
  if self.transpose:
76
+ x = x.transpose(1, 2)
77
  x = rearrange(
78
+ x,
79
+ "b t (g r) -> g b t r",
80
+ g=self.G,
81
+ r=self.R,
82
+ )
83
  feat = self.quantizer.get_output_from_indices(x)
84
+ return feat.transpose(1, 2) if self.transpose else feat
85
+
86
+ def forward(
87
+ self,
88
+ x,
89
+ ):
90
  if self.transpose:
91
+ x = x.transpose(1, 2)
92
  feat, ind = self.quantizer(x)
93
  ind = rearrange(
94
+ ind,
95
+ "g b t r ->b t (g r)",
96
+ )
97
  embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
98
+ e_mean = torch.mean(embed_onehot, dim=[0, 1])
99
  e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
100
  perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
101
+
102
  return (
103
  torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
104
+ feat.transpose(1, 2) if self.transpose else feat,
105
  perplexity,
106
  None,
107
+ ind.transpose(1, 2) if self.transpose else ind,
108
  )
109
+
110
+
111
  class DVAEDecoder(nn.Module):
112
+ def __init__(
113
+ self,
114
+ idim,
115
+ odim,
116
+ n_layer=12,
117
+ bn_dim=64,
118
+ hidden=256,
119
+ kernel=7,
120
+ dilation=2,
121
+ up=False,
122
+ ):
123
  super().__init__()
124
  self.up = up
125
  self.conv_in = nn.Sequential(
126
+ nn.Conv1d(idim, bn_dim, 3, 1, 1),
127
+ nn.GELU(),
128
+ nn.Conv1d(bn_dim, hidden, 3, 1, 1),
129
+ )
130
+ self.decoder_block = nn.ModuleList(
131
+ [
132
+ ConvNeXtBlock(
133
+ hidden,
134
+ hidden * 4,
135
+ kernel,
136
+ dilation,
137
+ )
138
+ for _ in range(n_layer)
139
+ ]
140
  )
 
 
 
141
  self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
142
 
143
  def forward(self, input, conditioning=None):
 
146
  x = self.conv_in(x)
147
  for f in self.decoder_block:
148
  x = f(x, conditioning)
149
+
150
  x = self.conv_out(x)
151
  return x.transpose(1, 2)
152
+
153
 
154
  class DVAE(nn.Module):
155
+ def __init__(self, decoder_config, vq_config, dim=512):
 
 
156
  super().__init__()
157
+ self.register_buffer("coef", torch.randn(1, 100, 1))
158
 
159
  self.decoder = DVAEDecoder(**decoder_config)
160
  self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
 
169
  vq_feats = self.vq_layer._embed(inp)
170
  else:
171
  vq_feats = inp.detach().clone()
172
+
173
+ vq_feats = (
174
+ vq_feats.view(
175
+ (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
176
+ )
177
+ .permute(0, 2, 3, 1)
178
+ .flatten(2)
179
+ )
180
 
181
  vq_feats = vq_feats.transpose(1, 2)
182
  dec_out = self.decoder(input=vq_feats)
modules/ChatTTS/ChatTTS/model/gpt.py CHANGED
@@ -1,19 +1,20 @@
1
  import os
 
2
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
3
 
4
  import logging
5
- from tqdm import tqdm
6
- from einops import rearrange
7
- from transformers.cache_utils import Cache
8
 
9
  import torch
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  import torch.nn.utils.parametrize as P
 
13
  from torch.nn.utils.parametrizations import weight_norm
14
- from transformers import LlamaModel, LlamaConfig
15
-
16
-
 
 
17
  class LlamaMLP(nn.Module):
18
  def __init__(self, hidden_size, intermediate_size):
19
  super().__init__()
@@ -27,70 +28,106 @@ class LlamaMLP(nn.Module):
27
  def forward(self, x):
28
  down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
29
  return down_proj
30
-
31
-
32
  class GPT_warpper(nn.Module):
33
  def __init__(
34
- self,
35
- gpt_config,
36
  num_audio_tokens,
37
  num_text_tokens,
38
  num_vq=4,
39
  **kwargs,
40
- ):
41
  super().__init__()
42
 
43
  self.logger = logging.getLogger(__name__)
44
  self.gpt = self.build_model(gpt_config)
45
- self.model_dim = self.gpt.config.hidden_size
46
 
47
  self.num_vq = num_vq
48
- self.emb_code = nn.ModuleList([nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)])
 
 
49
  self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
50
- self.head_text = weight_norm(nn.Linear(self.model_dim, num_text_tokens, bias=False), name='weight')
51
- self.head_code = nn.ModuleList([weight_norm(nn.Linear(self.model_dim, num_audio_tokens, bias=False), name='weight') for i in range(self.num_vq)])
 
 
 
 
 
 
 
 
 
 
52
 
53
  def build_model(self, config):
54
-
55
  configuration = LlamaConfig(**config)
56
  model = LlamaModel(configuration)
57
  del model.embed_tokens
58
-
59
  return model
60
-
61
  def get_emb(self, input_ids, text_mask, **kwargs):
62
 
63
  emb_text = self.emb_text(input_ids[text_mask][:, 0])
64
-
65
- emb_code = [self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)]
 
 
66
  emb_code = torch.stack(emb_code, 2).sum(2)
67
-
68
- emb = torch.zeros((input_ids.shape[:-1])+(emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype)
 
 
 
 
69
  emb[text_mask] = emb_text
70
  emb[~text_mask] = emb_code.to(emb.dtype)
71
-
72
  return emb
73
-
74
  def prepare_inputs_for_generation(
75
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
 
 
 
 
 
 
76
  ):
77
  # With static cache, the `past_key_values` is None
78
  # TODO joao: standardize interface for the different Cache classes and remove of this if
79
  has_static_cache = False
80
  if past_key_values is None:
81
- past_key_values = getattr(self.gpt.layers[0].self_attn, "past_key_value", None)
 
 
82
  has_static_cache = past_key_values is not None
83
 
84
  past_length = 0
85
  if past_key_values is not None:
86
  if isinstance(past_key_values, Cache):
87
- past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
 
 
 
 
88
  max_cache_length = (
89
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
 
 
90
  if past_key_values.get_max_length() is not None
91
  else None
92
  )
93
- cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
 
 
 
 
94
  # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
95
  else:
96
  cache_length = past_length = past_key_values[0][0].shape[2]
@@ -100,7 +137,10 @@ class GPT_warpper(nn.Module):
100
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
101
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
102
  # input)
103
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
 
 
 
104
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
105
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
106
  # input_ids based on the past_length.
@@ -133,9 +173,13 @@ class GPT_warpper(nn.Module):
133
  # TODO: use `next_tokens` directly instead.
134
  model_inputs = {"input_ids": input_ids.contiguous()}
135
 
136
- input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
 
 
137
  if cache_position is None:
138
- cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
 
 
139
  else:
140
  cache_position = cache_position[-input_length:]
141
 
@@ -152,118 +196,154 @@ class GPT_warpper(nn.Module):
152
  }
153
  )
154
  return model_inputs
155
-
156
  def generate(
157
- self,
158
- emb,
159
- inputs_ids,
160
- temperature,
161
- eos_token,
162
- attention_mask = None,
163
- max_new_token = 2048,
164
- min_new_token = 0,
165
- LogitsWarpers = [],
166
- LogitsProcessors = [],
167
  infer_text=False,
168
  return_attn=False,
169
  return_hidden=False,
170
- disable_tqdm=False
171
  ):
172
  if disable_tqdm:
173
  tqdm = lambda x: x
174
  else:
175
  from tqdm import tqdm
176
-
177
- with torch.no_grad():
178
-
179
  attentions = []
180
  hiddens = []
181
-
182
- start_idx, end_idx = inputs_ids.shape[1], torch.zeros(inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long)
 
 
183
  finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
184
-
185
  temperature = temperature[None].expand(inputs_ids.shape[0], -1)
186
  temperature = rearrange(temperature, "b n -> (b n) 1")
187
 
188
- attention_mask_cache = torch.ones((inputs_ids.shape[0], inputs_ids.shape[1]+max_new_token,), dtype=torch.bool, device=inputs_ids.device)
 
 
 
 
 
 
 
189
  if attention_mask is not None:
190
- attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
191
-
192
  for i in tqdm(range(max_new_token)):
193
  if finish.all():
194
  continue
195
-
196
- model_input = self.prepare_inputs_for_generation(inputs_ids,
197
- outputs.past_key_values if i!=0 else None,
198
- attention_mask_cache[:, :inputs_ids.shape[1]], use_cache=True)
199
-
 
 
 
200
  if i == 0:
201
- model_input['inputs_embeds'] = emb
202
  else:
203
  if infer_text:
204
- model_input['inputs_embeds'] = self.emb_text(model_input['input_ids'][:,:,0])
 
 
205
  else:
206
- code_emb = [self.emb_code[i](model_input['input_ids'][:,:,i]) for i in range(self.num_vq)]
207
- model_input['inputs_embeds'] = torch.stack(code_emb, 3).sum(3)
208
-
209
- model_input['input_ids'] = None
 
 
 
210
  outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
211
  attentions.append(outputs.attentions)
212
- hidden_states = outputs[0] # 🐻
213
  if return_hidden:
214
  hiddens.append(hidden_states[:, -1])
215
 
216
  with P.cached():
217
  if infer_text:
218
- logits = self.head_text(hidden_states)
219
  else:
220
- logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
221
-
 
 
 
 
 
 
222
  logits = logits[:, -1].float()
223
 
224
  if not infer_text:
225
  logits = rearrange(logits, "b c n -> (b n) c")
226
- logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
 
 
227
  else:
228
  logits_token = inputs_ids[:, start_idx:, 0]
229
-
230
  logits = logits / temperature
231
-
232
  for logitsProcessors in LogitsProcessors:
233
  logits = logitsProcessors(logits_token, logits)
234
-
235
  for logitsWarpers in LogitsWarpers:
236
  logits = logitsWarpers(logits_token, logits)
237
-
238
  if i < min_new_token:
239
  logits[:, eos_token] = -torch.inf
240
-
241
  scores = F.softmax(logits, dim=-1)
242
-
243
  idx_next = torch.multinomial(scores, num_samples=1)
244
-
245
  if not infer_text:
246
  idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
247
  finish = finish | (idx_next == eos_token).any(1)
248
  inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
249
  else:
250
  finish = finish | (idx_next == eos_token).any(1)
251
- inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq)], 1)
 
 
 
 
 
 
252
 
253
  end_idx = end_idx + (~finish).int()
254
-
255
- inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
 
 
 
256
  inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
257
-
258
  if return_hidden:
259
  hiddens = torch.stack(hiddens, 1)
260
  hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
261
-
262
  if not finish.all():
263
- self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
264
-
 
 
265
  return {
266
- 'ids': inputs_ids,
267
- 'attentions': attentions,
268
- 'hiddens':hiddens,
269
- }
 
1
  import os
2
+
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
 
5
  import logging
 
 
 
6
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
  import torch.nn.utils.parametrize as P
11
+ from einops import rearrange
12
  from torch.nn.utils.parametrizations import weight_norm
13
+ from tqdm import tqdm
14
+ from transformers import LlamaConfig, LlamaModel
15
+ from transformers.cache_utils import Cache
16
+
17
+
18
  class LlamaMLP(nn.Module):
19
  def __init__(self, hidden_size, intermediate_size):
20
  super().__init__()
 
28
  def forward(self, x):
29
  down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
30
  return down_proj
31
+
32
+
33
  class GPT_warpper(nn.Module):
34
  def __init__(
35
+ self,
36
+ gpt_config,
37
  num_audio_tokens,
38
  num_text_tokens,
39
  num_vq=4,
40
  **kwargs,
41
+ ):
42
  super().__init__()
43
 
44
  self.logger = logging.getLogger(__name__)
45
  self.gpt = self.build_model(gpt_config)
46
+ self.model_dim = self.gpt.config.hidden_size
47
 
48
  self.num_vq = num_vq
49
+ self.emb_code = nn.ModuleList(
50
+ [nn.Embedding(num_audio_tokens, self.model_dim) for i in range(self.num_vq)]
51
+ )
52
  self.emb_text = nn.Embedding(num_text_tokens, self.model_dim)
53
+ self.head_text = weight_norm(
54
+ nn.Linear(self.model_dim, num_text_tokens, bias=False), name="weight"
55
+ )
56
+ self.head_code = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ nn.Linear(self.model_dim, num_audio_tokens, bias=False),
60
+ name="weight",
61
+ )
62
+ for i in range(self.num_vq)
63
+ ]
64
+ )
65
 
66
  def build_model(self, config):
67
+
68
  configuration = LlamaConfig(**config)
69
  model = LlamaModel(configuration)
70
  del model.embed_tokens
71
+
72
  return model
73
+
74
  def get_emb(self, input_ids, text_mask, **kwargs):
75
 
76
  emb_text = self.emb_text(input_ids[text_mask][:, 0])
77
+
78
+ emb_code = [
79
+ self.emb_code[i](input_ids[~text_mask][:, i]) for i in range(self.num_vq)
80
+ ]
81
  emb_code = torch.stack(emb_code, 2).sum(2)
82
+
83
+ emb = torch.zeros(
84
+ (input_ids.shape[:-1]) + (emb_text.shape[-1],),
85
+ device=emb_text.device,
86
+ dtype=emb_text.dtype,
87
+ )
88
  emb[text_mask] = emb_text
89
  emb[~text_mask] = emb_code.to(emb.dtype)
90
+
91
  return emb
92
+
93
  def prepare_inputs_for_generation(
94
+ self,
95
+ input_ids,
96
+ past_key_values=None,
97
+ attention_mask=None,
98
+ inputs_embeds=None,
99
+ cache_position=None,
100
+ **kwargs,
101
  ):
102
  # With static cache, the `past_key_values` is None
103
  # TODO joao: standardize interface for the different Cache classes and remove of this if
104
  has_static_cache = False
105
  if past_key_values is None:
106
+ past_key_values = getattr(
107
+ self.gpt.layers[0].self_attn, "past_key_value", None
108
+ )
109
  has_static_cache = past_key_values is not None
110
 
111
  past_length = 0
112
  if past_key_values is not None:
113
  if isinstance(past_key_values, Cache):
114
+ past_length = (
115
+ cache_position[0]
116
+ if cache_position is not None
117
+ else past_key_values.get_seq_length()
118
+ )
119
  max_cache_length = (
120
+ torch.tensor(
121
+ past_key_values.get_max_length(), device=input_ids.device
122
+ )
123
  if past_key_values.get_max_length() is not None
124
  else None
125
  )
126
+ cache_length = (
127
+ past_length
128
+ if max_cache_length is None
129
+ else torch.min(max_cache_length, past_length)
130
+ )
131
  # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
132
  else:
133
  cache_length = past_length = past_key_values[0][0].shape[2]
 
137
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
138
  # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
139
  # input)
140
+ if (
141
+ attention_mask is not None
142
+ and attention_mask.shape[1] > input_ids.shape[1]
143
+ ):
144
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
145
  # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
146
  # input_ids based on the past_length.
 
173
  # TODO: use `next_tokens` directly instead.
174
  model_inputs = {"input_ids": input_ids.contiguous()}
175
 
176
+ input_length = (
177
+ position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
178
+ )
179
  if cache_position is None:
180
+ cache_position = torch.arange(
181
+ past_length, past_length + input_length, device=input_ids.device
182
+ )
183
  else:
184
  cache_position = cache_position[-input_length:]
185
 
 
196
  }
197
  )
198
  return model_inputs
199
+
200
  def generate(
201
+ self,
202
+ emb,
203
+ inputs_ids,
204
+ temperature,
205
+ eos_token,
206
+ attention_mask=None,
207
+ max_new_token=2048,
208
+ min_new_token=0,
209
+ LogitsWarpers=[],
210
+ LogitsProcessors=[],
211
  infer_text=False,
212
  return_attn=False,
213
  return_hidden=False,
214
+ disable_tqdm=False,
215
  ):
216
  if disable_tqdm:
217
  tqdm = lambda x: x
218
  else:
219
  from tqdm import tqdm
220
+
221
+ with torch.no_grad():
222
+
223
  attentions = []
224
  hiddens = []
225
+
226
+ start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
227
+ inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
228
+ )
229
  finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
230
+
231
  temperature = temperature[None].expand(inputs_ids.shape[0], -1)
232
  temperature = rearrange(temperature, "b n -> (b n) 1")
233
 
234
+ attention_mask_cache = torch.ones(
235
+ (
236
+ inputs_ids.shape[0],
237
+ inputs_ids.shape[1] + max_new_token,
238
+ ),
239
+ dtype=torch.bool,
240
+ device=inputs_ids.device,
241
+ )
242
  if attention_mask is not None:
243
+ attention_mask_cache[:, : attention_mask.shape[1]] = attention_mask
244
+
245
  for i in tqdm(range(max_new_token)):
246
  if finish.all():
247
  continue
248
+
249
+ model_input = self.prepare_inputs_for_generation(
250
+ inputs_ids,
251
+ outputs.past_key_values if i != 0 else None,
252
+ attention_mask_cache[:, : inputs_ids.shape[1]],
253
+ use_cache=True,
254
+ )
255
+
256
  if i == 0:
257
+ model_input["inputs_embeds"] = emb
258
  else:
259
  if infer_text:
260
+ model_input["inputs_embeds"] = self.emb_text(
261
+ model_input["input_ids"][:, :, 0]
262
+ )
263
  else:
264
+ code_emb = [
265
+ self.emb_code[i](model_input["input_ids"][:, :, i])
266
+ for i in range(self.num_vq)
267
+ ]
268
+ model_input["inputs_embeds"] = torch.stack(code_emb, 3).sum(3)
269
+
270
+ model_input["input_ids"] = None
271
  outputs = self.gpt.forward(**model_input, output_attentions=return_attn)
272
  attentions.append(outputs.attentions)
273
+ hidden_states = outputs[0] # 🐻
274
  if return_hidden:
275
  hiddens.append(hidden_states[:, -1])
276
 
277
  with P.cached():
278
  if infer_text:
279
+ logits = self.head_text(hidden_states)
280
  else:
281
+ logits = torch.stack(
282
+ [
283
+ self.head_code[i](hidden_states)
284
+ for i in range(self.num_vq)
285
+ ],
286
+ 3,
287
+ )
288
+
289
  logits = logits[:, -1].float()
290
 
291
  if not infer_text:
292
  logits = rearrange(logits, "b c n -> (b n) c")
293
+ logits_token = rearrange(
294
+ inputs_ids[:, start_idx:], "b c n -> (b n) c"
295
+ )
296
  else:
297
  logits_token = inputs_ids[:, start_idx:, 0]
298
+
299
  logits = logits / temperature
300
+
301
  for logitsProcessors in LogitsProcessors:
302
  logits = logitsProcessors(logits_token, logits)
303
+
304
  for logitsWarpers in LogitsWarpers:
305
  logits = logitsWarpers(logits_token, logits)
306
+
307
  if i < min_new_token:
308
  logits[:, eos_token] = -torch.inf
309
+
310
  scores = F.softmax(logits, dim=-1)
311
+
312
  idx_next = torch.multinomial(scores, num_samples=1)
313
+
314
  if not infer_text:
315
  idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
316
  finish = finish | (idx_next == eos_token).any(1)
317
  inputs_ids = torch.cat([inputs_ids, idx_next.unsqueeze(1)], 1)
318
  else:
319
  finish = finish | (idx_next == eos_token).any(1)
320
+ inputs_ids = torch.cat(
321
+ [
322
+ inputs_ids,
323
+ idx_next.unsqueeze(-1).expand(-1, -1, self.num_vq),
324
+ ],
325
+ 1,
326
+ )
327
 
328
  end_idx = end_idx + (~finish).int()
329
+
330
+ inputs_ids = [
331
+ inputs_ids[idx, start_idx : start_idx + i]
332
+ for idx, i in enumerate(end_idx.int())
333
+ ]
334
  inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
335
+
336
  if return_hidden:
337
  hiddens = torch.stack(hiddens, 1)
338
  hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]
339
+
340
  if not finish.all():
341
+ self.logger.warn(
342
+ f"Incomplete result. hit max_new_token: {max_new_token}"
343
+ )
344
+
345
  return {
346
+ "ids": inputs_ids,
347
+ "attentions": attentions,
348
+ "hiddens": hiddens,
349
+ }
modules/ChatTTS/ChatTTS/utils/infer_utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import re
 
2
  import torch
3
  import torch.nn.functional as F
4
 
 
1
  import re
2
+
3
  import torch
4
  import torch.nn.functional as F
5
 
modules/ChatTTS/ChatTTS/utils/io_utils.py CHANGED
@@ -1,14 +1,14 @@
1
-
2
- import os
3
  import logging
 
 
4
 
5
  def get_latest_modified_file(directory):
6
  logger = logging.getLogger(__name__)
7
-
8
- files = [os.path.join(directory, f) for f in os.listdir(directory)]
9
  if not files:
10
- logger.log(logging.WARNING, f'No files found in the directory: {directory}')
11
  return None
12
  latest_file = max(files, key=os.path.getmtime)
13
 
14
- return latest_file
 
 
 
1
  import logging
2
+ import os
3
+
4
 
5
  def get_latest_modified_file(directory):
6
  logger = logging.getLogger(__name__)
7
+
8
+ files = [os.path.join(directory, f) for f in os.listdir(directory)]
9
  if not files:
10
+ logger.log(logging.WARNING, f"No files found in the directory: {directory}")
11
  return None
12
  latest_file = max(files, key=os.path.getmtime)
13
 
14
+ return latest_file
modules/Denoiser/AudioDenoiser.py CHANGED
@@ -1,15 +1,17 @@
1
  import logging
2
  import math
3
  from typing import Union
 
4
  import torch
5
  import torchaudio
6
- from torch import nn
7
- from audio_denoiser.helpers.torch_helper import batched_apply
8
- from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
9
  from audio_denoiser.helpers.audio_helper import (
10
  create_spectrogram,
11
  reconstruct_from_spectrogram,
12
  )
 
 
 
 
13
 
14
  _expected_t_std = 0.23
15
  _recommended_backend = "soundfile"
 
1
  import logging
2
  import math
3
  from typing import Union
4
+
5
  import torch
6
  import torchaudio
 
 
 
7
  from audio_denoiser.helpers.audio_helper import (
8
  create_spectrogram,
9
  reconstruct_from_spectrogram,
10
  )
11
+ from audio_denoiser.helpers.torch_helper import batched_apply
12
+ from torch import nn
13
+
14
+ from modules.Denoiser.AudioNosiseModel import load_audio_denosier_model
15
 
16
  _expected_t_std = 0.23
17
  _recommended_backend = "soundfile"
modules/Denoiser/AudioNosiseModel.py CHANGED
@@ -1,12 +1,11 @@
 
 
1
  import torch
2
  import torch.nn as nn
3
-
4
  from audio_denoiser.modules.Permute import Permute
5
  from audio_denoiser.modules.SimpleRoberta import SimpleRoberta
6
  from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler
7
 
8
- import json
9
-
10
 
11
  class AudioNoiseModel(nn.Module):
12
  def __init__(self, config: dict):
 
1
+ import json
2
+
3
  import torch
4
  import torch.nn as nn
 
5
  from audio_denoiser.modules.Permute import Permute
6
  from audio_denoiser.modules.SimpleRoberta import SimpleRoberta
7
  from audio_denoiser.modules.SpectrogramScaler import SpectrogramScaler
8
 
 
 
9
 
10
  class AudioNoiseModel(nn.Module):
11
  def __init__(self, config: dict):
modules/Enhancer/ResembleEnhance.py CHANGED
@@ -1,20 +1,17 @@
1
  import gc
 
 
 
2
  from typing import Literal
3
 
4
  import numpy as np
 
 
5
  from modules.devices import devices
6
  from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
7
  from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
8
  from modules.repos_static.resemble_enhance.inference import inference
9
-
10
- import torch
11
-
12
  from modules.utils.constants import MODELS_DIR
13
- from pathlib import Path
14
-
15
- from threading import Lock
16
-
17
- import logging
18
 
19
  logger = logging.getLogger(__name__)
20
 
@@ -155,8 +152,8 @@ def apply_audio_enhance(
155
 
156
 
157
  if __name__ == "__main__":
158
- import torchaudio
159
  import gradio as gr
 
160
 
161
  device = torch.device("cuda")
162
 
 
1
  import gc
2
+ import logging
3
+ from pathlib import Path
4
+ from threading import Lock
5
  from typing import Literal
6
 
7
  import numpy as np
8
+ import torch
9
+
10
  from modules.devices import devices
11
  from modules.repos_static.resemble_enhance.enhancer.enhancer import Enhancer
12
  from modules.repos_static.resemble_enhance.enhancer.hparams import HParams
13
  from modules.repos_static.resemble_enhance.inference import inference
 
 
 
14
  from modules.utils.constants import MODELS_DIR
 
 
 
 
 
15
 
16
  logger = logging.getLogger(__name__)
17
 
 
152
 
153
 
154
  if __name__ == "__main__":
 
155
  import gradio as gr
156
+ import torchaudio
157
 
158
  device = torch.device("cuda")
159
 
modules/SentenceSplitter.py CHANGED
@@ -1,4 +1,5 @@
1
  import re
 
2
  import zhon
3
 
4
 
 
1
  import re
2
+
3
  import zhon
4
 
5
 
modules/SynthesizeSegments.py CHANGED
@@ -1,31 +1,37 @@
1
  import copy
 
 
2
  import re
 
 
 
3
  from box import Box
4
  from pydub import AudioSegment
5
- from typing import List, Union
6
- from scipy.io.wavfile import write
7
- import io
8
- from modules.SentenceSplitter import SentenceSplitter
9
- from modules.api.utils import calc_spk_style
10
- from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
11
- from modules.utils import rng
12
- from modules.utils.audio import time_stretch, pitch_shift
13
  from modules import generate_audio
 
14
  from modules.normalization import text_normalize
15
- import logging
16
- import json
17
-
18
- from modules.speaker import Speaker, speaker_mgr
 
19
 
20
  logger = logging.getLogger(__name__)
21
 
22
 
23
- def audio_data_to_segment(audio_data, sr):
24
- byte_io = io.BytesIO()
25
- write(byte_io, rate=sr, data=audio_data)
26
- byte_io.seek(0)
27
-
28
- return AudioSegment.from_file(byte_io, format="wav")
 
 
 
 
 
 
29
 
30
 
31
  def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
 
1
  import copy
2
+ import json
3
+ import logging
4
  import re
5
+ from typing import List, Union
6
+
7
+ import numpy as np
8
  from box import Box
9
  from pydub import AudioSegment
10
+
 
 
 
 
 
 
 
11
  from modules import generate_audio
12
+ from modules.api.utils import calc_spk_style
13
  from modules.normalization import text_normalize
14
+ from modules.SentenceSplitter import SentenceSplitter
15
+ from modules.speaker import Speaker
16
+ from modules.ssml_parser.SSMLParser import SSMLBreak, SSMLContext, SSMLSegment
17
+ from modules.utils import rng
18
+ from modules.utils.audio import pitch_shift, time_stretch
19
 
20
  logger = logging.getLogger(__name__)
21
 
22
 
23
+ def audio_data_to_segment(audio_data: np.ndarray, sr: int):
24
+ """
25
+ optimize: https://github.com/lenML/ChatTTS-Forge/issues/57
26
+ """
27
+ audio_data = (audio_data * 32767).astype(np.int16)
28
+ audio_segment = AudioSegment(
29
+ audio_data.tobytes(),
30
+ frame_rate=sr,
31
+ sample_width=audio_data.dtype.itemsize,
32
+ channels=1,
33
+ )
34
+ return audio_segment
35
 
36
 
37
  def combine_audio_segments(audio_segments: list[AudioSegment]) -> AudioSegment:
modules/api/Api.py CHANGED
@@ -1,12 +1,10 @@
1
- from fastapi import FastAPI
2
- from fastapi.middleware.cors import CORSMiddleware
3
-
4
  import logging
5
 
 
 
6
  from fastapi.staticfiles import StaticFiles
7
 
8
- import fnmatch
9
-
10
 
11
  def is_excluded(path, exclude_patterns):
12
  """
 
1
+ import fnmatch
 
 
2
  import logging
3
 
4
+ from fastapi import FastAPI
5
+ from fastapi.middleware.cors import CORSMiddleware
6
  from fastapi.staticfiles import StaticFiles
7
 
 
 
8
 
9
  def is_excluded(path, exclude_patterns):
10
  """
modules/api/api_setup.py CHANGED
@@ -1,26 +1,24 @@
1
- import logging
2
- from modules.Enhancer.ResembleEnhance import load_enhancer
3
- from modules.devices import devices
4
  import argparse
 
5
 
6
- from modules import config
7
- from modules.models import load_chat_tts
8
- from modules.utils import env
9
- from modules import generate_audio
10
  from modules.api.Api import APIManager
11
-
12
  from modules.api.impl import (
13
- style_api,
14
- tts_api,
15
- ssml_api,
16
  google_api,
 
17
  openai_api,
 
18
  refiner_api,
19
  speaker_api,
20
- ping_api,
21
- models_api,
 
22
  xtts_v2_api,
23
  )
 
 
 
 
24
 
25
  logger = logging.getLogger(__name__)
26
 
 
 
 
 
1
  import argparse
2
+ import logging
3
 
4
+ from modules import config, generate_audio
 
 
 
5
  from modules.api.Api import APIManager
 
6
  from modules.api.impl import (
 
 
 
7
  google_api,
8
+ models_api,
9
  openai_api,
10
+ ping_api,
11
  refiner_api,
12
  speaker_api,
13
+ ssml_api,
14
+ style_api,
15
+ tts_api,
16
  xtts_v2_api,
17
  )
18
+ from modules.devices import devices
19
+ from modules.Enhancer.ResembleEnhance import load_enhancer
20
+ from modules.models import load_chat_tts
21
+ from modules.utils import env
22
 
23
  logger = logging.getLogger(__name__)
24
 
modules/api/impl/google_api.py CHANGED
@@ -1,22 +1,18 @@
1
  from typing import Union
2
- from fastapi import HTTPException
3
 
 
4
  from pydantic import BaseModel
5
 
6
-
7
  from modules.api.Api import APIManager
8
  from modules.api.impl.handler.SSMLHandler import SSMLHandler
9
  from modules.api.impl.handler.TTSHandler import TTSHandler
10
  from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
11
  from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
12
  from modules.api.impl.model.enhancer_model import EnhancerConfig
13
-
14
  from modules.speaker import Speaker, speaker_mgr
15
 
16
 
17
- from modules.api import utils as api_utils
18
-
19
-
20
  class SynthesisInput(BaseModel):
21
  text: Union[str, None] = None
22
  ssml: Union[str, None] = None
 
1
  from typing import Union
 
2
 
3
+ from fastapi import HTTPException
4
  from pydantic import BaseModel
5
 
6
+ from modules.api import utils as api_utils
7
  from modules.api.Api import APIManager
8
  from modules.api.impl.handler.SSMLHandler import SSMLHandler
9
  from modules.api.impl.handler.TTSHandler import TTSHandler
10
  from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
11
  from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
12
  from modules.api.impl.model.enhancer_model import EnhancerConfig
 
13
  from modules.speaker import Speaker, speaker_mgr
14
 
15
 
 
 
 
16
  class SynthesisInput(BaseModel):
17
  text: Union[str, None] = None
18
  ssml: Union[str, None] = None
modules/api/impl/handler/AudioHandler.py CHANGED
@@ -1,10 +1,11 @@
1
  import base64
2
  import io
 
3
  import numpy as np
4
  import soundfile as sf
5
 
6
- from modules.api.impl.model.audio_model import AudioFormat
7
  from modules.api import utils as api_utils
 
8
 
9
 
10
  class AudioHandler:
 
1
  import base64
2
  import io
3
+
4
  import numpy as np
5
  import soundfile as sf
6
 
 
7
  from modules.api import utils as api_utils
8
+ from modules.api.impl.model.audio_model import AudioFormat
9
 
10
 
11
  class AudioHandler:
modules/api/impl/handler/SSMLHandler.py CHANGED
@@ -1,14 +1,14 @@
1
- from fastapi import HTTPException
2
  import numpy as np
 
3
 
4
- from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
5
- from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
6
  from modules.api.impl.handler.AudioHandler import AudioHandler
7
  from modules.api.impl.model.audio_model import AdjustConfig
8
  from modules.api.impl.model.chattts_model import InferConfig
9
  from modules.api.impl.model.enhancer_model import EnhancerConfig
 
10
  from modules.normalization import text_normalize
11
  from modules.ssml_parser.SSMLParser import create_ssml_parser
 
12
  from modules.utils import audio
13
 
14
 
 
 
1
  import numpy as np
2
+ from fastapi import HTTPException
3
 
 
 
4
  from modules.api.impl.handler.AudioHandler import AudioHandler
5
  from modules.api.impl.model.audio_model import AdjustConfig
6
  from modules.api.impl.model.chattts_model import InferConfig
7
  from modules.api.impl.model.enhancer_model import EnhancerConfig
8
+ from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
9
  from modules.normalization import text_normalize
10
  from modules.ssml_parser.SSMLParser import create_ssml_parser
11
+ from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
12
  from modules.utils import audio
13
 
14
 
modules/api/impl/handler/TTSHandler.py CHANGED
@@ -1,13 +1,13 @@
1
  import numpy as np
2
- from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
3
  from modules.api.impl.handler.AudioHandler import AudioHandler
4
  from modules.api.impl.model.audio_model import AdjustConfig
5
  from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
6
  from modules.api.impl.model.enhancer_model import EnhancerConfig
 
7
  from modules.normalization import text_normalize
8
  from modules.speaker import Speaker
9
  from modules.synthesize_audio import synthesize_audio
10
-
11
  from modules.utils.audio import apply_prosody_to_audio_data
12
 
13
 
 
1
  import numpy as np
2
+
3
  from modules.api.impl.handler.AudioHandler import AudioHandler
4
  from modules.api.impl.model.audio_model import AdjustConfig
5
  from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
6
  from modules.api.impl.model.enhancer_model import EnhancerConfig
7
+ from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
8
  from modules.normalization import text_normalize
9
  from modules.speaker import Speaker
10
  from modules.synthesize_audio import synthesize_audio
 
11
  from modules.utils.audio import apply_prosody_to_audio_data
12
 
13
 
modules/api/impl/model/enhancer_model.py CHANGED
@@ -1,4 +1,5 @@
1
  from typing import Literal
 
2
  from pydantic import BaseModel
3
 
4
 
 
1
  from typing import Literal
2
+
3
  from pydantic import BaseModel
4
 
5
 
modules/api/impl/models_api.py CHANGED
@@ -1,6 +1,6 @@
1
- from modules.Enhancer.ResembleEnhance import reload_enhancer, unload_enhancer
2
  from modules.api import utils as api_utils
3
  from modules.api.Api import APIManager
 
4
  from modules.models import reload_chat_tts, unload_chat_tts
5
 
6
 
 
 
1
  from modules.api import utils as api_utils
2
  from modules.api.Api import APIManager
3
+ from modules.Enhancer.ResembleEnhance import reload_enhancer, unload_enhancer
4
  from modules.models import reload_chat_tts, unload_chat_tts
5
 
6
 
modules/api/impl/openai_api.py CHANGED
@@ -1,23 +1,18 @@
1
- from fastapi import File, Form, HTTPException, Body, UploadFile
2
 
 
 
3
  from numpy import clip
4
  from pydantic import BaseModel, Field
5
- from fastapi.responses import StreamingResponse
6
-
7
 
 
 
8
  from modules.api.impl.handler.TTSHandler import TTSHandler
9
  from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
10
  from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
11
  from modules.api.impl.model.enhancer_model import EnhancerConfig
12
-
13
-
14
- from typing import List, Optional
15
-
16
- from modules.api import utils as api_utils
17
- from modules.api.Api import APIManager
18
-
19
- from modules.speaker import Speaker, speaker_mgr
20
  from modules.data import styles_mgr
 
21
 
22
 
23
  class AudioSpeechRequest(BaseModel):
 
1
+ from typing import List, Optional
2
 
3
+ from fastapi import Body, File, Form, HTTPException, UploadFile
4
+ from fastapi.responses import StreamingResponse
5
  from numpy import clip
6
  from pydantic import BaseModel, Field
 
 
7
 
8
+ from modules.api import utils as api_utils
9
+ from modules.api.Api import APIManager
10
  from modules.api.impl.handler.TTSHandler import TTSHandler
11
  from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
12
  from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
13
  from modules.api.impl.model.enhancer_model import EnhancerConfig
 
 
 
 
 
 
 
 
14
  from modules.data import styles_mgr
15
+ from modules.speaker import Speaker, speaker_mgr
16
 
17
 
18
  class AudioSpeechRequest(BaseModel):
modules/api/impl/ping_api.py CHANGED
@@ -1,8 +1,7 @@
 
1
  from modules.api import utils as api_utils
2
  from modules.api.Api import APIManager
3
 
4
- from modules import config
5
-
6
 
7
  def setup(app: APIManager):
8
  @app.get("/v1/ping", response_model=api_utils.BaseResponse)
 
1
+ from modules import config
2
  from modules.api import utils as api_utils
3
  from modules.api.Api import APIManager
4
 
 
 
5
 
6
  def setup(app: APIManager):
7
  @app.get("/v1/ping", response_model=api_utils.BaseResponse)
modules/api/impl/refiner_api.py CHANGED
@@ -1,10 +1,7 @@
1
  from fastapi import HTTPException
2
-
3
  from pydantic import BaseModel
4
 
5
-
6
  from modules import refiner
7
-
8
  from modules.api import utils as api_utils
9
  from modules.api.Api import APIManager
10
  from modules.normalization import text_normalize
 
1
  from fastapi import HTTPException
 
2
  from pydantic import BaseModel
3
 
 
4
  from modules import refiner
 
5
  from modules.api import utils as api_utils
6
  from modules.api.Api import APIManager
7
  from modules.normalization import text_normalize
modules/api/impl/speaker_api.py CHANGED
@@ -1,9 +1,10 @@
 
1
  from fastapi import HTTPException
2
  from pydantic import BaseModel
3
- import torch
4
- from modules.speaker import speaker_mgr
5
  from modules.api import utils as api_utils
6
  from modules.api.Api import APIManager
 
7
 
8
 
9
  class CreateSpeaker(BaseModel):
 
1
+ import torch
2
  from fastapi import HTTPException
3
  from pydantic import BaseModel
4
+
 
5
  from modules.api import utils as api_utils
6
  from modules.api.Api import APIManager
7
+ from modules.speaker import speaker_mgr
8
 
9
 
10
  class CreateSpeaker(BaseModel):
modules/api/impl/ssml_api.py CHANGED
@@ -1,19 +1,14 @@
1
- from fastapi import HTTPException, Body
2
- from fastapi.responses import StreamingResponse
3
-
4
  from pydantic import BaseModel
5
- from fastapi.responses import FileResponse
6
-
7
 
 
8
  from modules.api.impl.handler.SSMLHandler import SSMLHandler
9
  from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
10
  from modules.api.impl.model.chattts_model import InferConfig
11
  from modules.api.impl.model.enhancer_model import EnhancerConfig
12
 
13
 
14
- from modules.api.Api import APIManager
15
-
16
-
17
  class SSMLRequest(BaseModel):
18
  ssml: str
19
  format: AudioFormat = "mp3"
 
1
+ from fastapi import Body, HTTPException
2
+ from fastapi.responses import FileResponse, StreamingResponse
 
3
  from pydantic import BaseModel
 
 
4
 
5
+ from modules.api.Api import APIManager
6
  from modules.api.impl.handler.SSMLHandler import SSMLHandler
7
  from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
8
  from modules.api.impl.model.chattts_model import InferConfig
9
  from modules.api.impl.model.enhancer_model import EnhancerConfig
10
 
11
 
 
 
 
12
  class SSMLRequest(BaseModel):
13
  ssml: str
14
  format: AudioFormat = "mp3"
modules/api/impl/style_api.py CHANGED
@@ -1,6 +1,6 @@
1
- from modules.data import styles_mgr
2
  from modules.api import utils as api_utils
3
  from modules.api.Api import APIManager
 
4
 
5
 
6
  async def list_styles():
 
 
1
  from modules.api import utils as api_utils
2
  from modules.api.Api import APIManager
3
+ from modules.data import styles_mgr
4
 
5
 
6
  async def list_styles():
modules/api/impl/tts_api.py CHANGED
@@ -1,17 +1,13 @@
1
  from fastapi import Depends, HTTPException, Query
2
- from fastapi.responses import StreamingResponse
3
-
4
  from pydantic import BaseModel
5
- from fastapi.responses import FileResponse
6
-
7
 
 
 
8
  from modules.api.impl.handler.TTSHandler import TTSHandler
9
  from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
10
  from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
11
  from modules.api.impl.model.enhancer_model import EnhancerConfig
12
-
13
- from modules.api import utils as api_utils
14
- from modules.api.Api import APIManager
15
  from modules.speaker import Speaker
16
 
17
 
 
1
  from fastapi import Depends, HTTPException, Query
2
+ from fastapi.responses import FileResponse, StreamingResponse
 
3
  from pydantic import BaseModel
 
 
4
 
5
+ from modules.api import utils as api_utils
6
+ from modules.api.Api import APIManager
7
  from modules.api.impl.handler.TTSHandler import TTSHandler
8
  from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
9
  from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
10
  from modules.api.impl.model.enhancer_model import EnhancerConfig
 
 
 
11
  from modules.speaker import Speaker
12
 
13
 
modules/api/impl/xtts_v2_api.py CHANGED
@@ -1,19 +1,17 @@
1
  import io
 
 
 
2
  from fastapi import HTTPException
3
  from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel
5
- from modules.api import utils as api_utils
6
- from modules.api.Api import APIManager
7
-
8
- import soundfile as sf
9
 
10
  from modules import config
 
 
11
  from modules.normalization import text_normalize
12
  from modules.speaker import speaker_mgr
13
  from modules.synthesize_audio import synthesize_audio
14
-
15
- import logging
16
-
17
  from modules.utils.audio import apply_prosody_to_audio_data
18
 
19
  logger = logging.getLogger(__name__)
 
1
  import io
2
+ import logging
3
+
4
+ import soundfile as sf
5
  from fastapi import HTTPException
6
  from fastapi.responses import StreamingResponse
7
  from pydantic import BaseModel
 
 
 
 
8
 
9
  from modules import config
10
+ from modules.api import utils as api_utils
11
+ from modules.api.Api import APIManager
12
  from modules.normalization import text_normalize
13
  from modules.speaker import speaker_mgr
14
  from modules.synthesize_audio import synthesize_audio
 
 
 
15
  from modules.utils.audio import apply_prosody_to_audio_data
16
 
17
  logger = logging.getLogger(__name__)
modules/api/utils.py CHANGED
@@ -1,14 +1,10 @@
1
- from pydantic import BaseModel
2
  from typing import Any, Union
3
 
4
-
5
- from modules.speaker import speaker_mgr
6
-
7
-
8
- from modules.data import styles_mgr
9
-
10
  from pydub import AudioSegment
11
 
 
 
12
  from modules.ssml import merge_prompt
13
 
14
 
 
 
1
  from typing import Any, Union
2
 
3
+ from pydantic import BaseModel
 
 
 
 
 
4
  from pydub import AudioSegment
5
 
6
+ from modules.data import styles_mgr
7
+ from modules.speaker import speaker_mgr
8
  from modules.ssml import merge_prompt
9
 
10
 
modules/api/worker.py CHANGED
@@ -1,6 +1,7 @@
1
  import argparse
2
  import logging
3
  import os
 
4
  import dotenv
5
  from fastapi import FastAPI
6
 
@@ -12,6 +13,7 @@ logging.basicConfig(
12
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
13
  )
14
 
 
15
  from modules.api.api_setup import (
16
  process_api_args,
17
  process_model_args,
@@ -20,7 +22,6 @@ from modules.api.api_setup import (
20
  setup_uvicon_args,
21
  )
22
  from modules.api.app_config import app_description, app_title, app_version
23
- from modules import config
24
  from modules.utils.torch_opt import configure_torch_optimizations
25
 
26
  dotenv.load_dotenv(
 
1
  import argparse
2
  import logging
3
  import os
4
+
5
  import dotenv
6
  from fastapi import FastAPI
7
 
 
13
  format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
14
  )
15
 
16
+ from modules import config
17
  from modules.api.api_setup import (
18
  process_api_args,
19
  process_model_args,
 
22
  setup_uvicon_args,
23
  )
24
  from modules.api.app_config import app_description, app_title, app_version
 
25
  from modules.utils.torch_opt import configure_torch_optimizations
26
 
27
  dotenv.load_dotenv(
modules/config.py CHANGED
@@ -1,9 +1,9 @@
1
  import sys
2
 
3
  import torch
4
- from modules.utils.JsonObject import JsonObject
5
 
6
- from modules.utils import git, ffmpeg
 
7
 
8
  # TODO impl RuntimeEnvVars() class
9
  runtime_env_vars = JsonObject({})
 
1
  import sys
2
 
3
  import torch
 
4
 
5
+ from modules.utils import ffmpeg, git
6
+ from modules.utils.JsonObject import JsonObject
7
 
8
  # TODO impl RuntimeEnvVars() class
9
  runtime_env_vars = JsonObject({})
modules/data.py CHANGED
@@ -1,6 +1,5 @@
1
  from modules.utils.CsvMgr import BaseManager
2
 
3
-
4
  # speakers_mgr = BaseManager("./data/speakers.csv")
5
  styles_mgr = BaseManager("./data/styles.csv")
6
 
 
1
  from modules.utils.CsvMgr import BaseManager
2
 
 
3
  # speakers_mgr = BaseManager("./data/speakers.csv")
4
  styles_mgr = BaseManager("./data/styles.csv")
5
 
modules/denoise.py CHANGED
@@ -1,15 +1,13 @@
1
  import os
2
  from typing import Union
3
 
 
4
  import torch
5
  import torchaudio
6
- from modules.Denoiser.AudioDenoiser import AudioDenoiser
7
-
8
- from modules.utils.constants import MODELS_DIR
9
 
 
10
  from modules.devices import devices
11
-
12
- import soundfile as sf
13
 
14
  ad: Union[AudioDenoiser, None] = None
15
 
 
1
  import os
2
  from typing import Union
3
 
4
+ import soundfile as sf
5
  import torch
6
  import torchaudio
 
 
 
7
 
8
+ from modules.Denoiser.AudioDenoiser import AudioDenoiser
9
  from modules.devices import devices
10
+ from modules.utils.constants import MODELS_DIR
 
11
 
12
  ad: Union[AudioDenoiser, None] = None
13
 
modules/devices/devices.py CHANGED
@@ -1,9 +1,10 @@
1
- from functools import lru_cache
2
  import sys
 
 
3
  import torch
4
- from modules import config
5
 
6
- import logging
7
 
8
  logger = logging.getLogger(__name__)
9
 
 
1
+ import logging
2
  import sys
3
+ from functools import lru_cache
4
+
5
  import torch
 
6
 
7
+ from modules import config
8
 
9
  logger = logging.getLogger(__name__)
10
 
modules/devices/mac_devices.py CHANGED
@@ -1,8 +1,9 @@
1
- import torch
2
  import logging
3
- from packaging import version
 
4
  import torch.backends
5
  import torch.backends.mps
 
6
 
7
  logger = logging.getLogger(__name__)
8
 
 
 
1
  import logging
2
+
3
+ import torch
4
  import torch.backends
5
  import torch.backends.mps
6
+ from packaging import version
7
 
8
  logger = logging.getLogger(__name__)
9
 
modules/ffmpeg_env.py CHANGED
@@ -1,6 +1,7 @@
 
1
  import os
 
2
  from modules.utils.constants import ROOT_DIR
3
- import logging
4
 
5
  logger = logging.getLogger(__name__)
6
 
 
1
+ import logging
2
  import os
3
+
4
  from modules.utils.constants import ROOT_DIR
 
5
 
6
  logger = logging.getLogger(__name__)
7
 
modules/finetune/train_speaker.py CHANGED
@@ -3,9 +3,10 @@ import torch.nn.functional as F
3
  import transformers
4
 
5
  from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config
6
- from modules.finetune.utils.output import get_ansi_len, output_iter, ansi
7
- from .utils.logger import MetricLogger
8
  from .utils.dataset import AudioCollator, XzListTar
 
9
  from .utils.model import quantize
10
 
11
  IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index
@@ -201,11 +202,13 @@ def train_speaker_embeddings(
201
  if __name__ == "__main__":
202
  import argparse
203
  import os
204
- import numpy as np
205
  import pathlib
206
- from modules.models import load_chat_tts
207
- from modules.devices import devices
 
208
  from modules import config
 
 
209
  from modules.speaker import Speaker
210
 
211
  config.runtime_env_vars.no_half = True
 
3
  import transformers
4
 
5
  from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config
6
+ from modules.finetune.utils.output import ansi, get_ansi_len, output_iter
7
+
8
  from .utils.dataset import AudioCollator, XzListTar
9
+ from .utils.logger import MetricLogger
10
  from .utils.model import quantize
11
 
12
  IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index
 
202
  if __name__ == "__main__":
203
  import argparse
204
  import os
 
205
  import pathlib
206
+
207
+ import numpy as np
208
+
209
  from modules import config
210
+ from modules.devices import devices
211
+ from modules.models import load_chat_tts
212
  from modules.speaker import Speaker
213
 
214
  config.runtime_env_vars.no_half = True
modules/finetune/utils/dataset.py CHANGED
@@ -1,21 +1,21 @@
1
- import os
2
  import functools
3
- import json
4
- import tarfile
5
  import io
 
6
  import logging
7
- import abc
 
8
  import typing
9
 
10
  import torch.utils.data
11
  import torchaudio
12
- from torchvision.datasets.utils import download_url
13
  import transformers
14
  import vocos
 
15
 
16
  from modules.ChatTTS.ChatTTS.utils.infer_utils import (
17
- count_invalid_characters,
18
  apply_character_map,
 
19
  )
20
 
21
 
 
1
+ import abc
2
  import functools
 
 
3
  import io
4
+ import json
5
  import logging
6
+ import os
7
+ import tarfile
8
  import typing
9
 
10
  import torch.utils.data
11
  import torchaudio
 
12
  import transformers
13
  import vocos
14
+ from torchvision.datasets.utils import download_url
15
 
16
  from modules.ChatTTS.ChatTTS.utils.infer_utils import (
 
17
  apply_character_map,
18
+ count_invalid_characters,
19
  )
20
 
21
 
modules/finetune/utils/logger.py CHANGED
@@ -3,15 +3,14 @@
3
  import statistics
4
  import time
5
  from collections import defaultdict, deque
6
- from tqdm import tqdm as tqdm_class
7
-
8
  from typing import Generator, Iterable, TypeVar
9
- from typing_extensions import Self
10
 
11
  import torch
12
  import torch.distributed as dist
 
 
13
 
14
- from .output import ansi, prints, get_ansi_len
15
 
16
  __all__ = ["SmoothedValue", "MetricLogger"]
17
 
 
3
  import statistics
4
  import time
5
  from collections import defaultdict, deque
 
 
6
  from typing import Generator, Iterable, TypeVar
 
7
 
8
  import torch
9
  import torch.distributed as dist
10
+ from tqdm import tqdm as tqdm_class
11
+ from typing_extensions import Self
12
 
13
+ from .output import ansi, get_ansi_len, prints
14
 
15
  __all__ = ["SmoothedValue", "MetricLogger"]
16
 
modules/generate_audio.py CHANGED
@@ -1,18 +1,15 @@
 
 
 
 
1
  import numpy as np
2
  import torch
3
 
4
- from modules.speaker import Speaker
5
- from modules.utils.SeedContext import SeedContext
6
-
7
- from modules import models, config
8
-
9
- import logging
10
- import gc
11
-
12
  from modules.devices import devices
13
- from typing import Union
14
-
15
  from modules.utils.cache import conditional_cache
 
16
 
17
  logger = logging.getLogger(__name__)
18
 
 
1
+ import gc
2
+ import logging
3
+ from typing import Union
4
+
5
  import numpy as np
6
  import torch
7
 
8
+ from modules import config, models
 
 
 
 
 
 
 
9
  from modules.devices import devices
10
+ from modules.speaker import Speaker
 
11
  from modules.utils.cache import conditional_cache
12
+ from modules.utils.SeedContext import SeedContext
13
 
14
  logger = logging.getLogger(__name__)
15
 
modules/models.py CHANGED
@@ -1,13 +1,13 @@
 
 
1
  import threading
 
2
  import torch
3
- from modules.ChatTTS import ChatTTS
4
  from modules import config
 
5
  from modules.devices import devices
6
 
7
- import logging
8
- import gc
9
-
10
-
11
  logger = logging.getLogger(__name__)
12
 
13
  chat_tts = None
 
1
+ import gc
2
+ import logging
3
  import threading
4
+
5
  import torch
6
+
7
  from modules import config
8
+ from modules.ChatTTS import ChatTTS
9
  from modules.devices import devices
10
 
 
 
 
 
11
  logger = logging.getLogger(__name__)
12
 
13
  chat_tts = None
modules/normalization.py CHANGED
@@ -1,9 +1,11 @@
 
1
  from functools import lru_cache
2
- from modules.utils.zh_normalization.text_normlization import *
3
  import emojiswitch
4
- from modules.utils.markdown import markdown_to_text
5
  from modules import models
6
- import re
 
7
 
8
  # 是否关闭 unk token 检查
9
  # NOTE: 单测的时候用于跳过模型加载
 
1
+ import re
2
  from functools import lru_cache
3
+
4
  import emojiswitch
5
+
6
  from modules import models
7
+ from modules.utils.markdown import markdown_to_text
8
+ from modules.utils.zh_normalization.text_normlization import *
9
 
10
  # 是否关闭 unk token 检查
11
  # NOTE: 单测的时候用于跳过模型加载
modules/prompts/news_oral_prompt.txt CHANGED
@@ -1,7 +1,7 @@
1
- # 任务要求
2
- 任务: 新闻稿口播化
3
 
4
- 你需要将一个新闻稿改写为口语化的口播文本
5
  同时,适当的添加一些 附语言 标签为文本增加多样性
6
 
7
  目前可以使用的附语言标签如下:
@@ -10,5 +10,24 @@
10
  - `[v_break]`: 表示有声停顿,如“嗯”、“啊”等
11
  - `[lbreak]`: 表示一个长停顿一般表示段落结束
12
 
13
- # 输入
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  {{USER_INPUT}}
 
1
+ #任务要求
2
+ 任务:新闻稿口播化
3
 
4
+ 你需要将一个新闻稿改写为口语化的口播文本,以提供给新闻主播在晚间新闻节目中播报
5
  同时,适当的添加一些 附语言 标签为文本增加多样性
6
 
7
  目前可以使用的附语言标签如下:
 
10
  - `[v_break]`: 表示有声停顿,如“嗯”、“啊”等
11
  - `[lbreak]`: 表示一个长停顿一般表示段落结束
12
 
13
+ # examples
14
+ ## case 1
15
+ - input: `天气预报显示,今天会有小雨,请大家出门时记得带伞。降温的天气也提醒我们要适时添衣保暖`
16
+ - output: `天气预报显示,今天会有小雨,请大家出门时记得带伞[uv_break]。那降温的天气[uv_break]也提醒我们要适时添衣保暖[lbreak]`
17
+
18
+ ## case 2
19
+ - input: `请注意,电梯将在下午两点进行例行维护,预计需要一个小时的时间,请大家在此期间使用楼梯`
20
+ - output: `请注意啊,这个电梯将在下午两点进行[uv_break]例行维护[uv_break],预计需要一个小时的时间[uv_break],请大家在此期间使用楼梯[lbreak]`
21
+
22
+ ## case 3
23
+ - input: `它的任务是简化记者编辑的工作流程。记者写稿时可以用标签来标明关键词、标题或主题。随着时间推移,数据积累到一定程度后,机器编辑就能自动识别这些标签`
24
+ - output: `它的任务呢是简化记者编辑的工作流程[uv_break]。记者写稿时呢可以用标签来标明关键词啊、标题啊或主题[uv_break]。那随着时间推移呢,数据积累到一定程度后[uv_break],机器编辑就能自动识别这些标签[uv_break]`
25
+
26
+ ## case 4
27
+ - input: `有一天,小明问他爸爸:“爸爸,我是不是傻孩子啊?”
28
+
29
+ 爸爸说:“傻孩子,你怎么会是傻孩子呢?”`
30
+ - output: `然后有一天呢,小明问他[uv_break]爸爸[uv_break],爸爸,我是不是傻孩[uv_break]子啊?爸爸说,傻孩[laugh]子啊,你怎么会是傻孩子呢[laugh]?`
31
+
32
+ # 用户输入
33
  {{USER_INPUT}}
modules/refiner.py CHANGED
@@ -1,10 +1,9 @@
1
  import numpy as np
2
  import torch
3
 
 
4
  from modules.utils.SeedContext import SeedContext
5
 
6
- from modules import models, config
7
-
8
 
9
  @torch.inference_mode()
10
  def refine_text(
 
1
  import numpy as np
2
  import torch
3
 
4
+ from modules import config, models
5
  from modules.utils.SeedContext import SeedContext
6
 
 
 
7
 
8
  @torch.inference_mode()
9
  def refine_text(
modules/repos_static/resemble_enhance/common.py CHANGED
@@ -42,7 +42,9 @@ class Normalizer(nn.Module):
42
  self.running_var_unsafe = x.var()
43
  else:
44
  self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
45
- self.running_var_unsafe = self._ema(self.running_var_unsafe, (x - self.running_mean).pow(2).mean())
 
 
46
 
47
  def forward(self, x: Tensor, update=True):
48
  if self.training and update:
 
42
  self.running_var_unsafe = x.var()
43
  else:
44
  self.running_mean_unsafe = self._ema(self.running_mean_unsafe, x.mean())
45
+ self.running_var_unsafe = self._ema(
46
+ self.running_var_unsafe, (x - self.running_mean).pow(2).mean()
47
+ )
48
 
49
  def forward(self, x: Tensor, update=True):
50
  if self.training and update:
modules/repos_static/resemble_enhance/data/dataset.py CHANGED
@@ -44,7 +44,9 @@ def praat_augment(wav, sr):
44
  sound = parselmouth.Sound(wav, sr)
45
  formant_shift_ratio = random.uniform(1.1, 1.5)
46
  pitch_range_factor = random.uniform(0.5, 2.0)
47
- sound = parselmouth.praat.call(sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0)
 
 
48
  wav = np.array(sound.values)[0].astype(np.float32)
49
  return wav
50
 
@@ -73,7 +75,9 @@ class Dataset(DatasetBase):
73
  if len(self.bg_paths) == 0:
74
  raise ValueError(f"No background audio files found in {hp.bg_dir}")
75
 
76
- logger.info(f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files")
 
 
77
 
78
  self.training = training
79
  self.max_retries = max_retries
@@ -121,7 +125,9 @@ class Dataset(DatasetBase):
121
  fg_path = self.fg_paths[index]
122
 
123
  if self.training and random.random() < self.silent_fg_prob:
124
- fg_wav = np.zeros(int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32)
 
 
125
  else:
126
  fg_wav = self._load_wav(fg_path)
127
  if random.random() < self.hp.praat_augment_prob and self.training:
@@ -132,14 +138,20 @@ class Dataset(DatasetBase):
132
  fg_dwav = None
133
  bg_dwav = None
134
  else:
135
- fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(np.float32)
 
 
136
  if self.training:
137
  bg_path = random.choice(self.bg_paths)
138
  else:
139
  # Deterministic for validation
140
  bg_path = self.bg_paths[index % len(self.bg_paths)]
141
- bg_wav = self._load_wav(bg_path, length=len(fg_wav), random_crop=self.training)
142
- bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(np.float32)
 
 
 
 
143
 
144
  return dict(
145
  fg_wav=fg_wav,
@@ -154,7 +166,9 @@ class Dataset(DatasetBase):
154
  return self._getitem_unsafe(index)
155
  except Exception as e:
156
  if i == self.max_retries - 1:
157
- raise RuntimeError(f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries") from e
 
 
158
  logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
159
  index = np.random.randint(0, len(self))
160
 
 
44
  sound = parselmouth.Sound(wav, sr)
45
  formant_shift_ratio = random.uniform(1.1, 1.5)
46
  pitch_range_factor = random.uniform(0.5, 2.0)
47
+ sound = parselmouth.praat.call(
48
+ sound, "Change gender", 75, 600, formant_shift_ratio, 0, pitch_range_factor, 1.0
49
+ )
50
  wav = np.array(sound.values)[0].astype(np.float32)
51
  return wav
52
 
 
75
  if len(self.bg_paths) == 0:
76
  raise ValueError(f"No background audio files found in {hp.bg_dir}")
77
 
78
+ logger.info(
79
+ f"Found {len(self.fg_paths)} foreground files and {len(self.bg_paths)} background files"
80
+ )
81
 
82
  self.training = training
83
  self.max_retries = max_retries
 
125
  fg_path = self.fg_paths[index]
126
 
127
  if self.training and random.random() < self.silent_fg_prob:
128
+ fg_wav = np.zeros(
129
+ int(self.hp.training_seconds * self.hp.wav_rate), dtype=np.float32
130
+ )
131
  else:
132
  fg_wav = self._load_wav(fg_path)
133
  if random.random() < self.hp.praat_augment_prob and self.training:
 
138
  fg_dwav = None
139
  bg_dwav = None
140
  else:
141
+ fg_dwav = _normalize(self.distorter(fg_wav, self.hp.wav_rate)).astype(
142
+ np.float32
143
+ )
144
  if self.training:
145
  bg_path = random.choice(self.bg_paths)
146
  else:
147
  # Deterministic for validation
148
  bg_path = self.bg_paths[index % len(self.bg_paths)]
149
+ bg_wav = self._load_wav(
150
+ bg_path, length=len(fg_wav), random_crop=self.training
151
+ )
152
+ bg_dwav = _normalize(self.distorter(bg_wav, self.hp.wav_rate)).astype(
153
+ np.float32
154
+ )
155
 
156
  return dict(
157
  fg_wav=fg_wav,
 
166
  return self._getitem_unsafe(index)
167
  except Exception as e:
168
  if i == self.max_retries - 1:
169
+ raise RuntimeError(
170
+ f"Failed to load {self.fg_paths[index]} after {self.max_retries} retries"
171
+ ) from e
172
  logger.debug(f"Error loading {self.fg_paths[index]}: {e}, skipping")
173
  index = np.random.randint(0, len(self))
174
 
modules/repos_static/resemble_enhance/data/distorter/base.py CHANGED
@@ -2,8 +2,8 @@ import itertools
2
  import os
3
  import random
4
  import time
5
- from typing import Union
6
  import warnings
 
7
 
8
  import numpy as np
9
 
 
2
  import os
3
  import random
4
  import time
 
5
  import warnings
6
+ from typing import Union
7
 
8
  import numpy as np
9