zhzluke96 commited on
Commit
1df74c6
1 Parent(s): 3710ae9
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env.webui +1 -1
  2. language/zh-CN.json +3 -0
  3. modules/ChatTTS/ChatTTS/core.py +1 -1
  4. modules/SynthesizeSegments.py +49 -16
  5. modules/api/api_setup.py +5 -3
  6. modules/api/impl/google_api.py +11 -12
  7. modules/api/impl/openai_api.py +4 -0
  8. modules/api/impl/ssml_api.py +17 -3
  9. modules/api/impl/tts_api.py +3 -0
  10. modules/api/impl/xtts_v2_api.py +160 -0
  11. modules/devices/devices.py +1 -1
  12. modules/finetune/__init__.py +0 -0
  13. modules/finetune/model/__init__.py +0 -0
  14. modules/finetune/model/encoder.py +87 -0
  15. modules/finetune/model/wavenet.py +227 -0
  16. modules/finetune/train_gpt.py +246 -0
  17. modules/finetune/train_speaker.py +296 -0
  18. modules/finetune/utils/__init__.py +0 -0
  19. modules/finetune/utils/dataset.py +487 -0
  20. modules/finetune/utils/logger.py +409 -0
  21. modules/finetune/utils/model.py +19 -0
  22. modules/finetune/utils/output.py +146 -0
  23. modules/generate_audio.py +2 -0
  24. modules/normalization.py +5 -0
  25. modules/repos_static/resemble_enhance/data/distorter/base.py +2 -1
  26. modules/repos_static/resemble_enhance/data/distorter/custom.py +8 -3
  27. modules/repos_static/resemble_enhance/data/distorter/sox.py +32 -8
  28. modules/repos_static/resemble_enhance/data/utils.py +4 -2
  29. modules/repos_static/resemble_enhance/denoiser/denoiser.py +2 -1
  30. modules/repos_static/resemble_enhance/enhancer/download.py +8 -3
  31. modules/repos_static/resemble_enhance/enhancer/enhancer.py +5 -2
  32. modules/repos_static/resemble_enhance/enhancer/hparams.py +4 -3
  33. modules/repos_static/resemble_enhance/enhancer/inference.py +2 -1
  34. modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py +20 -9
  35. modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py +2 -1
  36. modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py +34 -7
  37. modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py +8 -3
  38. modules/repos_static/resemble_enhance/hparams.py +5 -2
  39. modules/speaker.py +5 -3
  40. modules/ssml_parser/SSMLParser.py +34 -18
  41. modules/synthesize_audio.py +27 -38
  42. modules/utils/audio.py +5 -1
  43. modules/webui/app.py +3 -0
  44. modules/webui/finetune/ProcessMonitor.py +92 -0
  45. modules/webui/finetune/ft_tab.py +13 -0
  46. modules/webui/finetune/ft_ui_utils.py +49 -0
  47. modules/webui/finetune/speaker_ft_tab.py +130 -0
  48. modules/webui/localization_runtime.py +126 -0
  49. modules/webui/ssml/podcast_tab.py +2 -65
  50. modules/webui/ssml/ssml_tab.py +21 -2
.env.webui CHANGED
@@ -17,5 +17,5 @@ TTS_MAX_LEN=1000
17
  SSML_MAX_LEN=3000
18
  MAX_BATCH_SIZE=12
19
 
20
- V_GIT_TAG="🤗hf(0.5.6-rc)"
21
  V_GIT_COMMIT=main
 
17
  SSML_MAX_LEN=3000
18
  MAX_BATCH_SIZE=12
19
 
20
+ V_GIT_TAG="🤗hf(0.6.1-rc)"
21
  V_GIT_COMMIT=main
language/zh-CN.json CHANGED
@@ -80,6 +80,9 @@
80
  "readme": "readme",
81
  "changelog": "changelog",
82
  "💼Speaker file": "💼音色文件",
 
 
 
83
  "TTS_STYLE_GUIDE": ["后缀为 _p 表示带prompt,效果更强但是影响质量"],
