mrfakename commited on
Commit
cf0b618
1 Parent(s): feb8eed

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/api.py CHANGED
@@ -46,24 +46,30 @@ class F5TTS:
46
  )
47
 
48
  # Load models
49
- self.load_vocoder_model(vocoder_name, local_path)
50
- self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema)
51
 
52
  def load_vocoder_model(self, vocoder_name, local_path):
53
  self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
54
 
55
- def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
56
  if model_type == "F5-TTS":
57
  if not ckpt_file:
58
  if mel_spec_type == "vocos":
59
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
 
 
60
  elif mel_spec_type == "bigvgan":
61
- ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt"))
 
 
62
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
63
  model_cls = DiT
64
  elif model_type == "E2-TTS":
65
  if not ckpt_file:
66
- ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
 
 
67
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
68
  model_cls = UNetT
69
  else:
 
46
  )
47
 
48
  # Load models
49
+ self.load_vocoder_model(vocoder_name, local_path=local_path)
50
+ self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path)
51
 
52
  def load_vocoder_model(self, vocoder_name, local_path):
53
  self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
54
 
55
+ def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path):
56
  if model_type == "F5-TTS":
57
  if not ckpt_file:
58
  if mel_spec_type == "vocos":
59
+ ckpt_file = str(
60
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path)
61
+ )
62
  elif mel_spec_type == "bigvgan":
63
+ ckpt_file = str(
64
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path)
65
+ )
66
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
67
  model_cls = DiT
68
  elif model_type == "E2-TTS":
69
  if not ckpt_file:
70
+ ckpt_file = str(
71
+ cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path)
72
+ )
73
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
74
  model_cls = UNetT
75
  else:
src/f5_tts/infer/SHARED.md CHANGED
@@ -40,6 +40,17 @@ VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
40
 
41
  ## Mandarin
42
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  ## English
45
 
 
40
 
41
  ## Mandarin
42
 
43
+ ## Japanese
44
+
45
+ #### F5-TTS Base @ pretrain/finetune @ ja
46
+ |Model|🤗Hugging Face|Data (Hours)|Model License|
47
+ |:---:|:------------:|:-----------:|:-------------:|
48
+ |F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
49
+
50
+ ```bash
51
+ MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
52
+ VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
53
+ ```
54
 
55
  ## English
56
 
src/f5_tts/infer/utils_infer.py CHANGED
@@ -19,6 +19,7 @@ import numpy as np
19
  import torch
20
  import torchaudio
21
  import tqdm
 
22
  from pydub import AudioSegment, silence
23
  from transformers import pipeline
24
  from vocos import Vocos
@@ -93,8 +94,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
93
  if vocoder_name == "vocos":
94
  if is_local:
95
  print(f"Load vocos from local path {local_path}")
96
- vocoder = Vocos.from_hparams(f"{local_path}/config.yaml")
97
- state_dict = torch.load(f"{local_path}/pytorch_model.bin", map_location="cpu")
 
 
 
 
 
 
 
 
98
  vocoder.load_state_dict(state_dict)
99
  vocoder = vocoder.eval().to(device)
100
  else:
@@ -107,6 +116,7 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
107
  print("You need to follow the README to init submodule and change the BigVGAN source code.")
108
  if is_local:
109
  """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
 
110
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
111
  else:
112
  vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)
 
19
  import torch
20
  import torchaudio
21
  import tqdm
22
+ from huggingface_hub import snapshot_download, hf_hub_download
23
  from pydub import AudioSegment, silence
24
  from transformers import pipeline
25
  from vocos import Vocos
 
94
  if vocoder_name == "vocos":
95
  if is_local:
96
  print(f"Load vocos from local path {local_path}")
97
+ repo_id = "charactr/vocos-mel-24khz"
98
+ revision = None
99
+ config_path = hf_hub_download(
100
+ repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision
101
+ )
102
+ model_path = hf_hub_download(
103
+ repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision
104
+ )
105
+ vocoder = Vocos.from_hparams(config_path=config_path)
106
+ state_dict = torch.load(model_path, map_location="cpu")
107
  vocoder.load_state_dict(state_dict)
108
  vocoder = vocoder.eval().to(device)
109
  else:
 
116
  print("You need to follow the README to init submodule and change the BigVGAN source code.")
117
  if is_local:
118
  """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
119
+ local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path)
120
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
121
  else:
122
  vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)