darksakura commited on
Commit
6f5bbf2
·
1 Parent(s): 01ec3d2

Upload 22 files

Browse files
auto_slicer.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import final
3
+ import numpy as np
4
+ import librosa
5
+ import soundfile as sf
6
+ from modules.slicer2 import Slicer
7
+
8
+ class AutoSlicer:
9
+ def __init__(self):
10
+ self.slicer_params = {
11
+ "threshold": -40,
12
+ "min_length": 5000,
13
+ "min_interval": 300,
14
+ "hop_size": 10,
15
+ "max_sil_kept": 500,
16
+ }
17
+ self.original_min_interval = self.slicer_params["min_interval"]
18
+
19
+ def auto_slice(self, filename, input_dir, output_dir, max_sec):
20
+ audio, sr = librosa.load(os.path.join(input_dir, filename), sr=None, mono=False)
21
+ slicer = Slicer(sr=sr, **self.slicer_params)
22
+ chunks = slicer.slice(audio)
23
+ files_to_delete = []
24
+ for i, chunk in enumerate(chunks):
25
+ if len(chunk.shape) > 1:
26
+ chunk = chunk.T
27
+ output_filename = f"{os.path.splitext(filename)[0]}_{i}"
28
+ output_filename = "".join(c for c in output_filename if c.isascii() or c == "_") + ".wav"
29
+ output_filepath = os.path.join(output_dir, output_filename)
30
+ sf.write(output_filepath, chunk, sr)
31
+ #Check and re-slice audio that more than max_sec.
32
+ while True:
33
+ new_audio, sr = librosa.load(output_filepath, sr=None, mono=False)
34
+ if librosa.get_duration(y=new_audio, sr=sr) <= max_sec:
35
+ break
36
+ self.slicer_params["min_interval"] = self.slicer_params["min_interval"] // 2
37
+ if self.slicer_params["min_interval"] >= self.slicer_params["hop_size"]:
38
+ new_chunks = Slicer(sr=sr, **self.slicer_params).slice(new_audio)
39
+ for j, new_chunk in enumerate(new_chunks):
40
+ if len(new_chunk.shape) > 1:
41
+ new_chunk = new_chunk.T
42
+ new_output_filename = f"{os.path.splitext(output_filename)[0]}_{j}.wav"
43
+ sf.write(os.path.join(output_dir, new_output_filename), new_chunk, sr)
44
+ files_to_delete.append(output_filepath)
45
+ else:
46
+ break
47
+ self.slicer_params["min_interval"] = self.original_min_interval
48
+ for file_path in files_to_delete:
49
+ if os.path.exists(file_path):
50
+ os.remove(file_path)
51
+
52
+ def merge_short(self, output_dir, max_sec, min_sec):
53
+ short_files = []
54
+ for filename in os.listdir(output_dir):
55
+ filepath = os.path.join(output_dir, filename)
56
+ if filename.endswith(".wav"):
57
+ audio, sr = librosa.load(filepath, sr=None, mono=False)
58
+ duration = librosa.get_duration(y=audio, sr=sr)
59
+ if duration < min_sec:
60
+ short_files.append((filepath, audio, duration))
61
+ short_files.sort(key=lambda x: x[2], reverse=True)
62
+ merged_audio = []
63
+ current_duration = 0
64
+ for filepath, audio, duration in short_files:
65
+ if current_duration + duration <= max_sec:
66
+ merged_audio.append(audio)
67
+ current_duration += duration
68
+ os.remove(filepath)
69
+ else:
70
+ if merged_audio:
71
+ output_audio = np.concatenate(merged_audio, axis=-1)
72
+ if len(output_audio.shape) > 1:
73
+ output_audio = output_audio.T
74
+ output_filename = f"merged_{len(os.listdir(output_dir))}.wav"
75
+ sf.write(os.path.join(output_dir, output_filename), output_audio, sr)
76
+ merged_audio = [audio]
77
+ current_duration = duration
78
+ os.remove(filepath)
79
+ if merged_audio and current_duration >= min_sec:
80
+ output_audio = np.concatenate(merged_audio, axis=-1)
81
+ if len(output_audio.shape) > 1:
82
+ output_audio = output_audio.T
83
+ output_filename = f"merged_{len(os.listdir(output_dir))}.wav"
84
+ sf.write(os.path.join(output_dir, output_filename), output_audio, sr)
85
+
86
+ def slice_count(self, input_dir, output_dir):
87
+ orig_duration = final_duration = 0
88
+ for file in os.listdir(input_dir):
89
+ if file.endswith(".wav"):
90
+ _audio, _sr = librosa.load(os.path.join(input_dir, file), sr=None, mono=False)
91
+ orig_duration += librosa.get_duration(y=_audio, sr=_sr)
92
+ wav_files = [file for file in os.listdir(output_dir) if file.endswith(".wav")]
93
+ num_files = len(wav_files)
94
+ max_duration = -1
95
+ min_duration = float("inf")
96
+ for file in wav_files:
97
+ file_path = os.path.join(output_dir, file)
98
+ audio, sr = librosa.load(file_path, sr=None, mono=False)
99
+ duration = librosa.get_duration(y=audio, sr=sr)
100
+ final_duration += float(duration)
101
+ if duration > max_duration:
102
+ max_duration = float(duration)
103
+ if duration < min_duration:
104
+ min_duration = float(duration)
105
+ return num_files, max_duration, min_duration, orig_duration, final_duration
106
+
107
+
flask_api.py CHANGED
@@ -7,7 +7,7 @@ import torchaudio
7
  from flask import Flask, request, send_file
8
  from flask_cors import CORS
9
 
10
- from inference.infer_tool import Svc, RealTimeVC
11
 
12
  app = Flask(__name__)
13
 
 
7
  from flask import Flask, request, send_file
8
  from flask_cors import CORS
9
 
10
+ from inference.infer_tool import RealTimeVC, Svc
11
 
12
  app = Flask(__name__)
13
 
inference/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (131 Bytes). View file
 
inference/__pycache__/infer_tool.cpython-38.pyc ADDED
Binary file (14.8 kB). View file
 
inference/__pycache__/infer_tool_webui.cpython-38.pyc ADDED
Binary file (14.9 kB). View file
 
inference/__pycache__/slicer.cpython-38.pyc ADDED
Binary file (3.83 kB). View file
 
