|
import os |
|
import logging |
|
|
|
logging.basicConfig( |
|
level=os.getenv("LOG_LEVEL", "INFO"), |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
) |
|
|
|
from modules.devices import devices |
|
import argparse |
|
import uvicorn |
|
|
|
import torch |
|
from modules import config |
|
from modules.utils import env |
|
from modules import generate_audio as generate |
|
from modules.api.Api import APIManager |
|
|
|
from modules.api.impl import ( |
|
style_api, |
|
tts_api, |
|
ssml_api, |
|
google_api, |
|
openai_api, |
|
refiner_api, |
|
speaker_api, |
|
ping_api, |
|
models_api, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
torch._dynamo.config.cache_size_limit = 64 |
|
torch._dynamo.config.suppress_errors = True |
|
torch.set_float32_matmul_precision("high") |
|
|
|
|
|
def create_api(app, no_docs=False, exclude=[]): |
|
app_mgr = APIManager(app=app, no_docs=no_docs, exclude_patterns=exclude) |
|
|
|
ping_api.setup(app_mgr) |
|
models_api.setup(app_mgr) |
|
style_api.setup(app_mgr) |
|
speaker_api.setup(app_mgr) |
|
tts_api.setup(app_mgr) |
|
ssml_api.setup(app_mgr) |
|
google_api.setup(app_mgr) |
|
openai_api.setup(app_mgr) |
|
refiner_api.setup(app_mgr) |
|
|
|
return app_mgr |
|
|
|
|
|
def get_and_update_env(*args): |
|
val = env.get_env_or_arg(*args) |
|
key = args[1] |
|
config.runtime_env_vars[key] = val |
|
return val |
|
|
|
|
|
def setup_model_args(parser: argparse.ArgumentParser): |
|
parser.add_argument("--compile", action="store_true", help="Enable model compile") |
|
parser.add_argument( |
|
"--half", |
|
action="store_true", |
|
help="Enable half precision for model inference", |
|
) |
|
parser.add_argument( |
|
"--off_tqdm", |
|
action="store_true", |
|
help="Disable tqdm progress bar", |
|
) |
|
parser.add_argument( |
|
"--device_id", |
|
type=str, |
|
help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--use_cpu", |
|
nargs="+", |
|
help="use CPU as torch device for specified modules", |
|
default=[], |
|
type=str.lower, |
|
) |
|
parser.add_argument( |
|
"--lru_size", |
|
type=int, |
|
default=64, |
|
help="Set the size of the request cache pool, set it to 0 will disable lru_cache", |
|
) |
|
|
|
|
|
def setup_api_args(parser: argparse.ArgumentParser): |
|
parser.add_argument("--api_host", type=str, help="Host to run the server on") |
|
parser.add_argument("--api_port", type=int, help="Port to run the server on") |
|
parser.add_argument( |
|
"--reload", action="store_true", help="Enable auto-reload for development" |
|
) |
|
parser.add_argument( |
|
"--cors_origin", |
|
type=str, |
|
help="Allowed CORS origins. Use '*' to allow all origins.", |
|
) |
|
parser.add_argument( |
|
"--no_playground", |
|
action="store_true", |
|
help="Disable the playground entry", |
|
) |
|
parser.add_argument( |
|
"--no_docs", |
|
action="store_true", |
|
help="Disable the documentation entry", |
|
) |
|
|
|
parser.add_argument( |
|
"--exclude", |
|
type=str, |
|
help="Exclude the specified API from the server", |
|
) |
|
|
|
|
|
def process_model_args(args): |
|
lru_size = get_and_update_env(args, "lru_size", 64, int) |
|
compile = get_and_update_env(args, "compile", False, bool) |
|
device_id = get_and_update_env(args, "device_id", None, str) |
|
use_cpu = get_and_update_env(args, "use_cpu", [], list) |
|
half = get_and_update_env(args, "half", False, bool) |
|
off_tqdm = get_and_update_env(args, "off_tqdm", False, bool) |
|
|
|
generate.setup_lru_cache() |
|
devices.reset_device() |
|
devices.first_time_calculation() |
|
|
|
|
|
def process_api_args(args, app): |
|
cors_origin = get_and_update_env(args, "cors_origin", "*", str) |
|
no_playground = get_and_update_env(args, "no_playground", False, bool) |
|
no_docs = get_and_update_env(args, "no_docs", False, bool) |
|
exclude = get_and_update_env(args, "exclude", "", str) |
|
|
|
api = create_api(app=app, no_docs=no_docs, exclude=exclude.split(",")) |
|
config.api = api |
|
|
|
if cors_origin: |
|
api.set_cors(allow_origins=[cors_origin]) |
|
|
|
if not no_playground: |
|
api.setup_playground() |
|
|
|
if compile: |
|
logger.info("Model compile is enabled") |
|
|
|
|
|
app_description = """ |
|
ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/> |
|
ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax |
|
|
|
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge) |
|
|
|
> 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/> |
|
> All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging |
|
|
|
> 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/> |
|
> [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb) |
|
""" |
|
app_title = "ChatTTS Forge API" |
|
app_version = "0.1.0" |
|
|
|
if __name__ == "__main__": |
|
import dotenv |
|
from fastapi import FastAPI |
|
|
|
dotenv.load_dotenv( |
|
dotenv_path=os.getenv("ENV_FILE", ".env.api"), |
|
) |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Start the FastAPI server with command line arguments" |
|
) |
|
setup_api_args(parser) |
|
setup_model_args(parser) |
|
|
|
args = parser.parse_args() |
|
|
|
app = FastAPI( |
|
title=app_title, |
|
description=app_description, |
|
version=app_version, |
|
redoc_url=None if config.runtime_env_vars.no_docs else "/redoc", |
|
docs_url=None if config.runtime_env_vars.no_docs else "/docs", |
|
) |
|
|
|
process_model_args(args) |
|
process_api_args(args, app) |
|
|
|
host = get_and_update_env(args, "api_host", "0.0.0.0", str) |
|
port = get_and_update_env(args, "api_port", 7870, int) |
|
reload = get_and_update_env(args, "reload", False, bool) |
|
|
|
uvicorn.run(app, host=host, port=port, reload=reload) |
|
|