mrfakename commited on
Commit
cf68f41
1 Parent(s): 57b3db8

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. src/f5_tts/model/trainer.py +9 -1
src/f5_tts/model/trainer.py CHANGED
@@ -47,6 +47,8 @@ class Trainer:
47
  ema_kwargs: dict = dict(),
48
  bnb_optimizer: bool = False,
49
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
 
 
50
  ):
51
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
52
 
@@ -108,7 +110,11 @@ class Trainer:
108
  self.max_samples = max_samples
109
  self.grad_accumulation_steps = grad_accumulation_steps
110
  self.max_grad_norm = max_grad_norm
 
 
111
  self.vocoder_name = mel_spec_type
 
 
112
 
113
  self.noise_scheduler = noise_scheduler
114
 
@@ -199,7 +205,9 @@ class Trainer:
199
  if self.log_samples:
200
  from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
201
 
202
- vocoder = load_vocoder(vocoder_name=self.vocoder_name)
 
 
203
  target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
204
  log_samples_path = f"{self.checkpoint_path}/samples"
205
  os.makedirs(log_samples_path, exist_ok=True)
 
47
  ema_kwargs: dict = dict(),
48
  bnb_optimizer: bool = False,
49
  mel_spec_type: str = "vocos", # "vocos" | "bigvgan"
50
+ is_local_vocoder: bool = False, # use local path vocoder
51
+ local_vocoder_path: str = "", # local vocoder path
52
  ):
53
  ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
54
 
 
110
  self.max_samples = max_samples
111
  self.grad_accumulation_steps = grad_accumulation_steps
112
  self.max_grad_norm = max_grad_norm
113
+
114
+ # mel vocoder config
115
  self.vocoder_name = mel_spec_type
116
+ self.is_local_vocoder = is_local_vocoder
117
+ self.local_vocoder_path = local_vocoder_path
118
 
119
  self.noise_scheduler = noise_scheduler
120
 
 
205
  if self.log_samples:
206
  from f5_tts.infer.utils_infer import cfg_strength, load_vocoder, nfe_step, sway_sampling_coef
207
 
208
+ vocoder = load_vocoder(
209
+ vocoder_name=self.vocoder_name, is_local=self.is_local_vocoder, local_path=self.local_vocoder_path
210
+ )
211
  target_sample_rate = self.accelerator.unwrap_model(self.model).mel_spec.target_sample_rate
212
  log_samples_path = f"{self.checkpoint_path}/samples"
213
  os.makedirs(log_samples_path, exist_ok=True)