Spaces:
Running
Running
File size: 3,923 Bytes
01e655b 02e90e4 84cfd61 01e655b 02e90e4 01e655b 02e90e4 01e655b 02e90e4 01e655b 29536f1 01e655b 29536f1 49bce5c 01e655b 29536f1 01e655b 84cfd61 01e655b 02e90e4 01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import numpy as np
import torch
from modules.speaker import Speaker
from modules.utils.SeedContext import SeedContext
from modules import models, config
import logging
from modules.devices import devices
from typing import Union
from modules.utils.cache import conditional_cache
logger = logging.getLogger(__name__)
def generate_audio(
text: str,
temperature: float = 0.3,
top_P: float = 0.7,
top_K: float = 20,
spk: Union[int, Speaker] = -1,
infer_seed: int = -1,
use_decoder: bool = True,
prompt1: str = "",
prompt2: str = "",
prefix: str = "",
):
(sample_rate, wav) = generate_audio_batch(
[text],
temperature=temperature,
top_P=top_P,
top_K=top_K,
spk=spk,
infer_seed=infer_seed,
use_decoder=use_decoder,
prompt1=prompt1,
prompt2=prompt2,
prefix=prefix,
)[0]
return (sample_rate, wav)
@torch.inference_mode()
def generate_audio_batch(
texts: list[str],
temperature: float = 0.3,
top_P: float = 0.7,
top_K: float = 20,
spk: Union[int, Speaker] = -1,
infer_seed: int = -1,
use_decoder: bool = True,
prompt1: str = "",
prompt2: str = "",
prefix: str = "",
):
chat_tts = models.load_chat_tts()
params_infer_code = {
"spk_emb": None,
"temperature": temperature,
"top_P": top_P,
"top_K": top_K,
"prompt1": prompt1 or "",
"prompt2": prompt2 or "",
"prefix": prefix or "",
"repetition_penalty": 1.0,
"disable_tqdm": config.runtime_env_vars.off_tqdm,
}
if isinstance(spk, int):
with SeedContext(spk):
params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
logger.info(("spk", spk))
elif isinstance(spk, Speaker):
params_infer_code["spk_emb"] = spk.emb
logger.info(("spk", spk.name))
else:
raise ValueError("spk must be int or Speaker")
logger.info(
{
"text": texts,
"infer_seed": infer_seed,
"temperature": temperature,
"top_P": top_P,
"top_K": top_K,
"prompt1": prompt1 or "",
"prompt2": prompt2 or "",
"prefix": prefix or "",
}
)
with SeedContext(infer_seed):
wavs = chat_tts.generate_audio(
texts, params_infer_code, use_decoder=use_decoder
)
sample_rate = 24000
devices.torch_gc()
return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
lru_cache_enabled = False
def setup_lru_cache():
global generate_audio_batch
global lru_cache_enabled
if lru_cache_enabled:
return
lru_cache_enabled = True
def should_cache(*args, **kwargs):
spk_seed = kwargs.get("spk", -1)
infer_seed = kwargs.get("infer_seed", -1)
return spk_seed != -1 and infer_seed != -1
lru_size = config.runtime_env_vars.lru_size
if isinstance(lru_size, int):
generate_audio_batch = conditional_cache(lru_size, should_cache)(
generate_audio_batch
)
logger.info(f"LRU cache enabled with size {lru_size}")
else:
logger.debug(f"LRU cache failed to enable, invalid size {lru_size}")
if __name__ == "__main__":
import soundfile as sf
# 测试batch生成
inputs = ["你好[lbreak]", "再见[lbreak]", "长度不同的文本片段[lbreak]"]
outputs = generate_audio_batch(inputs, spk=5, infer_seed=42)
for i, (sample_rate, wav) in enumerate(outputs):
print(i, sample_rate, wav.shape)
sf.write(f"batch_{i}.wav", wav, sample_rate, format="wav")
# 单独生成
for i, text in enumerate(inputs):
sample_rate, wav = generate_audio(text, spk=5, infer_seed=42)
print(i, sample_rate, wav.shape)
sf.write(f"one_{i}.wav", wav, sample_rate, format="wav")
|