Spaces:
Running
Running
mrfakename
commited on
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
src/f5_tts/train/finetune_cli.py
CHANGED
@@ -6,6 +6,7 @@ from cached_path import cached_path
|
|
6 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
7 |
from f5_tts.model.utils import get_tokenizer
|
8 |
from f5_tts.model.dataset import load_dataset
|
|
|
9 |
|
10 |
|
11 |
# -------------------------- Dataset Settings --------------------------- #
|
@@ -63,6 +64,7 @@ def parse_args():
|
|
63 |
|
64 |
def main():
|
65 |
args = parse_args()
|
|
|
66 |
|
67 |
# Model parameters based on experiment name
|
68 |
if args.exp_name == "F5TTS_Base":
|
@@ -85,12 +87,9 @@ def main():
|
|
85 |
ckpt_path = args.pretrain
|
86 |
|
87 |
if args.finetune:
|
88 |
-
|
89 |
-
|
90 |
-
os.
|
91 |
-
shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
|
92 |
-
|
93 |
-
checkpoint_path = os.path.join("ckpts", args.dataset_name)
|
94 |
|
95 |
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
96 |
tokenizer = args.tokenizer
|
|
|
6 |
from f5_tts.model import CFM, UNetT, DiT, Trainer
|
7 |
from f5_tts.model.utils import get_tokenizer
|
8 |
from f5_tts.model.dataset import load_dataset
|
9 |
+
from importlib.resources import files
|
10 |
|
11 |
|
12 |
# -------------------------- Dataset Settings --------------------------- #
|
|
|
64 |
|
65 |
def main():
|
66 |
args = parse_args()
|
67 |
+
checkpoint_path = str(files("f5_tts").joinpath(f"../../ckpts/{args.dataset_name}"))
|
68 |
|
69 |
# Model parameters based on experiment name
|
70 |
if args.exp_name == "F5TTS_Base":
|
|
|
87 |
ckpt_path = args.pretrain
|
88 |
|
89 |
if args.finetune:
|
90 |
+
if not os.path.isdir(checkpoint_path):
|
91 |
+
os.makedirs(checkpoint_path, exist_ok=True)
|
92 |
+
shutil.copy2(ckpt_path, os.path.join(checkpoint_path, os.path.basename(ckpt_path)))
|
|
|
|
|
|
|
93 |
|
94 |
# Use the tokenizer and tokenizer_path provided in the command line arguments
|
95 |
tokenizer = args.tokenizer
|