onereal kevinwang676 commited on
Commit
0c8b086
·
0 Parent(s):

Duplicate from kevinwang676/Voice-Cloning-for-Bilibili

Browse files

Co-authored-by: Kevin Wang <kevinwang676@users.noreply.huggingface.co>

Files changed (8) hide show
  1. .gitattributes +34 -0
  2. Makefile +11 -0
  3. README.md +14 -0
  4. app.py +300 -0
  5. packages.txt +3 -0
  6. pyproject.toml +17 -0
  7. requirements.txt +6 -0
  8. training_so_vits_svc_fork.ipynb +540 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Makefile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: quality style
2
+
3
+ # Check that source code meets quality standards
4
+ quality:
5
+ black --check --diff .
6
+ ruff .
7
+
8
+ # Format source code automatically
9
+ style:
10
+ black .
11
+ ruff . --fix
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Voice Cloning
3
+ emoji: 😻
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.27.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ duplicated_from: kevinwang676/Voice-Cloning-for-Bilibili
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import subprocess
4
+ from pathlib import Path
5
+
6
+ import gradio as gr
7
+ import librosa
8
+ import numpy as np
9
+ import torch
10
+ from demucs.apply import apply_model
11
+ from demucs.pretrained import DEFAULT_MODEL, get_model
12
+ from huggingface_hub import hf_hub_download, list_repo_files
13
+
14
+ from so_vits_svc_fork.hparams import HParams
15
+ from so_vits_svc_fork.inference.core import Svc
16
+
17
+
18
+ ###################################################################
19
+ # REPLACE THESE VALUES TO CHANGE THE MODEL REPO/CKPT NAME/SETTINGS
20
+ ###################################################################
21
+ # The Hugging Face Hub repo ID - 在这里修改repo_id,可替换成任何已经训练好的模型!
22
+ repo_id = "kevinwang676/guesswho"
23
+
24
+ # If None, Uses latest ckpt in the repo
25
+ ckpt_name = None
26
+
27
+ # If None, Uses "kmeans.pt" if it exists in the repo
28
+ cluster_model_name = None
29
+
30
+ # Set the default f0 type to use - use the one it was trained on.
31
+ # The default for so-vits-svc-fork is "dio".
32
+ # Options: "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
33
+ default_f0_method = "crepe"
34
+
35
+ # The default ratio of cluster inference to SVC inference.
36
+ # If cluster_model_name is not found in the repo, this is set to 0.
37
+ default_cluster_infer_ratio = 0.5
38
+
39
+ # Limit on duration of audio at inference time. increase if you can
40
+ # In this parent app, we set the limit with an env var to 30 seconds
41
+ # If you didnt set env var + you go OOM try changing 9e9 to <=300ish
42
+ duration_limit = int(os.environ.get("MAX_DURATION_SECONDS", 9e9))
43
+ ###################################################################
44
+
45
+ # Figure out the latest generator by taking highest value one.
46
+ # Ex. if the repo has: G_0.pth, G_100.pth, G_200.pth, we'd use G_200.pth
47
+ if ckpt_name is None:
48
+ latest_id = sorted(
49
+ [
50
+ int(Path(x).stem.split("_")[1])
51
+ for x in list_repo_files(repo_id)
52
+ if x.startswith("G_") and x.endswith(".pth")
53
+ ]
54
+ )[-1]
55
+ ckpt_name = f"G_{latest_id}.pth"
56
+
57
+ cluster_model_name = cluster_model_name or "kmeans.pt"
58
+ if cluster_model_name in list_repo_files(repo_id):
59
+ print(f"Found Cluster model - Downloading {cluster_model_name} from {repo_id}")
60
+ cluster_model_path = hf_hub_download(repo_id, cluster_model_name)
61
+ else:
62
+ print(f"Could not find {cluster_model_name} in {repo_id}. Using None")
63
+ cluster_model_path = None
64
+ default_cluster_infer_ratio = default_cluster_infer_ratio if cluster_model_path else 0
65
+
66
+ generator_path = hf_hub_download(repo_id, ckpt_name)
67
+ config_path = hf_hub_download(repo_id, "config.json")
68
+ hparams = HParams(**json.loads(Path(config_path).read_text()))
69
+ speakers = list(hparams.spk.keys())
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+ model = Svc(net_g_path=generator_path, config_path=config_path, device=device, cluster_model_path=cluster_model_path)
72
+ demucs_model = get_model(DEFAULT_MODEL)
73
+
74
+
75
+ def extract_vocal_demucs(model, filename, sr=44100, device=None, shifts=1, split=True, overlap=0.25, jobs=0):
76
+ wav, sr = librosa.load(filename, mono=False, sr=sr)
77
+ wav = torch.tensor(wav)
78
+ ref = wav.mean(0)
79
+ wav = (wav - ref.mean()) / ref.std()
80
+ sources = apply_model(
81
+ model, wav[None], device=device, shifts=shifts, split=split, overlap=overlap, progress=True, num_workers=jobs
82
+ )[0]
83
+ sources = sources * ref.std() + ref.mean()
84
+ # We take just the vocals stem. I know the vocals for this model are at index -1
85
+ # If using different model, check model.sources.index('vocals')
86
+ vocal_wav = sources[-1]
87
+ # I did this because its the same normalization the so-vits model required
88
+ vocal_wav = vocal_wav / max(1.01 * vocal_wav.abs().max(), 1)
89
+ vocal_wav = vocal_wav.numpy()
90
+ vocal_wav = librosa.to_mono(vocal_wav)
91
+ vocal_wav = vocal_wav.T
92
+ instrumental_wav = sources[:-1].sum(0).numpy().T
93
+ return vocal_wav, instrumental_wav
94
+
95
+
96
+ def download_youtube_clip(
97
+ video_identifier,
98
+ start_time,
99
+ end_time,
100
+ output_filename,
101
+ num_attempts=5,
102
+ url_base="https://www.youtube.com/watch?v=",
103
+ quiet=False,
104
+ force=False,
105
+ ):
106
+ output_path = Path(output_filename)
107
+ if output_path.exists():
108
+ if not force:
109
+ return output_path
110
+ else:
111
+ output_path.unlink()
112
+
113
+ quiet = "--quiet --no-warnings" if quiet else ""
114
+ command = f"""
115
+ yt-dlp {quiet} -x --audio-format wav -f bestaudio -o "{output_filename}" --download-sections "*{start_time}-{end_time}" "{url_base}{video_identifier}" # noqa: E501
116
+ """.strip()
117
+
118
+ attempts = 0
119
+ while True:
120
+ try:
121
+ _ = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)
122
+ except subprocess.CalledProcessError:
123
+ attempts += 1
124
+ if attempts == num_attempts:
125
+ return None
126
+ else:
127
+ break
128
+
129
+ if output_path.exists():
130
+ return output_path
131
+ else:
132
+ return None
133
+
134
+
135
+ def predict(
136
+ speaker,
137
+ audio,
138
+ transpose: int = 0,
139
+ auto_predict_f0: bool = False,
140
+ cluster_infer_ratio: float = 0,
141
+ noise_scale: float = 0.4,
142
+ f0_method: str = "crepe",
143
+ db_thresh: int = -40,
144
+ pad_seconds: float = 0.5,
145
+ chunk_seconds: float = 0.5,
146
+ absolute_thresh: bool = False,
147
+ ):
148
+ audio, _ = librosa.load(audio, sr=model.target_sample, duration=duration_limit)
149
+ audio = model.infer_silence(
150
+ audio.astype(np.float32),
151
+ speaker=speaker,
152
+ transpose=transpose,
153
+ auto_predict_f0=auto_predict_f0,
154
+ cluster_infer_ratio=cluster_infer_ratio,
155
+ noise_scale=noise_scale,
156
+ f0_method=f0_method,
157
+ db_thresh=db_thresh,
158
+ pad_seconds=pad_seconds,
159
+ chunk_seconds=chunk_seconds,
160
+ absolute_thresh=absolute_thresh,
161
+ )
162
+ return model.target_sample, audio
163
+
164
+
165
+ def predict_song_from_yt(
166
+ ytid_or_url,
167
+ start,
168
+ end,
169
+ speaker=speakers[0],
170
+ transpose: int = 0,
171
+ auto_predict_f0: bool = False,
172
+ cluster_infer_ratio: float = 0,
173
+ noise_scale: float = 0.4,
174
+ f0_method: str = "dio",
175
+ db_thresh: int = -40,
176
+ pad_seconds: float = 0.5,
177
+ chunk_seconds: float = 0.5,
178
+ absolute_thresh: bool = False,
179
+ ):
180
+ end = min(start + duration_limit, end)
181
+ original_track_filepath = download_youtube_clip(
182
+ ytid_or_url,
183
+ start,
184
+ end,
185
+ "track.wav",
186
+ force=True,
187
+ url_base="" if ytid_or_url.startswith("http") else "https://www.youtube.com/watch?v=",
188
+ )
189
+ vox_wav, inst_wav = extract_vocal_demucs(demucs_model, original_track_filepath)
190
+ if transpose != 0:
191
+ inst_wav = librosa.effects.pitch_shift(inst_wav.T, sr=model.target_sample, n_steps=transpose).T
192
+ cloned_vox = model.infer_silence(
193
+ vox_wav.astype(np.float32),
194
+ speaker=speaker,
195
+ transpose=transpose,
196
+ auto_predict_f0=auto_predict_f0,
197
+ cluster_infer_ratio=cluster_infer_ratio,
198
+ noise_scale=noise_scale,
199
+ f0_method=f0_method,
200
+ db_thresh=db_thresh,
201
+ pad_seconds=pad_seconds,
202
+ chunk_seconds=chunk_seconds,
203
+ absolute_thresh=absolute_thresh,
204
+ )
205
+ full_song = inst_wav + np.expand_dims(cloned_vox, 1)
206
+ return (model.target_sample, full_song), (model.target_sample, cloned_vox)
207
+
208
+ SPACE_ID = "nateraw/voice-cloning"
209
+ description = f"""
210
+
211
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a></center>
212
+
213
+ #### This app uses models trained with [so-vits-svc-fork](https://github.com/voicepaw/so-vits-svc-fork) to clone a voice. Model currently being used is https://hf.co/{repo_id}. To change the model being served, duplicate the space and update the `repo_id`/other settings in `app.py`.
214
+
215
+ #### Train Your Own: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nateraw/voice-cloning/blob/main/training_so_vits_svc_fork.ipynb)
216
+ """.strip()
217
+
218
+ article = """
219
+ <p style='text-align: center'> 注意❗:请不要生成会对个人以及组织造成侵害的内容,此程序仅供科研、学习及个人娱乐使用。用户生成内容与程序开发者无关,请自觉合法合规使用,违反者一切后果自负。
220
+ </p>
221
+ """.strip()
222
+
223
+
224
+ interface_mic = gr.Interface(
225
+ predict,
226
+ inputs=[
227
+ gr.Dropdown(speakers, value=speakers[0], label="Target Speaker"),
228
+ gr.Audio(type="filepath", source="microphone", label="Source Audio"),
229
+ gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
230
+ gr.Checkbox(False, label="Auto Predict F0"),
231
+ gr.Slider(0.0, 1.0, value=default_cluster_infer_ratio, step=0.1, label="cluster infer ratio"),
232
+ gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
233
+ gr.Dropdown(
234
+ choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
235
+ value=default_f0_method,
236
+ label="f0 method",
237
+ ),
238
+ ],
239
+ outputs="audio",
240
+ title="🥳🎶🎡 - AI歌手:可从网址直接上传素材,且无需分离背景音",
241
+ description=description,
242
+ article=article,
243
+ )
244
+ interface_file = gr.Interface(
245
+ predict,
246
+ inputs=[
247
+ gr.Dropdown(speakers, value=speakers[0], label="Target Speaker"),
248
+ gr.Audio(type="filepath", source="upload", label="Source Audio"),
249
+ gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
250
+ gr.Checkbox(False, label="Auto Predict F0"),
251
+ gr.Slider(0.0, 1.0, value=default_cluster_infer_ratio, step=0.1, label="cluster infer ratio"),
252
+ gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
253
+ gr.Dropdown(
254
+ choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
255
+ value=default_f0_method,
256
+ label="f0 method",
257
+ ),
258
+ ],
259
+ outputs="audio",
260
+ title="🥳🎶🎡 - AI歌手:可从网址直接上传素材,且无需分离背景音",
261
+ description=description,
262
+ article=article,
263
+ )
264
+ interface_yt = gr.Interface(
265
+ predict_song_from_yt,
266
+ inputs=[
267
+ gr.Textbox(
268
+ label="Bilibili网址", info="请填写含有您喜欢的声音的Bilibili网址"
269
+ ),
270
+ gr.Number(value=0, label="Start Time (seconds)"),
271
+ gr.Number(value=15, label="End Time (seconds)"),
272
+ gr.Dropdown(speakers, value=speakers[0], label="Target Speaker"),
273
+ gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
274
+ gr.Checkbox(False, label="Auto Predict F0"),
275
+ gr.Slider(0.0, 1.0, value=default_cluster_infer_ratio, step=0.1, label="cluster infer ratio"),
276
+ gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
277
+ gr.Dropdown(
278
+ choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
279
+ value=default_f0_method,
280
+ label="f0 method",
281
+ ),
282
+ ],
283
+ outputs=["audio", "audio"],
284
+ title="🥳🎶🎡 - AI歌手:可从网址直接上传素材,且无需分离背景音",
285
+ description=description,
286
+ article=article,
287
+ examples=[
288
+ ["COz9lDCFHjw", 75, 90, speakers[0], 0, False, default_cluster_infer_ratio, 0.4, default_f0_method],
289
+ ["dQw4w9WgXcQ", 21, 35, speakers[0], 0, False, default_cluster_infer_ratio, 0.4, default_f0_method],
290
+ ["Wvm5GuDfAas", 15, 30, speakers[0], 0, False, default_cluster_infer_ratio, 0.4, default_f0_method],
291
+ ],
292
+ )
293
+ interface = gr.TabbedInterface(
294
+ [interface_mic, interface_file, interface_yt],
295
+ ["从麦克风上传", "从文件上传", "从Bilibili上传"],
296
+ )
297
+
298
+
299
+ if __name__ == "__main__":
300
+ interface.launch()
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ffmpeg
2
+ x264
3
+ libx264-dev
pyproject.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.black]
2
+ line-length = 119
3
+ target_version = ['py37']
4
+
5
+ [tool.ruff]
6
+ # Never enforce `E501` (line length violations).
7
+ ignore = ["C901", "E501", "E741", "W605"]
8
+ select = ["C", "E", "F", "I", "W"]
9
+ line-length = 119
10
+
11
+ # Ignore import violations in all `__init__.py` files.
12
+ [tool.ruff.per-file-ignores]
13
+ "__init__.py" = ["E402", "F401", "F403", "F811"]
14
+
15
+ [tool.ruff.isort]
16
+ known-first-party = ["so_vits_svc_fork"]
17
+ lines-after-imports = 2
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ so-vits-svc-fork
2
+ gradio
3
+ huggingface_hub
4
+ yt-dlp
5
+ demucs
6
+ gradio
training_so_vits_svc_fork.ipynb ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "view-in-github",
7
+ "colab_type": "text"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/nateraw/voice-cloning/blob/main/training_so_vits_svc_fork.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {
17
+ "id": "jIcNJ5QfDsV_"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "# %%capture\n",
22
+ "! pip install git+https://github.com/nateraw/so-vits-svc-fork@main\n",
23
+ "! pip install openai-whisper yt-dlp huggingface_hub demucs"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "metadata": {
29
+ "id": "6uZAhUPOhFv9"
30
+ },
31
+ "source": [
32
+ "---\n",
33
+ "\n",
34
+ "# Restart runtime\n",
35
+ "\n",
36
+ "After running the cell above, you'll need to restart the Colab runtime because we installed a different version of numpy.\n",
37
+ "\n",
38
+ "`Runtime -> Restart runtime`\n",
39
+ "\n",
40
+ "---"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {
47
+ "id": "DROusQatF-wF"
48
+ },
49
+ "outputs": [],
50
+ "source": [
51
+ "from huggingface_hub import login\n",
52
+ "\n",
53
+ "login()"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "source": [
59
+ "## Settings"
60
+ ],
61
+ "metadata": {
62
+ "id": "yOM9WWmmRqTA"
63
+ }
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {
69
+ "id": "5oTDjDEKFz3W"
70
+ },
71
+ "outputs": [],
72
+ "source": [
73
+ "CHARACTER = \"kanye\"\n",
74
+ "DO_EXTRACT_VOCALS = False\n",
75
+ "MODEL_REPO_ID = \"dog/kanye\""
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "metadata": {
81
+ "id": "BFd_ly1P_5Ht"
82
+ },
83
+ "source": [
84
+ "## Data Preparation\n",
85
+ "\n",
86
+ "Prepare a data.csv file here with `ytid,start,end` as the first line (they're the expected column names). Then, prepare a training set given YouTube IDs and their start and end segment times in seconds. Try to pick segments that have dry vocal only, as that'll provide the best results.\n",
87
+ "\n",
88
+ "An example is given below for Kanye West."
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {
95
+ "id": "rBrtgDtWmhRb"
96
+ },
97
+ "outputs": [],
98
+ "source": [
99
+ "%%writefile data.csv\n",
100
+ "ytid,start,end\n",
101
+ "lkK4de9nbzQ,0,137\n",
102
+ "gXU9Am2Seo0,30,69\n",
103
+ "gXU9Am2Seo0,94,135\n",
104
+ "iVgrhWvQpqU,0,55\n",
105
+ "iVgrhWvQpqU,58,110\n",
106
+ "UIV-q-gneKA,85,99\n",
107
+ "UIV-q-gneKA,110,125\n",
108
+ "UIV-q-gneKA,127,141\n",
109
+ "UIV-q-gneKA,173,183\n",
110
+ "GmlyYCGE9ak,0,102\n",
111
+ "x-7aWcPmJ60,25,43\n",
112
+ "x-7aWcPmJ60,47,72\n",
113
+ "x-7aWcPmJ60,98,113\n",
114
+ "DK2LCIzIBrU,0,56\n",
115
+ "DK2LCIzIBrU,80,166\n",
116
+ "_W56nZk0fCI,184,224"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "metadata": {
123
+ "id": "cxxp4uYoC0aG"
124
+ },
125
+ "outputs": [],
126
+ "source": [
127
+ "import subprocess\n",
128
+ "from pathlib import Path\n",
129
+ "import librosa\n",
130
+ "from scipy.io import wavfile\n",
131
+ "import numpy as np\n",
132
+ "from demucs.pretrained import get_model, DEFAULT_MODEL\n",
133
+ "from demucs.apply import apply_model\n",
134
+ "import torch\n",
135
+ "import csv\n",
136
+ "import whisper\n",
137
+ "\n",
138
+ "\n",
139
+ "def download_youtube_clip(video_identifier, start_time, end_time, output_filename, num_attempts=5, url_base=\"https://www.youtube.com/watch?v=\"):\n",
140
+ " status = False\n",
141
+ "\n",
142
+ " output_path = Path(output_filename)\n",
143
+ " if output_path.exists():\n",
144
+ " return True, \"Already Downloaded\"\n",
145
+ "\n",
146
+ " command = f\"\"\"\n",
147
+ " yt-dlp --quiet --no-warnings -x --audio-format wav -f bestaudio -o \"{output_filename}\" --download-sections \"*{start_time}-{end_time}\" \"{url_base}{video_identifier}\"\n",
148
+ " \"\"\".strip()\n",
149
+ "\n",
150
+ " attempts = 0\n",
151
+ " while True:\n",
152
+ " try:\n",
153
+ " output = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)\n",
154
+ " except subprocess.CalledProcessError as err:\n",
155
+ " attempts += 1\n",
156
+ " if attempts == num_attempts:\n",
157
+ " return status, err.output\n",
158
+ " else:\n",
159
+ " break\n",
160
+ "\n",
161
+ " status = output_path.exists()\n",
162
+ " return status, \"Downloaded\"\n",
163
+ "\n",
164
+ "\n",
165
+ "def split_long_audio(model, filepaths, character_name, save_dir=\"data_dir\", out_sr=44100):\n",
166
+ " if isinstance(filepaths, str):\n",
167
+ " filepaths = [filepaths]\n",
168
+ "\n",
169
+ " for file_idx, filepath in enumerate(filepaths):\n",
170
+ "\n",
171
+ " save_path = Path(save_dir) / character_name\n",
172
+ " save_path.mkdir(exist_ok=True, parents=True)\n",
173
+ "\n",
174
+ " print(f\"Transcribing file {file_idx}: '{filepath}' to segments...\")\n",
175
+ " result = model.transcribe(filepath, word_timestamps=True, task=\"transcribe\", beam_size=5, best_of=5)\n",
176
+ " segments = result['segments']\n",
177
+ " \n",
178
+ " wav, sr = librosa.load(filepath, sr=None, offset=0, duration=None, mono=True)\n",
179
+ " wav, _ = librosa.effects.trim(wav, top_db=20)\n",
180
+ " peak = np.abs(wav).max()\n",
181
+ " if peak > 1.0:\n",
182
+ " wav = 0.98 * wav / peak\n",
183
+ " wav2 = librosa.resample(wav, orig_sr=sr, target_sr=out_sr)\n",
184
+ " wav2 /= max(wav2.max(), -wav2.min())\n",
185
+ "\n",
186
+ " for i, seg in enumerate(segments):\n",
187
+ " start_time = seg['start']\n",
188
+ " end_time = seg['end']\n",
189
+ " wav_seg = wav2[int(start_time * out_sr):int(end_time * out_sr)]\n",
190
+ " wav_seg_name = f\"{character_name}_{file_idx}_{i}.wav\"\n",
191
+ " out_fpath = save_path / wav_seg_name\n",
192
+ " wavfile.write(out_fpath, rate=out_sr, data=(wav_seg * np.iinfo(np.int16).max).astype(np.int16))\n",
193
+ "\n",
194
+ "\n",
195
+ "def extract_vocal_demucs(model, filename, out_filename, sr=44100, device=None, shifts=1, split=True, overlap=0.25, jobs=0):\n",
196
+ " wav, sr = librosa.load(filename, mono=False, sr=sr)\n",
197
+ " wav = torch.tensor(wav)\n",
198
+ " ref = wav.mean(0)\n",
199
+ " wav = (wav - ref.mean()) / ref.std()\n",
200
+ " sources = apply_model(\n",
201
+ " model,\n",
202
+ " wav[None],\n",
203
+ " device=device,\n",
204
+ " shifts=shifts,\n",
205
+ " split=split,\n",
206
+ " overlap=overlap,\n",
207
+ " progress=True,\n",
208
+ " num_workers=jobs\n",
209
+ " )[0]\n",
210
+ " sources = sources * ref.std() + ref.mean()\n",
211
+ "\n",
212
+ " wav = sources[-1]\n",
213
+ " wav = wav / max(1.01 * wav.abs().max(), 1)\n",
214
+ " wavfile.write(out_filename, rate=sr, data=wav.numpy().T)\n",
215
+ " return out_filename\n",
216
+ "\n",
217
+ "\n",
218
+ "def create_dataset(\n",
219
+ " clips_csv_filepath = \"data.csv\",\n",
220
+ " character = \"somebody\",\n",
221
+ " do_extract_vocals = False,\n",
222
+ " whisper_size = \"medium\",\n",
223
+ " # Where raw yt clips will be downloaded to\n",
224
+ " dl_dir = \"downloads\",\n",
225
+ " # Where actual data will be organized\n",
226
+ " data_dir = \"dataset_raw\",\n",
227
+ " **kwargs\n",
228
+ "):\n",
229
+ " dl_path = Path(dl_dir) / character\n",
230
+ " dl_path.mkdir(exist_ok=True, parents=True)\n",
231
+ " if do_extract_vocals:\n",
232
+ " demucs_model = get_model(DEFAULT_MODEL)\n",
233
+ "\n",
234
+ " with Path(clips_csv_filepath).open() as f:\n",
235
+ " reader = csv.DictReader(f)\n",
236
+ " for i, row in enumerate(reader):\n",
237
+ " outfile_path = dl_path / f\"{character}_{i:04d}.wav\"\n",
238
+ " download_youtube_clip(row['ytid'], row['start'], row['end'], outfile_path)\n",
239
+ " if do_extract_vocals:\n",
240
+ " extract_vocal_demucs(demucs_model, outfile_path, outfile_path)\n",
241
+ "\n",
242
+ " filenames = sorted([str(x) for x in dl_path.glob(\"*.wav\")])\n",
243
+ " whisper_model = whisper.load_model(whisper_size)\n",
244
+ " split_long_audio(whisper_model, filenames, character, data_dir) "
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "metadata": {
251
+ "id": "D9GrcDUKEGro"
252
+ },
253
+ "outputs": [],
254
+ "source": [
255
+ "\"\"\"\n",
256
+ "Here, we override config to have num_workers=0 because\n",
257
+ "of a limitation in HF Spaces Docker /dev/shm.\n",
258
+ "\"\"\"\n",
259
+ "\n",
260
+ "import json\n",
261
+ "from pathlib import Path\n",
262
+ "import multiprocessing\n",
263
+ "\n",
264
+ "def update_config(config_file=\"configs/44k/config.json\"):\n",
265
+ " config_path = Path(config_file)\n",
266
+ " data = json.loads(config_path.read_text())\n",
267
+ " data['train']['batch_size'] = 32\n",
268
+ " data['train']['eval_interval'] = 500\n",
269
+ " data['train']['num_workers'] = multiprocessing.cpu_count()\n",
270
+ " data['train']['persistent_workers'] = True\n",
271
+ " data['train']['push_to_hub'] = True\n",
272
+ " data['train']['repo_id'] = MODEL_REPO_ID # tuple(data['spk'])[0]\n",
273
+ " data['train']['private'] = True\n",
274
+ " config_path.write_text(json.dumps(data, indent=2, sort_keys=False))"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "markdown",
279
+ "source": [
280
+ "## Run all Preprocessing Steps"
281
+ ],
282
+ "metadata": {
283
+ "id": "aF6OZkTZRzhj"
284
+ }
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": null,
289
+ "metadata": {
290
+ "id": "OAPnD3xKD_Gw"
291
+ },
292
+ "outputs": [],
293
+ "source": [
294
+ "create_dataset(character=CHARACTER, do_extract_vocals=DO_EXTRACT_VOCALS)\n",
295
+ "! svc pre-resample\n",
296
+ "! svc pre-config\n",
297
+ "! svc pre-hubert -fm crepe\n",
298
+ "update_config()"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "markdown",
303
+ "source": [
304
+ "## Training"
305
+ ],
306
+ "metadata": {
307
+ "id": "VpyGazF6R3CE"
308
+ }
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "metadata": {
314
+ "colab": {
315
+ "background_save": true
316
+ },
317
+ "id": "MByHpf_wEByg"
318
+ },
319
+ "outputs": [],
320
+ "source": [
321
+ "from __future__ import annotations\n",
322
+ "\n",
323
+ "import os\n",
324
+ "import re\n",
325
+ "import warnings\n",
326
+ "from logging import getLogger\n",
327
+ "from multiprocessing import cpu_count\n",
328
+ "from pathlib import Path\n",
329
+ "from typing import Any\n",
330
+ "\n",
331
+ "import lightning.pytorch as pl\n",
332
+ "import torch\n",
333
+ "from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator\n",
334
+ "from lightning.pytorch.loggers import TensorBoardLogger\n",
335
+ "from lightning.pytorch.strategies.ddp import DDPStrategy\n",
336
+ "from lightning.pytorch.tuner import Tuner\n",
337
+ "from torch.cuda.amp import autocast\n",
338
+ "from torch.nn import functional as F\n",
339
+ "from torch.utils.data import DataLoader\n",
340
+ "from torch.utils.tensorboard.writer import SummaryWriter\n",
341
+ "\n",
342
+ "import so_vits_svc_fork.f0\n",
343
+ "import so_vits_svc_fork.modules.commons as commons\n",
344
+ "import so_vits_svc_fork.utils\n",
345
+ "\n",
346
+ "from so_vits_svc_fork import utils\n",
347
+ "from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset\n",
348
+ "from so_vits_svc_fork.logger import is_notebook\n",
349
+ "from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator\n",
350
+ "from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss\n",
351
+ "from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch\n",
352
+ "from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn\n",
353
+ "\n",
354
+ "from so_vits_svc_fork.train import VitsLightning, VCDataModule\n",
355
+ "\n",
356
+ "LOG = getLogger(__name__)\n",
357
+ "torch.set_float32_matmul_precision(\"high\")\n",
358
+ "\n",
359
+ "\n",
360
+ "from pathlib import Path\n",
361
+ "\n",
362
+ "from huggingface_hub import create_repo, upload_folder, login, list_repo_files, delete_file\n",
363
+ "\n",
364
+ "# if os.environ.get(\"HF_TOKEN\"):\n",
365
+ "# login(os.environ.get(\"HF_TOKEN\"))\n",
366
+ "\n",
367
+ "\n",
368
+ "class HuggingFacePushCallback(pl.Callback):\n",
369
+ " def __init__(self, repo_id, private=False, every=100):\n",
370
+ " self.repo_id = repo_id\n",
371
+ " self.private = private\n",
372
+ " self.every = every\n",
373
+ "\n",
374
+ " def on_validation_epoch_end(self, trainer, pl_module):\n",
375
+ " self.repo_url = create_repo(\n",
376
+ " repo_id=self.repo_id,\n",
377
+ " exist_ok=True,\n",
378
+ " private=self.private\n",
379
+ " )\n",
380
+ " self.repo_id = self.repo_url.repo_id\n",
381
+ " if pl_module.global_step == 0:\n",
382
+ " return\n",
383
+ " print(f\"\\n🤗 Pushing to Hugging Face Hub: {self.repo_url}...\")\n",
384
+ " model_dir = pl_module.hparams.model_dir\n",
385
+ " upload_folder(\n",
386
+ " repo_id=self.repo_id,\n",
387
+ " folder_path=model_dir,\n",
388
+ " path_in_repo=\".\",\n",
389
+ " commit_message=\"🍻 cheers\",\n",
390
+ " ignore_patterns=[\"*.git*\", \"*README.md*\", \"*__pycache__*\"],\n",
391
+ " )\n",
392
+ " ckpt_pattern = r'^(D_|G_)\\d+\\.pth$'\n",
393
+ " todelete = []\n",
394
+ " repo_ckpts = [x for x in list_repo_files(self.repo_id) if re.match(ckpt_pattern, x) and x not in [\"G_0.pth\", \"D_0.pth\"]]\n",
395
+ " local_ckpts = [x.name for x in Path(model_dir).glob(\"*.pth\") if re.match(ckpt_pattern, x.name)]\n",
396
+ " to_delete = set(repo_ckpts) - set(local_ckpts)\n",
397
+ "\n",
398
+ " for fname in to_delete:\n",
399
+ " print(f\"🗑 Deleting {fname} from repo\")\n",
400
+ " delete_file(fname, self.repo_id)\n",
401
+ "\n",
402
+ "\n",
403
+ "def train(\n",
404
+ " config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False\n",
405
+ "):\n",
406
+ " config_path = Path(config_path)\n",
407
+ " model_path = Path(model_path)\n",
408
+ "\n",
409
+ " hparams = utils.get_backup_hparams(config_path, model_path)\n",
410
+ " utils.ensure_pretrained_model(model_path, hparams.model.get(\"type_\", \"hifi-gan\"))\n",
411
+ "\n",
412
+ " datamodule = VCDataModule(hparams)\n",
413
+ " strategy = (\n",
414
+ " (\n",
415
+ " \"ddp_find_unused_parameters_true\"\n",
416
+ " if os.name != \"nt\"\n",
417
+ " else DDPStrategy(find_unused_parameters=True, process_group_backend=\"gloo\")\n",
418
+ " )\n",
419
+ " if torch.cuda.device_count() > 1\n",
420
+ " else \"auto\"\n",
421
+ " )\n",
422
+ " LOG.info(f\"Using strategy: {strategy}\")\n",
423
+ " \n",
424
+ " callbacks = []\n",
425
+ " if hparams.train.push_to_hub:\n",
426
+ " callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private))\n",
427
+ " if not is_notebook():\n",
428
+ " callbacks.append(pl.callbacks.RichProgressBar())\n",
429
+ " if callbacks == []:\n",
430
+ " callbacks = None\n",
431
+ "\n",
432
+ " trainer = pl.Trainer(\n",
433
+ " logger=TensorBoardLogger(\n",
434
+ " model_path, \"lightning_logs\", hparams.train.get(\"log_version\", 0)\n",
435
+ " ),\n",
436
+ " # profiler=\"simple\",\n",
437
+ " val_check_interval=hparams.train.eval_interval,\n",
438
+ " max_epochs=hparams.train.epochs,\n",
439
+ " check_val_every_n_epoch=None,\n",
440
+ " precision=\"16-mixed\"\n",
441
+ " if hparams.train.fp16_run\n",
442
+ " else \"bf16-mixed\"\n",
443
+ " if hparams.train.get(\"bf16_run\", False)\n",
444
+ " else 32,\n",
445
+ " strategy=strategy,\n",
446
+ " callbacks=callbacks,\n",
447
+ " benchmark=True,\n",
448
+ " enable_checkpointing=False,\n",
449
+ " )\n",
450
+ " tuner = Tuner(trainer)\n",
451
+ " model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)\n",
452
+ "\n",
453
+ " # automatic batch size scaling\n",
454
+ " batch_size = hparams.train.batch_size\n",
455
+ " batch_split = str(batch_size).split(\"-\")\n",
456
+ " batch_size = batch_split[0]\n",
457
+ " init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])\n",
458
+ " max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])\n",
459
+ " if batch_size == \"auto\":\n",
460
+ " batch_size = \"binsearch\"\n",
461
+ " if batch_size in [\"power\", \"binsearch\"]:\n",
462
+ " model.tuning = True\n",
463
+ " tuner.scale_batch_size(\n",
464
+ " model,\n",
465
+ " mode=batch_size,\n",
466
+ " datamodule=datamodule,\n",
467
+ " steps_per_trial=1,\n",
468
+ " init_val=init_val,\n",
469
+ " max_trials=max_trials,\n",
470
+ " )\n",
471
+ " model.tuning = False\n",
472
+ " else:\n",
473
+ " batch_size = int(batch_size)\n",
474
+ " # automatic learning rate scaling is not supported for multiple optimizers\n",
475
+ " \"\"\"if hparams.train.learning_rate == \"auto\":\n",
476
+ " lr_finder = tuner.lr_find(model)\n",
477
+ " LOG.info(lr_finder.results)\n",
478
+ " fig = lr_finder.plot(suggest=True)\n",
479
+ " fig.savefig(model_path / \"lr_finder.png\")\"\"\"\n",
480
+ "\n",
481
+ " trainer.fit(model, datamodule=datamodule)\n",
482
+ "\n",
483
+ "if __name__ == '__main__':\n",
484
+ " train('configs/44k/config.json', 'logs/44k')"
485
+ ]
486
+ },
487
+ {
488
+ "cell_type": "markdown",
489
+ "source": [
490
+ "## Train Cluster Model"
491
+ ],
492
+ "metadata": {
493
+ "id": "b2vNCDrSR8Xo"
494
+ }
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "execution_count": null,
499
+ "metadata": {
500
+ "id": "DBBEx-6Y1sOy"
501
+ },
502
+ "outputs": [],
503
+ "source": [
504
+ "! svc train-cluster"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "metadata": {
511
+ "id": "y_qYMuNY1tlm"
512
+ },
513
+ "outputs": [],
514
+ "source": [
515
+ "from huggingface_hub import upload_file\n",
516
+ "\n",
517
+ "upload_file(path_or_fileobj=\"/content/logs/44k/kmeans.pt\", repo_id=MODEL_REPO_ID, path_in_repo=\"kmeans.pt\")"
518
+ ]
519
+ }
520
+ ],
521
+ "metadata": {
522
+ "accelerator": "GPU",
523
+ "colab": {
524
+ "machine_shape": "hm",
525
+ "provenance": [],
526
+ "authorship_tag": "ABX9TyOQeFSvxop9rlCaglNlNoXI",
527
+ "include_colab_link": true
528
+ },
529
+ "gpuClass": "premium",
530
+ "kernelspec": {
531
+ "display_name": "Python 3",
532
+ "name": "python3"
533
+ },
534
+ "language_info": {
535
+ "name": "python"
536
+ }
537
+ },
538
+ "nbformat": 4,
539
+ "nbformat_minor": 0
540
+ }