inference/infer_tool_webui.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import hashlib
3
+ import io
4
+ import json
5
+ import logging
6
+ import os
7
+ import pickle
8
+ import time
9
+ from pathlib import Path
10
+
11
+ import librosa
12
+ import numpy as np
13
+
14
+ # import onnxruntime
15
+ import soundfile
16
+ import torch
17
+ import torchaudio
18
+ from tqdm import tqdm
19
+
20
+ import cluster
21
+ import utils
22
+ from diffusion.unit2mel import load_model_vocoder
23
+ from inference import slicer
24
+ from models import SynthesizerTrn
25
+
26
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
27
+
28
+
29
+ def read_temp(file_name):
30
+ if not os.path.exists(file_name):
31
+ with open(file_name, "w") as f:
32
+ f.write(json.dumps({"info": "temp_dict"}))
33
+ return {}
34
+ else:
35
+ try:
36
+ with open(file_name, "r") as f:
37
+ data = f.read()
38
+ data_dict = json.loads(data)
39
+ if os.path.getsize(file_name) > 50 * 1024 * 1024:
40
+ f_name = file_name.replace("\\", "/").split("/")[-1]
41
+ print(f"clean {f_name}")
42
+ for wav_hash in list(data_dict.keys()):
43
+ if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600:
44
+ del data_dict[wav_hash]
45
+ except Exception as e:
46
+ print(e)
47
+ print(f"{file_name} error,auto rebuild file")
48
+ data_dict = {"info": "temp_dict"}
49
+ return data_dict
50
+
51
+
52
+ def write_temp(file_name, data):
53
+ with open(file_name, "w") as f:
54
+ f.write(json.dumps(data))
55
+
56
+
57
+ def timeit(func):
58
+ def run(*args, **kwargs):
59
+ t = time.time()
60
+ res = func(*args, **kwargs)
61
+ print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
62
+ return res
63
+
64
+ return run
65
+
66
+
67
+ def format_wav(audio_path):
68
+ if Path(audio_path).suffix == '.wav':
69
+ return
70
+ raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True, sr=None)
71
+ soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
72
+
73
+
74
+ def get_end_file(dir_path, end):
75
+ file_lists = []
76
+ for root, dirs, files in os.walk(dir_path):
77
+ files = [f for f in files if f[0] != '.']
78
+ dirs[:] = [d for d in dirs if d[0] != '.']
79
+ for f_file in files:
80
+ if f_file.endswith(end):
81
+ file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
82
+ return file_lists
83
+
84
+
85
+ def get_md5(content):
86
+ return hashlib.new("md5", content).hexdigest()
87
+
88
+ def fill_a_to_b(a, b):
89
+ if len(a) < len(b):
90
+ for _ in range(0, len(b) - len(a)):
91
+ a.append(a[0])
92
+
93
+ def mkdir(paths: list):
94
+ for path in paths:
95
+ if not os.path.exists(path):
96
+ os.mkdir(path)
97
+
98
+ def pad_array(arr, target_length):
99
+ current_length = arr.shape[0]
100
+ if current_length >= target_length:
101
+ return arr
102
+ else:
103
+ pad_width = target_length - current_length
104
+ pad_left = pad_width // 2
105
+ pad_right = pad_width - pad_left
106
+ padded_arr = np.pad(arr, (pad_left, pad_right), 'constant', constant_values=(0, 0))
107
+ return padded_arr
108
+
109
+ def split_list_by_n(list_collection, n, pre=0):
110
+ for i in range(0, len(list_collection), n):
111
+ yield list_collection[i-pre if i-pre>=0 else i: i + n]
112
+
113
+
114
+ class F0FilterException(Exception):
115
+ pass
116
+
117
+ class Svc(object):
118
+ def __init__(self, net_g_path, config_path,
119
+ device=None,
120
+ cluster_model_path="logs/44k/kmeans_10000.pt",
121
+ nsf_hifigan_enhance = False,
122
+ diffusion_model_path="logs/44k/diffusion/model_0.pt",
123
+ diffusion_config_path="configs/diffusion.yaml",
124
+ shallow_diffusion = False,
125
+ only_diffusion = False,
126
+ spk_mix_enable = False,
127
+ feature_retrieval = False
128
+ ):
129
+ self.net_g_path = net_g_path
130
+ self.only_diffusion = only_diffusion
131
+ self.shallow_diffusion = shallow_diffusion
132
+ self.feature_retrieval = feature_retrieval
133
+ if device is None:
134
+ self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
135
+ else:
136
+ self.dev = torch.device(device)
137
+ self.net_g_ms = None
138
+ if not self.only_diffusion:
139
+ self.hps_ms = utils.get_hparams_from_file(config_path,True)
140
+ self.target_sample = self.hps_ms.data.sampling_rate
141
+ self.hop_size = self.hps_ms.data.hop_length
142
+ self.spk2id = self.hps_ms.spk
143
+ self.unit_interpolate_mode = self.hps_ms.data.unit_interpolate_mode if self.hps_ms.data.unit_interpolate_mode is not None else 'left'
144
+ self.vol_embedding = self.hps_ms.model.vol_embedding if self.hps_ms.model.vol_embedding is not None else False
145
+ self.speech_encoder = self.hps_ms.model.speech_encoder if self.hps_ms.model.speech_encoder is not None else 'vec768l12'
146
+
147
+ self.nsf_hifigan_enhance = nsf_hifigan_enhance
148
+ if self.shallow_diffusion or self.only_diffusion:
149
+ if os.path.exists(diffusion_model_path) and os.path.exists(diffusion_model_path):
150
+ self.diffusion_model,self.vocoder,self.diffusion_args = load_model_vocoder(diffusion_model_path,self.dev,config_path=diffusion_config_path)
151
+ if self.only_diffusion:
152
+ self.target_sample = self.diffusion_args.data.sampling_rate
153
+ self.hop_size = self.diffusion_args.data.block_size
154
+ self.spk2id = self.diffusion_args.spk
155
+ self.speech_encoder = self.diffusion_args.data.encoder
156
+ self.unit_interpolate_mode = self.diffusion_args.data.unit_interpolate_mode if self.diffusion_args.data.unit_interpolate_mode is not None else 'left'
157
+ if spk_mix_enable:
158
+ self.diffusion_model.init_spkmix(len(self.spk2id))
159
+ else:
160
+ print("No diffusion model or config found. Shallow diffusion mode will False")
161
+ self.shallow_diffusion = self.only_diffusion = False
162
+
163
+ # load hubert and model
164
+ if not self.only_diffusion:
165
+ self.load_model(spk_mix_enable)
166
+ self.hubert_model = utils.get_speech_encoder(self.speech_encoder,device=self.dev)
167
+ self.volume_extractor = utils.Volume_Extractor(self.hop_size)
168
+ else:
169
+ self.hubert_model = utils.get_speech_encoder(self.diffusion_args.data.encoder,device=self.dev)
170
+ self.volume_extractor = utils.Volume_Extractor(self.diffusion_args.data.block_size)
171
+
172
+ if os.path.exists(cluster_model_path):
173
+ if self.feature_retrieval:
174
+ with open(cluster_model_path,"rb") as f:
175
+ self.cluster_model = pickle.load(f)
176
+ self.big_npy = None
177
+ self.now_spk_id = -1
178
+ else:
179
+ self.cluster_model = cluster.get_cluster_model(cluster_model_path)
180
+ else:
181
+ self.feature_retrieval=False
182
+
183
+ if self.shallow_diffusion :
184
+ self.nsf_hifigan_enhance = False
185
+ if self.nsf_hifigan_enhance:
186
+ from modules.enhancer import Enhancer
187
+ self.enhancer = Enhancer('nsf-hifigan', 'pretrain/nsf_hifigan/model',device=self.dev)
188
+
189
+ def load_model(self, spk_mix_enable=False):
190
+ # get model configuration
191
+ self.net_g_ms = SynthesizerTrn(
192
+ self.hps_ms.data.filter_length // 2 + 1,
193
+ self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
194
+ **self.hps_ms.model)
195
+ _ = utils.load_checkpoint(self.net_g_path, self.net_g_ms, None)
196
+ self.dtype = list(self.net_g_ms.parameters())[0].dtype
197
+ if "half" in self.net_g_path and torch.cuda.is_available():
198
+ _ = self.net_g_ms.half().eval().to(self.dev)
199
+ else:
200
+ _ = self.net_g_ms.eval().to(self.dev)
201
+ if spk_mix_enable:
202
+ self.net_g_ms.EnableCharacterMix(len(self.spk2id), self.dev)
203
+
204
+ def get_unit_f0(self, wav, tran, cluster_infer_ratio, speaker, f0_filter ,f0_predictor,cr_threshold=0.05):
205
+
206
+ f0_predictor_object = utils.get_f0_predictor(f0_predictor,hop_length=self.hop_size,sampling_rate=self.target_sample,device=self.dev,threshold=cr_threshold)
207
+
208
+ f0, uv = f0_predictor_object.compute_f0_uv(wav)
209
+ if f0_filter and sum(f0) == 0:
210
+ raise F0FilterException("No voice detected")
211
+ f0 = torch.FloatTensor(f0).to(self.dev)
212
+ uv = torch.FloatTensor(uv).to(self.dev)
213
+
214
+ f0 = f0 * 2 ** (tran / 12)
215
+ f0 = f0.unsqueeze(0)
216
+ uv = uv.unsqueeze(0)
217
+
218
+ wav16k = librosa.resample(wav, orig_sr=self.target_sample, target_sr=16000)
219
+ wav16k = torch.from_numpy(wav16k).to(self.dev)
220
+ c = self.hubert_model.encoder(wav16k)
221
+ c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode)
222
+
223
+ if cluster_infer_ratio !=0:
224
+ if self.feature_retrieval:
225
+ speaker_id = self.spk2id.get(speaker)
226
+ if speaker_id is None:
227
+ raise RuntimeError("The name you entered is not in the speaker list!")
228
+ if not speaker_id and type(speaker) is int:
229
+ if len(self.spk2id.__dict__) >= speaker:
230
+ speaker_id = speaker
231
+ feature_index = self.cluster_model[speaker_id]
232
+ feat_np = c.transpose(0,1).cpu().numpy()
233
+ if self.big_npy is None or self.now_spk_id != speaker_id:
234
+ self.big_npy = feature_index.reconstruct_n(0, feature_index.ntotal)
235
+ self.now_spk_id = speaker_id
236
+ print("starting feature retrieval...")
237
+ score, ix = feature_index.search(feat_np, k=8)
238
+ weight = np.square(1 / score)
239
+ weight /= weight.sum(axis=1, keepdims=True)
240
+ npy = np.sum(self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
241
+ c = cluster_infer_ratio * npy + (1 - cluster_infer_ratio) * feat_np
242
+ c = torch.FloatTensor(c).to(self.dev).transpose(0,1)
243
+ print("end feature retrieval...")
244
+ else:
245
+ cluster_c = cluster.get_cluster_center_result(self.cluster_model, c.cpu().numpy().T, speaker).T
246
+ cluster_c = torch.FloatTensor(cluster_c).to(self.dev)
247
+ c = cluster_infer_ratio * cluster_c + (1 - cluster_infer_ratio) * c
248
+
249
+ c = c.unsqueeze(0)
250
+ return c, f0, uv
251
+
252
+ def infer(self, speaker, tran, raw_path,
253
+ cluster_infer_ratio=0,
254
+ auto_predict_f0=False,
255
+ noice_scale=0.4,
256
+ f0_filter=False,
257
+ f0_predictor='pm',
258
+ enhancer_adaptive_key = 0,
259
+ cr_threshold = 0.05,
260
+ k_step = 100,
261
+ frame = 0,
262
+ spk_mix = False,
263
+ second_encoding = False,
264
+ loudness_envelope_adjustment = 1
265
+ ):
266
+ wav, sr = librosa.load(raw_path, sr=self.target_sample)
267
+ if spk_mix:
268
+ c, f0, uv = self.get_unit_f0(wav, tran, 0, None, f0_filter,f0_predictor,cr_threshold=cr_threshold)
269
+ n_frames = f0.size(1)
270
+ sid = speaker[:, frame:frame+n_frames].transpose(0,1)
271
+ else:
272
+ speaker_id = self.spk2id.get(speaker)
273
+ if not speaker_id and type(speaker) is int:
274
+ if len(self.spk2id.__dict__) >= speaker:
275
+ speaker_id = speaker
276
+ if speaker_id is None:
277
+ raise RuntimeError("The name you entered is not in the speaker list!")
278
+ sid = torch.LongTensor([int(speaker_id)]).to(self.dev).unsqueeze(0)
279
+ c, f0, uv = self.get_unit_f0(wav, tran, cluster_infer_ratio, speaker, f0_filter,f0_predictor,cr_threshold=cr_threshold)
280
+ n_frames = f0.size(1)
281
+ c = c.to(self.dtype)
282
+ f0 = f0.to(self.dtype)
283
+ uv = uv.to(self.dtype)
284
+ with torch.no_grad():
285
+ start = time.time()
286
+ vol = None
287
+ if not self.only_diffusion:
288
+ vol = self.volume_extractor.extract(torch.FloatTensor(wav).to(self.dev)[None,:])[None,:].to(self.dev) if self.vol_embedding else None
289
+ audio,f0 = self.net_g_ms.infer(c, f0=f0, g=sid, uv=uv, predict_f0=auto_predict_f0, noice_scale=noice_scale,vol=vol)
290
+ audio = audio[0,0].data.float()
291
+ audio_mel = self.vocoder.extract(audio[None,:],self.target_sample) if self.shallow_diffusion else None
292
+ else:
293
+ audio = torch.FloatTensor(wav).to(self.dev)
294
+ audio_mel = None
295
+ if self.dtype != torch.float32:
296
+ c = c.to(torch.float32)
297
+ f0 = f0.to(torch.float32)
298
+ uv = uv.to(torch.float32)
299
+ if self.only_diffusion or self.shallow_diffusion:
300
+ vol = self.volume_extractor.extract(audio[None,:])[None,:,None].to(self.dev) if vol is None else vol[:,:,None]
301
+ if self.shallow_diffusion and second_encoding:
302
+ audio16k = librosa.resample(audio.detach().cpu().numpy(), orig_sr=self.target_sample, target_sr=16000)
303
+ audio16k = torch.from_numpy(audio16k).to(self.dev)
304
+ c = self.hubert_model.encoder(audio16k)
305
+ c = utils.repeat_expand_2d(c.squeeze(0), f0.shape[1],self.unit_interpolate_mode)
306
+ f0 = f0[:,:,None]
307
+ c = c.transpose(-1,-2)
308
+ audio_mel = self.diffusion_model(
309
+ c,
310
+ f0,
311
+ vol,
312
+ spk_id = sid,
313
+ spk_mix_dict = None,
314
+ gt_spec=audio_mel,
315
+ infer=True,
316
+ infer_speedup=self.diffusion_args.infer.speedup,
317
+ method=self.diffusion_args.infer.method,
318
+ k_step=k_step)
319
+ audio = self.vocoder.infer(audio_mel, f0).squeeze()
320
+ if self.nsf_hifigan_enhance:
321
+ audio, _ = self.enhancer.enhance(
322
+ audio[None,:],
323
+ self.target_sample,
324
+ f0[:,:,None],
325
+ self.hps_ms.data.hop_length,
326
+ adaptive_key = enhancer_adaptive_key)
327
+ if loudness_envelope_adjustment != 1:
328
+ audio = utils.change_rms(wav,self.target_sample,audio,self.target_sample,loudness_envelope_adjustment)
329
+ use_time = time.time() - start
330
+ print("vits use time:{}".format(use_time))
331
+ return audio, audio.shape[-1], n_frames
332
+
333
+ def clear_empty(self):
334
+ # clean up vram
335
+ torch.cuda.empty_cache()
336
+
337
+ def unload_model(self):
338
+ # unload model
339
+ self.net_g_ms = self.net_g_ms.to("cpu")
340
+ del self.net_g_ms
341
+ if hasattr(self,"enhancer"):
342
+ self.enhancer.enhancer = self.enhancer.enhancer.to("cpu")
343
+ del self.enhancer.enhancer
344
+ del self.enhancer
345
+ gc.collect()
346
+
347
+ def slice_inference(self,
348
+ raw_audio_path,
349
+ spk,
350
+ tran,
351
+ slice_db,
352
+ cluster_infer_ratio,
353
+ auto_predict_f0,
354
+ noice_scale,
355
+ pad_seconds=0.5,
356
+ clip_seconds=0,
357
+ lg_num=0,
358
+ lgr_num =0.75,
359
+ f0_predictor='pm',
360
+ enhancer_adaptive_key = 0,
361
+ cr_threshold = 0.05,
362
+ k_step = 100,
363
+ use_spk_mix = False,
364
+ second_encoding = False,
365
+ loudness_envelope_adjustment = 1
366
+ ):
367
+ if use_spk_mix:
368
+ if len(self.spk2id) == 1:
369
+ spk = self.spk2id.keys()[0]
370
+ use_spk_mix = False
371
+ wav_path = Path(raw_audio_path).with_suffix('.wav')
372
+ chunks = slicer.cut(wav_path, db_thresh=slice_db)
373
+ audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
374
+ per_size = int(clip_seconds*audio_sr)
375
+ lg_size = int(lg_num*audio_sr)
376
+ lg_size_r = int(lg_size*lgr_num)
377
+ lg_size_c_l = (lg_size-lg_size_r)//2
378
+ lg_size_c_r = lg_size-lg_size_r-lg_size_c_l
379
+ lg = np.linspace(0,1,lg_size_r) if lg_size!=0 else 0
380
+
381
+ if use_spk_mix:
382
+ assert len(self.spk2id) == len(spk)
383
+ audio_length = 0
384
+ for (slice_tag, data) in audio_data:
385
+ aud_length = int(np.ceil(len(data) / audio_sr * self.target_sample))
386
+ if slice_tag:
387
+ audio_length += aud_length // self.hop_size
388
+ continue
389
+ if per_size != 0:
390
+ datas = split_list_by_n(data, per_size,lg_size)
391
+ else:
392
+ datas = [data]
393
+ for k,dat in enumerate(datas):
394
+ pad_len = int(audio_sr * pad_seconds)
395
+ per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample))
396
+ a_length = per_length + 2 * pad_len
397
+ audio_length += a_length // self.hop_size
398
+ audio_length += len(audio_data)
399
+ spk_mix_tensor = torch.zeros(size=(len(spk), audio_length)).to(self.dev)
400
+ for i in range(len(spk)):
401
+ last_end = None
402
+ for mix in spk[i]:
403
+ if mix[3]<0. or mix[2]<0.:
404
+ raise RuntimeError("mix value must higer Than zero!")
405
+ begin = int(audio_length * mix[0])
406
+ end = int(audio_length * mix[1])
407
+ length = end - begin
408
+ if length<=0:
409
+ raise RuntimeError("begin Must lower Than end!")
410
+ step = (mix[3] - mix[2])/length
411
+ if last_end is not None:
412
+ if last_end != begin:
413
+ raise RuntimeError("[i]EndTime Must Equal [i+1]BeginTime!")
414
+ last_end = end
415
+ if step == 0.:
416
+ spk_mix_data = torch.zeros(length).to(self.dev) + mix[2]
417
+ else:
418
+ spk_mix_data = torch.arange(mix[2],mix[3],step).to(self.dev)
419
+ if(len(spk_mix_data)<length):
420
+ num_pad = length - len(spk_mix_data)
421
+ spk_mix_data = torch.nn.functional.pad(spk_mix_data, [0, num_pad], mode="reflect").to(self.dev)
422
+ spk_mix_tensor[i][begin:end] = spk_mix_data[:length]
423
+
424
+ spk_mix_ten = torch.sum(spk_mix_tensor,dim=0).unsqueeze(0).to(self.dev)
425
+ # spk_mix_tensor[0][spk_mix_ten<0.001] = 1.0
426
+ for i, x in enumerate(spk_mix_ten[0]):
427
+ if x == 0.0:
428
+ spk_mix_ten[0][i] = 1.0
429
+ spk_mix_tensor[:,i] = 1.0 / len(spk)
430
+ spk_mix_tensor = spk_mix_tensor / spk_mix_ten
431
+ if not ((torch.sum(spk_mix_tensor,dim=0) - 1.)<0.0001).all():
432
+ raise RuntimeError("sum(spk_mix_tensor) not equal 1")
433
+ spk = spk_mix_tensor
434
+
435
+ global_frame = 0
436
+ audio = []
437
+ for (slice_tag, data) in tqdm(audio_data):
438
+ print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
439
+ # padd
440
+ length = int(np.ceil(len(data) / audio_sr * self.target_sample))
441
+ if slice_tag:
442
+ print('jump empty segment')
443
+ _audio = np.zeros(length)
444
+ audio.extend(list(pad_array(_audio, length)))
445
+ global_frame += length // self.hop_size
446
+ continue
447
+ if per_size != 0:
448
+ datas = split_list_by_n(data, per_size,lg_size)
449
+ else:
450
+ datas = [data]
451
+ for k,dat in enumerate(datas):
452
+ per_length = int(np.ceil(len(dat) / audio_sr * self.target_sample)) if clip_seconds!=0 else length
453
+ if clip_seconds!=0:
454
+ print(f'###=====segment clip start, {round(len(dat) / audio_sr, 3)}s======')
455
+ # padd
456
+ pad_len = int(audio_sr * pad_seconds)
457
+ dat = np.concatenate([np.zeros([pad_len]), dat, np.zeros([pad_len])])
458
+ raw_path = io.BytesIO()
459
+ soundfile.write(raw_path, dat, audio_sr, format="wav")
460
+ raw_path.seek(0)
461
+ out_audio, out_sr, out_frame = self.infer(spk, tran, raw_path,
462
+ cluster_infer_ratio=cluster_infer_ratio,
463
+ auto_predict_f0=auto_predict_f0,
464
+ noice_scale=noice_scale,
465
+ f0_predictor = f0_predictor,
466
+ enhancer_adaptive_key = enhancer_adaptive_key,
467
+ cr_threshold = cr_threshold,
468
+ k_step = k_step,
469
+ frame = global_frame,
470
+ spk_mix = use_spk_mix,
471
+ second_encoding = second_encoding,
472
+ loudness_envelope_adjustment = loudness_envelope_adjustment
473
+ )
474
+ global_frame += out_frame
475
+ _audio = out_audio.cpu().numpy()
476
+ pad_len = int(self.target_sample * pad_seconds)
477
+ _audio = _audio[pad_len:-pad_len]
478
+ _audio = pad_array(_audio, per_length)
479
+ if lg_size!=0 and k!=0:
480
+ lg1 = audio[-(lg_size_r+lg_size_c_r):-lg_size_c_r] if lgr_num != 1 else audio[-lg_size:]
481
+ lg2 = _audio[lg_size_c_l:lg_size_c_l+lg_size_r] if lgr_num != 1 else _audio[0:lg_size]
482
+ lg_pre = lg1*(1-lg)+lg2*lg
483
+ audio = audio[0:-(lg_size_r+lg_size_c_r)] if lgr_num != 1 else audio[0:-lg_size]
484
+ audio.extend(lg_pre)
485
+ _audio = _audio[lg_size_c_l+lg_size_r:] if lgr_num != 1 else _audio[lg_size:]
486
+ audio.extend(list(_audio))
487
+ return np.array(audio)
488
+
489
+ class RealTimeVC:
490
+ def __init__(self):
491
+ self.last_chunk = None
492
+ self.last_o = None
493
+ self.chunk_len = 16000 # chunk length
494
+ self.pre_len = 3840 # cross fade length, multiples of 640
495
+
496
+ # Input and output are 1-dimensional numpy waveform arrays
497
+
498
+ def process(self, svc_model, speaker_id, f_pitch_change, input_wav_path,
499
+ cluster_infer_ratio=0,
500
+ auto_predict_f0=False,
501
+ noice_scale=0.4,
502
+ f0_filter=False):
503
+
504
+ import maad
505
+ audio, sr = torchaudio.load(input_wav_path)
506
+ audio = audio.cpu().numpy()[0]
507
+ temp_wav = io.BytesIO()
508
+ if self.last_chunk is None:
509
+ input_wav_path.seek(0)
510
+
511
+ audio, sr = svc_model.infer(speaker_id, f_pitch_change, input_wav_path,
512
+ cluster_infer_ratio=cluster_infer_ratio,
513
+ auto_predict_f0=auto_predict_f0,
514
+ noice_scale=noice_scale,
515
+ f0_filter=f0_filter)
516
+
517
+ audio = audio.cpu().numpy()
518
+ self.last_chunk = audio[-self.pre_len:]
519
+ self.last_o = audio
520
+ return audio[-self.chunk_len:]
521
+ else:
522
+ audio = np.concatenate([self.last_chunk, audio])
523
+ soundfile.write(temp_wav, audio, sr, format="wav")
524
+ temp_wav.seek(0)
525
+
526
+ audio, sr = svc_model.infer(speaker_id, f_pitch_change, temp_wav,
527
+ cluster_infer_ratio=cluster_infer_ratio,
528
+ auto_predict_f0=auto_predict_f0,
529
+ noice_scale=noice_scale,
530
+ f0_filter=f0_filter)
531
+
532
+ audio = audio.cpu().numpy()
533
+ ret = maad.util.crossfade(self.last_o, audio, self.pre_len)
534
+ self.last_chunk = audio[-self.pre_len:]
535
+ self.last_o = audio
536
+ return ret[self.chunk_len:2 * self.chunk_len]
537
+
preprocess_flist_config.py CHANGED
@@ -1,11 +1,13 @@
1
- import os
2
  import argparse
 
 
