mrfakename commited on
Commit
48c079f
1 Parent(s): b2e5882

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

README_REPO.md CHANGED
@@ -147,11 +147,11 @@ Note: Some model components have linting exceptions for E722 to accommodate tens
147
  ## Acknowledgements
148
 
149
  - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
150
- - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763) valuable datasets
151
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
152
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
153
- - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) as vocoder
154
- - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech) for evaluation tools
155
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156
  - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
157
  - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
 
147
  ## Acknowledgements
148
 
149
  - [E2-TTS](https://arxiv.org/abs/2406.18009) brilliant work, simple and effective
150
+ - [Emilia](https://arxiv.org/abs/2407.05361), [WenetSpeech4TTS](https://arxiv.org/abs/2406.05763), [LibriTTS](https://arxiv.org/abs/1904.02882), [LJSpeech](https://keithito.com/LJ-Speech-Dataset/) valuable datasets
151
  - [lucidrains](https://github.com/lucidrains) initial CFM structure with also [bfs18](https://github.com/bfs18) for discussion
152
  - [SD3](https://arxiv.org/abs/2403.03206) & [Hugging Face diffusers](https://github.com/huggingface/diffusers) DiT and MMDiT code structure
153
+ - [torchdiffeq](https://github.com/rtqichen/torchdiffeq) as ODE solver, [Vocos](https://huggingface.co/charactr/vocos-mel-24khz) and [BigVGAN](https://github.com/NVIDIA/BigVGAN) as vocoder
154
+ - [FunASR](https://github.com/modelscope/FunASR), [faster-whisper](https://github.com/SYSTRAN/faster-whisper), [UniSpeech](https://github.com/microsoft/UniSpeech), [SpeechMOS](https://github.com/tarepan/SpeechMOS) for evaluation tools
155
  - [ctc-forced-aligner](https://github.com/MahmoudAshraf97/ctc-forced-aligner) for speech edit test
156
  - [mrfakename](https://x.com/realmrfakename) huggingface space demo ~
157
  - [f5-tts-mlx](https://github.com/lucasnewman/f5-tts-mlx/tree/main) Implementation with MLX framework by [Lucas Newman](https://github.com/lucasnewman)
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # ruff: noqa: E402
2
  # Above allows ruff to ignore E402: module level import not at top of file
3
 
 
4
  import re
5
  import tempfile
6
  from collections import OrderedDict
@@ -43,6 +44,12 @@ from f5_tts.infer.utils_infer import (
43
  DEFAULT_TTS_MODEL = "F5-TTS"
44
  tts_model_choice = DEFAULT_TTS_MODEL
45
 
 
 
 
 
 
 
46
 
47
  # load models
48
 
@@ -103,7 +110,15 @@ def generate_response(messages, model, tokenizer):
103
 
104
  @gpu_decorator
105
  def infer(
106
- ref_audio_orig, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15, speed=1, show_info=gr.Info
 
 
 
 
 
 
 
 
107
  ):
108
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
109
 
@@ -120,7 +135,7 @@ def infer(
120
  global custom_ema_model, pre_custom_path
121
  if pre_custom_path != model[1]:
122
  show_info("Loading Custom TTS model...")
123
- custom_ema_model = load_custom(model[1], vocab_path=model[2])
124
  pre_custom_path = model[1]
125
  ema_model = custom_ema_model
126
 
@@ -131,6 +146,7 @@ def infer(
131
  ema_model,
132
  vocoder,
133
  cross_fade_duration=cross_fade_duration,
 
134
  speed=speed,
135
  show_info=show_info,
136
  progress=gr.Progress(),
@@ -184,6 +200,14 @@ with gr.Blocks() as app_tts:
184
  step=0.1,
185
  info="Adjust the speed of the audio.",
186
  )
 
 
 
 
 
 
 
 
187
  cross_fade_duration_slider = gr.Slider(
188
  label="Cross-Fade Duration (s)",
189
  minimum=0.0,
@@ -203,6 +227,7 @@ with gr.Blocks() as app_tts:
203
  gen_text_input,
204
  remove_silence,
205
  cross_fade_duration_slider,
 
206
  speed_slider,
207
  ):
208
  audio_out, spectrogram_path, ref_text_out = infer(
@@ -211,8 +236,9 @@ with gr.Blocks() as app_tts:
211
  gen_text_input,
212
  tts_model_choice,
213
  remove_silence,
214
- cross_fade_duration_slider,
215
- speed_slider,
 
216
  )
217
  return audio_out, spectrogram_path, gr.update(value=ref_text_out)
218
 
@@ -224,6 +250,7 @@ with gr.Blocks() as app_tts:
224
  gen_text_input,
225
  remove_silence,
226
  cross_fade_duration_slider,
 
227
  speed_slider,
228
  ],
229
  outputs=[audio_output, spectrogram_output, ref_text_input],
@@ -744,34 +771,38 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
744
  """
745
  )
746
 
747
- last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom.txt")
748
 
749
  def load_last_used_custom():
750
  try:
751
- with open(last_used_custom, "r") as f:
752
- return f.read().split(",")
 
 
 
753
  except FileNotFoundError:
754
  last_used_custom.parent.mkdir(parents=True, exist_ok=True)
755
- return [
756
- "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
757
- "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
758
- ]
759
 
760
  def switch_tts_model(new_choice):
761
  global tts_model_choice
762
  if new_choice == "Custom": # override in case webpage is refreshed
763
- custom_ckpt_path, custom_vocab_path = load_last_used_custom()
764
- tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
765
- return gr.update(visible=True, value=custom_ckpt_path), gr.update(visible=True, value=custom_vocab_path)
 
 
 
 
766
  else:
767
  tts_model_choice = new_choice
768
- return gr.update(visible=False), gr.update(visible=False)
769
 
770
- def set_custom_model(custom_ckpt_path, custom_vocab_path):
771
  global tts_model_choice
772
- tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path]
773
- with open(last_used_custom, "w") as f:
774
- f.write(f"{custom_ckpt_path},{custom_vocab_path}")
775
 
776
  with gr.Row():
777
  if not USING_SPACES:
@@ -783,34 +814,46 @@ If you're having issues, try converting your reference audio to WAV or MP3, clip
783
  choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
784
  )
785
  custom_ckpt_path = gr.Dropdown(
786
- choices=["hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"],
787
  value=load_last_used_custom()[0],
788
  allow_custom_value=True,
789
- label="MODEL CKPT: local_path | hf://user_id/repo_id/model_ckpt",
790
  visible=False,
791
  )
792
  custom_vocab_path = gr.Dropdown(
793
- choices=["hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt"],
794
  value=load_last_used_custom()[1],
795
  allow_custom_value=True,
796
- label="VOCAB FILE: local_path | hf://user_id/repo_id/vocab_file",
 
 
 
 
 
 
 
797
  visible=False,
798
  )
799
 
800
  choose_tts_model.change(
801
  switch_tts_model,
802
  inputs=[choose_tts_model],
803
- outputs=[custom_ckpt_path, custom_vocab_path],
804
  show_progress="hidden",
805
  )
806
  custom_ckpt_path.change(
807
  set_custom_model,
808
- inputs=[custom_ckpt_path, custom_vocab_path],
809
  show_progress="hidden",
810
  )
811
  custom_vocab_path.change(
812
  set_custom_model,
813
- inputs=[custom_ckpt_path, custom_vocab_path],
 
 
 
 
 
814
  show_progress="hidden",
815
  )
816
 
 
1
  # ruff: noqa: E402
2
  # Above allows ruff to ignore E402: module level import not at top of file
3
 
4
+ import json
5
  import re
6
  import tempfile
7
  from collections import OrderedDict
 
44
  DEFAULT_TTS_MODEL = "F5-TTS"
45
  tts_model_choice = DEFAULT_TTS_MODEL
46
 
47
+ DEFAULT_TTS_MODEL_CFG = [
48
+ "hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors",
49
+ "hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt",
50
+ json.dumps(dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)),
51
+ ]
52
+
53
 
54
  # load models
55
 
 
110
 
111
  @gpu_decorator
112
  def infer(
113
+ ref_audio_orig,
114
+ ref_text,
115
+ gen_text,
116
+ model,
117
+ remove_silence,
118
+ cross_fade_duration=0.15,
119
+ nfe_step=32,
120
+ speed=1,
121
+ show_info=gr.Info,
122
  ):
123
  ref_audio, ref_text = preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=show_info)
124
 
 
135
  global custom_ema_model, pre_custom_path
136
  if pre_custom_path != model[1]:
137
  show_info("Loading Custom TTS model...")
138
+ custom_ema_model = load_custom(model[1], vocab_path=model[2], model_cfg=model[3])
139
  pre_custom_path = model[1]
140
  ema_model = custom_ema_model
141
 
 
146
  ema_model,
147
  vocoder,
148
  cross_fade_duration=cross_fade_duration,
149
+ nfe_step=nfe_step,
150
  speed=speed,
151
  show_info=show_info,
152
  progress=gr.Progress(),
 
200
  step=0.1,
201
  info="Adjust the speed of the audio.",
202
  )
203
+ nfe_slider = gr.Slider(
204
+ label="NFE Steps",
205
+ minimum=4,
206
+ maximum=64,
207
+ value=32,
208
+ step=2,
209
+ info="Set the number of denoising steps.",
210
+ )
211
  cross_fade_duration_slider = gr.Slider(
212
  label="Cross-Fade Duration (s)",
213
  minimum=0.0,
 
227
  gen_text_input,
228
  remove_silence,
229
  cross_fade_duration_slider,
230
+ nfe_slider,
231
  speed_slider,
232
  ):
233
  audio_out, spectrogram_path, ref_text_out = infer(
 
236
  gen_text_input,
237
  tts_model_choice,
238
  remove_silence,
239
+ cross_fade_duration=cross_fade_duration_slider,
240
+ nfe_step=nfe_slider,
241
+ speed=speed_slider,
242
  )
243
  return audio_out, spectrogram_path, gr.update(value=ref_text_out)
244
 
 
250
  gen_text_input,
251
  remove_silence,
252
  cross_fade_duration_slider,
253
+ nfe_slider,
254
  speed_slider,
255
  ],
256
  outputs=[audio_output, spectrogram_output, ref_text_input],
 
771
  """
772
  )
773
 
774
+ last_used_custom = files("f5_tts").joinpath("infer/.cache/last_used_custom_model_info.txt")
775
 
776
  def load_last_used_custom():
777
  try:
778
+ custom = []
779
+ with open(last_used_custom, "r", encoding="utf-8") as f:
780
+ for line in f:
781
+ custom.append(line.strip())
782
+ return custom
783
  except FileNotFoundError:
784
  last_used_custom.parent.mkdir(parents=True, exist_ok=True)
785
+ return DEFAULT_TTS_MODEL_CFG
 
 
 
786
 
787
  def switch_tts_model(new_choice):
788
  global tts_model_choice
789
  if new_choice == "Custom": # override in case webpage is refreshed
790
+ custom_ckpt_path, custom_vocab_path, custom_model_cfg = load_last_used_custom()
791
+ tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
792
+ return (
793
+ gr.update(visible=True, value=custom_ckpt_path),
794
+ gr.update(visible=True, value=custom_vocab_path),
795
+ gr.update(visible=True, value=custom_model_cfg),
796
+ )
797
  else:
798
  tts_model_choice = new_choice
799
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
800
 
801
+ def set_custom_model(custom_ckpt_path, custom_vocab_path, custom_model_cfg):
802
  global tts_model_choice
803
+ tts_model_choice = ["Custom", custom_ckpt_path, custom_vocab_path, json.loads(custom_model_cfg)]
804
+ with open(last_used_custom, "w", encoding="utf-8") as f:
805
+ f.write(custom_ckpt_path + "\n" + custom_vocab_path + "\n" + custom_model_cfg + "\n")
806
 
807
  with gr.Row():
808
  if not USING_SPACES:
 
814
  choices=[DEFAULT_TTS_MODEL, "E2-TTS"], label="Choose TTS Model", value=DEFAULT_TTS_MODEL
815
  )
816
  custom_ckpt_path = gr.Dropdown(
817
+ choices=[DEFAULT_TTS_MODEL_CFG[0]],
818
  value=load_last_used_custom()[0],
819
  allow_custom_value=True,
820
+ label="Model: local_path | hf://user_id/repo_id/model_ckpt",
821
  visible=False,
822
  )
823
  custom_vocab_path = gr.Dropdown(
824
+ choices=[DEFAULT_TTS_MODEL_CFG[1]],
825
  value=load_last_used_custom()[1],
826
  allow_custom_value=True,
827
+ label="Vocab: local_path | hf://user_id/repo_id/vocab_file",
828
+ visible=False,
829
+ )
830
+ custom_model_cfg = gr.Dropdown(
831
+ choices=[DEFAULT_TTS_MODEL_CFG[2]],
832
+ value=load_last_used_custom()[2],
833
+ allow_custom_value=True,
834
+ label="Config: in a dictionary form",
835
  visible=False,
836
  )
837
 
838
  choose_tts_model.change(
839
  switch_tts_model,
840
  inputs=[choose_tts_model],
841
+ outputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
842
  show_progress="hidden",
843
  )
844
  custom_ckpt_path.change(
845
  set_custom_model,
846
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
847
  show_progress="hidden",
848
  )
849
  custom_vocab_path.change(
850
  set_custom_model,
851
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
852
+ show_progress="hidden",
853
+ )
854
+ custom_model_cfg.change(
855
+ set_custom_model,
856
+ inputs=[custom_ckpt_path, custom_vocab_path, custom_model_cfg],
857
  show_progress="hidden",
858
  )
859
 
src/f5_tts/configs/F5TTS_Base_train.yaml CHANGED
@@ -28,6 +28,7 @@ model:
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
 
31
  mel_spec:
32
  target_sample_rate: 24000
33
  n_mel_channels: 100
 
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
31
+ checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
src/f5_tts/configs/F5TTS_Small_train.yaml CHANGED
@@ -28,6 +28,7 @@ model:
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
 
31
  mel_spec:
32
  target_sample_rate: 24000
33
  n_mel_channels: 100
 
28
  ff_mult: 2
29
  text_dim: 512
30
  conv_layers: 4
31
+ checkpoint_activations: False # recompute activations and save memory for extra compute
32
  mel_spec:
33
  target_sample_rate: 24000
34
  n_mel_channels: 100
src/f5_tts/eval/README.md CHANGED
@@ -39,11 +39,14 @@ Then update in the following scripts with the paths you put evaluation model ckp
39
 
40
  ### Objective Evaluation
41
 
42
- Update the path with your batch-inferenced results, and carry out WER / SIM evaluations:
43
  ```bash
44
- # Evaluation for Seed-TTS test set
45
- python src/f5_tts/eval/eval_seedtts_testset.py --gen_wav_dir <GEN_WAVE_DIR>
46
 
47
- # Evaluation for LibriSpeech-PC test-clean (cross-sentence)
48
- python src/f5_tts/eval/eval_librispeech_test_clean.py --gen_wav_dir <GEN_WAVE_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
49
- ```
 
 
 
 
39
 
40
  ### Objective Evaluation
41
 
42
+ Update the path with your batch-inferenced results, and carry out WER / SIM / UTMOS evaluations:
43
  ```bash
44
+ # Evaluation [WER] for Seed-TTS test [ZH] set
45
+ python src/f5_tts/eval/eval_seedtts_testset.py --eval_task wer --lang zh --gen_wav_dir <GEN_WAV_DIR> --gpu_nums 8
46
 
47
+ # Evaluation [SIM] for LibriSpeech-PC test-clean (cross-sentence)
48
+ python src/f5_tts/eval/eval_librispeech_test_clean.py --eval_task sim --gen_wav_dir <GEN_WAV_DIR> --librispeech_test_clean_path <TEST_CLEAN_PATH>
49
+
50
+ # Evaluation [UTMOS]. --ext: Audio extension
51
+ python src/f5_tts/eval/eval_utmos.py --audio_dir <WAV_DIR> --ext wav
52
+ ```
src/f5_tts/eval/eval_librispeech_test_clean.py CHANGED
@@ -1,8 +1,9 @@
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
3
- import sys
4
- import os
5
  import argparse
 
 
 
6
 
7
  sys.path.append(os.getcwd())
8
 
@@ -10,7 +11,6 @@ import multiprocessing as mp
10
  from importlib.resources import files
11
 
12
  import numpy as np
13
-
14
  from f5_tts.eval.utils_eval import (
15
  get_librispeech_test,
16
  run_asr_wer,
@@ -54,29 +54,41 @@ def main():
54
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
 
56
  # --------------------------- WER ---------------------------
 
57
  if eval_task == "wer":
 
58
  wers = []
 
59
  with mp.Pool(processes=len(gpus)) as pool:
60
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
61
  results = pool.map(run_asr_wer, args)
62
- for wers_ in results:
63
- wers.extend(wers_)
 
 
 
 
 
 
 
64
 
65
  wer = round(np.mean(wers) * 100, 3)
66
  print(f"\nTotal {len(wers)} samples")
67
  print(f"WER : {wer}%")
 
68
 
69
  # --------------------------- SIM ---------------------------
 
70
  if eval_task == "sim":
71
- sim_list = []
72
  with mp.Pool(processes=len(gpus)) as pool:
73
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
74
  results = pool.map(run_sim, args)
75
- for sim_ in results:
76
- sim_list.extend(sim_)
77
 
78
- sim = round(sum(sim_list) / len(sim_list), 3)
79
- print(f"\nTotal {len(sim_list)} samples")
80
  print(f"SIM : {sim}")
81
 
82
 
 
1
  # Evaluate with Librispeech test-clean, ~3s prompt to generate 4-10s audio (the way of valle/voicebox evaluation)
2
 
 
 
3
  import argparse
4
+ import json
5
+ import os
6
+ import sys
7
 
8
  sys.path.append(os.getcwd())
9
 
 
11
  from importlib.resources import files
12
 
13
  import numpy as np
 
14
  from f5_tts.eval.utils_eval import (
15
  get_librispeech_test,
16
  run_asr_wer,
 
54
  wavlm_ckpt_dir = "../checkpoints/UniSpeech/wavlm_large_finetune.pth"
55
 
56
  # --------------------------- WER ---------------------------
57
+
58
  if eval_task == "wer":
59
+ wer_results = []
60
  wers = []
61
+
62
  with mp.Pool(processes=len(gpus)) as pool:
63
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
64
  results = pool.map(run_asr_wer, args)
65
+ for r in results:
66
+ wer_results.extend(r)
67
+
68
+ wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
69
+ with open(wer_result_path, "w") as f:
70
+ for line in wer_results:
71
+ wers.append(line["wer"])
72
+ json_line = json.dumps(line, ensure_ascii=False)
73
+ f.write(json_line + "\n")
74
 
75
  wer = round(np.mean(wers) * 100, 3)
76
  print(f"\nTotal {len(wers)} samples")
77
  print(f"WER : {wer}%")
78
+ print(f"Results have been saved to {wer_result_path}")
79
 
80
  # --------------------------- SIM ---------------------------
81
+
82
  if eval_task == "sim":
83
+ sims = []
84
  with mp.Pool(processes=len(gpus)) as pool:
85
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
86
  results = pool.map(run_sim, args)
87
+ for r in results:
88
+ sims.extend(r)
89
 
90
+ sim = round(sum(sims) / len(sims), 3)
91
+ print(f"\nTotal {len(sims)} samples")
92
  print(f"SIM : {sim}")
93
 
94
 
src/f5_tts/eval/eval_seedtts_testset.py CHANGED
@@ -1,8 +1,9 @@
1
  # Evaluate with Seed-TTS testset
2
 
3
- import sys
4
- import os
5
  import argparse
 
 
 
6
 
7
  sys.path.append(os.getcwd())
8
 
@@ -10,7 +11,6 @@ import multiprocessing as mp
10
  from importlib.resources import files
11
 
12
  import numpy as np
13
-
14
  from f5_tts.eval.utils_eval import (
15
  get_seed_tts_test,
16
  run_asr_wer,
@@ -55,28 +55,39 @@ def main():
55
  # --------------------------- WER ---------------------------
56
 
57
  if eval_task == "wer":
 
58
  wers = []
 
59
  with mp.Pool(processes=len(gpus)) as pool:
60
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
61
  results = pool.map(run_asr_wer, args)
62
- for wers_ in results:
63
- wers.extend(wers_)
 
 
 
 
 
 
 
64
 
65
  wer = round(np.mean(wers) * 100, 3)
66
  print(f"\nTotal {len(wers)} samples")
67
  print(f"WER : {wer}%")
 
68
 
69
  # --------------------------- SIM ---------------------------
 
70
  if eval_task == "sim":
71
- sim_list = []
72
  with mp.Pool(processes=len(gpus)) as pool:
73
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
74
  results = pool.map(run_sim, args)
75
- for sim_ in results:
76
- sim_list.extend(sim_)
77
 
78
- sim = round(sum(sim_list) / len(sim_list), 3)
79
- print(f"\nTotal {len(sim_list)} samples")
80
  print(f"SIM : {sim}")
81
 
82
 
 
1
  # Evaluate with Seed-TTS testset
2
 
 
 
3
  import argparse
4
+ import json
5
+ import os
6
+ import sys
7
 
8
  sys.path.append(os.getcwd())
9
 
 
11
  from importlib.resources import files
12
 
13
  import numpy as np
 
14
  from f5_tts.eval.utils_eval import (
15
  get_seed_tts_test,
16
  run_asr_wer,
 
55
  # --------------------------- WER ---------------------------
56
 
57
  if eval_task == "wer":
58
+ wer_results = []
59
  wers = []
60
+
61
  with mp.Pool(processes=len(gpus)) as pool:
62
  args = [(rank, lang, sub_test_set, asr_ckpt_dir) for (rank, sub_test_set) in test_set]
63
  results = pool.map(run_asr_wer, args)
64
+ for r in results:
65
+ wer_results.extend(r)
66
+
67
+ wer_result_path = f"{gen_wav_dir}/{lang}_wer_results.jsonl"
68
+ with open(wer_result_path, "w") as f:
69
+ for line in wer_results:
70
+ wers.append(line["wer"])
71
+ json_line = json.dumps(line, ensure_ascii=False)
72
+ f.write(json_line + "\n")
73
 
74
  wer = round(np.mean(wers) * 100, 3)
75
  print(f"\nTotal {len(wers)} samples")
76
  print(f"WER : {wer}%")
77
+ print(f"Results have been saved to {wer_result_path}")
78
 
79
  # --------------------------- SIM ---------------------------
80
+
81
  if eval_task == "sim":
82
+ sims = []
83
  with mp.Pool(processes=len(gpus)) as pool:
84
  args = [(rank, sub_test_set, wavlm_ckpt_dir) for (rank, sub_test_set) in test_set]
85
  results = pool.map(run_sim, args)
86
+ for r in results:
87
+ sims.extend(r)
88
 
89
+ sim = round(sum(sims) / len(sims), 3)
90
+ print(f"\nTotal {len(sims)} samples")
91
  print(f"SIM : {sim}")
92
 
93
 
src/f5_tts/eval/eval_utmos.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import librosa
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+
10
+ def main():
11
+ parser = argparse.ArgumentParser(description="UTMOS Evaluation")
12
+ parser.add_argument("--audio_dir", type=str, required=True, help="Audio file path.")
13
+ parser.add_argument("--ext", type=str, default="wav", help="Audio extension.")
14
+ args = parser.parse_args()
15
+
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
19
+ predictor = predictor.to(device)
20
+
21
+ audio_paths = list(Path(args.audio_dir).rglob(f"*.{args.ext}"))
22
+ utmos_results = {}
23
+ utmos_score = 0
24
+
25
+ for audio_path in tqdm(audio_paths, desc="Processing"):
26
+ wav_name = audio_path.stem
27
+ wav, sr = librosa.load(audio_path, sr=None, mono=True)
28
+ wav_tensor = torch.from_numpy(wav).to(device).unsqueeze(0)
29
+ score = predictor(wav_tensor, sr)
30
+ utmos_results[str(wav_name)] = score.item()
31
+ utmos_score += score.item()
32
+
33
+ avg_score = utmos_score / len(audio_paths) if len(audio_paths) > 0 else 0
34
+ print(f"UTMOS: {avg_score}")
35
+
36
+ utmos_result_path = Path(args.audio_dir) / "utmos_results.json"
37
+ with open(utmos_result_path, "w", encoding="utf-8") as f:
38
+ json.dump(utmos_results, f, ensure_ascii=False, indent=4)
39
+
40
+ print(f"Results have been saved to {utmos_result_path}")
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()
src/f5_tts/eval/utils_eval.py CHANGED
@@ -2,6 +2,7 @@ import math
2
  import os
3
  import random
4
  import string
 
5
 
6
  import torch
7
  import torch.nn.functional as F
@@ -320,7 +321,7 @@ def run_asr_wer(args):
320
  from zhon.hanzi import punctuation
321
 
322
  punctuation_all = punctuation + string.punctuation
323
- wers = []
324
 
325
  from jiwer import compute_measures
326
 
@@ -335,8 +336,8 @@ def run_asr_wer(args):
335
  for segment in segments:
336
  hypo = hypo + " " + segment.text
337
 
338
- # raw_truth = truth
339
- # raw_hypo = hypo
340
 
341
  for x in punctuation_all:
342
  truth = truth.replace(x, "")
@@ -360,9 +361,16 @@ def run_asr_wer(args):
360
  # dele = measures["deletions"] / len(ref_list)
361
  # inse = measures["insertions"] / len(ref_list)
362
 
363
- wers.append(wer)
 
 
 
 
 
 
 
364
 
365
- return wers
366
 
367
 
368
  # SIM Evaluation
@@ -381,7 +389,7 @@ def run_sim(args):
381
  model = model.cuda(device)
382
  model.eval()
383
 
384
- sim_list = []
385
  for wav1, wav2, truth in tqdm(test_set):
386
  wav1, sr1 = torchaudio.load(wav1)
387
  wav2, sr2 = torchaudio.load(wav2)
@@ -400,6 +408,6 @@ def run_sim(args):
400
 
401
  sim = F.cosine_similarity(emb1, emb2)[0].item()
402
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
403
- sim_list.append(sim)
404
 
405
- return sim_list
 
2
  import os
3
  import random
4
  import string
5
+ from pathlib import Path
6
 
7
  import torch
8
  import torch.nn.functional as F
 
321
  from zhon.hanzi import punctuation
322
 
323
  punctuation_all = punctuation + string.punctuation
324
+ wer_results = []
325
 
326
  from jiwer import compute_measures
327
 
 
336
  for segment in segments:
337
  hypo = hypo + " " + segment.text
338
 
339
+ raw_truth = truth
340
+ raw_hypo = hypo
341
 
342
  for x in punctuation_all:
343
  truth = truth.replace(x, "")
 
361
  # dele = measures["deletions"] / len(ref_list)
362
  # inse = measures["insertions"] / len(ref_list)
363
 
364
+ wer_results.append(
365
+ {
366
+ "wav": Path(gen_wav).stem,
367
+ "truth": raw_truth,
368
+ "hypo": raw_hypo,
369
+ "wer": wer,
370
+ }
371
+ )
372
 
373
+ return wer_results
374
 
375
 
376
  # SIM Evaluation
 
389
  model = model.cuda(device)
390
  model.eval()
391
 
392
+ sims = []
393
  for wav1, wav2, truth in tqdm(test_set):
394
  wav1, sr1 = torchaudio.load(wav1)
395
  wav2, sr2 = torchaudio.load(wav2)
 
408
 
409
  sim = F.cosine_similarity(emb1, emb2)[0].item()
410
  # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
411
+ sims.append(sim)
412
 
413
+ return sims
src/f5_tts/infer/README.md CHANGED
@@ -64,6 +64,9 @@ f5-tts_infer-cli \
64
  # Choose Vocoder
65
  f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
66
  f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
 
 
 
67
  ```
68
 
69
  And a `.toml` file would help with more flexible usage.
 
64
  # Choose Vocoder
65
  f5-tts_infer-cli --vocoder_name bigvgan --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base_bigvgan/model_1250000.pt>
66
  f5-tts_infer-cli --vocoder_name vocos --load_vocoder_from_local --ckpt_file <YOUR_CKPT_PATH, eg:ckpts/F5TTS_Base/model_1200000.safetensors>
67
+
68
+ # More instructions
69
+ f5-tts_infer-cli --help
70
  ```
71
 
72
  And a `.toml` file would help with more flexible usage.
src/f5_tts/infer/SHARED.md CHANGED
@@ -16,31 +16,34 @@
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
- - [F5-TTS Base @ pretrain @ zh \& en](#f5-tts-base--pretrain--zh--en)
20
  - [English](#english)
21
  - [Finnish](#finnish)
22
- - [Finnish Common\_Voice Vox\_Populi @ finetune @ fi](#finnish-common_voice-vox_populi--finetune--fi)
23
  - [French](#french)
24
- - [French LibriVox @ finetune @ fr](#french-librivox--finetune--fr)
 
 
25
  - [Italian](#italian)
26
- - [F5-TTS Italian @ finetune @ it](#f5-tts-italian--finetune--it)
27
  - [Japanese](#japanese)
28
- - [F5-TTS Japanese @ pretrain/finetune @ ja](#f5-tts-japanese--pretrainfinetune--ja)
29
  - [Mandarin](#mandarin)
30
  - [Spanish](#spanish)
31
- - [F5-TTS Spanish @ pretrain/finetune @ es](#f5-tts-spanish--pretrainfinetune--es)
32
 
33
 
34
  ## Multilingual
35
 
36
- #### F5-TTS Base @ pretrain @ zh & en
37
  |Model|🤗Hugging Face|Data (Hours)|Model License|
38
  |:---:|:------------:|:-----------:|:-------------:|
39
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
40
 
41
  ```bash
42
- MODEL_CKPT: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
43
- VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
 
44
  ```
45
 
46
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
@@ -51,27 +54,29 @@ VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
51
 
52
  ## Finnish
53
 
54
- #### Finnish Common_Voice Vox_Populi @ finetune @ fi
55
  |Model|🤗Hugging Face|Data|Model License|
56
  |:---:|:------------:|:-----------:|:-------------:|
57
- |F5-TTS Finnish|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|
58
 
59
  ```bash
60
- MODEL_CKPT: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
61
- VOCAB_FILE: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
 
62
  ```
63
 
64
 
65
  ## French
66
 
67
- #### French LibriVox @ finetune @ fr
68
  |Model|🤗Hugging Face|Data (Hours)|Model License|
69
  |:---:|:------------:|:-----------:|:-------------:|
70
- |F5-TTS French|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
71
 
72
  ```bash
73
- MODEL_CKPT: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
74
- VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
 
75
  ```
76
 
77
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
@@ -79,16 +84,34 @@ VOCAB_FILE: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
79
  - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ## Italian
83
 
84
- #### F5-TTS Italian @ finetune @ it
85
  |Model|🤗Hugging Face|Data|Model License|
86
  |:---:|:------------:|:-----------:|:-------------:|
87
- |F5-TTS Italian|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|
88
 
89
  ```bash
90
- MODEL_CKPT: hf://alien79/F5-TTS-italian/model_159600.safetensors
91
- VOCAB_FILE: hf://alien79/F5-TTS-italian/vocab.txt
 
92
  ```
93
 
94
  - Trained by [Mithril Man](https://github.com/MithrilMan)
@@ -98,14 +121,15 @@ VOCAB_FILE: hf://alien79/F5-TTS-italian/vocab.txt
98
 
99
  ## Japanese
100
 
101
- #### F5-TTS Japanese @ pretrain/finetune @ ja
102
  |Model|🤗Hugging Face|Data (Hours)|Model License|
103
  |:---:|:------------:|:-----------:|:-------------:|
104
- |F5-TTS Japanese|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
105
 
106
  ```bash
107
- MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
108
- VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
 
109
  ```
110
 
111
 
@@ -114,9 +138,9 @@ VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
114
 
115
  ## Spanish
116
 
117
- #### F5-TTS Spanish @ pretrain/finetune @ es
118
  |Model|🤗Hugging Face|Data (Hours)|Model License|
119
  |:---:|:------------:|:-----------:|:-------------:|
120
- |F5-TTS Spanish|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|
121
 
122
  - @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.
 
16
  <!-- omit in toc -->
17
  ### Supported Languages
18
  - [Multilingual](#multilingual)
19
+ - [F5-TTS Base @ zh \& en @ F5-TTS](#f5-tts-base--zh--en--f5-tts)
20
  - [English](#english)
21
  - [Finnish](#finnish)
22
+ - [F5-TTS Base @ fi @ AsmoKoskinen](#f5-tts-base--fi--asmokoskinen)
23
  - [French](#french)
24
+ - [F5-TTS Base @ fr @ RASPIAUDIO](#f5-tts-base--fr--raspiaudio)
25
+ - [Hindi](#hindi)
26
+ - [F5-TTS Small @ hi @ SPRINGLab](#f5-tts-small--hi--springlab)
27
  - [Italian](#italian)
28
+ - [F5-TTS Base @ it @ alien79](#f5-tts-base--it--alien79)
29
  - [Japanese](#japanese)
30
+ - [F5-TTS Base @ ja @ Jmica](#f5-tts-base--ja--jmica)
31
  - [Mandarin](#mandarin)
32
  - [Spanish](#spanish)
33
+ - [F5-TTS Base @ es @ jpgallegoar](#f5-tts-base--es--jpgallegoar)
34
 
35
 
36
  ## Multilingual
37
 
38
+ #### F5-TTS Base @ zh & en @ F5-TTS
39
  |Model|🤗Hugging Face|Data (Hours)|Model License|
40
  |:---:|:------------:|:-----------:|:-------------:|
41
  |F5-TTS Base|[ckpt & vocab](https://huggingface.co/SWivid/F5-TTS/tree/main/F5TTS_Base)|[Emilia 95K zh&en](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07)|cc-by-nc-4.0|
42
 
43
  ```bash
44
+ Model: hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors
45
+ Vocab: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
46
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
47
  ```
48
 
49
  *Other infos, e.g. Author info, Github repo, Link to some sampled results, Usage instruction, Tutorial (Blog, Video, etc.) ...*
 
54
 
55
  ## Finnish
56
 
57
+ #### F5-TTS Base @ fi @ AsmoKoskinen
58
  |Model|🤗Hugging Face|Data|Model License|
59
  |:---:|:------------:|:-----------:|:-------------:|
60
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/AsmoKoskinen/F5-TTS_Finnish_Model)|[Common Voice](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0), [Vox Populi](https://huggingface.co/datasets/facebook/voxpopuli)|cc-by-nc-4.0|
61
 
62
  ```bash
63
+ Model: hf://AsmoKoskinen/F5-TTS_Finnish_Model/model_common_voice_fi_vox_populi_fi_20241206.safetensors
64
+ Vocab: hf://AsmoKoskinen/F5-TTS_Finnish_Model/vocab.txt
65
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
66
  ```
67
 
68
 
69
  ## French
70
 
71
+ #### F5-TTS Base @ fr @ RASPIAUDIO
72
  |Model|🤗Hugging Face|Data (Hours)|Model License|
73
  |:---:|:------------:|:-----------:|:-------------:|
74
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/RASPIAUDIO/F5-French-MixedSpeakers-reduced)|[LibriVox](https://librivox.org/)|cc-by-nc-4.0|
75
 
76
  ```bash
77
+ Model: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/model_last_reduced.pt
78
+ Vocab: hf://RASPIAUDIO/F5-French-MixedSpeakers-reduced/vocab.txt
79
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
80
  ```
81
 
82
  - [Online Inference with Hugging Face Space](https://huggingface.co/spaces/RASPIAUDIO/f5-tts_french).
 
84
  - [Discussion about this training can be found here](https://github.com/SWivid/F5-TTS/issues/434).
85
 
86
 
87
+ ## Hindi
88
+
89
+ #### F5-TTS Small @ hi @ SPRINGLab
90
+ |Model|🤗Hugging Face|Data (Hours)|Model License|
91
+ |:---:|:------------:|:-----------:|:-------------:|
92
+ |F5-TTS Small|[ckpt & vocab](https://huggingface.co/SPRINGLab/F5-Hindi-24KHz)|[IndicTTS Hi](https://huggingface.co/datasets/SPRINGLab/IndicTTS-Hindi) & [IndicVoices-R Hi](https://huggingface.co/datasets/SPRINGLab/IndicVoices-R_Hindi) |cc-by-4.0|
93
+
94
+ ```bash
95
+ Model: hf://SPRINGLab/F5-Hindi-24KHz/model_2500000.safetensors
96
+ Vocab: hf://SPRINGLab/F5-Hindi-24KHz/vocab.txt
97
+ Config: {"dim": 768, "depth": 18, "heads": 12, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
98
+ ```
99
+
100
+ - Authors: SPRING Lab, Indian Institute of Technology, Madras
101
+ - Website: https://asr.iitm.ac.in/
102
+
103
+
104
  ## Italian
105
 
106
+ #### F5-TTS Base @ it @ alien79
107
  |Model|🤗Hugging Face|Data|Model License|
108
  |:---:|:------------:|:-----------:|:-------------:|
109
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/alien79/F5-TTS-italian)|[ylacombe/cml-tts](https://huggingface.co/datasets/ylacombe/cml-tts) |cc-by-nc-4.0|
110
 
111
  ```bash
112
+ Model: hf://alien79/F5-TTS-italian/model_159600.safetensors
113
+ Vocab: hf://alien79/F5-TTS-italian/vocab.txt
114
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
115
  ```
116
 
117
  - Trained by [Mithril Man](https://github.com/MithrilMan)
 
121
 
122
  ## Japanese
123
 
124
+ #### F5-TTS Base @ ja @ Jmica
125
  |Model|🤗Hugging Face|Data (Hours)|Model License|
126
  |:---:|:------------:|:-----------:|:-------------:|
127
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
128
 
129
  ```bash
130
+ Model: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
131
+ Vocab: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
132
+ Config: {"dim": 1024, "depth": 22, "heads": 16, "ff_mult": 2, "text_dim": 512, "conv_layers": 4}
133
  ```
134
 
135
 
 
138
 
139
  ## Spanish
140
 
141
+ #### F5-TTS Base @ es @ jpgallegoar
142
  |Model|🤗Hugging Face|Data (Hours)|Model License|
143
  |:---:|:------------:|:-----------:|:-------------:|
144
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/jpgallegoar/F5-Spanish)|[Voxpopuli](https://huggingface.co/datasets/facebook/voxpopuli) & Crowdsourced & TEDx, 218 hours|cc0-1.0|
145
 
146
  - @jpgallegoar [GitHub repo](https://github.com/jpgallegoar/Spanish-F5), Jupyter Notebook and Gradio usage for Spanish model.
src/f5_tts/infer/examples/basic/basic.toml CHANGED
@@ -8,4 +8,4 @@ gen_text = "I don't really care what you call me. I've been a silent spectator,
8
  gen_file = ""
9
  remove_silence = false
10
  output_dir = "tests"
11
- output_file = "infer_cli_out.wav"
 
8
  gen_file = ""
9
  remove_silence = false
10
  output_dir = "tests"
11
+ output_file = "infer_cli_basic.wav"
src/f5_tts/infer/examples/multi/story.toml CHANGED
@@ -8,6 +8,7 @@ gen_text = ""
8
  gen_file = "infer/examples/multi/story.txt"
9
  remove_silence = true
10
  output_dir = "tests"
 
11
 
12
  [voices.town]
13
  ref_audio = "infer/examples/multi/town.flac"
 
8
  gen_file = "infer/examples/multi/story.txt"
9
  remove_silence = true
10
  output_dir = "tests"
11
+ output_file = "infer_cli_story.wav"
12
 
13
  [voices.town]
14
  ref_audio = "infer/examples/multi/town.flac"
src/f5_tts/infer/infer_cli.py CHANGED
@@ -2,6 +2,7 @@ import argparse
2
  import codecs
3
  import os
4
  import re
 
5
  from importlib.resources import files
6
  from pathlib import Path
7
 
@@ -9,8 +10,17 @@ import numpy as np
9
  import soundfile as sf
10
  import tomli
11
  from cached_path import cached_path
 
12
 
13
  from f5_tts.infer.utils_infer import (
 
 
 
 
 
 
 
 
14
  infer_process,
15
  load_model,
16
  load_vocoder,
@@ -19,6 +29,7 @@ from f5_tts.infer.utils_infer import (
19
  )
20
  from f5_tts.model import DiT, UNetT
21
 
 
22
  parser = argparse.ArgumentParser(
23
  prog="python3 infer-cli.py",
24
  description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
@@ -27,74 +38,168 @@ parser = argparse.ArgumentParser(
27
  parser.add_argument(
28
  "-c",
29
  "--config",
30
- help="Configuration file. Default=infer/examples/basic/basic.toml",
31
  default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
 
32
  )
 
 
 
 
33
  parser.add_argument(
34
  "-m",
35
  "--model",
36
- help="F5-TTS | E2-TTS",
 
 
 
 
 
 
 
37
  )
38
  parser.add_argument(
39
  "-p",
40
  "--ckpt_file",
41
- help="The Checkpoint .pt",
 
42
  )
43
  parser.add_argument(
44
  "-v",
45
  "--vocab_file",
46
- help="The vocab .txt",
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
48
- parser.add_argument("-r", "--ref_audio", type=str, help="Reference audio file < 15 seconds.")
49
- parser.add_argument("-s", "--ref_text", type=str, default="666", help="Subtitle for the reference audio.")
50
  parser.add_argument(
51
  "-t",
52
  "--gen_text",
53
  type=str,
54
- help="Text to generate.",
55
  )
56
  parser.add_argument(
57
  "-f",
58
  "--gen_file",
59
  type=str,
60
- help="File with text to generate. Ignores --gen_text",
61
  )
62
  parser.add_argument(
63
  "-o",
64
  "--output_dir",
65
  type=str,
66
- help="Path to output folder..",
67
  )
68
  parser.add_argument(
69
  "-w",
70
  "--output_file",
71
  type=str,
72
- help="Filename of output file..",
 
 
 
 
 
73
  )
74
  parser.add_argument(
75
  "--remove_silence",
76
- help="Remove silence.",
 
77
  )
78
- parser.add_argument("--vocoder_name", type=str, default="vocos", choices=["vocos", "bigvgan"], help="vocoder name")
79
  parser.add_argument(
80
  "--load_vocoder_from_local",
81
  action="store_true",
82
- help="load vocoder from local. Default: ../checkpoints/charactr/vocos-mel-24khz",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
  parser.add_argument(
85
  "--speed",
86
  type=float,
87
- default=1.0,
88
- help="Adjust the speed of the audio generation (default: 1.0)",
 
 
 
 
89
  )
90
  args = parser.parse_args()
91
 
 
 
 
92
  config = tomli.load(open(args.config, "rb"))
93
 
94
- ref_audio = args.ref_audio if args.ref_audio else config["ref_audio"]
95
- ref_text = args.ref_text if args.ref_text != "666" else config["ref_text"]
96
- gen_text = args.gen_text if args.gen_text else config["gen_text"]
97
- gen_file = args.gen_file if args.gen_file else config["gen_file"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  # patches for pip pkg user
100
  if "infer/examples/" in ref_audio:
@@ -107,34 +212,39 @@ if "voices" in config:
107
  if "infer/examples/" in voice_ref_audio:
108
  config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
109
 
 
 
 
110
  if gen_file:
111
  gen_text = codecs.open(gen_file, "r", "utf-8").read()
112
- output_dir = args.output_dir if args.output_dir else config["output_dir"]
113
- output_file = args.output_file if args.output_file else config["output_file"]
114
- model = args.model if args.model else config["model"]
115
- ckpt_file = args.ckpt_file if args.ckpt_file else ""
116
- vocab_file = args.vocab_file if args.vocab_file else ""
117
- remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
118
- speed = args.speed
119
 
120
  wave_path = Path(output_dir) / output_file
121
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
 
 
 
 
 
 
 
122
 
123
- vocoder_name = args.vocoder_name
124
- mel_spec_type = args.vocoder_name
125
  if vocoder_name == "vocos":
126
  vocoder_local_path = "../checkpoints/vocos-mel-24khz"
127
  elif vocoder_name == "bigvgan":
128
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
129
 
130
- vocoder = load_vocoder(vocoder_name=mel_spec_type, is_local=args.load_vocoder_from_local, local_path=vocoder_local_path)
 
131
 
 
132
 
133
- # load models
134
  if model == "F5-TTS":
135
  model_cls = DiT
136
- model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
137
- if ckpt_file == "":
138
  if vocoder_name == "vocos":
139
  repo_name = "F5-TTS"
140
  exp_name = "F5TTS_Base"
@@ -148,22 +258,25 @@ if model == "F5-TTS":
148
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
149
 
150
  elif model == "E2-TTS":
151
- assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos"
 
152
  model_cls = UNetT
153
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
154
- if ckpt_file == "":
155
  repo_name = "E2-TTS"
156
  exp_name = "E2TTS_Base"
157
  ckpt_step = 1200000
158
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
159
  # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
160
 
161
-
162
  print(f"Using {model}...")
163
- ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=mel_spec_type, vocab_file=vocab_file)
 
164
 
 
165
 
166
- def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove_silence, speed):
 
167
  main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
168
  if "voices" not in config:
169
  voices = {"main": main_voice}
@@ -171,16 +284,16 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
171
  voices = config["voices"]
172
  voices["main"] = main_voice
173
  for voice in voices:
 
 
174
  voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
175
  voices[voice]["ref_audio"], voices[voice]["ref_text"]
176
  )
177
- print("Voice:", voice)
178
- print("Ref_audio:", voices[voice]["ref_audio"])
179
- print("Ref_text:", voices[voice]["ref_text"])
180
 
181
  generated_audio_segments = []
182
  reg1 = r"(?=\[\w+\])"
183
- chunks = re.split(reg1, text_gen)
184
  reg2 = r"\[(\w+)\]"
185
  for text in chunks:
186
  if not text.strip():
@@ -195,14 +308,35 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
195
  print(f"Voice {voice} not found, using main.")
196
  voice = "main"
197
  text = re.sub(reg2, "", text)
198
- gen_text = text.strip()
199
- ref_audio = voices[voice]["ref_audio"]
200
- ref_text = voices[voice]["ref_text"]
201
  print(f"Voice: {voice}")
202
- audio, final_sample_rate, spectragram = infer_process(
203
- ref_audio, ref_text, gen_text, model_obj, vocoder, mel_spec_type=mel_spec_type, speed=speed
 
 
 
 
 
 
 
 
 
 
 
 
204
  )
205
- generated_audio_segments.append(audio)
 
 
 
 
 
 
 
 
 
206
 
207
  if generated_audio_segments:
208
  final_wave = np.concatenate(generated_audio_segments)
@@ -218,9 +352,5 @@ def main_process(ref_audio, ref_text, text_gen, model_obj, mel_spec_type, remove
218
  print(f.name)
219
 
220
 
221
- def main():
222
- main_process(ref_audio, ref_text, gen_text, ema_model, mel_spec_type, remove_silence, speed)
223
-
224
-
225
  if __name__ == "__main__":
226
  main()
 
2
  import codecs
3
  import os
4
  import re
5
+ from datetime import datetime
6
  from importlib.resources import files
7
  from pathlib import Path
8
 
 
10
  import soundfile as sf
11
  import tomli
12
  from cached_path import cached_path
13
+ from omegaconf import OmegaConf
14
 
15
  from f5_tts.infer.utils_infer import (
16
+ mel_spec_type,
17
+ target_rms,
18
+ cross_fade_duration,
19
+ nfe_step,
20
+ cfg_strength,
21
+ sway_sampling_coef,
22
+ speed,
23
+ fix_duration,
24
  infer_process,
25
  load_model,
26
  load_vocoder,
 
29
  )
30
  from f5_tts.model import DiT, UNetT
31
 
32
+
33
  parser = argparse.ArgumentParser(
34
  prog="python3 infer-cli.py",
35
  description="Commandline interface for E2/F5 TTS with Advanced Batch Processing.",
 
38
  parser.add_argument(
39
  "-c",
40
  "--config",
41
+ type=str,
42
  default=os.path.join(files("f5_tts").joinpath("infer/examples/basic"), "basic.toml"),
43
+ help="The configuration file, default see infer/examples/basic/basic.toml",
44
  )
45
+
46
+
47
+ # Note. Not to provide default value here in order to read default from config file
48
+
49
  parser.add_argument(
50
  "-m",
51
  "--model",
52
+ type=str,
53
+ help="The model name: F5-TTS | E2-TTS",
54
+ )
55
+ parser.add_argument(
56
+ "-mc",
57
+ "--model_cfg",
58
+ type=str,
59
+ help="The path to F5-TTS model config file .yaml",
60
  )
61
  parser.add_argument(
62
  "-p",
63
  "--ckpt_file",
64
+ type=str,
65
+ help="The path to model checkpoint .pt, leave blank to use default",
66
  )
67
  parser.add_argument(
68
  "-v",
69
  "--vocab_file",
70
+ type=str,
71
+ help="The path to vocab file .txt, leave blank to use default",
72
+ )
73
+ parser.add_argument(
74
+ "-r",
75
+ "--ref_audio",
76
+ type=str,
77
+ help="The reference audio file.",
78
+ )
79
+ parser.add_argument(
80
+ "-s",
81
+ "--ref_text",
82
+ type=str,
83
+ help="The transcript/subtitle for the reference audio",
84
  )
 
 
85
  parser.add_argument(
86
  "-t",
87
  "--gen_text",
88
  type=str,
89
+ help="The text to make model synthesize a speech",
90
  )
91
  parser.add_argument(
92
  "-f",
93
  "--gen_file",
94
  type=str,
95
+ help="The file with text to generate, will ignore --gen_text",
96
  )
97
  parser.add_argument(
98
  "-o",
99
  "--output_dir",
100
  type=str,
101
+ help="The path to output folder",
102
  )
103
  parser.add_argument(
104
  "-w",
105
  "--output_file",
106
  type=str,
107
+ help="The name of output file",
108
+ )
109
+ parser.add_argument(
110
+ "--save_chunk",
111
+ action="store_true",
112
+ help="To save each audio chunks during inference",
113
  )
114
  parser.add_argument(
115
  "--remove_silence",
116
+ action="store_true",
117
+ help="To remove long silence found in ouput",
118
  )
 
119
  parser.add_argument(
120
  "--load_vocoder_from_local",
121
  action="store_true",
122
+ help="To load vocoder from local dir, default to ../checkpoints/vocos-mel-24khz",
123
+ )
124
+ parser.add_argument(
125
+ "--vocoder_name",
126
+ type=str,
127
+ choices=["vocos", "bigvgan"],
128
+ help=f"Used vocoder name: vocos | bigvgan, default {mel_spec_type}",
129
+ )
130
+ parser.add_argument(
131
+ "--target_rms",
132
+ type=float,
133
+ help=f"Target output speech loudness normalization value, default {target_rms}",
134
+ )
135
+ parser.add_argument(
136
+ "--cross_fade_duration",
137
+ type=float,
138
+ help=f"Duration of cross-fade between audio segments in seconds, default {cross_fade_duration}",
139
+ )
140
+ parser.add_argument(
141
+ "--nfe_step",
142
+ type=int,
143
+ help=f"The number of function evaluation (denoising steps), default {nfe_step}",
144
+ )
145
+ parser.add_argument(
146
+ "--cfg_strength",
147
+ type=float,
148
+ help=f"Classifier-free guidance strength, default {cfg_strength}",
149
+ )
150
+ parser.add_argument(
151
+ "--sway_sampling_coef",
152
+ type=float,
153
+ help=f"Sway Sampling coefficient, default {sway_sampling_coef}",
154
  )
155
  parser.add_argument(
156
  "--speed",
157
  type=float,
158
+ help=f"The speed of the generated audio, default {speed}",
159
+ )
160
+ parser.add_argument(
161
+ "--fix_duration",
162
+ type=float,
163
+ help=f"Fix the total duration (ref and gen audios) in seconds, default {fix_duration}",
164
  )
165
  args = parser.parse_args()
166
 
167
+
168
+ # config file
169
+
170
  config = tomli.load(open(args.config, "rb"))
171
 
172
+
173
+ # command-line interface parameters
174
+
175
+ model = args.model or config.get("model", "F5-TTS")
176
+ model_cfg = args.model_cfg or config.get("model_cfg", str(files("f5_tts").joinpath("configs/F5TTS_Base_train.yaml")))
177
+ ckpt_file = args.ckpt_file or config.get("ckpt_file", "")
178
+ vocab_file = args.vocab_file or config.get("vocab_file", "")
179
+
180
+ ref_audio = args.ref_audio or config.get("ref_audio", "infer/examples/basic/basic_ref_en.wav")
181
+ ref_text = args.ref_text or config.get("ref_text", "Some call me nature, others call me mother nature.")
182
+ gen_text = args.gen_text or config.get("gen_text", "Here we generate something just for test.")
183
+ gen_file = args.gen_file or config.get("gen_file", "")
184
+
185
+ output_dir = args.output_dir or config.get("output_dir", "tests")
186
+ output_file = args.output_file or config.get(
187
+ "output_file", f"infer_cli_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}.wav"
188
+ )
189
+
190
+ save_chunk = args.save_chunk or config.get("save_chunk", False)
191
+ remove_silence = args.remove_silence or config.get("remove_silence", False)
192
+ load_vocoder_from_local = args.load_vocoder_from_local or config.get("load_vocoder_from_local", False)
193
+
194
+ vocoder_name = args.vocoder_name or config.get("vocoder_name", mel_spec_type)
195
+ target_rms = args.target_rms or config.get("target_rms", target_rms)
196
+ cross_fade_duration = args.cross_fade_duration or config.get("cross_fade_duration", cross_fade_duration)
197
+ nfe_step = args.nfe_step or config.get("nfe_step", nfe_step)
198
+ cfg_strength = args.cfg_strength or config.get("cfg_strength", cfg_strength)
199
+ sway_sampling_coef = args.sway_sampling_coef or config.get("sway_sampling_coef", sway_sampling_coef)
200
+ speed = args.speed or config.get("speed", speed)
201
+ fix_duration = args.fix_duration or config.get("fix_duration", fix_duration)
202
+
203
 
204
  # patches for pip pkg user
205
  if "infer/examples/" in ref_audio:
 
212
  if "infer/examples/" in voice_ref_audio:
213
  config["voices"][voice]["ref_audio"] = str(files("f5_tts").joinpath(f"{voice_ref_audio}"))
214
 
215
+
216
+ # ignore gen_text if gen_file provided
217
+
218
  if gen_file:
219
  gen_text = codecs.open(gen_file, "r", "utf-8").read()
220
+
221
+
222
+ # output path
 
 
 
 
223
 
224
  wave_path = Path(output_dir) / output_file
225
  # spectrogram_path = Path(output_dir) / "infer_cli_out.png"
226
+ if save_chunk:
227
+ output_chunk_dir = os.path.join(output_dir, f"{Path(output_file).stem}_chunks")
228
+ if not os.path.exists(output_chunk_dir):
229
+ os.makedirs(output_chunk_dir)
230
+
231
+
232
+ # load vocoder
233
 
 
 
234
  if vocoder_name == "vocos":
235
  vocoder_local_path = "../checkpoints/vocos-mel-24khz"
236
  elif vocoder_name == "bigvgan":
237
  vocoder_local_path = "../checkpoints/bigvgan_v2_24khz_100band_256x"
238
 
239
+ vocoder = load_vocoder(vocoder_name=vocoder_name, is_local=load_vocoder_from_local, local_path=vocoder_local_path)
240
+
241
 
242
+ # load TTS model
243
 
 
244
  if model == "F5-TTS":
245
  model_cls = DiT
246
+ model_cfg = OmegaConf.load(model_cfg).model.arch
247
+ if not ckpt_file: # path not specified, download from repo
248
  if vocoder_name == "vocos":
249
  repo_name = "F5-TTS"
250
  exp_name = "F5TTS_Base"
 
258
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.pt"))
259
 
260
  elif model == "E2-TTS":
261
+ assert args.model_cfg is None, "E2-TTS does not support custom model_cfg yet"
262
+ assert vocoder_name == "vocos", "E2-TTS only supports vocoder vocos yet"
263
  model_cls = UNetT
264
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
265
+ if not ckpt_file: # path not specified, download from repo
266
  repo_name = "E2-TTS"
267
  exp_name = "E2TTS_Base"
268
  ckpt_step = 1200000
269
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
270
  # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
271
 
 
272
  print(f"Using {model}...")
273
+ ema_model = load_model(model_cls, model_cfg, ckpt_file, mel_spec_type=vocoder_name, vocab_file=vocab_file)
274
+
275
 
276
+ # inference process
277
 
278
+
279
+ def main():
280
  main_voice = {"ref_audio": ref_audio, "ref_text": ref_text}
281
  if "voices" not in config:
282
  voices = {"main": main_voice}
 
284
  voices = config["voices"]
285
  voices["main"] = main_voice
286
  for voice in voices:
287
+ print("Voice:", voice)
288
+ print("ref_audio ", voices[voice]["ref_audio"])
289
  voices[voice]["ref_audio"], voices[voice]["ref_text"] = preprocess_ref_audio_text(
290
  voices[voice]["ref_audio"], voices[voice]["ref_text"]
291
  )
292
+ print("ref_audio_", voices[voice]["ref_audio"], "\n\n")
 
 
293
 
294
  generated_audio_segments = []
295
  reg1 = r"(?=\[\w+\])"
296
+ chunks = re.split(reg1, gen_text)
297
  reg2 = r"\[(\w+)\]"
298
  for text in chunks:
299
  if not text.strip():
 
308
  print(f"Voice {voice} not found, using main.")
309
  voice = "main"
310
  text = re.sub(reg2, "", text)
311
+ ref_audio_ = voices[voice]["ref_audio"]
312
+ ref_text_ = voices[voice]["ref_text"]
313
+ gen_text_ = text.strip()
314
  print(f"Voice: {voice}")
315
+ audio_segment, final_sample_rate, spectragram = infer_process(
316
+ ref_audio_,
317
+ ref_text_,
318
+ gen_text_,
319
+ ema_model,
320
+ vocoder,
321
+ mel_spec_type=vocoder_name,
322
+ target_rms=target_rms,
323
+ cross_fade_duration=cross_fade_duration,
324
+ nfe_step=nfe_step,
325
+ cfg_strength=cfg_strength,
326
+ sway_sampling_coef=sway_sampling_coef,
327
+ speed=speed,
328
+ fix_duration=fix_duration,
329
  )
330
+ generated_audio_segments.append(audio_segment)
331
+
332
+ if save_chunk:
333
+ if len(gen_text_) > 200:
334
+ gen_text_ = gen_text_[:200] + " ... "
335
+ sf.write(
336
+ os.path.join(output_chunk_dir, f"{len(generated_audio_segments)-1}_{gen_text_}.wav"),
337
+ audio_segment,
338
+ final_sample_rate,
339
+ )
340
 
341
  if generated_audio_segments:
342
  final_wave = np.concatenate(generated_audio_segments)
 
352
  print(f.name)
353
 
354
 
 
 
 
 
355
  if __name__ == "__main__":
356
  main()
src/f5_tts/model/backbones/dit.py CHANGED
@@ -105,6 +105,7 @@ class DiT(nn.Module):
105
  text_dim=None,
106
  conv_layers=0,
107
  long_skip_connection=False,
 
108
  ):
109
  super().__init__()
110
 
@@ -127,6 +128,16 @@ class DiT(nn.Module):
127
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
128
  self.proj_out = nn.Linear(dim, mel_dim)
129
 
 
 
 
 
 
 
 
 
 
 
130
  def forward(
131
  self,
132
  x: float["b n d"], # nosied input audio # noqa: F722
@@ -152,7 +163,10 @@ class DiT(nn.Module):
152
  residual = x
153
 
154
  for block in self.transformer_blocks:
155
- x = block(x, t, mask=mask, rope=rope)
 
 
 
156
 
157
  if self.long_skip_connection is not None:
158
  x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
 
105
  text_dim=None,
106
  conv_layers=0,
107
  long_skip_connection=False,
108
+ checkpoint_activations=False,
109
  ):
110
  super().__init__()
111
 
 
128
  self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
129
  self.proj_out = nn.Linear(dim, mel_dim)
130
 
131
+ self.checkpoint_activations = checkpoint_activations
132
+
133
+ def ckpt_wrapper(self, module):
134
+ # https://github.com/chuanyangjin/fast-DiT/blob/main/models.py
135
+ def ckpt_forward(*inputs):
136
+ outputs = module(*inputs)
137
+ return outputs
138
+
139
+ return ckpt_forward
140
+
141
  def forward(
142
  self,
143
  x: float["b n d"], # nosied input audio # noqa: F722
 
163
  residual = x
164
 
165
  for block in self.transformer_blocks:
166
+ if self.checkpoint_activations:
167
+ x = torch.utils.checkpoint.checkpoint(self.ckpt_wrapper(block), x, t, mask, rope)
168
+ else:
169
+ x = block(x, t, mask=mask, rope=rope)
170
 
171
  if self.long_skip_connection is not None:
172
  x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
src/f5_tts/model/trainer.py CHANGED
@@ -315,7 +315,7 @@ class Trainer:
315
  self.scheduler.step()
316
  self.optimizer.zero_grad()
317
 
318
- if self.is_main:
319
  self.ema_model.update()
320
 
321
  global_step += 1
 
315
  self.scheduler.step()
316
  self.optimizer.zero_grad()
317
 
318
+ if self.is_main and self.accelerator.sync_gradients:
319
  self.ema_model.update()
320
 
321
  global_step += 1