TangRain commited on
Commit
4f1f8e0
1 Parent(s): dbf7985

update pipeline for model loading

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -62,6 +62,10 @@ langs = {
62
  "jp": 1,
63
  }
64
 
 
 
 
 
65
  def gen_song(model_name, spk, texts, durs, pitchs):
66
  fs = 44100
67
  tempo = 120
@@ -141,15 +145,19 @@ def gen_song(model_name, spk, texts, durs, pitchs):
141
  # return (fs, np.array([0.0])), "success!"
142
 
143
  # Infer
144
- device = "cpu"
145
- # device = "cuda" if torch.cuda.is_available() else "cpu"
146
- d = ModelDownloader()
147
- pretrain_downloaded = d.download_and_unpack(PRETRAIN_MODEL)
148
- svs = SingingGenerate(
149
- train_config = pretrain_downloaded["train_config"],
150
- model_file = pretrain_downloaded["model_file"],
151
- device = device
152
- )
 
 
 
 
153
  if model_name == "Model①(Chinese)-zh":
154
  sid = np.array([singer_embeddings[model_name][spk]])
155
  output_dict = svs(batch, sids=sid)
@@ -160,7 +168,7 @@ def gen_song(model_name, spk, texts, durs, pitchs):
160
  wav_info = output_dict["wav"].cpu().numpy()
161
 
162
  # mos prediction with sr=16k
163
- predictor = torch.hub.load("South-Twilight/SingMOS:v0.2.0", "singing_ssl_mos", trust_repo=True)
164
  wav_mos = librosa.resample(wav_info, orig_sr=fs, target_sr=16000)
165
  wav_mos = torch.from_numpy(wav_mos).unsqueeze(0)
166
  len_mos = torch.tensor([wav_mos.shape[1]])
 
62
  "jp": 1,
63
  }
64
 
65
+ predictor = torch.hub.load("South-Twilight/SingMOS:v0.2.0", "singing_ssl_mos", trust_repo=True)
66
+ exist_model = "Null"
67
+ svs = None
68
+
69
  def gen_song(model_name, spk, texts, durs, pitchs):
70
  fs = 44100
71
  tempo = 120
 
145
  # return (fs, np.array([0.0])), "success!"
146
 
147
  # Infer
148
+ global exist_model
149
+ global svs
150
+ if exist_model == "Null" or exist_model != model_name:
151
+ device = "cpu"
152
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
153
+ d = ModelDownloader()
154
+ pretrain_downloaded = d.download_and_unpack(PRETRAIN_MODEL)
155
+ svs = SingingGenerate(
156
+ train_config = pretrain_downloaded["train_config"],
157
+ model_file = pretrain_downloaded["model_file"],
158
+ device = device
159
+ )
160
+ exist_model = model_name
161
  if model_name == "Model①(Chinese)-zh":
162
  sid = np.array([singer_embeddings[model_name][spk]])
163
  output_dict = svs(batch, sids=sid)
 
168
  wav_info = output_dict["wav"].cpu().numpy()
169
 
170
  # mos prediction with sr=16k
171
+ global predictor
172
  wav_mos = librosa.resample(wav_info, orig_sr=fs, target_sr=16000)
173
  wav_mos = torch.from_numpy(wav_mos).unsqueeze(0)
174
  len_mos = torch.tensor([wav_mos.shape[1]])