mrfakename commited on
Commit
1bcb8fe
·
verified ·
1 Parent(s): e2287e3

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

src/f5_tts/api.py CHANGED
@@ -32,6 +32,7 @@ class F5TTS:
32
  vocoder_name="vocos",
33
  local_path=None,
34
  device=None,
 
35
  ):
36
  # Initialize parameters
37
  self.final_wave = None
@@ -46,29 +47,31 @@ class F5TTS:
46
  )
47
 
48
  # Load models
49
- self.load_vocoder_model(vocoder_name, local_path=local_path)
50
- self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path)
 
 
51
 
52
- def load_vocoder_model(self, vocoder_name, local_path=None):
53
- self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
54
 
55
- def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path=None):
56
  if model_type == "F5-TTS":
57
  if not ckpt_file:
58
  if mel_spec_type == "vocos":
59
  ckpt_file = str(
60
- cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path)
61
  )
62
  elif mel_spec_type == "bigvgan":
63
  ckpt_file = str(
64
- cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path)
65
  )
66
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
67
  model_cls = DiT
68
  elif model_type == "E2-TTS":
69
  if not ckpt_file:
70
  ckpt_file = str(
71
- cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path)
72
  )
73
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
74
  model_cls = UNetT
 
32
  vocoder_name="vocos",
33
  local_path=None,
34
  device=None,
35
+ hf_cache_dir=None,
36
  ):
37
  # Initialize parameters
38
  self.final_wave = None
 
47
  )
48
 
49
  # Load models
50
+ self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
51
+ self.load_ema_model(
52
+ model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
53
+ )
54
 
55
+ def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
56
+ self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
57
 
58
+ def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
59
  if model_type == "F5-TTS":
60
  if not ckpt_file:
61
  if mel_spec_type == "vocos":
62
  ckpt_file = str(
63
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
64
  )
65
  elif mel_spec_type == "bigvgan":
66
  ckpt_file = str(
67
+ cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
68
  )
69
  model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
70
  model_cls = DiT
71
  elif model_type == "E2-TTS":
72
  if not ckpt_file:
73
  ckpt_file = str(
74
+ cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
75
  )
76
  model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
77
  model_cls = UNetT
src/f5_tts/infer/utils_infer.py CHANGED
@@ -90,18 +90,18 @@ def chunk_text(text, max_chars=135):
90
 
91
 
92
  # load vocoder
93
- def load_vocoder(vocoder_name="vocos", is_local=False, local_path=None, device=device):
94
  if vocoder_name == "vocos":
95
  # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
96
- if is_local and local_path is not None:
97
  print(f"Load vocos from local path {local_path}")
98
  config_path = f"{local_path}/config.yaml"
99
  model_path = f"{local_path}/pytorch_model.bin"
100
  else:
101
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
102
  repo_id = "charactr/vocos-mel-24khz"
103
- config_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="config.yaml")
104
- model_path = hf_hub_download(repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin")
105
  vocoder = Vocos.from_hparams(config_path)
106
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
107
  from vocos.feature_extractors import EncodecFeatures
@@ -119,11 +119,11 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path=None, device=d
119
  from third_party.BigVGAN import bigvgan
120
  except ImportError:
121
  print("You need to follow the README to init submodule and change the BigVGAN source code.")
122
- if is_local and local_path is not None:
123
  """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
124
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
125
  else:
126
- local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path)
127
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
128
 
129
  vocoder.remove_weight_norm()
 
90
 
91
 
92
  # load vocoder
93
+ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=device, hf_cache_dir=None):
94
  if vocoder_name == "vocos":
95
  # vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz").to(device)
96
+ if is_local:
97
  print(f"Load vocos from local path {local_path}")
98
  config_path = f"{local_path}/config.yaml"
99
  model_path = f"{local_path}/pytorch_model.bin"
100
  else:
101
  print("Download Vocos from huggingface charactr/vocos-mel-24khz")
102
  repo_id = "charactr/vocos-mel-24khz"
103
+ config_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="config.yaml")
104
+ model_path = hf_hub_download(repo_id=repo_id, cache_dir=hf_cache_dir, filename="pytorch_model.bin")
105
  vocoder = Vocos.from_hparams(config_path)
106
  state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
107
  from vocos.feature_extractors import EncodecFeatures
 
119
  from third_party.BigVGAN import bigvgan
120
  except ImportError:
121
  print("You need to follow the README to init submodule and change the BigVGAN source code.")
122
+ if is_local:
123
  """download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
124
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
125
  else:
126
+ local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=hf_cache_dir)
127
  vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
128
 
129
  vocoder.remove_weight_norm()