84
  "SSML_SPLITER_GUIDE": [
85
  "- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`",
 
80
  "readme": "readme",
81
  "changelog": "changelog",
82
  "💼Speaker file": "💼音色文件",
83
+ "🎛️Spliter": "🎛️分割器配置",
84
+ "eos": "句尾词",
85
+ "Spliter Threshold": "分割器阈值",
86
  "TTS_STYLE_GUIDE": ["后缀为 _p 表示带prompt,效果更强但是影响质量"],
87
  "SSML_SPLITER_GUIDE": [
88
  "- 如果尾字吞字不读,可以试试结尾加上 `[lbreak]`",
modules/ChatTTS/ChatTTS/core.py CHANGED
@@ -17,7 +17,7 @@ from .infer.api import refine_text, infer_code
17
 
18
  from huggingface_hub import snapshot_download
19
 
20
- logging.basicConfig(level=logging.INFO)
21
 
22
 
23
  class Chat:
 
17
 
18
  from huggingface_hub import snapshot_download
19
 
20
+ logging.basicConfig(level=logging.ERROR)
21
 
22
 
23
  class Chat:
modules/SynthesizeSegments.py CHANGED
@@ -1,8 +1,10 @@
 
1
  from box import Box
2
  from pydub import AudioSegment
3
  from typing import List, Union
4
  from scipy.io.wavfile import write
5
  import io
 
6
  from modules.api.utils import calc_spk_style
7
  from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
8
  from modules.utils import rng
@@ -56,27 +58,27 @@ def to_number(value, t, default=0):
56
 
57
 
58
  class TTSAudioSegment(Box):
59
- text: str
60
- temperature: float
61
- top_P: float
62
- top_K: int
63
- spk: int
64
- infer_seed: int
65
- prompt1: str
66
- prompt2: str
67
- prefix: str
68
-
69
- _type: str
70
-
71
  def __init__(self, *args, **kwargs):
72
  super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  class SynthesizeSegments:
76
- def __init__(self, batch_size: int = 8):
77
  self.batch_size = batch_size
78
  self.batch_default_spk_seed = rng.np_rng()
79
  self.batch_default_infer_seed = rng.np_rng()
 
 
80
 
81
  def segment_to_generate_params(
82
  self, segment: Union[SSMLSegment, SSMLBreak]
@@ -85,9 +87,11 @@ class SynthesizeSegments:
85
  return TTSAudioSegment(_type="break")
86
 
87
  if segment.get("params", None) is not None:
88
- return TTSAudioSegment(**segment.get("params"))
 
 
89
 
90
- text = segment.get("text", "")
91
  is_end = segment.get("is_end", False)
92
 
93
  text = str(text).strip()
@@ -156,7 +160,7 @@ class SynthesizeSegments:
156
  for i in range(0, len(bucket), self.batch_size):
157
  batch = bucket[i : i + self.batch_size]
158
  param_arr = [self.segment_to_generate_params(segment) for segment in batch]
159
- texts = [params.text for params in param_arr]
160
 
161
  params = param_arr[0]
162
  audio_datas = generate_audio.generate_audio_batch(
@@ -204,9 +208,38 @@ class SynthesizeSegments:
204
 
205
  return buckets
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  def synthesize_segments(
208
  self, segments: List[Union[SSMLSegment, SSMLBreak]]
209
  ) -> List[AudioSegment]:
 
210
  audio_segments = [None] * len(segments)
211
  buckets = self.bucket_segments(segments)
212
 
 
1
+ import copy
2
  from box import Box
3
  from pydub import AudioSegment
4
  from typing import List, Union
5
  from scipy.io.wavfile import write
6
  import io
7
+ from modules.SentenceSplitter import SentenceSplitter
8
  from modules.api.utils import calc_spk_style
9
  from modules.ssml_parser.SSMLParser import SSMLSegment, SSMLBreak, SSMLContext
10
  from modules.utils import rng
 
58
 
59
 
60
  class TTSAudioSegment(Box):
 
 
 
 
 
 
 
 
 
 
 
 
61
  def __init__(self, *args, **kwargs):
62
  super().__init__(*args, **kwargs)
63
+ self._type = kwargs.get("_type", "voice")
64
+ self.text = kwargs.get("text", "")
65
+ self.temperature = kwargs.get("temperature", 0.3)
66
+ self.top_P = kwargs.get("top_P", 0.5)
67
+ self.top_K = kwargs.get("top_K", 20)
68
+ self.spk = kwargs.get("spk", -1)
69
+ self.infer_seed = kwargs.get("infer_seed", -1)
70
+ self.prompt1 = kwargs.get("prompt1", "")
71
+ self.prompt2 = kwargs.get("prompt2", "")
72
+ self.prefix = kwargs.get("prefix", "")
73
 
74
 
75
  class SynthesizeSegments:
76
+ def __init__(self, batch_size: int = 8, eos="", spliter_thr=100):
77
  self.batch_size = batch_size
78
  self.batch_default_spk_seed = rng.np_rng()
79
  self.batch_default_infer_seed = rng.np_rng()
80
+ self.eos = eos
81
+ self.spliter_thr = spliter_thr
82
 
83
  def segment_to_generate_params(
84
  self, segment: Union[SSMLSegment, SSMLBreak]
 
87
  return TTSAudioSegment(_type="break")
88
 
89
  if segment.get("params", None) is not None:
90
+ params = segment.get("params")
91
+ text = segment.get("text", None) or segment.text or ""
92
+ return TTSAudioSegment(**params, text=text)
93
 
94
+ text = segment.get("text", None) or segment.text or ""
95
  is_end = segment.get("is_end", False)
96
 
97
  text = str(text).strip()
 
160
  for i in range(0, len(bucket), self.batch_size):
161
  batch = bucket[i : i + self.batch_size]
162
  param_arr = [self.segment_to_generate_params(segment) for segment in batch]
163
+ texts = [params.text + self.eos for params in param_arr]
164
 
165
  params = param_arr[0]
166
  audio_datas = generate_audio.generate_audio_batch(
 
208
 
209
  return buckets
210
 
211
+ def split_segments(self, segments: List[Union[SSMLSegment, SSMLBreak]]):
212
+ """
213
+ 将 segments 中的 text 经过 spliter 处理成多个 segments
214
+ """
215
+ spliter = SentenceSplitter(threshold=self.spliter_thr)
216
+ ret_segments: List[Union[SSMLSegment, SSMLBreak]] = []
217
+
218
+ for segment in segments:
219
+ if isinstance(segment, SSMLBreak):
220
+ ret_segments.append(segment)
221
+ continue
222
+
223
+ text = segment.text
224
+ if not text:
225
+ continue
226
+
227
+ sentences = spliter.parse(text)
228
+ for sentence in sentences:
229
+ ret_segments.append(
230
+ SSMLSegment(
231
+ text=sentence,
232
+ attrs=segment.attrs.copy(),
233
+ params=copy.copy(segment.params),
234
+ )
235
+ )
236
+
237
+ return ret_segments
238
+
239
  def synthesize_segments(
240
  self, segments: List[Union[SSMLSegment, SSMLBreak]]
241
  ) -> List[AudioSegment]:
242
+ segments = self.split_segments(segments)
243
  audio_segments = [None] * len(segments)
244
  buckets = self.bucket_segments(segments)
245
 
modules/api/api_setup.py CHANGED
@@ -18,6 +18,7 @@ from modules.api.impl import (
18
  speaker_api,
19
  ping_api,
20
  models_api,
 
21
  )
22
 
23
  logger = logging.getLogger(__name__)
@@ -35,6 +36,7 @@ def create_api(app, exclude=[]):
35
  google_api.setup(app_mgr)
36
  openai_api.setup(app_mgr)
37
  refiner_api.setup(app_mgr)
 
38
 
39
  return app_mgr
40
 
@@ -42,9 +44,9 @@ def create_api(app, exclude=[]):
42
  def setup_model_args(parser: argparse.ArgumentParser):
43
  parser.add_argument("--compile", action="store_true", help="Enable model compile")
44
  parser.add_argument(
45
- "--half",
46
  action="store_true",
47
- help="Enable half precision for model inference",
48
  )
49
  parser.add_argument(
50
  "--off_tqdm",
@@ -82,7 +84,7 @@ def process_model_args(args):
82
  compile = env.get_and_update_env(args, "compile", False, bool)
83
  device_id = env.get_and_update_env(args, "device_id", None, str)
84
  use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
85
- half = env.get_and_update_env(args, "half", False, bool)
86
  off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
87
  debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
88
 
 
18
  speaker_api,
19
  ping_api,
20
  models_api,
21
+ xtts_v2_api,
22
  )
23
 
24
  logger = logging.getLogger(__name__)
 
36
  google_api.setup(app_mgr)
37
  openai_api.setup(app_mgr)
38
  refiner_api.setup(app_mgr)
39
+ xtts_v2_api.setup(app_mgr)
40
 
41
  return app_mgr
42
 
 
44
  def setup_model_args(parser: argparse.ArgumentParser):
45
  parser.add_argument("--compile", action="store_true", help="Enable model compile")
46
  parser.add_argument(
47
+ "--no_half",
48
  action="store_true",
49
+ help="Disalbe half precision for model inference",
50
  )
51
  parser.add_argument(
52
  "--off_tqdm",
 
84
  compile = env.get_and_update_env(args, "compile", False, bool)
85
  device_id = env.get_and_update_env(args, "device_id", None, str)
86
  use_cpu = env.get_and_update_env(args, "use_cpu", [], list)
87
+ no_half = env.get_and_update_env(args, "no_half", False, bool)
88
  off_tqdm = env.get_and_update_env(args, "off_tqdm", False, bool)
89
  debug_generate = env.get_and_update_env(args, "debug_generate", False, bool)
90
 
modules/api/impl/google_api.py CHANGED
@@ -13,6 +13,7 @@ from modules.Enhancer.ResembleEnhance import (
13
  )
14
  from modules.api.Api import APIManager
15
  from modules.synthesize_audio import synthesize_audio
 
16
  from modules.utils.audio import apply_prosody_to_audio_data
17
  from modules.normalization import text_normalize
18
 
@@ -44,6 +45,9 @@ class VoiceSelectionParams(BaseModel):
44
  topK: int = 20
45
  seed: int = 42
46
 
 
 
 
47
 
48
  class AudioConfig(BaseModel):
49
  audioEncoding: api_utils.AudioFormat = "mp3"
@@ -87,6 +91,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
87
  language_code = voice.languageCode
88
  voice_name = voice.name
89
  infer_seed = voice.seed or 42
 
90
  audio_format = audioConfig.audioEncoding or "mp3"
91
  speaking_rate = audioConfig.speakingRate or 1
92
  pitch = audioConfig.pitch or 0
@@ -94,11 +99,9 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
94
 
95
  batch_size = audioConfig.batchSize or 1
96
 
97
- # TODO spliter_threshold
98
  spliter_threshold = audioConfig.spliterThreshold or 100
99
 
100
- # TODO sample_rate
101
- sample_rate_hertz = audioConfig.sampleRateHertz or 24000
102
 
103
  params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
104
 
@@ -137,10 +140,10 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
137
  prefix=params.get("prefix", ""),
138
  batch_size=batch_size,
139
  spliter_threshold=spliter_threshold,
 
140
  )
141
 
142
  elif input.ssml:
143
- # 处理SSML合成逻辑
144
  parser = create_ssml_parser()
145
  segments = parser.parse(input.ssml)
146
  for seg in segments:
@@ -151,17 +154,13 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
151
  status_code=422, detail="The SSML text is empty or parsing failed."
152
  )
153
 
154
- synthesize = SynthesizeSegments(batch_size=batch_size)
 
 
155
  audio_segments = synthesize.synthesize_segments(segments)
156
  combined_audio = combine_audio_segments(audio_segments)
157
 
158
- buffer = io.BytesIO()
159
- combined_audio.export(buffer, format="wav")
160
-
161
- buffer.seek(0)
162
-
163
- audio_data = buffer.read()
164
-
165
  else:
166
  raise HTTPException(
167
  status_code=422, detail="Either text or SSML input must be provided."
 
13
  )
14
  from modules.api.Api import APIManager
15
  from modules.synthesize_audio import synthesize_audio
16
+ from modules.utils import audio
17
  from modules.utils.audio import apply_prosody_to_audio_data
18
  from modules.normalization import text_normalize
19
 
 
45
  topK: int = 20
46
  seed: int = 42
47
 
48
+ # end_of_sentence
49
+ eos: str = "[uv_break]"
50
+
51
 
52
  class AudioConfig(BaseModel):
53
  audioEncoding: api_utils.AudioFormat = "mp3"
 
91
  language_code = voice.languageCode
92
  voice_name = voice.name
93
  infer_seed = voice.seed or 42
94
+ eos = voice.eos or "[uv_break]"
95
  audio_format = audioConfig.audioEncoding or "mp3"
96
  speaking_rate = audioConfig.speakingRate or 1
97
  pitch = audioConfig.pitch or 0
 
99
 
100
  batch_size = audioConfig.batchSize or 1
101
 
 
102
  spliter_threshold = audioConfig.spliterThreshold or 100
103
 
104
+ sample_rate = audioConfig.sampleRateHertz or 24000
 
105
 
106
  params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
107
 
 
140
  prefix=params.get("prefix", ""),
141
  batch_size=batch_size,
142
  spliter_threshold=spliter_threshold,
143
+ end_of_sentence=eos,
144
  )
145
 
146
  elif input.ssml:
 
147
  parser = create_ssml_parser()
148
  segments = parser.parse(input.ssml)
149
  for seg in segments:
 
154
  status_code=422, detail="The SSML text is empty or parsing failed."
155
  )
156
 
157
+ synthesize = SynthesizeSegments(
158
+ batch_size=batch_size, eos=eos, spliter_thr=spliter_threshold
159
+ )
160
  audio_segments = synthesize.synthesize_segments(segments)
161
  combined_audio = combine_audio_segments(audio_segments)
162
 
163
+ sample_rate, audio_data = audio.pydub_to_np(combined_audio)
 
 
 
 
 
 
164
  else:
165
  raise HTTPException(
166
  status_code=422, detail="Either text or SSML input must be provided."
modules/api/impl/openai_api.py CHANGED
@@ -41,6 +41,8 @@ class AudioSpeechRequest(BaseModel):
41
  spliter_threshold: float = Field(
42
  100, ge=10, le=1024, description="Threshold for sentence spliter"
43
  )
 
 
44
 
45
 
46
  async def openai_speech_api(
@@ -52,6 +54,7 @@ async def openai_speech_api(
52
  input_text = request.input
53
  voice = request.voice
54
  style = request.style
 
55
  response_format = request.response_format
56
  batch_size = request.batch_size
57
  spliter_threshold = request.spliter_threshold
@@ -95,6 +98,7 @@ async def openai_speech_api(
95
  prompt1=prompt1,
96
  prompt2=prompt2,
97
  prefix=prefix,
 
98
  )
99
 
100
  if speed != 1:
 
41
  spliter_threshold: float = Field(
42
  100, ge=10, le=1024, description="Threshold for sentence spliter"
43
  )
44
+ # end of sentence
45
+ eos: str = "[uv_break]"
46
 
47
 
48
  async def openai_speech_api(
 
54
  input_text = request.input
55
  voice = request.voice
56
  style = request.style
57
+ eos = request.eos
58
  response_format = request.response_format
59
  batch_size = request.batch_size
60
  spliter_threshold = request.spliter_threshold
 
98
  prompt1=prompt1,
99
  prompt2=prompt2,
100
  prefix=prefix,
101
+ end_of_sentence=eos,
102
  )
103
 
104
  if speed != 1:
modules/api/impl/ssml_api.py CHANGED
@@ -26,8 +26,13 @@ class SSMLRequest(BaseModel):
26
  # NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
27
  batch_size: int = 4
28
 
 
 
29
 
30
- async def synthesize_ssml(
 
 
 
31
  request: SSMLRequest = Body(
32
  ..., description="JSON body with SSML string and format"
33
  )
@@ -36,12 +41,19 @@ async def synthesize_ssml(
36
  ssml = request.ssml
37
  format = request.format.lower()
38
  batch_size = request.batch_size
 
 
39
 
40
  if batch_size < 1:
41
  raise HTTPException(
42
  status_code=400, detail="Batch size must be greater than 0."
43
  )
44
 
 
 
 
 
 
45
  if not ssml or ssml == "":
46
  raise HTTPException(status_code=400, detail="SSML content is required.")
47
 
@@ -55,7 +67,9 @@ async def synthesize_ssml(
55
  for seg in segments:
56
  seg["text"] = text_normalize(seg["text"], is_end=True)
57
 
58
- synthesize = SynthesizeSegments(batch_size)
 
 
59
  audio_segments = synthesize.synthesize_segments(segments)
60
  combined_audio = combine_audio_segments(audio_segments)
61
  buffer = io.BytesIO()
@@ -77,4 +91,4 @@ async def synthesize_ssml(
77
 
78
 
79
  def setup(api_manager: APIManager):
80
- api_manager.post("/v1/ssml", response_class=FileResponse)(synthesize_ssml)
 
26
  # NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
27
  batch_size: int = 4
28
 
29
+ # end of sentence
30
+ eos: str = "[uv_break]"
31
 
32
+ spliter_thr: int = 100
33
+
34
+
35
+ async def synthesize_ssml_api(
36
  request: SSMLRequest = Body(
37
  ..., description="JSON body with SSML string and format"
38
  )
 
41
  ssml = request.ssml
42
  format = request.format.lower()
43
  batch_size = request.batch_size
44
+ eos = request.eos
45
+ spliter_thr = request.spliter_thr
46
 
47
  if batch_size < 1:
48
  raise HTTPException(
49
  status_code=400, detail="Batch size must be greater than 0."
50
  )
51
 
52
+ if spliter_thr < 50:
53
+ raise HTTPException(
54
+ status_code=400, detail="Spliter threshold must be greater than 50."
55
+ )
56
+
57
  if not ssml or ssml == "":
58
  raise HTTPException(status_code=400, detail="SSML content is required.")
59
 
 
67
  for seg in segments:
68
  seg["text"] = text_normalize(seg["text"], is_end=True)
69
 
70
+ synthesize = SynthesizeSegments(
71
+ batch_size=batch_size, eos=eos, spliter_thr=spliter_thr
72
+ )
73
  audio_segments = synthesize.synthesize_segments(segments)
74
  combined_audio = combine_audio_segments(audio_segments)
75
  buffer = io.BytesIO()
 
91
 
92
 
93
  def setup(api_manager: APIManager):
94
+ api_manager.post("/v1/ssml", response_class=FileResponse)(synthesize_ssml_api)
modules/api/impl/tts_api.py CHANGED
@@ -38,6 +38,7 @@ class TTSParams(BaseModel):
38
  prefix: str = Query("", description="Text prefix for inference")
39
  bs: str = Query("8", description="Batch size for inference")
40
  thr: str = Query("100", description="Threshold for sentence spliter")
 
41
 
42
 
43
  async def synthesize_tts(params: TTSParams = Depends()):
@@ -87,6 +88,7 @@ async def synthesize_tts(params: TTSParams = Depends()):
87
  prefix = params.prefix or calc_params.get("prefix", params.prefix)
88
  prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1)
89
  prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2)
 
90
 
91
  batch_size = int(params.bs)
92
  threshold = int(params.thr)
@@ -103,6 +105,7 @@ async def synthesize_tts(params: TTSParams = Depends()):
103
  prefix=prefix,
104
  batch_size=batch_size,
105
  spliter_threshold=threshold,
 
106
  )
107
 
108
  buffer = io.BytesIO()
 
38
  prefix: str = Query("", description="Text prefix for inference")
39
  bs: str = Query("8", description="Batch size for inference")
40
  thr: str = Query("100", description="Threshold for sentence spliter")
41
+ eos: str = Query("", description="End of sentence str")
42
 
43
 
44
  async def synthesize_tts(params: TTSParams = Depends()):
 
88
  prefix = params.prefix or calc_params.get("prefix", params.prefix)
89
  prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1)
90
  prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2)
91
+ eos = params.eos or ""
92
 
93
  batch_size = int(params.bs)
94
  threshold = int(params.thr)
 
105
  prefix=prefix,
106
  batch_size=batch_size,
107
  spliter_threshold=threshold,
108
+ end_of_sentence=eos,
109
  )
110
 
111
  buffer = io.BytesIO()
modules/api/impl/xtts_v2_api.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from fastapi import HTTPException
3
+ from fastapi.responses import StreamingResponse
4
+ from pydantic import BaseModel
5
+ from modules.api import utils as api_utils
6
+ from modules.api.Api import APIManager
7
+
8
+ import soundfile as sf
9
+
10
+ from modules import config
11
+ from modules.normalization import text_normalize
12
+ from modules.speaker import speaker_mgr
13
+ from modules.synthesize_audio import synthesize_audio
14
+
15
+ import logging
16
+
17
+ from modules.utils.audio import apply_prosody_to_audio_data
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class XTTS_V2_Settings:
23
+ def __init__(self):
24
+ self.stream_chunk_size = 100
25
+ self.temperature = 0.3
26
+ self.speed = 1
27
+ self.length_penalty = 0.5
28
+ self.repetition_penalty = 1.0
29
+ self.top_p = 0.7
30
+ self.top_k = 20
31
+ self.enable_text_splitting = True
32
+
33
+
34
+ class TTSSettingsRequest(BaseModel):
35
+ stream_chunk_size: int
36
+ temperature: float
37
+ speed: float
38
+ length_penalty: float
39
+ repetition_penalty: float
40
+ top_p: float
41
+ top_k: int
42
+ enable_text_splitting: bool
43
+
44
+
45
+ class SynthesisRequest(BaseModel):
46
+ text: str
47
+ speaker_wav: str
48
+ language: str
49
+
50
+
51
+ def setup(app: APIManager):
52
+ XTTSV2 = XTTS_V2_Settings()
53
+
54
+ @app.get("/v1/xtts_v2/speakers")
55
+ async def speakers():
56
+ spks = speaker_mgr.list_speakers()
57
+ return [
58
+ {
59
+ "name": spk.name,
60
+ "voice_id": spk.id,
61
+ # TODO: 也许可以放一个 "/v1/tts" 接口地址在这里
62
+ "preview_url": "",
63
+ }
64
+ for spk in spks
65
+ ]
66
+
67
+ @app.post("/v1/xtts_v2/tts_to_audio", response_class=StreamingResponse)
68
+ async def tts_to_audio(request: SynthesisRequest):
69
+ text = request.text
70
+ # speaker_wav 就是 speaker id 。。。
71
+ voice_id = request.speaker_wav
72
+ language = request.language
73
+
74
+ spk = speaker_mgr.get_speaker_by_id(voice_id) or speaker_mgr.get_speaker(
75
+ voice_id
76
+ )
77
+ if spk is None:
78
+ raise HTTPException(status_code=400, detail="Invalid speaker id")
79
+
80
+ text = text_normalize(text, is_end=True)
81
+ sample_rate, audio_data = synthesize_audio(
82
+ text=text,
83
+ temperature=XTTSV2.temperature,
84
+ # length_penalty=XTTSV2.length_penalty,
85
+ # repetition_penalty=XTTSV2.repetition_penalty,
86
+ top_P=XTTSV2.top_p,
87
+ top_K=XTTSV2.top_k,
88
+ spk=spk,
89
+ spliter_threshold=XTTSV2.stream_chunk_size,
90
+ # TODO 支持设置 batch_size
91
+ batch_size=4,
92
+ end_of_sentence="[uv_break]",
93
+ )
94
+
95
+ if XTTSV2.speed:
96
+ audio_data = apply_prosody_to_audio_data(
97
+ audio_data,
98
+ rate=XTTSV2.speed,
99
+ sr=sample_rate,
100
+ )
101
+
102
+ # to mp3
103
+ buffer = io.BytesIO()
104
+ sf.write(buffer, audio_data, sample_rate, format="wav")
105
+ buffer.seek(0)
106
+
107
+ buffer = api_utils.wav_to_mp3(buffer)
108
+
109
+ return StreamingResponse(buffer, media_type="audio/mpeg")
110
+
111
+ @app.get("/v1/xtts_v2/tts_stream")
112
+ async def tts_stream():
113
+ raise HTTPException(status_code=501, detail="Not implemented")
114
+
115
+ @app.post("/v1/xtts_v2/set_tts_settings")
116
+ async def set_tts_settings(request: TTSSettingsRequest):
117
+ try:
118
+ if request.stream_chunk_size < 50:
119
+ raise HTTPException(
120
+ status_code=400, detail="stream_chunk_size must be greater than 0"
121
+ )
122
+ if request.temperature < 0:
123
+ raise HTTPException(
124
+ status_code=400, detail="temperature must be greater than 0"
125
+ )
126
+ if request.speed < 0:
127
+ raise HTTPException(
128
+ status_code=400, detail="speed must be greater than 0"
129
+ )
130
+ if request.length_penalty < 0:
131
+ raise HTTPException(
132
+ status_code=400, detail="length_penalty must be greater than 0"
133
+ )
134
+ if request.repetition_penalty < 0:
135
+ raise HTTPException(
136
+ status_code=400, detail="repetition_penalty must be greater than 0"
137
+ )
138
+ if request.top_p < 0:
139
+ raise HTTPException(
140
+ status_code=400, detail="top_p must be greater than 0"
141
+ )
142
+ if request.top_k < 0:
143
+ raise HTTPException(
144
+ status_code=400, detail="top_k must be greater than 0"
145
+ )
146
+
147
+ XTTSV2.stream_chunk_size = request.stream_chunk_size
148
+ XTTSV2.temperature = request.temperature
149
+ XTTSV2.speed = request.speed
150
+ XTTSV2.length_penalty = request.length_penalty
151
+ XTTSV2.repetition_penalty = request.repetition_penalty
152
+ XTTSV2.top_p = request.top_p
153
+ XTTSV2.top_k = request.top_k
154
+ XTTSV2.enable_text_splitting = request.enable_text_splitting
155
+ return {"message": "Settings successfully applied"}
156
+ except Exception as e:
157
+ if isinstance(e, HTTPException):
158
+ raise e
159
+ logger.error(e)
160
+ raise HTTPException(status_code=500, detail=str(e))
modules/devices/devices.py CHANGED
@@ -127,7 +127,7 @@ def reset_device():
127
  global dtype_gpt
128
  global dtype_decoder
129
 
130
- if config.runtime_env_vars.half:
131
  dtype = torch.float16
132
  dtype_dvae = torch.float16
133
  dtype_vocos = torch.float16
 
127
  global dtype_gpt
128
  global dtype_decoder
129
 
130
+ if not config.runtime_env_vars.no_half:
131
  dtype = torch.float16
132
  dtype_dvae = torch.float16
133
  dtype_vocos = torch.float16
modules/finetune/__init__.py ADDED
File without changes
modules/finetune/model/__init__.py ADDED
File without changes
modules/finetune/model/encoder.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from modules.ChatTTS.ChatTTS.model.dvae import ConvNeXtBlock, DVAEDecoder
5
+
6
+ from .wavenet import WaveNet
7
+
8
+
9
+ def get_encoder_config(decoder: DVAEDecoder) -> dict[str, int | bool]:
10
+ return {
11
+ "idim": decoder.conv_out.out_channels,
12
+ "odim": decoder.conv_in[0].in_channels,
13
+ "n_layer": len(decoder.decoder_block),
14
+ "bn_dim": decoder.conv_in[0].out_channels,
15
+ "hidden": decoder.conv_in[2].out_channels,
16
+ "kernel": decoder.decoder_block[0].dwconv.kernel_size[0],
17
+ "dilation": decoder.decoder_block[0].dwconv.dilation[0],
18
+ "down": decoder.up,
19
+ }
20
+
21
+
22
+ class DVAEEncoder(nn.Module):
23
+ def __init__(
24
+ self,
25
+ idim: int,
26
+ odim: int,
27
+ n_layer: int = 12,
28
+ bn_dim: int = 64,
29
+ hidden: int = 256,
30
+ kernel: int = 7,
31
+ dilation: int = 2,
32
+ down: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+ self.wavenet = WaveNet(
36
+ input_channels=100,
37
+ residual_channels=idim,
38
+ residual_layers=20,
39
+ dilation_cycle=4,
40
+ )
41
+ self.conv_in_transpose = nn.ConvTranspose1d(
42
+ idim, hidden, kernel_size=1, bias=False
43
+ )
44
+ # nn.Sequential(
45
+ # nn.ConvTranspose1d(100, idim, 3, 1, 1, bias=False),
46
+ # nn.ConvTranspose1d(idim, hidden, kernel_size=1, bias=False)
47
+ # )
48
+ self.encoder_block = nn.ModuleList(
49
+ [
50
+ ConvNeXtBlock(
51
+ hidden,
52
+ hidden * 4,
53
+ kernel,
54
+ dilation,
55
+ )
56
+ for _ in range(n_layer)
57
+ ]
58
+ )
59
+ self.conv_out_transpose = nn.Sequential(
60
+ nn.Conv1d(hidden, bn_dim, 3, 1, 1),
61
+ nn.GELU(),
62
+ nn.Conv1d(bn_dim, odim, 3, 1, 1),
63
+ )
64
+
65
+ def forward(
66
+ self,
67
+ audio_mel_specs: torch.Tensor, # (batch_size, audio_len*2, 100)
68
+ audio_attention_mask: torch.Tensor, # (batch_size, audio_len)
69
+ conditioning=None,
70
+ ) -> torch.Tensor:
71
+ mel_attention_mask = (
72
+ audio_attention_mask.unsqueeze(-1).repeat(1, 1, 2).flatten(1)
73
+ )
74
+ x: torch.Tensor = self.wavenet(
75
+ audio_mel_specs.transpose(1, 2)
76
+ ) # (batch_size, idim, audio_len*2)
77
+ x = x * mel_attention_mask.unsqueeze(1)
78
+ x = self.conv_in_transpose(x) # (batch_size, hidden, audio_len*2)
79
+ for f in self.encoder_block:
80
+ x = f(x, conditioning)
81
+ x = self.conv_out_transpose(x) # (batch_size, odim, audio_len*2)
82
+ x = (
83
+ x.view(x.size(0), x.size(1), 2, x.size(2) // 2)
84
+ .permute(0, 3, 1, 2)
85
+ .flatten(2)
86
+ )
87
+ return x # (batch_size, audio_len, audio_dim=odim*2)
modules/finetune/model/wavenet.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """https://github.com/fishaudio/fish-speech/blob/main/fish_speech/models/vqgan/modules/wavenet.py"""
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+
11
+ class Mish(nn.Module):
12
+ def forward(self, x):
13
+ return x * torch.tanh(F.softplus(x))
14
+
15
+
16
+ class DiffusionEmbedding(nn.Module):
17
+ """Diffusion Step Embedding"""
18
+
19
+ def __init__(self, d_denoiser):
20
+ super(DiffusionEmbedding, self).__init__()
21
+ self.dim = d_denoiser
22
+
23
+ def forward(self, x):
24
+ device = x.device
25
+ half_dim = self.dim // 2
26
+ emb = math.log(10000) / (half_dim - 1)
27
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
28
+ emb = x[:, None] * emb[None, :]
29
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
30
+ return emb
31
+
32
+
33
+ class LinearNorm(nn.Module):
34
+ """LinearNorm Projection"""
35
+
36
+ def __init__(self, in_features, out_features, bias=False):
37
+ super(LinearNorm, self).__init__()
38
+ self.linear = nn.Linear(in_features, out_features, bias)
39
+
40
+ nn.init.xavier_uniform_(self.linear.weight)
41
+ if bias:
42
+ nn.init.constant_(self.linear.bias, 0.0)
43
+
44
+ def forward(self, x):
45
+ x = self.linear(x)
46
+ return x
47
+
48
+
49
+ class ConvNorm(nn.Module):
50
+ """1D Convolution"""
51
+
52
+ def __init__(
53
+ self,
54
+ in_channels,
55
+ out_channels,
56
+ kernel_size=1,
57
+ stride=1,
58
+ padding=None,
59
+ dilation=1,
60
+ bias=True,
61
+ w_init_gain="linear",
62
+ ):
63
+ super(ConvNorm, self).__init__()
64
+
65
+ if padding is None:
66
+ assert kernel_size % 2 == 1
67
+ padding = int(dilation * (kernel_size - 1) / 2)
68
+
69
+ self.conv = nn.Conv1d(
70
+ in_channels,
71
+ out_channels,
72
+ kernel_size=kernel_size,
73
+ stride=stride,
74
+ padding=padding,
75
+ dilation=dilation,
76
+ bias=bias,
77
+ )
78
+ nn.init.kaiming_normal_(self.conv.weight)
79
+
80
+ def forward(self, signal):
81
+ conv_signal = self.conv(signal)
82
+
83
+ return conv_signal
84
+
85
+
86
+ class ResidualBlock(nn.Module):
87
+ """Residual Block"""
88
+
89
+ def __init__(
90
+ self,
91
+ residual_channels,
92
+ use_linear_bias=False,
93
+ dilation=1,
94
+ condition_channels=None,
95
+ ):
96
+ super(ResidualBlock, self).__init__()
97
+ self.conv_layer = ConvNorm(
98
+ residual_channels,
99
+ 2 * residual_channels,
100
+ kernel_size=3,
101
+ stride=1,
102
+ padding=dilation,
103
+ dilation=dilation,
104
+ )
105
+
106
+ if condition_channels is not None:
107
+ self.diffusion_projection = LinearNorm(
108
+ residual_channels, residual_channels, use_linear_bias
109
+ )
110
+ self.condition_projection = ConvNorm(
111
+ condition_channels, 2 * residual_channels, kernel_size=1
112
+ )
113
+
114
+ self.output_projection = ConvNorm(
115
+ residual_channels, 2 * residual_channels, kernel_size=1
116
+ )
117
+
118
+ def forward(self, x, condition=None, diffusion_step=None):
119
+ y = x
120
+
121
+ if diffusion_step is not None:
122
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
123
+ y = y + diffusion_step
124
+
125
+ y = self.conv_layer(y)
126
+
127
+ if condition is not None:
128
+ condition = self.condition_projection(condition)
129
+ y = y + condition
130
+
131
+ gate, filter = torch.chunk(y, 2, dim=1)
132
+ y = torch.sigmoid(gate) * torch.tanh(filter)
133
+
134
+ y = self.output_projection(y)
135
+ residual, skip = torch.chunk(y, 2, dim=1)
136
+
137
+ return (x + residual) / math.sqrt(2.0), skip
138
+
139
+
140
+ class WaveNet(nn.Module):
141
+ def __init__(
142
+ self,
143
+ input_channels: Optional[int] = None,
144
+ output_channels: Optional[int] = None,
145
+ residual_channels: int = 512,
146
+ residual_layers: int = 20,
147
+ dilation_cycle: Optional[int] = 4,
148
+ is_diffusion: bool = False,
149
+ condition_channels: Optional[int] = None,
150
+ ):
151
+ super().__init__()
152
+
153
+ # Input projection
154
+ self.input_projection = None
155
+ if input_channels is not None and input_channels != residual_channels:
156
+ self.input_projection = ConvNorm(
157
+ input_channels, residual_channels, kernel_size=1
158
+ )
159
+
160
+ if input_channels is None:
161
+ input_channels = residual_channels
162
+
163
+ self.input_channels = input_channels
164
+
165
+ # Residual layers
166
+ self.residual_layers = nn.ModuleList(
167
+ [
168
+ ResidualBlock(
169
+ residual_channels=residual_channels,
170
+ use_linear_bias=False,
171
+ dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
172
+ condition_channels=condition_channels,
173
+ )
174
+ for i in range(residual_layers)
175
+ ]
176
+ )
177
+
178
+ # Skip projection
179
+ self.skip_projection = ConvNorm(
180
+ residual_channels, residual_channels, kernel_size=1
181
+ )
182
+
183
+ # Output projection
184
+ self.output_projection = None
185
+ if output_channels is not None and output_channels != residual_channels:
186
+ self.output_projection = ConvNorm(
187
+ residual_channels, output_channels, kernel_size=1
188
+ )
189
+
190
+ if is_diffusion:
191
+ self.diffusion_embedding = DiffusionEmbedding(residual_channels)
192
+ self.mlp = nn.Sequential(
193
+ LinearNorm(residual_channels, residual_channels * 4, False),
194
+ Mish(),
195
+ LinearNorm(residual_channels * 4, residual_channels, False),
196
+ )
197
+
198
+ self.apply(self._init_weights)
199
+
200
+ def _init_weights(self, m):
201
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
202
+ nn.init.trunc_normal_(m.weight, std=0.02)
203
+ if getattr(m, "bias", None) is not None:
204
+ nn.init.constant_(m.bias, 0)
205
+
206
+ def forward(self, x, t=None, condition=None):
207
+ if self.input_projection is not None:
208
+ x = self.input_projection(x)
209
+ x = F.silu(x)
210
+
211
+ if t is not None:
212
+ t = self.diffusion_embedding(t)
213
+ t = self.mlp(t)
214
+
215
+ skip = []
216
+ for layer in self.residual_layers:
217
+ x, skip_connection = layer(x, condition, t)
218
+ skip.append(skip_connection)
219
+
220
+ x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
221
+ x = self.skip_projection(x)
222
+
223
+ if self.output_projection is not None:
224
+ x = F.silu(x)
225
+ x = self.output_projection(x)
226
+
227
+ return x
modules/finetune/train_gpt.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch
3
+ import transformers
4
+ import peft
5
+ from transformers.trainer_pt_utils import LabelSmoother
6
+ from utils.dataset import AudioCollator
7
+ from utils.logger import MetricLogger
8
+ from utils.output import ansi, get_ansi_len, output_iter
9
+
10
+ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
11
+
12
+
13
+ def train_gpt_lora(
14
+ chat,
15
+ dataset,
16
+ decoder_encoder,
17
+ dvae_encoder,
18
+ batch_size=16,
19
+ epochs=10,
20
+ train_text=True,
21
+ speaker_embeds=None,
22
+ lora_r=8,
23
+ lora_alpha=16,
24
+ ):
25
+ if speaker_embeds is None:
26
+ speaker_embeds = {}
27
+
28
+ tokenizer = chat.pretrain_models["tokenizer"]
29
+ decoder_decoder = chat.pretrain_models["decoder"]
30
+ decoder_decoder.eval().requires_grad_(False)
31
+ decoder_encoder.to(device=dataset.device).eval().requires_grad_(False)
32
+ dvae_decoder = chat.pretrain_models["dvae"]
33
+ dvae_decoder.eval().requires_grad_(False)
34
+ dvae_encoder.to(device=dataset.device).eval().requires_grad_(False)
35
+
36
+ gpt = chat.pretrain_models["gpt"]
37
+ gpt.train().requires_grad_()
38
+
39
+ # Add LoRA to GPT model
40
+ lora_config = peft.LoraConfig(r=lora_r, lora_alpha=lora_alpha)
41
+ gpt.gpt = peft.get_peft_model(gpt.gpt, lora_config)
42
+
43
+ speaker_embeds = {
44
+ speaker: torch.randn(768, device=dataset.device, requires_grad=True)
45
+ for speaker in dataset.speakers
46
+ } | speaker_embeds
47
+
48
+ for speaker_embed in speaker_embeds.values():
49
+ std, mean = chat.pretrain_models["spk_stat"].chunk(2)
50
+ speaker_embed.data = speaker_embed.data * std + mean
51
+
52
+ SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
53
+ AUDIO_EOS_TOKEN_ID = 0
54
+ AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID
55
+
56
+ train_params = list(gpt.parameters()) + list(speaker_embeds.values())
57
+ optimizer = torch.optim.Adam(
58
+ gpt.parameters(), lr=1e-3, weight_decay=0, betas=[0.9, 0.95], eps=1e-5
59
+ )
60
+ optimizer.add_param_group({"params": speaker_embeds.values(), "lr": 1e-1})
61
+
62
+ loss_fn = torch.nn.CrossEntropyLoss()
63
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7)
64
+
65
+ loader = torch.utils.data.DataLoader(
66
+ dataset,
67
+ batch_size=batch_size,
68
+ shuffle=True,
69
+ collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id),
70
+ )
71
+ logger = MetricLogger()
72
+ logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None)
73
+
74
+ for _epoch in range(epochs):
75
+ _epoch += 1
76
+ logger.reset()
77
+ header = "{blue_light}{0}: {1}{reset}".format(
78
+ "Epoch", output_iter(_epoch, epochs), **ansi
79
+ )
80
+ header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header))
81
+ iterator = logger.log_every(loader, header=header, tqdm_header="Batch")
82
+
83
+ for batch in iterator:
84
+ speakers = batch["speaker"]
85
+ text_input_ids = batch["text_input_ids"]
86
+ text_attention_mask = batch["text_attention_mask"]
87
+ audio_mel_specs = batch["audio_mel_specs"]
88
+ audio_attention_mask = batch["audio_attention_mask"]
89
+
90
+ batch_size, text_len = text_attention_mask.size()
91
+
92
+ dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask)
93
+ _, dvae_audio_input_ids = quantize(
94
+ dvae_decoder.vq_layer.quantizer, dvae_audio_latents
95
+ )
96
+ dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID
97
+
98
+ extended_audio_attention_mask = torch.cat(
99
+ [
100
+ audio_attention_mask,
101
+ torch.zeros(
102
+ (batch_size, 1),
103
+ dtype=audio_attention_mask.dtype,
104
+ device=audio_attention_mask.device,
105
+ ),
106
+ ],
107
+ dim=1,
108
+ )
109
+ extended_audio_input_ids = torch.cat(
110
+ [
111
+ dvae_audio_input_ids,
112
+ AUDIO_PAD_TOKEN_ID
113
+ * torch.ones(
114
+ (batch_size, 1, gpt.num_vq),
115
+ dtype=dvae_audio_input_ids.dtype,
116
+ device=dvae_audio_input_ids.device,
117
+ ),
118
+ ],
119
+ dim=1,
120
+ )
121
+
122
+ indices = audio_attention_mask.int().sum(dim=1)
123
+ for i in range(batch_size):
124
+ extended_audio_attention_mask[i, indices[i]] = 1
125
+ extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID
126
+
127
+ input_ids = torch.cat(
128
+ [
129
+ text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq),
130
+ extended_audio_input_ids,
131
+ ],
132
+ dim=1,
133
+ )
134
+ attention_mask = torch.cat(
135
+ [text_attention_mask, extended_audio_attention_mask], dim=1
136
+ )
137
+ text_mask = torch.cat(
138
+ [
139
+ torch.ones_like(text_attention_mask, dtype=bool),
140
+ torch.zeros_like(extended_audio_attention_mask, dtype=bool),
141
+ ],
142
+ dim=1,
143
+ )
144
+ labels = input_ids.clone()
145
+ labels[~attention_mask.bool()] = IGNORE_TOKEN_ID
146
+
147
+ inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask)
148
+
149
+ indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1)
150
+ for i, speaker in enumerate(speakers):
151
+ inputs_embeds[i, indices[i]] = torch.nn.functional.normalize(
152
+ speaker_embeds[speaker].to(dtype=inputs_embeds.dtype),
153
+ p=2.0,
154
+ dim=-1,
155
+ eps=1e-12,
156
+ ).unsqueeze(0)
157
+
158
+ outputs = gpt.gpt.forward(
159
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask
160
+ )
161
+ hidden_states = outputs.last_hidden_state
162
+ text_hidden_states = hidden_states[:, : text_len - 1]
163
+ audio_hidden_states = hidden_states[:, text_len - 1 : -1]
164
+
165
+ audio_logits = torch.stack(
166
+ [gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)],
167
+ dim=2,
168
+ )
169
+ audio_loss = loss_fn(
170
+ audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
171
+ )
172
+ loss = audio_loss
173
+
174
+ if train_text:
175
+ text_logits = gpt.head_text(text_hidden_states)
176
+ text_loss = loss_fn(
177
+ text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
178
+ )
179
+ loss += text_loss
180
+ logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
181
+
182
+ gpt_gen_mel_specs = decoder_decoder(
183
+ audio_hidden_states[:, :-1].transpose(1, 2)
184
+ ).transpose(1, 2)
185
+ mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs)
186
+ loss += 0.01 * mse_loss
187
+
188
+ optimizer.zero_grad()
189
+ loss.backward()
190
+ torch.nn.utils.clip_grad_norm_(train_params, 1.0)
191
+ optimizer.step()
192
+
193
+ logger.meters["loss"].update(loss.item(), n=batch_size)
194
+ logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size)
195
+ logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size)
196
+
197
+ lr_scheduler.step()
198
+ optimizer.zero_grad()
199
+ return speaker_embeds
200
+
201
+
202
+ # Example usage
203
+ def main():
204
+ # Load necessary models and data paths
205
+ chat = ChatTTS.Chat()
206
+ chat.load_models()
207
+ dataset = XzListTar(
208
+ root="data/all.list",
209
+ tokenizer=chat.pretrain_models["tokenizer"],
210
+ vocos_model=chat.pretrain_models["vocos"],
211
+ tar_path="data/Xz.tar",
212
+ tar_in_memory=True,
213
+ process_ahead=True,
214
+ )
215
+
216
+ decoder_encoder = DVAEEncoder(
217
+ **get_encoder_config(chat.pretrain_models["decoder"].decoder)
218
+ )
219
+ dvae_encoder = DVAEEncoder(
220
+ **get_encoder_config(chat.pretrain_models["dvae"].decoder)
221
+ )
222
+
223
+ # Train GPT with LoRA
224
+ speaker_embeds = train_gpt_lora(
225
+ chat=chat,
226
+ dataset=dataset,
227
+ decoder_encoder=decoder_encoder,
228
+ dvae_encoder=dvae_encoder,
229
+ batch_size=32,
230
+ epochs=10,
231
+ train_text=True,
232
+ lora_r=8,
233
+ lora_alpha=16,
234
+ )
235
+
236
+ # Save LoRA parameters and embeddings
237
+ lora_save_path = "./saved_models/gpt_lora.pth"
238
+ peft.save_pretrained(gpt.gpt, lora_save_path)
239
+ np.savez(
240
+ "./saved_models/speaker_embeds.npz",
241
+ **{k: v.cpu().numpy() for k, v in speaker_embeds.items()}
242
+ )
243
+
244
+
245
+ if __name__ == "__main__":
246
+ main()
modules/finetune/train_speaker.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import transformers
4
+
5
+ from modules.finetune.model.encoder import DVAEEncoder, get_encoder_config
6
+ from modules.finetune.utils.output import get_ansi_len, output_iter, ansi
7
+ from .utils.logger import MetricLogger
8
+ from .utils.dataset import AudioCollator, XzListTar
9
+ from .utils.model import quantize
10
+
11
+ IGNORE_TOKEN_ID = transformers.trainer_pt_utils.LabelSmoother.ignore_index
12
+
13
+
14
+ def train_speaker_embeddings(
15
+ chat,
16
+ dataset,
17
+ gpt,
18
+ batch_size=16,
19
+ epochs=10,
20
+ train_text=True,
21
+ speaker_embeds=None,
22
+ ):
23
+ tokenizer = chat.pretrain_models["tokenizer"]
24
+
25
+ decoder_decoder = chat.pretrain_models["decoder"]
26
+ decoder_decoder.eval().requires_grad_(False)
27
+ decoder_encoder = DVAEEncoder(**get_encoder_config(decoder_decoder.decoder)).to(
28
+ device=dataset.device
29
+ )
30
+ decoder_encoder.eval().requires_grad_(False)
31
+
32
+ dvae_decoder = chat.pretrain_models["dvae"]
33
+ dvae_decoder.eval().requires_grad_(False)
34
+ dvae_encoder = DVAEEncoder(**get_encoder_config(dvae_decoder.decoder)).to(
35
+ device=dataset.device
36
+ )
37
+ dvae_encoder.eval().requires_grad_(False)
38
+
39
+ if speaker_embeds is None:
40
+ speaker_embeds = {
41
+ speaker: torch.randn(
42
+ 768,
43
+ device=dataset.device,
44
+ requires_grad=True,
45
+ )
46
+ for speaker in dataset.speakers
47
+ }
48
+ for speaker_embed in speaker_embeds.values():
49
+ std, mean = chat.pretrain_models["spk_stat"].chunk(2)
50
+ speaker_embed.data = speaker_embed.data * std + mean
51
+
52
+ SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
53
+ AUDIO_EOS_TOKEN_ID = 0
54
+ AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID
55
+
56
+ optimizer = torch.optim.Adam(
57
+ speaker_embeds.values(), lr=1e-2, weight_decay=0, betas=[0.9, 0.95], eps=1e-5
58
+ )
59
+ loss_fn = torch.nn.CrossEntropyLoss()
60
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7)
61
+
62
+ loader = torch.utils.data.DataLoader(
63
+ dataset,
64
+ batch_size=batch_size,
65
+ shuffle=True,
66
+ collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id),
67
+ )
68
+ logger = MetricLogger()
69
+ logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None)
70
+
71
+ for _epoch in range(epochs):
72
+ _epoch += 1
73
+ logger.reset()
74
+ header = "{blue_light}{0}: {1}{reset}".format(
75
+ "Epoch", output_iter(_epoch, epochs), **ansi
76
+ )
77
+ header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header))
78
+ iterator = logger.log_every(loader, header=header, tqdm_header="Batch")
79
+
80
+ for batch in iterator:
81
+ speakers = batch["speaker"]
82
+ text_input_ids = batch["text_input_ids"]
83
+ text_attention_mask = batch["text_attention_mask"]
84
+ audio_mel_specs = batch["audio_mel_specs"]
85
+ audio_attention_mask = batch["audio_attention_mask"]
86
+
87
+ batch_size, text_len = text_attention_mask.size()
88
+
89
+ dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask)
90
+ _, dvae_audio_input_ids = quantize(
91
+ dvae_decoder.vq_layer.quantizer, dvae_audio_latents
92
+ )
93
+ dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID
94
+
95
+ extended_audio_attention_mask = torch.cat(
96
+ [
97
+ audio_attention_mask,
98
+ torch.zeros(
99
+ (batch_size, 1),
100
+ dtype=audio_attention_mask.dtype,
101
+ device=audio_attention_mask.device,
102
+ ),
103
+ ],
104
+ dim=1,
105
+ )
106
+ extended_audio_input_ids = torch.cat(
107
+ [
108
+ dvae_audio_input_ids,
109
+ AUDIO_PAD_TOKEN_ID
110
+ * torch.ones(
111
+ (batch_size, 1, gpt.num_vq),
112
+ dtype=dvae_audio_input_ids.dtype,
113
+ device=dvae_audio_input_ids.device,
114
+ ),
115
+ ],
116
+ dim=1,
117
+ )
118
+ indices = audio_attention_mask.int().sum(dim=1)
119
+ for i in range(batch_size):
120
+ extended_audio_attention_mask[i, indices[i]] = 1
121
+ extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID
122
+
123
+ input_ids = torch.cat(
124
+ [
125
+ text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq),
126
+ extended_audio_input_ids,
127
+ ],
128
+ dim=1,
129
+ )
130
+ attention_mask = torch.cat(
131
+ [text_attention_mask, extended_audio_attention_mask], dim=1
132
+ )
133
+ text_mask = torch.cat(
134
+ [
135
+ torch.ones_like(text_attention_mask, dtype=bool),
136
+ torch.zeros_like(extended_audio_attention_mask, dtype=bool),
137
+ ],
138
+ dim=1,
139
+ )
140
+
141
+ labels = input_ids.clone()
142
+ labels[~attention_mask.bool()] = IGNORE_TOKEN_ID
143
+
144
+ inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask)
145
+
146
+ indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1)
147
+ for i, speaker in enumerate(speakers):
148
+ inputs_embeds[i, indices[i]] = F.normalize(
149
+ speaker_embeds[speaker].to(dtype=inputs_embeds.dtype),
150
+ p=2.0,
151
+ dim=-1,
152
+ eps=1e-12,
153
+ ).unsqueeze(0)
154
+ outputs = gpt.gpt.forward(
155
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask
156
+ )
157
+ hidden_states = outputs.last_hidden_state
158
+ text_hidden_states = hidden_states[:, : text_len - 1]
159
+ audio_hidden_states = hidden_states[:, text_len - 1 : -1]
160
+
161
+ audio_logits = torch.stack(
162
+ [gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)],
163
+ dim=2,
164
+ )
165
+ audio_loss = loss_fn(
166
+ audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
167
+ )
168
+ loss = audio_loss
169
+ if train_text:
170
+ text_logits = gpt.head_text(text_hidden_states)
171
+ text_loss = loss_fn(
172
+ text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
173
+ )
174
+ loss += text_loss
175
+ logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
176
+
177
+ gpt_gen_mel_specs = decoder_decoder(
178
+ audio_hidden_states[:, :-1].transpose(1, 2)
179
+ ).transpose(1, 2)
180
+ mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs)
181
+ loss += 0.01 * mse_loss
182
+
183
+ optimizer.zero_grad()
184
+ loss.backward()
185
+ torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
186
+ optimizer.step()
187
+ logger.meters["loss"].update(loss.item(), n=batch_size)
188
+ logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size)
189
+ logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size)
190
+ lr_scheduler.step()
191
+ optimizer.zero_grad()
192
+ return speaker_embeds
193
+
194
+
195
+ if __name__ == "__main__":
196
+ import argparse
197
+ import os
198
+ import numpy as np
199
+ import pathlib
200
+ from modules.models import load_chat_tts
201
+ from modules.devices import devices
202
+ from modules import config
203
+ from modules.speaker import Speaker
204
+
205
+ config.runtime_env_vars.no_half = True
206
+ devices.reset_device()
207
+
208
+ parser = argparse.ArgumentParser()
209
+ parser.add_argument("--save_folder", type=str, default="./")
210
+ parser.add_argument("--batch_size", type=int, default=16)
211
+ parser.add_argument("--epochs", type=int, default=100)
212
+ parser.add_argument("--train_text", action="store_true", help="train text loss")
213
+ # 初始化 speaker
214
+ parser.add_argument("--init_speaker", type=str)
215
+ parser.add_argument(
216
+ "--data_path",
217
+ type=str,
218
+ default="datasets/data_speaker_a/speaker_a.list",
219
+ help="the data_path to json/list file",
220
+ )
221
+ parser.add_argument("--tar_path", type=str, help="the tarball path with wavs")
222
+ parser.add_argument(
223
+ "--tar_in_memory", action="store_true", help="load tarball in memory"
224
+ )
225
+
226
+ args = parser.parse_args()
227
+
228
+ data_path: str = args.data_path
229
+ tar_path: str | None = args.tar_path
230
+ tar_in_memory: bool = args.tar_in_memory
231
+ train_text: bool = args.train_text
232
+ # gpt_lora: bool = args.gpt_lora
233
+ # gpt_kbit: int = args.gpt_kbit
234
+ save_folder: str = args.save_folder
235
+ batch_size: int = args.batch_size
236
+ epochs: int = args.epochs
237
+ init_speaker: str = args.init_speaker
238
+
239
+ speaker_embeds_save_path = os.path.join(save_folder, "speaker_embeds.npz")
240
+
241
+ chat = load_chat_tts()
242
+ dataset = XzListTar(
243
+ root=data_path,
244
+ tokenizer=chat.pretrain_models["tokenizer"],
245
+ vocos_model=chat.pretrain_models["vocos"],
246
+ tar_path=tar_path,
247
+ tar_in_memory=tar_in_memory,
248
+ device=devices.device,
249
+ # speakers=None, # set(['speaker_A', 'speaker_B'])
250
+ )
251
+
252
+ print("len(dataset)", len(dataset))
253
+
254
+ speaker_embeds = None
255
+ if init_speaker:
256
+ spk: Speaker = Speaker.from_file(init_speaker)
257
+ speaker_embeds = {
258
+ speaker: torch.tensor(
259
+ spk.emb,
260
+ device=devices.device,
261
+ requires_grad=True,
262
+ )
263
+ for speaker in dataset.speakers
264
+ }
265
+
266
+ speaker_embeds = train_speaker_embeddings(
267
+ chat,
268
+ dataset,
269
+ chat.pretrain_models["gpt"],
270
+ batch_size=batch_size,
271
+ epochs=epochs,
272
+ train_text=train_text,
273
+ speaker_embeds=speaker_embeds,
274
+ )
275
+ speaker_outs = {
276
+ speaker: Speaker(speaker_embed.detach().cpu(), f"ep{epochs}_{speaker}")
277
+ for speaker, speaker_embed in speaker_embeds.items()
278
+ }
279
+ time_str = np.datetime_as_string(np.datetime64("now", "s"))
280
+ time_str = time_str.replace(":", "_").replace(" ", "_").replace("-", "_")
281
+ for speaker, speaker_out in speaker_outs.items():
282
+ torch.save(
283
+ speaker_out,
284
+ pathlib.Path(save_folder) / f"spk_{speaker}_{time_str}_ep{epochs}.pt",
285
+ )
286
+
287
+ # example
288
+ """
289
+ python -m modules.finetune.train_speaker \
290
+ --data_path datasets/data_speaker_a/speaker_a.list \
291
+ --save_folder ./data \
292
+ --init_speaker ./data/speakers/Bob.pt \
293
+ --epochs 100 \
294
+ --batch_size 6 \
295
+ --train_text
296
+ """
modules/finetune/utils/__init__.py ADDED
File without changes
modules/finetune/utils/dataset.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import functools
3
+ import json
4
+ import tarfile
5
+ import io
6
+ import logging
7
+ import abc
8
+ import typing
9
+
10
+ import torch.utils.data
11
+ import torchaudio
12
+ from torchvision.datasets.utils import download_url
13
+ import transformers
14
+ import vocos
15
+
16
+ from modules.ChatTTS.ChatTTS.utils.infer_utils import (
17
+ count_invalid_characters,
18
+ apply_character_map,
19
+ )
20
+
21
+
22
+ class LazyDataType(typing.TypedDict):
23
+ filepath: str
24
+ speaker: str
25
+ lang: str
26
+ text: str
27
+
28
+
29
+ class DataType(LazyDataType):
30
+ text_input_ids: torch.Tensor # (batch_size, text_len)
31
+ text_attention_mask: torch.Tensor # (batch_size, text_len)
32
+ audio_mel_specs: torch.Tensor # (batch_size, audio_len*2, 100)
33
+ audio_attention_mask: torch.Tensor # (batch_size, audio_len)
34
+
35
+
36
+ class XzListTarKwargsType(typing.TypedDict):
37
+ tokenizer: typing.Union[transformers.PreTrainedTokenizer, None]
38
+ vocos_model: typing.Union[vocos.Vocos, None]
39
+ device: typing.Union[str, torch.device, None]
40
+ speakers: typing.Union[typing.Iterable[str], None]
41
+ sample_rate: typing.Union[int]
42
+ default_speaker: typing.Union[str, None]
43
+ default_lang: typing.Union[str, None]
44
+ tar_in_memory: typing.Union[bool, None]
45
+ process_ahead: typing.Union[bool, None]
46
+
47
+
48
+ class AudioFolder(torch.utils.data.Dataset, abc.ABC):
49
+ def __init__(
50
+ self,
51
+ root: str | io.BytesIO,
52
+ tokenizer: transformers.PreTrainedTokenizer | None = None,
53
+ vocos_model: vocos.Vocos | None = None,
54
+ device: str | torch.device | None = None,
55
+ speakers: typing.Iterable[str] | None = None,
56
+ sample_rate: int = 24_000,
57
+ default_speaker: str | None = None,
58
+ default_lang: str | None = None,
59
+ tar_path: str | None = None,
60
+ tar_in_memory: bool = False,
61
+ process_ahead: bool = False,
62
+ ) -> None:
63
+ self.root = root
64
+ self.sample_rate = sample_rate
65
+ self.default_speaker = default_speaker
66
+ self.default_lang = default_lang
67
+
68
+ self.logger = logging.getLogger(__name__)
69
+ self.normalizer = {}
70
+
71
+ self.tokenizer = tokenizer
72
+ self.vocos = vocos_model
73
+ self.vocos_device = (
74
+ None if self.vocos is None else next(self.vocos.parameters()).device
75
+ )
76
+ self.device = device or self.vocos_device
77
+
78
+ # tar -cvf ../Xz.tar *
79
+ # tar -xf Xz.tar -C ./Xz
80
+ self.tar_path = tar_path
81
+ self.tar_file = None
82
+ self.tar_io = None
83
+ if tar_path is not None:
84
+ if tar_in_memory:
85
+ with open(tar_path, "rb") as f:
86
+ self.tar_io = io.BytesIO(f.read())
87
+ self.tar_file = tarfile.open(fileobj=self.tar_io)
88
+ else:
89
+ self.tar_file = tarfile.open(tar_path)
90
+
91
+ self.lazy_data, self.speakers = self.get_lazy_data(root, speakers)
92
+
93
+ self.text_input_ids: dict[int, torch.Tensor] = {}
94
+ self.audio_mel_specs: dict[int, torch.Tensor] = {}
95
+ if process_ahead:
96
+ for n, item in enumerate(self.lazy_data):
97
+ self.audio_mel_specs[n] = self.preprocess_audio(item["filepath"])
98
+ self.text_input_ids[n] = self.preprocess_text(
99
+ item["text"], item["lang"]
100
+ )
101
+ if self.tar_file is not None:
102
+ self.tar_file.close()
103
+ if self.tar_io is not None:
104
+ self.tar_io.close()
105
+
106
+ @abc.abstractmethod
107
+ def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: ...
108
+
109
+ @staticmethod
110
+ @abc.abstractmethod
111
+ def save_config(
112
+ save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
113
+ ) -> None: ...
114
+
115
+ def __len__(self):
116
+ return len(self.lazy_data)
117
+
118
+ def __getitem__(self, n: int) -> DataType:
119
+ lazy_data = self.lazy_data[n]
120
+ if n in self.audio_mel_specs:
121
+ audio_mel_specs = self.audio_mel_specs[n]
122
+ text_input_ids = self.text_input_ids[n]
123
+ else:
124
+ audio_mel_specs = self.preprocess_audio(lazy_data["filepath"])
125
+ text_input_ids = self.preprocess_text(lazy_data["text"], lazy_data["lang"])
126
+ self.audio_mel_specs[n] = audio_mel_specs
127
+ self.text_input_ids[n] = text_input_ids
128
+ if len(self.audio_mel_specs) == len(self.lazy_data):
129
+ if self.tar_file is not None:
130
+ self.tar_file.close()
131
+ if self.tar_io is not None:
132
+ self.tar_io.close()
133
+ text_attention_mask = torch.ones(
134
+ len(text_input_ids), device=text_input_ids.device
135
+ )
136
+ audio_attention_mask = torch.ones(
137
+ (len(audio_mel_specs) + 1) // 2,
138
+ device=audio_mel_specs.device,
139
+ )
140
+ return {
141
+ "filepath": lazy_data["filepath"],
142
+ "speaker": lazy_data["speaker"],
143
+ "lang": lazy_data["lang"],
144
+ "text": lazy_data["text"],
145
+ "text_input_ids": text_input_ids,
146
+ "text_attention_mask": text_attention_mask,
147
+ "audio_mel_specs": audio_mel_specs,
148
+ "audio_attention_mask": audio_attention_mask,
149
+ }
150
+
151
+ def get_lazy_data(
152
+ self,
153
+ root: str | io.BytesIO,
154
+ speakers: typing.Iterable[str] | None = None,
155
+ ) -> tuple[list[LazyDataType], set[str]]:
156
+ if speakers is not None:
157
+ new_speakers = set(speakers)
158
+ else:
159
+ new_speakers = set()
160
+ lazy_data = []
161
+
162
+ raw_data = self.get_raw_data(root)
163
+ folder_path = os.path.dirname(root) if isinstance(root, str) else ""
164
+ for item in raw_data:
165
+ if "speaker" not in item:
166
+ item["speaker"] = self.default_speaker
167
+ if "lang" not in item:
168
+ item["lang"] = self.default_lang
169
+
170
+ if speakers is not None and item["speaker"] not in speakers:
171
+ continue
172
+ if speakers is None and item["speaker"] not in new_speakers:
173
+ new_speakers.add(item["speaker"])
174
+ if self.tar_file is None and isinstance(root, str):
175
+ filepath = os.path.join(folder_path, item["filepath"])
176
+ else:
177
+ filepath = item["filepath"]
178
+ lazy_data.append(
179
+ {
180
+ "filepath": filepath,
181
+ "speaker": item["speaker"],
182
+ "lang": item["lang"].lower(),
183
+ "text": item["text"],
184
+ }
185
+ )
186
+ return lazy_data, new_speakers
187
+
188
+ def preprocess_text(
189
+ self,
190
+ text: str,
191
+ lang: str,
192
+ ) -> torch.Tensor:
193
+ invalid_characters = count_invalid_characters(text)
194
+ if len(invalid_characters):
195
+ # self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
196
+ text = apply_character_map(text)
197
+
198
+ # if not skip_refine_text:
199
+ # text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
200
+ # text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
201
+ # text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
202
+ # if refine_text_only:
203
+ # return text
204
+
205
+ text = f"[Stts][spk_emb]{text}[Ptts]"
206
+ # text = f'[Stts][empty_spk]{text}[Ptts]'
207
+
208
+ text_token = self.tokenizer(
209
+ text, return_tensors="pt", add_special_tokens=False
210
+ ).to(device=self.device)
211
+ return text_token["input_ids"].squeeze(0)
212
+
213
+ def preprocess_audio(self, filepath: str) -> torch.Tensor:
214
+ if self.tar_file is not None:
215
+ file = self.tar_file.extractfile(filepath)
216
+ waveform, sample_rate = torchaudio.load(file)
217
+ else:
218
+ waveform, sample_rate = torchaudio.load(filepath)
219
+ waveform = waveform.to(device=self.vocos_device)
220
+ if sample_rate != self.sample_rate:
221
+ waveform = torchaudio.functional.resample(
222
+ waveform,
223
+ orig_freq=sample_rate,
224
+ new_freq=self.sample_rate,
225
+ )
226
+ mel_spec: torch.Tensor = self.vocos.feature_extractor(waveform)
227
+ return (
228
+ mel_spec.to(device=self.device).squeeze(0).transpose(0, 1)
229
+ ) # (audio_len*2, 100)
230
+
231
+
232
+ class JsonFolder(AudioFolder):
233
+ """
234
+ In json file, each item is formatted as following example:
235
+ `{"filepath": "path/to/file.wav", "speaker": "John", "lang": "ZH", "text": "Hello"}`.
236
+
237
+ filepath is relative to the dirname of root json file.
238
+ """
239
+
240
+ def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
241
+ with open(root, "r", encoding="utf-8") as f:
242
+ raw_data = json.load(f)
243
+ return raw_data
244
+
245
+ @staticmethod
246
+ def save_config(
247
+ save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
248
+ ) -> None:
249
+ save_data = [item.copy() for item in lazy_data]
250
+ for item in save_data:
251
+ item["filepath"] = os.path.relpath(item["filepath"], rel_path)
252
+ with open(save_path, "w", encoding="utf-8") as f:
253
+ json.dump(save_data, f, ensure_ascii=False, indent=4)
254
+
255
+
256
+ class ListFolder(AudioFolder):
257
+ """
258
+ In list file, each row is formatted as `filepath|speaker|lang|text` with `|` as separator.
259
+ `path/to/file.wav|John|ZH|Hello`.
260
+
261
+ filepath is relative to the dirname of root list file.
262
+ """
263
+
264
+ def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
265
+ raw_data = []
266
+ with open(root, "r", encoding="utf-8") as f:
267
+ for line in f.readlines():
268
+ line = line.strip().removesuffix("\n")
269
+ if len(line) == 0:
270
+ continue
271
+ filepath, speaker, lang, text = line.split(sep="|", maxsplit=3)
272
+ raw_data.append(
273
+ {
274
+ "text": text,
275
+ "filepath": filepath,
276
+ "speaker": speaker,
277
+ "lang": lang,
278
+ }
279
+ )
280
+ return raw_data
281
+
282
+ @staticmethod
283
+ def save_config(
284
+ save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./"
285
+ ) -> None:
286
+ save_data = [item.copy() for item in lazy_data]
287
+ for item in save_data:
288
+ item["filepath"] = os.path.relpath(item["filepath"], rel_path)
289
+ with open(save_path, "w", encoding="utf-8") as f:
290
+ for item in save_data:
291
+ f.write(
292
+ f"{item['filepath']}|{item['speaker']}|{item['lang']}|{item['text']}\n"
293
+ )
294
+
295
+
296
+ class XzListTar(ListFolder):
297
+ def __init__(
298
+ self,
299
+ *args,
300
+ root: str | io.BytesIO,
301
+ tar_path: str | None = None,
302
+ **kwargs,
303
+ ):
304
+ if isinstance(root, io.BytesIO):
305
+ assert tar_path is not None
306
+ else:
307
+ # make sure root is a list file
308
+ if not root.endswith(".list"): # folder case
309
+ if os.path.isfile(root):
310
+ raise FileExistsError(f"{root} is a file!")
311
+ elif not os.path.exists(root):
312
+ os.makedirs(root)
313
+ root = os.path.join(root, "all.list")
314
+ if isinstance(root, str) and not os.path.isfile(root):
315
+ # prepare all.list
316
+ self.concat_dataset(
317
+ save_folder=os.path.dirname(root),
318
+ langs=kwargs.get("langs", ["zh", "en"]),
319
+ )
320
+
321
+ super().__init__(root, *args, tar_path=tar_path, **kwargs)
322
+
323
+ def concat_dataset(
324
+ self, save_folder: str | None = None, langs: list[str] = ["zh", "en"]
325
+ ) -> None:
326
+ if save_folder is None:
327
+ save_folder = os.path.dirname(self.root)
328
+ if os.path.isfile(save_folder):
329
+ raise FileExistsError(f"{save_folder} already exists as a file!")
330
+ elif not os.path.exists(save_folder):
331
+ os.makedirs(save_folder)
332
+ lazy_data = []
333
+
334
+ for member in self.tar_file.getmembers():
335
+ if not member.isfile():
336
+ continue
337
+ if member.name.endswith(".list"):
338
+ print(member.name)
339
+ root_io = self.tar_file.extractfile(member)
340
+ lazy_data += ListFolder(root_io).lazy_data
341
+ if member.name.endswith(".json"):
342
+ print(member.name)
343
+ root_io = self.tar_file.extractfile(member)
344
+ lazy_data += JsonFolder(root_io).lazy_data
345
+ if langs is not None:
346
+ lazy_data = [item for item in lazy_data if item["lang"] in langs]
347
+ ListFolder.save_config(os.path.join(save_folder, "all.list"), lazy_data)
348
+ JsonFolder.save_config(os.path.join(save_folder, "all.json"), lazy_data)
349
+ print(f"all.list and all.json are saved to {save_folder}")
350
+
351
+
352
+ class XzListFolder(ListFolder):
353
+ """
354
+ [Xz乔希](https://space.bilibili.com/5859321)
355
+
356
+ Only look at the basename of filepath in list file. Previous folder paths are ignored.
357
+ Files are organized as `[list basename]/[file basename]`
358
+
359
+ Example tree structure:
360
+
361
+ [folder]
362
+ ├── speaker_A
363
+ │ ├── 1.wav
364
+ │ └── 2.wav
365
+ ├── speaker_A.list
366
+ ├── speaker_B
367
+ │ ├── 1.wav
368
+ │ └── 2.wav
369
+ └── speaker_B.list
370
+ """
371
+
372
+ def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]:
373
+ raw_data = super().get_raw_data(root)
374
+ for item in raw_data:
375
+ item["filepath"] = os.path.join(
376
+ os.path.basename(root).removesuffix(".list"),
377
+ os.path.basename(item["filepath"]),
378
+ )
379
+ return raw_data
380
+
381
+
382
+ class AudioCollator:
383
+ def __init__(self, text_pad: int = 0, audio_pad: int = 0):
384
+ self.text_pad = text_pad
385
+ self.audio_pad = audio_pad
386
+
387
+ def __call__(self, batch: list[DataType]):
388
+ batch = [x for x in batch if x is not None]
389
+
390
+ audio_maxlen = max(len(item["audio_attention_mask"]) for item in batch)
391
+ text_maxlen = max(len(item["text_attention_mask"]) for item in batch)
392
+
393
+ filepath = []
394
+ speaker = []
395
+ lang = []
396
+ text = []
397
+ text_input_ids = []
398
+ text_attention_mask = []
399
+ audio_mel_specs = []
400
+ audio_attention_mask = []
401
+
402
+ for x in batch:
403
+ filepath.append(x["filepath"])
404
+ speaker.append(x["speaker"])
405
+ lang.append(x["lang"])
406
+ text.append(x["text"])
407
+ text_input_ids.append(
408
+ torch.nn.functional.pad(
409
+ x["text_input_ids"],
410
+ (text_maxlen - len(x["text_input_ids"]), 0),
411
+ value=self.text_pad,
412
+ )
413
+ )
414
+ text_attention_mask.append(
415
+ torch.nn.functional.pad(
416
+ x["text_attention_mask"],
417
+ (text_maxlen - len(x["text_attention_mask"]), 0),
418
+ value=0,
419
+ )
420
+ )
421
+ audio_mel_specs.append(
422
+ torch.nn.functional.pad(
423
+ x["audio_mel_specs"],
424
+ (0, 0, 0, audio_maxlen * 2 - len(x["audio_mel_specs"])),
425
+ value=self.audio_pad,
426
+ )
427
+ )
428
+ audio_attention_mask.append(
429
+ torch.nn.functional.pad(
430
+ x["audio_attention_mask"],
431
+ (0, audio_maxlen - len(x["audio_attention_mask"])),
432
+ value=0,
433
+ )
434
+ )
435
+ return {
436
+ "filepath": filepath,
437
+ "speaker": speaker,
438
+ "lang": lang,
439
+ "text": text,
440
+ "text_input_ids": torch.stack(text_input_ids),
441
+ "text_attention_mask": torch.stack(text_attention_mask),
442
+ "audio_mel_specs": torch.stack(audio_mel_specs),
443
+ "audio_attention_mask": torch.stack(audio_attention_mask),
444
+ }
445
+
446
+
447
+ def formalize_xz_list(src_folder: str):
448
+ for root, _, files in os.walk(src_folder):
449
+ for file in files:
450
+ if file.endswith(".list"):
451
+ filepath = os.path.join(root, file)
452
+ print(filepath)
453
+ lazy_data = XzListFolder(filepath).lazy_data
454
+ XzListFolder.save_config(filepath, lazy_data, rel_path=src_folder)
455
+
456
+
457
+ def concat_dataset(
458
+ src_folder: str, save_folder: str | None = None, langs: list[str] = ["zh", "en"]
459
+ ) -> None:
460
+ if save_folder is None:
461
+ save_folder = src_folder
462
+ if os.path.isfile(save_folder):
463
+ raise FileExistsError(f"{save_folder} already exists as a file!")
464
+ elif not os.path.exists(save_folder):
465
+ os.makedirs(save_folder)
466
+ lazy_data = []
467
+ same_folder = os.path.samefile(src_folder, save_folder)
468
+ for root, _, files in os.walk(src_folder):
469
+ for file in files:
470
+ filepath = os.path.join(root, file)
471
+ if same_folder and file in ("all.list", "all.json"):
472
+ continue
473
+ if file.endswith(".list"):
474
+ print(filepath)
475
+ lazy_data += ListFolder(filepath).lazy_data
476
+ if file.endswith(".json"):
477
+ print(filepath)
478
+ lazy_data += JsonFolder(filepath).lazy_data
479
+ if langs is not None:
480
+ lazy_data = [item for item in lazy_data if item["lang"] in langs]
481
+ ListFolder.save_config(
482
+ os.path.join(save_folder, "all.list"), lazy_data, rel_path=save_folder
483
+ )
484
+ JsonFolder.save_config(
485
+ os.path.join(save_folder, "all.json"), lazy_data, rel_path=save_folder
486
+ )
487
+ print(f"all.list and all.json are saved to {save_folder}")
modules/finetune/utils/logger.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import statistics
4
+ import time
5
+ from collections import defaultdict, deque
6
+ from tqdm import tqdm as tqdm_class
7
+
8
+ from typing import Generator, Iterable, TypeVar
9
+ from typing_extensions import Self
10
+
11
+ import torch
12
+ import torch.distributed as dist
13
+
14
+ from .output import ansi, prints, get_ansi_len
15
+
16
+ __all__ = ["SmoothedValue", "MetricLogger"]
17
+
18
+ MB = 1 << 20
19
+ T = TypeVar("T")
20
+
21
+
22
+ class SmoothedValue:
23
+ r"""Track a series of values and provide access to smoothed values over a
24
+ window or the global series average.
25
+
26
+ See Also:
27
+ https://github.com/pytorch/vision/blob/main/references/classification/utils.py
28
+
29
+ Args:
30
+ name (str): Name string.
31
+ window_size (int): The :attr:`maxlen` of :class:`~collections.deque`.
32
+ fmt (str): The format pattern of ``str(self)``.
33
+
34
+ Attributes:
35
+ name (str): Name string.
36
+ fmt (str): The string pattern.
37
+ deque (~collections.deque): The unique data series.
38
+ count (int): The amount of data.
39
+ total (float): The sum of all data.
40
+
41
+ median (float): The median of :attr:`deque`.
42
+ avg (float): The avg of :attr:`deque`.
43
+ global_avg (float): :math:`\frac{\text{total}}{\text{count}}`
44
+ max (float): The max of :attr:`deque`.
45
+ min (float): The min of :attr:`deque`.
46
+ last_value (float): The last value of :attr:`deque`.
47
+ """
48
+
49
+ def __init__(
50
+ self, name: str = "", window_size: int = None, fmt: str = "{global_avg:.3f}"
51
+ ):
52
+ self.name = name
53
+ self.deque: deque[float] = deque(maxlen=window_size)
54
+ self.count: int = 0
55
+ self.total: float = 0.0
56
+ self.fmt = fmt
57
+
58
+ def update(self, value: float, n: int = 1) -> Self:
59
+ r"""Update :attr:`n` pieces of data with same :attr:`value`.
60
+
61
+ .. code-block:: python
62
+
63
+ self.deque.append(value)
64
+ self.total += value * n
65
+ self.count += n
66
+
67
+ Args:
68
+ value (float): the value to update.
69
+ n (int): the number of data with same :attr:`value`.
70
+
71
+ Returns:
72
+ SmoothedValue: return ``self`` for stream usage.
73
+ """
74
+ self.deque.append(value)
75
+ self.total += value * n
76
+ self.count += n
77
+ return self
78
+
79
+ def update_list(self, value_list: list[float]) -> Self:
80
+ r"""Update :attr:`value_list`.
81
+
82
+ .. code-block:: python
83
+
84
+ for value in value_list:
85
+ self.deque.append(value)
86
+ self.total += value
87
+ self.count += len(value_list)
88
+
89
+ Args:
90
+ value_list (list[float]): the value list to update.
91
+
92
+ Returns:
93
+ SmoothedValue: return ``self`` for stream usage.
94
+ """
95
+ for value in value_list:
96
+ self.deque.append(value)
97
+ self.total += value
98
+ self.count += len(value_list)
99
+ return self
100
+
101
+ def reset(self) -> Self:
102
+ r"""Reset ``deque``, ``count`` and ``total`` to be empty.
103
+
104
+ Returns:
105
+ SmoothedValue: return ``self`` for stream usage.
106
+ """
107
+ self.deque = deque(maxlen=self.deque.maxlen)
108
+ self.count = 0
109
+ self.total = 0.0
110
+ return self
111
+
112
+ def synchronize_between_processes(self):
113
+ r"""
114
+ Warning:
115
+ Does NOT synchronize the deque!
116
+ """
117
+ if not (dist.is_available() and dist.is_initialized()):
118
+ return
119
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
120
+ dist.barrier()
121
+ dist.all_reduce(t)
122
+ t = t.tolist()
123
+ self.count = int(t[0])
124
+ self.total = float(t[1])
125
+
126
+ @property
127
+ def median(self) -> float:
128
+ try:
129
+ return statistics.median(self.deque)
130
+ except Exception:
131
+ return 0.0
132
+
133
+ @property
134
+ def avg(self) -> float:
135
+ try:
136
+ return statistics.mean(self.deque)
137
+ except Exception:
138
+ return 0.0
139
+
140
+ @property
141
+ def global_avg(self) -> float:
142
+ try:
143
+ return self.total / self.count
144
+ except Exception:
145
+ return 0.0
146
+
147
+ @property
148
+ def max(self) -> float:
149
+ try:
150
+ return max(self.deque)
151
+ except Exception:
152
+ return 0.0
153
+
154
+ @property
155
+ def min(self) -> float:
156
+ try:
157
+ return min(self.deque)
158
+ except Exception:
159
+ return 0.0
160
+
161
+ @property
162
+ def last_value(self) -> float:
163
+ try:
164
+ return self.deque[-1]
165
+ except Exception:
166
+ return 0.0
167
+
168
+ def __str__(self):
169
+ return self.fmt.format(
170
+ name=self.name,
171
+ count=self.count,
172
+ total=self.total,
173
+ median=self.median,
174
+ avg=self.avg,
175
+ global_avg=self.global_avg,
176
+ min=self.min,
177
+ max=self.max,
178
+ last_value=self.last_value,
179
+ )
180
+
181
+ def __format__(self, format_spec: str) -> str:
182
+ return self.__str__()
183
+
184
+
185
+ class MetricLogger:
186
+ r"""
187
+ See Also:
188
+ https://github.com/pytorch/vision/blob/main/references/classification/utils.py
189
+
190
+ Args:
191
+ delimiter (str): The delimiter to join different meter strings.
192
+ Defaults to ``''``.
193
+ meter_length (int): The minimum length for each meter.
194
+ Defaults to ``20``.
195
+ tqdm (bool): Whether to use tqdm to show iteration information.
196
+ Defaults to ``env['tqdm']``.
197
+ indent (int): The space indent for the entire string.
198
+ Defaults to ``0``.
199
+
200
+ Attributes:
201
+ meters (dict[str, SmoothedValue]): The meter dict.
202
+ iter_time (SmoothedValue): Iteration time meter.
203
+ data_time (SmoothedValue): Data loading time meter.
204
+ memory (SmoothedValue): Memory usage meter.
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ delimiter: str = "",
210
+ meter_length: int = 20,
211
+ tqdm: bool = True,
212
+ indent: int = 0,
213
+ **kwargs,
214
+ ):
215
+ self.meters: defaultdict[str, SmoothedValue] = defaultdict(SmoothedValue)
216
+ self.create_meters(**kwargs)
217
+ self.delimiter = delimiter
218
+ self.meter_length = meter_length
219
+ self.tqdm = tqdm
220
+ self.indent = indent
221
+
222
+ self.iter_time = SmoothedValue()
223
+ self.data_time = SmoothedValue()
224
+ self.memory = SmoothedValue(fmt="{max:.0f}")
225
+
226
+ def create_meters(self, **kwargs: str) -> Self:
227
+ r"""Create meters with specific ``fmt`` in :attr:`self.meters`.
228
+
229
+ ``self.meters[meter_name] = SmoothedValue(fmt=fmt)``
230
+
231
+ Args:
232
+ **kwargs: ``(meter_name: fmt)``
233
+
234
+ Returns:
235
+ MetricLogger: return ``self`` for stream usage.
236
+ """
237
+ for k, v in kwargs.items():
238
+ self.meters[k] = SmoothedValue(fmt="{global_avg:.3f}" if v is None else v)
239
+ return self
240
+
241
+ def update(self, n: int = 1, **kwargs: float) -> Self:
242
+ r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update()`.
243
+
244
+ ``self.meters[meter_name].update(float(value), n=n)``
245
+
246
+ Args:
247
+ n (int): the number of data with same value.
248
+ **kwargs: ``{meter_name: value}``.
249
+
250
+ Returns:
251
+ MetricLogger: return ``self`` for stream usage.
252
+ """
253
+ for k, v in kwargs.items():
254
+ if k not in self.meters:
255
+ self.meters[k] = SmoothedValue()
256
+ self.meters[k].update(float(v), n=n)
257
+ return self
258
+
259
+ def update_list(self, **kwargs: list) -> Self:
260
+ r"""Update values to :attr:`self.meters` by calling :meth:`SmoothedValue.update_list()`.
261
+
262
+ ``self.meters[meter_name].update_list(value_list)``
263
+
264
+ Args:
265
+ **kwargs: ``{meter_name: value_list}``.
266
+
267
+ Returns:
268
+ MetricLogger: return ``self`` for stream usage.
269
+ """
270
+ for k, v in kwargs.items():
271
+ self.meters[k].update_list(v)
272
+ return self
273
+
274
+ def reset(self) -> Self:
275
+ r"""Reset meter in :attr:`self.meters` by calling :meth:`SmoothedValue.reset()`.
276
+
277
+ Returns:
278
+ MetricLogger: return ``self`` for stream usage.
279
+ """
280
+ for meter in self.meters.values():
281
+ meter.reset()
282
+ return self
283
+
284
+ def get_str(self, cut_too_long: bool = True, strip: bool = True, **kwargs) -> str:
285
+ r"""Generate formatted string based on keyword arguments.
286
+
287
+ ``key: value`` with max length to be :attr:`self.meter_length`.
288
+
289
+ Args:
290
+ cut_too_long (bool): Whether to cut too long values to first 5 characters.
291
+ Defaults to ``True``.
292
+ strip (bool): Whether to strip trailing whitespaces.
293
+ Defaults to ``True``.
294
+ **kwargs: Keyword arguments to generate string.
295
+ """
296
+ str_list: list[str] = []
297
+ for k, v in kwargs.items():
298
+ v_str = str(v)
299
+ _str: str = "{green}{k}{reset}: {v}".format(k=k, v=v_str, **ansi)
300
+ max_length = self.meter_length + get_ansi_len(_str)
301
+ if cut_too_long:
302
+ _str = _str[:max_length]
303
+ str_list.append(_str.ljust(max_length))
304
+ _str = self.delimiter.join(str_list)
305
+ if strip:
306
+ _str = _str.rstrip()
307
+ return _str
308
+
309
+ def __getattr__(self, attr: str) -> float:
310
+ if attr in self.meters:
311
+ return self.meters[attr]
312
+ if attr in vars(self): # TODO: use hasattr
313
+ return vars(self)[attr]
314
+ raise AttributeError(
315
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
316
+ )
317
+
318
+ def __str__(self) -> str:
319
+ return self.get_str(**self.meters)
320
+
321
+ def synchronize_between_processes(self):
322
+ for meter in self.meters.values():
323
+ meter.synchronize_between_processes()
324
+
325
+ def log_every(
326
+ self,
327
+ iterable: Iterable[T],
328
+ header: str = "",
329
+ tqdm: bool = None,
330
+ tqdm_header: str = "Iter",
331
+ indent: int = None,
332
+ verbose: int = 1,
333
+ ) -> Generator[T, None, None]:
334
+ r"""Wrap an :class:`collections.abc.Iterable` with formatted outputs.
335
+
336
+ * Middle Output:
337
+ ``{tqdm_header}: [ current / total ] str(self) {memory} {iter_time} {data_time} {time}<{remaining}``
338
+ * Final Output
339
+ ``{header} str(self) {memory} {iter_time} {data_time} {total_time}``
340
+
341
+ Args:
342
+ iterable (~collections.abc.Iterable): The raw iterator.
343
+ header (str): The header string for final output.
344
+ Defaults to ``''``.
345
+ tqdm (bool): Whether to use tqdm to show iteration information.
346
+ Defaults to ``self.tqdm``.
347
+ tqdm_header (str): The header string for middle output.
348
+ Defaults to ``'Iter'``.
349
+ indent (int): The space indent for the entire string.
350
+ if ``None``, use ``self.indent``.
351
+ Defaults to ``None``.
352
+ verbose (int): The verbose level of output information.
353
+ """
354
+ tqdm = tqdm if tqdm is not None else self.tqdm
355
+ indent = indent if indent is not None else self.indent
356
+ iterator = iterable
357
+ if len(header) != 0:
358
+ header = header.ljust(30 + get_ansi_len(header))
359
+ if tqdm:
360
+ length = len(str(len(iterable)))
361
+ pattern: str = (
362
+ "{tqdm_header}: {blue_light}"
363
+ "[ {red}{{n_fmt:>{length}}}{blue_light} "
364
+ "/ {red}{{total_fmt}}{blue_light} ]{reset}"
365
+ ).format(tqdm_header=tqdm_header, length=length, **ansi)
366
+ offset = len(f"{{n_fmt:>{length}}}{{total_fmt}}") - 2 * length
367
+ pattern = pattern.ljust(30 + offset + get_ansi_len(pattern))
368
+ time_str = self.get_str(time="{elapsed}<{remaining}", cut_too_long=False)
369
+ bar_format = f"{pattern}{{desc}}{time_str}"
370
+ iterator = tqdm_class(iterable, leave=False, bar_format=bar_format)
371
+
372
+ self.iter_time.reset()
373
+ self.data_time.reset()
374
+ self.memory.reset()
375
+
376
+ end = time.time()
377
+ start_time = time.time()
378
+ for obj in iterator:
379
+ cur_data_time = time.time() - end
380
+ self.data_time.update(cur_data_time)
381
+ yield obj
382
+ cur_iter_time = time.time() - end
383
+ self.iter_time.update(cur_iter_time)
384
+ if torch.cuda.is_available():
385
+ cur_memory = torch.cuda.max_memory_allocated() / MB
386
+ self.memory.update(cur_memory)
387
+ if tqdm:
388
+ _dict = {k: v for k, v in self.meters.items()}
389
+ if verbose > 2 and torch.cuda.is_available():
390
+ _dict.update(memory=f"{cur_memory:.0f} MB")
391
+ if verbose > 1:
392
+ _dict.update(
393
+ iter=f"{cur_iter_time:.3f} s", data=f"{cur_data_time:.3f} s"
394
+ )
395
+ iterator.set_description_str(self.get_str(**_dict, strip=False))
396
+ end = time.time()
397
+ self.synchronize_between_processes()
398
+ total_time = time.time() - start_time
399
+ total_time_str = tqdm_class.format_interval(total_time)
400
+
401
+ _dict = {k: v for k, v in self.meters.items()}
402
+ if verbose > 2 and torch.cuda.is_available():
403
+ _dict.update(memory=f"{str(self.memory)} MB")
404
+ if verbose > 1:
405
+ _dict.update(
406
+ iter=f"{str(self.iter_time)} s", data=f"{str(self.data_time)} s"
407
+ )
408
+ _dict.update(time=total_time_str)
409
+ prints(self.delimiter.join([header, self.get_str(**_dict)]), indent=indent)
modules/finetune/utils/model.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from vector_quantize_pytorch.residual_fsq import GroupedResidualFSQ
4
+
5
+
6
+ def quantize(
7
+ quantizer: GroupedResidualFSQ,
8
+ audio_latents: torch.Tensor, # (batch_size, audio_len, audio_dim=1024)
9
+ ) -> tuple[torch.Tensor, torch.Tensor]:
10
+ # feat shape (batch_size, audio_len, audio_dim)
11
+ # ind shape (GFSQ.G, batch_size, audio_len, GFSQ.R)
12
+ # num_vq=GFSQ.G*GFSQ.R
13
+ feat, ind = quantizer(audio_latents)
14
+ audio_quantized_latents = feat # (batch_size, audio_len, audio_dim)
15
+ audio_input_ids = rearrange( # (batch_size, audio_len, num_vq)
16
+ ind,
17
+ "g b t r ->b t (g r)",
18
+ )
19
+ return audio_quantized_latents, audio_input_ids
modules/finetune/utils/output.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import re
4
+ import sys
5
+ from contextlib import contextmanager
6
+
7
+
8
+ class ANSI:
9
+ ansi_color = {
10
+ "black": "\033[30m",
11
+ "red": "\033[31m",
12
+ "green": "\033[32m",
13
+ "yellow": "\033[33m",
14
+ "blue": "\033[34m",
15
+ "purple": "\033[35m",
16
+ "blue_light": "\033[36m",
17
+ "white": "\033[37m",
18
+ "reset": "\033[0m",
19
+ "upline": "\033[1A",
20
+ "clear_line": "\033[2K",
21
+ "clear": "\033[2J",
22
+ }
23
+ ansi_nocolor = {
24
+ "black": "",
25
+ "red": "",
26
+ "green": "",
27
+ "yellow": "",
28
+ "blue": "",
29
+ "purple": "",
30
+ "blue_light": "",
31
+ "white": "",
32
+ "reset": "",
33
+ "upline": "\033[1A\033[",
34
+ "clear_line": "\033[K",
35
+ "clear": "\033[2J",
36
+ }
37
+
38
+ def __init__(self):
39
+ self._dict = ANSI.ansi_color if ("--color" in sys.argv) else ANSI.ansi_nocolor
40
+
41
+ def switch(self, color: bool):
42
+ self._dict = ANSI.ansi_color if color else ANSI.ansi_nocolor
43
+
44
+ def keys(self):
45
+ return self._dict.keys()
46
+
47
+ def items(self):
48
+ return self._dict.items()
49
+
50
+ def __getitem__(self, key):
51
+ return self._dict[key]
52
+
53
+ def __str__(self):
54
+ return str(self._dict)
55
+
56
+ def __repr__(self):
57
+ return repr(self._dict)
58
+
59
+
60
+ ansi = ANSI()
61
+
62
+
63
+ def remove_ansi(s: str) -> str:
64
+ ansi_escape = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]")
65
+ return ansi_escape.sub("", s)
66
+
67
+
68
+ def get_ansi_len(s: str) -> int:
69
+ return len(s) - len(remove_ansi(s))
70
+
71
+
72
+ def prints(*args: str, indent: int = 0, prefix: str = "", **kwargs):
73
+ assert indent >= 0
74
+ new_args = []
75
+ for arg in args:
76
+ new_args.append(indent_str(str(arg), indent=indent))
77
+ if len(new_args):
78
+ new_args[0] = prefix + str(new_args[0])
79
+ print(*new_args, **kwargs)
80
+
81
+
82
+ def output_iter(_iter: int, iteration: int = None, iter_len: int = 4) -> str:
83
+ if iteration is None:
84
+ pattern = "{blue_light}[ {red}{0}{blue_light} ]{reset}"
85
+ return pattern.format(str(_iter).rjust(iter_len), **ansi)
86
+ else:
87
+ iter_str = str(iteration)
88
+ length = len(iter_str)
89
+ pattern = (
90
+ "{blue_light}[ {red}{0}{blue_light} " "/ {red}{1}{blue_light} ]{reset}"
91
+ )
92
+ return pattern.format(str(_iter).rjust(length), iter_str, **ansi)
93
+
94
+
95
+ def indent_str(s_: str, indent: int = 0) -> str:
96
+ # modified from torch.nn.modules._addindent
97
+ if indent > 0 and s_:
98
+ s_ = indent * " " + str(s_[:-1]).replace("\n", "\n" + indent * " ") + s_[-1]
99
+ return s_
100
+
101
+
102
+ class IndentRedirect: # TODO: inherit TextIOWrapper?
103
+ def __init__(self, buffer: bool = True, indent: int = 0):
104
+ self.__console__ = sys.stdout
105
+ self.indent = indent
106
+ self.__buffer: str = None
107
+ if buffer:
108
+ self.__buffer = ""
109
+
110
+ def write(self, text: str, indent: int = None):
111
+ indent = indent if indent is not None else self.indent
112
+ text = indent_str(text, indent=indent)
113
+ if self.__buffer is None:
114
+ self.__console__.write(text)
115
+ else:
116
+ self.__buffer += text
117
+
118
+ def flush(self):
119
+ if self.__buffer is not None:
120
+ self.__console__.write(self.__buffer)
121
+ self.__buffer = ""
122
+ self.__console__.flush()
123
+
124
+ @contextmanager
125
+ def __call__(self) -> None:
126
+ try:
127
+ sys.stdout = self
128
+ yield
129
+ finally:
130
+ sys.stdout = self.__console__
131
+ self.__buffer = ""
132
+
133
+ def enable(self):
134
+ sys.stdout = self
135
+
136
+ def disable(self):
137
+ if self.__buffer is not None:
138
+ self.__buffer = ""
139
+ sys.stdout = self.__console__
140
+
141
+ @property
142
+ def buffer(self) -> str:
143
+ return self.__buffer
144
+
145
+
146
+ redirect = IndentRedirect()
modules/generate_audio.py CHANGED
@@ -76,6 +76,8 @@ def generate_audio_batch(
76
  params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
77
  logger.debug(("spk", spk))
78
  elif isinstance(spk, Speaker):
 
 
79
  params_infer_code["spk_emb"] = spk.emb
80
  logger.debug(("spk", spk.name))
81
  else:
 
76
  params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
77
  logger.debug(("spk", spk))
78
  elif isinstance(spk, Speaker):
79
+ if not isinstance(spk.emb, torch.Tensor):
80
+ raise ValueError("spk.pt is broken, please retrain the model.")
81
  params_infer_code["spk_emb"] = spk.emb
82
  logger.debug(("spk", spk.name))
83
  else:
modules/normalization.py CHANGED
@@ -120,6 +120,7 @@ character_map = {
120
  "~": " ",
121
  "~": " ",
122
  "/": " ",
 
123
  }
124
 
125
  character_to_word = {
@@ -282,6 +283,9 @@ def text_normalize(text, is_end=False):
282
 
283
 
284
  if __name__ == "__main__":
 
 
 
285
  test_cases = [
286
  "ChatTTS是专门为对话场景设计的文本转语音模型,例如LLM助手对话任务。它支持英文和中文两种语言。最大的模型使用了10万小时以上的中英文数据进行训练。在HuggingFace中开源的版本为4万小时训练且未SFT的版本.",
287
  " [oral_9] [laugh_0] [break_0] 电 [speed_0] 影 [speed_0] 中 梁朝伟 [speed_9] 扮演的陈永仁的编号27149",
@@ -319,6 +323,7 @@ State-of-the-art Machine Learning for PyTorch, TensorFlow, and JAX.
319
  """
320
  120米
321
  有12%的概率会下雨
 
322
  """,
323
  ]
324
 
 
120
  "~": " ",
121
  "~": " ",
122
  "/": " ",
123
+ "·": " ",
124
  }
125
 
126
  character_to_word = {
 
283
 
284
 
285
  if __name__ == "__main__":
286
+ from modules.devices import devices
287
+
288
+ devices.reset_device()
289
  test_cases = [
290
  "ChatTTS是专门为对话场景设计的文本转语音模型,例如LLM助手对话任务。它支持英文和中文两种语言。最大的模型使用了10万小时以上的中英文数据进行训练。在HuggingFace中开源的版本为4万小时训练且未SFT的版本.",
291
  " [oral_9] [laugh_0] [break_0] 电 [speed_0] 影 [speed_0] 中 梁朝伟 [speed_9] 扮演的陈永仁的编号27149",
 
323
  """
324
  120米
325
  有12%的概率会下雨
326
+ 埃隆·马斯克
327
  """,
328
  ]
329
 
modules/repos_static/resemble_enhance/data/distorter/base.py CHANGED
@@ -2,6 +2,7 @@ import itertools
2
  import os
3
  import random
4
  import time
 
5
  import warnings
6
 
7
  import numpy as np
@@ -87,7 +88,7 @@ class Choice(Effect):
87
 
88
 
89
  class Permutation(Effect):
90
- def __init__(self, *effects, n: int | None = None):
91
  super().__init__()
92
  self.effects = effects
93
  self.n = n
 
2
  import os
3
  import random
4
  import time
5
+ from typing import Union
6
  import warnings
7
 
8
  import numpy as np
 
88
 
89
 
90
  class Permutation(Effect):
91
+ def __init__(self, *effects, n: Union[int, None] = None):
92
  super().__init__()
93
  self.effects = effects
94
  self.n = n
modules/repos_static/resemble_enhance/data/distorter/custom.py CHANGED
@@ -3,6 +3,7 @@ import random
3
  from dataclasses import dataclass
4
  from functools import cached_property
5
  from pathlib import Path
 
6
 
7
  import librosa
8
  import numpy as np
@@ -16,7 +17,7 @@ _logger = logging.getLogger(__name__)
16
 
17
  @dataclass
18
  class RandomRIR(Effect):
19
- rir_dir: Path | None
20
  rir_rate: int = 44_000
21
  rir_suffix: str = ".npy"
22
  deterministic: bool = False
@@ -49,7 +50,9 @@ class RandomRIR(Effect):
49
 
50
  length = len(wav)
51
 
52
- wav = librosa.resample(wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast")
 
 
53
  rir = self._sample_rir()
54
 
55
  wav = signal.convolve(wav, rir, mode="same")
@@ -58,7 +61,9 @@ class RandomRIR(Effect):
58
  if actlev > 0.99:
59
  wav = (wav / actlev) * 0.98
60
 
61
- wav = librosa.resample(wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast")
 
 
62
 
63
  if abs(length - len(wav)) > 10:
64
  _logger.warning(f"length mismatch: {length} vs {len(wav)}")
 
3
  from dataclasses import dataclass
4
  from functools import cached_property
5
  from pathlib import Path
6
+ from typing import Union
7
 
8
  import librosa
9
  import numpy as np
 
17
 
18
  @dataclass
19
  class RandomRIR(Effect):
20
+ rir_dir: Union[Path, None]
21
  rir_rate: int = 44_000
22
  rir_suffix: str = ".npy"
23
  deterministic: bool = False
 
50
 
51
  length = len(wav)
52
 
53
+ wav = librosa.resample(
54
+ wav, orig_sr=sr, target_sr=self.rir_rate, res_type="kaiser_fast"
55
+ )
56
  rir = self._sample_rir()
57
 
58
  wav = signal.convolve(wav, rir, mode="same")
 
61
  if actlev > 0.99:
62
  wav = (wav / actlev) * 0.98
63
 
64
+ wav = librosa.resample(
65
+ wav, orig_sr=self.rir_rate, target_sr=sr, res_type="kaiser_fast"
66
+ )
67
 
68
  if abs(length - len(wav)) > 10:
69
  _logger.warning(f"length mismatch: {length} vs {len(wav)}")
modules/repos_static/resemble_enhance/data/distorter/sox.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  import os
3
  import random
 
4
  import warnings
5
  from functools import partial
6
 
@@ -29,7 +30,9 @@ class AttachableEffect(Effect):
29
  chain = augment.EffectChain()
30
  chain = self.attach(chain)
31
  tensor = torch.from_numpy(wav)[None].float() # (1, T)
32
- tensor = chain.apply(tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr})
 
 
33
  wav = tensor.numpy()[0] # (T,)
34
  return wav
35
 
@@ -41,7 +44,9 @@ class SoxEffect(AttachableEffect):
41
  self.kwargs = kwargs
42
 
43
  def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
44
- _logger.debug(f"Attaching {self.effect_name} with {self.args} and {self.kwargs}")
 
 
45
  if not hasattr(chain, self.effect_name):
46
  raise ValueError(f"EffectChain has no attribute {self.effect_name}")
47
  return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
@@ -115,21 +120,30 @@ class Randint(Generator):
115
 
116
 
117
  class Concat(Generator):
118
- def __init__(self, *parts: Generator | str):
119
  self.parts = parts
120
 
121
  def __call__(self):
122
- return "".join([part if isinstance(part, str) else part() for part in self.parts])
 
 
123
 
124
 
125
  class RandomLowpassDistorter(SoxEffect):
126
  def __init__(self, low=2000, high=16000):
127
- super().__init__("sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high)))
 
 
128
 
129
 
130
  class RandomBandpassDistorter(SoxEffect):
131
  def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
132
- super().__init__("sinc", "-n", Randint(50, 200), partial(self._fn, low, high, min_width, max_width))
 
 
 
 
 
133
 
134
  @staticmethod
135
  def _fn(low, high, min_width, max_width):
@@ -139,7 +153,15 @@ class RandomBandpassDistorter(SoxEffect):
139
 
140
 
141
  class RandomEqualizer(SoxEffect):
142
- def __init__(self, low=100, high=4000, q_low=1, q_high=5, db_low: int = -30, db_high: int = 30):
 
 
 
 
 
 
 
 
143
  super().__init__(
144
  "equalizer",
145
  Uniform(low, high),
@@ -150,7 +172,9 @@ class RandomEqualizer(SoxEffect):
150
 
151
  class RandomOverdrive(SoxEffect):
152
  def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
153
- super().__init__("overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high))
 
 
154
 
155
 
156
  class RandomReverb(Chain):
 
1
  import logging
2
  import os
3
  import random
4
+ from typing import Union
5
  import warnings
6
  from functools import partial
7
 
 
30
  chain = augment.EffectChain()
31
  chain = self.attach(chain)
32
  tensor = torch.from_numpy(wav)[None].float() # (1, T)
33
+ tensor = chain.apply(
34
+ tensor, src_info={"rate": sr}, target_info={"channels": 1, "rate": sr}
35
+ )
36
  wav = tensor.numpy()[0] # (T,)
37
  return wav
38
 
 
44
  self.kwargs = kwargs
45
 
46
  def attach(self, chain: augment.EffectChain) -> augment.EffectChain:
47
+ _logger.debug(
48
+ f"Attaching {self.effect_name} with {self.args} and {self.kwargs}"
49
+ )
50
  if not hasattr(chain, self.effect_name):
51
  raise ValueError(f"EffectChain has no attribute {self.effect_name}")
52
  return getattr(chain, self.effect_name)(*self.args, **self.kwargs)
 
120
 
121
 
122
  class Concat(Generator):
123
+ def __init__(self, *parts: Union[Generator, str]):
124
  self.parts = parts
125
 
126
  def __call__(self):
127
+ return "".join(
128
+ [part if isinstance(part, str) else part() for part in self.parts]
129
+ )
130
 
131
 
132
  class RandomLowpassDistorter(SoxEffect):
133
  def __init__(self, low=2000, high=16000):
134
+ super().__init__(
135
+ "sinc", "-n", Randint(50, 200), Concat("-", Uniform(low, high))
136
+ )
137
 
138
 
139
  class RandomBandpassDistorter(SoxEffect):
140
  def __init__(self, low=100, high=1000, min_width=2000, max_width=4000):
141
+ super().__init__(
142
+ "sinc",
143
+ "-n",
144
+ Randint(50, 200),
145
+ partial(self._fn, low, high, min_width, max_width),
146
+ )
147
 
148
  @staticmethod
149
  def _fn(low, high, min_width, max_width):
 
153
 
154
 
155
  class RandomEqualizer(SoxEffect):
156
+ def __init__(
157
+ self,
158
+ low=100,
159
+ high=4000,
160
+ q_low=1,
161
+ q_high=5,
162
+ db_low: int = -30,
163
+ db_high: int = 30,
164
+ ):
165
  super().__init__(
166
  "equalizer",
167
  Uniform(low, high),
 
172
 
173
  class RandomOverdrive(SoxEffect):
174
  def __init__(self, gain_low=5, gain_high=40, colour_low=20, colour_high=80):
175
+ super().__init__(
176
+ "overdrive", Uniform(gain_low, gain_high), Uniform(colour_low, colour_high)
177
+ )
178
 
179
 
180
  class RandomReverb(Chain):
modules/repos_static/resemble_enhance/data/utils.py CHANGED
@@ -1,5 +1,5 @@
1
  from pathlib import Path
2
- from typing import Callable
3
 
4
  from torch import Tensor
5
 
@@ -16,7 +16,9 @@ def rglob_audio_files(path: Path):
16
  return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
17
 
18
 
19
- def mix_fg_bg(fg: Tensor, bg: Tensor, alpha: float | Callable[..., float] = 0.5, eps=1e-7):
 
 
20
  """
21
  Args:
22
  fg: (b, t)
 
1
  from pathlib import Path
2
+ from typing import Callable, Union
3
 
4
  from torch import Tensor
5
 
 
16
  return list(walk_paths(path, ".wav")) + list(walk_paths(path, ".flac"))
17
 
18
 
19
+ def mix_fg_bg(
20
+ fg: Tensor, bg: Tensor, alpha: Union[float, Callable[..., float]] = 0.5, eps=1e-7
21
+ ):
22
  """
23
  Args:
24
  fg: (b, t)
modules/repos_static/resemble_enhance/denoiser/denoiser.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
 
3
  import torch
4
  import torch.nn.functional as F
@@ -154,7 +155,7 @@ class Denoiser(nn.Module):
154
  sep_sin = sin * cos_res + cos * sin_res
155
  return sep_mag, sep_cos, sep_sin
156
 
157
- def forward(self, x: Tensor, y: Tensor | None = None):
158
  """
159
  Args:
160
  x: (b t), a mixed audio
 
1
  import logging
2
+ from typing import Union
3
 
4
  import torch
5
  import torch.nn.functional as F
 
155
  sep_sin = sin * cos_res + cos * sin_res
156
  return sep_mag, sep_cos, sep_sin
157
 
158
+ def forward(self, x: Tensor, y: Union[Tensor, None] = None):
159
  """
160
  Args:
161
  x: (b t), a mixed audio
modules/repos_static/resemble_enhance/enhancer/download.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  from pathlib import Path
 
3
 
4
  import torch
5
 
@@ -12,14 +13,18 @@ def get_source_url(relpath):
12
  return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
13
 
14
 
15
- def get_target_path(relpath: str | Path, run_dir: str | Path | None = None):
16
  if run_dir is None:
17
  run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
18
  return Path(run_dir) / relpath
19
 
20
 
21
- def download(run_dir: str | Path | None = None):
22
- relpaths = ["hparams.yaml", "ds/G/latest", "ds/G/default/mp_rank_00_model_states.pt"]
 
 
 
 
23
  for relpath in relpaths:
24
  path = get_target_path(relpath, run_dir=run_dir)
25
  if path.exists():
 
1
  import logging
2
  from pathlib import Path
3
+ from typing import Union
4
 
5
  import torch
6
 
 
13
  return f"https://huggingface.co/ResembleAI/resemble-enhance/resolve/main/{RUN_NAME}/{relpath}?download=true"
14
 
15
 
16
+ def get_target_path(relpath: Union[str, Path], run_dir: Union[str, Path, None] = None):
17
  if run_dir is None:
18
  run_dir = Path(__file__).parent.parent / "model_repo" / RUN_NAME
19
  return Path(run_dir) / relpath
20
 
21
 
22
+ def download(run_dir: Union[str, Path, None] = None):
23
+ relpaths = [
24
+ "hparams.yaml",
25
+ "ds/G/latest",
26
+ "ds/G/default/mp_rank_00_model_states.pt",
27
+ ]
28
  for relpath in relpaths:
29
  path = get_target_path(relpath, run_dir=run_dir)
30
  if path.exists():
modules/repos_static/resemble_enhance/enhancer/enhancer.py CHANGED
@@ -1,4 +1,5 @@
1
  import logging
 
2
 
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
@@ -109,7 +110,7 @@ class Enhancer(nn.Module):
109
  return self.mel_fn(x)[..., :-1] # (b d t)
110
  return self.mel_fn(x)
111
 
112
- def _may_denoise(self, x: Tensor, y: Tensor | None = None):
113
  if self.hp.lcfm_training_mode == "cfm":
114
  return self.denoiser(x, y)
115
  return x
@@ -126,7 +127,9 @@ class Enhancer(nn.Module):
126
  self.lcfm.eval_tau_(tau)
127
  self._eval_lambd = lambd
128
 
129
- def forward(self, x: Tensor, y: Tensor | None = None, z: Tensor | None = None):
 
 
130
  """
131
  Args:
132
  x: (b t), mix wavs (fg + bg)
 
1
  import logging
2
+ from typing import Union
3
 
4
  import matplotlib.pyplot as plt
5
  import pandas as pd
 
110
  return self.mel_fn(x)[..., :-1] # (b d t)
111
  return self.mel_fn(x)
112
 
113
+ def _may_denoise(self, x: Tensor, y: Union[Tensor, None] = None):
114
  if self.hp.lcfm_training_mode == "cfm":
115
  return self.denoiser(x, y)
116
  return x
 
127
  self.lcfm.eval_tau_(tau)
128
  self._eval_lambd = lambd
129
 
130
+ def forward(
131
+ self, x: Tensor, y: Union[Tensor, None] = None, z: Union[Tensor, None] = None
132
+ ):
133
  """
134
  Args:
135
  x: (b t), mix wavs (fg + bg)
modules/repos_static/resemble_enhance/enhancer/hparams.py CHANGED
@@ -1,5 +1,6 @@
1
  from dataclasses import dataclass
2
  from pathlib import Path
 
3
 
4
  from ..hparams import HParams as HParamsBase
5
 
@@ -17,7 +18,7 @@ class HParams(HParamsBase):
17
 
18
  vocoder_extra_dim: int = 32
19
 
20
- gan_training_start_step: int | None = 5_000
21
- enhancer_stage1_run_dir: Path | None = None
22
 
23
- denoiser_run_dir: Path | None = None
 
1
  from dataclasses import dataclass
2
  from pathlib import Path
3
+ from typing import Union
4
 
5
  from ..hparams import HParams as HParamsBase
6
 
 
18
 
19
  vocoder_extra_dim: int = 32
20
 
21
+ gan_training_start_step: Union[int, None] = 5_000
22
+ enhancer_stage1_run_dir: Union[Path, None] = None
23
 
24
+ denoiser_run_dir: Union[Path, None] = None
modules/repos_static/resemble_enhance/enhancer/inference.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  from functools import cache
3
  from pathlib import Path
 
4
 
5
  import torch
6
 
@@ -13,7 +14,7 @@ logger = logging.getLogger(__name__)
13
 
14
 
15
  @cache
16
- def load_enhancer(run_dir: str | Path | None, device):
17
  run_dir = download(run_dir)
18
  hp = HParams.load(run_dir)
19
  enhancer = Enhancer(hp)
 
1
  import logging
2
  from functools import cache
3
  from pathlib import Path
4
+ from typing import Union
5
 
6
  import torch
7
 
 
14
 
15
 
16
  @cache
17
+ def load_enhancer(run_dir: Union[str, Path, None], device):
18
  run_dir = download(run_dir)
19
  hp = HParams.load(run_dir)
20
  enhancer = Enhancer(hp)
modules/repos_static/resemble_enhance/enhancer/lcfm/cfm.py CHANGED
@@ -1,7 +1,7 @@
1
  import logging
2
  from dataclasses import dataclass
3
  from functools import partial
4
- from typing import Protocol
5
 
6
  import matplotlib.pyplot as plt
7
  import numpy as np
@@ -17,8 +17,7 @@ logger = logging.getLogger(__name__)
17
 
18
 
19
  class VelocityField(Protocol):
20
- def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor:
21
- ...
22
 
23
 
24
  class Solver:
@@ -40,7 +39,9 @@ class Solver:
40
 
41
  self._camera = None
42
  self._mel_fn = mel_fn
43
- self._time_mapping = partial(self.exponential_decay_mapping, n=time_mapping_divisor)
 
 
44
 
45
  def configurate_(self, nfe=None, method=None):
46
  if nfe is None:
@@ -50,7 +51,9 @@ class Solver:
50
  method = self.method
51
 
52
  if nfe == 1 and method in ("midpoint", "rk4"):
53
- logger.warning(f"1 NFE is not supported for {method}, using euler method instead.")
 
 
54
  method = "euler"
55
 
56
  self.nfe = nfe
@@ -105,7 +108,9 @@ class Solver:
105
  )
106
  else:
107
  # Spectrogram, b c t
108
- plt.imshow(ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none")
 
 
109
  ax = plt.gca()
110
  ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
111
  camera.snap()
@@ -271,7 +276,7 @@ class CFM(nn.Module):
271
  global_dim=self.time_emb_dim,
272
  )
273
 
274
- def _perturb(self, ψ1: Tensor, t: Tensor | None = None):
275
  """
276
  Perturb ψ1 to ψt.
277
  """
@@ -311,7 +316,7 @@ class CFM(nn.Module):
311
  """
312
  return ψ1 - ψ0
313
 
314
- def _to_v(self, *, ψt, x, t: float | Tensor):
315
  """
316
  Args:
317
  ψt: (b c t)
@@ -364,7 +369,13 @@ class CFM(nn.Module):
364
  ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
365
  return ψ1
366
 
367
- def forward(self, x: Tensor, y: Tensor | None = None, ψ0: Tensor | None = None, t0=0.0):
 
 
 
 
 
 
368
  if y is None:
369
  y = self.sample(x, ψ0=ψ0, t0=t0)
370
  else:
 
1
  import logging
2
  from dataclasses import dataclass
3
  from functools import partial
4
+ from typing import Protocol, Union
5
 
6
  import matplotlib.pyplot as plt
7
  import numpy as np
 
17
 
18
 
19
  class VelocityField(Protocol):
20
+ def __call__(self, *, t: Tensor, ψt: Tensor, dt: Tensor) -> Tensor: ...
 
21
 
22
 
23
  class Solver:
 
39
 
40
  self._camera = None
41
  self._mel_fn = mel_fn
42
+ self._time_mapping = partial(
43
+ self.exponential_decay_mapping, n=time_mapping_divisor
44
+ )
45
 
46
  def configurate_(self, nfe=None, method=None):
47
  if nfe is None:
 
51
  method = self.method
52
 
53
  if nfe == 1 and method in ("midpoint", "rk4"):
54
+ logger.warning(
55
+ f"1 NFE is not supported for {method}, using euler method instead."
56
+ )
57
  method = "euler"
58
 
59
  self.nfe = nfe
 
108
  )
109
  else:
110
  # Spectrogram, b c t
111
+ plt.imshow(
112
+ ψt.detach().cpu().numpy()[0], origin="lower", interpolation="none"
113
+ )
114
  ax = plt.gca()
115
  ax.text(0.5, 1.01, f"t={t:.2f}", transform=ax.transAxes, ha="center")
116
  camera.snap()
 
276
  global_dim=self.time_emb_dim,
277
  )
278
 
279
+ def _perturb(self, ψ1: Tensor, t: Union[Tensor, None] = None):
280
  """
281
  Perturb ψ1 to ψt.
282
  """
 
316
  """
317
  return ψ1 - ψ0
318
 
319
+ def _to_v(self, *, ψt, x, t: Union[float, Tensor]):
320
  """
321
  Args:
322
  ψt: (b c t)
 
369
  ψ1 = self.solver(f=f, ψ0=ψ0, t0=t0)
370
  return ψ1
371
 
372
+ def forward(
373
+ self,
374
+ x: Tensor,
375
+ y: Union[Tensor, None] = None,
376
+ ψ0: Union[Tensor, None] = None,
377
+ t0=0.0,
378
+ ):
379
  if y is None:
380
  y = self.sample(x, ψ0=ψ0, t0=t0)
381
  else:
modules/repos_static/resemble_enhance/enhancer/lcfm/irmae.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  from dataclasses import dataclass
 
3
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
@@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
14
  @dataclass
15
  class IRMAEOutput:
16
  latent: Tensor # latent vector
17
- decoded: Tensor | None # decoder output, include extra dim
18
 
19
 
20
  class ResBlock(nn.Sequential):
 
1
  import logging
2
  from dataclasses import dataclass
3
+ from typing import Union
4
 
5
  import torch.nn as nn
6
  import torch.nn.functional as F
 
15
  @dataclass
16
  class IRMAEOutput:
17
  latent: Tensor # latent vector
18
+ decoded: Union[Tensor, None] # decoder output, include extra dim
19
 
20
 
21
  class ResBlock(nn.Sequential):
modules/repos_static/resemble_enhance/enhancer/lcfm/lcfm.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
  from enum import Enum
 
3
 
4
  import matplotlib.pyplot as plt
5
  import torch
@@ -70,19 +71,34 @@ class LCFM(nn.Module):
70
  return
71
 
72
  plt.subplot(221)
73
- plt.imshow(y[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
 
 
 
 
 
74
  plt.title("GT")
75
 
76
  plt.subplot(222)
77
  y_ = y_[:, : y.shape[1]]
78
- plt.imshow(y_[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
 
 
 
 
 
79
  plt.title("Posterior")
80
 
81
  plt.subplot(223)
82
  z_ = self.cfm(x)
83
  y__ = self.ae.decode(z_)
84
  y__ = y__[:, : y.shape[1]]
85
- plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
 
 
 
 
 
86
  plt.title("C-Prior")
87
  del y__
88
 
@@ -90,7 +106,12 @@ class LCFM(nn.Module):
90
  z_ = torch.randn_like(z_)
91
  y__ = self.ae.decode(z_)
92
  y__ = y__[:, : y.shape[1]]
93
- plt.imshow(y__[0].detach().cpu().numpy(), aspect="auto", origin="lower", interpolation="none")
 
 
 
 
 
94
  plt.title("Prior")
95
  del z_, y__
96
 
@@ -109,7 +130,7 @@ class LCFM(nn.Module):
109
  def eval_tau_(self, tau):
110
  self._eval_tau = tau
111
 
112
- def forward(self, x, y: Tensor | None = None, ψ0: Tensor | None = None):
113
  """
114
  Args:
115
  x: (b d t), condition mel
@@ -139,14 +160,20 @@ class LCFM(nn.Module):
139
 
140
  h = self.ae.decode(z)
141
  else:
142
- ae_output: IRMAEOutput = self.ae(y, skip_decoding=self.mode == self.Mode.CFM)
 
 
143
 
144
  if self.mode == self.Mode.CFM:
145
  _ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
146
 
147
  h = ae_output.decoded
148
 
149
- if h is not None and self.global_step is not None and self.global_step % 100 == 0:
 
 
 
 
150
  self._visualize(x[:1], y[:1], h[:1])
151
 
152
  return h
 
1
  import logging
2
  from enum import Enum
3
+ from typing import Union
4
 
5
  import matplotlib.pyplot as plt
6
  import torch
 
71
  return
72
 
73
  plt.subplot(221)
74
+ plt.imshow(
75
+ y[0].detach().cpu().numpy(),
76
+ aspect="auto",
77
+ origin="lower",
78
+ interpolation="none",
79
+ )
80
  plt.title("GT")
81
 
82
  plt.subplot(222)
83
  y_ = y_[:, : y.shape[1]]
84
+ plt.imshow(
85
+ y_[0].detach().cpu().numpy(),
86
+ aspect="auto",
87
+ origin="lower",
88
+ interpolation="none",
89
+ )
90
  plt.title("Posterior")
91
 
92
  plt.subplot(223)
93
  z_ = self.cfm(x)
94
  y__ = self.ae.decode(z_)
95
  y__ = y__[:, : y.shape[1]]
96
+ plt.imshow(
97
+ y__[0].detach().cpu().numpy(),
98
+ aspect="auto",
99
+ origin="lower",
100
+ interpolation="none",
101
+ )
102
  plt.title("C-Prior")
103
  del y__
104
 
 
106
  z_ = torch.randn_like(z_)
107
  y__ = self.ae.decode(z_)
108
  y__ = y__[:, : y.shape[1]]
109
+ plt.imshow(
110
+ y__[0].detach().cpu().numpy(),
111
+ aspect="auto",
112
+ origin="lower",
113
+ interpolation="none",
114
+ )
115
  plt.title("Prior")
116
  del z_, y__
117
 
 
130
  def eval_tau_(self, tau):
131
  self._eval_tau = tau
132
 
133
+ def forward(self, x, y: Union[Tensor, None] = None, ψ0: Union[Tensor, None] = None):
134
  """
135
  Args:
136
  x: (b d t), condition mel
 
160
 
161
  h = self.ae.decode(z)
162
  else:
163
+ ae_output: IRMAEOutput = self.ae(
164
+ y, skip_decoding=self.mode == self.Mode.CFM
165
+ )
166
 
167
  if self.mode == self.Mode.CFM:
168
  _ = self.cfm(x, self._scale(ae_output.latent.detach()), ψ0=ψ0)
169
 
170
  h = ae_output.decoded
171
 
172
+ if (
173
+ h is not None
174
+ and self.global_step is not None
175
+ and self.global_step % 100 == 0
176
+ ):
177
  self._visualize(x[:1], y[:1], h[:1])
178
 
179
  return h
modules/repos_static/resemble_enhance/enhancer/univnet/univnet.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
@@ -50,7 +51,9 @@ class UnivNet(nn.Module):
50
  ]
51
  )
52
 
53
- self.conv_pre = weight_norm(nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect"))
 
 
54
 
55
  self.conv_post = nn.Sequential(
56
  nn.LeakyReLU(0.2),
@@ -64,7 +67,7 @@ class UnivNet(nn.Module):
64
  def eps(self):
65
  return 1e-5
66
 
67
- def forward(self, x: Tensor, y: Tensor | None = None, npad=10):
68
  """
69
  Args:
70
  x: (b c t), acoustic features
@@ -74,7 +77,9 @@ class UnivNet(nn.Module):
74
  """
75
  assert x.ndim == 3, "x must be 3D tensor"
76
  assert y is None or y.ndim == 2, "y must be 2D tensor"
77
- assert x.shape[1] == self.d_input, f"x.shape[1] must be {self.d_input}, but got {x.shape}"
 
 
78
  assert npad >= 0, "npad must be positive or zero"
79
 
80
  x = F.pad(x, (0, npad), "constant", 0)
 
1
+ from typing import Union
2
  import numpy as np
3
  import torch
4
  import torch.nn.functional as F
 
51
  ]
52
  )
53
 
54
+ self.conv_pre = weight_norm(
55
+ nn.Conv1d(self.d_noise, self.nc, 7, padding=3, padding_mode="reflect")
56
+ )
57
 
58
  self.conv_post = nn.Sequential(
59
  nn.LeakyReLU(0.2),
 
67
  def eps(self):
68
  return 1e-5
69
 
70
+ def forward(self, x: Tensor, y: Union[Tensor, None] = None, npad=10):
71
  """
72
  Args:
73
  x: (b c t), acoustic features
 
77
  """
78
  assert x.ndim == 3, "x must be 3D tensor"
79
  assert y is None or y.ndim == 2, "y must be 2D tensor"
80
+ assert (
81
+ x.shape[1] == self.d_input
82
+ ), f"x.shape[1] must be {self.d_input}, but got {x.shape}"
83
  assert npad >= 0, "npad must be positive or zero"
84
 
85
  x = F.pad(x, (0, npad), "constant", 0)
modules/repos_static/resemble_enhance/hparams.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  from dataclasses import asdict, dataclass
3
  from pathlib import Path
 
4
 
5
  from omegaconf import OmegaConf
6
  from rich.console import Console
@@ -102,7 +103,7 @@ class HParams:
102
  OmegaConf.save(asdict(self), str(path))
103
 
104
  @classmethod
105
- def load(cls, run_dir, yaml: Path | None = None):
106
  hps = []
107
 
108
  if (run_dir / "hparams.yaml").exists():
@@ -120,7 +121,9 @@ class HParams:
120
  for k, v in asdict(hp).items():
121
  if getattr(hps[0], k) != v:
122
  errors[k] = f"{getattr(hps[0], k)} != {v}"
123
- raise ValueError(f"Found inconsistent hparams: {errors}, consider deleting {run_dir}")
 
 
124
 
125
  return hps[0]
126
 
 
1
  import logging
2
  from dataclasses import asdict, dataclass
3
  from pathlib import Path
4
+ from typing import Union
5
 
6
  from omegaconf import OmegaConf
7
  from rich.console import Console
 
103
  OmegaConf.save(asdict(self), str(path))
104
 
105
  @classmethod
106
+ def load(cls, run_dir, yaml: Union[Path, None] = None):
107
  hps = []
108
 
109
  if (run_dir / "hparams.yaml").exists():
 
121
  for k, v in asdict(hp).items():
122
  if getattr(hps[0], k) != v:
123
  errors[k] = f"{getattr(hps[0], k)} != {v}"
124
+ raise ValueError(
125
+ f"Found inconsistent hparams: {errors}, consider deleting {run_dir}"
126
+ )
127
 
128
  return hps[0]
129
 
modules/speaker.py CHANGED
@@ -29,13 +29,15 @@ class Speaker:
29
  speaker.emb = tensor
30
  return speaker
31
 
32
- def __init__(self, seed, name="", gender="", describe=""):
 
 
33
  self.id = uuid.uuid4()
34
- self.seed = seed
35
  self.name = name
36
  self.gender = gender
37
  self.describe = describe
38
- self.emb = None
39
 
40
  # TODO replace emb => tokens
41
  self.tokens = []
 
29
  speaker.emb = tensor
30
  return speaker
31
 
32
+ def __init__(
33
+ self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe=""
34
+ ):
35
  self.id = uuid.uuid4()
36
+ self.seed = -2 if isinstance(seed_or_tensor, torch.Tensor) else seed_or_tensor
37
  self.name = name
38
  self.gender = gender
39
  self.describe = describe
40
+ self.emb = None if isinstance(seed_or_tensor, int) else seed_or_tensor
41
 
42
  # TODO replace emb => tokens
43
  self.tokens = []
modules/ssml_parser/SSMLParser.py CHANGED
@@ -11,8 +11,8 @@ import copy
11
 
12
 
13
  class SSMLContext(Box):
14
- def __init__(self, parent=None):
15
- self.parent: Union[SSMLContext, None] = parent
16
 
17
  self.style = None
18
  self.spk = None
@@ -29,18 +29,14 @@ class SSMLContext(Box):
29
  self.prompt2 = None
30
  self.prefix = None
31
 
32
- def clone(self):
33
- ctx = SSMLContext()
34
- for k, v in self.items():
35
- ctx[k] = v
36
- return ctx
37
 
38
 
39
  class SSMLSegment(Box):
40
- def __init__(self, text: str, attrs=SSMLContext()):
41
- self.attrs = attrs
42
  self.text = text
43
- self.params = None
44
 
45
 
46
  class SSMLBreak:
@@ -68,7 +64,7 @@ class SSMLParser:
68
  root = etree.fromstring(ssml)
69
 
70
  root_ctx = SSMLContext()
71
- segments = []
72
  self.resolve(root, root_ctx, segments)
73
 
74
  return segments
@@ -89,8 +85,13 @@ def create_ssml_parser():
89
  parser = SSMLParser()
90
 
91
  @parser.resolver("speak")
92
- def tag_speak(element, context, segments, parser):
93
- ctx = context.clone() if context is not None else SSMLContext()
 
 
 
 
 
94
 
95
  version = element.get("version")
96
  if version != "0.1":
@@ -100,8 +101,13 @@ def create_ssml_parser():
100
  parser.resolve(child, ctx, segments)
101
 
102
  @parser.resolver("voice")
103
- def tag_voice(element, context, segments, parser):
104
- ctx = context.clone() if context is not None else SSMLContext()
 
 
 
 
 
105
 
106
  ctx.spk = element.get("spk", ctx.spk)
107
  ctx.style = element.get("style", ctx.style)
@@ -131,13 +137,23 @@ def create_ssml_parser():
131
  segments.append(SSMLSegment(child.tail.strip(), ctx))
132
 
133
  @parser.resolver("break")
134
- def tag_break(element, context, segments, parser):
 
 
 
 
 
135
  time_ms = int(element.get("time", "0").replace("ms", ""))
136
  segments.append(SSMLBreak(time_ms))
137
 
138
  @parser.resolver("prosody")
139
- def tag_prosody(element, context, segments, parser):
140
- ctx = context.clone() if context is not None else SSMLContext()
 
 
 
 
 
141
 
142
  ctx.spk = element.get("spk", ctx.spk)
143
  ctx.style = element.get("style", ctx.style)
 
11
 
12
 
13
  class SSMLContext(Box):
14
+ def __init__(self, *args, **kwargs):
15
+ self.parent: Union[SSMLContext, None] = None
16
 
17
  self.style = None
18
  self.spk = None
 
29
  self.prompt2 = None
30
  self.prefix = None
31
 
32
+ super().__init__(*args, **kwargs)
 
 
 
 
33
 
34
 
35
  class SSMLSegment(Box):
36
+ def __init__(self, text: str, attrs=SSMLContext(), params=None):
37
+ self.attrs = SSMLContext(**attrs)
38
  self.text = text
39
+ self.params = params
40
 
41
 
42
  class SSMLBreak:
 
64
  root = etree.fromstring(ssml)
65
 
66
  root_ctx = SSMLContext()
67
+ segments: List[Union[SSMLSegment, SSMLBreak]] = []
68
  self.resolve(root, root_ctx, segments)
69
 
70
  return segments
 
85
  parser = SSMLParser()
86
 
87
  @parser.resolver("speak")
88
+ def tag_speak(
89
+ element: etree.Element,
90
+ context: Box,
91
+ segments: List[Union[SSMLSegment, SSMLBreak]],
92
+ parser: SSMLParser,
93
+ ):
94
+ ctx = context.copy() if context is not None else SSMLContext()
95
 
96
  version = element.get("version")
97
  if version != "0.1":
 
101
  parser.resolve(child, ctx, segments)
102
 
103
  @parser.resolver("voice")
104
+ def tag_voice(
105
+ element: etree.Element,
106
+ context: Box,
107
+ segments: List[Union[SSMLSegment, SSMLBreak]],
108
+ parser: SSMLParser,
109
+ ):
110
+ ctx = context.copy() if context is not None else SSMLContext()
111
 
112
  ctx.spk = element.get("spk", ctx.spk)
113
  ctx.style = element.get("style", ctx.style)
 
137
  segments.append(SSMLSegment(child.tail.strip(), ctx))
138
 
139
  @parser.resolver("break")
140
+ def tag_break(
141
+ element: etree.Element,
142
+ context: Box,
143
+ segments: List[Union[SSMLSegment, SSMLBreak]],
144
+ parser: SSMLParser,
145
+ ):
146
  time_ms = int(element.get("time", "0").replace("ms", ""))
147
  segments.append(SSMLBreak(time_ms))
148
 
149
  @parser.resolver("prosody")
150
+ def tag_prosody(
151
+ element: etree.Element,
152
+ context: Box,
153
+ segments: List[Union[SSMLSegment, SSMLBreak]],
154
+ parser: SSMLParser,
155
+ ):
156
+ ctx = context.copy() if context is not None else SSMLContext()
157
 
158
  ctx.spk = element.get("spk", ctx.spk)
159
  ctx.style = element.get("style", ctx.style)
modules/synthesize_audio.py CHANGED
@@ -7,6 +7,7 @@ from modules import generate_audio as generate
7
 
8
 
9
  from modules.speaker import Speaker
 
10
  from modules.utils import audio
11
 
12
 
@@ -23,45 +24,33 @@ def synthesize_audio(
23
  prefix: str = "",
24
  batch_size: int = 1,
25
  spliter_threshold: int = 100,
 
26
  ):
27
- if batch_size == 1:
28
- return generate.generate_audio(
29
- text,
30
- temperature=temperature,
31
- top_P=top_P,
32
- top_K=top_K,
33
- spk=spk,
34
- infer_seed=infer_seed,
35
- use_decoder=use_decoder,
36
- prompt1=prompt1,
37
- prompt2=prompt2,
38
- prefix=prefix,
 
 
 
 
 
39
  )
40
- else:
41
- spliter = SentenceSplitter(spliter_threshold)
42
- sentences = spliter.parse(text)
 
 
 
43
 
44
- text_segments = [
45
- {
46
- "text": s,
47
- "params": {
48
- "text": s,
49
- "temperature": temperature,
50
- "top_P": top_P,
51
- "top_K": top_K,
52
- "spk": spk,
53
- "infer_seed": infer_seed,
54
- "use_decoder": use_decoder,
55
- "prompt1": prompt1,
56
- "prompt2": prompt2,
57
- "prefix": prefix,
58
- },
59
- }
60
- for s in sentences
61
- ]
62
- synthesizer = SynthesizeSegments(batch_size)
63
- audio_segments = synthesizer.synthesize_segments(text_segments)
64
 
65
- combined_audio = combine_audio_segments(audio_segments)
66
-
67
- return audio.pydub_to_np(combined_audio)
 
7
 
8
 
9
  from modules.speaker import Speaker
10
+ from modules.ssml_parser.SSMLParser import SSMLSegment
11
  from modules.utils import audio
12
 
13
 
 
24
  prefix: str = "",
25
  batch_size: int = 1,
26
  spliter_threshold: int = 100,
27
+ end_of_sentence="",
28
  ):
29
+ spliter = SentenceSplitter(spliter_threshold)
30
+ sentences = spliter.parse(text)
31
+
32
+ text_segments = [
33
+ SSMLSegment(
34
+ text=s,
35
+ params={
36
+ "temperature": temperature,
37
+ "top_P": top_P,
38
+ "top_K": top_K,
39
+ "spk": spk,
40
+ "infer_seed": infer_seed,
41
+ "use_decoder": use_decoder,
42
+ "prompt1": prompt1,
43
+ "prompt2": prompt2,
44
+ "prefix": prefix,
45
+ },
46
  )
47
+ for s in sentences
48
+ ]
49
+ synthesizer = SynthesizeSegments(
50
+ batch_size=batch_size, eos=end_of_sentence, spliter_thr=spliter_threshold
51
+ )
52
+ audio_segments = synthesizer.synthesize_segments(text_segments)
53
 
54
+ combined_audio = combine_audio_segments(audio_segments)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ return audio.pydub_to_np(combined_audio)
 
 
modules/utils/audio.py CHANGED
@@ -95,7 +95,11 @@ def pitch_shift(
95
 
96
 
97
  def apply_prosody_to_audio_data(
98
- audio_data: np.ndarray, rate: float, volume: float, pitch: float, sr: int
 
 
 
 
99
  ) -> np.ndarray:
100
  if rate != 1:
101
  audio_data = pyrb.time_stretch(audio_data, sr=sr, rate=rate)
 
95
 
96
 
97
  def apply_prosody_to_audio_data(
98
+ audio_data: np.ndarray,
99
+ rate: float = 1,
100
+ volume: float = 0,
101
+ pitch: float = 0,
102
+ sr: int = 24000,
103
  ) -> np.ndarray:
104
  if rate != 1:
105
  audio_data = pyrb.time_stretch(audio_data, sr=sr, rate=rate)
modules/webui/app.py CHANGED
@@ -7,6 +7,7 @@ from modules import config
7
  from modules.webui import gradio_extensions, webui_config
8
 
9
  from modules.webui.changelog_tab import create_changelog_tab
 
10
  from modules.webui.localization_runtime import ENLocalizationVars, ZHLocalizationVars
11
  from modules.webui.ssml.podcast_tab import create_ssml_podcast_tab
12
  from modules.webui.system_tab import create_system_tab
@@ -118,6 +119,8 @@ def create_interface():
118
  gr.Markdown("🚧 Under construction")
119
  with gr.TabItem("ASR", visible=webui_config.experimental):
120
  gr.Markdown("🚧 Under construction")
 
 
121
 
122
  with gr.TabItem("System"):
123
  create_system_tab()
 
7
  from modules.webui import gradio_extensions, webui_config
8
 
9
  from modules.webui.changelog_tab import create_changelog_tab
10
+ from modules.webui.finetune.ft_tab import create_ft_tabs
11
  from modules.webui.localization_runtime import ENLocalizationVars, ZHLocalizationVars
12
  from modules.webui.ssml.podcast_tab import create_ssml_podcast_tab
13
  from modules.webui.system_tab import create_system_tab
 
119
  gr.Markdown("🚧 Under construction")
120
  with gr.TabItem("ASR", visible=webui_config.experimental):
121
  gr.Markdown("🚧 Under construction")
122
+ with gr.TabItem("Finetune", visible=webui_config.experimental):
123
+ create_ft_tabs(demo)
124
 
125
  with gr.TabItem("System"):
126
  create_system_tab()
modules/webui/finetune/ProcessMonitor.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import subprocess
4
+ import threading
5
+
6
+
7
+ class ProcessMonitor:
8
+ def __init__(self):
9
+ self.process = None
10
+ self.stdout = ""
11
+ self.stderr = ""
12
+ self.lock = threading.Lock()
13
+
14
+ def start_process(self, command):
15
+ self.process = subprocess.Popen(
16
+ command,
17
+ stdout=subprocess.PIPE,
18
+ stderr=subprocess.PIPE,
19
+ bufsize=1,
20
+ universal_newlines=True,
21
+ )
22
+
23
+ # Set pipes to non-blocking mode
24
+ fd_out = self.process.stdout.fileno()
25
+ fd_err = self.process.stderr.fileno()
26
+
27
+ if sys.platform != "win32":
28
+ import fcntl
29
+
30
+ fl_out = fcntl.fcntl(fd_out, fcntl.F_GETFL)
31
+ fl_err = fcntl.fcntl(fd_err, fcntl.F_GETFL)
32
+ fcntl.fcntl(fd_out, fcntl.F_SETFL, fl_out | os.O_NONBLOCK)
33
+ fcntl.fcntl(fd_err, fcntl.F_SETFL, fl_err | os.O_NONBLOCK)
34
+
35
+ # Start threads to read stdout and stderr
36
+ threading.Thread(target=self._read_stdout).start()
37
+ threading.Thread(target=self._read_stderr).start()
38
+
39
+ def _read_stdout(self):
40
+ while self.process is not None and self.process.poll() is None:
41
+ try:
42
+ output = self.process.stdout.read()
43
+ if output:
44
+ with self.lock:
45
+ self.stdout += output
46
+ except:
47
+ pass
48
+
49
+ def _read_stderr(self):
50
+ while self.process is not None and self.process.poll() is None:
51
+ try:
52
+ error = self.process.stderr.read()
53
+ if error:
54
+ with self.lock:
55
+ self.stderr += error
56
+ except:
57
+ pass
58
+
59
+ def get_output(self):
60
+ with self.lock:
61
+ return self.stdout, self.stderr
62
+
63
+ def stop_process(self):
64
+ if self.process:
65
+ self.process.terminate()
66
+ self.process = None
67
+
68
+
69
+ if __name__ == "__main__":
70
+ import time
71
+
72
+ pm = ProcessMonitor()
73
+ pm.start_process(
74
+ [
75
+ "python",
76
+ "-u",
77
+ "-c",
78
+ "import time; [print(i) or time.sleep(1) for i in range(5)]",
79
+ ]
80
+ )
81
+
82
+ while pm.process and pm.process.poll() is None:
83
+ stdout, stderr = pm.get_output()
84
+ if stdout:
85
+ print("STDOUT:", stdout)
86
+ if stderr:
87
+ print("STDERR:", stderr)
88
+ time.sleep(1)
89
+
90
+ stdout, stderr = pm.get_output()
91
+ print("Final STDOUT:", stdout)
92
+ print("Final STDERR:", stderr)
modules/webui/finetune/ft_tab.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from modules.webui.finetune.speaker_ft_tab import create_speaker_ft_tab
4
+
5
+
6
+ def create_ft_tabs(demo):
7
+ with gr.Tabs():
8
+ with gr.TabItem("Speaker"):
9
+ create_speaker_ft_tab(demo)
10
+ with gr.TabItem("GPT"):
11
+ gr.Markdown("🚧 Under construction")
12
+ with gr.TabItem("AE"):
13
+ gr.Markdown("🚧 Under construction")
modules/webui/finetune/ft_ui_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import IO, Union
3
+ from modules.speaker import Speaker, speaker_mgr
4
+ import subprocess
5
+
6
+
7
+ def get_datasets_dir():
8
+ """
9
+ 列出 ./datasets/data_* 文件夹
10
+ """
11
+ dataset_path = "./datasets"
12
+ dataset_list = os.listdir(dataset_path)
13
+ dataset_list = [
14
+ d for d in dataset_list if os.path.isdir(os.path.join(dataset_path, d))
15
+ ]
16
+ dataset_list = [d for d in dataset_list if d.startswith("data_")]
17
+ return dataset_list
18
+
19
+
20
+ def get_datasets_listfile():
21
+ datasets = get_datasets_dir()
22
+ listfiles = []
23
+ for d in datasets:
24
+ dir_path = os.path.join("./datasets", d)
25
+ files = os.listdir(dir_path)
26
+ for f in files:
27
+ if f.endswith(".list"):
28
+ listfiles.append(os.path.join(dir_path, f))
29
+ return listfiles
30
+
31
+
32
+ def run_speaker_ft(
33
+ batch_size: int, epochs: int, train_text: bool, data_path: str, init_speaker: str
34
+ ):
35
+ command = ["python3", "-m", "modules.finetune.train_speaker"]
36
+ command += [
37
+ f"--batch_size={batch_size}",
38
+ f"--epochs={epochs}",
39
+ f"--data_path={data_path}",
40
+ ]
41
+ if train_text:
42
+ command.append("--train_text")
43
+ if init_speaker:
44
+ command.append(f"--init_speaker={init_speaker}")
45
+ process = subprocess.Popen(
46
+ command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1
47
+ )
48
+
49
+ return process
modules/webui/finetune/speaker_ft_tab.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from modules.Enhancer.ResembleEnhance import unload_enhancer
4
+ from modules.webui import webui_config
5
+ from modules.webui.webui_utils import get_speaker_names
6
+ from .ft_ui_utils import get_datasets_listfile, run_speaker_ft
7
+ from .ProcessMonitor import ProcessMonitor
8
+ from modules.speaker import speaker_mgr
9
+ from modules.models import unload_chat_tts
10
+
11
+
12
+ class SpeakerFt:
13
+ def __init__(self):
14
+ self.process_monitor = ProcessMonitor()
15
+ self.status_str = "idle"
16
+
17
+ def unload_main_thread_models(self):
18
+ unload_chat_tts()
19
+ unload_enhancer()
20
+
21
+ def run(
22
+ self,
23
+ batch_size: int,
24
+ epochs: int,
25
+ lr: str,
26
+ train_text: bool,
27
+ data_path: str,
28
+ select_speaker: str = "",
29
+ ):
30
+ if self.process_monitor.process:
31
+ return
32
+ self.unload_main_thread_models()
33
+ spk_path = None
34
+ if select_speaker != "" and select_speaker != "none":
35
+ select_speaker = select_speaker.split(" : ")[1].strip()
36
+ spk = speaker_mgr.get_speaker(select_speaker)
37
+ if spk is None:
38
+ return ["Speaker not found"]
39
+ spk_filename = speaker_mgr.get_speaker_filename(spk.id)
40
+ spk_path = f"./data/speakers/{spk_filename}"
41
+
42
+ command = ["python3", "-m", "modules.finetune.train_speaker"]
43
+ command += [
44
+ f"--batch_size={batch_size}",
45
+ f"--epochs={epochs}",
46
+ f"--data_path={data_path}",
47
+ ]
48
+ if train_text:
49
+ command.append("--train_text")
50
+ if spk_path:
51
+ command.append(f"--init_speaker={spk_path}")
52
+
53
+ self.status("Training process starting")
54
+
55
+ self.process_monitor.start_process(command)
56
+
57
+ self.status("Training started")
58
+
59
+ def status(self, text: str):
60
+ self.status_str = text
61
+
62
+ def flush(self):
63
+ stdout, stderr = self.process_monitor.get_output()
64
+ return f"{self.status_str}\n{stdout}\n{stderr}"
65
+
66
+ def clear(self):
67
+ self.process_monitor.stdout = ""
68
+ self.process_monitor.stderr = ""
69
+ self.status("Logs cleared")
70
+
71
+ def stop(self):
72
+ self.process_monitor.stop_process()
73
+ self.status("Training stopped")
74
+
75
+
76
+ def create_speaker_ft_tab(demo: gr.Blocks):
77
+ spk_ft = SpeakerFt()
78
+ speakers, speaker_names = get_speaker_names()
79
+ speaker_names = ["none"] + speaker_names
80
+
81
+ with gr.Row():
82
+ with gr.Column(scale=2):
83
+ with gr.Group():
84
+ gr.Markdown("🎛️hparams")
85
+ dataset_input = gr.Dropdown(
86
+ label="Dataset", choices=get_datasets_listfile()
87
+ )
88
+ lr_input = gr.Textbox(label="Learning Rate", value="1e-2")
89
+ epochs_input = gr.Slider(
90
+ label="Epochs", value=10, minimum=1, maximum=100, step=1
91
+ )
92
+ batch_size_input = gr.Slider(
93
+ label="Batch Size", value=4, minimum=1, maximum=64, step=1
94
+ )
95
+ train_text_checkbox = gr.Checkbox(label="Train text_loss", value=True)
96
+ init_spk_dropdown = gr.Dropdown(
97
+ label="Initial Speaker",
98
+ choices=speaker_names,
99
+ value="none",
100
+ )
101
+
102
+ with gr.Group():
103
+ start_train_btn = gr.Button("Start Training")
104
+ stop_train_btn = gr.Button("Stop Training")
105
+ clear_train_btn = gr.Button("Clear logs")
106
+ with gr.Column(scale=5):
107
+ with gr.Group():
108
+ # log
109
+ gr.Markdown("📜logs")
110
+ log_output = gr.Textbox(
111
+ show_label=False, label="Log", value="", lines=20, interactive=True
112
+ )
113
+
114
+ start_train_btn.click(
115
+ spk_ft.run,
116
+ inputs=[
117
+ batch_size_input,
118
+ epochs_input,
119
+ lr_input,
120
+ train_text_checkbox,
121
+ dataset_input,
122
+ init_spk_dropdown,
123
+ ],
124
+ outputs=[],
125
+ )
126
+ stop_train_btn.click(spk_ft.stop)
127
+ clear_train_btn.click(spk_ft.clear)
128
+
129
+ if webui_config.experimental:
130
+ demo.load(spk_ft.flush, every=1, outputs=[log_output])
modules/webui/localization_runtime.py CHANGED
@@ -7,6 +7,7 @@ class LocalizationVars:
7
 
8
  self.ssml_examples = []
9
  self.tts_examples = []
 
10
 
11
 
12
  class ZHLocalizationVars(LocalizationVars):
@@ -167,6 +168,69 @@ class ZHLocalizationVars(LocalizationVars):
167
  },
168
  ]
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  class ENLocalizationVars(LocalizationVars):
172
  def __init__(self):
@@ -224,3 +288,65 @@ class ENLocalizationVars(LocalizationVars):
224
  "text": "Don't ever let somebody tell you you can't do something. Not even me. Alright? You got a dream, you gotta protect it. When people can't do something themselves, they're gonna tell you that you can't do it. You want something, go get it. Period.",
225
  },
226
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  self.ssml_examples = []
9
  self.tts_examples = []
10
+ self.podcast_default = []
11
 
12
 
13
  class ZHLocalizationVars(LocalizationVars):
 
168
  },
169
  ]
170
 
171
+ self.podcast_default = [
172
+ [
173
+ 1,
174
+ "female2",
175
+ "你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。",
176
+ "podcast",
177
+ ],
178
+ [
179
+ 2,
180
+ "Alice",
181
+ "嗨,我特别期待这个话题!中华料理真的是博大精深。",
182
+ "podcast",
183
+ ],
184
+ [
185
+ 3,
186
+ "Bob",
187
+ "没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。",
188
+ "podcast",
189
+ ],
190
+ [
191
+ 4,
192
+ "female2",
193
+ "那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。",
194
+ "podcast",
195
+ ],
196
+ [
197
+ 5,
198
+ "Alice",
199
+ "对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。",
200
+ "podcast",
201
+ ],
202
+ [
203
+ 6,
204
+ "Bob",
205
+ "除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。",
206
+ "podcast",
207
+ ],
208
+ [
209
+ 7,
210
+ "female2",
211
+ "对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。",
212
+ "podcast",
213
+ ],
214
+ [
215
+ 8,
216
+ "Alice",
217
+ "还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。",
218
+ "podcast",
219
+ ],
220
+ [
221
+ 9,
222
+ "Bob",
223
+ "不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。",
224
+ "podcast",
225
+ ],
226
+ [
227
+ 10,
228
+ "female2",
229
+ "对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。",
230
+ "podcast",
231
+ ],
232
+ ]
233
+
234
 
235
  class ENLocalizationVars(LocalizationVars):
236
  def __init__(self):
 
288
  "text": "Don't ever let somebody tell you you can't do something. Not even me. Alright? You got a dream, you gotta protect it. When people can't do something themselves, they're gonna tell you that you can't do it. You want something, go get it. Period.",
289
  },
290
  ]