3
  import re
 
 
4
 
5
  from tqdm import tqdm
6
- from random import shuffle
7
- import json
8
- import wave
9
 
10
  config_template = json.load(open("configs_template/config_template.json"))
11
 
@@ -26,6 +28,8 @@ if __name__ == "__main__":
26
  parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list")
27
  parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list")
28
  parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir")
 
 
29
  args = parser.parse_args()
30
 
31
  train = []
@@ -41,8 +45,8 @@ if __name__ == "__main__":
41
  for file in wavs:
42
  if not file.endswith("wav"):
43
  continue
44
- #if not pattern.match(file):
45
- # print(f"warning:文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
46
  if get_wav_duration(file) < 0.3:
47
  print("skip too short audio:", file)
48
  continue
@@ -67,9 +71,34 @@ if __name__ == "__main__":
67
  wavpath = fname
68
  f.write(wavpath + "\n")
69
 
 
 
 
 
 
 
70
  config_template["spk"] = spk_dict
71
  config_template["model"]["n_speakers"] = spk_id
72
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  print("Writing configs/config.json")
74
  with open("configs/config.json", "w") as f:
75
  json.dump(config_template, f, indent=2)
 
 
 
 
1
  import argparse
2
+ import json
3
+ import os
4
  import re
5
+ import wave
6
+ from random import shuffle
7
 
