Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,577 Bytes
4dab15f ea89faa b5979c9 4dab15f b5979c9 fededd1 4dab15f fededd1 4dab15f ea89faa 4dab15f b5979c9 4dab15f b5979c9 4dab15f b5979c9 4dab15f b5979c9 4dab15f b5979c9 4dab15f b5979c9 4dab15f b5979c9 4dab15f b5979c9 b6584c2 b5979c9 fededd1 27cee60 4dab15f b5979c9 4dab15f b5979c9 4dab15f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
# training script.
import os
from importlib.resources import files
import hydra
from f5_tts.model import CFM, DiT, Trainer, UNetT
from f5_tts.model.dataset import load_dataset
from f5_tts.model.utils import get_tokenizer
os.chdir(str(files("f5_tts").joinpath("../.."))) # change working directory to root of project (local editable)
@hydra.main(version_base="1.3", config_path=str(files("f5_tts").joinpath("configs")), config_name=None)
def main(cfg):
tokenizer = cfg.model.tokenizer
mel_spec_type = cfg.model.mel_spec.mel_spec_type
exp_name = f"{cfg.model.name}_{mel_spec_type}_{cfg.model.tokenizer}_{cfg.datasets.name}"
# set text tokenizer
if tokenizer != "custom":
tokenizer_path = cfg.datasets.name
else:
tokenizer_path = cfg.model.tokenizer_path
vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
# set model
if "F5TTS" in cfg.model.name:
model_cls = DiT
elif "E2TTS" in cfg.model.name:
model_cls = UNetT
wandb_resume_id = None
model = CFM(
transformer=model_cls(**cfg.model.arch, text_num_embeds=vocab_size, mel_dim=cfg.model.mel_spec.n_mel_channels),
mel_spec_kwargs=cfg.model.mel_spec,
vocab_char_map=vocab_char_map,
)
# init trainer
trainer = Trainer(
model,
epochs=cfg.optim.epochs,
learning_rate=cfg.optim.learning_rate,
num_warmup_updates=cfg.optim.num_warmup_updates,
save_per_updates=cfg.ckpts.save_per_updates,
checkpoint_path=str(files("f5_tts").joinpath(f"../../{cfg.ckpts.save_dir}")),
batch_size=cfg.datasets.batch_size_per_gpu,
batch_size_type=cfg.datasets.batch_size_type,
max_samples=cfg.datasets.max_samples,
grad_accumulation_steps=cfg.optim.grad_accumulation_steps,
max_grad_norm=cfg.optim.max_grad_norm,
logger=cfg.ckpts.logger,
wandb_project="CFM-TTS",
wandb_run_name=exp_name,
wandb_resume_id=wandb_resume_id,
last_per_steps=cfg.ckpts.last_per_steps,
log_samples=True,
bnb_optimizer=cfg.optim.bnb_optimizer,
mel_spec_type=mel_spec_type,
is_local_vocoder=cfg.model.vocoder.is_local,
local_vocoder_path=cfg.model.vocoder.local_path,
)
train_dataset = load_dataset(cfg.datasets.name, tokenizer, mel_spec_kwargs=cfg.model.mel_spec)
trainer.train(
train_dataset,
num_workers=cfg.datasets.num_workers,
resumable_with_seed=666, # seed for shuffling dataset
)
if __name__ == "__main__":
main()
|