Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,571 Bytes
4dab15f fededd1 4dab15f fededd1 4dab15f fededd1 4dab15f c0fb8c8 b0bca14 4dab15f fededd1 4dab15f fededd1 4dab15f 1bcb8fe 4dab15f b0bca14 4dab15f fededd1 4dab15f b315dd9 4dab15f 1bcb8fe 4dab15f 1bcb8fe 4dab15f 1bcb8fe 4dab15f fededd1 cf0b618 1bcb8fe cf0b618 fededd1 cf0b618 1bcb8fe cf0b618 4dab15f cf0b618 1bcb8fe cf0b618 4dab15f fededd1 4dab15f c0fb8c8 4dab15f b0bca14 4dab15f b6584c2 fededd1 4dab15f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import random
import sys
from importlib.resources import files
import soundfile as sf
import tqdm
from cached_path import cached_path
from f5_tts.infer.utils_infer import (
hop_length,
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
transcribe,
target_sample_rate,
)
from f5_tts.model import DiT, UNetT
from f5_tts.model.utils import seed_everything
class F5TTS:
def __init__(
self,
model_type="F5-TTS",
ckpt_file="",
vocab_file="",
ode_method="euler",
use_ema=True,
vocoder_name="vocos",
local_path=None,
device=None,
hf_cache_dir=None,
):
# Initialize parameters
self.final_wave = None
self.target_sample_rate = target_sample_rate
self.hop_length = hop_length
self.seed = -1
self.mel_spec_type = vocoder_name
# Set device
if device is not None:
self.device = device
else:
import torch
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Load models
self.load_vocoder_model(vocoder_name, local_path=local_path, hf_cache_dir=hf_cache_dir)
self.load_ema_model(
model_type, ckpt_file, vocoder_name, vocab_file, ode_method, use_ema, hf_cache_dir=hf_cache_dir
)
def load_vocoder_model(self, vocoder_name, local_path=None, hf_cache_dir=None):
self.vocoder = load_vocoder(vocoder_name, local_path is not None, local_path, self.device, hf_cache_dir)
def load_ema_model(self, model_type, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, hf_cache_dir=None):
if model_type == "F5-TTS":
if not ckpt_file:
if mel_spec_type == "vocos":
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
)
elif mel_spec_type == "bigvgan":
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_Base_bigvgan/model_1250000.pt", cache_dir=hf_cache_dir)
)
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
model_cls = DiT
elif model_type == "E2-TTS":
if not ckpt_file:
ckpt_file = str(
cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors", cache_dir=hf_cache_dir)
)
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
model_cls = UNetT
else:
raise ValueError(f"Unknown model type: {model_type}")
self.ema_model = load_model(
model_cls, model_cfg, ckpt_file, mel_spec_type, vocab_file, ode_method, use_ema, self.device
)
def transcribe(self, ref_audio, language=None):
return transcribe(ref_audio, language)
def export_wav(self, wav, file_wave, remove_silence=False):
sf.write(file_wave, wav, self.target_sample_rate)
if remove_silence:
remove_silence_for_generated_wav(file_wave)
def export_spectrogram(self, spect, file_spect):
save_spectrogram(spect, file_spect)
def infer(
self,
ref_file,
ref_text,
gen_text,
show_info=print,
progress=tqdm,
target_rms=0.1,
cross_fade_duration=0.15,
sway_sampling_coef=-1,
cfg_strength=2,
nfe_step=32,
speed=1.0,
fix_duration=None,
remove_silence=False,
file_wave=None,
file_spect=None,
seed=-1,
):
if seed == -1:
seed = random.randint(0, sys.maxsize)
seed_everything(seed)
self.seed = seed
ref_file, ref_text = preprocess_ref_audio_text(ref_file, ref_text, device=self.device)
wav, sr, spect = infer_process(
ref_file,
ref_text,
gen_text,
self.ema_model,
self.vocoder,
self.mel_spec_type,
show_info=show_info,
progress=progress,
target_rms=target_rms,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
speed=speed,
fix_duration=fix_duration,
device=self.device,
)
if file_wave is not None:
self.export_wav(wav, file_wave, remove_silence)
if file_spect is not None:
self.export_spectrogram(spect, file_spect)
return wav, sr, spect
if __name__ == "__main__":
f5tts = F5TTS()
wav, sr, spect = f5tts.infer(
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
ref_text="some call me nature, others call me mother nature.",
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""",
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")),
file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")),
seed=-1, # random seed = -1
)
print("seed :", f5tts.seed)
|