8
  from tqdm import tqdm
9
+
10
+ import diffusion.logger.utils as du
 
11
 
12
  config_template = json.load(open("configs_template/config_template.json"))
13
 
 
28
  parser.add_argument("--train_list", type=str, default="./filelists/train.txt", help="path to train list")
29
  parser.add_argument("--val_list", type=str, default="./filelists/val.txt", help="path to val list")
30
  parser.add_argument("--source_dir", type=str, default="./dataset/44k", help="path to source dir")
31
+ parser.add_argument("--speech_encoder", type=str, default="vec768l12", help="choice a speech encoder|'vec768l12','vec256l9','hubertsoft','whisper-ppg','cnhubertlarge','dphubert','whisper-ppg-large','wavlmbase+'")
32
+ parser.add_argument("--vol_aug", action="store_true", help="Whether to use volume embedding and volume augmentation")
33
  args = parser.parse_args()
34
 
35
  train = []
 
45
  for file in wavs:
46
  if not file.endswith("wav"):
47
  continue
48
+ if not pattern.match(file):
49
+ print(f"warning:文件名{file}中包含非字母数字下划线,可能会导致错误。(也可能不会)")
50
  if get_wav_duration(file) < 0.3:
51
  print("skip too short audio:", file)
52
  continue
 
71
  wavpath = fname
72
  f.write(wavpath + "\n")
73
 
74
+
75
+ d_config_template = du.load_config("configs_template/diffusion_template.yaml")
76
+ d_config_template["model"]["n_spk"] = spk_id
77
+ d_config_template["data"]["encoder"] = args.speech_encoder
78
+ d_config_template["spk"] = spk_dict
79
+
80
  config_template["spk"] = spk_dict
81
  config_template["model"]["n_speakers"] = spk_id
