mrfakename commited on
Commit
5a9adbc
1 Parent(s): fb41309

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/train/datasets/prepare_csv_wavs.py CHANGED
@@ -54,8 +54,7 @@ def prepare_csv_wavs_dir(input_dir):
54
 
55
  def get_audio_duration(audio_path):
56
  audio, sample_rate = torchaudio.load(audio_path)
57
- num_channels = audio.shape[0]
58
- return audio.shape[1] / (sample_rate * num_channels)
59
 
60
 
61
  def read_audio_text_pairs(csv_file_path):
 
54
 
55
  def get_audio_duration(audio_path):
56
  audio, sample_rate = torchaudio.load(audio_path)
57
+ return audio.shape[1] / sample_rate
 
58
 
59
 
60
  def read_audio_text_pairs(csv_file_path):
src/f5_tts/train/finetune_gradio.py CHANGED
@@ -172,10 +172,9 @@ def load_settings(project_name):
172
 
173
  # Load metadata
174
  def get_audio_duration(audio_path):
175
- """Calculate the duration of an audio file."""
176
  audio, sample_rate = torchaudio.load(audio_path)
177
- num_channels = audio.shape[0]
178
- return audio.shape[1] / (sample_rate * num_channels)
179
 
180
 
181
  def clear_text(text):
@@ -383,13 +382,17 @@ def start_training(
383
  stream=False,
384
  logger="wandb",
385
  ):
386
- global training_process, tts_api, stop_signal
387
 
388
- if tts_api is not None:
389
- del tts_api
 
 
 
390
  gc.collect()
391
  torch.cuda.empty_cache()
392
  tts_api = None
 
393
 
394
  path_project = os.path.join(path_data, dataset_name)
395
 
@@ -1557,7 +1560,7 @@ If you encounter a memory error, try reducing the batch size per GPU to a smalle
1557
  last_per_steps = gr.Number(label="Last per Steps", value=100)
1558
 
1559
  with gr.Row():
1560
- mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "fpb16"], value="none")
1561
  cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1562
  start_button = gr.Button("Start Training")
1563
  stop_button = gr.Button("Stop Training", interactive=False)
 
172
 
173
  # Load metadata
174
  def get_audio_duration(audio_path):
175
+ """Calculate the duration mono of an audio file."""
176
  audio, sample_rate = torchaudio.load(audio_path)
177
+ return audio.shape[1] / sample_rate
 
178
 
179
 
180
  def clear_text(text):
 
382
  stream=False,
383
  logger="wandb",
384
  ):
385
+ global training_process, tts_api, stop_signal, pipe
386
 
387
+ if tts_api is not None or pipe is not None:
388
+ if tts_api is not None:
389
+ del tts_api
390
+ if pipe is not None:
391
+ del pipe
392
  gc.collect()
393
  torch.cuda.empty_cache()
394
  tts_api = None
395
+ pipe = None
396
 
397
  path_project = os.path.join(path_data, dataset_name)
398
 
 
1560
  last_per_steps = gr.Number(label="Last per Steps", value=100)
1561
 
1562
  with gr.Row():
1563
+ mixed_precision = gr.Radio(label="mixed_precision", choices=["none", "fp16", "bf16"], value="none")
1564
  cd_logger = gr.Radio(label="logger", choices=["wandb", "tensorboard"], value="wandb")
1565
  start_button = gr.Button("Start Training")
1566
  stop_button = gr.Button("Stop Training", interactive=False)