291
+ self.podcast_default = [
292
+ [
293
+ 1,
294
+ "female2",
295
+ "Hello, welcome to today's podcast. Today, we're going to talk about global cuisine.",
296
+ "podcast",
297
+ ],
298
+ [
299
+ 2,
300
+ "Alice",
301
+ "Hi, I'm really excited about this topic! Global cuisine is incredibly diverse and fascinating.",
302
+ "podcast",
303
+ ],
304
+ [
305
+ 3,
306
+ "Bob",
307
+ "Absolutely, every country has its own unique culinary traditions and specialties.",
308
+ "podcast",
309
+ ],
310
+ [
311
+ 4,
312
+ "female2",
313
+ "Let's start with Italian cuisine. Italian food is loved worldwide, especially for its pasta and pizza.",
314
+ "podcast",
315
+ ],
316
+ [
317
+ 5,
318
+ "Alice",
319
+ "Yes, I especially love a good Margherita pizza and a hearty plate of spaghetti carbonara. The flavors are simply amazing.",
320
+ "podcast",
321
+ ],
322
+ [
323
+ 6,
324
+ "Bob",
325
+ "Besides Italian cuisine, Japanese cuisine is also very popular. Dishes like sushi and ramen have become global favorites.",
326
+ "podcast",
327
+ ],
328
+ [
329
+ 7,
330
+ "female2",
331
+ "Exactly, Japanese cuisine is known for its emphasis on fresh ingredients and delicate presentation.",
332
+ "podcast",
333
+ ],
334
+ [
335
+ 8,
336
+ "Alice",
337
+ "And then there's Mexican cuisine, with its bold flavors and colorful dishes like tacos and guacamole.",
338
+ "podcast",
339
+ ],
340
+ [
341
+ 9,
342
+ "Bob",
343
+ "Not to mention, there's also Indian cuisine, Thai cuisine, French cuisine, and so many more, each with its own distinctive flavors and techniques.",
344
+ "podcast",
345
+ ],
346
+ [
347
+ 10,
348
+ "female2",
349
+ "Yes, like Indian curry, Thai tom yum soup, and French croissants, these are all mouth-watering dishes that are loved by people all over the world.",
350
+ "podcast",
351
+ ],
352
+ ]
modules/webui/ssml/podcast_tab.py CHANGED
@@ -3,72 +3,9 @@ import pandas as pd
3
  import torch
