Spaces:
Running
Running
import random | |
import sys | |
from importlib.resources import files | |
import soundfile as sf | |
import torch | |
import tqdm | |
from cached_path import cached_path | |
from f5_tts.infer.utils_infer import ( | |
hop_length, | |
infer_process, | |
load_model, | |
load_vocoder, | |
preprocess_ref_audio_text, | |
remove_silence_for_generated_wav, | |
save_spectrogram, | |
target_sample_rate, | |
) | |
from f5_tts.model import DiT, UNetT | |
from f5_tts.model.utils import seed_everything | |
class F5TTS: | |
def __init__( | |
self, | |
model_type="F5-TTS", | |
ckpt_file="", | |
vocab_file="", | |
ode_method="euler", | |
use_ema=True, | |
vocoder_name="vocos", | |
local_path=None, | |
device=None, | |
): | |
# Initialize parameters | |
self.final_wave = None | |
self.target_sample_rate = target_sample_rate | |
self.hop_length = hop_length | |
self.seed = -1 | |
self.mel_spec_type = vocoder_name | |
# Set device | |
self.device = device or ( | |
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
) | |
# Load models | |
self.load_vocoder_model(vocoder_name, local_path) | |
self.load_ema_model(model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema) | |
def load_vocoder_model(self, vocoder_name, local_path): | |
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device) | |
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema): | |
if model_type == "F5-TTS": | |
if not ckpt_file: | |
if mel_spec_type == "vocos": | |
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) | |
elif mel_spec_type == "bigvgan": | |
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt")) | |
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) | |
model_cls = DiT | |
elif model_type == "E2-TTS": | |
if not ckpt_file: | |
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) | |
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) | |
model_cls = UNetT | |
else: | |
raise ValueError(f"Unknown model type: {model_type}") | |
self.ema_model = load_model( | |
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device | |
) | |
def export_wav(self, wav, file_wave, remove_silence=False): | |
sf.write(file_wave, wav, self.target_sample_rate) | |
if remove_silence: | |
remove_silence_for_generated_wav(file_wave) | |
def export_spectrogram(self, spect, file_spect): | |
save_spectrogram(spect, file_spect) | |
def infer( | |
self, | |
ref_file, | |
ref_text, | |
gen_text, | |
show_info=print, | |
progress=tqdm, | |
target_rms=0.1, | |
cross_fade_duration=0.15, | |
sway_sampling_coef=-1, | |
cfg_strength=2, | |
nfe_step=32, | |
speed=1.0, | |
fix_duration=None, | |
remove_silence=False, | |
file_wave=None, | |
file_spect=None, | |
seed=-1, | |
): | |
if seed == -1: | |
seed = random.randint(0, sys.maxsize) | |
seed_everything(seed) | |
self.seed = seed | |
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device) | |
wav, sr, spect = infer_process( | |
ref_file, | |
ref_text, | |
gen_text, | |
self.ema_model, | |
self.vocoder, | |
self.mel_spec_type, | |
show_info=show_info, | |
progress=progress, | |
target_rms=target_rms, | |
cross_fade_duration=cross_fade_duration, | |
nfe_step=nfe_step, | |
cfg_strength=cfg_strength, | |
sway_sampling_coef=sway_sampling_coef, | |
speed=speed, | |
fix_duration=fix_duration, | |
device=self.device, | |
) | |
if file_wave is not None: | |
self.export_wav(wav, file_wave, remove_silence) | |
if file_spect is not None: | |
self.export_spectrogram(spect, file_spect) | |
return wav, sr, spect | |
if __name__ == "__main__": | |
f5tts = F5TTS() | |
wav, sr, spect = f5tts.infer( | |
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), | |
ref_text="some call me nature, others call me mother nature.", | |
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", | |
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")), | |
file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")), | |
seed=-1, # random seed = -1 | |
) | |
print("seed :", f5tts.seed) | |