BartPoint commited on
Commit
60b02a0
1 Parent(s): 59cc821

Update vc_infer_pipeline.py

Browse files
Files changed (1) hide show
  1. vc_infer_pipeline.py +90 -32
vc_infer_pipeline.py CHANGED
@@ -1,4 +1,4 @@
1
- import numpy as np, parselmouth, torch, pdb
2
  from time import time as ttime
3
  import torch.nn.functional as F
4
  import scipy.signal as signal
@@ -6,13 +6,17 @@ import pyworld, os, traceback, faiss, librosa, torchcrepe
6
  from scipy import signal
7
  from functools import lru_cache
8
 
 
 
 
9
  bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
10
 
11
- input_audio_path2wav={}
 
12
 
13
  @lru_cache
14
- def cache_harvest_f0(input_audio_path,fs,f0max,f0min,frame_period):
15
- audio=input_audio_path2wav[input_audio_path]
16
  f0, t = pyworld.harvest(
17
  audio,
18
  fs=fs,
@@ -23,18 +27,29 @@ def cache_harvest_f0(input_audio_path,fs,f0max,f0min,frame_period):
23
  f0 = pyworld.stonemask(audio, f0, t, fs)
24
  return f0
25
 
26
- def change_rms(data1,sr1,data2,sr2,rate):#1是输入音频,2是输出音频,rate是2的占比
 
27
  # print(data1.max(),data2.max())
28
- rms1 = librosa.feature.rms(y=data1, frame_length=sr1//2*2, hop_length=sr1//2)#每半秒一个点
29
- rms2 = librosa.feature.rms(y=data2, frame_length=sr2//2*2, hop_length=sr2//2)
30
- rms1=torch.from_numpy(rms1)
31
- rms1=F.interpolate(rms1.unsqueeze(0), size=data2.shape[0],mode='linear').squeeze()
32
- rms2=torch.from_numpy(rms2)
33
- rms2=F.interpolate(rms2.unsqueeze(0), size=data2.shape[0],mode='linear').squeeze()
34
- rms2=torch.max(rms2,torch.zeros_like(rms2)+1e-6)
35
- data2*=(torch.pow(rms1,torch.tensor(1-rate))*torch.pow(rms2,torch.tensor(rate-1))).numpy()
 
 
 
 
 
 
 
 
 
36
  return data2
37
 
 
38
  class VC(object):
39
  def __init__(self, tgt_sr, config):
40
  self.x_pad, self.x_query, self.x_center, self.x_max, self.is_half = (
@@ -54,7 +69,16 @@ class VC(object):
54
  self.t_max = self.sr * self.x_max # 免查询时长阈值
55
  self.device = config.device
56
 
57
- def get_f0(self, input_audio_path,x, p_len, f0_up_key, f0_method,filter_radius, inp_f0=None):
 
 
 
 
 
 
 
 
 
58
  global input_audio_path2wav
59
  time_step = self.window / self.sr * 1000
60
  f0_min = 50
@@ -78,9 +102,9 @@ class VC(object):
78
  f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
79
  )
80
  elif f0_method == "harvest":
81
- input_audio_path2wav[input_audio_path]=x.astype(np.double)
82
- f0=cache_harvest_f0(input_audio_path,self.sr,f0_max,f0_min,10)
83
- if(filter_radius>2):
84
  f0 = signal.medfilt(f0, 3)
85
  elif f0_method == "crepe":
86
  model = "full"
@@ -103,6 +127,15 @@ class VC(object):
103
  f0 = torchcrepe.filter.mean(f0, 3)
104
  f0[pd < 0.1] = 0
105
  f0 = f0[0].cpu().numpy()
 
 
 
 
 
 
 
 
 
106
  f0 *= pow(2, f0_up_key / 12)
107
  # with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
108
  tf0 = self.sr // self.window # 每秒f0点数
@@ -125,7 +158,7 @@ class VC(object):
125
  ) + 1
126
  f0_mel[f0_mel <= 1] = 1
127
  f0_mel[f0_mel > 255] = 255
128
- f0_coarse = np.rint(f0_mel).astype(int)
129
  return f0_coarse, f0bak # 1-0
130
 
131
  def vc(
@@ -141,6 +174,7 @@ class VC(object):
141
  big_npy,
142
  index_rate,
143
  version,
 
144
  ): # ,file_index,file_big_npy
145
  feats = torch.from_numpy(audio0)
146
  if self.is_half:
@@ -161,8 +195,9 @@ class VC(object):
161
  t0 = ttime()
162
  with torch.no_grad():
163
  logits = model.extract_features(**inputs)
164
- feats = model.final_proj(logits[0])if version=="v1"else logits[0]
165
-
 
166
  if (
167
  isinstance(index, type(None)) == False
168
  and isinstance(big_npy, type(None)) == False
@@ -188,6 +223,10 @@ class VC(object):
188
  )
189
 
190
  feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
 
 
 
 
191
  t1 = ttime()
192
  p_len = audio0.shape[0] // self.window
193
  if feats.shape[1] < p_len:
@@ -195,6 +234,14 @@ class VC(object):
195
  if pitch != None and pitchf != None:
196
  pitch = pitch[:, :p_len]
197
  pitchf = pitchf[:, :p_len]
 
 
 
 
 
 
 
 
198
  p_len = torch.tensor([p_len], device=self.device).long()
199
  with torch.no_grad():
200
  if pitch != None and pitchf != None:
@@ -206,10 +253,7 @@ class VC(object):
206
  )
207
  else:
208
  audio1 = (
209
- (net_g.infer(feats, p_len, sid)[0][0, 0])
210
- .data.cpu()
211
- .float()
212
- .numpy()
213
  )
214
  del feats, p_len, padding_mask
215
  if torch.cuda.is_available():
@@ -238,6 +282,7 @@ class VC(object):
238
  resample_sr,
239
  rms_mix_rate,
240
  version,
 
241
  f0_file=None,
242
  ):