82
+ config_template["model"]["speech_encoder"] = args.speech_encoder
83
+
84
+ if args.speech_encoder == "vec768l12" or args.speech_encoder == "dphubert" or args.speech_encoder == "wavlmbase+":
85
+ config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 768
86
+ d_config_template["data"]["encoder_out_channels"] = 768
87
+ elif args.speech_encoder == "vec256l9" or args.speech_encoder == 'hubertsoft':
88
+ config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 256
89
+ d_config_template["data"]["encoder_out_channels"] = 256
90
+ elif args.speech_encoder == "whisper-ppg" or args.speech_encoder == 'cnhubertlarge':
91
+ config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1024
92
+ d_config_template["data"]["encoder_out_channels"] = 1024
93
+ elif args.speech_encoder == "whisper-ppg-large":
94
+ config_template["model"]["ssl_dim"] = config_template["model"]["filter_channels"] = config_template["model"]["gin_channels"] = 1280
95
+ d_config_template["data"]["encoder_out_channels"] = 1280
96
+
97
+ if args.vol_aug:
98
+ config_template["train"]["vol_aug"] = config_template["model"]["vol_embedding"] = True
99
+
100
  print("Writing configs/config.json")
101
  with open("configs/config.json", "w") as f:
102
  json.dump(config_template, f, indent=2)
103
+ print("Writing configs/diffusion.yaml")
104
+ du.save_config("configs/diffusion.yaml",d_config_template)
preprocess_hubert_f0.py CHANGED
@@ -1,43 +1,54 @@
1
- import math
 
2
  import multiprocessing
3
  import os
4
- import argparse
 
 
5
  from random import shuffle
6
 
 
 
7
  import torch
8
- from glob import glob
9
  from tqdm import tqdm
10
- from modules.mel_processing import spectrogram_torch
11
 
 
12
  import utils
13
- import logging
 
14
 
15
  logging.getLogger("numba").setLevel(logging.WARNING)
16
- import librosa
17
- import numpy as np
18
 
19
  hps = utils.get_hparams_from_file("configs/config.json")
 
20
  sampling_rate = hps.data.sampling_rate
21
  hop_length = hps.data.hop_length
 
22
 
23
 
24
- def process_one(filename, hmodel):
25
  # print(filename)
26
  wav, sr = librosa.load(filename, sr=sampling_rate)
 
 
 
 
27
  soft_path = filename + ".soft.pt"
28
  if not os.path.exists(soft_path):
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  wav16k = librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000)
31
  wav16k = torch.from_numpy(wav16k).to(device)
32
- c = utils.get_hubert_content(hmodel, wav_16k_tensor=wav16k)
33
  torch.save(c.cpu(), soft_path)
34
 
35
  f0_path = filename + ".f0.npy"
36
  if not os.path.exists(f0_path):
37
- f0 = utils.compute_f0_dio(
38
- wav, sampling_rate=sampling_rate, hop_length=hop_length
 
39
  )
40
- np.save(f0_path, f0)
 
41
 
42
  spec_path = filename.replace(".wav", ".spec.pt")
43
  if not os.path.exists(spec_path):
@@ -45,7 +56,6 @@ def process_one(filename, hmodel):
45
  # The following code can't be replaced by torch.FloatTensor(wav)
46
  # because load_wav_to_torch return a tensor that need to be normalized
47
 
48
- audio, sr = utils.load_wav_to_torch(filename)
49
  if sr != hps.data.sampling_rate:
50
  raise ValueError(
51
  "{} SR doesn't match target {} SR".format(
@@ -53,8 +63,7 @@ def process_one(filename, hmodel):
53
  )
54
  )
55
 
56
- audio_norm = audio / hps.data.max_wav_value
57
- audio_norm = audio_norm.unsqueeze(0)
58
 
