freyza commited on
Commit
bfcae1f
·
verified ·
1 Parent(s): 7f9f464

Update src/rmvpe.py

Browse files
Files changed (1) hide show
  1. src/rmvpe.py +34 -11
src/rmvpe.py CHANGED
@@ -1,8 +1,7 @@
1
- import numpy as np
2
- import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from librosa.filters import mel
6
 
7
 
8
  class BiGRU(nn.Module):
@@ -248,7 +247,7 @@ class E2E(nn.Module):
248
  )
249
  else:
250
  self.fc = nn.Sequential(
251
- nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
252
  )
253
 
254
  def forward(self, mel):
@@ -258,6 +257,9 @@ class E2E(nn.Module):
258
  return x
259
 
260
 
 
 
 
261
  class MelSpectrogram(torch.nn.Module):
262
  def __init__(
263
  self,
@@ -384,8 +386,8 @@ class RMVPE:
384
 
385
  def to_local_average_cents(self, salience, thred=0.05):
386
  # t0 = ttime()
387
- center = np.argmax(salience, axis=1) # 帧长#index
388
- salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
389
  # t1 = ttime()
390
  center += 4
391
  todo_salience = []
@@ -396,14 +398,35 @@ class RMVPE:
396
  todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
397
  todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
398
  # t2 = ttime()
399
- todo_salience = np.array(todo_salience) # 帧长,9
400
- todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
401
  product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
402
- weight_sum = np.sum(todo_salience, 1) # 帧长
403
- devided = product_sum / weight_sum # 帧长
404
  # t3 = ttime()
405
- maxx = np.max(salience, axis=1) # 帧长
406
  devided[maxx <= thred] = 0
407
  # t4 = ttime()
408
  # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
409
  return devided
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, numpy as np
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
+
5
 
6
 
7
  class BiGRU(nn.Module):
 
247
  )
248
  else:
249
  self.fc = nn.Sequential(
250
+ nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
251
  )
252
 
253
  def forward(self, mel):
 
257
  return x
258
 
259
 
260
+ from librosa.filters import mel
261
+
262
+
263
  class MelSpectrogram(torch.nn.Module):
264
  def __init__(
265
  self,
 
386
 
387
  def to_local_average_cents(self, salience, thred=0.05):
388
  # t0 = ttime()
389
+ center = np.argmax(salience, axis=1) # frame length#index
390
+ salience = np.pad(salience, ((0, 0), (4, 4))) # frame length,368
391
  # t1 = ttime()
392
  center += 4
393
  todo_salience = []
 
398
  todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
399
  todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
400
  # t2 = ttime()
401
+ todo_salience = np.array(todo_salience) # frame length,9
402
+ todo_cents_mapping = np.array(todo_cents_mapping) # frame length,9
403
  product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
404
+ weight_sum = np.sum(todo_salience, 1) # frame length
405
+ devided = product_sum / weight_sum # frame length
406
  # t3 = ttime()
407
+ maxx = np.max(salience, axis=1) # frame length
408
  devided[maxx <= thred] = 0
409
  # t4 = ttime()
410
  # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
411
  return devided
412
+
413
+
414
+ # if __name__ == '__main__':
415
+ # audio, sampling_rate = sf.read("Quotations~1.wav") ### edit
416
+ # if len(audio.shape) > 1:
417
+ # audio = librosa.to_mono(audio.transpose(1, 0))
418
+ # audio_bak = audio.copy()
419
+ # if sampling_rate != 16000:
420
+ # audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
421
+ # model_path = "/bili-coeus/jupyter/jupyterhub-liujing04/vits_ch/test-RMVPE/weights/rmvpe_llc_half.pt"
422
+ # thred = 0.03 # 0.01
423
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
424
+ # rmvpe = RMVPE(model_path,is_half=False, device=device)
425
+ # t0=ttime()
426
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
427
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
428
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
429
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
430
+ # f0 = rmvpe.infer_from_audio(audio, thred=thred)
431
+ # t1=ttime()
432
+ # print(f0.shape,t1-t0)9