diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..83cfd8dbb643612f79f25d84b65ac7e4b3c4fb7f --- /dev/null +++ b/.gitattributes @@ -0,0 +1,36 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/sync.yml b/.github/workflows/sync.yml new file mode 100644 index 0000000000000000000000000000000000000000..d4204761c085cc04792c6446ae5e7d55633719fc --- /dev/null +++ b/.github/workflows/sync.yml @@ -0,0 +1,26 @@ +name: Sync to Hugging Face Spaces + +on: + push: + branches: + - main + +jobs: + sync: + name: Sync + runs-on: ubuntu-latest + + steps: + - name: Checkout Repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Sync to Hugging Face Spaces + uses: JacobLinCool/huggingface-sync@v1 + with: + github: ${{ secrets.GITHUB_TOKEN }} + user: jacoblincool # Hugging Face username or organization name + space: ZeroRVC # Hugging Face space name + token: ${{ secrets.HF_TOKEN }} # Hugging Face token + configuration: headers.yaml diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..2d30cb63522405061de0db9926bf9afa021e42c6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +.DS_Store +*.pyc +__pycache__ +dist/ +logs/ +separated/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..38bc5a7b8423e62742d0ac3f08527fb11ba20b2d --- /dev/null +++ b/LICENSE @@ -0,0 +1,19 @@ +Copyright (c) 2024 Jacob Lin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5cdec13ed806aa9e69bb6af2806c10796b235ace --- /dev/null +++ b/README.md @@ -0,0 +1,57 @@ +--- +title: ZeroRVC +emoji: 🎙️ +colorFrom: gray +colorTo: gray +sdk: gradio +sdk_version: 4.37.2 +app_file: app.py +pinned: false +--- + +# ZeroRVC + +Run Retrieval-based Voice Conversion training and inference with ease. + +## Features + +- [x] Dataset Preparation +- [x] Hugging Face Datasets Integration +- [x] Hugging Face Accelerate Integration +- [x] Trainer API +- [x] Inference API + - [ ] Index Support +- [x] Tensorboard Support +- [ ] FP16 Support + +## Dataset Preparation + +ZeroRVC provides a simple API to prepare your dataset for training. You only need to provide the path to your audio files. The feature extraction models will be downloaded automatically, or you can provide your own with the `hubert` and `rmvpe` arguments. + +```py +from datasets import load_dataset +from zerorvc import prepare, RVCTrainer + +dataset = load_dataset("my-audio-dataset") +dataset = prepare(dataset) + +trainer = RVCTrainer( + "my-rvc-model", + dataset_train=dataset["train"], + dataset_test=dataset["test"], +) +trainer.train(epochs=100, batch_size=8, upload="someone/rvc-test-1") +``` + +## Inference + +ZeroRVC provides an easy API to convert your voice with the trained model. + +```py +from zerorvc import RVC +import soundfile as sf + +rvc = RVC.from_pretrained("someone/rvc-test-1") +samples = rvc.convert("test.mp3") +sf.write("output.wav", samples, rvc.sr) +``` diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..faaea632750a2c695143b42f2ec7335937fe4cbf --- /dev/null +++ b/app.py @@ -0,0 +1,49 @@ +import gradio as gr +from app.settings import SettingsTab +from app.tutorial import TutotialTab +from app.dataset import DatasetTab +from app.train import TrainTab +from app.infer import InferenceTab +from app.zero import zero_is_available + +if zero_is_available: + import torch + + torch.backends.cuda.matmul.allow_tf32 = True + + +with gr.Blocks() as app: + gr.Markdown("# ZeroRVC") + gr.Markdown( + "Run Retrieval-based Voice Conversion training and inference on Hugging Face ZeroGPU or locally." + ) + + settings = SettingsTab() + tutorial = TutotialTab() + dataset = DatasetTab() + training = TrainTab() + inference = InferenceTab() + + with gr.Accordion(label="Environment Settings"): + settings.ui() + + with gr.Tabs(): + with gr.Tab(label="Tutorial", id=0): + tutorial.ui() + + with gr.Tab(label="Dataset", id=1): + dataset.ui() + + with gr.Tab(label="Training", id=2): + training.ui() + + with gr.Tab(label="Inference", id=3): + inference.ui() + + settings.build() + tutorial.build() + dataset.build(settings.exp_dir, settings.hf_token) + training.build(settings.exp_dir, settings.hf_token) + inference.build(settings.exp_dir) + + app.launch() diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/app/constants.py b/app/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..a5e1a1e424940db07c890b4827bebbcacac323a5 --- /dev/null +++ b/app/constants.py @@ -0,0 +1,13 @@ +import os +from pathlib import Path + +HF_TOKEN = os.environ.get("HF_TOKEN") + +ROOT_EXP_DIR = Path( + os.environ.get("ROOT_EXP_DIR") + or os.path.join(os.path.dirname(os.path.abspath(__file__)), "../logs") +).resolve() +ROOT_EXP_DIR.mkdir(exist_ok=True, parents=True) + +BATCH_SIZE = int(os.environ.get("BATCH_SIZE") or 8) +TRAINING_EPOCHS = int(os.environ.get("TRAINING_EPOCHS") or 10) diff --git a/app/dataset.py b/app/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d3afa4c204ff7cbc1b2d514da3a773f5a3d72241 --- /dev/null +++ b/app/dataset.py @@ -0,0 +1,225 @@ +import os +import gradio as gr +import zipfile +import tempfile +from zerorvc import prepare +from datasets import load_dataset, load_from_disk +from .constants import ROOT_EXP_DIR, BATCH_SIZE +from .zero import zero +from .model import accelerator + + +def extract_audio_files(zip_file: str, target_dir: str) -> list[str]: + with zipfile.ZipFile(zip_file, "r") as zip_ref: + zip_ref.extractall(target_dir) + + audio_files = [ + os.path.join(target_dir, f) + for f in os.listdir(target_dir) + if f.endswith((".wav", ".mp3", ".ogg")) + ] + if not audio_files: + raise gr.Error("No audio files found at the top level of the zip file") + + return audio_files + + +def make_dataset_from_zip(exp_dir: str, zip_file: str): + if not exp_dir: + exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR) + print(f"Using exp dir: {exp_dir}") + + data_dir = os.path.join(exp_dir, "raw_data") + if not os.path.exists(data_dir): + os.makedirs(data_dir) + extract_audio_files(zip_file, data_dir) + + ds = prepare( + data_dir, + accelerator=accelerator, + batch_size=BATCH_SIZE, + stage=1, + ) + + return exp_dir, str(ds) + + +@zero(duration=120) +def make_dataset_from_zip_stage_2(exp_dir: str): + data_dir = os.path.join(exp_dir, "raw_data") + ds = prepare( + data_dir, + accelerator=accelerator, + batch_size=BATCH_SIZE, + stage=2, + ) + return exp_dir, str(ds) + + +def make_dataset_from_zip_stage_3(exp_dir: str): + data_dir = os.path.join(exp_dir, "raw_data") + ds = prepare( + data_dir, + accelerator=accelerator, + batch_size=BATCH_SIZE, + stage=3, + ) + + dataset = os.path.join(exp_dir, "dataset") + ds.save_to_disk(dataset) + return exp_dir, str(ds) + + +def make_dataset_from_repo(repo: str, hf_token: str): + ds = load_dataset(repo, token=hf_token) + ds = prepare( + ds, + accelerator=accelerator, + batch_size=BATCH_SIZE, + stage=1, + ) + return str(ds) + + +@zero(duration=120) +def make_dataset_from_repo_stage_2(repo: str, hf_token: str): + ds = load_dataset(repo, token=hf_token) + ds = prepare( + ds, + accelerator=accelerator, + batch_size=BATCH_SIZE, + stage=2, + ) + return str(ds) + + +def make_dataset_from_repo_stage_3(exp_dir: str, repo: str, hf_token: str): + ds = load_dataset(repo, token=hf_token) + ds = prepare( + ds, + accelerator=accelerator, + batch_size=BATCH_SIZE, + stage=3, + ) + + if not exp_dir: + exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR) + print(f"Using exp dir: {exp_dir}") + + dataset = os.path.join(exp_dir, "dataset") + ds.save_to_disk(dataset) + return exp_dir, str(ds) + + +def use_dataset(exp_dir: str, repo: str, hf_token: str): + gr.Info("Fetching dataset") + ds = load_dataset(repo, token=hf_token) + + if not exp_dir: + exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR) + print(f"Using exp dir: {exp_dir}") + + dataset = os.path.join(exp_dir, "dataset") + ds.save_to_disk(dataset) + return exp_dir, str(ds) + + +def upload_dataset(exp_dir: str, repo: str, hf_token: str): + dataset = os.path.join(exp_dir, "dataset") + if not os.path.exists(dataset): + raise gr.Error("Dataset not found") + + gr.Info("Uploading dataset") + ds = load_from_disk(dataset) + ds.push_to_hub(repo, token=hf_token, private=True) + gr.Info("Dataset uploaded successfully") + + +class DatasetTab: + def __init__(self): + pass + + def ui(self): + gr.Markdown("# Dataset") + gr.Markdown("The suggested dataset size is > 5 minutes of audio.") + + gr.Markdown("## Create Dataset from ZIP") + gr.Markdown( + "Create a dataset by simply upload a zip file containing audio files. The audio files should be at the top level of the zip file." + ) + with gr.Row(): + self.zip_file = gr.File( + label="Upload a zip file containing audio files", + file_types=["zip"], + ) + self.make_ds_from_dir = gr.Button( + value="Create Dataset from ZIP", variant="primary" + ) + + gr.Markdown("## Create Dataset from Dataset Repository") + gr.Markdown( + "You can also create a dataset from any Hugging Face dataset repository that has 'audio' column." + ) + with gr.Row(): + self.repo = gr.Textbox( + label="Hugging Face Dataset Repository", + placeholder="username/dataset-name", + ) + self.make_ds_from_repo = gr.Button( + value="Create Dataset from Repo", variant="primary" + ) + + gr.Markdown("## Sync Preprocessed Dataset") + gr.Markdown( + "After you have preprocessed the dataset, you can upload the dataset to Hugging Face. And fetch it back later directly." + ) + with gr.Row(): + self.preprocessed_repo = gr.Textbox( + label="Hugging Face Dataset Repository", + placeholder="username/dataset-name", + ) + self.fetch_ds = gr.Button(value="Fetch Dataset", variant="primary") + self.upload_ds = gr.Button(value="Upload Dataset", variant="primary") + + self.ds_state = gr.Textbox(label="Dataset Info", lines=5) + + def build(self, exp_dir: gr.Textbox, hf_token: gr.Textbox): + self.make_ds_from_dir.click( + fn=make_dataset_from_zip, + inputs=[exp_dir, self.zip_file], + outputs=[exp_dir, self.ds_state], + ).success( + fn=make_dataset_from_zip_stage_2, + inputs=[exp_dir], + outputs=[exp_dir, self.ds_state], + ).success( + fn=make_dataset_from_zip_stage_3, + inputs=[exp_dir], + outputs=[exp_dir, self.ds_state], + ) + + self.make_ds_from_repo.click( + fn=make_dataset_from_repo, + inputs=[self.repo, hf_token], + outputs=[self.ds_state], + ).success( + fn=make_dataset_from_repo_stage_2, + inputs=[self.repo, hf_token], + outputs=[self.ds_state], + ).success( + fn=make_dataset_from_repo_stage_3, + inputs=[exp_dir, self.repo, hf_token], + outputs=[exp_dir, self.ds_state], + ) + + self.fetch_ds.click( + fn=use_dataset, + inputs=[exp_dir, self.preprocessed_repo, hf_token], + outputs=[exp_dir, self.ds_state], + ) + + self.upload_ds.click( + fn=upload_dataset, + inputs=[exp_dir, self.preprocessed_repo, hf_token], + outputs=[], + ) diff --git a/app/dataset_maker.py b/app/dataset_maker.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf9a4610524475ed37ed6cf4911bb1ddee0fb35 --- /dev/null +++ b/app/dataset_maker.py @@ -0,0 +1,225 @@ +import yt_dlp +import numpy as np +import librosa +import soundfile as sf +import os +import zipfile + + +# Function to download audio from YouTube and save it as a WAV file +def download_youtube_audio(url, audio_name): + ydl_opts = { + "format": "bestaudio/best", + "postprocessors": [ + { + "key": "FFmpegExtractAudio", + "preferredcodec": "wav", + } + ], + "outtmpl": f"youtubeaudio/{audio_name}", # Output template + } + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + ydl.download([url]) + return f"youtubeaudio/{audio_name}.wav" + + +# Function to calculate RMS +def get_rms(y, frame_length=2048, hop_length=512, pad_mode="constant"): + padding = (int(frame_length // 2), int(frame_length // 2)) + y = np.pad(y, padding, mode=pad_mode) + + axis = -1 + out_strides = y.strides + tuple([y.strides[axis]]) + x_shape_trimmed = list(y.shape) + x_shape_trimmed[axis] -= frame_length - 1 + out_shape = tuple(x_shape_trimmed) + tuple([frame_length]) + xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides) + if axis < 0: + target_axis = axis - 1 + else: + target_axis = axis + 1 + xw = np.moveaxis(xw, -1, target_axis) + slices = [slice(None)] * xw.ndim + slices[axis] = slice(0, None, hop_length) + x = xw[tuple(slices)] + + power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True) + return np.sqrt(power) + + +# Slicer class +class Slicer: + def __init__( + self, + sr, + threshold=-40.0, + min_length=5000, + min_interval=300, + hop_size=20, + max_sil_kept=5000, + ): + if not min_length >= min_interval >= hop_size: + raise ValueError( + "The following condition must be satisfied: min_length >= min_interval >= hop_size" + ) + if not max_sil_kept >= hop_size: + raise ValueError( + "The following condition must be satisfied: max_sil_kept >= hop_size" + ) + min_interval = sr * min_interval / 1000 + self.threshold = 10 ** (threshold / 20.0) + self.hop_size = round(sr * hop_size / 1000) + self.win_size = min(round(min_interval), 4 * self.hop_size) + self.min_length = round(sr * min_length / 1000 / self.hop_size) + self.min_interval = round(min_interval / self.hop_size) + self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) + + def _apply_slice(self, waveform, begin, end): + if len(waveform.shape) > 1: + return waveform[ + :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size) + ] + else: + return waveform[ + begin * self.hop_size : min(waveform.shape[0], end * self.hop_size) + ] + + def slice(self, waveform): + if len(waveform.shape) > 1: + samples = waveform.mean(axis=0) + else: + samples = waveform + if samples.shape[0] <= self.min_length: + return [waveform] + rms_list = get_rms( + y=samples, frame_length=self.win_size, hop_length=self.hop_size + ).squeeze(0) + sil_tags = [] + silence_start = None + clip_start = 0 + for i, rms in enumerate(rms_list): + if rms < self.threshold: + if silence_start is None: + silence_start = i + continue + if silence_start is None: + continue + is_leading_silence = silence_start == 0 and i > self.max_sil_kept + need_slice_middle = ( + i - silence_start >= self.min_interval + and i - clip_start >= self.min_length + ) + if not is_leading_silence and not need_slice_middle: + silence_start = None + continue + if i - silence_start <= self.max_sil_kept: + pos = rms_list[silence_start : i + 1].argmin() + silence_start + if silence_start == 0: + sil_tags.append((0, pos)) + else: + sil_tags.append((pos, pos)) + clip_start = pos + elif i - silence_start <= self.max_sil_kept * 2: + pos = rms_list[ + i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 + ].argmin() + pos += i - self.max_sil_kept + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + clip_start = pos_r + else: + sil_tags.append((min(pos_l, pos), max(pos_r, pos))) + clip_start = max(pos_r, pos) + else: + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + else: + sil_tags.append((pos_l, pos_r)) + clip_start = pos_r + silence_start = None + total_frames = rms_list.shape[0] + if ( + silence_start is not None + and total_frames - silence_start >= self.min_interval + ): + silence_end = min(total_frames, silence_start + self.max_sil_kept) + pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start + sil_tags.append((pos, total_frames + 1)) + if len(sil_tags) == 0: + return [waveform] + else: + chunks = [] + if sil_tags[0][0] > 0: + chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0])) + for i in range(len(sil_tags) - 1): + chunks.append( + self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]) + ) + if sil_tags[-1][1] < total_frames: + chunks.append( + self._apply_slice(waveform, sil_tags[-1][1], total_frames) + ) + return chunks + + +# Function to slice and save audio chunks +def slice_audio(file_path, audio_name): + audio, sr = librosa.load(file_path, sr=None, mono=False) + os.makedirs(f"dataset/{audio_name}", exist_ok=True) + slicer = Slicer( + sr=sr, + threshold=-40, + min_length=5000, + min_interval=500, + hop_size=10, + max_sil_kept=500, + ) + chunks = slicer.slice(audio) + for i, chunk in enumerate(chunks): + if len(chunk.shape) > 1: + chunk = chunk.T + sf.write(f"dataset/{audio_name}/split_{i}.wav", chunk, sr) + return f"dataset/{audio_name}" + + +# Function to zip the dataset directory +def zip_directory(directory_path, audio_name): + zip_file = f"dataset/{audio_name}.zip" + os.makedirs(os.path.dirname(zip_file), exist_ok=True) # Ensure the directory exists + with zipfile.ZipFile(zip_file, "w", zipfile.ZIP_DEFLATED) as zipf: + for root, dirs, files in os.walk(directory_path): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, start=directory_path) + zipf.write(file_path, arcname) + return zip_file + + +# Gradio interface +def process_audio(url, audio_name): + file_path = download_youtube_audio(url, audio_name) + dataset_path = slice_audio(file_path, audio_name) + zip_file = zip_directory(dataset_path, audio_name) + return zip_file, print(f"{zip_file} successfully processed") diff --git a/app/infer.py b/app/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..bf23c29f38ca652be3133be5c6d547cb75cedf1f --- /dev/null +++ b/app/infer.py @@ -0,0 +1,164 @@ +import os +import shutil +import hashlib +from pathlib import Path +from typing import Tuple +from demucs.separate import main as demucs +import gradio as gr +import numpy as np +import soundfile as sf +from zerorvc import RVC +from .zero import zero +from .model import device +import yt_dlp + + +def download_audio(url): + ydl_opts = { + "format": "bestaudio/best", + "outtmpl": "ytdl/%(title)s.%(ext)s", + "postprocessors": [ + { + "key": "FFmpegExtractAudio", + "preferredcodec": "wav", + "preferredquality": "192", + } + ], + } + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info_dict = ydl.extract_info(url, download=True) + file_path = ydl.prepare_filename(info_dict).rsplit(".", 1)[0] + ".wav" + sample_rate, audio_data = read(file_path) + audio_array = np.asarray(audio_data, dtype=np.int16) + + return sample_rate, audio_array + + +@zero(duration=120) +def infer( + exp_dir: str, original_audio: str, pitch_mod: int, protect: float +) -> Tuple[int, np.ndarray]: + checkpoint_dir = os.path.join(exp_dir, "checkpoints") + if not os.path.exists(checkpoint_dir): + raise gr.Error("Model not found") + + # rename the original audio to the hash + with open(original_audio, "rb") as f: + original_audio_hash = hashlib.md5(f.read()).hexdigest() + ext = Path(original_audio).suffix + original_audio_hashed = os.path.join(exp_dir, f"{original_audio_hash}{ext}") + shutil.copy(original_audio, original_audio_hashed) + + out = os.path.join("separated", "htdemucs", original_audio_hash, "vocals.wav") + if not os.path.exists(out): + demucs( + [ + "--two-stems", + "vocals", + "-d", + str(device), + "-n", + "htdemucs", + original_audio_hashed, + ] + ) + + rvc = RVC.from_pretrained(checkpoint_dir) + samples = rvc.convert(out, pitch_modification=pitch_mod, protect=protect) + file = os.path.join(exp_dir, "infer.wav") + sf.write(file, samples, rvc.sr) + + return file + + +def merge(exp_dir: str, original_audio: str, vocal: Tuple[int, np.ndarray]) -> str: + with open(original_audio, "rb") as f: + original_audio_hash = hashlib.md5(f.read()).hexdigest() + music = os.path.join("separated", "htdemucs", original_audio_hash, "no_vocals.wav") + + tmp = os.path.join(exp_dir, "tmp.wav") + sf.write(tmp, vocal[1], vocal[0]) + + os.system( + f"ffmpeg -i {music} -i {tmp} -filter_complex '[1]volume=2[a];[0][a]amix=inputs=2:duration=first:dropout_transition=2' -ac 2 -y {tmp}.merged.mp3" + ) + + return f"{tmp}.merged.mp3" + + +class InferenceTab: + def __init__(self): + pass + + def ui(self): + gr.Markdown("# Inference") + gr.Markdown( + "After trained model is pruned, you can use it to infer on new music. \n" + "Upload the original audio and adjust the F0 add value to generate the inferred audio." + ) + + with gr.Row(): + self.original_audio = gr.Audio( + label="Upload original audio", + type="filepath", + show_download_button=True, + ) + with gr.Accordion("inference by Link", open=False): + with gr.Row(): + youtube_link = gr.Textbox( + label="Link", + placeholder="Paste the link here", + interactive=True, + ) + with gr.Row(): + gr.Markdown( + "You can paste the link to the video/audio from many sites, check the complete list [here](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)" + ) + with gr.Row(): + download_button = gr.Button("Download!", variant="primary") + download_button.click( + download_audio, [youtube_link], [self.original_audio] + ) + + with gr.Column(): + self.pitch_mod = gr.Slider( + label="Pitch Modification +/-", + minimum=-16, + maximum=16, + step=1, + value=0, + ) + self.protect = gr.Slider( + label="Protect", + minimum=0, + maximum=0.5, + step=0.01, + value=0.33, + ) + + self.infer_btn = gr.Button(value="Infer", variant="primary") + with gr.Row(): + self.infer_output = gr.Audio( + label="Inferred audio", show_download_button=True, format="mp3" + ) + with gr.Row(): + self.merge_output = gr.Audio( + label="Merged audio", show_download_button=True, format="mp3" + ) + + def build(self, exp_dir: gr.Textbox): + self.infer_btn.click( + fn=infer, + inputs=[ + exp_dir, + self.original_audio, + self.pitch_mod, + self.protect, + ], + outputs=[self.infer_output], + ).success( + fn=merge, + inputs=[exp_dir, self.original_audio, self.infer_output], + outputs=[self.merge_output], + ) diff --git a/app/model.py b/app/model.py new file mode 100644 index 0000000000000000000000000000000000000000..5e796adfbd690c9d1c4f10336563624ac68f4c90 --- /dev/null +++ b/app/model.py @@ -0,0 +1,17 @@ +import logging +from accelerate import Accelerator +from zerorvc import load_hubert, load_rmvpe + +logger = logging.getLogger(__name__) + +accelerator = Accelerator() +device = accelerator.device + +logger.info(f"device: {device}") +logger.info(f"mixed_precision: {accelerator.mixed_precision}") + +rmvpe = load_rmvpe(device=device) +logger.info("RMVPE model loaded.") + +hubert = load_hubert(device=device) +logger.info("HuBERT model loaded.") diff --git a/app/settings.py b/app/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..042948688da318273aea2d6cb7101e32b55b5b86 --- /dev/null +++ b/app/settings.py @@ -0,0 +1,26 @@ +import gradio as gr +from .constants import HF_TOKEN + + +class SettingsTab: + def __init__(self): + pass + + def ui(self): + self.exp_dir = gr.Textbox( + label="Temporary Experiment Directory (auto-managed)", + placeholder="It will be auto-generated after setup", + interactive=True, + ) + gr.Markdown( + "### Sync with Hugging Face 🤗\n\nThe access token will be use to upload/download the dataset and model." + ) + self.hf_token = gr.Textbox( + label="Hugging Face Access Token", + placeholder="Paste your Hugging Face access token here (hf_...)", + value=HF_TOKEN, + interactive=True, + ) + + def build(self): + pass diff --git a/app/train.py b/app/train.py new file mode 100644 index 0000000000000000000000000000000000000000..56c0de2392fe8529eca3110fe10e7fbeeed71f45 --- /dev/null +++ b/app/train.py @@ -0,0 +1,169 @@ +import os +import tempfile +import gradio as gr +import torch +from zerorvc import RVCTrainer, pretrained_checkpoints, SynthesizerTrnMs768NSFsid +from zerorvc.trainer import TrainingCheckpoint +from datasets import load_from_disk +from huggingface_hub import snapshot_download +from .zero import zero +from .model import accelerator, device +from .constants import BATCH_SIZE, ROOT_EXP_DIR, TRAINING_EPOCHS + + +@zero(duration=240) +def train_model(exp_dir: str, progress=gr.Progress()): + dataset = os.path.join(exp_dir, "dataset") + if not os.path.exists(dataset): + raise gr.Error("Dataset not found. Please prepare the dataset first.") + + ds = load_from_disk(dataset) + checkpoint_dir = os.path.join(exp_dir, "checkpoints") + trainer = RVCTrainer(checkpoint_dir) + + resume_from = trainer.latest_checkpoint() + if resume_from is None: + resume_from = pretrained_checkpoints() + gr.Info(f"Starting training from pretrained checkpoints.") + else: + gr.Info(f"Resuming training from {resume_from}") + + tqdm = progress.tqdm( + trainer.train( + dataset=ds["train"], + resume_from=resume_from, + batch_size=BATCH_SIZE, + epochs=TRAINING_EPOCHS, + accelerator=accelerator, + ), + total=TRAINING_EPOCHS, + unit="epochs", + desc="Training", + ) + + for ckpt in tqdm: + info = f"Epoch: {ckpt.epoch} loss: (gen: {ckpt.loss_gen:.4f}, fm: {ckpt.loss_fm:.4f}, mel: {ckpt.loss_mel:.4f}, kl: {ckpt.loss_kl:.4f}, disc: {ckpt.loss_disc:.4f})" + print(info) + latest: TrainingCheckpoint = ckpt + + latest.save(trainer.checkpoint_dir) + latest.G.save_pretrained(trainer.checkpoint_dir) + + result = f"{TRAINING_EPOCHS} epochs trained. Latest loss: (gen: {latest.loss_gen:.4f}, fm: {latest.loss_fm:.4f}, mel: {latest.loss_mel:.4f}, kl: {latest.loss_kl:.4f}, disc: {latest.loss_disc:.4f})" + + del trainer + if device.type == "cuda": + torch.cuda.empty_cache() + + return result + + +def upload_model(exp_dir: str, repo: str, hf_token: str): + checkpoint_dir = os.path.join(exp_dir, "checkpoints") + if not os.path.exists(checkpoint_dir): + raise gr.Error("Model not found") + + gr.Info("Uploading model") + model = SynthesizerTrnMs768NSFsid.from_pretrained(checkpoint_dir) + model.push_to_hub(repo, token=hf_token, private=True) + gr.Info("Model uploaded successfully") + + +def upload_checkpoints(exp_dir: str, repo: str, hf_token: str): + checkpoint_dir = os.path.join(exp_dir, "checkpoints") + if not os.path.exists(checkpoint_dir): + raise gr.Error("Checkpoints not found") + + gr.Info("Uploading checkpoints") + trainer = RVCTrainer(checkpoint_dir) + trainer.push_to_hub(repo, token=hf_token, private=True) + gr.Info("Checkpoints uploaded successfully") + + +def fetch_model(exp_dir: str, repo: str, hf_token: str): + if not exp_dir: + exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR) + checkpoint_dir = os.path.join(exp_dir, "checkpoints") + + gr.Info("Fetching model") + files = ["README.md", "config.json", "model.safetensors"] + snapshot_download( + repo, token=hf_token, local_dir=checkpoint_dir, allow_patterns=files + ) + gr.Info("Model fetched successfully") + + return exp_dir + + +def fetch_checkpoints(exp_dir: str, repo: str, hf_token: str): + if not exp_dir: + exp_dir = tempfile.mkdtemp(dir=ROOT_EXP_DIR) + checkpoint_dir = os.path.join(exp_dir, "checkpoints") + + gr.Info("Fetching checkpoints") + snapshot_download(repo, token=hf_token, local_dir=checkpoint_dir) + gr.Info("Checkpoints fetched successfully") + + return exp_dir + + +class TrainTab: + def __init__(self): + pass + + def ui(self): + gr.Markdown("# Training") + gr.Markdown( + "You can start training the model by clicking the button below. " + f"Each time you click the button, the model will train for {TRAINING_EPOCHS} epochs, which takes about 3 minutes on ZeroGPU (A100). " + ) + + with gr.Row(): + self.train_btn = gr.Button(value="Train", variant="primary") + self.result = gr.Textbox(label="Training Result", lines=3) + + gr.Markdown("## Sync Model and Checkpoints with Hugging Face") + gr.Markdown( + "You can upload the trained model and checkpoints to Hugging Face for sharing or further training." + ) + + self.repo = gr.Textbox(label="Repository ID", placeholder="username/repo") + with gr.Row(): + self.upload_model_btn = gr.Button(value="Upload Model", variant="primary") + self.upload_checkpoints_btn = gr.Button( + value="Upload Checkpoints", variant="primary" + ) + with gr.Row(): + self.fetch_mode_btn = gr.Button(value="Fetch Model", variant="primary") + self.fetch_checkpoints_btn = gr.Button( + value="Fetch Checkpoints", variant="primary" + ) + + def build(self, exp_dir: gr.Textbox, hf_token: gr.Textbox): + self.train_btn.click( + fn=train_model, + inputs=[exp_dir], + outputs=[self.result], + ) + + self.upload_model_btn.click( + fn=upload_model, + inputs=[exp_dir, self.repo, hf_token], + ) + + self.upload_checkpoints_btn.click( + fn=upload_checkpoints, + inputs=[exp_dir, self.repo, hf_token], + ) + + self.fetch_mode_btn.click( + fn=fetch_model, + inputs=[exp_dir, self.repo, hf_token], + outputs=[exp_dir], + ) + + self.fetch_checkpoints_btn.click( + fn=fetch_checkpoints, + inputs=[exp_dir, self.repo, hf_token], + outputs=[exp_dir], + ) diff --git a/app/tutorial.py b/app/tutorial.py new file mode 100644 index 0000000000000000000000000000000000000000..e4f2e638bb31ecb98a33a5dae25fdae5e09e43aa --- /dev/null +++ b/app/tutorial.py @@ -0,0 +1,30 @@ +import gradio as gr + + +class TutotialTab: + def __init__(self): + pass + + def ui(self): + gr.Markdown( + """ + # Welcome to ZeroRVC! + + > If you are more satisfied with Python code, you can also [use the Python API to run ZeroRVC](https://pypi.org/project/zerorvc/). + + ZeroRVC is a toolkit for training and inference of retrieval-based voice conversion models. + + By leveraging the power of Hugging Face ZeroGPU, you can train your model in minutes without setting up the environment. + + ## How to Use + + There are 3 main steps to use ZeroRVC: + + - **Make Dataset**: Prepare your dataset for training. You can upload a zip file containing audio files. + - **Model Training**: Train your model using the prepared dataset. + - **Model Inference**: Try your model. + """ + ) + + def build(self): + pass diff --git a/app/zero.py b/app/zero.py new file mode 100644 index 0000000000000000000000000000000000000000..beb233742df21b5cd4f44160ad0ce6f67ce42816 --- /dev/null +++ b/app/zero.py @@ -0,0 +1,24 @@ +import os +import logging + +logger = logging.getLogger(__name__) + +zero_is_available = "SPACES_ZERO_GPU" in os.environ + +if zero_is_available: + import spaces # type: ignore + + logger.info("ZeroGPU is available") +else: + logger.info("ZeroGPU is not available") + + +# a decorator that applies the spaces.GPU decorator if zero is available +def zero(duration=60): + def wrapper(func): + if zero_is_available: + return spaces.GPU(func, duration=duration) + else: + return func + + return wrapper diff --git a/example-dataset.py b/example-dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..acd0083de930fcca2a50ca1f0af64c1b53598df1 --- /dev/null +++ b/example-dataset.py @@ -0,0 +1,9 @@ +import os +from zerorvc import prepare + +HF_TOKEN = os.environ.get("HF_TOKEN") + +dataset = prepare("./my-voices") +print(dataset) + +dataset.push_to_hub("my-rvc-dataset", token=HF_TOKEN, private=True) diff --git a/example-infer.py b/example-infer.py new file mode 100644 index 0000000000000000000000000000000000000000..e4584a3c74a7c49ea2ce38c646db76135cbdf931 --- /dev/null +++ b/example-infer.py @@ -0,0 +1,15 @@ +import os +from zerorvc import RVC +import soundfile as sf + +HF_TOKEN = os.environ.get("HF_TOKEN") +MODEL = "JacobLinCool/my-rvc-model3" + +rvc = RVC.from_pretrained(MODEL, token=HF_TOKEN) +samples = rvc.convert("test.mp3") +sf.write("output.wav", samples, rvc.sr) + +pitch_modifications = [-12, -8, -4, 4, 8, 12] +for pitch_modification in pitch_modifications: + samples = rvc.convert("test.mp3", pitch_modification=pitch_modification) + sf.write(f"output-{pitch_modification}.wav", samples, rvc.sr) diff --git a/example-train.py b/example-train.py new file mode 100644 index 0000000000000000000000000000000000000000..b58e97596aac9e5eac4bb90885bbe1982228071a --- /dev/null +++ b/example-train.py @@ -0,0 +1,38 @@ +import os +from datasets import load_dataset +from tqdm import tqdm +from zerorvc import RVCTrainer, pretrained_checkpoints + +HF_TOKEN = os.environ.get("HF_TOKEN") +EPOCHS = 100 +BATCH_SIZE = 8 +DATASET = "JacobLinCool/my-rvc-dataset" +MODEL = "JacobLinCool/my-rvc-model" + +dataset = load_dataset(DATASET, token=HF_TOKEN) +print(dataset) + +trainer = RVCTrainer(checkpoint_dir="./checkpoints") +training = tqdm( + trainer.train( + dataset=dataset["train"], + resume_from=pretrained_checkpoints(), # resume training from the pretrained VCTK checkpoint + epochs=EPOCHS, + batch_size=BATCH_SIZE, + ), + total=EPOCHS, +) + +# Training loop: iterate over epochs +for checkpoint in training: + training.set_description( + f"Epoch {checkpoint.epoch}/{EPOCHS} loss: (gen: {checkpoint.loss_gen:.4f}, fm: {checkpoint.loss_fm:.4f}, mel: {checkpoint.loss_mel:.4f}, kl: {checkpoint.loss_kl:.4f}, disc: {checkpoint.loss_disc:.4f})" + ) + + # Save checkpoint every 10 epochs + if checkpoint.epoch % 10 == 0: + checkpoint.save(checkpoint_dir=trainer.checkpoint_dir) + # Directly push the synthesizer to the Hugging Face Hub + checkpoint.G.push_to_hub(MODEL, token=HF_TOKEN, private=True) + +print("Training completed.") diff --git a/headers.yaml b/headers.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6bc455f032de9d816533598ffe92b73b438f77d2 --- /dev/null +++ b/headers.yaml @@ -0,0 +1,8 @@ +title: ZeroRVC +emoji: 🎙️ +colorFrom: gray +colorTo: gray +sdk: gradio +sdk_version: 4.37.2 +app_file: app.py +pinned: false diff --git a/my-voices/.gitignore b/my-voices/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..d8dd7532abcc65af52e9db03c516274e3d674dc1 --- /dev/null +++ b/my-voices/.gitignore @@ -0,0 +1 @@ +*.wav diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..85834ecd25ab572c921d81044784d8b943e5edf8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[project] +name = "zerorvc" +version = "0.0.19" +authors = [{ name = "Jacob Lin", email = "jacob@csie.cool" }] +description = "Run Retrieval-based Voice Conversion training and inference with ease." +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +dependencies = [ + "numpy>=1.0.0", + "torch>=2.0.0", + "datasets", + "accelerate", + "huggingface_hub", + "tqdm", + "librosa", + "scipy", + "tensorboard", +] + +[project.urls] +Homepage = "https://github.com/jacoblincool/zero-rvc" +Issues = "https://github.com/jacoblincool/zero-rvc/issues" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.sdist] +include = ["zerorvc/**/*", "pyproject.toml", "README.md", "LICENSE"] +[tool.hatch.build.targets.wheel] +packages = ["zerorvc"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..2d847640f7f3ae4d83f8aae37a90724d04624e69 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +zerorvc>=0.0.10 + +# gradio app deps +gradio +demucs==4.0.1 +yt_dlp +tensorboard diff --git a/zerorvc/__init__.py b/zerorvc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bde8336b218c5218126d3ef8269a0c0dcf83a72f --- /dev/null +++ b/zerorvc/__init__.py @@ -0,0 +1,8 @@ +from .rvc import RVC +from .trainer import RVCTrainer +from .dataset import prepare +from .synthesizer import SynthesizerTrnMs768NSFsid +from .pretrained import pretrained_checkpoints +from .f0 import load_rmvpe, RMVPE, F0Extractor +from .hubert import load_hubert, HubertModel, HubertFeatureExtractor +from .auto_loader import auto_loaded_model diff --git a/zerorvc/assets/mute/mute48k.wav b/zerorvc/assets/mute/mute48k.wav new file mode 100644 index 0000000000000000000000000000000000000000..57e2db6dec3b3546fadbc4094e75d42bc465a1cf --- /dev/null +++ b/zerorvc/assets/mute/mute48k.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2f2bb4daaa106e351aebb001e5a25de985c0b472f22e8d60676bc924a79056ee +size 288078 diff --git a/zerorvc/auto_loader.py b/zerorvc/auto_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..6444958cf846c9646bb68446e7dfbe9f6fdbe07b --- /dev/null +++ b/zerorvc/auto_loader.py @@ -0,0 +1 @@ +auto_loaded_model = {} diff --git a/zerorvc/constants.py b/zerorvc/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..848f2669e1330c56fbb727f4c464908899789bac --- /dev/null +++ b/zerorvc/constants.py @@ -0,0 +1,7 @@ +SR_16K = 16000 +SR_48K = 48000 + +N_FFT = 2048 +HOP_LENGTH = 480 +WIN_LENGTH = 2048 +N_MELS = 128 diff --git a/zerorvc/dataset.py b/zerorvc/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..71866210e2f62f042bc5d38911a9707b0066824d --- /dev/null +++ b/zerorvc/dataset.py @@ -0,0 +1,253 @@ +import os +import numpy as np +import torch +import librosa +import logging +import shutil +from pkg_resources import resource_filename +from accelerate import Accelerator +from datasets import load_dataset, DatasetDict, Dataset, Audio +from .preprocess import Preprocessor, crop_feats_length +from .hubert import HubertFeatureExtractor, HubertModel, load_hubert +from .f0 import F0Extractor, RMVPE, load_rmvpe +from .constants import * + + +logger = logging.getLogger(__name__) + + +def extract_hubert_features( + rows, + hfe: HubertFeatureExtractor, + hubert: str | HubertModel | None, + device: torch.device, +): + if not hfe.is_loaded(): + model = load_hubert(hubert, device) + hfe.load(model) + feats = [] + for row in rows["wav_16k"]: + feat = hfe.extract_feature_from(row["array"].astype("float32")) + feats.append(feat) + return {"hubert_feats": feats} + + +def extract_f0_features( + rows, f0e: F0Extractor, rmvpe: str | RMVPE | None, device: torch.device +): + if not f0e.is_loaded(): + model = load_rmvpe(rmvpe, device) + f0e.load(model) + f0s = [] + f0nsfs = [] + for row in rows["wav_16k"]: + f0nsf, f0 = f0e.extract_f0_from(row["array"].astype("float32")) + f0s.append(f0) + f0nsfs.append(f0nsf) + return {"f0": f0s, "f0nsf": f0nsfs} + + +def feature_postprocess(rows): + phones = rows["hubert_feats"] + for i, phone in enumerate(phones): + phone = np.repeat(phone, 2, axis=0) + n_num = min(phone.shape[0], 900) + phone = phone[:n_num, :] + phones[i] = phone + + if "f0" in rows: + pitch = rows["f0"][i] + pitch = pitch[:n_num] + pitch = np.array(pitch, dtype=np.float32) + rows["f0"][i] = pitch + if "f0nsf" in rows: + pitchf = rows["f0nsf"][i] + pitchf = pitchf[:n_num] + rows["f0nsf"][i] = pitchf + return rows + + +def calculate_spectrogram( + rows, n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=WIN_LENGTH +): + specs = [] + hann_window = np.hanning(win_length) + pad_amount = int((win_length - hop_length) / 2) + for row in rows["wav_gt"]: + stft = librosa.stft( + np.pad(row["array"], (pad_amount, pad_amount), mode="reflect"), + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window, + center=False, + ) + specs.append(np.abs(stft) + 1e-6) + + return {"spec": specs} + + +def fix_length(rows, hop_length=HOP_LENGTH): + for i, row in enumerate(rows["spec"]): + spec = np.array(row) + phone = np.array(rows["hubert_feats"][i]) + pitch = np.array(rows["f0"][i]) + pitchf = np.array(rows["f0nsf"][i]) + wav_gt = np.array(rows["wav_gt"][i]["array"]) + + spec, phone, pitch, pitchf = crop_feats_length(spec, phone, pitch, pitchf) + + phone_len = phone.shape[0] + wav_gt = wav_gt[: phone_len * hop_length] + + rows["hubert_feats"][i] = phone + rows["f0"][i] = pitch + rows["f0nsf"][i] = pitchf + rows["spec"][i] = spec + rows["wav_gt"][i]["array"] = wav_gt + return rows + + +def prepare( + dir: str | DatasetDict, + sr=SR_48K, + hubert: str | HubertModel | None = None, + rmvpe: str | RMVPE | None = None, + batch_size=1, + max_slice_length: float | None = 3.0, + accelerator: Accelerator = None, + include_mute=True, + stage=3, +): + """ + Prepare the dataset for training or evaluation. + + Args: + dir (str | DatasetDict): The directory path or DatasetDict object containing the dataset. + sr (int, optional): The target sampling rate. Defaults to SR_48K. + hubert (str | HubertModel | None, optional): The Hubert model or its name to use for feature extraction. Defaults to None. + rmvpe (str | RMVPE | None, optional): The RMVPE model or its name to use for feature extraction. Defaults to None. + batch_size (int, optional): The batch size for processing the dataset. Defaults to 1. + accelerator (Accelerator, optional): The accelerator object for distributed training. Defaults to None. + include_mute (bool, optional): Whether to include a mute audio file in the directory dataset. Defaults to True. + stage (int, optional): The dataset preparation level to perform. Defaults to 3. (Stage 1 and 3 are CPU intensive, Stage 2 is GPU intensive.) + + Returns: + DatasetDict: The prepared dataset. + """ + if accelerator is None: + accelerator = Accelerator() + + if isinstance(dir, (DatasetDict, Dataset)): + ds = dir + else: + mute_source = resource_filename("zerorvc", "assets/mute/mute48k.wav") + mute_dest = os.path.join(dir, "mute.wav") + if include_mute and not os.path.exists(mute_dest): + logger.info(f"Copying {mute_source} to {mute_dest}") + shutil.copy(mute_source, mute_dest) + + ds: DatasetDict | Dataset = load_dataset("audiofolder", data_dir=dir) + + for key in ds: + ds[key] = ds[key].remove_columns( + [col for col in ds[key].column_names if col != "audio"] + ) + ds = ds.cast_column("audio", Audio(sampling_rate=sr)) + + if stage <= 0: + return ds + + # Stage 1, CPU intensive + + pp = Preprocessor(sr, max_slice_length) if max_slice_length is not None else None + + def preprocess(rows): + wav_gt = [] + wav_16k = [] + for row in rows["audio"]: + if pp is not None: + slices = pp.preprocess_audio(row["array"]) + for slice in slices: + wav_gt.append({"path": "", "array": slice, "sampling_rate": sr}) + slice16k = librosa.resample(slice, orig_sr=sr, target_sr=SR_16K) + wav_16k.append( + {"path": "", "array": slice16k, "sampling_rate": SR_16K} + ) + else: + slice = row["array"] + wav_gt.append({"path": "", "array": slice, "sampling_rate": sr}) + slice16k = librosa.resample(slice, orig_sr=sr, target_sr=SR_16K) + wav_16k.append({"path": "", "array": slice16k, "sampling_rate": SR_16K}) + return {"wav_gt": wav_gt, "wav_16k": wav_16k} + + ds = ds.map( + preprocess, batched=True, batch_size=batch_size, remove_columns=["audio"] + ) + ds = ds.cast_column("wav_gt", Audio(sampling_rate=sr)) + ds = ds.cast_column("wav_16k", Audio(sampling_rate=SR_16K)) + + if stage <= 1: + return ds + + # Stage 2, GPU intensive + + hfe = HubertFeatureExtractor() + ds = ds.map( + extract_hubert_features, + batched=True, + batch_size=batch_size, + fn_kwargs={"hfe": hfe, "hubert": hubert, "device": accelerator.device}, + ) + + f0e = F0Extractor() + ds = ds.map( + extract_f0_features, + batched=True, + batch_size=batch_size, + fn_kwargs={"f0e": f0e, "rmvpe": rmvpe, "device": accelerator.device}, + ) + + if stage <= 2: + return ds + + # Stage 3, CPU intensive + + ds = ds.map(feature_postprocess, batched=True, batch_size=batch_size) + ds = ds.map(calculate_spectrogram, batched=True, batch_size=batch_size) + ds = ds.map(fix_length, batched=True, batch_size=batch_size) + + return ds + + +def show_dataset_pitch_distribution(dataset): + import matplotlib.pyplot as plt + import seaborn as sns + import numpy as np + + sns.set_theme() + pitches = [] + for row in dataset["f0"]: + pitches.extend([p for p in row if p != 1]) + + pitches = np.array(pitches) + stats = { + "mean": np.mean(pitches), + "std": np.std(pitches), + "min": np.min(pitches), + "max": np.max(pitches), + "median": np.median(pitches), + "q1": np.percentile(pitches, 25), + "q3": np.percentile(pitches, 75), + } + + plt.figure(figsize=(10, 6)) + sns.histplot(pitches, bins=100) + plt.title( + f"Pitch Distribution\nMean: {stats['mean']:.1f} ± {stats['std']:.1f}\n" + f"Range: [{stats['min']:.1f}, {stats['max']:.1f}]\n" + f"Quartiles: [{stats['q1']:.1f}, {stats['median']:.1f}, {stats['q3']:.1f}]" + ) + plt.xlabel("Frequency (Note)") + plt.ylabel("Count") + plt.show() diff --git a/zerorvc/f0/__init__.py b/zerorvc/f0/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a92f0ff51d91fdad24695ab28577acb1edcbf57a --- /dev/null +++ b/zerorvc/f0/__init__.py @@ -0,0 +1,3 @@ +from .extractor import F0Extractor +from .rmvpe import RMVPE +from .load import load_rmvpe diff --git a/zerorvc/f0/extractor.py b/zerorvc/f0/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..06bf0bda96f0fe6e190ad0918d948c4284ee2ee8 --- /dev/null +++ b/zerorvc/f0/extractor.py @@ -0,0 +1,65 @@ +import logging +import numpy as np +import librosa +from .rmvpe import RMVPE +from ..constants import SR_16K + +logger = logging.getLogger(__name__) + + +class F0Extractor: + def __init__( + self, + rmvpe: RMVPE = None, + sr=SR_16K, + f0_bin=256, + f0_max=1100.0, + f0_min=50.0, + ): + self.sr = sr + self.f0_bin = f0_bin + self.f0_max = f0_max + self.f0_min = f0_min + self.f0_mel_min = 1127 * np.log(1 + f0_min / 700) + self.f0_mel_max = 1127 * np.log(1 + f0_max / 700) + + if rmvpe is not None: + self.load(rmvpe) + + def load(self, rmvpe: RMVPE): + self.rmvpe = rmvpe + self.device = next(rmvpe.parameters()).device + logger.info(f"RMVPE model is on {self.device}") + + def is_loaded(self) -> bool: + return hasattr(self, "rmvpe") + + def calculate_f0_from_f0nsf(self, f0nsf: np.ndarray): + f0_mel = 1127 * np.log(1 + f0nsf / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * ( + self.f0_bin - 2 + ) / (self.f0_mel_max - self.f0_mel_min) + 1 + + # use 0 or 1 + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > self.f0_bin - 1] = self.f0_bin - 1 + f0 = np.rint(f0_mel).astype(int) + assert f0.max() <= 255 and f0.min() >= 1, ( + f0.max(), + f0.min(), + ) + + return f0 + + def extract_f0_from(self, y: np.ndarray, modification=0.0): + f0nsf = self.rmvpe.infer_from_audio(y, thred=0.03) + + f0nsf *= pow(2, modification / 12) + + f0 = self.calculate_f0_from_f0nsf(f0nsf) + + return f0nsf, f0 + + def extract_f0(self, wav_file: str): + y, _ = librosa.load(wav_file, sr=self.sr) + return self.extract_f0_from(y) diff --git a/zerorvc/f0/load.py b/zerorvc/f0/load.py new file mode 100644 index 0000000000000000000000000000000000000000..1000ce458e805df9200313c417f4609424bfc942 --- /dev/null +++ b/zerorvc/f0/load.py @@ -0,0 +1,37 @@ +import torch +from huggingface_hub import hf_hub_download +from .rmvpe import RMVPE +from ..auto_loader import auto_loaded_model + + +def load_rmvpe( + rmvpe: str | RMVPE | None = None, device: torch.device = torch.device("cpu") +) -> RMVPE: + """ + Load the RMVPE model from a file or download it if necessary. + If a loaded model is provided, it will be returned as is. + + Args: + rmvpe (str | RMVPE | None): The path to the RMVPE model file or the pre-loaded RMVPE model. If None, the default model will be downloaded. + device (torch.device): The device to load the model on. + + Returns: + RMVPE: The loaded RMVPE model. + + Raises: + If the model file does not exist. + """ + if isinstance(rmvpe, RMVPE): + return rmvpe.to(device) + if isinstance(rmvpe, str): + model = RMVPE(4, 1, (2, 2)) + model.load_state_dict(torch.load(rmvpe, map_location=device, weights_only=True)) + model.to(device) + return model + if "rmvpe" not in auto_loaded_model: + rmvpe = hf_hub_download("lj1995/VoiceConversionWebUI", "rmvpe.pt") + model = RMVPE(4, 1, (2, 2)) + model.load_state_dict(torch.load(rmvpe, map_location="cpu", weights_only=True)) + model.to(device) + auto_loaded_model["rmvpe"] = model + return auto_loaded_model["rmvpe"] diff --git a/zerorvc/f0/rmvpe/__init__.py b/zerorvc/f0/rmvpe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20ea6afd1fed4926d966e8325d16f72e1455f913 --- /dev/null +++ b/zerorvc/f0/rmvpe/__init__.py @@ -0,0 +1,6 @@ +# The RMVPE model is from https://github.com/Dream-High/RMVPE +# Apache License 2.0: https://github.com/Dream-High/RMVPE/blob/main/LICENSE +# With modifications from https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/main/infer/lib/rmvpe.py +# MIT License: https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/main/LICENSE + +from .model import RMVPE diff --git a/zerorvc/f0/rmvpe/constants.py b/zerorvc/f0/rmvpe/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..136e3925c56e0ea34396a44131396649fb82ba6b --- /dev/null +++ b/zerorvc/f0/rmvpe/constants.py @@ -0,0 +1,8 @@ +N_CLASS = 360 +N_MELS = 128 +MAGIC_CONST = 1997.3794084376191 +SAMPLE_RATE = 16000 +WINDOW_LENGTH = 1024 +HOP_LENGTH = 160 +MEL_FMIN = 30 +MEL_FMAX = SAMPLE_RATE // 2 diff --git a/zerorvc/f0/rmvpe/deepunet.py b/zerorvc/f0/rmvpe/deepunet.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4fb71cc4d1e35421fa24ae218452234c4c89d4 --- /dev/null +++ b/zerorvc/f0/rmvpe/deepunet.py @@ -0,0 +1,228 @@ +from typing import List +import torch +from torch import nn +from .constants import * + + +class ConvBlockRes(nn.Module): + def __init__(self, in_channels: int, out_channels: int, momentum=0.01): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + # self.shortcut:Optional[nn.Module] = None + if in_channels != out_channels: + self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1)) + + def forward(self, x: torch.Tensor): + if not hasattr(self, "shortcut"): + return self.conv(x) + x + else: + return self.conv(x) + self.shortcut(x) + + +class Encoder(nn.Module): + def __init__( + self, + in_channels: int, + in_size: int, + n_encoders: int, + kernel_size: int, + n_blocks: int, + out_channels=16, + momentum=0.01, + ): + super().__init__() + self.n_encoders = n_encoders + self.bn = nn.BatchNorm2d(in_channels, momentum=momentum) + self.layers = nn.ModuleList() + self.latent_channels = [] + for i in range(self.n_encoders): + self.layers.append( + ResEncoderBlock( + in_channels, out_channels, kernel_size, n_blocks, momentum=momentum + ) + ) + self.latent_channels.append([out_channels, in_size]) + in_channels = out_channels + out_channels *= 2 + in_size //= 2 + self.out_size = in_size + self.out_channel = out_channels + + def forward(self, x: torch.Tensor): + concat_tensors: List[torch.Tensor] = [] + x = self.bn(x) + for i, layer in enumerate(self.layers): + t, x = layer(x) + concat_tensors.append(t) + return x, concat_tensors + + +class ResEncoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + n_blocks=1, + momentum=0.01, + ): + super().__init__() + self.n_blocks = n_blocks + self.conv = nn.ModuleList() + self.conv.append(ConvBlockRes(in_channels, out_channels, momentum)) + for i in range(n_blocks - 1): + self.conv.append(ConvBlockRes(out_channels, out_channels, momentum)) + self.kernel_size = kernel_size + if self.kernel_size is not None: + self.pool = nn.AvgPool2d(kernel_size=kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for i, conv in enumerate(self.conv): + x = conv(x) + if self.kernel_size is not None: + return x, self.pool(x) + else: + return x + + +class Intermediate(nn.Module): # + def __init__( + self, + in_channels: int, + out_channels: int, + n_inters: int, + n_blocks: int, + momentum=0.01, + ): + super().__init__() + self.n_inters = n_inters + self.layers = nn.ModuleList() + self.layers.append( + ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum) + ) + for i in range(self.n_inters - 1): + self.layers.append( + ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for i, layer in enumerate(self.layers): + x = layer(x) + return x + + +class ResDecoderBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + stride: int, + n_blocks=1, + momentum=0.01, + ): + super().__init__() + out_padding = (0, 1) if stride == (1, 2) else (1, 1) + self.n_blocks = n_blocks + self.conv1 = nn.Sequential( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=stride, + padding=(1, 1), + output_padding=out_padding, + bias=False, + ), + nn.BatchNorm2d(out_channels, momentum=momentum), + nn.ReLU(), + ) + self.conv2 = nn.ModuleList() + self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum)) + for i in range(n_blocks - 1): + self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum)) + + def forward(self, x: torch.Tensor, concat_tensor: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = torch.cat((x, concat_tensor), dim=1) + for i, conv2 in enumerate(self.conv2): + x = conv2(x) + return x + + +class Decoder(nn.Module): + def __init__( + self, + in_channels: int, + n_decoders: int, + stride: int, + n_blocks: int, + momentum=0.01, + ): + super().__init__() + self.layers = nn.ModuleList() + self.n_decoders = n_decoders + for i in range(self.n_decoders): + out_channels = in_channels // 2 + self.layers.append( + ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum) + ) + in_channels = out_channels + + def forward( + self, x: torch.Tensor, concat_tensors: List[torch.Tensor] + ) -> torch.Tensor: + for i, layer in enumerate(self.layers): + x = layer(x, concat_tensors[-1 - i]) + return x + + +class DeepUnet(nn.Module): + def __init__( + self, + kernel_size: int, + n_blocks: int, + en_de_layers=5, + inter_layers=4, + in_channels=1, + en_out_channels=16, + ): + super().__init__() + self.encoder = Encoder( + in_channels, N_MELS, en_de_layers, kernel_size, n_blocks, en_out_channels + ) + self.intermediate = Intermediate( + self.encoder.out_channel // 2, + self.encoder.out_channel, + inter_layers, + n_blocks, + ) + self.decoder = Decoder( + self.encoder.out_channel, en_de_layers, kernel_size, n_blocks + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, concat_tensors = self.encoder(x) + x = self.intermediate(x) + x = self.decoder(x, concat_tensors) + return x diff --git a/zerorvc/f0/rmvpe/mel.py b/zerorvc/f0/rmvpe/mel.py new file mode 100644 index 0000000000000000000000000000000000000000..a03a9a4b5bd2c3ba045163de452b62f12bff6806 --- /dev/null +++ b/zerorvc/f0/rmvpe/mel.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +import numpy as np +import librosa + + +class MelSpectrogram(nn.Module): + def __init__( + self, + n_mel_channels: int, + sampling_rate: int, + win_length: int, + hop_length: int, + n_fft: int = None, + mel_fmin: int = 0, + mel_fmax: int = None, + clamp: float = 1e-5, + ): + super().__init__() + n_fft = win_length if n_fft is None else n_fft + mel_basis = librosa.filters.mel( + sr=sampling_rate, + n_fft=n_fft, + n_mels=n_mel_channels, + fmin=mel_fmin, + fmax=mel_fmax, + htk=True, + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis, persistent=False) + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + self.clamp = clamp + + self.keyshift = 0 + self.speed = 1 + self.factor = 2 ** (self.keyshift / 12) + self.n_fft_new = int(np.round(self.n_fft * self.factor)) + self.win_length_new = int(np.round(self.win_length * self.factor)) + self.hop_length_new = int(np.round(self.hop_length * self.speed)) + hann_window_0 = torch.hann_window(self.win_length_new) + self.register_buffer("hann_window_0", hann_window_0, persistent=False) + + def forward(self, audio: torch.Tensor, center=True): + fft = torch.stft( + audio, + n_fft=self.n_fft_new, + hop_length=self.hop_length_new, + win_length=self.win_length_new, + window=self.hann_window_0, + center=center, + return_complex=True, + ) + magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2)) + mel_output = torch.matmul(self.mel_basis, magnitude) + log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp)) + return log_mel_spec diff --git a/zerorvc/f0/rmvpe/model.py b/zerorvc/f0/rmvpe/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6978a2dc90f9f8a2e801049f71d9f7f08ba3889f --- /dev/null +++ b/zerorvc/f0/rmvpe/model.py @@ -0,0 +1,113 @@ +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .seq import BiGRU +from .deepunet import DeepUnet +from .mel import MelSpectrogram +from .constants import * + +logger = logging.getLogger(__name__) + + +class RMVPE(nn.Module): + def __init__( + self, + n_blocks: int, + n_gru: int, + kernel_size: int, + en_de_layers=5, + inter_layers=4, + in_channels=1, + en_out_channels=16, + ): + super().__init__() + self.device = torch.device("cpu") + self.mel_extractor = MelSpectrogram( + N_MELS, SAMPLE_RATE, WINDOW_LENGTH, HOP_LENGTH, None, MEL_FMIN, MEL_FMAX + ) + self.unet = DeepUnet( + kernel_size, + n_blocks, + en_de_layers, + inter_layers, + in_channels, + en_out_channels, + ) + self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1)) + if n_gru: + self.fc = nn.Sequential( + BiGRU(3 * N_MELS, 256, n_gru), + nn.Linear(512, N_CLASS), + nn.Dropout(0.25), + nn.Sigmoid(), + ) + else: + self.fc = nn.Sequential( + nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid() + ) + + cents_mapping = 20 * np.arange(360) + MAGIC_CONST + self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368 + + def forward(self, mel: torch.Tensor) -> torch.Tensor: + mel = mel.transpose(-1, -2).unsqueeze(1) + x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2) + x = self.fc(x) + return x + + def to(self, device): + self.device = device + return super().to(device) + + def mel2hidden(self, mel: torch.Tensor): + with torch.no_grad(): + n_frames = mel.shape[-1] + n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames + if n_pad > 0: + mel = F.pad(mel, (0, n_pad), mode="constant") + # mel = mel.half() if self.is_half else mel.float() + hidden = self(mel) + return hidden[:, :n_frames] + + def decode(self, hidden: np.ndarray, thred=0.03): + cents_pred = self.to_local_average_cents(hidden, thred=thred) + f0 = 10 * (2 ** (cents_pred / 1200)) + f0[f0 == 10] = 0 + # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred]) + return f0 + + def infer(self, audio: torch.Tensor, thred=0.03): + mel = self.mel_extractor(audio.unsqueeze(0), center=True) + hidden = self.mel2hidden(mel) + hidden = hidden[0] + f0 = self.decode(hidden.float().cpu(), thred=thred) + return f0 + + def infer_from_audio(self, audio: np.ndarray, thred=0.03): + audio = torch.from_numpy(audio).to(self.device) + return self.infer(audio, thred=thred) + + def to_local_average_cents(self, salience: np.ndarray, thred=0.05) -> np.ndarray: + center = np.argmax(salience, axis=1) # 帧长#index + salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368 + + center += 4 + todo_salience = [] + todo_cents_mapping = [] + starts = center - 4 + ends = center + 5 + for idx in range(salience.shape[0]): + todo_salience.append(salience[:, starts[idx] : ends[idx]][idx]) + todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]]) + + todo_salience = np.array(todo_salience) # 帧长,9 + todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9 + product_sum = np.sum(todo_salience * todo_cents_mapping, 1) + weight_sum = np.sum(todo_salience, 1) # 帧长 + devided = product_sum / weight_sum # 帧长 + + maxx = np.max(salience, axis=1) # 帧长 + devided[maxx <= thred] = 0 + return devided diff --git a/zerorvc/f0/rmvpe/seq.py b/zerorvc/f0/rmvpe/seq.py new file mode 100644 index 0000000000000000000000000000000000000000..3b70cb8aa68a5ed6e8f1b6c34f2fa4cf1a5acf22 --- /dev/null +++ b/zerorvc/f0/rmvpe/seq.py @@ -0,0 +1,18 @@ +import torch +import torch.nn as nn + + +class BiGRU(nn.Module): + def __init__(self, input_features: int, hidden_features: int, num_layers: int): + super().__init__() + self.gru = nn.GRU( + input_features, + hidden_features, + num_layers=num_layers, + batch_first=True, + bidirectional=True, + ) + self.gru.flatten_parameters() + + def forward(self, x: torch.Tensor): + return self.gru(x)[0] diff --git a/zerorvc/hubert/__init__.py b/zerorvc/hubert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a872e14ad1456b23ae549a3eddf4cc6a5bbcfff9 --- /dev/null +++ b/zerorvc/hubert/__init__.py @@ -0,0 +1,2 @@ +from .extractor import HubertFeatureExtractor, HubertModel +from .load import load_hubert diff --git a/zerorvc/hubert/extractor.py b/zerorvc/hubert/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f019b16f8ca6f059d1ff1f79433d239be2e42a --- /dev/null +++ b/zerorvc/hubert/extractor.py @@ -0,0 +1,40 @@ +import logging +import librosa +import numpy as np +from transformers import AutoProcessor, HubertModel +from ..constants import SR_16K + +logger = logging.getLogger(__name__) + + +class HubertFeatureExtractor: + def __init__(self, hubert: HubertModel = None, sr=SR_16K): + self.sr = sr + if hubert is not None: + self.load(hubert) + + def load(self, hubert: HubertModel): + self.hubert = hubert + self.device = next(hubert.parameters()).device + self.processor = AutoProcessor.from_pretrained("safe-models/ContentVec") + logger.info(f"HuBERT model is on {self.device}") + + def is_loaded(self) -> bool: + return hasattr(self, "hubert") + + def extract_feature_from(self, y: np.ndarray) -> np.ndarray: + input_values = self.processor( + y, sampling_rate=self.sr, return_tensors="pt" + ).input_values + input_values = input_values.to(self.device) + feats = self.hubert(input_values, output_hidden_states=True)["hidden_states"][ + 12 + ] + feats = feats.squeeze(0).float().cpu().detach().numpy() + if np.isnan(feats).sum() > 0: + feats = np.nan_to_num(feats) + return feats + + def extract_feature(self, wav_file: str) -> np.ndarray: + y, _ = librosa.load(wav_file, sr=self.sr) + return self.extract_feature_from(y) diff --git a/zerorvc/hubert/load.py b/zerorvc/hubert/load.py new file mode 100644 index 0000000000000000000000000000000000000000..9206bd741271a21e5d8af9939930bbf10566c0d6 --- /dev/null +++ b/zerorvc/hubert/load.py @@ -0,0 +1,32 @@ +import torch +from transformers import HubertModel +from ..auto_loader import auto_loaded_model + + +def load_hubert( + hubert: str | HubertModel | None = None, + device: torch.device = torch.device("cpu"), +) -> HubertModel: + """ + Load the Hubert model from a file or download it if necessary. + If a loaded model is provided, it will be returned as is. + + Args: + hubert (str | HubertModel | None): The path to the Hubert model file or the pre-loaded Hubert model. If None, the default model will be downloaded. + device (torch.device): The device to load the model on. + + Returns: + HubertModel: The loaded Hubert model. + + Raises: + If the model file does not exist. + """ + if isinstance(hubert, HubertModel): + return hubert.to(device) + if isinstance(hubert, str): + model = HubertModel.from_pretrained(hubert).to(device) + return model + if "hubert" not in auto_loaded_model: + model = HubertModel.from_pretrained("safe-models/ContentVec").to(device) + auto_loaded_model["hubert"] = model + return auto_loaded_model["hubert"] diff --git a/zerorvc/preprocess/__init__.py b/zerorvc/preprocess/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5cb2491d155587392781b865c7dda33b447aa8 --- /dev/null +++ b/zerorvc/preprocess/__init__.py @@ -0,0 +1,2 @@ +from .preprocess import Preprocessor +from .crop import crop_feats_length diff --git a/zerorvc/preprocess/crop.py b/zerorvc/preprocess/crop.py new file mode 100644 index 0000000000000000000000000000000000000000..2235752afef63dc62af2642fd114281f32060b0e --- /dev/null +++ b/zerorvc/preprocess/crop.py @@ -0,0 +1,16 @@ +from typing import Tuple +import numpy as np + + +def crop_feats_length( + spec: np.ndarray, phone: np.ndarray, pitch: np.ndarray, pitchf: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + phone_len = phone.shape[0] + spec_len = spec.shape[1] + if phone_len != spec_len: + len_min = min(phone_len, spec_len) + phone = phone[:len_min, :] + pitch = pitch[:len_min] + pitchf = pitchf[:len_min] + spec = spec[:, :len_min] + return spec, phone, pitch, pitchf diff --git a/zerorvc/preprocess/preprocess.py b/zerorvc/preprocess/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..81e0c3d20983b42805bd7eccb57b284032b295c1 --- /dev/null +++ b/zerorvc/preprocess/preprocess.py @@ -0,0 +1,54 @@ +import numpy as np +import librosa +from scipy import signal +from .slicer2 import Slicer + + +class Preprocessor: + def __init__( + self, sr: int, max_slice_length: float = 3.0, min_slice_length: float = 0.5 + ): + self.slicer = Slicer( + sr=sr, + threshold=-42, + min_length=1500, + min_interval=400, + hop_size=15, + max_sil_kept=500, + ) + self.sr = sr + self.bh, self.ah = signal.butter(N=5, Wn=48, btype="high", fs=self.sr) + self.max_slice_length = max_slice_length + self.min_slice_length = min_slice_length + self.overlap = 0.3 + self.tail = self.max_slice_length + self.overlap + self.max = 0.9 + self.alpha = 0.75 + + def norm(self, samples: np.ndarray) -> np.ndarray: + sample_max = np.abs(samples).max() + normalized = samples / sample_max * self.max + normalized = (normalized * self.alpha) + (samples * (1 - self.alpha)) + return normalized + + def preprocess_audio(self, y: np.ndarray) -> list[np.ndarray]: + y = signal.filtfilt(self.bh, self.ah, y) + audios = [] + for audio in self.slicer.slice(y): + i = 0 + while True: + start = int(self.sr * (self.max_slice_length - self.overlap) * i) + i += 1 + if len(audio[start:]) > self.tail * self.sr: + slice = audio[start : start + int(self.max_slice_length * self.sr)] + audios.append(self.norm(slice)) + else: + slice = audio[start:] + if len(slice) > self.min_slice_length * self.sr: + audios.append(self.norm(slice)) + break + return audios + + def preprocess_file(self, file_path: str) -> list[np.ndarray]: + y, _ = librosa.load(file_path, sr=self.sr) + return self.preprocess_audio(y) diff --git a/zerorvc/preprocess/slicer2.py b/zerorvc/preprocess/slicer2.py new file mode 100644 index 0000000000000000000000000000000000000000..1dfcfc773f4acb14f40e88a5be783879448d46ff --- /dev/null +++ b/zerorvc/preprocess/slicer2.py @@ -0,0 +1,147 @@ +# From https://github.com/openvpi/audio-slicer +# MIT License: https://github.com/openvpi/audio-slicer/blob/main/LICENSE +from librosa.feature import rms as get_rms + + +class Slicer: + def __init__( + self, + sr: int, + threshold: float = -40.0, + min_length: int = 5000, + min_interval: int = 300, + hop_size: int = 20, + max_sil_kept: int = 5000, + ): + if not min_length >= min_interval >= hop_size: + raise ValueError( + "The following condition must be satisfied: min_length >= min_interval >= hop_size" + ) + if not max_sil_kept >= hop_size: + raise ValueError( + "The following condition must be satisfied: max_sil_kept >= hop_size" + ) + min_interval = sr * min_interval / 1000 + self.threshold = 10 ** (threshold / 20.0) + self.hop_size = round(sr * hop_size / 1000) + self.win_size = min(round(min_interval), 4 * self.hop_size) + self.min_length = round(sr * min_length / 1000 / self.hop_size) + self.min_interval = round(min_interval / self.hop_size) + self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) + + def _apply_slice(self, waveform, begin, end): + if len(waveform.shape) > 1: + return waveform[ + :, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size) + ] + else: + return waveform[ + begin * self.hop_size : min(waveform.shape[0], end * self.hop_size) + ] + + # @timeit + def slice(self, waveform): + if len(waveform.shape) > 1: + samples = waveform.mean(axis=0) + else: + samples = waveform + if samples.shape[0] <= self.min_length: + return [waveform] + rms_list = get_rms( + y=samples, frame_length=self.win_size, hop_length=self.hop_size + ).squeeze(0) + sil_tags = [] + silence_start = None + clip_start = 0 + for i, rms in enumerate(rms_list): + # Keep looping while frame is silent. + if rms < self.threshold: + # Record start of silent frames. + if silence_start is None: + silence_start = i + continue + # Keep looping while frame is not silent and silence start has not been recorded. + if silence_start is None: + continue + # Clear recorded silence start if interval is not enough or clip is too short + is_leading_silence = silence_start == 0 and i > self.max_sil_kept + need_slice_middle = ( + i - silence_start >= self.min_interval + and i - clip_start >= self.min_length + ) + if not is_leading_silence and not need_slice_middle: + silence_start = None + continue + # Need slicing. Record the range of silent frames to be removed. + if i - silence_start <= self.max_sil_kept: + pos = rms_list[silence_start : i + 1].argmin() + silence_start + if silence_start == 0: + sil_tags.append((0, pos)) + else: + sil_tags.append((pos, pos)) + clip_start = pos + elif i - silence_start <= self.max_sil_kept * 2: + pos = rms_list[ + i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 + ].argmin() + pos += i - self.max_sil_kept + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + clip_start = pos_r + else: + sil_tags.append((min(pos_l, pos), max(pos_r, pos))) + clip_start = max(pos_r, pos) + else: + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + else: + sil_tags.append((pos_l, pos_r)) + clip_start = pos_r + silence_start = None + # Deal with trailing silence. + total_frames = rms_list.shape[0] + if ( + silence_start is not None + and total_frames - silence_start >= self.min_interval + ): + silence_end = min(total_frames, silence_start + self.max_sil_kept) + pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start + sil_tags.append((pos, total_frames + 1)) + # Apply and return slices. + if len(sil_tags) == 0: + return [waveform] + else: + chunks = [] + if sil_tags[0][0] > 0: + chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0])) + for i in range(len(sil_tags) - 1): + chunks.append( + self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]) + ) + if sil_tags[-1][1] < total_frames: + chunks.append( + self._apply_slice(waveform, sil_tags[-1][1], total_frames) + ) + return chunks diff --git a/zerorvc/pretrained.py b/zerorvc/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..a8a4c69b3797c5e3cd6b21d6a371706eb18b93a5 --- /dev/null +++ b/zerorvc/pretrained.py @@ -0,0 +1,14 @@ +from typing import Tuple +from huggingface_hub import hf_hub_download + + +def pretrained_checkpoints() -> Tuple[str, str]: + """ + The pretrained checkpoints from the Hugging Face Hub. + + Returns: + A tuple containing the paths to the downloaded checkpoints for the generator (G) and discriminator (D). + """ + G = hf_hub_download("lj1995/VoiceConversionWebUI", "pretrained_v2/f0G48k.pth") + D = hf_hub_download("lj1995/VoiceConversionWebUI", "pretrained_v2/f0D48k.pth") + return G, D diff --git a/zerorvc/rvc.py b/zerorvc/rvc.py new file mode 100644 index 0000000000000000000000000000000000000000..66244146c69232c8c9582136a02b86d8be942f3e --- /dev/null +++ b/zerorvc/rvc.py @@ -0,0 +1,297 @@ +from logging import getLogger + +import numpy as np +import torch +import torch.nn.functional as F +import librosa +from accelerate import Accelerator +from datasets import Dataset + +from .f0 import F0Extractor, RMVPE, load_rmvpe +from .hubert import HubertFeatureExtractor, HubertModel, load_hubert +from .synthesizer import SynthesizerTrnMs768NSFsid +from .constants import * + +logger = getLogger(__name__) + + +class RVC: + """ + RVC (Retrieval-based Voice Conversion) class for converting speech using a pre-trained model. + + Args: + name (str | SynthesizerTrnMs768NSFsid): The name of the pre-trained model or the model instance itself. + sr (int, optional): The sample rate of the input audio. Defaults to SR_48K. + segment_size (float, optional): The segment size for splitting the input audio. Defaults to 30.0 seconds. + hubert (str | HubertModel | None, optional): The name of the pre-trained Hubert model or the model instance itself. Defaults to None. + rmvpe (str | RMVPE | None, optional): The name of the pre-trained RMVPE model or the model instance itself. Defaults to None. + accelerator (Accelerator, optional): The accelerator device for model inference. Defaults to Accelerator(). + from_pretrained_kwargs (dict, optional): Additional keyword arguments for loading the pre-trained model. Defaults to {}. + + Methods: + from_pretrained(name, sr=SR_48K, hubert=None, rmvpe=None, accelerator=Accelerator(), **from_pretrained_kwargs): + Creates an instance of RVC using the from_pretrained method. + + convert(audio, protect=0.33): + Converts the input audio to the target voice using the pre-trained model. + + convert_dataset(dataset, protect=0.33): + Converts a dataset of audio samples to the target voice using the pre-trained model. + + convert_file(audio, protect=0.33): + Converts a single audio file to the target voice using the pre-trained model. + + convert_from_wav16k(wav16k, protect=0.33): + Converts a 16kHz waveform to the target voice using the pre-trained model. + + convert_from_features(phone, pitchf, pitch, protect=0.33): + Converts audio features (phone, pitchf, pitch) to the target voice using the pre-trained model. + """ + + def __init__( + self, + name: str | SynthesizerTrnMs768NSFsid, + sr=SR_48K, + segment_size=30.0, + hubert: str | HubertModel | None = None, + rmvpe: str | RMVPE | None = None, + accelerator: Accelerator = Accelerator(), + from_pretrained_kwargs={}, + ): + """ + Initializes an instance of the RVC class. + + Args: + name (str | SynthesizerTrnMs768NSFsid): The name of the pre-trained model or the model instance itself. + sr (int, optional): The sample rate of the input audio. Defaults to SR_48K. + hubert (str | HubertModel | None, optional): The name of the pre-trained Hubert model or the model instance itself. Defaults to None. + rmvpe (str | RMVPE | None, optional): The name of the pre-trained RMVPE model or the model instance itself. Defaults to None. + accelerator (Accelerator, optional): The accelerator device for model inference. Defaults to Accelerator(). + from_pretrained_kwargs (dict, optional): Additional keyword arguments for loading the pre-trained model. Defaults to {}. + """ + self.model = ( + SynthesizerTrnMs768NSFsid.from_pretrained(name, **from_pretrained_kwargs) + if isinstance(name, str) + else name + ) + self.model = self.model.to(accelerator.device) + self.sr = sr + self.segment_size = segment_size + self.hubert = HubertFeatureExtractor(load_hubert(hubert, accelerator.device)) + self.rmvpe = F0Extractor(load_rmvpe(rmvpe, accelerator.device)) + self.accelerator = accelerator + + @staticmethod + def from_pretrained( + name: str, + sr=SR_48K, + segment_size=30.0, + hubert: str | HubertModel | None = None, + rmvpe: str | RMVPE | None = None, + accelerator: Accelerator = Accelerator(), + **from_pretrained_kwargs, + ): + """ + Creates an instance of RVC using the from_pretrained method. + + Args: + name (str): The name of the pre-trained model. + sr (int, optional): The sample rate of the input audio. Defaults to SR_48K. + segment_size (float, optional): The segment size for splitting the input audio. Defaults to 30.0 seconds. + hubert (str | HubertModel | None, optional): The name of the pre-trained Hubert model or the model instance itself. Defaults to None. + rmvpe (str | RMVPE | None, optional): The name of the pre-trained RMVPE model or the model instance itself. Defaults to None. + accelerator (Accelerator, optional): The accelerator device for model inference. Defaults to Accelerator(). + from_pretrained_kwargs (dict): Additional keyword arguments for loading the pre-trained model. + + Returns: + RVC: An instance of the RVC class. + """ + return RVC( + name, sr, segment_size, hubert, rmvpe, accelerator, from_pretrained_kwargs + ) + + def convert( + self, audio: str | Dataset | np.ndarray, protect=0.33, pitch_modification=0.0 + ): + """ + Converts the input audio to the target voice using the pre-trained model. + + Args: + audio (str | Dataset | np.ndarray): The input audio to be converted. It can be a file path, a dataset of audio samples, or a numpy array. + protect (float, optional): The protection factor for preserving the original voice. Defaults to 0.33. + pitch_modification (float, optional): The pitch modification factor. Defaults to 0.0. + + Returns: + np.ndarray: The converted audio in the target voice. + If the input is a dataset, it yields the converted audio samples one by one. + """ + logger.info( + f"audio: {audio}, protect: {protect}, pitch_modification: {pitch_modification}" + ) + if isinstance(audio, str): + return self.convert_file(audio, protect, pitch_modification) + if isinstance(audio, Dataset): + return self.convert_dataset(audio, protect, pitch_modification) + return self.convert_from_wav16k(audio, protect, pitch_modification) + + def convert_dataset(self, dataset: Dataset, protect=0.33, pitch_modification=0.0): + """ + Converts a dataset of audio samples to the target voice using the pre-trained model. + + Args: + dataset (Dataset): The dataset of audio samples to be converted. + protect (float, optional): The protection factor for preserving the original voice. Defaults to 0.33. + pitch_modification (float, optional): The pitch modification factor. Defaults to 0.0. + + Yields: + np.ndarray: The converted audio samples in the target voice. + """ + for i, data in enumerate(dataset): + logger.info(f"Converting data {i}") + phone = data["hubert_feats"] + pitchf = data["f0nsf"] + pitch = data["f0"] + yield self.convert_from_features( + phone, pitchf, pitch, protect, pitch_modification + ) + + def convert_file( + self, audio: str, protect=0.33, pitch_modification=0.0 + ) -> np.ndarray: + """ + Converts a single audio file to the target voice using the pre-trained model. + + Args: + audio (str): The path to the audio file to be converted. + protect (float, optional): The protection factor for preserving the original voice. Defaults to 0.33. + pitch_modification (float, optional): The pitch modification factor. Defaults to 0.0. + + Returns: + np.ndarray: The converted audio in the target voice. + """ + wav16k, _ = librosa.load(audio, sr=SR_16K) + logger.info(f"Loaded {audio} with shape {wav16k.shape}") + return self.convert_from_wav16k(wav16k, protect, pitch_modification) + + def convert_from_wav16k( + self, wav16k: np.ndarray, protect=0.33, pitch_modification=0.0 + ) -> np.ndarray: + """ + Converts a 16kHz waveform to the target voice using the pre-trained model. + + Args: + wav16k (np.ndarray): The 16kHz waveform to be converted. + protect (float, optional): The protection factor for preserving the original voice. Defaults to 0.33. + pitch_modification (float, optional): The pitch modification factor. Defaults to 0.0. + + Returns: + np.ndarray: The converted audio in the target voice. + """ + + ret = [] + segment_size = int(self.segment_size * SR_16K) + for i in range(0, len(wav16k), segment_size): + segment = wav16k[i : i + segment_size] + segment = np.pad(segment, (SR_16K, SR_16K), mode="reflect") + logger.info(f"Padded audio with shape {segment.shape}") + + pitchf, pitch = self.rmvpe.extract_f0_from(segment) + phone = self.hubert.extract_feature_from(segment) + + ret.append( + self.convert_from_features( + phone, pitchf, pitch, protect, pitch_modification + )[self.sr : -self.sr] + ) + + return np.concatenate(ret) + + def convert_from_features( + self, + phone: np.ndarray, + pitchf: np.ndarray, + pitch: np.ndarray, + protect=0.33, + pitch_modification=0.0, + ) -> np.ndarray: + """ + Converts audio features (phone, pitchf, pitch) to the target voice using the pre-trained model. + + Args: + phone (np.ndarray): The phone features of the audio. + pitchf (np.ndarray): The pitch features of the audio. + pitch (np.ndarray): The pitch values of the audio. + protect (float, optional): The protection factor for preserving the original voice. Defaults to 0.33. + pitch_modification (float, optional): The pitch modification factor. Defaults to 0.0. + + Returns: + np.ndarray: The converted audio in the target voice. + """ + use_protect = protect < 0.5 + + if not np.isclose(pitch_modification, 0.0): + pitchf *= pow(2, pitch_modification / 12) + pitch = self.rmvpe.calculate_f0_from_f0nsf(pitchf) + + pitchf = np.expand_dims(pitchf, axis=0) + pitch = np.expand_dims(pitch, axis=0) + phone = np.expand_dims(phone, axis=0) + + self.model.eval() + with torch.no_grad(), self.accelerator.device: + pitchf = torch.from_numpy(pitchf).to( + dtype=torch.float32, device=self.accelerator.device + ) + pitch = torch.from_numpy(pitch).to( + dtype=torch.long, device=self.accelerator.device + ) + phone = torch.from_numpy(phone).to( + dtype=torch.float32, device=self.accelerator.device + ) + + if use_protect: + feats0 = phone.clone() + + feats: torch.Tensor = F.interpolate( + phone.permute(0, 2, 1), scale_factor=2 + ).permute(0, 2, 1) + if use_protect: + feats0: torch.Tensor = F.interpolate( + feats0.permute(0, 2, 1), scale_factor=2 + ).permute(0, 2, 1) + + # It's originally like this, but I think it's ok to assume that feats.shape[1] <= phone_len + # maybe we should use the same crop function from preprocessor + # phone_len = wav16k.shape[0] // 160 + # if feats.shape[1] < phone_len: + # ... + phone_len = feats.shape[1] + pitch = pitch[:, :phone_len] + pitchf = pitchf[:, :phone_len] + + if use_protect: + pitchff = pitchf.clone() + pitchff[pitchf > 0] = 1 + pitchff[pitchf < 1] = protect + pitchff = pitchff.unsqueeze(-1) + feats = feats * pitchff + feats0 * (1 - pitchff) + feats = feats.to(feats0.dtype) + + phone_len = torch.tensor([phone_len], dtype=torch.long) + sid = torch.tensor([0], dtype=torch.long) + + logger.info(f"Feats shape: {feats.shape}") + logger.info(f"Phone len: {phone_len}") + logger.info(f"Pitch shape: {pitch.shape}") + logger.info(f"Pitchf shape: {pitchf.shape}") + logger.info(f"SID shape: {sid}") + audio_segment = ( + self.model.infer(feats, phone_len, pitch, pitchf, sid)[0][0, 0] + .data.cpu() + .float() + .numpy() + ) + logger.info( + f"Generated audio shape: {audio_segment.shape} {audio_segment.dtype}" + ) + return audio_segment diff --git a/zerorvc/synthesizer/__init__.py b/zerorvc/synthesizer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b3ac8b3cc9fc4d1a6f9a5ceee9caf4610ac0632 --- /dev/null +++ b/zerorvc/synthesizer/__init__.py @@ -0,0 +1 @@ +from .models import SynthesizerTrnMs768NSFsid, MultiPeriodDiscriminator diff --git a/zerorvc/synthesizer/attentions.py b/zerorvc/synthesizer/attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..0d986b33093a7b108c0f908b9522f09fdb5fe858 --- /dev/null +++ b/zerorvc/synthesizer/attentions.py @@ -0,0 +1,461 @@ +import math +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F + +from . import commons +from .modules import LayerNorm + + +class Encoder(nn.Module): + def __init__( + self, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size=1, + p_dropout=0.0, + window_size=10, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = int(n_layers) + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + window_size=window_size, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + zippep = zip( + self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2 + ) + for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zippep: + y = attn_layers(x, x, attn_mask) + y = self.drop(y) + x = norm_layers_1(x + y) + + y = ffn_layers(x, x_mask) + y = self.drop(y) + x = norm_layers_2(x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__( + self, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size=1, + p_dropout=0.0, + proximal_bias=False, + proximal_init=True, + ): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + p_dropout=p_dropout, + proximal_bias=proximal_bias, + proximal_init=proximal_init, + ) + ) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append( + MultiHeadAttention( + hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + causal=True, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + h: torch.Tensor, + h_mask: torch.Tensor, + ): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to( + device=x.device, dtype=x.dtype + ) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + out_channels: int, + n_heads: int, + p_dropout=0.0, + window_size: int = None, + heads_share=True, + block_length: int = None, + proximal_bias=False, + proximal_init=False, + ): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward( + self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None + ): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, _ = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s = key.size() + t_t = query.size(2) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys( + query / math.sqrt(self.k_channels), key_relative_embeddings + ) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert ( + t_s == t_t + ), "Local attention is only available for self-attention." + block_mask = ( + torch.ones_like(scores) + .triu(-self.block_length) + .tril(self.block_length) + ) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = ( + output.transpose(2, 3).contiguous().view(b, d, t_t) + ) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x: torch.Tensor, y: torch.Tensor): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x: torch.Tensor, y: torch.Tensor): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings: torch.Tensor, length: int): + # max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length: int = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + # commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + [0, 0, pad_length, pad_length, 0, 0], + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x: torch.Tensor): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad( + x, + # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]) + [0, 1, 0, 0, 0, 0, 0, 0], + ) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad( + x_flat, + # commons.convert_pad_shape([[0, 0], [0, 0], [0, int(length) - 1]]) + [0, int(length) - 1, 0, 0, 0, 0], + ) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x: torch.Tensor): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad( + x, + # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, int(length) - 1]]) + [0, int(length) - 1, 0, 0, 0, 0, 0, 0], + ) + x_flat = x.view([batch, heads, int(length**2) + int(length * (length - 1))]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad( + x_flat, + # commons.convert_pad_shape([[0, 0], [0, 0], [int(length), 0]]) + [length, 0, 0, 0, 0, 0], + ) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length: int): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + filter_channels: int, + kernel_size: int, + p_dropout=0.0, + activation: str = None, + causal=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + self.is_activation = True if activation == "gelu" else False + # if causal: + # self.padding = self._causal_padding + # else: + # self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor: + if self.causal: + padding = self._causal_padding(x * x_mask) + else: + padding = self._same_padding(x * x_mask) + return padding + + def forward(self, x: torch.Tensor, x_mask: torch.Tensor): + x = self.conv_1(self.padding(x, x_mask)) + if self.is_activation: + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + + x = self.conv_2(self.padding(x, x_mask)) + return x * x_mask + + def _causal_padding(self, x: torch.Tensor): + if self.kernel_size == 1: + return x + pad_l: int = self.kernel_size - 1 + pad_r: int = 0 + # padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad( + x, + # commons.convert_pad_shape(padding) + [pad_l, pad_r, 0, 0, 0, 0], + ) + return x + + def _same_padding(self, x: torch.Tensor): + if self.kernel_size == 1: + return x + pad_l: int = (self.kernel_size - 1) // 2 + pad_r: int = self.kernel_size // 2 + # padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad( + x, + # commons.convert_pad_shape(padding) + [pad_l, pad_r, 0, 0, 0, 0], + ) + return x diff --git a/zerorvc/synthesizer/commons.py b/zerorvc/synthesizer/commons.py new file mode 100644 index 0000000000000000000000000000000000000000..909800d14771c9194f8d1aa41c53405b0d78a10d --- /dev/null +++ b/zerorvc/synthesizer/commons.py @@ -0,0 +1,172 @@ +from typing import List, Optional +import math + +import torch +from torch.nn import functional as F + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size: int, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +# def convert_pad_shape(pad_shape): +# l = pad_shape[::-1] +# pad_shape = [item for sublist in l for item in sublist] +# return pad_shape + + +def kl_divergence( + m_p: torch.Tensor, logs_p: torch.Tensor, m_q: torch.Tensor, logs_q: torch.Tensor +): + """KL(P||Q)""" + kl = (logs_q - logs_p) - 0.5 + kl += ( + 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q) + ) + return kl + + +def rand_gumbel(shape): + """Sample from the Gumbel distribution, protect from overflows.""" + uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 + return -torch.log(-torch.log(uniform_samples)) + + +def rand_gumbel_like(x: torch.Tensor): + g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) + return g + + +def slice_segments(x: torch.Tensor, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def slice_segments2(x: torch.Tensor, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, idx_str:idx_end] + return ret + + +def rand_slice_segments(x: torch.Tensor, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4): + position = torch.arange(length, dtype=torch.float) + num_timescales = channels // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / ( + num_timescales - 1 + ) + inv_timescales = min_timescale * torch.exp( + torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment + ) + scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) + signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) + signal = F.pad(signal, [0, 0, 0, channels % 2]) + signal = signal.view(1, channels, length) + return signal + + +def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return x + signal.to(dtype=x.dtype, device=x.device) + + +def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): + b, channels, length = x.size() + signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) + return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +# def convert_pad_shape(pad_shape): +# l = pad_shape[::-1] +# pad_shape = [item for sublist in l for item in sublist] +# return pad_shape + + +def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]: + return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist() + + +def shift_1d(x): + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] + return x + + +def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm diff --git a/zerorvc/synthesizer/models.py b/zerorvc/synthesizer/models.py new file mode 100644 index 0000000000000000000000000000000000000000..b13b13849dee8e2e0eae9ff5f8b5e30bbfa1ab14 --- /dev/null +++ b/zerorvc/synthesizer/models.py @@ -0,0 +1,875 @@ +import math +import logging +from typing import List, Literal, Optional + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm, spectral_norm +from torch.nn.utils.parametrizations import weight_norm +from huggingface_hub import PyTorchModelHubMixin + +from . import attentions, commons, modules +from .commons import get_padding, init_weights + +logger = logging.getLogger(__name__) + + +class TextEncoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int, + p_dropout: float, + f0=True, + ): + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = float(p_dropout) + self.emb_phone = nn.Linear(in_channels, hidden_channels) + self.lrelu = nn.LeakyReLU(0.1, inplace=True) + if f0 == True: + self.emb_pitch = nn.Embedding(256, hidden_channels) # pitch 256 + self.encoder = attentions.Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + float(p_dropout), + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward( + self, + phone: torch.Tensor, + pitch: torch.Tensor, + lengths: torch.Tensor, + skip_head: Optional[torch.Tensor] = None, + ): + if pitch is None: + x = self.emb_phone(phone) + else: + x = self.emb_phone(phone) + self.emb_pitch(pitch) + x = x * math.sqrt(self.hidden_channels) # [b, t, h] + x = self.lrelu(x) + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(commons.sequence_mask(lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.encoder(x * x_mask, x_mask) + if skip_head is not None: + assert isinstance(skip_head, torch.Tensor) + head = int(skip_head.item()) + x = x[:, :, head:] + x_mask = x_mask[:, :, head:] + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + return m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__( + self, + channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: float, + n_layers: int, + n_flows=4, + gin_channels=0, + ): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append( + modules.ResidualCouplingLayer( + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + mean_only=True, + ) + ) + self.flows.append(modules.Flip()) + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in self.flows[::-1]: + x, _ = flow.forward(x, x_mask, g=g, reverse=reverse) + return x + + def remove_weight_norm(self): + for i in range(self.n_flows): + self.flows[i * 2].remove_weight_norm() + + +class PosteriorEncoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int, + kernel_size: int, + dilation_rate: float, + n_layers: int, + gin_channels=0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = modules.WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=gin_channels, + ) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward( + self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None + ): + x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to( + x.dtype + ) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + def remove_weight_norm(self): + self.enc.remove_weight_norm() + + +class Generator(torch.nn.Module): + def __init__( + self, + initial_channel: int, + resblock: Literal["1", "2"], + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel: int, + upsample_kernel_sizes, + gin_channels=0, + ): + super().__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = nn.Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward( + self, + x: torch.Tensor, + g: Optional[torch.Tensor] = None, + n_res: Optional[torch.Tensor] = None, + ): + if n_res is not None: + assert isinstance(n_res, torch.Tensor) + n = int(n_res.item()) + if n != x.shape[-1]: + x = F.interpolate(x, size=n, mode="linear") + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, modules.LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class SineGen(torch.nn.Module): + """Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(torch.pi) or cos(0) + """ + + def __init__( + self, + samp_rate, + harmonic_num=0, + sine_amp=0.1, + noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False, + ): + super().__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + + def _f02uv(self, f0): + # generate uv signal + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + def forward(self, f0: torch.Tensor, upp: int): + """sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + with torch.no_grad(): + f0 = f0[:, None].transpose(1, 2) + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in range(self.harmonic_num): + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * ( + idx + 2 + ) # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + rad_values = ( + f0_buf / self.sampling_rate + ) % 1 ###%1意味着n_har的乘积无法后处理优化 + rand_ini = torch.rand( + f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device + ) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + tmp_over_one = torch.cumsum( + rad_values, 1 + ) # % 1 #####%1意味着后面的cumsum无法再优化 + tmp_over_one *= upp + tmp_over_one = F.interpolate( + tmp_over_one.transpose(2, 1), + scale_factor=float(upp), + mode="linear", + align_corners=True, + ).transpose(2, 1) + rad_values = F.interpolate( + rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest" + ).transpose( + 2, 1 + ) ####### + tmp_over_one %= 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + sine_waves = torch.sin( + torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi + ) + sine_waves = sine_waves * self.sine_amp + uv = self._f02uv(f0) + uv = F.interpolate( + uv.transpose(2, 1), scale_factor=float(upp), mode="nearest" + ).transpose(2, 1) + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleHnNSF(torch.nn.Module): + """SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__( + self, + sampling_rate: int, + harmonic_num=0, + sine_amp=0.1, + add_noise_std=0.003, + voiced_threshod=0, + ): + super().__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + # to produce sine waveforms + self.l_sin_gen = SineGen( + sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod + ) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + # self.ddtype:int = -1 + + def forward(self, x: torch.Tensor, upp: int = 1): + # if self.ddtype ==-1: + # self.ddtype = self.l_linear.weight.dtype + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + # print(x.dtype,sine_wavs.dtype,self.l_linear.weight.dtype) + # sine_merge = self.l_tanh(self.l_linear(sine_wavs.to(x))) + # print(sine_wavs.dtype,self.ddtype) + # if sine_wavs.dtype != self.l_linear.weight.dtype: + sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + return sine_merge, None, None # noise, uv + + +class GeneratorNSF(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels, + sr, + ): + super().__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + + self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates)) + self.m_source = SourceModuleHnNSF( + sampling_rate=sr, + harmonic_num=0, + ) + self.noise_convs = nn.ModuleList() + self.conv_pre = nn.Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + c_cur = upsample_initial_channel // (2 ** (i + 1)) + self.ups.append( + weight_norm( + nn.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + if i + 1 < len(upsample_rates): + stride_f0 = math.prod(upsample_rates[i + 1 :]) + self.noise_convs.append( + nn.Conv1d( + 1, + c_cur, + kernel_size=stride_f0 * 2, + stride=stride_f0, + padding=stride_f0 // 2, + ) + ) + else: + self.noise_convs.append(nn.Conv1d(1, c_cur, kernel_size=1)) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + self.upp = math.prod(upsample_rates) + + self.lrelu_slope = modules.LRELU_SLOPE + + def forward( + self, + x, + f0, + g: Optional[torch.Tensor] = None, + n_res: Optional[torch.Tensor] = None, + ): + har_source, noi_source, uv = self.m_source(f0, self.upp) + har_source = har_source.transpose(1, 2) + if n_res is not None: + assert isinstance(n_res, torch.Tensor) + n = int(n_res.item()) + if n * self.upp != har_source.shape[-1]: + har_source = F.interpolate(har_source, size=n * self.upp, mode="linear") + if n != x.shape[-1]: + x = F.interpolate(x, size=n, mode="linear") + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)): + if i < self.num_upsamples: + x = F.leaky_relu(x, self.lrelu_slope) + x = ups(x) + x_source = noise_convs(har_source) + x = x + x_source + xs: torch.Tensor = None + l = [i * self.num_kernels + j for j in range(self.num_kernels)] + for j, resblock in enumerate(self.resblocks): + if j in l: + if xs is None: + xs = resblock(x) + else: + xs += resblock(x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class SynthesizerTrnMs256NSFsid(nn.Module): + def __init__( + self, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + sr, + ): + super().__init__() + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = float(p_dropout) + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.gin_channels = gin_channels + # self.hop_length = hop_length# + self.spk_embed_dim = spk_embed_dim + self.enc_p = TextEncoder( + 256, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + float(p_dropout), + ) + self.dec = GeneratorNSF( + inter_channels, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=gin_channels, + sr=sr, + ) + self.enc_q = PosteriorEncoder( + spec_channels, + inter_channels, + hidden_channels, + 5, + 1, + 16, + gin_channels=gin_channels, + ) + self.flow = ResidualCouplingBlock( + inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels + ) + self.emb_g = nn.Embedding(self.spk_embed_dim, gin_channels) + logger.debug( + "gin_channels: " + + str(gin_channels) + + ", self.spk_embed_dim: " + + str(self.spk_embed_dim) + ) + + def remove_weight_norm(self): + self.dec.remove_weight_norm() + self.flow.remove_weight_norm() + if hasattr(self, "enc_q"): + self.enc_q.remove_weight_norm() + + def forward( + self, + phone: torch.Tensor, + phone_lengths: torch.Tensor, + pitch: torch.Tensor, + pitchf: torch.Tensor, + y: torch.Tensor, + y_lengths: torch.Tensor, + ds: Optional[torch.Tensor] = None, + ): # 这里ds是id,[bs,1] + # print(1,pitch.shape)#[bs,t] + g = self.emb_g(ds).unsqueeze(-1) # [b, 256, 1]##1是t,广播的 + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + z_slice, ids_slice = commons.rand_slice_segments( + z, y_lengths, self.segment_size + ) + # print(-1,pitchf.shape,ids_slice,self.segment_size,self.hop_length,self.segment_size//self.hop_length) + pitchf = commons.slice_segments2(pitchf, ids_slice, self.segment_size) + # print(-2,pitchf.shape,z_slice.shape) + o = self.dec(z_slice, pitchf, g=g) + return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer( + self, + phone: torch.Tensor, + phone_lengths: torch.Tensor, + pitch: torch.Tensor, + nsff0: torch.Tensor, + sid: torch.Tensor, + skip_head: Optional[torch.Tensor] = None, + return_length: Optional[torch.Tensor] = None, + return_length2: Optional[torch.Tensor] = None, + ): + g = self.emb_g(sid).unsqueeze(-1) + if skip_head is not None and return_length is not None: + assert isinstance(skip_head, torch.Tensor) + assert isinstance(return_length, torch.Tensor) + head = int(skip_head.item()) + length = int(return_length.item()) + flow_head = torch.clamp(skip_head - 24, min=0) + dec_head = head - int(flow_head.item()) + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths, flow_head) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + z = self.flow(z_p, x_mask, g=g, reverse=True) + z = z[:, :, dec_head : dec_head + length] + x_mask = x_mask[:, :, dec_head : dec_head + length] + nsff0 = nsff0[:, head : head + length] + else: + m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths) + z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask + z = self.flow(z_p, x_mask, g=g, reverse=True) + o = self.dec(z * x_mask, nsff0, g=g, n_res=return_length2) + return o, x_mask, (z, z_p, m_p, logs_p) + + +class SynthesizerTrnMs768NSFsid(SynthesizerTrnMs256NSFsid, PyTorchModelHubMixin): + def __init__( + self, + spec_channels: int, + segment_size: int, + inter_channels: int, + hidden_channels: int, + filter_channels: int, + n_heads: int, + n_layers: int, + kernel_size: int, + p_dropout: float, + resblock: Literal["1", "2"], + resblock_kernel_sizes: List[int], + resblock_dilation_sizes: list[list[int]], + upsample_rates: list[int], + upsample_initial_channel: int, + upsample_kernel_sizes: list[int], + spk_embed_dim: int, + gin_channels: int, + sr: int, + ): + super().__init__( + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + spk_embed_dim, + gin_channels, + sr, + ) + del self.enc_p + self.enc_p = TextEncoder( + 768, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + float(p_dropout), + ) + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super().__init__() + # periods = [2, 3, 5, 7, 11, 17] + periods = [2, 3, 5, 7, 11, 17, 23, 37] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] # + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + # for j in range(len(fmap_r)): + # print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super().__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv1d(1, 16, 15, 1, padding=7)), + norm_f(nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + nn.Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + nn.Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + nn.Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + nn.Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + nn.Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, modules.LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap diff --git a/zerorvc/synthesizer/modules.py b/zerorvc/synthesizer/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..4c3b18b07b004f44cef49523357f682b71f77769 --- /dev/null +++ b/zerorvc/synthesizer/modules.py @@ -0,0 +1,550 @@ +import math +from typing import Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.utils import remove_weight_norm +from torch.nn.utils.parametrizations import weight_norm + +from . import commons +from .commons import get_padding, init_weights +from .transforms import piecewise_rational_quadratic_transform + +LRELU_SLOPE = 0.1 + + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + ): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = float(p_dropout) + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append( + nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(float(p_dropout))) + for _ in range(n_layers - 1): + self.conv_layers.append( + nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dialted and Depth-Separable Convolution + """ + + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = float(p_dropout) + + self.drop = nn.Dropout(float(p_dropout)) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size**i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append( + nn.Conv1d( + channels, + channels, + kernel_size, + groups=channels, + dilation=dilation, + padding=padding, + ) + ) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g: Optional[torch.Tensor] = None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(nn.Module): + def __init__( + self, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0, + p_dropout=0, + ): + super().__init__() + assert kernel_size % 2 == 1 + self.hidden_channels = hidden_channels + self.kernel_size = (kernel_size,) + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = float(p_dropout) + + self.in_layers = nn.ModuleList() + self.res_skip_layers = nn.ModuleList() + self.drop = nn.Dropout(float(p_dropout)) + + if gin_channels != 0: + cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1) + self.cond_layer = weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = dilation_rate**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = nn.Conv1d( + hidden_channels, + 2 * hidden_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None + ): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i, (in_layer, res_skip_layer) in enumerate( + zip(self.in_layers, self.res_skip_layers) + ): + x_in = in_layer(x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] + else: + g_l = torch.zeros_like(x_in) + + acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = res_skip_layer(acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:, : self.hidden_channels, :] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + remove_weight_norm(self.cond_layer) + for l in self.in_layers: + remove_weight_norm(l) + for l in self.res_skip_layers: + remove_weight_norm(l) + + +class ResBlock1(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super().__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + self.lrelu_slope = LRELU_SLOPE + + def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, self.lrelu_slope) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, self.lrelu_slope) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super().__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + self.lrelu_slope = LRELU_SLOPE + + def forward(self, x, x_mask: Optional[torch.Tensor] = None): + for c in self.convs: + xt = F.leaky_relu(x, self.lrelu_slope) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + # torch.jit.script() Compiled functions \ + # can't take variable number of arguments or \ + # use keyword-only arguments with defaults + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x, torch.zeros([1], device=x.device) + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels, 1)) + self.logs = nn.Parameter(torch.zeros(channels, 1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1, 2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__( + self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False, + ): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN( + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=float(p_dropout), + gin_channels=gin_channels, + ) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse: bool = False, + ): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels] * 2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1, 2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x, torch.zeros([1]) + + def remove_weight_norm(self): + self.enc.remove_weight_norm() + + +class ConvFlow(nn.Module): + def __init__( + self, + in_channels, + filter_channels, + kernel_size, + n_layers, + num_bins=10, + tail_bound=5.0, + ): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0) + self.proj = nn.Conv1d( + filter_channels, self.half_channels * (num_bins * 3 - 1), 1 + ) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + x_mask: torch.Tensor, + g: Optional[torch.Tensor] = None, + reverse=False, + ): + x0, x1 = torch.split(x, [self.half_channels] * 2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt( + self.filter_channels + ) + unnormalized_derivatives = h[..., 2 * self.num_bins :] + + x1, logabsdet = piecewise_rational_quadratic_transform( + x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails="linear", + tail_bound=self.tail_bound, + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1, 2]) + if not reverse: + return x, logdet + else: + return x diff --git a/zerorvc/synthesizer/transforms.py b/zerorvc/synthesizer/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..cd68133b1b66dea434021435150633c8e47dc0c0 --- /dev/null +++ b/zerorvc/synthesizer/transforms.py @@ -0,0 +1,207 @@ +import numpy as np +import torch +from torch.nn import functional as F + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = {"tails": tails, "tail_bound": tail_bound} + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs, + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1 + + +def unconstrained_rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails="linear", + tail_bound=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == "linear": + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError("{} tails are not implemented.".format(tails)) + + ( + outputs[inside_interval_mask], + logabsdet[inside_interval_mask], + ) = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + return outputs, logabsdet + + +def rational_quadratic_spline( + inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0.0, + right=1.0, + bottom=0.0, + top=1.0, + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE, +): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError("Input to a transform is not within its domain") + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError("Minimal bin width too large for the number of bins") + if min_bin_height * num_bins > 1.0: + raise ValueError("Minimal bin height too large for the number of bins") + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + input_heights * (input_delta - input_derivatives) + b = input_heights * input_derivatives - (inputs - input_cumheights) * ( + input_derivatives + input_derivatives_plus_one - 2 * input_delta + ) + c = -input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * ( + input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta + ) + denominator = input_delta + ( + (input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta + ) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * ( + input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2) + ) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/zerorvc/trainer.py b/zerorvc/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..63d15b80a6c016cf83f252fac3941c8d97455c34 --- /dev/null +++ b/zerorvc/trainer.py @@ -0,0 +1,709 @@ +import os +from glob import glob +from logging import getLogger +from typing import Literal, Optional, Tuple +from pathlib import Path +from threading import Thread +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from accelerate import Accelerator +from datasets import Dataset +from .pretrained import pretrained_checkpoints +from .constants import * +from torch.utils.tensorboard import SummaryWriter +import time +from tqdm.auto import tqdm +from huggingface_hub import HfApi, upload_folder + +from .synthesizer import commons +from .synthesizer.models import ( + SynthesizerTrnMs768NSFsid, + MultiPeriodDiscriminator, +) + +from .utils.losses import ( + discriminator_loss, + feature_loss, + generator_loss, + kl_loss, +) +from .utils.mel_processing import mel_spectrogram_torch, spec_to_mel_torch +from .utils.data_utils import TextAudioCollateMultiNSFsid + +logger = getLogger(__name__) + + +class TrainingCheckpoint: + def __init__( + self, + epoch: int, + G: SynthesizerTrnMs768NSFsid, + D: MultiPeriodDiscriminator, + optimizer_G: torch.optim.AdamW, + optimizer_D: torch.optim.AdamW, + scheduler_G: torch.optim.lr_scheduler.ExponentialLR, + scheduler_D: torch.optim.lr_scheduler.ExponentialLR, + loss_gen: float, + loss_fm: float, + loss_mel: float, + loss_kl: float, + loss_gen_all: float, + loss_disc: float, + ): + self.epoch = epoch + self.G = G + self.D = D + self.optimizer_G = optimizer_G + self.optimizer_D = optimizer_D + self.scheduler_G = scheduler_G + self.scheduler_D = scheduler_D + self.loss_gen = loss_gen + self.loss_fm = loss_fm + self.loss_mel = loss_mel + self.loss_kl = loss_kl + self.loss_gen_all = loss_gen_all + self.loss_disc = loss_disc + + def save( + self, + exp_dir="./", + g_checkpoint: str | None = None, + d_checkpoint: str | None = None, + ): + g_path = g_checkpoint if g_checkpoint is not None else f"G_latest.pth" + d_path = d_checkpoint if d_checkpoint is not None else f"D_latest.pth" + torch.save( + { + "epoch": self.epoch, + "model": self.G.state_dict(), + "optimizer": self.optimizer_G.state_dict(), + "scheduler": self.scheduler_G.state_dict(), + "loss_gen": self.loss_gen, + "loss_fm": self.loss_fm, + "loss_mel": self.loss_mel, + "loss_kl": self.loss_kl, + "loss_gen_all": self.loss_gen_all, + "loss_disc": self.loss_disc, + }, + os.path.join(exp_dir, g_path), + ) + torch.save( + { + "epoch": self.epoch, + "model": self.D.state_dict(), + "optimizer": self.optimizer_D.state_dict(), + "scheduler": self.scheduler_D.state_dict(), + }, + os.path.join(exp_dir, d_path), + ) + + +def latest_checkpoint_file(files: list[str]) -> str: + try: + return max(files, key=lambda x: int(Path(x).stem.split("_")[1])) + except: + return max(files, key=os.path.getctime) + + +class RVCTrainer: + def __init__( + self, + exp_dir: str, + dataset_train: Dataset, + dataset_test: Optional[Dataset] = None, + sr: int = SR_48K, + ): + self.exp_dir = exp_dir + self.dataset_train = dataset_train + self.dataset_test = dataset_test + self.sr = sr + self.writer = SummaryWriter( + os.path.join(exp_dir, "logs", time.strftime("%Y%m%d-%H%M%S")) + ) + + def latest_checkpoint(self, fallback_to_pretrained: bool = True): + files_g = glob(os.path.join(self.exp_dir, "G_*.pth")) + if not files_g: + return pretrained_checkpoints() if fallback_to_pretrained else None + latest_g = latest_checkpoint_file(files_g) + + files_d = glob(os.path.join(self.exp_dir, "D_*.pth")) + if not files_d: + return pretrained_checkpoints() if fallback_to_pretrained else None + latest_d = latest_checkpoint_file(files_d) + + return latest_g, latest_d + + def setup_models( + self, + resume_from: Tuple[str, str] | None = None, + accelerator: Accelerator | None = None, + lr=1e-4, + lr_decay=0.999875, + betas: Tuple[float, float] = (0.8, 0.99), + eps=1e-9, + use_spectral_norm=False, + segment_size=17280, + filter_length=N_FFT, + hop_length=HOP_LENGTH, + inter_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.0, + resblock: Literal["1", "2"] = "1", + resblock_kernel_sizes: list[int] = [3, 7, 11], + resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_initial_channel=512, + upsample_rates: list[int] = [12, 10, 2, 2], + upsample_kernel_sizes: list[int] = [24, 20, 4, 4], + spk_embed_dim=109, + gin_channels=256, + ) -> Tuple[ + SynthesizerTrnMs768NSFsid, + MultiPeriodDiscriminator, + torch.optim.AdamW, + torch.optim.AdamW, + torch.optim.lr_scheduler.ExponentialLR, + torch.optim.lr_scheduler.ExponentialLR, + int, + ]: + if accelerator is None: + accelerator = Accelerator() + + G = SynthesizerTrnMs768NSFsid( + spec_channels=filter_length // 2 + 1, + segment_size=segment_size // hop_length, + inter_channels=inter_channels, + hidden_channels=hidden_channels, + filter_channels=filter_channels, + n_heads=n_heads, + n_layers=n_layers, + kernel_size=kernel_size, + p_dropout=p_dropout, + resblock=resblock, + resblock_kernel_sizes=resblock_kernel_sizes, + resblock_dilation_sizes=resblock_dilation_sizes, + upsample_initial_channel=upsample_initial_channel, + upsample_rates=upsample_rates, + upsample_kernel_sizes=upsample_kernel_sizes, + spk_embed_dim=spk_embed_dim, + gin_channels=gin_channels, + sr=self.sr, + ).to(accelerator.device) + D = MultiPeriodDiscriminator(use_spectral_norm=use_spectral_norm).to( + accelerator.device + ) + + optimizer_G = torch.optim.AdamW( + G.parameters(), + lr, + betas=betas, + eps=eps, + ) + optimizer_D = torch.optim.AdamW( + D.parameters(), + lr, + betas=betas, + eps=eps, + ) + + if resume_from is not None: + g_checkpoint, d_checkpoint = resume_from + logger.info(f"Resuming from {g_checkpoint} and {d_checkpoint}") + + G_checkpoint = torch.load( + g_checkpoint, map_location=accelerator.device, weights_only=True + ) + D_checkpoint = torch.load( + d_checkpoint, map_location=accelerator.device, weights_only=True + ) + + if "epoch" in G_checkpoint: + finished_epoch = int(G_checkpoint["epoch"]) + try: + finished_epoch = int(Path(g_checkpoint).stem.split("_")[1]) + except: + finished_epoch = 0 + + scheduler_G = torch.optim.lr_scheduler.ExponentialLR( + optimizer_G, gamma=lr_decay, last_epoch=finished_epoch - 1 + ) + scheduler_D = torch.optim.lr_scheduler.ExponentialLR( + optimizer_D, gamma=lr_decay, last_epoch=finished_epoch - 1 + ) + + G.load_state_dict(G_checkpoint["model"]) + if "optimizer" in G_checkpoint: + optimizer_G.load_state_dict(G_checkpoint["optimizer"]) + if "scheduler" in G_checkpoint: + scheduler_G.load_state_dict(G_checkpoint["scheduler"]) + + D.load_state_dict(D_checkpoint["model"]) + if "optimizer" in D_checkpoint: + optimizer_D.load_state_dict(D_checkpoint["optimizer"]) + if "scheduler" in D_checkpoint: + scheduler_D.load_state_dict(D_checkpoint["scheduler"]) + else: + finished_epoch = 0 + scheduler_G = torch.optim.lr_scheduler.ExponentialLR( + optimizer_G, gamma=lr_decay, last_epoch=-1 + ) + scheduler_D = torch.optim.lr_scheduler.ExponentialLR( + optimizer_D, gamma=lr_decay, last_epoch=-1 + ) + + G, D, optimizer_G, optimizer_D = accelerator.prepare( + G, D, optimizer_G, optimizer_D + ) + + return G, D, optimizer_G, optimizer_D, scheduler_G, scheduler_D, finished_epoch + + def setup_dataloader( + self, + dataset: Dataset, + batch_size=1, + shuffle=True, + accelerator: Accelerator | None = None, + ): + if accelerator is None: + accelerator = Accelerator() + + dataset = dataset.with_format("torch", device=accelerator.device) + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + collate_fn=TextAudioCollateMultiNSFsid(), + ) + loader = accelerator.prepare(loader) + return loader + + def run( + self, + G, + D, + optimizer_G, + optimizer_D, + scheduler_G, + scheduler_D, + finished_epoch, + loader_train, + loader_test, + accelerator: Accelerator | None = None, + epochs=100, + segment_size=17280, + filter_length=N_FFT, + hop_length=HOP_LENGTH, + n_mel_channels=N_MELS, + win_length=WIN_LENGTH, + mel_fmin=0.0, + mel_fmax: float | None = None, + c_mel=45, + c_kl=1.0, + upload_to_hub: str | None = None, + upload_window_minutes=5, + ): + if accelerator is None: + accelerator = Accelerator() + + if accelerator.is_main_process: + logger.info("Start training") + + upload_state_last = 0.0 + + prev_loss_gen = -1.0 + prev_loss_fm = -1.0 + prev_loss_mel = -1.0 + prev_loss_kl = -1.0 + prev_loss_disc = -1.0 + prev_loss_gen_all = -1.0 + + with accelerator.autocast(): + epoch_iterator = tqdm( + range(1, epochs + 1), + desc="Training", + disable=not accelerator.is_main_process, + ) + for epoch in epoch_iterator: + if epoch <= finished_epoch: + continue + + G.train() + D.train() + + epoch_loss_gen = 0.0 + epoch_loss_fm = 0.0 + epoch_loss_mel = 0.0 + epoch_loss_kl = 0.0 + epoch_loss_disc = 0.0 + epoch_loss_gen_all = 0.0 + num_batches = 0 + + batch_iterator = tqdm( + loader_train, + desc=f"Epoch {epoch}", + leave=False, + disable=not accelerator.is_main_process, + ) + for batch in batch_iterator: + ( + phone, + phone_lengths, + pitch, + pitchf, + spec, + spec_lengths, + wave, + wave_lengths, + sid, + ) = batch + + # Generator + optimizer_G.zero_grad() + ( + y_hat, + ids_slice, + x_mask, + z_mask, + (z, z_p, m_p, logs_p, m_q, logs_q), + ) = G( + phone, + phone_lengths, + pitch, + pitchf, + spec, + spec_lengths, + sid, + ) + mel = spec_to_mel_torch( + spec, + filter_length, + n_mel_channels, + self.sr, + mel_fmin, + mel_fmax, + ) + y_mel = commons.slice_segments( + mel, ids_slice, segment_size // hop_length + ) + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + filter_length, + n_mel_channels, + self.sr, + hop_length, + win_length, + mel_fmin, + mel_fmax, + ) + wave = commons.slice_segments( + wave, ids_slice * hop_length, segment_size + ) + + # Discriminator + optimizer_D.zero_grad() + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = D(wave, y_hat.detach()) + + # Update Discriminator + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss( + y_d_hat_r, y_d_hat_g + ) + accelerator.backward(loss_disc) + optimizer_D.step() + + # Re-compute discriminator output (since we just got a "better" discriminator) + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = D(wave, y_hat) + + # Update Generator + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_mel = F.l1_loss(y_mel, y_hat_mel) * c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * c_kl + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + accelerator.backward(loss_gen_all) + optimizer_G.step() + + prev_loss_gen = loss_gen.item() + prev_loss_fm = loss_fm.item() + prev_loss_mel = loss_mel.item() + prev_loss_kl = loss_kl.item() + prev_loss_disc = loss_disc.item() + prev_loss_gen_all = loss_gen_all.item() + + # Update progress bar with current losses + if accelerator.is_main_process: + batch_iterator.set_postfix( + { + "g_loss": f"{prev_loss_gen:.4f}", + "d_loss": f"{prev_loss_disc:.4f}", + "mel_loss": f"{prev_loss_mel:.4f}", + "total": f"{prev_loss_gen_all:.4f}", + } + ) + + epoch_loss_gen += prev_loss_gen + epoch_loss_fm += prev_loss_fm + epoch_loss_mel += prev_loss_mel + epoch_loss_kl += prev_loss_kl + epoch_loss_disc += prev_loss_disc + epoch_loss_gen_all += prev_loss_gen_all + num_batches += 1 + + scheduler_G.step() + scheduler_D.step() + + if accelerator.is_main_process and num_batches > 0: + avg_gen = epoch_loss_gen / num_batches + avg_disc = epoch_loss_disc / num_batches + avg_fm = epoch_loss_fm / num_batches + avg_mel = epoch_loss_mel / num_batches + avg_kl = epoch_loss_kl / num_batches + avg_total = epoch_loss_gen_all / num_batches + + logger.info( + f"Epoch {epoch} | " + f"Generator Loss: {avg_gen:.4f} | " + f"Discriminator Loss: {avg_disc:.4f} | " + f"Mel Loss: {avg_mel:.4f} | " + f"Total Loss: {avg_total:.4f}" + ) + + # Update epoch progress bar + epoch_iterator.set_postfix( + { + "g_loss": f"{avg_gen:.4f}", + "d_loss": f"{avg_disc:.4f}", + "total": f"{avg_total:.4f}", + } + ) + + self.writer.add_scalar("Loss/Generator", avg_gen, epoch) + self.writer.add_scalar("Loss/Feature_Matching", avg_fm, epoch) + self.writer.add_scalar("Loss/Mel", avg_mel, epoch) + self.writer.add_scalar("Loss/KL", avg_kl, epoch) + self.writer.add_scalar("Loss/Discriminator", avg_disc, epoch) + self.writer.add_scalar("Loss/Generator_Total", avg_total, epoch) + self.writer.add_scalar( + "Learning_Rate/Generator", + scheduler_G.get_last_lr()[0], + epoch, + ) + self.writer.add_scalar( + "Learning_Rate/Discriminator", + scheduler_D.get_last_lr()[0], + epoch, + ) + + if loader_test is not None: + with torch.no_grad(): + sample_idx = 0 + test_iterator = tqdm( + loader_test, + desc=f"Testing epoch {epoch}", + leave=False, + disable=not accelerator.is_main_process, + ) + for batch_idx, ( + phone, + phone_lengths, + pitch, + pitchf, + spec, + spec_lengths, + wave, + wave_lengths, + sid, + ) in enumerate(test_iterator): + # Generate audio for each sample in the batch + audio_segments = G.infer( + phone, phone_lengths, pitch, pitchf, sid + )[0] + + # Log each audio sample in the batch + for i, audio in enumerate(audio_segments): + audio_numpy = audio[0].data.cpu().float().numpy() + self.writer.add_audio( + f"Audio/{sample_idx}", + audio_numpy, + epoch, + sample_rate=self.sr, + ) + sample_idx += 1 + + res = TrainingCheckpoint( + epoch, + G, + D, + optimizer_G, + optimizer_D, + scheduler_G, + scheduler_D, + prev_loss_gen, + prev_loss_fm, + prev_loss_mel, + prev_loss_kl, + prev_loss_gen_all, + prev_loss_disc, + ) + + res.save(self.exp_dir) + G.save_pretrained(self.exp_dir) + + if upload_to_hub is not None: + if ( + time.time() - upload_state_last > 60 * upload_window_minutes + or epoch == epochs + ): + try: + self.push_to_hub(upload_to_hub) + upload_state_last = time.time() + except Exception: + logger.error(f"Failed to upload to Hub.", exc_info=1) + else: + next_upload = 60 * upload_window_minutes - ( + time.time() - upload_state_last + ) + logger.info( + f"Skipping upload to Hub (next upload in {next_upload:.0f} seconds)" + ) + + def train( + self, + resume_from: Tuple[str, str] | None = None, + accelerator: Accelerator | None = None, + batch_size=1, + epochs=100, + lr=1e-4, + lr_decay=0.999875, + betas: Tuple[float, float] = (0.8, 0.99), + eps=1e-9, + use_spectral_norm=False, + segment_size=17280, + filter_length=N_FFT, + hop_length=HOP_LENGTH, + inter_channels=192, + hidden_channels=192, + filter_channels=768, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.0, + resblock: Literal["1", "2"] = "1", + resblock_kernel_sizes: list[int] = [3, 7, 11], + resblock_dilation_sizes: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_initial_channel=512, + upsample_rates: list[int] = [12, 10, 2, 2], + upsample_kernel_sizes: list[int] = [24, 20, 4, 4], + spk_embed_dim=109, + gin_channels=256, + n_mel_channels=N_MELS, + win_length=WIN_LENGTH, + mel_fmin=0.0, + mel_fmax: float | None = None, + c_mel=45, + c_kl=1.0, + upload_to_hub: str | None = None, + ): + if not os.path.exists(self.exp_dir): + os.makedirs(self.exp_dir) + + if accelerator is None: + accelerator = Accelerator() + + ( + G, + D, + optimizer_G, + optimizer_D, + scheduler_G, + scheduler_D, + finished_epoch, + ) = self.setup_models( + resume_from=resume_from or self.latest_checkpoint(), + accelerator=accelerator, + lr=lr, + lr_decay=lr_decay, + betas=betas, + eps=eps, + use_spectral_norm=use_spectral_norm, + segment_size=segment_size, + filter_length=filter_length, + hop_length=hop_length, + inter_channels=inter_channels, + hidden_channels=hidden_channels, + filter_channels=filter_channels, + n_heads=n_heads, + n_layers=n_layers, + kernel_size=kernel_size, + p_dropout=p_dropout, + resblock=resblock, + resblock_kernel_sizes=resblock_kernel_sizes, + resblock_dilation_sizes=resblock_dilation_sizes, + upsample_initial_channel=upsample_initial_channel, + upsample_rates=upsample_rates, + upsample_kernel_sizes=upsample_kernel_sizes, + spk_embed_dim=spk_embed_dim, + gin_channels=gin_channels, + ) + + loader_train = self.setup_dataloader( + self.dataset_train, + batch_size=batch_size, + accelerator=accelerator, + ) + + loader_test = ( + self.setup_dataloader( + self.dataset_test, + batch_size=batch_size, + accelerator=accelerator, + shuffle=False, + ) + if self.dataset_test is not None + else None + ) + + return self.run( + G, + D, + optimizer_G, + optimizer_D, + scheduler_G, + scheduler_D, + finished_epoch, + loader_train, + loader_test, + accelerator, + epochs=epochs, + segment_size=segment_size, + filter_length=filter_length, + hop_length=hop_length, + n_mel_channels=n_mel_channels, + win_length=win_length, + mel_fmin=mel_fmin, + mel_fmax=mel_fmax, + c_mel=c_mel, + c_kl=c_kl, + upload_to_hub=upload_to_hub, + ) + + def push_to_hub(self, repo: str, private: bool = True): + if not os.path.exists(self.exp_dir): + raise FileNotFoundError("exp_dir not found") + + api = HfApi() + repo_id = api.create_repo(repo_id=repo, private=private, exist_ok=True).repo_id + + return upload_folder( + repo_id=repo_id, + folder_path=self.exp_dir, + commit_message="Upload via ZeroRVC", + ) + + def __del__(self): + if hasattr(self, "writer"): + self.writer.close() diff --git a/zerorvc/utils/__init__.py b/zerorvc/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c1c353ce548d21eedfad46c8be289716d4bcccd --- /dev/null +++ b/zerorvc/utils/__init__.py @@ -0,0 +1,3 @@ +from .data_utils import * +from .mel_processing import * +from .losses import * diff --git a/zerorvc/utils/data_utils.py b/zerorvc/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e874262448c3108b92177040afc1a26995eb7d6f --- /dev/null +++ b/zerorvc/utils/data_utils.py @@ -0,0 +1,85 @@ +import logging + +import torch +import torch.utils.data + +logger = logging.getLogger(__name__) + + +class TextAudioCollateMultiNSFsid: + """Zero-pads model inputs and targets""" + + def __init__(self): + pass + + def __call__(self, batch): + """Collate's training batch from normalized text and aduio + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized] + """ + device = batch[0]["spec"].device + + with device: + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort( + torch.tensor([x["spec"].size(1) for x in batch], dtype=torch.long), + dim=0, + descending=True, + ) + + max_spec_len = max([x["spec"].size(1) for x in batch]) + max_wave_len = max([x["wav_gt"]["array"].size(0) for x in batch]) + spec_lengths = torch.zeros(len(batch), dtype=torch.long) + wave_lengths = torch.zeros(len(batch), dtype=torch.long) + spec_padded = torch.zeros( + len(batch), batch[0]["spec"].size(0), max_spec_len, dtype=torch.float32 + ) + wave_padded = torch.zeros(len(batch), 1, max_wave_len, dtype=torch.float32) + + max_phone_len = max([x["hubert_feats"].size(0) for x in batch]) + phone_lengths = torch.zeros(len(batch), dtype=torch.long) + phone_padded = torch.zeros( + len(batch), + max_phone_len, + batch[0]["hubert_feats"].shape[1], + dtype=torch.float32, + ) # (spec, wav, phone, pitch) + pitch_padded = torch.zeros(len(batch), max_phone_len, dtype=torch.long) + pitchf_padded = torch.zeros(len(batch), max_phone_len, dtype=torch.float32) + # dv = torch.FloatTensor(len(batch), 256)#gin=256 + sid = torch.zeros(len(batch), dtype=torch.long) + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + spec = row["spec"] + spec_padded[i, :, : spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wave = row["wav_gt"]["array"] + wave_padded[i, :, : wave.size(0)] = wave + wave_lengths[i] = wave.size(0) + + phone = row["hubert_feats"] + phone_padded[i, : phone.size(0), :] = phone + phone_lengths[i] = phone.size(0) + + pitch = row["f0"] + pitch_padded[i, : pitch.size(0)] = pitch + pitchf = row["f0nsf"] + pitchf_padded[i, : pitchf.size(0)] = pitchf + + sid[i] = torch.tensor([0], dtype=torch.long) + + return ( + phone_padded, + phone_lengths, + pitch_padded, + pitchf_padded, + spec_padded, + spec_lengths, + wave_padded, + wave_lengths, + sid, + ) diff --git a/zerorvc/utils/losses.py b/zerorvc/utils/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..9d369d130162ed01c6c5e67a9c0275042156c494 --- /dev/null +++ b/zerorvc/utils/losses.py @@ -0,0 +1,66 @@ +import torch + + +def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss( + disc_real_outputs: list[torch.Tensor], disc_generated_outputs: list[torch.Tensor] +): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + dr = dr.float() + dg = dg.float() + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs: list[torch.Tensor]): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + +def kl_loss( + z_p: torch.Tensor, + logs_q: torch.Tensor, + m_p: torch.Tensor, + logs_p: torch.Tensor, + z_mask: torch.Tensor, +): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l diff --git a/zerorvc/utils/mel_processing.py b/zerorvc/utils/mel_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..e7e35f7cbb571e55a85cabf72760c04f33d7d17e --- /dev/null +++ b/zerorvc/utils/mel_processing.py @@ -0,0 +1,127 @@ +import torch +import torch.utils.data +import librosa +import logging + +logger = logging.getLogger(__name__) + +MAX_WAV_VALUE = 32768.0 + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + return dynamic_range_compression_torch(magnitudes) + + +def spectral_de_normalize_torch(magnitudes): + return dynamic_range_decompression_torch(magnitudes) + + +# Reusable banks +mel_basis = {} +hann_window = {} + + +def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): + """Convert waveform into Linear-frequency Linear-amplitude spectrogram. + + Args: + y :: (B, T) - Audio waveforms + n_fft + sampling_rate + hop_size + win_size + center + Returns: + :: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram + """ + + # Window - Cache if needed + global hann_window + dtype_device = str(y.dtype) + "_" + str(y.device) + wnsize_dtype_device = str(win_size) + "_" + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to( + dtype=y.dtype, device=y.device + ) + + # Padding + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + # Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2) + spec = torch.stft( + y, + n_fft, + hop_length=hop_size, + win_length=win_size, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + + # Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame) + spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6) + return spec + + +def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + # MelBasis - Cache if needed + global mel_basis + dtype_device = str(spec.dtype) + "_" + str(spec.device) + fmax_dtype_device = str(fmax) + "_" + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa.filters.mel( + sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax + ) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to( + dtype=spec.dtype, device=spec.device + ) + + # Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame) + melspec = torch.matmul(mel_basis[fmax_dtype_device], spec) + melspec = spectral_normalize_torch(melspec) + return melspec + + +def mel_spectrogram_torch( + y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + """Convert waveform into Mel-frequency Log-amplitude spectrogram. + + Args: + y :: (B, T) - Waveforms + Returns: + melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram + """ + # Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame) + spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center) + + # Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame) + melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax) + + return melspec