4
 
5
  from modules.normalization import text_normalize
6
- from modules.webui import webui_utils
7
  from modules.utils.hf import spaces
8
 
9
- podcast_default_case = [
10
- [
11
- 1,
12
- "female2",
13
- "你好,欢迎收听今天的播客内容。今天我们要聊的是中华料理。 [lbreak]",
14
- "podcast",
15
- ],
16
- [
17
- 2,
18
- "Alice",
19
- "嗨,我特别期待这个话题!中华料理真的是博大精深。 [lbreak]",
20
- "podcast",
21
- ],
22
- [
23
- 3,
24
- "Bob",
25
- "没错,中华料理有着几千年的历史,而且每个地区都有自己的特色菜。 [lbreak]",
26
- "podcast",
27
- ],
28
- [
29
- 4,
30
- "female2",
31
- "那我们先从最有名的川菜开始吧。川菜以其麻辣著称,是很多人的最爱。 [lbreak]",
32
- "podcast",
33
- ],
34
- [
35
- 5,
36
- "Alice",
37
- "对,我特别喜欢吃麻婆豆腐和辣子鸡。那种麻辣的感觉真是让人难以忘怀。 [lbreak]",
38
- "podcast",
39
- ],
40
- [
41
- 6,
42
- "Bob",
43
- "除了川菜,粤菜也是很受欢迎的。粤菜讲究鲜美,像是白切鸡和蒸鱼都是经典。 [lbreak]",
44
- "podcast",
45
- ],
46
- [
47
- 7,
48
- "female2",
49
- "对啊,粤菜的烹饪方式比较清淡,更注重食材本身的味道。 [lbreak]",
50
- "podcast",
51
- ],
52
- [
53
- 8,
54
- "Alice",
55
- "还有北京的京菜,像北京烤鸭,那可是来北京必吃的美食。 [lbreak]",
56
- "podcast",
57
- ],
58
- [
59
- 9,
60
- "Bob",
61
- "不仅如此,还有淮扬菜、湘菜、鲁菜等等,每个菜系都有其独特的风味。 [lbreak]",
62
- "podcast",
63
- ],
64
- [
65
- 10,
66
- "female2",
67
- "对对对,像淮扬菜的狮子头,湘菜的剁椒鱼头,都是让人垂涎三尺的美味。 [lbreak]",
68
- "podcast",
69
- ],
70
- ]
71
-
72
 