59
  spec = spectrogram_torch(
60
  audio_norm,
@@ -67,35 +76,88 @@ def process_one(filename, hmodel):
67
  spec = torch.squeeze(spec, 0)
68
  torch.save(spec, spec_path)
69
 
70
-
71
- def process_batch(filenames):
72
- print("Loading hubert for content...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  device = "cuda" if torch.cuda.is_available() else "cpu"
74
- hmodel = utils.get_hubert_model().to(device)
75
- print("Loaded hubert.")
76
- for filename in tqdm(filenames):
77
- process_one(filename, hmodel)
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  if __name__ == "__main__":
81
  parser = argparse.ArgumentParser()
82
  parser.add_argument(
83
  "--in_dir", type=str, default="dataset/44k", help="path to input dir"
84
  )
85
-
 
 
 
 
 
 
 
 
 
86
  args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
87
  filenames = glob(f"{args.in_dir}/*/*.wav", recursive=True) # [:10]
88
  shuffle(filenames)
89
  multiprocessing.set_start_method("spawn", force=True)
90
 
91
- num_processes = 1
92
- chunk_size = int(math.ceil(len(filenames) / num_processes))
93
- chunks = [
94
- filenames[i : i + chunk_size] for i in range(0, len(filenames), chunk_size)
95
- ]
96
- print([len(c) for c in chunks])
97
- processes = [
98
- multiprocessing.Process(target=process_batch, args=(chunk,)) for chunk in chunks
99
- ]
100
- for p in processes:
101
- p.start()
 
1
+ import argparse
2
+ import logging
3
  import multiprocessing
4
  import os
5
+ import random
6
+ from concurrent.futures import ProcessPoolExecutor
7
+ from glob import glob
8
  from random import shuffle
9
 
10
+ import librosa
11
+ import numpy as np
12
  import torch
 
13
  from tqdm import tqdm
 
14
 
15
+ import diffusion.logger.utils as du
16
  import utils
17
+ from diffusion.vocoder import Vocoder
18
+ from modules.mel_processing import spectrogram_torch
19
 
20
  logging.getLogger("numba").setLevel(logging.WARNING)
21
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
 
22
 
23
  hps = utils.get_hparams_from_file("configs/config.json")
24
+ dconfig = du.load_config("configs/diffusion.yaml")
25
  sampling_rate = hps.data.sampling_rate
26
  hop_length = hps.data.hop_length
27
+ speech_encoder = hps["model"]["speech_encoder"]
28
 
29
 
30
+ def process_one(filename, hmodel,f0p,diff=False,mel_extractor=None):
31
  # print(filename)
32
  wav, sr = librosa.load(filename, sr=sampling_rate)
33
+ audio_norm = torch.FloatTensor(wav)
34
+ audio_norm = audio_norm.unsqueeze(0)
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
  soft_path = filename + ".soft.pt"
38
  if not os.path.exists(soft_path):
 
39
  wav16k = librosa.resample(wav, orig_sr=sampling_rate, target_sr=16000)
40
  wav16k = torch.from_numpy(wav16k).to(device)
41
+ c = hmodel.encoder(wav16k)
42
  torch.save(c.cpu(), soft_path)
43
 
44
  f0_path = filename + ".f0.npy"
45
  if not os.path.exists(f0_path):
46
+ f0_predictor = utils.get_f0_predictor(f0p,sampling_rate=sampling_rate, hop_length=hop_length,device=None,threshold=0.05)
47
+ f0,uv = f0_predictor.compute_f0_uv(
48
+ wav
49
  )
50
+ np.save(f0_path, np.asanyarray((f0,uv),dtype=object))
51
+
52
 
53
  spec_path = filename.replace(".wav", ".spec.pt")
54
  if not os.path.exists(spec_path):
 
56
  # The following code can't be replaced by torch.FloatTensor(wav)
57
  # because load_wav_to_torch return a tensor that need to be normalized
58
 
 
59
  if sr != hps.data.sampling_rate:
60
  raise ValueError(
61
  "{} SR doesn't match target {} SR".format(
 
63
  )
64
  )
65
 
66
+ #audio_norm = audio / hps.data.max_wav_value
 
67
 
68
  spec = spectrogram_torch(
69
  audio_norm,
 
76
  spec = torch.squeeze(spec, 0)
77
  torch.save(spec, spec_path)
78
 
79
+ if diff or hps.model.vol_embedding:
80
+ volume_path = filename + ".vol.npy"
81
+ volume_extractor = utils.Volume_Extractor(hop_length)
82
+ if not os.path.exists(volume_path):
83
+ volume = volume_extractor.extract(audio_norm)
84
+ np.save(volume_path, volume.to('cpu').numpy())
85
+
86
+ if diff:
87
+ mel_path = filename + ".mel.npy"
88
+ if not os.path.exists(mel_path) and mel_extractor is not None:
89
+ mel_t = mel_extractor.extract(audio_norm.to(device), sampling_rate)
90
+ mel = mel_t.squeeze().to('cpu').numpy()
91
+ np.save(mel_path, mel)
92
+ aug_mel_path = filename + ".aug_mel.npy"
93
+ aug_vol_path = filename + ".aug_vol.npy"
94
+ max_amp = float(torch.max(torch.abs(audio_norm))) + 1e-5
95
+ max_shift = min(1, np.log10(1/max_amp))
96
+ log10_vol_shift = random.uniform(-1, max_shift)
97
+ keyshift = random.uniform(-5, 5)
98
+ if mel_extractor is not None:
99
+ aug_mel_t = mel_extractor.extract(audio_norm * (10 ** log10_vol_shift), sampling_rate, keyshift = keyshift)
100
+ aug_mel = aug_mel_t.squeeze().to('cpu').numpy()
101
+ aug_vol = volume_extractor.extract(audio_norm * (10 ** log10_vol_shift))
102
+ if not os.path.exists(aug_mel_path):
103
+ np.save(aug_mel_path,np.asanyarray((aug_mel,keyshift),dtype=object))
104
+ if not os.path.exists(aug_vol_path):
105
+ np.save(aug_vol_path,aug_vol.to('cpu').numpy())
106
+
107
+ def process_batch(file_chunk, f0p, diff=False, mel_extractor=None):
108
+ print("Loading speech encoder for content...")
109
  device = "cuda" if torch.cuda.is_available() else "cpu"
110
+ hmodel = utils.get_speech_encoder(speech_encoder, device=device)
111
+ print("Loaded speech encoder.")
 
 
112
 
113
+ for filename in tqdm(file_chunk):
114
+ process_one(filename, hmodel, f0p, diff, mel_extractor)
115
+
116
+ def parallel_process(filenames, num_processes, f0p, diff, mel_extractor):
117
+ with ProcessPoolExecutor(max_workers=num_processes) as executor:
118
+ tasks = []
119
+ for i in range(num_processes):
120
+ start = int(i * len(filenames) / num_processes)
121
+ end = int((i + 1) * len(filenames) / num_processes)
122
+ file_chunk = filenames[start:end]
123
+ tasks.append(executor.submit(process_batch, file_chunk, f0p, diff, mel_extractor))
124
+
125
+ for task in tqdm(tasks):
126
+ task.result()
127
 
128
  if __name__ == "__main__":
129
  parser = argparse.ArgumentParser()
130
  parser.add_argument(
131
  "--in_dir", type=str, default="dataset/44k", help="path to input dir"
132
  )
133
+ parser.add_argument(
134
+ '--use_diff',action='store_true', help='Whether to use the diffusion model'
135
+ )
136
+ parser.add_argument(
137
+ '--f0_predictor', type=str, default="dio", help='Select F0 predictor, can select crepe,pm,dio,harvest,rmvpe, default pm(note: crepe is original F0 using mean filter)'
138
+ )
139
+ parser.add_argument(
140
+ '--num_processes', type=int, default=1, help='You are advised to set the number of processes to the same as the number of CPU cores'
141
+ )
142
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
  args = parser.parse_args()
144
+ f0p = args.f0_predictor
145
+ print(speech_encoder)
146
+ print(f0p)
147
+ print(args.use_diff)
148
+ if args.use_diff:
149
+ print("use_diff")
150
+ print("Loading Mel Extractor...")
151
+ mel_extractor = Vocoder(dconfig.vocoder.type, dconfig.vocoder.ckpt, device = device)
152
+ print("Loaded Mel Extractor.")
153
+ else:
154
+ mel_extractor = None
155
  filenames = glob(f"{args.in_dir}/*/*.wav", recursive=True) # [:10]
156
  shuffle(filenames)
157
  multiprocessing.set_start_method("spawn", force=True)
158
 
159
+ num_processes = args.num_processes
160
+ if num_processes == 0:
161
+ num_processes = os.cpu_count()
162
+
163
+ parallel_process(filenames, num_processes, f0p, args.use_diff, mel_extractor)
 
 
 
 
 
 
resample.py CHANGED
@@ -1,48 +1,98 @@
1
- import os
2
  import argparse
 
 
 
 
 
3
  import librosa
4
  import numpy as np
5
- from multiprocessing import Pool, cpu_count
6
  from scipy.io import wavfile
7
  from tqdm import tqdm
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def process(item):
11
  spkdir, wav_name, args = item
12
- # speaker 's5', 'p280', 'p315' are excluded,
13
  speaker = spkdir.replace("\\", "/").split("/")[-1]
 
14
  wav_path = os.path.join(args.in_dir, speaker, wav_name)
15
  if os.path.exists(wav_path) and '.wav' in wav_path:
16
  os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True)
17
- wav, sr = librosa.load(wav_path, sr=None)
18
- wav, _ = librosa.effects.trim(wav, top_db=20)
19
- peak = np.abs(wav).max()
20
- if peak > 1.0:
21
- wav = 0.98 * wav / peak
22
- wav2 = librosa.resample(wav, orig_sr=sr, target_sr=args.sr2)
23
- wav2 /= max(wav2.max(), -wav2.min())
24
- save_name = wav_name
25
- save_path2 = os.path.join(args.out_dir2, speaker, save_name)
26
- wavfile.write(
27
- save_path2,
28
- args.sr2,
29
- (wav2 * np.iinfo(np.int16).max).astype(np.int16)
30
- )
31
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  if __name__ == "__main__":
35
  parser = argparse.ArgumentParser()
36
  parser.add_argument("--sr2", type=int, default=44100, help="sampling rate")
37
  parser.add_argument("--in_dir", type=str, default="./dataset_raw", help="path to source dir")
38
  parser.add_argument("--out_dir2", type=str, default="./dataset/44k", help="path to target dir")
 
39
  args = parser.parse_args()
40
- processs = cpu_count()-2 if cpu_count() >4 else 1
41
- pool = Pool(processes=processs)
42
-
43
- for speaker in os.listdir(args.in_dir):
44
- spk_dir = os.path.join(args.in_dir, speaker)
45
- if os.path.isdir(spk_dir):
46
- print(spk_dir)
47
- for _ in tqdm(pool.imap_unordered(process, [(spk_dir, i, args) for i in os.listdir(spk_dir) if i.endswith("wav")])):
48
- pass
 
 
1
  import argparse
2
+ import concurrent.futures
3
+ import os
4
+ from concurrent.futures import ProcessPoolExecutor
5
+ from multiprocessing import cpu_count
6
+
7
  import librosa
8
  import numpy as np
 
9
  from scipy.io import wavfile
10
  from tqdm import tqdm
11
 
12
 
13
+ def load_wav(wav_path):
14
+ return librosa.load(wav_path, sr=None)
15
+
16
+
17
+ def trim_wav(wav, top_db=40):
18
+ return librosa.effects.trim(wav, top_db=top_db)
19
+
20
+
21
+ def normalize_peak(wav, threshold=1.0):
22
+ peak = np.abs(wav).max()
23
+ if peak > threshold:
24
+ wav = 0.98 * wav / peak
25
+ return wav
26
+
27
+
28
+ def resample_wav(wav, sr, target_sr):
29
+ return librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
30
+
31
+
32
+ def save_wav_to_path(wav, save_path, sr):
33
+ wavfile.write(
34
+ save_path,
35
+ sr,
36
+ (wav * np.iinfo(np.int16).max).astype(np.int16)
37
+ )
38
+
39
+
40
  def process(item):
41
  spkdir, wav_name, args = item
 
42
  speaker = spkdir.replace("\\", "/").split("/")[-1]
43
+
44
  wav_path = os.path.join(args.in_dir, speaker, wav_name)
45
  if os.path.exists(wav_path) and '.wav' in wav_path:
46
  os.makedirs(os.path.join(args.out_dir2, speaker), exist_ok=True)
47
+
48
+ wav, sr = load_wav(wav_path)
49
+ wav, _ = trim_wav(wav)
50
+ wav = normalize_peak(wav)
51
+ resampled_wav = resample_wav(wav, sr, args.sr2)
52
+
53
+ if not args.skip_loudnorm:
54
+ resampled_wav /= np.max(np.abs(resampled_wav))
55
+
56
+ save_path2 = os.path.join(args.out_dir2, speaker, wav_name)
57
+ save_wav_to_path(resampled_wav, save_path2, args.sr2)
 
 
 
58
 
59
 
60
+ """
61
+ def process_all_speakers():
62
+ process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1)
63
+
64
+ with ThreadPoolExecutor(max_workers=process_count) as executor:
65
+ for speaker in speakers:
66
+ spk_dir = os.path.join(args.in_dir, speaker)
67
+ if os.path.isdir(spk_dir):
68
+ print(spk_dir)
69
+ futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")]
70
+ for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
71
+ pass
72
+ """
73
+ # multi process
74
+
75
+
76
+ def process_all_speakers():
77
+ process_count = 30 if os.cpu_count() > 60 else (os.cpu_count() - 2 if os.cpu_count() > 4 else 1)
78
+ with ProcessPoolExecutor(max_workers=process_count) as executor:
79
+ for speaker in speakers:
80
+ spk_dir = os.path.join(args.in_dir, speaker)
81
+ if os.path.isdir(spk_dir):
82
+ print(spk_dir)
83
+ futures = [executor.submit(process, (spk_dir, i, args)) for i in os.listdir(spk_dir) if i.endswith("wav")]
84
+ for _ in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
85
+ pass
86
+
87
 
88
  if __name__ == "__main__":
89
  parser = argparse.ArgumentParser()
90
  parser.add_argument("--sr2", type=int, default=44100, help="sampling rate")
91
  parser.add_argument("--in_dir", type=str, default="./dataset_raw", help="path to source dir")
92
  parser.add_argument("--out_dir2", type=str, default="./dataset/44k", help="path to target dir")
93
+ parser.add_argument("--skip_loudnorm", action="store_true", help="Skip loudness matching if you have done it")
94
  args = parser.parse_args()
95
+
96
+ print(f"CPU count: {cpu_count()}")
97
+ speakers = os.listdir(args.in_dir)
98
+ process_all_speakers()
 
 
 
 
 
spkmix.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 角色混合轨道 编写规则:
2
+ # 角色ID : [[起始时间1, 终止时间1, 起始数值1, 起始数值1], [起始时间2, 终止时间2, 起始数值2, 起始数值2]]
3
+ # 起始时间和前一个的终止时间必须相同,第一个起始时间必须为0,最后一个终止时间必须为1 (时间的范围为0-1)
4
+ # 全部角色必须填写,不使用的角色填[[0., 1., 0., 0.]]即可
5
+ # 融合数值可以随便填,在指定的时间段内从起始数值线性变化为终止数值,内部会自动确保线性组合为1,可以放心使用
6
+
7
+ spk_mix_map = {
8
+ 0 : [[0., 0.5, 1, 0.5], [0.5, 1, 0.5, 1]],
9
+ 1 : [[0., 0.35, 1, 0.5], [0.35, 0.75, 0.75, 1], [0.75, 1, 0.45, 1]],
10
+ 2 : [[0., 0.35, 1, 0.5], [0.35, 0.75, 0.75, 1], [0.75, 1, 0.45, 1]]
11
+ }
train_diff.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from torch.optim import lr_scheduler
5
+
6
+ from diffusion.data_loaders import get_data_loaders
7
+ from diffusion.logger import utils
8
+ from diffusion.solver import train
9
+ from diffusion.unit2mel import Unit2Mel
10
+ from diffusion.vocoder import Vocoder
11
+
12
+
13
+ def parse_args(args=None, namespace=None):
14
+ """Parse command-line arguments."""
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ "-c",
18
+ "--config",
19
+ type=str,
20
+ required=True,
21
+ help="path to the config file")
22
+ return parser.parse_args(args=args, namespace=namespace)
23
+
24
+
25
+ if __name__ == '__main__':
26
+ # parse commands
27
+ cmd = parse_args()
28
+
29
+ # load config
30
+ args = utils.load_config(cmd.config)
31
+ print(' > config:', cmd.config)
32
+ print(' > exp:', args.env.expdir)
33
+
34
+ # load vocoder
35
+ vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device)
36
+
37
+ # load model
38
+ model = Unit2Mel(
39
+ args.data.encoder_out_channels,
40
+ args.model.n_spk,
41
+ args.model.use_pitch_aug,
42
+ vocoder.dimension,
43
+ args.model.n_layers,
44
+ args.model.n_chans,
45
+ args.model.n_hidden,
46
+ args.model.timesteps,
47
+ args.model.k_step_max
48
+ )
49
+
50
+ print(f' > INFO: now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}')
51
+
52
+ # load parameters
53
+ optimizer = torch.optim.AdamW(model.parameters())
54
+ initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device)
55
+ for param_group in optimizer.param_groups:
56
+ param_group['initial_lr'] = args.train.lr
57
+ param_group['lr'] = args.train.lr * (args.train.gamma ** max(((initial_global_step-2)//args.train.decay_step),0) )
58
+ param_group['weight_decay'] = args.train.weight_decay
59
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma,last_epoch=initial_global_step-2)
60
+
61
+ # device
62
+ if args.device == 'cuda':
63
+ torch.cuda.set_device(args.env.gpu_id)
64
+ model.to(args.device)
65
+
66
+ for state in optimizer.state.values():
67
+ for k, v in state.items():
68
+ if torch.is_tensor(v):
69
+ state[k] = v.to(args.device)
70
+
71
+ # datas
72
+ loader_train, loader_valid = get_data_loaders(args, whole_audio=False)
73
+
74
+ # run
75
+ train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid)
76
+
train_index.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+
5
+ import utils
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument(
10
+ "--root_dir", type=str, default="dataset/44k", help="path to root dir"
11
+ )
12
+ parser.add_argument('-c', '--config', type=str, default="./configs/config.json",
13
+ help='JSON file for configuration')
14
+ parser.add_argument(
15
+ "--output_dir", type=str, default="logs/44k", help="path to output dir"
16
+ )
17
+
18
+ args = parser.parse_args()
19
+
20
+ hps = utils.get_hparams_from_file(args.config)
21
+ spk_dic = hps.spk
22
+ result = {}
23
+
24
+ for k,v in spk_dic.items():
25
+ print(f"now, index {k} feature...")
26
+ index = utils.train_index(k,args.root_dir)
27
+ result[v] = index
28
+
29
+ with open(os.path.join(args.output_dir,"feature_and_index.pkl"),"wb") as f:
30
+ pickle.dump(result,f)
utils.py CHANGED
@@ -1,22 +1,21 @@
1
- import os
2
- import glob
3
- import re
4
- import sys
5
  import argparse
6
- import logging
7
  import json
 
 
 
8
  import subprocess
9
- import warnings
10
- import random
11
- import functools
 
 
12
  import librosa
13
  import numpy as np
14
- from scipy.io.wavfile import read
15
  import torch
 
 
16
  from torch.nn import functional as F
17
- from modules.commons import sequence_mask
18
- import faiss
19
- import tqdm
20
 
21
  MATPLOTLIB_FLAG = False
22
 
@@ -97,7 +96,10 @@ def get_f0_predictor(f0_predictor,hop_length,sampling_rate,**kargs):
97
  f0_predictor_object = HarvestF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
98
  elif f0_predictor == "dio":
99
  from modules.F0Predictor.DioF0Predictor import DioF0Predictor
100
- f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
 
 
 
101
  else:
102
  raise Exception("Unknown f0 predictor")
103
  return f0_predictor_object
@@ -130,6 +132,18 @@ def get_speech_encoder(speech_encoder,device=None,**kargs):
130
  elif speech_encoder == "whisper-ppg":
131
  from vencoder.WhisperPPG import WhisperPPG
132
  speech_encoder_object = WhisperPPG(device = device)
 
 
 
 
 
 
 
 
 
 
 
 
133
  else:
134
  raise Exception("Unknown speech encoder")
135
  return speech_encoder_object
@@ -142,6 +156,7 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
142
  if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
143
  optimizer.load_state_dict(checkpoint_dict['optimizer'])
144
  saved_state_dict = checkpoint_dict['model']
 
145
  if hasattr(model, 'module'):
146
  state_dict = model.module.state_dict()
147
  else:
@@ -153,10 +168,11 @@ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False
153
  # print("load", k)
154
  new_state_dict[k] = saved_state_dict[k]
155
  assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
156
- except:
157
- print("error, %s is not in the checkpoint" % k)
158
- logger.info("%s is not in the checkpoint" % k)
159
- new_state_dict[k] = v
 
160
  if hasattr(model, 'module'):
161
  model.module.load_state_dict(new_state_dict)
162
  else:
@@ -189,15 +205,20 @@ def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_tim
189
  False -> lexicographically delete ckpts
190
  """
191
  ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
192
- name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1)))
193
- time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f)))
 
 
194
  sort_key = time_key if sort_by_time else name_key
195
- x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], key=sort_key)
 
196
  to_del = [os.path.join(path_to_models, fn) for fn in
197
  (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
198
- del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}")
199
- del_routine = lambda x: [os.remove(x), del_info(x)]
200
- rs = [del_routine(fn) for fn in to_del]
 
 
201
 
202
  def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
203
  for k, v in scalars.items():
@@ -325,11 +346,11 @@ def get_hparams_from_dir(model_dir):
325
  return hparams
326
 
327
 
328
- def get_hparams_from_file(config_path):
329
  with open(config_path, "r") as f:
330
  data = f.read()
331
  config = json.loads(data)
332
- hparams =HParams(**config)
333
  return hparams
334
 
335
 
@@ -368,7 +389,13 @@ def get_logger(model_dir, filename="train.log"):
368
  return logger
369
 
370
 
371
- def repeat_expand_2d(content, target_len):
 
 
 
 
 
 
372
  # content : [h, t]
373
 
374
  src_len = content.shape[-1]
@@ -385,6 +412,14 @@ def repeat_expand_2d(content, target_len):
385
  return target
386
 
387
 
 
 
 
 
 
 
 
 
388
  def mix_model(model_paths,mix_rate,mode):
389
  mix_rate = torch.FloatTensor(mix_rate)/100
390
  model_tem = torch.load(model_paths[0])
@@ -420,6 +455,7 @@ def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出
420
  return data2
421
 
422
  def train_index(spk_name,root_dir = "dataset/44k/"): #from: RVC https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI
 
423
  print("The feature index is constructing.")
424
  exp_dir = os.path.join(root_dir,spk_name)
425
  listdir_res = []
@@ -436,6 +472,25 @@ def train_index(spk_name,root_dir = "dataset/44k/"): #from: RVC https://github.
436
  big_npy_idx = np.arange(big_npy.shape[0])
437
  np.random.shuffle(big_npy_idx)
438
  big_npy = big_npy[big_npy_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439
  n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
440
  index = faiss.index_factory(big_npy.shape[1] , "IVF%s,Flat" % n_ivf)
441
  index_ivf = faiss.extract_index_ivf(index) #
@@ -486,6 +541,18 @@ class HParams():
486
  def get(self,index):
487
  return self.__dict__.get(index)
488
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  class Volume_Extractor:
490
  def __init__(self, hop_size = 512):
491
  self.hop_size = hop_size
@@ -496,6 +563,6 @@ class Volume_Extractor:
496
  n_frames = int(audio.size(-1) // self.hop_size)
497
  audio2 = audio ** 2
498
  audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
499
- volume = torch.FloatTensor([torch.mean(audio2[:,int(n * self.hop_size) : int((n + 1) * self.hop_size)]) for n in range(n_frames)])
500
  volume = torch.sqrt(volume)
501
- return volume
 
 
 
 
 
1
  import argparse
2
+ import glob
3
  import json
4
+ import logging
5
+ import os
6
+ import re
7
  import subprocess
8
+ import sys
9
+ import traceback
10
+ from multiprocessing import cpu_count
11
+
12
+ import faiss
13
  import librosa
14
  import numpy as np
 
15
  import torch
16
+ from scipy.io.wavfile import read
17
+ from sklearn.cluster import MiniBatchKMeans
18
  from torch.nn import functional as F
 
 
 
19
 
20
  MATPLOTLIB_FLAG = False
21
 
 
96
  f0_predictor_object = HarvestF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
97
  elif f0_predictor == "dio":
98
  from modules.F0Predictor.DioF0Predictor import DioF0Predictor
99
+ f0_predictor_object = DioF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate)
100
+ elif f0_predictor == "rmvpe":
101
+ from modules.F0Predictor.RMVPEF0Predictor import RMVPEF0Predictor
102
+ f0_predictor_object = RMVPEF0Predictor(hop_length=hop_length,sampling_rate=sampling_rate,dtype=torch.float32 ,device=kargs["device"],threshold=kargs["threshold"])
103
  else:
104
  raise Exception("Unknown f0 predictor")
105
  return f0_predictor_object
 
132
  elif speech_encoder == "whisper-ppg":
133
  from vencoder.WhisperPPG import WhisperPPG
134
  speech_encoder_object = WhisperPPG(device = device)
135
+ elif speech_encoder == "cnhubertlarge":
136
+ from vencoder.CNHubertLarge import CNHubertLarge
137
+ speech_encoder_object = CNHubertLarge(device = device)
138
+ elif speech_encoder == "dphubert":
139
+ from vencoder.DPHubert import DPHubert
140
+ speech_encoder_object = DPHubert(device = device)
141
+ elif speech_encoder == "whisper-ppg-large":
142
+ from vencoder.WhisperPPGLarge import WhisperPPGLarge
143
+ speech_encoder_object = WhisperPPGLarge(device = device)
144
+ elif speech_encoder == "wavlmbase+":
145
+ from vencoder.WavLMBasePlus import WavLMBasePlus
146
+ speech_encoder_object = WavLMBasePlus(device = device)
147
  else:
148
  raise Exception("Unknown speech encoder")
149
  return speech_encoder_object
 
156
  if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None:
157
  optimizer.load_state_dict(checkpoint_dict['optimizer'])
158
  saved_state_dict = checkpoint_dict['model']
159
+ model = model.to(list(saved_state_dict.values())[0].dtype)
160
  if hasattr(model, 'module'):
161
  state_dict = model.module.state_dict()
162
  else:
 
168
  # print("load", k)
169
  new_state_dict[k] = saved_state_dict[k]
170
  assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape)
171
+ except Exception:
172
+ if "enc_q" not in k or "emb_g" not in k:
173
+ print("error, %s is not in the checkpoint" % k)
174
+ logger.info("%s is not in the checkpoint" % k)
175
+ new_state_dict[k] = v
176
  if hasattr(model, 'module'):
177
  model.module.load_state_dict(new_state_dict)
178
  else:
 
205
  False -> lexicographically delete ckpts
206
  """
207
  ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))]
