Staticaliza commited on
Commit
464583c
·
verified ·
1 Parent(s): 138fa16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -81
app.py CHANGED
@@ -51,11 +51,7 @@ def load_custom_model_from_hf(repo_id, model_filename="pytorch_model.bin", confi
51
  return model_path, config_path
52
 
53
  # Load DiT model
54
- dit_checkpoint_path, dit_config_path = load_custom_model_from_hf(
55
- "Plachta/Seed-VC",
56
- "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth",
57
- "config_dit_mel_seed_uvit_whisper_small_wavenet.yml"
58
- )
59
  config = yaml.safe_load(open(dit_config_path, 'r'))
60
  model_params = recursive_munch(config['model_params'])
61
  model = build_model(model_params, stage='DiT')
@@ -129,11 +125,7 @@ mel_fn_args = {
129
  to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
130
 
131
  # Load F0 conditioned model
132
- dit_checkpoint_path_f0, dit_config_path_f0 = load_custom_model_from_hf(
133
- "Plachta/Seed-VC",
134
- "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth",
135
- "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml"
136
- )
137
  config_f0 = yaml.safe_load(open(dit_config_path_f0, 'r'))
138
  model_params_f0 = recursive_munch(config_f0['model_params'])
139
  model_f0 = build_model(model_params_f0, stage='DiT')
@@ -220,22 +212,9 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
220
  # Generate Whisper features
221
  print("[INFO] | Generating Whisper features for source audio.")
222
  if converted_waves_16k.size(-1) <= 16000 * 30:
223
- alt_inputs = whisper_feature_extractor(
224
- [converted_waves_16k.squeeze(0).cpu().numpy()],
225
- return_tensors="pt",
226
- return_attention_mask=True,
227
- sampling_rate=16000
228
- )
229
- alt_input_features = whisper_model._mask_input_features(
230
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask
231
- ).to(device)
232
- alt_outputs = whisper_model.encoder(
233
- alt_input_features.to(torch.float32),
234
- head_mask=None,
235
- output_attentions=False,
236
- output_hidden_states=False,
237
- return_dict=True
238
- )
239
  S_alt = alt_outputs.last_hidden_state.to(torch.float32)
240
  S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
241
  print(f"[INFO] | S_alt shape: {S_alt.shape}")
@@ -254,26 +233,10 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
254
  if buffer is None:
255
  chunk = converted_waves_16k[:, traversed_time:traversed_time + chunk_size]
256
  else:
257
- chunk = torch.cat([
258
- buffer,
259
- converted_waves_16k[:, traversed_time:traversed_time + chunk_size - overlap_size]
260
- ], dim=-1)
261
- alt_inputs = whisper_feature_extractor(
262
- [chunk.squeeze(0).cpu().numpy()],
263
- return_tensors="pt",
264
- return_attention_mask=True,
265
- sampling_rate=16000
266
- )
267
- alt_input_features = whisper_model._mask_input_features(
268
- alt_inputs.input_features, attention_mask=alt_inputs.attention_mask
269
- ).to(device)
270
- alt_outputs = whisper_model.encoder(
271
- alt_input_features.to(torch.float32),
272
- head_mask=None,
273
- output_attentions=False,
274
- output_hidden_states=False,
275
- return_dict=True
276
- )
277
  S_chunk = alt_outputs.last_hidden_state.to(torch.float32)
278
  S_chunk = S_chunk[:, :chunk.size(-1) // 320 + 1]
279
  print(f"[INFO] | Processed chunk with S_chunk shape: {S_chunk.shape}")
@@ -293,22 +256,9 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
293
  # Original Whisper features
294
  print("[INFO] | Generating Whisper features for reference audio.")
295
  ori_waves_16k = torchaudio.functional.resample(ref_audio_tensor, sr_current, 16000)
296
- ori_inputs = whisper_feature_extractor(
297
- [ori_waves_16k.squeeze(0).cpu().numpy()],
298
- return_tensors="pt",
299
- return_attention_mask=True,
300
- sampling_rate=16000
301
- )
302
- ori_input_features = whisper_model._mask_input_features(
303
- ori_inputs.input_features, attention_mask=ori_inputs.attention_mask
304
- ).to(device)
305
- ori_outputs = whisper_model.encoder(
306
- ori_input_features.to(torch.float32),
307
- head_mask=None,
308
- output_attentions=False,
309
- output_hidden_states=False,
310
- return_dict=True
311
- )
312
  S_ori = ori_outputs.last_hidden_state.to(torch.float32)
313
  S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
314
  print(f"[INFO] | S_ori shape: {S_ori.shape}")
@@ -326,12 +276,7 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
326
 
327
  # Extract style features
328
  print("[INFO] | Extracting style features from reference audio.")
329
- feat2 = torchaudio.compliance.kaldi.fbank(
330
- ref_waves_16k,
331
- num_mel_bins=80,
332
- dither=0,
333
- sample_frequency=16000
334
- )
335
  feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
336
  style2 = campplus_model(feat2.unsqueeze(0))
337
  print(f"[INFO] | Style2 shape: {style2.shape}")
@@ -358,9 +303,7 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
358
  # Shift F0 levels
359
  shifted_log_f0_alt = log_f0_alt.clone()
360
  if auto_f0_adjust:
361
- shifted_log_f0_alt[F0_alt > 1] = (
362
- log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
363
- )
364
  shifted_f0_alt = torch.exp(shifted_log_f0_alt)
365
  if pitch != 0:
366
  shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch)
@@ -390,15 +333,7 @@ def voice_conversion(input, reference, steps, guidance, speed, use_conditioned,
390
  cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
391
 
392
  # Perform inference
393
- vc_target = inference_module.cfm.inference(
394
- cat_condition,
395
- torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
396
- mel2,
397
- style2,
398
- None,
399
- steps,
400
- inference_cfg_rate=guidance
401
- )
402
  vc_target = vc_target[:, :, mel2.size(2):]
403
  print(f"[INFO] | vc_target shape: {vc_target.shape}")
404
 
@@ -458,7 +393,7 @@ with gr.Blocks(css=css) as main:
458
  speed = gr.Slider(label="Speed", value=1.0, minimum=0.5, maximum=2.0, step=0.1)
459
 
460
  with gr.Column():
461
- use_conditioned = gr.Checkbox(label="Use 'F0 Conditioned Model'", value=False),
462
  use_auto_adjustment = gr.Checkbox(label="Use 'Auto F0 Adjustment' with 'F0 Conditioned Model'", value=True)
463
  pitch = gr.Slider(label="Pitch with 'F0 Conditioned Model'", value=0, minimum=-12, maximum=12, step=1)
464
 
 
51
  return model_path, config_path
52
 
53
  # Load DiT model
54
+ dit_checkpoint_path, dit_config_path = load_custom_model_from_hf("Plachta/Seed-VC", "DiT_seed_v2_uvit_whisper_small_wavenet_bigvgan_pruned.pth", "config_dit_mel_seed_uvit_whisper_small_wavenet.yml")
 
 
 
 
55
  config = yaml.safe_load(open(dit_config_path, 'r'))
56
  model_params = recursive_munch(config['model_params'])
57
  model = build_model(model_params, stage='DiT')
 
125
  to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
126
 
127
  # Load F0 conditioned model
128
+ dit_checkpoint_path_f0, dit_config_path_f0 = load_custom_model_from_hf("Plachta/Seed-VC", "DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema.pth", "config_dit_mel_seed_uvit_whisper_base_f0_44k.yml")
 
 
 
 
129
  config_f0 = yaml.safe_load(open(dit_config_path_f0, 'r'))
130
  model_params_f0 = recursive_munch(config_f0['model_params'])
131
  model_f0 = build_model(model_params_f0, stage='DiT')
 
212
  # Generate Whisper features
213
  print("[INFO] | Generating Whisper features for source audio.")
214
  if converted_waves_16k.size(-1) <= 16000 * 30:
215
+ alt_inputs = whisper_feature_extractor([converted_waves_16k.squeeze(0).cpu().numpy()], return_tensors="pt", return_attention_mask=True, sampling_rate=16000)
216
+ alt_input_features = whisper_model._mask_input_features(alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
217
+ alt_outputs = whisper_model.encoder(alt_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  S_alt = alt_outputs.last_hidden_state.to(torch.float32)
219
  S_alt = S_alt[:, :converted_waves_16k.size(-1) // 320 + 1]
220
  print(f"[INFO] | S_alt shape: {S_alt.shape}")
 
233
  if buffer is None:
234
  chunk = converted_waves_16k[:, traversed_time:traversed_time + chunk_size]
235
  else:
236
+ chunk = torch.cat([buffer, converted_waves_16k[:, traversed_time:traversed_time + chunk_size - overlap_size]], dim=-1)
237
+ alt_inputs = whisper_feature_extractor([chunk.squeeze(0).cpu().numpy()],return_tensors="pt", return_attention_mask=True, sampling_rate=16000)
238
+ alt_input_features = whisper_model._mask_input_features(alt_inputs.input_features, attention_mask=alt_inputs.attention_mask).to(device)
239
+ alt_outputs = whisper_model.encoder(alt_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  S_chunk = alt_outputs.last_hidden_state.to(torch.float32)
241
  S_chunk = S_chunk[:, :chunk.size(-1) // 320 + 1]
242
  print(f"[INFO] | Processed chunk with S_chunk shape: {S_chunk.shape}")
 
256
  # Original Whisper features
257
  print("[INFO] | Generating Whisper features for reference audio.")
258
  ori_waves_16k = torchaudio.functional.resample(ref_audio_tensor, sr_current, 16000)
259
+ ori_inputs = whisper_feature_extractor([ori_waves_16k.squeeze(0).cpu().numpy()], return_tensors="pt", return_attention_mask=True, sampling_rate=16000)
260
+ ori_input_features = whisper_model._mask_input_features(ori_inputs.input_features, attention_mask=ori_inputs.attention_mask).to(device)
261
+ ori_outputs = whisper_model.encoder(ori_input_features.to(torch.float32), head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  S_ori = ori_outputs.last_hidden_state.to(torch.float32)
263
  S_ori = S_ori[:, :ori_waves_16k.size(-1) // 320 + 1]
264
  print(f"[INFO] | S_ori shape: {S_ori.shape}")
 
276
 
277
  # Extract style features
278
  print("[INFO] | Extracting style features from reference audio.")
279
+ feat2 = torchaudio.compliance.kaldi.fbank(ref_waves_16k, num_mel_bins=80, dither=0, sample_frequency=16000)
 
 
 
 
 
280
  feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
281
  style2 = campplus_model(feat2.unsqueeze(0))
282
  print(f"[INFO] | Style2 shape: {style2.shape}")
 
303
  # Shift F0 levels
304
  shifted_log_f0_alt = log_f0_alt.clone()
305
  if auto_f0_adjust:
306
+ shifted_log_f0_alt[F0_alt > 1] = (log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori)
 
 
307
  shifted_f0_alt = torch.exp(shifted_log_f0_alt)
308
  if pitch != 0:
309
  shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch)
 
333
  cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
334
 
335
  # Perform inference
336
+ vc_target = inference_module.cfm.inference(cat_condition, torch.LongTensor([cat_condition.size(1)]).to(mel2.device), mel2, style2, None, steps, inference_cfg_rate=guidance)
 
 
 
 
 
 
 
 
337
  vc_target = vc_target[:, :, mel2.size(2):]
338
  print(f"[INFO] | vc_target shape: {vc_target.shape}")
339
 
 
393
  speed = gr.Slider(label="Speed", value=1.0, minimum=0.5, maximum=2.0, step=0.1)
394
 
395
  with gr.Column():
396
+ use_conditioned = gr.Checkbox(label="Use 'F0 Conditioned Model'", value=False)
397
  use_auto_adjustment = gr.Checkbox(label="Use 'Auto F0 Adjustment' with 'F0 Conditioned Model'", value=True)
398
  pitch = gr.Slider(label="Pitch with 'F0 Conditioned Model'", value=0, minimum=-12, maximum=12, step=1)
399