Spaces:
Running
on
T4
Running
on
T4
import json | |
import logging | |
import os | |
import shlex | |
import subprocess | |
import tempfile | |
from pathlib import Path | |
from typing import Literal, Optional | |
import fastapi | |
import fastapi.middleware.cors | |
import torch | |
import tyro | |
import uvicorn | |
from attr import dataclass | |
from fastapi import Request | |
from fastapi.responses import Response | |
from huggingface_hub import snapshot_download | |
from fam.llm.sample import ( | |
InferenceConfig, | |
Model, | |
build_models, | |
get_first_stage_path, | |
get_second_stage_path, | |
# sample_utterance, | |
) | |
from fam.llm.fast_inference import TTS | |
logger = logging.getLogger(__name__) | |
## Setup FastAPI server. | |
app = fastapi.FastAPI() | |
class ServingConfig: | |
huggingface_repo_id: str | |
"""Absolute path to the model directory.""" | |
max_new_tokens: int = 864 * 2 | |
"""Maximum number of new tokens to generate from the first stage model.""" | |
temperature: float = 1.0 | |
"""Temperature for sampling applied to both models.""" | |
top_k: int = 200 | |
"""Top k for sampling applied to both models.""" | |
seed: int = 1337 | |
"""Random seed for sampling.""" | |
dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = "bfloat16" | |
"""Data type to use for sampling.""" | |
enhancer: Optional[Literal["df"]] = "df" | |
"""Enhancer to use for post-processing.""" | |
port: int = 58003 | |
# Singleton | |
class _GlobalState: | |
config: ServingConfig | |
tts: TTS | |
GlobalState = _GlobalState() | |
class TTSRequest: | |
text: str | |
guidance: Optional[float] = 3.0 | |
top_p: Optional[float] = 0.95 | |
speaker_ref_path: Optional[str] = None | |
top_k: Optional[int] = None | |
def sample_utterance( | |
text: str, | |
spk_cond_path: str | None, | |
guidance_scale, | |
max_new_tokens, | |
top_k, | |
top_p, | |
temperature, | |
) -> str: | |
return GlobalState.tts.synthesise( | |
text, | |
spk_cond_path, | |
top_p=top_p, | |
guidance_scale=guidance_scale, | |
temperature=temperature, | |
) | |
async def text_to_speech(req: Request): | |
audiodata = await req.body() | |
payload = None | |
wav_out_path = None | |
try: | |
headers = req.headers | |
payload = headers["X-Payload"] | |
payload = json.loads(payload) | |
tts_req = TTSRequest(**payload) | |
with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp: | |
if tts_req.speaker_ref_path is None: | |
wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp) | |
else: | |
wav_path = tts_req.speaker_ref_path | |
wav_out_path = sample_utterance( | |
tts_req.text, | |
wav_path, | |
guidance_scale=tts_req.guidance, | |
max_new_tokens=GlobalState.config.max_new_tokens, | |
temperature=GlobalState.config.temperature, | |
top_k=tts_req.top_k, | |
top_p=tts_req.top_p, | |
) | |
with open(wav_out_path, "rb") as f: | |
return Response(content=f.read(), media_type="audio/wav") | |
except Exception as e: | |
# traceback_str = "".join(traceback.format_tb(e.__traceback__)) | |
logger.exception(f"Error processing request {payload}") | |
return Response( | |
content="Something went wrong. Please try again in a few mins or contact us on Discord", | |
status_code=500, | |
) | |
finally: | |
if wav_out_path is not None: | |
Path(wav_out_path).unlink(missing_ok=True) | |
def _convert_audiodata_to_wav_path(audiodata, wav_tmp): | |
with tempfile.NamedTemporaryFile() as unknown_format_tmp: | |
assert unknown_format_tmp.write(audiodata) > 0 | |
unknown_format_tmp.flush() | |
subprocess.check_output( | |
# arbitrary 2 minute cutoff | |
shlex.split(f"ffmpeg -t 120 -y -i {unknown_format_tmp.name} -f wav {wav_tmp.name}") | |
) | |
return wav_tmp.name | |
if __name__ == "__main__": | |
# This has to be here to avoid some weird audiocraft shenaningans messing up matplotlib | |
from fam.llm.enhancers import get_enhancer | |
for name in logging.root.manager.loggerDict: | |
logger = logging.getLogger(name) | |
logger.setLevel(logging.INFO) | |
logging.root.setLevel(logging.INFO) | |
GlobalState.config = tyro.cli(ServingConfig) | |
app.add_middleware( | |
fastapi.middleware.cors.CORSMiddleware, | |
allow_origins=["*", f"http://localhost:{GlobalState.config.port}", "http://localhost:3000"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
common_config = dict( | |
num_samples=1, | |
seed=1337, | |
device=device, | |
dtype=GlobalState.config.dtype, | |
compile=False, | |
init_from="resume", | |
output_dir=tempfile.mkdtemp(), | |
) | |
model_dir = snapshot_download(repo_id=GlobalState.config.huggingface_repo_id) | |
config1 = InferenceConfig( | |
ckpt_path=get_first_stage_path(model_dir), | |
**common_config, | |
) | |
config2 = InferenceConfig( | |
ckpt_path=get_second_stage_path(model_dir), | |
**common_config, | |
) | |
GlobalState.tts = TTS() | |
# start server | |
uvicorn.run( | |
app, | |
host="127.0.0.1", | |
port=GlobalState.config.port, | |
log_level="info", | |
) | |