73
  # NOTE: 因为 text_normalize 需要使用 tokenizer
74
  @torch.inference_mode()
@@ -133,7 +70,7 @@ def create_ssml_podcast_tab(ssml_input: gr.Textbox, tabs1: gr.Tabs, tabs2: gr.Ta
133
  datatype=["number", "str", "str", "str"],
134
  interactive=True,
135
  wrap=True,
136
- value=podcast_default_case,
137
  row_count=(0, "dynamic"),
138
  col_count=(4, "fixed"),
139
  )
 
3
  import torch
4
 
5
  from modules.normalization import text_normalize
6
+ from modules.webui import webui_config, webui_utils
7
  from modules.utils.hf import spaces
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # NOTE: 因为 text_normalize 需要使用 tokenizer
11
  @torch.inference_mode()
 
70
  datatype=["number", "str", "str", "str"],
71
  interactive=True,
72
  wrap=True,
73
+ value=webui_config.localization.podcast_default,
74
  row_count=(0, "dynamic"),
75
  col_count=(4, "fixed"),
76
  )
modules/webui/ssml/ssml_tab.py CHANGED
@@ -22,7 +22,6 @@ def create_ssml_interface():
22
  ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
23
  with gr.Column(scale=1):
24
  with gr.Group():
25
- # 参数
26
  gr.Markdown("🎛️Parameters")
