xtts / xtts.py
li cheng
try id
7ee8daa
raw
history blame
7.61 kB
import re, os, logging, tempfile, subprocess
import requests
import torch
import traceback
from TTS.api import TTS
#ffmpeg -y -i /var/folders/w6/mxy2wbmd2bj360glkp0d5qbw0000gn/T/tmp49s6gxk7.wav -af lowpass=8000,highpass=75,areverse,silenceremove=start_periods=1:start_silence=0:start_threshold=0.02,areverse,silenceremove=start_periods=1:start_silence=0:start_threshold=0.02, ./test.wav
root=os.path.dirname(os.path.abspath(__file__))
ffmpeg=f'{root}/ffmpeg'
#local test
if os.path.exists(f'{root}/env.py'):
ffmpeg="/opt/homebrew/bin/ffmpeg"
import env
os.environ["COQUI_TOS_AGREED"]="1"
api=os.environ.get('api')
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
tts=None
model=None
LIBROSA_CACHE_DIR=f'{os.getcwd()}/caches'
if not os.path.exists(LIBROSA_CACHE_DIR):
os.makedirs(LIBROSA_CACHE_DIR)
os.environ["LIBROSA_CACHE_DIR"]=LIBROSA_CACHE_DIR
sample_root=f'{os.getcwd()}/samples'
if not os.path.exists(sample_root):
os.makedirs(sample_root)
default_sample=f'{root}/sample.wav', f'{sample_root}/sample.pt'
if api:
from qili import upload, check_token
else:
def upload(file):
return file
def check_token(token):
return True
def predict(text, sample=None, language="zh"):
get_tts()
global tts
global model
try:
text= re.sub("([^\x00-\x7F]|\w)(\.|\。|\?)",r"\1 \2\2",text)
output=tempfile.mktemp(suffix=".wav")
tts.tts_to_file(
text,
language=language if language is not None else "zh",
speaker_wav=sample if sample is not None else default_sample[0],
file_path=output
)
output=to_mp3(output)
url= upload(output, request.headers.get('token') or os.environ.get('token'))[0]
if url!=output:
os.remove(output)
return url
except Exception as e:
return str(e)
def playInHTML(url):
return f'''
<html>
<body>
<audio controls autoplay>
<source src="{url}" type="audio/wav">
Your browser does not support the audio element.
</audio>
</body>
</html>
'''
def get_conditioning_latents(audio_path, **others):
global model
speaker_wav, pt_file=download_sample(audio_path)
try:
if pt_file != None:
if not os.path.exists(pt_file):
raise ValueError(f'{pt_file} not exists, regenerat it.')
(
gpt_cond_latent,
speaker_embedding,
) = torch.load(pt_file)
logging.debug(f'sample wav info loaded from {pt_file}')
except:
logging.debug(f'creating sample latent and embedding from {speaker_wav}')
(
gpt_cond_latent,
speaker_embedding,
) = model.__get_conditioning_latents(audio_path=speaker_wav, **others)
torch.save((gpt_cond_latent,speaker_embedding), pt_file)
logging.debug(f'sample latent and embedding saved to {pt_file}')
return gpt_cond_latent,speaker_embedding
def download_sample(url):
try:
response = requests.get(url)
if response.status_code == 200:
id=f'{url}{response.headers["etag"]}'
id=hash(id)
id=f'{sample_root}/{id}.pt'.replace('"','')
if(os.path.exists(id)):
return "", id
with tempfile.NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as temp_file:
temp_file.write(response.content)
logging.debug(f'downloaded sample wav from {url}')
return trim_sample_audio(os.path.abspath(temp_file.name)), id
except:
return default_sample
def download(url):
response = requests.get(url)
if response.status_code == 200:
with tempfile.NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as temp_file:
temp_file.write(response.content)
return os.path.abspath(temp_file.name)
def trim_sample_audio(speaker_wav, threshold=0.005):
global ffmpeg
try:
lowpass_highpass = "lowpass=8000,highpass=75,"
trim_silence = f"areverse,silenceremove=start_periods=1:start_silence=0:start_threshold={threshold},areverse,silenceremove=start_periods=1:start_silence=0:start_threshold={threshold},"
out_filename=speaker_wav.replace(".wav","_trimed.wav")
shell_command = f"{ffmpeg} -y -i {speaker_wav} -af {lowpass_highpass}{trim_silence} {out_filename}".split(" ")
logging.debug(" ".join(shell_command))
result=subprocess.run(
[item for item in shell_command],
capture_output=True,
text=True,
check=True,
stdout=None, #subprocess.DEVNULL,
stderr=None,
)
if result.stderr is not None and "Output file is empty" in result.stderr:
if threshold > 0.001:
return trim_sample_audio(speaker_wav, threshold/2)
return speaker_wav
os.remove(speaker_wav)
logging.debug(f'trimed sample wav to {out_filename}')
return out_filename
except:
logging.debug(f'Error: trimed sample wav to, ignored')
return speaker_wav
def to_mp3(wav):
global ffmpeg
try:
mp3=tempfile.mktemp(suffix=".mp3")
shell_command = f"{ffmpeg} -i {wav} {mp3}".split(" ")
subprocess.run(
[item for item in shell_command],
capture_output=False,
text=True,
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
os.remove(wav)
logging.debug(f'convert wav {wav} to mp3 at {mp3}')
return mp3
except:
logging.debug(f'error: convert wav to mp3, use wav')
return wav
# if __name__ == "__main__":
# app = Flask(__name__)
# else:
# app = Blueprint("xtts", __name__)
from flask import Flask, request
app = Flask(__name__)
@app.route("/tts")
def convert():
check_token(request.headers.get('token') or os.environ.get('token'))
text = request.args.get('text')
sample = request.args.get('sample')
language = request.args.get('language')
if text is None:
return 'text is missing', 400
return predict(text, sample, language)
@app.route("/tts/play")
def tts_play():
url=convert()
return playInHTML(url)
@app.route("/setup")
def get_tts(model_path=os.environ["MODEL_DIR"]):
global tts
global model
if tts is None:
config_path=f'{model_path}/config.json'
vocoder_config_path=f'{model_path}/vocab.json'
model_name="tts_models/multilingual/multi-dataset/xtts_v2"
logging.info(f"loading model {model_name} ...")
tts = TTS(
model_name if not model_path else None,
model_path=model_path if model_path else None,
config_path=config_path if model_path else None,
vocoder_config_path=vocoder_config_path if model_path else None,
progress_bar=True
)
model=tts.synthesizer.tts_model
#hack to use cache
model.__get_conditioning_latents=model.get_conditioning_latents
model.get_conditioning_latents=get_conditioning_latents
logging.info("model is ready")
return "ready"
# import gradio as gr
# demo=gr.Interface(predict, inputs=["text", "text"], outputs=gr.Audio())
# app = gr.mount_gradio_app(app, demo, path="/")
@app.route("/")
def hello():
return "welcome!"
logging.info("xtts is ready")