Spaces:
Running
Running
File size: 3,787 Bytes
9c54d62 118c154 9c54d62 118c154 9c54d62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import soundfile as sf
import torch
import tqdm
from cached_path import cached_path
from model import DiT, UNetT
from model.utils import save_spectrogram
from model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
class F5TTS:
def __init__(
self,
model_type="F5-TTS",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
local_path=None,
device=None,
):
# Initialize parameters
self.final_wave = None
self.target_sample_rate = 24000
self.n_mel_channels = 100
self.hop_length = 256
self.target_rms = 0.1
# Set device
self.device = device or (
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
)
# Load models
self.load_vocoder_model(local_path)
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
def load_vocoder_model(self, local_path):
self.vocos = load_vocoder(local_path is not None, local_path, self.device)
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
if model_type == "F5-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors"))
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors"))
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
else:
raise ValueError(f"Unknown model type: {model_type}")
self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device)
def export_wav(self, wav, file_wave, remove_silence=False):
if remove_silence:
remove_silence_for_generated_wav(file_wave)
sf.write(file_wave, wav, self.target_sample_rate)
def export_spectrogram(self, spect, file_spect):
save_spectrogram(spect, file_spect)
def infer(
self,
ref_file,
ref_text,
gen_text,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spect=None,
cross_fade_duration=0.15,
show_info=print,
progress=tqdm,
):
wav, sr, spect = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
cross_fade_duration,
speed,
show_info,
progress,
nfe_step,
cfg_strength,
sway_sampling_coef,
fix_duration,
)
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)
if file_spect is not None:
self.export_spectrogram(spect, file_spect)
return wav, sr, spect
if __name__ == "__main__":
f5tts = F5TTS()
wav, sr, spect = f5tts.infer(
ref_file="tests/ref_audio/test_en_1_ref_short.wav",
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave="tests/out.wav",
file_spect="tests/out.png",
)
|