243
  if (
@@ -292,7 +337,15 @@ class VC(object):
292
  sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
293
  pitch, pitchf = None, None
294
  if if_f0 == 1:
295
- pitch, pitchf = self.get_f0(input_audio_path,audio_pad, p_len, f0_up_key, f0_method,filter_radius, inp_f0)
 
 
 
 
 
 
 
 
296
  pitch = pitch[:p_len]
297
  pitchf = pitchf[:p_len]
298
  if self.device == "mps":
@@ -317,6 +370,7 @@ class VC(object):
317
  big_npy,
318
  index_rate,
319
  version,
 
320
  )[self.t_pad_tgt : -self.t_pad_tgt]
321
  )
322
  else:
@@ -333,6 +387,7 @@ class VC(object):
333
  big_npy,
334
  index_rate,
335
  version,
 
336
  )[self.t_pad_tgt : -self.t_pad_tgt]
337
  )
338
  s = t
@@ -350,6 +405,7 @@ class VC(object):
350
  big_npy,
351
  index_rate,
352
  version,
 
353
  )[self.t_pad_tgt : -self.t_pad_tgt]
354
  )
355
  else:
@@ -366,19 +422,21 @@ class VC(object):
366
  big_npy,
367
  index_rate,
368
  version,
 
369
  )[self.t_pad_tgt : -self.t_pad_tgt]
370
  )
371
  audio_opt = np.concatenate(audio_opt)
372
- if(rms_mix_rate!=1):
373
- audio_opt=change_rms(audio,16000,audio_opt,tgt_sr,rms_mix_rate)
374
- if(resample_sr>=16000 and tgt_sr!=resample_sr):
375
  audio_opt = librosa.resample(
376
  audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
377
  )
378
- audio_max=np.abs(audio_opt).max()/0.99
379
- max_int16=32768
380
- if(audio_max>1):max_int16/=audio_max
381
- audio_opt=(audio_opt * max_int16).astype(np.int16)
 
382
  del pitch, pitchf, sid
383
  if torch.cuda.is_available():
384
  torch.cuda.empty_cache()
 
