update pipeline for model loading
Browse files
app.py
CHANGED
@@ -62,6 +62,10 @@ langs = {
|
|
62 |
"jp": 1,
|
63 |
}
|
64 |
|
|
|
|
|
|
|
|
|
65 |
def gen_song(model_name, spk, texts, durs, pitchs):
|
66 |
fs = 44100
|
67 |
tempo = 120
|
@@ -141,15 +145,19 @@ def gen_song(model_name, spk, texts, durs, pitchs):
|
|
141 |
# return (fs, np.array([0.0])), "success!"
|
142 |
|
143 |
# Infer
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
153 |
if model_name == "Model①(Chinese)-zh":
|
154 |
sid = np.array([singer_embeddings[model_name][spk]])
|
155 |
output_dict = svs(batch, sids=sid)
|
@@ -160,7 +168,7 @@ def gen_song(model_name, spk, texts, durs, pitchs):
|
|
160 |
wav_info = output_dict["wav"].cpu().numpy()
|
161 |
|
162 |
# mos prediction with sr=16k
|
163 |
-
predictor
|
164 |
wav_mos = librosa.resample(wav_info, orig_sr=fs, target_sr=16000)
|
165 |
wav_mos = torch.from_numpy(wav_mos).unsqueeze(0)
|
166 |
len_mos = torch.tensor([wav_mos.shape[1]])
|
|
|
62 |
"jp": 1,
|
63 |
}
|
64 |
|
65 |
+
predictor = torch.hub.load("South-Twilight/SingMOS:v0.2.0", "singing_ssl_mos", trust_repo=True)
|
66 |
+
exist_model = "Null"
|
67 |
+
svs = None
|
68 |
+
|
69 |
def gen_song(model_name, spk, texts, durs, pitchs):
|
70 |
fs = 44100
|
71 |
tempo = 120
|
|
|
145 |
# return (fs, np.array([0.0])), "success!"
|
146 |
|
147 |
# Infer
|
148 |
+
global exist_model
|
149 |
+
global svs
|
150 |
+
if exist_model == "Null" or exist_model != model_name:
|
151 |
+
device = "cpu"
|
152 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
153 |
+
d = ModelDownloader()
|
154 |
+
pretrain_downloaded = d.download_and_unpack(PRETRAIN_MODEL)
|
155 |
+
svs = SingingGenerate(
|
156 |
+
train_config = pretrain_downloaded["train_config"],
|
157 |
+
model_file = pretrain_downloaded["model_file"],
|
158 |
+
device = device
|
159 |
+
)
|
160 |
+
exist_model = model_name
|
161 |
if model_name == "Model①(Chinese)-zh":
|
162 |
sid = np.array([singer_embeddings[model_name][spk]])
|
163 |
output_dict = svs(batch, sids=sid)
|
|
|
168 |
wav_info = output_dict["wav"].cpu().numpy()
|
169 |
|
170 |
# mos prediction with sr=16k
|
171 |
+
global predictor
|
172 |
wav_mos = librosa.resample(wav_info, orig_sr=fs, target_sr=16000)
|
173 |
wav_mos = torch.from_numpy(wav_mos).unsqueeze(0)
|
174 |
len_mos = torch.tensor([wav_mos.shape[1]])
|