mrfakename commited on
Commit
490128a
1 Parent(s): 910037c

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/configs/E2TTS_Base_train.yaml CHANGED
@@ -41,4 +41,4 @@ ckpts:
41
  logger: wandb # wandb | tensorboard | None
42
  save_per_updates: 50000 # save checkpoint per steps
43
  last_per_steps: 5000 # save last checkpoint per steps
44
- save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
 
41
  logger: wandb # wandb | tensorboard | None
42
  save_per_updates: 50000 # save checkpoint per steps
43
  last_per_steps: 5000 # save last checkpoint per steps
44
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/E2TTS_Small_train.yaml CHANGED
@@ -41,4 +41,4 @@ ckpts:
41
  logger: wandb # wandb | tensorboard | None
42
  save_per_updates: 50000 # save checkpoint per steps
43
  last_per_steps: 5000 # save last checkpoint per steps
44
- save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
 
41
  logger: wandb # wandb | tensorboard | None
42
  save_per_updates: 50000 # save checkpoint per steps
43
  last_per_steps: 5000 # save last checkpoint per steps
44
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_Base_train.yaml CHANGED
@@ -43,4 +43,4 @@ ckpts:
43
  logger: wandb # wandb | tensorboard | None
44
  save_per_updates: 50000 # save checkpoint per steps
45
  last_per_steps: 5000 # save last checkpoint per steps
46
- save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
 
43
  logger: wandb # wandb | tensorboard | None
44
  save_per_updates: 50000 # save checkpoint per steps
45
  last_per_steps: 5000 # save last checkpoint per steps
46
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/configs/F5TTS_Small_train.yaml CHANGED
@@ -43,4 +43,4 @@ ckpts:
43
  logger: wandb # wandb | tensorboard | None
44
  save_per_updates: 50000 # save checkpoint per steps
45
  last_per_steps: 5000 # save last checkpoint per steps
46
- save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
 
43
  logger: wandb # wandb | tensorboard | None
44
  save_per_updates: 50000 # save checkpoint per steps
45
  last_per_steps: 5000 # save last checkpoint per steps
46
+ save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
src/f5_tts/infer/utils_infer.py CHANGED
@@ -138,7 +138,11 @@ asr_pipe = None
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
  dtype = (
141
- torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
 
 
 
 
142
  )
143
  global asr_pipe
144
  asr_pipe = pipeline(
@@ -171,7 +175,11 @@ def transcribe(ref_audio, language=None):
171
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
172
  if dtype is None:
173
  dtype = (
174
- torch.float16 if "cuda" in device and torch.cuda.get_device_properties(device).major >= 6 else torch.float32
 
 
 
 
175
  )
176
  model = model.to(dtype)
177
 
 
138
  def initialize_asr_pipeline(device: str = device, dtype=None):
139
  if dtype is None:
140
  dtype = (
141
+ torch.float16
142
+ if "cuda" in device
143
+ and torch.cuda.get_device_properties(device).major >= 6
144
+ and not torch.cuda.get_device_name().endswith("[ZLUDA]")
145
+ else torch.float32
146
  )
147
  global asr_pipe
148
  asr_pipe = pipeline(
 
175
  def load_checkpoint(model, ckpt_path, device: str, dtype=None, use_ema=True):
176
  if dtype is None:
177
  dtype = (
178
+ torch.float16
179
+ if "cuda" in device
180
+ and torch.cuda.get_device_properties(device).major >= 6
181
+ and not torch.cuda.get_device_name().endswith("[ZLUDA]")
182
+ else torch.float32
183
  )
184
  model = model.to(dtype)
185