1
+ import numpy as np, parselmouth, torch, pdb, sys, os
2
  from time import time as ttime
3
  import torch.nn.functional as F
4
  import scipy.signal as signal
 
6
  from scipy import signal
7
  from functools import lru_cache
8
 
9
+ now_dir = os.getcwd()
10
+ sys.path.append(now_dir)
11
+
12
  bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
13
 
14
+ input_audio_path2wav = {}
15
+
16
 
17
  @lru_cache
18
+ def cache_harvest_f0(input_audio_path, fs, f0max, f0min, frame_period):
19
+ audio = input_audio_path2wav[input_audio_path]
20
  f0, t = pyworld.harvest(
21
  audio,
22
  fs=fs,
 
27
  f0 = pyworld.stonemask(audio, f0, t, fs)
28
  return f0
29
 
30
+
31
+ def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比
32
  # print(data1.max(),data2.max())
33
+ rms1 = librosa.feature.rms(
34
+ y=data1, frame_length=sr1 // 2 * 2, hop_length=sr1 // 2
35
+ ) # 每半秒一个点
36
+ rms2 = librosa.feature.rms(y=data2, frame_length=sr2 // 2 * 2, hop_length=sr2 // 2)
37
+ rms1 = torch.from_numpy(rms1)
38
+ rms1 = F.interpolate(
39
+ rms1.unsqueeze(0), size=data2.shape[0], mode="linear"
40
+ ).squeeze()
41
+ rms2 = torch.from_numpy(rms2)
42
+ rms2 = F.interpolate(
43
+ rms2.unsqueeze(0), size=data2.shape[0], mode="linear"
44
+ ).squeeze()
45
+ rms2 = torch.max(rms2, torch.zeros_like(rms2) + 1e-6)
46
+ data2 *= (
47
+ torch.pow(rms1, torch.tensor(1 - rate))
48
+ * torch.pow(rms2, torch.tensor(rate - 1))
49
+ ).numpy()
50
  return data2
51
 
52
+
53
  class VC(object):
54
  def __init__(self, tgt_sr, config):
55
  self.x_pad, self.x_query, self.x_center, self.x_max, self.is_half = (
 
69
  self.t_max = self.sr * self.x_max # 免查询时长阈值
70
  self.device = config.device
71
 
72
+ def get_f0(
73
+ self,
74
+ input_audio_path,
75
+ x,
76
+ p_len,
77
+ f0_up_key,
78
+ f0_method,
79
+ filter_radius,
80
+ inp_f0=None,
81
+ ):
82
  global input_audio_path2wav
83
  time_step = self.window / self.sr * 1000
84
  f0_min = 50
 
102
  f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
103
  )
104
  elif f0_method == "harvest":
105
+ input_audio_path2wav[input_audio_path] = x.astype(np.double)
106
+ f0 = cache_harvest_f0(input_audio_path, self.sr, f0_max, f0_min, 10)
107
+ if filter_radius > 2:
108
  f0 = signal.medfilt(f0, 3)
109
  elif f0_method == "crepe":
110
  model = "full"
 
127
  f0 = torchcrepe.filter.mean(f0, 3)
128
  f0[pd < 0.1] = 0
129
  f0 = f0[0].cpu().numpy()
130
+ elif f0_method == "rmvpe":
131
+ if hasattr(self, "model_rmvpe") == False:
132
+ from rmvpe import RMVPE
133
+
134
+ print("loading rmvpe model")
135
+ self.model_rmvpe = RMVPE(
136
+ "rmvpe.pt", is_half=self.is_half, device=self.device
137
+ )
138
+ f0 = self.model_rmvpe.infer_from_audio(x, thred=0.03)
139
  f0 *= pow(2, f0_up_key / 12)
140
  # with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
141
  tf0 = self.sr // self.window # 每秒f0点数
 
158
  ) + 1
159
  f0_mel[f0_mel <= 1] = 1
160
  f0_mel[f0_mel > 255] = 255
161
+ f0_coarse = np.rint(f0_mel).astype(np.int)
162
  return f0_coarse, f0bak # 1-0
163
 
164
  def vc(
 
174
  big_npy,
175
  index_rate,
176
  version,
177
+ protect,
178
  ): # ,file_index,file_big_npy
179
  feats = torch.from_numpy(audio0)
180
  if self.is_half:
 
195
  t0 = ttime()
196
  with torch.no_grad():
197
  logits = model.extract_features(**inputs)
198
+ feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
199
+ if protect < 0.5 and pitch != None and pitchf != None:
200
+ feats0 = feats.clone()
201
  if (
202
  isinstance(index, type(None)) == False
203
  and isinstance(big_npy, type(None)) == False
 
223
  )
224
 
225
  feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
226
+ if protect < 0.5 and pitch != None and pitchf != None:
227
+ feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
228
+ 0, 2, 1
229
+ )
230
  t1 = ttime()