27
  # batch size
28
  batch_size_input = gr.Slider(
@@ -32,6 +31,19 @@ def create_ssml_interface():
32
  maximum=webui_config.max_batch_size,
33
  step=1,
34
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  with gr.Group():
37
  gr.Markdown("💪🏼Enhance")
@@ -49,7 +61,14 @@ def create_ssml_interface():
49
 
50
  ssml_button.click(
51
  synthesize_ssml,
52
- inputs=[ssml_input, batch_size_input, enable_enhance, enable_de_noise],
 
 
 
 
 
 
 
53
  outputs=ssml_output,
54
  )
55
 
 
22
  ssml_button = gr.Button("🔊Synthesize SSML", variant="primary")
23
  with gr.Column(scale=1):
24
  with gr.Group():
 
25
  gr.Markdown("🎛️Parameters")
26
  # batch size
27
  batch_size_input = gr.Slider(
 
31
  maximum=webui_config.max_batch_size,
32
  step=1,
33
  )
34
+ with gr.Group():
35
+ gr.Markdown("🎛️Spliter")
36
+ eos_input = gr.Textbox(
37
+ label="eos",
38
+ value="[uv_break]",
39
+ )
40
+ spliter_thr_input = gr.Slider(
41
+ label="Spliter Threshold",
42
+ value=100,
43
+ minimum=50,
44
+ maximum=1000,
45
+ step=1,
46
+ )
47
 
48
  with gr.Group():
49
  gr.Markdown("💪🏼Enhance")
 
61
 
62
  ssml_button.click(
63
  synthesize_ssml,
64
+ inputs=[
65
+ ssml_input,
66
+ batch_size_input,
67
+ enable_enhance,
68
+ enable_de_noise,
69
+ eos_input,
70
+ spliter_thr_input,
71
+ ],
72
  outputs=ssml_output,
73
  )
74