208
+ def name_key(_f):
209
+ return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
210
+ def time_key(_f):
211
+ return os.path.getmtime(os.path.join(path_to_models, _f))
212
  sort_key = time_key if sort_by_time else name_key
213
+ def x_sorted(_x):
214
+ return sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")], key=sort_key)
215
  to_del = [os.path.join(path_to_models, fn) for fn in
216
  (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])]
217
+ def del_info(fn):
218
+ return logger.info(f".. Free up space by deleting ckpt {fn}")
219
+ def del_routine(x):
220
+ return [os.remove(x), del_info(x)]
221
+ [del_routine(fn) for fn in to_del]
222
 
223
  def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
224
  for k, v in scalars.items():
 
346
  return hparams
347
 
348
 
349
+ def get_hparams_from_file(config_path, infer_mode = False):
350
  with open(config_path, "r") as f:
351
  data = f.read()
352
  config = json.loads(data)
353
+ hparams =HParams(**config) if not infer_mode else InferHParams(**config)
354
  return hparams
355
 
356
 
 
389
  return logger
390
 
391
 
392
+ def repeat_expand_2d(content, target_len, mode = 'left'):
393
+ # content : [h, t]
394
+ return repeat_expand_2d_left(content, target_len) if mode == 'left' else repeat_expand_2d_other(content, target_len, mode)
395
+
396
+
397
+
398
+ def repeat_expand_2d_left(content, target_len):
399
  # content : [h, t]
