Spaces:
Running
on
Zero
Running
on
Zero
mrfakename
commited on
Commit
•
cf0b618
1
Parent(s):
feb8eed
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- src/f5_tts/api.py +12 -6
- src/f5_tts/infer/SHARED.md +11 -0
- src/f5_tts/infer/utils_infer.py +12 -2
src/f5_tts/api.py
CHANGED
@@ -46,24 +46,30 @@ class F5TTS:
|
|
46 |
)
|
47 |
|
48 |
# Load models
|
49 |
-
self.load_vocoder_model(vocoder_name, local_path)
|
50 |
-
self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema)
|
51 |
|
52 |
def load_vocoder_model(self, vocoder_name, local_path):
|
53 |
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
|
54 |
|
55 |
-
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema):
|
56 |
if model_type == "F5-TTS":
|
57 |
if not ckpt_file:
|
58 |
if mel_spec_type == "vocos":
|
59 |
-
ckpt_file = str(
|
|
|
|
|
60 |
elif mel_spec_type == "bigvgan":
|
61 |
-
ckpt_file = str(
|
|
|
|
|
62 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
63 |
model_cls = DiT
|
64 |
elif model_type == "E2-TTS":
|
65 |
if not ckpt_file:
|
66 |
-
ckpt_file = str(
|
|
|
|
|
67 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
68 |
model_cls = UNetT
|
69 |
else:
|
|
|
46 |
)
|
47 |
|
48 |
# Load models
|
49 |
+
self.load_vocoder_model(vocoder_name, local_path=local_path)
|
50 |
+
self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, local_path=local_path)
|
51 |
|
52 |
def load_vocoder_model(self, vocoder_name, local_path):
|
53 |
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device)
|
54 |
|
55 |
+
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, local_path):
|
56 |
if model_type == "F5-TTS":
|
57 |
if not ckpt_file:
|
58 |
if mel_spec_type == "vocos":
|
59 |
+
ckpt_file = str(
|
60 |
+
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=local_path)
|
61 |
+
)
|
62 |
elif mel_spec_type == "bigvgan":
|
63 |
+
ckpt_file = str(
|
64 |
+
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=local_path)
|
65 |
+
)
|
66 |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
|
67 |
model_cls = DiT
|
68 |
elif model_type == "E2-TTS":
|
69 |
if not ckpt_file:
|
70 |
+
ckpt_file = str(
|
71 |
+
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=local_path)
|
72 |
+
)
|
73 |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
|
74 |
model_cls = UNetT
|
75 |
else:
|
src/f5_tts/infer/SHARED.md
CHANGED
@@ -40,6 +40,17 @@ VOCAB_FILE: hf://SWivid/F5-TTS/F5TTS_Base/vocab.txt
|
|
40 |
|
41 |
## Mandarin
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
## English
|
45 |
|
|
|
40 |
|
41 |
## Mandarin
|
42 |
|
43 |
+
## Japanese
|
44 |
+
|
45 |
+
#### F5-TTS Base @ pretrain/finetune @ ja
|
46 |
+
|Model|🤗Hugging Face|Data (Hours)|Model License|
|
47 |
+
|:---:|:------------:|:-----------:|:-------------:|
|
48 |
+
|F5-TTS Base|[ckpt & vocab](https://huggingface.co/Jmica/F5TTS/tree/main/JA_8500000)|[Emilia 1.7k JA](https://huggingface.co/datasets/amphion/Emilia-Dataset/tree/fc71e07) & [Galgame Dataset 5.4k](https://huggingface.co/datasets/OOPPEENN/Galgame_Dataset)|cc-by-nc-4.0|
|
49 |
+
|
50 |
+
```bash
|
51 |
+
MODEL_CKPT: hf://Jmica/F5TTS/JA_8500000/model_8499660.pt
|
52 |
+
VOCAB_FILE: hf://Jmica/F5TTS/JA_8500000/vocab_updated.txt
|
53 |
+
```
|
54 |
|
55 |
## English
|
56 |
|
src/f5_tts/infer/utils_infer.py
CHANGED
@@ -19,6 +19,7 @@ import numpy as np
|
|
19 |
import torch
|
20 |
import torchaudio
|
21 |
import tqdm
|
|
|
22 |
from pydub import AudioSegment, silence
|
23 |
from transformers import pipeline
|
24 |
from vocos import Vocos
|
@@ -93,8 +94,16 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
93 |
if vocoder_name == "vocos":
|
94 |
if is_local:
|
95 |
print(f"Load vocos from local path {local_path}")
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
vocoder.load_state_dict(state_dict)
|
99 |
vocoder = vocoder.eval().to(device)
|
100 |
else:
|
@@ -107,6 +116,7 @@ def load_vocoder(vocoder_name="vocos", is_local=False, local_path="", device=dev
|
|
107 |
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
108 |
if is_local:
|
109 |
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
|
|
110 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
111 |
else:
|
112 |
vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)
|
|
|
19 |
import torch
|
20 |
import torchaudio
|
21 |
import tqdm
|
22 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
23 |
from pydub import AudioSegment, silence
|
24 |
from transformers import pipeline
|
25 |
from vocos import Vocos
|
|
|
94 |
if vocoder_name == "vocos":
|
95 |
if is_local:
|
96 |
print(f"Load vocos from local path {local_path}")
|
97 |
+
repo_id = "charactr/vocos-mel-24khz"
|
98 |
+
revision = None
|
99 |
+
config_path = hf_hub_download(
|
100 |
+
repo_id=repo_id, cache_dir=local_path, filename="config.yaml", revision=revision
|
101 |
+
)
|
102 |
+
model_path = hf_hub_download(
|
103 |
+
repo_id=repo_id, cache_dir=local_path, filename="pytorch_model.bin", revision=revision
|
104 |
+
)
|
105 |
+
vocoder = Vocos.from_hparams(config_path=config_path)
|
106 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
107 |
vocoder.load_state_dict(state_dict)
|
108 |
vocoder = vocoder.eval().to(device)
|
109 |
else:
|
|
|
116 |
print("You need to follow the README to init submodule and change the BigVGAN source code.")
|
117 |
if is_local:
|
118 |
"""download from https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x/tree/main"""
|
119 |
+
local_path = snapshot_download(repo_id="nvidia/bigvgan_v2_24khz_100band_256x", cache_dir=local_path)
|
120 |
vocoder = bigvgan.BigVGAN.from_pretrained(local_path, use_cuda_kernel=False)
|
121 |
else:
|
122 |
vocoder = bigvgan.BigVGAN.from_pretrained("nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False)
|