mrfakename commited on
Commit
01eed3a
·
verified ·
1 Parent(s): 5819920

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. finetune_gradio.py +150 -7
finetune_gradio.py CHANGED
@@ -1,9 +1,12 @@
1
  import os
2
  import sys
3
 
 
 
4
  from transformers import pipeline
5
  import gradio as gr
6
  import torch
 
7
  import click
8
  import torchaudio
9
  from glob import glob
@@ -20,11 +23,16 @@ import psutil
20
  import platform
21
  import subprocess
22
  from datasets.arrow_writer import ArrowWriter
 
 
23
 
24
 
25
  training_process = None
26
  system = platform.system()
27
  python_executable = sys.executable or "python"
 
 
 
28
 
29
  path_data = "data"
30
 
@@ -240,7 +248,12 @@ def start_training(
240
  last_per_steps=800,
241
  finetune=True,
242
  ):
243
- global training_process
 
 
 
 
 
244
 
245
  path_project = os.path.join(path_data, dataset_name + "_pinyin")
246
 
@@ -288,7 +301,7 @@ def start_training(
288
  training_process = subprocess.Popen(cmd, shell=True)
289
 
290
  time.sleep(5)
291
- yield "check terminal for wandb", gr.update(interactive=False), gr.update(interactive=True)
292
 
293
  # Wait for the training process to finish
294
  training_process.wait()
@@ -519,6 +532,17 @@ def calculate_train(
519
  path_project = os.path.join(path_data, name_project)
520
  file_duraction = os.path.join(path_project, "duration.json")
521
 
 
 
 
 
 
 
 
 
 
 
 
522
  with open(file_duraction, "r") as file:
523
  data = json.load(file)
524
 
@@ -549,8 +573,8 @@ def calculate_train(
549
  else:
550
  max_samples = 64
551
 
552
- num_warmup_updates = int(samples * 0.10)
553
- save_per_updates = int(samples * 0.25)
554
  last_per_steps = int(save_per_updates * 5)
555
 
556
  max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
@@ -559,7 +583,7 @@ def calculate_train(
559
  last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
560
 
561
  if finetune:
562
- learning_rate = 1e-4
563
  else:
564
  learning_rate = 7.5e-5
565
 
@@ -611,6 +635,7 @@ def vocab_check(project_name):
611
  sp = item.split("|")
612
  if len(sp) != 2:
613
  continue
 
614
  text = sp[1].lower().strip()
615
 
616
  for t in text:
@@ -625,6 +650,80 @@ def vocab_check(project_name):
625
  return info
626
 
627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
628
  with gr.Blocks() as app:
629
  with gr.Row():
630
  project_name = gr.Textbox(label="project name", value="my_speak")
@@ -661,6 +760,18 @@ with gr.Blocks() as app:
661
  )
662
  ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
663
 
 
 
 
 
 
 
 
 
 
 
 
 
664
  with gr.TabItem("prepare Data"):
665
  gr.Markdown(
666
  """```plaintext
@@ -687,6 +798,16 @@ with gr.Blocks() as app:
687
  txt_info_prepare = gr.Text(label="info", value="")
688
  bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
689
 
 
 
 
 
 
 
 
 
 
 
690
  with gr.TabItem("train Data"):
691
  with gr.Row():
692
  bt_calculate = bt_create = gr.Button("Auto Settings")
@@ -696,11 +817,11 @@ with gr.Blocks() as app:
696
 
697
  with gr.Row():
698
  exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
699
- learning_rate = gr.Number(label="Learning Rate", value=1e-4, step=1e-4)
700
 
701
  with gr.Row():
702
  batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
703
- max_samples = gr.Number(label="Max Samples", value=16)
704
 
705
  with gr.Row():
706
  grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
@@ -778,6 +899,28 @@ with gr.Blocks() as app:
778
  txt_info_check = gr.Text(label="info", value="")
779
  check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
 
782
  @click.command()
783
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
 
1
  import os
2
  import sys
3
 
4
+ import tempfile
5
+ import random
6
  from transformers import pipeline
7
  import gradio as gr
8
  import torch
9
+ import gc
10
  import click
11
  import torchaudio
12
  from glob import glob
 
23
  import platform
24
  import subprocess
25
  from datasets.arrow_writer import ArrowWriter
26
+ from datasets import Dataset as Dataset_
27
+ from api import F5TTS
28
 
29
 
30
  training_process = None
31
  system = platform.system()
32
  python_executable = sys.executable or "python"
33
+ tts_api = None
34
+ last_checkpoint = ""
35
+ last_device = ""
36
 
37
  path_data = "data"
38
 
 
248
  last_per_steps=800,
249
  finetune=True,
250
  ):
251
+ global training_process, tts_api
252
+
253
+ if tts_api is not None:
254
+ del tts_api
255
+ gc.collect()
256
+ torch.cuda.empty_cache()
257
 
258
  path_project = os.path.join(path_data, dataset_name + "_pinyin")
259
 
 
301
  training_process = subprocess.Popen(cmd, shell=True)
302
 
303
  time.sleep(5)
304
+ yield "train start", gr.update(interactive=False), gr.update(interactive=True)
305
 
306
  # Wait for the training process to finish
307
  training_process.wait()
 
532
  path_project = os.path.join(path_data, name_project)
533
  file_duraction = os.path.join(path_project, "duration.json")
534
 
535
+ if not os.path.isfile(file_duraction):
536
+ return (
537
+ 1000,
538
+ max_samples,
539
+ num_warmup_updates,
540
+ save_per_updates,
541
+ last_per_steps,
542
+ "project not found !",
543
+ learning_rate,
544
+ )
545
+
546
  with open(file_duraction, "r") as file:
547
  data = json.load(file)
548
 
 
573
  else:
574
  max_samples = 64
575
 
576
+ num_warmup_updates = int(samples * 0.05)
577
+ save_per_updates = int(samples * 0.10)
578
  last_per_steps = int(save_per_updates * 5)
579
 
580
  max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
 
583
  last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
584
 
585
  if finetune:
586
+ learning_rate = 1e-5
587
  else:
588
  learning_rate = 7.5e-5
589
 
 
635
  sp = item.split("|")
636
  if len(sp) != 2:
637
  continue
638
+
639
  text = sp[1].lower().strip()
640
 
641
  for t in text:
 
650
  return info
651
 
652
 
653
+ def get_random_sample_prepare(project_name):
654
+ name_project = project_name + "_pinyin"
655
+ path_project = os.path.join(path_data, name_project)
656
+ file_arrow = os.path.join(path_project, "raw.arrow")
657
+ if not os.path.isfile(file_arrow):
658
+ return "", None
659
+ dataset = Dataset_.from_file(file_arrow)
660
+ random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
661
+ text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
662
+ audio_path = random_sample["audio_path"][0]
663
+ return text, audio_path
664
+
665
+
666
+ def get_random_sample_transcribe(project_name):
667
+ name_project = project_name + "_pinyin"
668
+ path_project = os.path.join(path_data, name_project)
669
+ file_metadata = os.path.join(path_project, "metadata.csv")
670
+ if not os.path.isfile(file_metadata):
671
+ return "", None
672
+
673
+ data = ""
674
+ with open(file_metadata, "r", encoding="utf-8") as f:
675
+ data = f.read()
676
+
677
+ list_data = []
678
+ for item in data.split("\n"):
679
+ sp = item.split("|")
680
+ if len(sp) != 2:
681
+ continue
682
+ list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
683
+
684
+ if list_data == []:
685
+ return "", None
686
+
687
+ random_item = random.choice(list_data)
688
+
689
+ return random_item[1], random_item[0]
690
+
691
+
692
+ def get_random_sample_infer(project_name):
693
+ text, audio = get_random_sample_transcribe(project_name)
694
+ return (
695
+ text,
696
+ text,
697
+ audio,
698
+ )
699
+
700
+
701
+ def infer(project_name, file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
702
+ global last_checkpoint, last_device, tts_api
703
+
704
+ if not os.path.isfile(file_checkpoint):
705
+ return None
706
+
707
+ if training_process is not None:
708
+ device_test = "cpu"
709
+ else:
710
+ device_test = None
711
+
712
+ if last_checkpoint != file_checkpoint or last_device != device_test:
713
+ if last_checkpoint != file_checkpoint:
714
+ last_checkpoint = file_checkpoint
715
+ if last_device != device_test:
716
+ last_device = device_test
717
+
718
+ tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
719
+
720
+ print("update", device_test, file_checkpoint)
721
+
722
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
723
+ tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
724
+ return f.name
725
+
726
+
727
  with gr.Blocks() as app:
728
  with gr.Row():
729
  project_name = gr.Textbox(label="project name", value="my_speak")
 
760
  )
761
  ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
762
 
763
+ random_sample_transcribe = gr.Button("random sample")
764
+
765
+ with gr.Row():
766
+ random_text_transcribe = gr.Text(label="Text")
767
+ random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
768
+
769
+ random_sample_transcribe.click(
770
+ fn=get_random_sample_transcribe,
771
+ inputs=[project_name],
772
+ outputs=[random_text_transcribe, random_audio_transcribe],
773
+ )
774
+
775
  with gr.TabItem("prepare Data"):
776
  gr.Markdown(
777
  """```plaintext
 
798
  txt_info_prepare = gr.Text(label="info", value="")
799
  bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
800
 
801
+ random_sample_prepare = gr.Button("random sample")
802
+
803
+ with gr.Row():
804
+ random_text_prepare = gr.Text(label="Pinyin")
805
+ random_audio_prepare = gr.Audio(label="Audio", type="filepath")
806
+
807
+ random_sample_prepare.click(
808
+ fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
809
+ )
810
+
811
  with gr.TabItem("train Data"):
812
  with gr.Row():
813
  bt_calculate = bt_create = gr.Button("Auto Settings")
 
817
 
818
  with gr.Row():
819
  exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
820
+ learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
821
 
822
  with gr.Row():
823
  batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
824
+ max_samples = gr.Number(label="Max Samples", value=64)
825
 
826
  with gr.Row():
827
  grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
 
899
  txt_info_check = gr.Text(label="info", value="")
900
  check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
901
 
902
+ with gr.TabItem("test model"):
903
+ exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
904
+ nfe_step = gr.Number(label="n_step", value=32)
905
+ file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
906
+
907
+ random_sample_infer = gr.Button("random sample")
908
+
909
+ ref_text = gr.Textbox(label="ref text")
910
+ ref_audio = gr.Audio(label="audio ref", type="filepath")
911
+ gen_text = gr.Textbox(label="gen text")
912
+ random_sample_infer.click(
913
+ fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
914
+ )
915
+ check_button_infer = gr.Button("infer")
916
+ gen_audio = gr.Audio(label="audio gen", type="filepath")
917
+
918
+ check_button_infer.click(
919
+ fn=infer,
920
+ inputs=[project_name, file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
921
+ outputs=[gen_audio],
922
+ )
923
+
924
 
925
  @click.command()
926
  @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")