231
  p_len = audio0.shape[0] // self.window
232
  if feats.shape[1] < p_len:
 
234
  if pitch != None and pitchf != None:
235
  pitch = pitch[:, :p_len]
236
  pitchf = pitchf[:, :p_len]
237
+
238
+ if protect < 0.5 and pitch != None and pitchf != None:
239
+ pitchff = pitchf.clone()
240
+ pitchff[pitchf > 0] = 1
241
+ pitchff[pitchf < 1] = protect
242
+ pitchff = pitchff.unsqueeze(-1)
243
+ feats = feats * pitchff + feats0 * (1 - pitchff)
244
+ feats = feats.to(feats0.dtype)
245
  p_len = torch.tensor([p_len], device=self.device).long()
246
  with torch.no_grad():
247
  if pitch != None and pitchf != None:
 
253
  )
254
  else:
255
  audio1 = (
256
+ (net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy()
 
 
 
257
  )
258
  del feats, p_len, padding_mask
259
  if torch.cuda.is_available():
 
282
  resample_sr,
283
  rms_mix_rate,
284
  version,
285
+ protect,
286
  f0_file=None,
287
  ):
288
  if (
 
337
  sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
338
  pitch, pitchf = None, None
339
  if if_f0 == 1:
340
+ pitch, pitchf = self.get_f0(
341
+ input_audio_path,
342
+ audio_pad,
343
+ p_len,
344
+ f0_up_key,
345
+ f0_method,
346
+ filter_radius,
347
+ inp_f0,
348
+ )
349
  pitch = pitch[:p_len]
350
  pitchf = pitchf[:p_len]
351
  if self.device == "mps":
 
370
  big_npy,
371
  index_rate,
372
  version,
373
+ protect,
374
  )[self.t_pad_tgt : -self.t_pad_tgt]
375
  )
376
  else:
 
387
  big_npy,
388
  index_rate,
389
  version,
390
+ protect,
391
  )[self.t_pad_tgt : -self.t_pad_tgt]
392
  )
393
  s = t
 
405
  big_npy,
406
  index_rate,
407
  version,
408
+ protect,
409
  )[self.t_pad_tgt : -self.t_pad_tgt]
410
  )
411
  else:
 
422
  big_npy,
423
  index_rate,
424
  version,
425
+ protect,
426
  )[self.t_pad_tgt : -self.t_pad_tgt]
427
  )
428
  audio_opt = np.concatenate(audio_opt)
429
+ if rms_mix_rate != 1:
430
+ audio_opt = change_rms(audio, 16000, audio_opt, tgt_sr, rms_mix_rate)
431
+ if resample_sr >= 16000 and tgt_sr != resample_sr:
432
  audio_opt = librosa.resample(
433
  audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
434
  )
435
+ audio_max = np.abs(audio_opt).max() / 0.99
436
+ max_int16 = 32768
437
+ if audio_max > 1:
438
+ max_int16 /= audio_max
439
+ audio_opt = (audio_opt * max_int16).astype(np.int16)
440
  del pitch, pitchf, sid
441
  if torch.cuda.is_available():
442
  torch.cuda.empty_cache()