400
 
401
  src_len = content.shape[-1]
 
412
  return target
413
 
414
 
415
+ # mode : 'nearest'| 'linear'| 'bilinear'| 'bicubic'| 'trilinear'| 'area'
416
+ def repeat_expand_2d_other(content, target_len, mode = 'nearest'):
417
+ # content : [h, t]
418
+ content = content[None,:,:]
419
+ target = F.interpolate(content,size=target_len,mode=mode)[0]
420
+ return target
421
+
422
+
423
  def mix_model(model_paths,mix_rate,mode):
424
  mix_rate = torch.FloatTensor(mix_rate)/100
425
  model_tem = torch.load(model_paths[0])
 
455
  return data2
456
 
457
  def train_index(spk_name,root_dir = "dataset/44k/"): #from: RVC https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI
458
+ n_cpu = cpu_count()
459
  print("The feature index is constructing.")
460
  exp_dir = os.path.join(root_dir,spk_name)
461
  listdir_res = []
 
472
  big_npy_idx = np.arange(big_npy.shape[0])
473
  np.random.shuffle(big_npy_idx)
474
  big_npy = big_npy[big_npy_idx]
475
+ if big_npy.shape[0] > 2e5:
476
+ # if(1):
477
+ info = "Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0]
478
+ print(info)
479
+ try:
480
+ big_npy = (
481
+ MiniBatchKMeans(
482
+ n_clusters=10000,
483
+ verbose=True,
484
+ batch_size=256 * n_cpu,
485
+ compute_labels=False,
486
+ init="random",
487
+ )
488
+ .fit(big_npy)
489
+ .cluster_centers_
490
+ )
491
+ except Exception:
492
+ info = traceback.format_exc()
493
+ print(info)
494
  n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
495
  index = faiss.index_factory(big_npy.shape[1] , "IVF%s,Flat" % n_ivf)
496
  index_ivf = faiss.extract_index_ivf(index) #
 
541
  def get(self,index):
542
  return self.__dict__.get(index)
543
 
544
+
545
+ class InferHParams(HParams):
546
+ def __init__(self, **kwargs):
547
+ for k, v in kwargs.items():
548
+ if type(v) == dict:
549
+ v = InferHParams(**v)
550
+ self[k] = v
551
+
552
+ def __getattr__(self,index):
553
+ return self.get(index)
554
+
555
+
556
  class Volume_Extractor:
557
  def __init__(self, hop_size = 512):
558
  self.hop_size = hop_size
 
563
  n_frames = int(audio.size(-1) // self.hop_size)
564
  audio2 = audio ** 2
565
  audio2 = torch.nn.functional.pad(audio2, (int(self.hop_size // 2), int((self.hop_size + 1) // 2)), mode = 'reflect')
566
+ volume = torch.nn.functional.unfold(audio2[:,None,None,:],(1,self.hop_size),stride=self.hop_size)[:,:,:n_frames].mean(dim=1)[0]
567
  volume = torch.sqrt(volume)
568
+ return volume