Spaces:
Running
Running
alessandro trinca tornidor
test: update test cases for models modules, add preprocessAudioStandalone() function
acfca85
import os | |
from pathlib import Path | |
import tempfile | |
import torch | |
import torch.nn as nn | |
from silero.utils import Decoder | |
from aip_trainer import app_logger, sample_rate_start | |
default_speaker_dict = { | |
"de": {"speaker": "karlsson", "model_id": "v3_de", "sample_rate": sample_rate_start}, | |
"en": {"speaker": "en_0", "model_id": "v3_en", "sample_rate": sample_rate_start}, | |
} | |
def silero_tts(language="en", version="latest", output_folder: Path | str = None, **kwargs): | |
"""Silero Text-To-Speech Models | |
language (str): language of the model, now available are ['ru', 'en', 'de', 'es', 'fr'] | |
Returns a model and a set of utils | |
Please see https://github.com/snakers4/silero-models for usage examples | |
""" | |
output_folder = Path(output_folder) | |
current_model_lang = default_speaker_dict[language] | |
app_logger.info(f"model speaker current_model_lang: {current_model_lang} ...") | |
if language in default_speaker_dict: | |
model_id = current_model_lang["model_id"] | |
models = get_models(language, output_folder, version, model_type="tts_models") | |
available_languages = list(models.tts_models.keys()) | |
assert ( | |
language in available_languages | |
), f"Language not in the supported list {available_languages}" | |
tts_models_lang = models.tts_models[language] | |
model_conf = tts_models_lang[model_id] | |
model_conf_latest = model_conf[version] | |
app_logger.info(f"model_conf: {model_conf_latest} ...") | |
if "_v2" in model_id or "_v3" in model_id or "v3_" in model_id or "v4_" in model_id: | |
from torch import package | |
model_url = model_conf_latest.package | |
model_dir = output_folder / "model" | |
os.makedirs(model_dir, exist_ok=True) | |
model_path = output_folder / os.path.basename(model_url) | |
if not os.path.isfile(model_path): | |
torch.hub.download_url_to_file(model_url, model_path, progress=True) | |
imp = package.PackageImporter(model_path) | |
model = imp.load_pickle("tts_models", "model") | |
app_logger.info( | |
f"current model_conf_latest.sample_rate:{model_conf_latest.sample_rate} ..." | |
) | |
sample_rate = current_model_lang["sample_rate"] | |
return ( | |
model, | |
model_conf_latest.example, | |
current_model_lang["speaker"], | |
sample_rate, | |
) | |
else: | |
from silero.tts_utils import apply_tts, init_jit_model as init_jit_model_tts | |
model = init_jit_model_tts(model_conf_latest.jit) | |
symbols = model_conf_latest.tokenset | |
example_text = model_conf_latest.example | |
sample_rate = model_conf_latest.sample_rate | |
return model, symbols, sample_rate, example_text, apply_tts, model_id | |
def silero_stt( | |
language="en", | |
version="latest", | |
jit_model="jit", | |
output_folder: Path | str = None, | |
**kwargs, | |
): | |
"""Modified Silero Speech-To-Text Model(s) function | |
language (str): language of the model, now available are ['en', 'de', 'es'] | |
version: | |
jit_model: | |
output_folder: needed in case of docker build | |
Returns a model, decoder object and a set of utils | |
Please see https://github.com/snakers4/silero-models for usage examples | |
""" | |
from silero.utils import ( | |
read_audio, | |
read_batch, | |
split_into_batches, | |
prepare_model_input, | |
) | |
model, decoder = get_latest_model( | |
language, | |
output_folder, | |
version, | |
model_type="stt_models", | |
jit_model=jit_model, | |
**kwargs, | |
) | |
utils = (read_batch, split_into_batches, read_audio, prepare_model_input) | |
return model, decoder, utils | |
def init_jit_model( | |
model_url: str, | |
device: torch.device = torch.device("cpu"), | |
output_folder: Path | str = None, | |
): | |
torch.set_grad_enabled(False) | |
app_logger.info( | |
f"model output_folder exists? '{output_folder is None}' => '{output_folder}' ..." | |
) | |
model_dir = ( | |
Path(output_folder) | |
if output_folder is not None | |
else Path(torch.hub.get_dir()) | |
) | |
os.makedirs(model_dir, exist_ok=True) | |
app_logger.info(f"downloading the models to model_dir: '{model_dir}' ...") | |
model_path = model_dir / os.path.basename(model_url) | |
app_logger.info( | |
f"model_path exists? '{os.path.isfile(model_path)}' => '{model_path}' ..." | |
) | |
if not os.path.isfile(model_path): | |
app_logger.info(f"downloading model_path: '{model_path}' ...") | |
torch.hub.download_url_to_file(model_url, model_path, progress=True) | |
app_logger.info(f"model_path {model_path} downloaded!") | |
model = torch.jit.load(model_path, map_location=device) | |
model.eval() | |
return model, Decoder(model.labels) | |
# second returned type here is the custom class src.silero.utils.Decoder from snakers4/silero-models | |
def getASRModel(language: str) -> tuple[nn.Module, Decoder]: | |
tmp_dir = tempfile.gettempdir() | |
if language == "de": | |
model, decoder, _ = silero_stt( | |
language="de", version="v4", jit_model="jit_large", output_folder=tmp_dir | |
) | |
elif language == "en": | |
model, decoder, _ = silero_stt(language="en", output_folder=tmp_dir) | |
else: | |
raise NotImplementedError( | |
"currenty works only for 'de' and 'en' languages, not for '{}'.".format( | |
language | |
) | |
) | |
return model, decoder | |
def get_models(language, output_folder, version, model_type): | |
from omegaconf import OmegaConf | |
output_folder = ( | |
Path(output_folder) | |
if output_folder is not None | |
else Path(os.path.dirname(__file__)).parent.parent | |
) | |
models_list_file = output_folder / f"latest_silero_model_{language}.yml" | |
if not os.path.exists(models_list_file): | |
app_logger.info( | |
f"model {model_type} yml for '{language}' language, '{version}' version not found, download it in folder {output_folder}..." | |
) | |
torch.hub.download_url_to_file( | |
"https://raw.githubusercontent.com/snakers4/silero-models/master/models.yml", | |
models_list_file, | |
progress=False, | |
) | |
assert os.path.exists(models_list_file) | |
return OmegaConf.load(models_list_file) | |
def get_latest_model(language, output_folder, version, model_type, jit_model, **kwargs): | |
models = get_models(language, output_folder, version, model_type) | |
available_languages = list(models[model_type].keys()) | |
assert language in available_languages | |
model, decoder = init_jit_model( | |
model_url=models[model_type].get(language).get(version).get(jit_model), | |
output_folder=output_folder, | |
**kwargs, | |
) | |
return model, decoder | |