Update xtts.py
Browse files
xtts.py
CHANGED
@@ -13,6 +13,10 @@ import torch
|
|
13 |
import torchaudio
|
14 |
|
15 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
|
|
|
|
|
16 |
|
17 |
app = Flask(__name__)
|
18 |
# def upload_bytes(bytes, ext=".wav"):
|
@@ -55,7 +59,12 @@ def predict():
|
|
55 |
TTS=import_module("TTS.api").TTS
|
56 |
model_name="tts_models/multilingual/multi-dataset/xtts_v2"
|
57 |
logging.info(f"loading model {model_name} ...")
|
58 |
-
tts = TTS(
|
|
|
|
|
|
|
|
|
|
|
59 |
model=tts.synthesizer.tts_model
|
60 |
#hack to use cache
|
61 |
model.__get_conditioning_latents=model.get_conditioning_latents
|
|
|
13 |
import torchaudio
|
14 |
|
15 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
16 |
+
model_dir=os.environ.get("MODEL_DIR")
|
17 |
+
model_path=f'{model_dir}/model.pth'
|
18 |
+
config_path=f'{model_dir}/config.json'
|
19 |
+
vocoder_config_path=f'{model_dir}/vocab.json'
|
20 |
|
21 |
app = Flask(__name__)
|
22 |
# def upload_bytes(bytes, ext=".wav"):
|
|
|
59 |
TTS=import_module("TTS.api").TTS
|
60 |
model_name="tts_models/multilingual/multi-dataset/xtts_v2"
|
61 |
logging.info(f"loading model {model_name} ...")
|
62 |
+
tts = TTS(
|
63 |
+
model_path=model_path,
|
64 |
+
config_path=config_path,
|
65 |
+
vocoder_config_path=vocoder_config_path,
|
66 |
+
progress_bar=False
|
67 |
+
)
|
68 |
model=tts.synthesizer.tts_model
|
69 |
#hack to use cache
|
70 |
model.__get_conditioning_latents=model.get_conditioning_latents
|