v2.3
Browse files- .gitignore +136 -5
- README_model.md +146 -53
- {alias_free_cuda β alias_free_activation/cuda}/__init__.py +0 -0
- {alias_free_cuda β alias_free_activation/cuda}/activation1d.py +37 -23
- {alias_free_cuda β alias_free_activation/cuda}/anti_alias_activation.cpp +4 -29
- alias_free_activation/cuda/anti_alias_activation_cuda.cu +246 -0
- {alias_free_cuda β alias_free_activation/cuda}/compat.h +0 -2
- {alias_free_cuda β alias_free_activation/cuda}/load.py +40 -26
- alias_free_activation/cuda/type_shim.h +92 -0
- {alias_free_torch β alias_free_activation/torch}/__init__.py +1 -1
- {alias_free_torch β alias_free_activation/torch}/act.py +10 -8
- {alias_free_torch β alias_free_activation/torch}/filter.py +37 -31
- {alias_free_torch β alias_free_activation/torch}/resample.py +25 -16
- alias_free_cuda/anti_alias_activation_cuda.cu +0 -314
- alias_free_cuda/test_activation.py +0 -55
- alias_free_cuda/test_activation_snake_beta.py +0 -55
- alias_free_cuda/type_shim.h +0 -97
- app.py +1 -1
- bigvgan.py +286 -160
- meldataset.py +318 -30
- utils.py +34 -15
.gitignore
CHANGED
@@ -1,6 +1,137 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
*/__pycache__/
|
4 |
-
alias_free_cuda/build/
|
5 |
exp/
|
6 |
-
tmp/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BigVGAN
|
2 |
+
alias_free_activation/cuda/build/
|
|
|
|
|
3 |
exp/
|
4 |
+
tmp/
|
5 |
+
|
6 |
+
# VSCode configs
|
7 |
+
.vscode/
|
8 |
+
|
9 |
+
# Byte-compiled / optimized / DLL files
|
10 |
+
__pycache__/
|
11 |
+
*.py[cod]
|
12 |
+
*$py.class
|
13 |
+
|
14 |
+
# C extensions
|
15 |
+
*.so
|
16 |
+
|
17 |
+
# Distribution / packaging
|
18 |
+
.Python
|
19 |
+
build/
|
20 |
+
develop-eggs/
|
21 |
+
dist/
|
22 |
+
downloads/
|
23 |
+
eggs/
|
24 |
+
.eggs/
|
25 |
+
lib/
|
26 |
+
lib64/
|
27 |
+
parts/
|
28 |
+
sdist/
|
29 |
+
var/
|
30 |
+
wheels/
|
31 |
+
share/python-wheels/
|
32 |
+
*.egg-info/
|
33 |
+
.installed.cfg
|
34 |
+
*.egg
|
35 |
+
MANIFEST
|
36 |
+
|
37 |
+
# PyInstaller
|
38 |
+
# Usually these files are written by a python script from a template
|
39 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
40 |
+
*.manifest
|
41 |
+
*.spec
|
42 |
+
|
43 |
+
# Installer logs
|
44 |
+
pip-log.txt
|
45 |
+
pip-delete-this-directory.txt
|
46 |
+
|
47 |
+
# Unit test / coverage reports
|
48 |
+
htmlcov/
|
49 |
+
.tox/
|
50 |
+
.nox/
|
51 |
+
.coverage
|
52 |
+
.coverage.*
|
53 |
+
.cache
|
54 |
+
nosetests.xml
|
55 |
+
coverage.xml
|
56 |
+
*.cover
|
57 |
+
*.py,cover
|
58 |
+
.hypothesis/
|
59 |
+
.pytest_cache/
|
60 |
+
cover/
|
61 |
+
|
62 |
+
# Translations
|
63 |
+
*.mo
|
64 |
+
*.pot
|
65 |
+
|
66 |
+
# Django stuff:
|
67 |
+
*.log
|
68 |
+
local_settings.py
|
69 |
+
db.sqlite3
|
70 |
+
db.sqlite3-journal
|
71 |
+
|
72 |
+
# Flask stuff:
|
73 |
+
instance/
|
74 |
+
.webassets-cache
|
75 |
+
|
76 |
+
# Scrapy stuff:
|
77 |
+
.scrapy
|
78 |
+
|
79 |
+
# Sphinx documentation
|
80 |
+
docs/_build/
|
81 |
+
|
82 |
+
# PyBuilder
|
83 |
+
.pybuilder/
|
84 |
+
target/
|
85 |
+
|
86 |
+
# Jupyter Notebook
|
87 |
+
.ipynb_checkpoints
|
88 |
+
|
89 |
+
# IPython
|
90 |
+
profile_default/
|
91 |
+
ipython_config.py
|
92 |
+
|
93 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
94 |
+
__pypackages__/
|
95 |
+
|
96 |
+
# Celery stuff
|
97 |
+
celerybeat-schedule
|
98 |
+
celerybeat.pid
|
99 |
+
|
100 |
+
# SageMath parsed files
|
101 |
+
*.sage.py
|
102 |
+
|
103 |
+
# Environments
|
104 |
+
.env
|
105 |
+
.venv
|
106 |
+
env/
|
107 |
+
venv/
|
108 |
+
ENV/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
# Spyder project settings
|
113 |
+
.spyderproject
|
114 |
+
.spyproject
|
115 |
+
|
116 |
+
# Rope project settings
|
117 |
+
.ropeproject
|
118 |
+
|
119 |
+
# mkdocs documentation
|
120 |
+
/site
|
121 |
+
|
122 |
+
# mypy
|
123 |
+
.mypy_cache/
|
124 |
+
.dmypy.json
|
125 |
+
dmypy.json
|
126 |
+
|
127 |
+
# Pyre type checker
|
128 |
+
.pyre/
|
129 |
+
|
130 |
+
# pytype static type analyzer
|
131 |
+
.pytype/
|
132 |
+
|
133 |
+
# Cython debug symbols
|
134 |
+
cython_debug/
|
135 |
+
|
136 |
+
# PyCharm
|
137 |
+
.idea/
|
README_model.md
CHANGED
@@ -1,37 +1,96 @@
|
|
1 |
## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
|
|
|
2 |
#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
|
3 |
|
4 |
-
|
5 |
|
|
|
6 |
|
7 |
-
|
8 |
|
9 |
## News
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
## Installation
|
|
|
17 |
The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
|
|
|
18 |
```shell
|
19 |
conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
|
20 |
conda activate bigvgan
|
21 |
```
|
22 |
|
23 |
Clone the repository and install dependencies:
|
|
|
24 |
```shell
|
25 |
git clone https://github.com/NVIDIA/BigVGAN
|
26 |
cd BigVGAN
|
27 |
pip install -r requirements.txt
|
28 |
```
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
|
33 |
-
|
34 |
-
|
|
|
35 |
ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
|
36 |
ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
|
37 |
ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
|
@@ -39,29 +98,30 @@ ln -s /path/to/your/LibriTTS/dev-clean dev-clean && \
|
|
39 |
ln -s /path/to/your/LibriTTS/dev-other dev-other && \
|
40 |
ln -s /path/to/your/LibriTTS/test-clean test-clean && \
|
41 |
ln -s /path/to/your/LibriTTS/test-other test-other && \
|
42 |
-
cd
|
43 |
```
|
44 |
|
45 |
-
## Training
|
46 |
Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
|
|
|
47 |
```shell
|
48 |
python train.py \
|
49 |
--config configs/bigvgan_v2_24khz_100band_256x.json \
|
50 |
-
--input_wavs_dir LibriTTS \
|
51 |
-
--input_training_file LibriTTS/train-full.txt \
|
52 |
-
--input_validation_file LibriTTS/val-full.txt \
|
53 |
-
--list_input_unseen_wavs_dir LibriTTS LibriTTS \
|
54 |
-
--list_input_unseen_validation_file LibriTTS/dev-clean.txt LibriTTS/dev-other.txt \
|
55 |
--checkpoint_path exp/bigvgan_v2_24khz_100band_256x
|
56 |
```
|
57 |
|
58 |
-
|
59 |
## Synthesis
|
|
|
60 |
Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
|
61 |
It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
|
|
|
62 |
```shell
|
63 |
python inference.py \
|
64 |
-
--checkpoint_file
|
65 |
--input_wavs_dir /path/to/your/input_wav \
|
66 |
--output_dir /path/to/your/output_wav
|
67 |
```
|
@@ -70,14 +130,16 @@ python inference.py \
|
|
70 |
It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
|
71 |
|
72 |
Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
|
|
|
73 |
```shell
|
74 |
python inference_e2e.py \
|
75 |
-
--checkpoint_file
|
76 |
--input_mels_dir /path/to/your/input_mel \
|
77 |
--output_dir /path/to/your/output_wav
|
78 |
```
|
79 |
|
80 |
## Using Custom CUDA Kernel for Synthesis
|
|
|
81 |
You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
|
82 |
|
83 |
```python
|
@@ -86,15 +148,15 @@ generator = BigVGAN(h, use_cuda_kernel=True)
|
|
86 |
|
87 |
You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
|
88 |
|
89 |
-
When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `
|
90 |
|
91 |
Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
|
92 |
|
93 |
We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
|
94 |
|
95 |
```python
|
96 |
-
python test_cuda_vs_torch_model.py \
|
97 |
-
--checkpoint_file /path/to/your/
|
98 |
```
|
99 |
|
100 |
```shell
|
@@ -102,12 +164,12 @@ loading plain Pytorch BigVGAN
|
|
102 |
...
|
103 |
loading CUDA kernel BigVGAN with auto-build
|
104 |
Detected CUDA files, patching ldflags
|
105 |
-
Emitting ninja build file /path/to/your/BigVGAN/
|
106 |
Building extension module anti_alias_activation_cuda...
|
107 |
...
|
108 |
Loading extension module anti_alias_activation_cuda...
|
109 |
...
|
110 |
-
Loading '/path/to/your/
|
111 |
...
|
112 |
[Success] test CUDA fused vs. plain torch BigVGAN inference
|
113 |
> mean_difference=0.0007238413265440613
|
@@ -116,30 +178,34 @@ Loading '/path/to/your/bigvgan/g_03000000'
|
|
116 |
|
117 |
If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
|
118 |
|
119 |
-
|
120 |
## Pretrained Models
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
128 |
-
|
|
129 |
-
|
|
130 |
-
|
|
131 |
-
|
|
132 |
-
|
|
133 |
-
|
|
134 |
-
|
|
|
|
135 |
|
136 |
The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
|
137 |
We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
|
138 |
-
Note that the checkpoints use
|
|
|
|
|
139 |
|
140 |
-
|
|
|
141 |
|
142 |
## Training Details of BigVGAN-v2
|
|
|
143 |
Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
|
144 |
|
145 |
Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
|
@@ -147,23 +213,50 @@ Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` a
|
|
147 |
When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
|
148 |
|
149 |
## Evaluation Results of BigVGAN-v2
|
|
|
150 |
Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
|
151 |
|
152 |
-
|Model|Dataset|Steps|PESQ(β)|M-STFT(β)|MCD(β)|Periodicity(β)|V/UV F1(β)|
|
153 |
-
|
154 |
-
|BigVGAN|LibriTTS|1M|4.027|0.7997|0.3745|0.1018|0.9598|
|
155 |
-
|BigVGAN|LibriTTS|5M|4.256|0.7409|0.2988|0.0809|0.9698|
|
156 |
-
|BigVGAN-v2|Large-scale Compilation|3M
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
158 |
## Acknowledgements
|
|
|
159 |
We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
|
160 |
|
161 |
## References
|
162 |
-
* [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
|
163 |
-
* [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
|
164 |
-
* [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
|
165 |
-
* [Julius](https://github.com/adefossez/julius) (for low-pass filter)
|
166 |
-
* [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
|
167 |
-
* [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
|
168 |
-
* [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
## BigVGAN: A Universal Neural Vocoder with Large-Scale Training
|
2 |
+
|
3 |
#### Sang-gil Lee, Wei Ping, Boris Ginsburg, Bryan Catanzaro, Sungroh Yoon
|
4 |
|
5 |
+
[[Paper]](https://arxiv.org/abs/2206.04658) - [[Code]](https://github.com/NVIDIA/BigVGAN) - [[Showcase]](https://bigvgan-demo.github.io/) - [[Project Page]](https://research.nvidia.com/labs/adlr/projects/bigvgan/) - [[Weights]](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a) - [[Demo]](https://huggingface.co/spaces/nvidia/BigVGAN)
|
6 |
|
7 |
+
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/bigvgan-a-universal-neural-vocoder-with-large/speech-synthesis-on-libritts)](https://paperswithcode.com/sota/speech-synthesis-on-libritts?p=bigvgan-a-universal-neural-vocoder-with-large)
|
8 |
|
9 |
+
<center><img src="https://user-images.githubusercontent.com/15963413/218609148-881e39df-33af-4af9-ab95-1427c4ebf062.png" width="800"></center>
|
10 |
|
11 |
## News
|
12 |
+
- **Jul 2024 (v2.3):**
|
13 |
+
- General refactor and code improvements for improved readability.
|
14 |
+
- Fully fused CUDA kernel of anti-alised activation (upsampling + activation + downsampling) with inference speed benchmark.
|
15 |
+
|
16 |
+
- **Jul 2024 (v2.2):** The repository now includes an interactive local demo using gradio.
|
17 |
+
|
18 |
+
- **Jul 2024 (v2.1):** BigVGAN is now integrated with π€ Hugging Face Hub with easy access to inference using pretrained checkpoints. We also provide an interactive demo on Hugging Face Spaces.
|
19 |
+
|
20 |
+
- **Jul 2024 (v2):** We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:
|
21 |
+
- Custom CUDA kernel for inference: we provide a fused upsampling + activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.
|
22 |
+
- Improved discriminator and loss: BigVGAN-v2 is trained using a [multi-scale sub-band CQT discriminator](https://arxiv.org/abs/2311.14957) and a [multi-scale mel spectrogram loss](https://arxiv.org/abs/2306.06546).
|
23 |
+
- Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.
|
24 |
+
- We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio.
|
25 |
|
26 |
## Installation
|
27 |
+
|
28 |
The codebase has been tested on Python `3.10` and PyTorch `2.3.1` conda packages with either `pytorch-cuda=12.1` or `pytorch-cuda=11.8`. Below is an example command to create the conda environment:
|
29 |
+
|
30 |
```shell
|
31 |
conda create -n bigvgan python=3.10 pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
|
32 |
conda activate bigvgan
|
33 |
```
|
34 |
|
35 |
Clone the repository and install dependencies:
|
36 |
+
|
37 |
```shell
|
38 |
git clone https://github.com/NVIDIA/BigVGAN
|
39 |
cd BigVGAN
|
40 |
pip install -r requirements.txt
|
41 |
```
|
42 |
|
43 |
+
## Inference Quickstart using π€ Hugging Face Hub
|
44 |
+
|
45 |
+
Below example describes how you can use BigVGAN: load the pretrained BigVGAN generator from Hugging Face Hub, compute mel spectrogram from input waveform, and generate synthesized waveform using the mel spectrogram as the model's input.
|
46 |
+
|
47 |
+
```python
|
48 |
+
device = 'cuda'
|
49 |
+
|
50 |
+
import torch
|
51 |
+
import bigvgan
|
52 |
+
import librosa
|
53 |
+
from meldataset import get_mel_spectrogram
|
54 |
+
|
55 |
+
# instantiate the model. You can optionally set use_cuda_kernel=True for faster inference.
|
56 |
+
model = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
|
57 |
+
|
58 |
+
# remove weight norm in the model and set to eval mode
|
59 |
+
model.remove_weight_norm()
|
60 |
+
model = model.eval().to(device)
|
61 |
+
|
62 |
+
# load wav file and compute mel spectrogram
|
63 |
+
wav_path = '/path/to/your/audio.wav'
|
64 |
+
wav, sr = librosa.load(wav_path, sr=model.h.sampling_rate, mono=True) # wav is np.ndarray with shape [T_time] and values in [-1, 1]
|
65 |
+
wav = torch.FloatTensor(wav).unsqueeze(0) # wav is FloatTensor with shape [B(1), T_time]
|
66 |
+
|
67 |
+
# compute mel spectrogram from the ground truth audio
|
68 |
+
mel = get_mel_spectrogram(wav, model.h).to(device) # mel is FloatTensor with shape [B(1), C_mel, T_frame]
|
69 |
+
|
70 |
+
# generate waveform from mel
|
71 |
+
with torch.inference_mode():
|
72 |
+
wav_gen = model(mel) # wav_gen is FloatTensor with shape [B(1), 1, T_time] and values in [-1, 1]
|
73 |
+
wav_gen_float = wav_gen.squeeze(0).cpu() # wav_gen is FloatTensor with shape [1, T_time]
|
74 |
+
|
75 |
+
# you can convert the generated waveform to 16 bit linear PCM
|
76 |
+
wav_gen_int16 = (wav_gen_float * 32767.0).numpy().astype('int16') # wav_gen is now np.ndarray with shape [1, T_time] and int16 dtype
|
77 |
+
```
|
78 |
+
|
79 |
+
## Local gradio demo <a href='https://github.com/gradio-app/gradio'><img src='https://img.shields.io/github/stars/gradio-app/gradio'></a>
|
80 |
|
81 |
+
You can run a local gradio demo using below command:
|
82 |
+
|
83 |
+
```python
|
84 |
+
pip install -r demo/requirements.txt
|
85 |
+
python demo/app.py
|
86 |
+
```
|
87 |
+
|
88 |
+
## Training
|
89 |
|
90 |
Create symbolic link to the root of the dataset. The codebase uses filelist with the relative path from the dataset. Below are the example commands for LibriTTS dataset:
|
91 |
+
|
92 |
+
```shell
|
93 |
+
cd filelists/LibriTTS && \
|
94 |
ln -s /path/to/your/LibriTTS/train-clean-100 train-clean-100 && \
|
95 |
ln -s /path/to/your/LibriTTS/train-clean-360 train-clean-360 && \
|
96 |
ln -s /path/to/your/LibriTTS/train-other-500 train-other-500 && \
|
|
|
98 |
ln -s /path/to/your/LibriTTS/dev-other dev-other && \
|
99 |
ln -s /path/to/your/LibriTTS/test-clean test-clean && \
|
100 |
ln -s /path/to/your/LibriTTS/test-other test-other && \
|
101 |
+
cd ../..
|
102 |
```
|
103 |
|
|
|
104 |
Train BigVGAN model. Below is an example command for training BigVGAN-v2 using LibriTTS dataset at 24kHz with a full 100-band mel spectrogram as input:
|
105 |
+
|
106 |
```shell
|
107 |
python train.py \
|
108 |
--config configs/bigvgan_v2_24khz_100band_256x.json \
|
109 |
+
--input_wavs_dir filelists/LibriTTS \
|
110 |
+
--input_training_file filelists/LibriTTS/train-full.txt \
|
111 |
+
--input_validation_file filelists/LibriTTS/val-full.txt \
|
112 |
+
--list_input_unseen_wavs_dir filelists/LibriTTS filelists/LibriTTS \
|
113 |
+
--list_input_unseen_validation_file filelists/LibriTTS/dev-clean.txt filelists/LibriTTS/dev-other.txt \
|
114 |
--checkpoint_path exp/bigvgan_v2_24khz_100band_256x
|
115 |
```
|
116 |
|
|
|
117 |
## Synthesis
|
118 |
+
|
119 |
Synthesize from BigVGAN model. Below is an example command for generating audio from the model.
|
120 |
It computes mel spectrograms using wav files from `--input_wavs_dir` and saves the generated audio to `--output_dir`.
|
121 |
+
|
122 |
```shell
|
123 |
python inference.py \
|
124 |
+
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
|
125 |
--input_wavs_dir /path/to/your/input_wav \
|
126 |
--output_dir /path/to/your/output_wav
|
127 |
```
|
|
|
130 |
It loads mel spectrograms from `--input_mels_dir` and saves the generated audio to `--output_dir`.
|
131 |
|
132 |
Make sure that the STFT hyperparameters for mel spectrogram are the same as the model, which are defined in `config.json` of the corresponding model.
|
133 |
+
|
134 |
```shell
|
135 |
python inference_e2e.py \
|
136 |
+
--checkpoint_file /path/to/your/bigvgan_v2_24khz_100band_256x/bigvgan_generator.pt \
|
137 |
--input_mels_dir /path/to/your/input_mel \
|
138 |
--output_dir /path/to/your/output_wav
|
139 |
```
|
140 |
|
141 |
## Using Custom CUDA Kernel for Synthesis
|
142 |
+
|
143 |
You can apply the fast CUDA inference kernel by using a parameter `use_cuda_kernel` when instantiating BigVGAN:
|
144 |
|
145 |
```python
|
|
|
148 |
|
149 |
You can also pass `--use_cuda_kernel` to `inference.py` and `inference_e2e.py` to enable this feature.
|
150 |
|
151 |
+
When applied for the first time, it builds the kernel using `nvcc` and `ninja`. If the build succeeds, the kernel is saved to `alias_free_activation/cuda/build` and the model automatically loads the kernel. The codebase has been tested using CUDA `12.1`.
|
152 |
|
153 |
Please make sure that both are installed in your system and `nvcc` installed in your system matches the version your PyTorch build is using.
|
154 |
|
155 |
We recommend running `test_cuda_vs_torch_model.py` first to build and check the correctness of the CUDA kernel. See below example command and its output, where it returns `[Success] test CUDA fused vs. plain torch BigVGAN inference`:
|
156 |
|
157 |
```python
|
158 |
+
python tests/test_cuda_vs_torch_model.py \
|
159 |
+
--checkpoint_file /path/to/your/bigvgan_generator.pt
|
160 |
```
|
161 |
|
162 |
```shell
|
|
|
164 |
...
|
165 |
loading CUDA kernel BigVGAN with auto-build
|
166 |
Detected CUDA files, patching ldflags
|
167 |
+
Emitting ninja build file /path/to/your/BigVGAN/alias_free_activation/cuda/build/build.ninja..
|
168 |
Building extension module anti_alias_activation_cuda...
|
169 |
...
|
170 |
Loading extension module anti_alias_activation_cuda...
|
171 |
...
|
172 |
+
Loading '/path/to/your/bigvgan_generator.pt'
|
173 |
...
|
174 |
[Success] test CUDA fused vs. plain torch BigVGAN inference
|
175 |
> mean_difference=0.0007238413265440613
|
|
|
178 |
|
179 |
If you see `[Fail] test CUDA fused vs. plain torch BigVGAN inference`, it means that the CUDA kernel inference is incorrect. Please check if `nvcc` installed in your system is compatible with your PyTorch version.
|
180 |
|
|
|
181 |
## Pretrained Models
|
182 |
+
|
183 |
+
We provide the [pretrained models on Hugging Face Collections](https://huggingface.co/collections/nvidia/bigvgan-66959df3d97fd7d98d97dc9a).
|
184 |
+
One can download the checkpoints of the generator weight (named `bigvgan_generator.pt`) and its discriminator/optimizer states (named `bigvgan_discriminator_optimizer.pt`) within the listed model repositories.
|
185 |
+
|
186 |
+
| Model Name | Sampling Rate | Mel band | fmax | Upsampling Ratio | Params | Dataset | Steps | Fine-Tuned |
|
187 |
+
|:--------------------------------------------------------------------------------------------------------:|:-------------:|:--------:|:-----:|:----------------:|:------:|:--------------------------:|:-----:|:----------:|
|
188 |
+
| [bigvgan_v2_44khz_128band_512x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_512x) | 44 kHz | 128 | 22050 | 512 | 122M | Large-scale Compilation | 3M | No |
|
189 |
+
| [bigvgan_v2_44khz_128band_256x](https://huggingface.co/nvidia/bigvgan_v2_44khz_128band_256x) | 44 kHz | 128 | 22050 | 256 | 112M | Large-scale Compilation | 3M | No |
|
190 |
+
| [bigvgan_v2_24khz_100band_256x](https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x) | 24 kHz | 100 | 12000 | 256 | 112M | Large-scale Compilation | 3M | No |
|
191 |
+
| [bigvgan_v2_22khz_80band_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_256x) | 22 kHz | 80 | 11025 | 256 | 112M | Large-scale Compilation | 3M | No |
|
192 |
+
| [bigvgan_v2_22khz_80band_fmax8k_256x](https://huggingface.co/nvidia/bigvgan_v2_22khz_80band_fmax8k_256x) | 22 kHz | 80 | 8000 | 256 | 112M | Large-scale Compilation | 3M | No |
|
193 |
+
| [bigvgan_24khz_100band](https://huggingface.co/nvidia/bigvgan_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 112M | LibriTTS | 5M | No |
|
194 |
+
| [bigvgan_base_24khz_100band](https://huggingface.co/nvidia/bigvgan_base_24khz_100band) | 24 kHz | 100 | 12000 | 256 | 14M | LibriTTS | 5M | No |
|
195 |
+
| [bigvgan_22khz_80band](https://huggingface.co/nvidia/bigvgan_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 112M | LibriTTS + VCTK + LJSpeech | 5M | No |
|
196 |
+
| [bigvgan_base_22khz_80band](https://huggingface.co/nvidia/bigvgan_base_22khz_80band) | 22 kHz | 80 | 8000 | 256 | 14M | LibriTTS + VCTK + LJSpeech | 5M | No |
|
197 |
|
198 |
The paper results are based on the original 24kHz BigVGAN models (`bigvgan_24khz_100band` and `bigvgan_base_24khz_100band`) trained on LibriTTS dataset.
|
199 |
We also provide 22kHz BigVGAN models with band-limited setup (i.e., fmax=8000) for TTS applications.
|
200 |
+
Note that the checkpoints use `snakebeta` activation with log scale parameterization, which have the best overall quality.
|
201 |
+
|
202 |
+
You can fine-tune the models by:
|
203 |
|
204 |
+
1. downloading the checkpoints (both the generator weight and its discriminator/optimizer states)
|
205 |
+
2. resuming training using your audio dataset by specifying `--checkpoint_path` that includes the checkpoints when launching `train.py`
|
206 |
|
207 |
## Training Details of BigVGAN-v2
|
208 |
+
|
209 |
Comapred to the original BigVGAN, the pretrained checkpoints of BigVGAN-v2 used `batch_size=32` with a longer `segment_size=65536` and are trained using 8 A100 GPUs.
|
210 |
|
211 |
Note that the BigVGAN-v2 `json` config files in `./configs` use `batch_size=4` as default to fit in a single A100 GPU for training. You can fine-tune the models adjusting `batch_size` depending on your GPUs.
|
|
|
213 |
When training BigVGAN-v2 from scratch with small batch size, it can potentially encounter the early divergence problem mentioned in the paper. In such case, we recommend lowering the `clip_grad_norm` value (e.g. `100`) for the early training iterations (e.g. 20000 steps) and increase the value to the default `500`.
|
214 |
|
215 |
## Evaluation Results of BigVGAN-v2
|
216 |
+
|
217 |
Below are the objective results of the 24kHz model (`bigvgan_v2_24khz_100band_256x`) obtained from the LibriTTS `dev` sets. BigVGAN-v2 shows noticeable improvements of the metrics. The model also exhibits reduced perceptual artifacts, especially for non-speech audio.
|
218 |
|
219 |
+
| Model | Dataset | Steps | PESQ(β) | M-STFT(β) | MCD(β) | Periodicity(β) | V/UV F1(β) |
|
220 |
+
|:----------:|:-----------------------:|:-----:|:---------:|:----------:|:------:|:--------------:|:----------:|
|
221 |
+
| BigVGAN | LibriTTS | 1M | 4.027 | 0.7997 | 0.3745 | 0.1018 | 0.9598 |
|
222 |
+
| BigVGAN | LibriTTS | 5M | 4.256 | 0.7409 | 0.2988 | 0.0809 | 0.9698 |
|
223 |
+
| BigVGAN-v2 | Large-scale Compilation | 3M | **4.359** | **0.7134** | 0.3060 | **0.0621** | **0.9777** |
|
224 |
+
|
225 |
+
## Speed Benchmark
|
226 |
+
|
227 |
+
Below are the speed and VRAM usage benchmark results of BigVGAN from `tests/test_cuda_vs_torch_model.py`, using `bigvgan_v2_24khz_100band_256x` as a reference model.
|
228 |
+
|
229 |
+
| GPU | num_mel_frame | use_cuda_kernel | Speed (kHz) | Real-time Factor | VRAM (GB) |
|
230 |
+
|:--------------------------:|:-------------:|:---------------:|:-----------:|:----------------:|:---------:|
|
231 |
+
| NVIDIA A100 | 256 | False | 1672.1 | 69.7x | 1.3 |
|
232 |
+
| | | True | 3916.5 | 163.2x | 1.3 |
|
233 |
+
| | 2048 | False | 1899.6 | 79.2x | 1.7 |
|
234 |
+
| | | True | 5330.1 | 222.1x | 1.7 |
|
235 |
+
| | 16384 | False | 1973.8 | 82.2x | 5.0 |
|
236 |
+
| | | True | 5761.7 | 240.1x | 4.4 |
|
237 |
+
| NVIDIA GeForce RTX 3080 | 256 | False | 841.1 | 35.0x | 1.3 |
|
238 |
+
| | | True | 1598.1 | 66.6x | 1.3 |
|
239 |
+
| | 2048 | False | 929.9 | 38.7x | 1.7 |
|
240 |
+
| | | True | 1971.3 | 82.1x | 1.6 |
|
241 |
+
| | 16384 | False | 943.4 | 39.3x | 5.0 |
|
242 |
+
| | | True | 2026.5 | 84.4x | 3.9 |
|
243 |
+
| NVIDIA GeForce RTX 2080 Ti | 256 | False | 515.6 | 21.5x | 1.3 |
|
244 |
+
| | | True | 811.3 | 33.8x | 1.3 |
|
245 |
+
| | 2048 | False | 576.5 | 24.0x | 1.7 |
|
246 |
+
| | | True | 1023.0 | 42.6x | 1.5 |
|
247 |
+
| | 16384 | False | 589.4 | 24.6x | 5.0 |
|
248 |
+
| | | True | 1068.1 | 44.5x | 3.2 |
|
249 |
|
250 |
## Acknowledgements
|
251 |
+
|
252 |
We thank Vijay Anand Korthikanti and Kevin J. Shih for their generous support in implementing the CUDA kernel for inference.
|
253 |
|
254 |
## References
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
|
256 |
+
- [HiFi-GAN](https://github.com/jik876/hifi-gan) (for generator and multi-period discriminator)
|
257 |
+
- [Snake](https://github.com/EdwardDixon/snake) (for periodic activation)
|
258 |
+
- [Alias-free-torch](https://github.com/junjun3518/alias-free-torch) (for anti-aliasing)
|
259 |
+
- [Julius](https://github.com/adefossez/julius) (for low-pass filter)
|
260 |
+
- [UnivNet](https://github.com/mindslab-ai/univnet) (for multi-resolution discriminator)
|
261 |
+
- [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) and [vocos](https://github.com/gemelo-ai/vocos) (for multi-band multi-scale STFT discriminator and multi-scale mel spectrogram loss)
|
262 |
+
- [Amphion](https://github.com/open-mmlab/Amphion) (for multi-scale sub-band CQT discriminator)
|
{alias_free_cuda β alias_free_activation/cuda}/__init__.py
RENAMED
File without changes
|
{alias_free_cuda β alias_free_activation/cuda}/activation1d.py
RENAMED
@@ -3,36 +3,45 @@
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
-
from
|
|
|
7 |
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
8 |
-
from
|
9 |
-
|
|
|
|
|
10 |
|
11 |
class FusedAntiAliasActivation(torch.autograd.Function):
|
12 |
"""
|
13 |
-
Assumes filter size 12, replication padding on upsampling, and logscale alpha/beta parameters as inputs
|
|
|
|
|
14 |
"""
|
|
|
15 |
@staticmethod
|
16 |
-
def forward(ctx, inputs,
|
17 |
-
|
18 |
-
|
|
|
|
|
19 |
return activation_results
|
20 |
|
21 |
@staticmethod
|
22 |
def backward(ctx, output_grads):
|
23 |
-
# TODO: implement bwd pass
|
24 |
raise NotImplementedError
|
25 |
return output_grads, None, None
|
26 |
|
|
|
27 |
class Activation1d(nn.Module):
|
28 |
-
def __init__(
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
36 |
super().__init__()
|
37 |
self.up_ratio = up_ratio
|
38 |
self.down_ratio = down_ratio
|
@@ -40,8 +49,7 @@ class Activation1d(nn.Module):
|
|
40 |
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
41 |
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
42 |
|
43 |
-
self.fused = fused
|
44 |
-
|
45 |
|
46 |
def forward(self, x):
|
47 |
if not self.fused:
|
@@ -51,13 +59,19 @@ class Activation1d(nn.Module):
|
|
51 |
return x
|
52 |
else:
|
53 |
if self.act.__class__.__name__ == "Snake":
|
54 |
-
beta = self.act.alpha.data
|
55 |
else:
|
56 |
-
beta =
|
|
|
|
|
57 |
alpha = self.act.alpha.data
|
58 |
-
if
|
|
|
|
|
59 |
alpha = torch.log(alpha)
|
60 |
beta = torch.log(beta)
|
61 |
-
|
62 |
-
x =
|
|
|
|
|
63 |
return x
|
|
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
+
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
7 |
+
|
8 |
# load fused CUDA kernel: this enables importing anti_alias_activation_cuda
|
9 |
+
from alias_free_activation.cuda import load
|
10 |
+
|
11 |
+
anti_alias_activation_cuda = load.load()
|
12 |
+
|
13 |
|
14 |
class FusedAntiAliasActivation(torch.autograd.Function):
|
15 |
"""
|
16 |
+
Assumes filter size 12, replication padding on upsampling/downsampling, and logscale alpha/beta parameters as inputs.
|
17 |
+
The hyperparameters are hard-coded in the kernel to maximize speed.
|
18 |
+
NOTE: The fused kenrel is incorrect for Activation1d with different hyperparameters.
|
19 |
"""
|
20 |
+
|
21 |
@staticmethod
|
22 |
+
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
23 |
+
activation_results = anti_alias_activation_cuda.forward(
|
24 |
+
inputs, up_ftr, down_ftr, alpha, beta
|
25 |
+
)
|
26 |
+
|
27 |
return activation_results
|
28 |
|
29 |
@staticmethod
|
30 |
def backward(ctx, output_grads):
|
|
|
31 |
raise NotImplementedError
|
32 |
return output_grads, None, None
|
33 |
|
34 |
+
|
35 |
class Activation1d(nn.Module):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
activation,
|
39 |
+
up_ratio: int = 2,
|
40 |
+
down_ratio: int = 2,
|
41 |
+
up_kernel_size: int = 12,
|
42 |
+
down_kernel_size: int = 12,
|
43 |
+
fused: bool = True,
|
44 |
+
):
|
45 |
super().__init__()
|
46 |
self.up_ratio = up_ratio
|
47 |
self.down_ratio = down_ratio
|
|
|
49 |
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
50 |
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
51 |
|
52 |
+
self.fused = fused # Whether to use fused CUDA kernel or not
|
|
|
53 |
|
54 |
def forward(self, x):
|
55 |
if not self.fused:
|
|
|
59 |
return x
|
60 |
else:
|
61 |
if self.act.__class__.__name__ == "Snake":
|
62 |
+
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
63 |
else:
|
64 |
+
beta = (
|
65 |
+
self.act.beta.data
|
66 |
+
) # Snakebeta uses different params for alpha and beta
|
67 |
alpha = self.act.alpha.data
|
68 |
+
if (
|
69 |
+
not self.act.alpha_logscale
|
70 |
+
): # Exp baked into cuda kernel, cancel it out with a log
|
71 |
alpha = torch.log(alpha)
|
72 |
beta = torch.log(beta)
|
73 |
+
|
74 |
+
x = FusedAntiAliasActivation.apply(
|
75 |
+
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
76 |
+
)
|
77 |
return x
|
{alias_free_cuda β alias_free_activation/cuda}/anti_alias_activation.cpp
RENAMED
@@ -14,35 +14,10 @@
|
|
14 |
* limitations under the License.
|
15 |
*/
|
16 |
|
17 |
-
#include <
|
18 |
-
#include <torch/extension.h>
|
19 |
-
#include <vector>
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
torch::Tensor fwd_cuda(torch::Tensor const& input,
|
24 |
-
torch::Tensor const& filter,
|
25 |
-
torch::Tensor const& alpha,
|
26 |
-
torch::Tensor const& beta
|
27 |
-
);
|
28 |
-
|
29 |
-
torch::Tensor fwd(torch::Tensor const& input,
|
30 |
-
torch::Tensor const& filter,
|
31 |
-
torch::Tensor const& alpha,
|
32 |
-
torch::Tensor const& beta
|
33 |
-
) {
|
34 |
-
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
|
35 |
-
//AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
|
36 |
-
// (input.scalar_type() == at::ScalarType::BFloat16),
|
37 |
-
// "Only fp16 and bf16 are supported");
|
38 |
-
|
39 |
-
return fwd_cuda(input, filter, alpha, beta);
|
40 |
-
}
|
41 |
-
|
42 |
-
} // end namespace anti_alias_activation
|
43 |
|
44 |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
45 |
-
|
46 |
-
|
47 |
-
"Anti Alias Activation -- Forward.");
|
48 |
-
}
|
|
|
14 |
* limitations under the License.
|
15 |
*/
|
16 |
|
17 |
+
#include <torch/extension.h>
|
|
|
|
|
18 |
|
19 |
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
22 |
+
m.def("forward", &fwd_cuda, "Anti-Alias Activation forward (CUDA)");
|
23 |
+
}
|
|
|
|
alias_free_activation/cuda/anti_alias_activation_cuda.cu
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <ATen/ATen.h>
|
18 |
+
#include <cuda.h>
|
19 |
+
#include <cuda_runtime.h>
|
20 |
+
#include <cuda_fp16.h>
|
21 |
+
#include <cuda_profiler_api.h>
|
22 |
+
#include <ATen/cuda/CUDAContext.h>
|
23 |
+
#include <torch/extension.h>
|
24 |
+
#include "type_shim.h"
|
25 |
+
#include <assert.h>
|
26 |
+
#include <cfloat>
|
27 |
+
#include <limits>
|
28 |
+
#include <stdint.h>
|
29 |
+
#include <c10/macros/Macros.h>
|
30 |
+
|
31 |
+
namespace
|
32 |
+
{
|
33 |
+
// Hard-coded hyperparameters
|
34 |
+
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
35 |
+
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
36 |
+
constexpr int BUFFER_SIZE = 32;
|
37 |
+
constexpr int FILTER_SIZE = 12;
|
38 |
+
constexpr int HALF_FILTER_SIZE = 6;
|
39 |
+
constexpr int UPSAMPLE_REPLICATION_PAD = 5; // 5 on each side, matching torch impl
|
40 |
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_LEFT = 5; // matching torch impl
|
41 |
+
constexpr int DOWNSAMPLE_REPLICATION_PAD_RIGHT = 6; // matching torch impl
|
42 |
+
|
43 |
+
template <typename input_t, typename output_t, typename acc_t>
|
44 |
+
__global__ void anti_alias_activation_forward(
|
45 |
+
output_t *dst,
|
46 |
+
const input_t *src,
|
47 |
+
const input_t *up_ftr,
|
48 |
+
const input_t *down_ftr,
|
49 |
+
const input_t *alpha,
|
50 |
+
const input_t *beta,
|
51 |
+
int batch_size,
|
52 |
+
int channels,
|
53 |
+
int seq_len)
|
54 |
+
{
|
55 |
+
// Up and downsample filters
|
56 |
+
input_t up_filter[FILTER_SIZE];
|
57 |
+
input_t down_filter[FILTER_SIZE];
|
58 |
+
|
59 |
+
// Load data from global memory including extra indices reserved for replication paddings
|
60 |
+
input_t elements[2 * FILTER_SIZE + 2 * BUFFER_SIZE + 2 * UPSAMPLE_REPLICATION_PAD] = {0};
|
61 |
+
input_t intermediates[2 * FILTER_SIZE + 2 * BUFFER_SIZE + DOWNSAMPLE_REPLICATION_PAD_LEFT + DOWNSAMPLE_REPLICATION_PAD_RIGHT] = {0};
|
62 |
+
|
63 |
+
// Output stores downsampled output before writing to dst
|
64 |
+
output_t output[BUFFER_SIZE];
|
65 |
+
|
66 |
+
// blockDim/threadIdx = (128, 1, 1)
|
67 |
+
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
68 |
+
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
69 |
+
int local_offset = threadIdx.x * BUFFER_SIZE;
|
70 |
+
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
71 |
+
|
72 |
+
// intermediate have double the seq_len
|
73 |
+
int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
74 |
+
int intermediate_seq_offset = blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_local_offset;
|
75 |
+
|
76 |
+
// Get values needed for replication padding before moving pointer
|
77 |
+
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
78 |
+
input_t seq_left_most_value = right_most_pntr[0];
|
79 |
+
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
80 |
+
|
81 |
+
// Move src and dst pointers
|
82 |
+
src += block_offset + local_offset;
|
83 |
+
dst += block_offset + local_offset;
|
84 |
+
|
85 |
+
// Alpha and beta values for snake activatons. Applies exp by default
|
86 |
+
alpha = alpha + blockIdx.y;
|
87 |
+
input_t alpha_val = expf(alpha[0]);
|
88 |
+
beta = beta + blockIdx.y;
|
89 |
+
input_t beta_val = expf(beta[0]);
|
90 |
+
|
91 |
+
#pragma unroll
|
92 |
+
for (int it = 0; it < FILTER_SIZE; it += 1)
|
93 |
+
{
|
94 |
+
up_filter[it] = up_ftr[it];
|
95 |
+
down_filter[it] = down_ftr[it];
|
96 |
+
}
|
97 |
+
|
98 |
+
// Apply replication padding for upsampling, matching torch impl
|
99 |
+
#pragma unroll
|
100 |
+
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE; it += 1)
|
101 |
+
{
|
102 |
+
int element_index = seq_offset + it; // index for element
|
103 |
+
if ((element_index < 0) && (element_index >= -UPSAMPLE_REPLICATION_PAD))
|
104 |
+
{
|
105 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_left_most_value;
|
106 |
+
}
|
107 |
+
if ((element_index >= seq_len) && (element_index < seq_len + UPSAMPLE_REPLICATION_PAD))
|
108 |
+
{
|
109 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * seq_right_most_value;
|
110 |
+
}
|
111 |
+
if ((element_index >= 0) && (element_index < seq_len))
|
112 |
+
{
|
113 |
+
elements[2 * (HALF_FILTER_SIZE + it)] = 2 * src[it];
|
114 |
+
}
|
115 |
+
}
|
116 |
+
|
117 |
+
// Apply upsampling strided convolution and write to intermediates. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT for replication padding of the downsampilng conv later
|
118 |
+
#pragma unroll
|
119 |
+
for (int it = 0; it < (2 * BUFFER_SIZE + 2 * FILTER_SIZE); it += 1)
|
120 |
+
{
|
121 |
+
input_t acc = 0.0;
|
122 |
+
int element_index = intermediate_seq_offset + it; // index for intermediate
|
123 |
+
#pragma unroll
|
124 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
125 |
+
{
|
126 |
+
if ((element_index + f_idx) >= 0)
|
127 |
+
{
|
128 |
+
acc += up_filter[f_idx] * elements[it + f_idx];
|
129 |
+
}
|
130 |
+
}
|
131 |
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] = acc;
|
132 |
+
}
|
133 |
+
|
134 |
+
// Apply activation function. It reserves DOWNSAMPLE_REPLICATION_PAD_LEFT and DOWNSAMPLE_REPLICATION_PAD_RIGHT for replication padding of the downsampilng conv later
|
135 |
+
double no_div_by_zero = 0.000000001;
|
136 |
+
#pragma unroll
|
137 |
+
for (int it = 0; it < 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it += 1)
|
138 |
+
{
|
139 |
+
intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] += (1.0 / (beta_val + no_div_by_zero)) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val) * sinf(intermediates[it + DOWNSAMPLE_REPLICATION_PAD_LEFT] * alpha_val);
|
140 |
+
}
|
141 |
+
|
142 |
+
// Apply replication padding before downsampling conv from intermediates
|
143 |
+
#pragma unroll
|
144 |
+
for (int it = 0; it < DOWNSAMPLE_REPLICATION_PAD_LEFT; it += 1)
|
145 |
+
{
|
146 |
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT];
|
147 |
+
}
|
148 |
+
#pragma unroll
|
149 |
+
for (int it = DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE; it < DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE + DOWNSAMPLE_REPLICATION_PAD_RIGHT; it += 1)
|
150 |
+
{
|
151 |
+
intermediates[it] = intermediates[DOWNSAMPLE_REPLICATION_PAD_LEFT + 2 * BUFFER_SIZE + 2 * FILTER_SIZE - 1];
|
152 |
+
}
|
153 |
+
|
154 |
+
// Apply downsample strided convolution (assuming stride=2) from intermediates
|
155 |
+
#pragma unroll
|
156 |
+
for (int it = 0; it < BUFFER_SIZE; it += 1)
|
157 |
+
{
|
158 |
+
input_t acc = 0.0;
|
159 |
+
#pragma unroll
|
160 |
+
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx += 1)
|
161 |
+
{
|
162 |
+
// Add constant DOWNSAMPLE_REPLICATION_PAD_RIGHT to match torch implementation
|
163 |
+
acc += down_filter[f_idx] * intermediates[it * 2 + f_idx + DOWNSAMPLE_REPLICATION_PAD_RIGHT];
|
164 |
+
}
|
165 |
+
output[it] = acc;
|
166 |
+
}
|
167 |
+
|
168 |
+
// Write output to dst
|
169 |
+
#pragma unroll
|
170 |
+
for (int it = 0; it < BUFFER_SIZE; it += ELEMENTS_PER_LDG_STG)
|
171 |
+
{
|
172 |
+
int element_index = seq_offset + it;
|
173 |
+
if (element_index < seq_len)
|
174 |
+
{
|
175 |
+
dst[it] = output[it];
|
176 |
+
}
|
177 |
+
}
|
178 |
+
|
179 |
+
}
|
180 |
+
|
181 |
+
template <typename input_t, typename output_t, typename acc_t>
|
182 |
+
void dispatch_anti_alias_activation_forward(
|
183 |
+
output_t *dst,
|
184 |
+
const input_t *src,
|
185 |
+
const input_t *up_ftr,
|
186 |
+
const input_t *down_ftr,
|
187 |
+
const input_t *alpha,
|
188 |
+
const input_t *beta,
|
189 |
+
int batch_size,
|
190 |
+
int channels,
|
191 |
+
int seq_len)
|
192 |
+
{
|
193 |
+
if (seq_len == 0)
|
194 |
+
{
|
195 |
+
return;
|
196 |
+
}
|
197 |
+
else
|
198 |
+
{
|
199 |
+
// Use 128 threads per block to maximimize gpu utilization
|
200 |
+
constexpr int threads_per_block = 128;
|
201 |
+
constexpr int seq_len_per_block = 4096;
|
202 |
+
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
203 |
+
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
204 |
+
dim3 threads(threads_per_block, 1, 1);
|
205 |
+
|
206 |
+
anti_alias_activation_forward<input_t, output_t, acc_t>
|
207 |
+
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, up_ftr, down_ftr, alpha, beta, batch_size, channels, seq_len);
|
208 |
+
}
|
209 |
+
}
|
210 |
+
}
|
211 |
+
|
212 |
+
extern "C" torch::Tensor fwd_cuda(torch::Tensor const &input, torch::Tensor const &up_filter, torch::Tensor const &down_filter, torch::Tensor const &alpha, torch::Tensor const &beta)
|
213 |
+
{
|
214 |
+
// Input is a 3d tensor with dimensions [batches, channels, seq_len]
|
215 |
+
const int batches = input.size(0);
|
216 |
+
const int channels = input.size(1);
|
217 |
+
const int seq_len = input.size(2);
|
218 |
+
|
219 |
+
// Output
|
220 |
+
auto act_options = input.options().requires_grad(false);
|
221 |
+
|
222 |
+
torch::Tensor anti_alias_activation_results =
|
223 |
+
torch::empty({batches, channels, seq_len}, act_options);
|
224 |
+
|
225 |
+
void *input_ptr = static_cast<void *>(input.data_ptr());
|
226 |
+
void *up_filter_ptr = static_cast<void *>(up_filter.data_ptr());
|
227 |
+
void *down_filter_ptr = static_cast<void *>(down_filter.data_ptr());
|
228 |
+
void *alpha_ptr = static_cast<void *>(alpha.data_ptr());
|
229 |
+
void *beta_ptr = static_cast<void *>(beta.data_ptr());
|
230 |
+
void *anti_alias_activation_results_ptr = static_cast<void *>(anti_alias_activation_results.data_ptr());
|
231 |
+
|
232 |
+
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
233 |
+
input.scalar_type(),
|
234 |
+
"dispatch anti alias activation_forward",
|
235 |
+
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
|
236 |
+
reinterpret_cast<scalar_t *>(anti_alias_activation_results_ptr),
|
237 |
+
reinterpret_cast<const scalar_t *>(input_ptr),
|
238 |
+
reinterpret_cast<const scalar_t *>(up_filter_ptr),
|
239 |
+
reinterpret_cast<const scalar_t *>(down_filter_ptr),
|
240 |
+
reinterpret_cast<const scalar_t *>(alpha_ptr),
|
241 |
+
reinterpret_cast<const scalar_t *>(beta_ptr),
|
242 |
+
batches,
|
243 |
+
channels,
|
244 |
+
seq_len););
|
245 |
+
return anti_alias_activation_results;
|
246 |
+
}
|
{alias_free_cuda β alias_free_activation/cuda}/compat.h
RENAMED
@@ -18,8 +18,6 @@
|
|
18 |
* https://github.com/NVIDIA/apex
|
19 |
* with minor changes. */
|
20 |
|
21 |
-
|
22 |
-
|
23 |
#ifndef TORCH_CHECK
|
24 |
#define TORCH_CHECK AT_CHECK
|
25 |
#endif
|
|
|
18 |
* https://github.com/NVIDIA/apex
|
19 |
* with minor changes. */
|
20 |
|
|
|
|
|
21 |
#ifndef TORCH_CHECK
|
22 |
#define TORCH_CHECK AT_CHECK
|
23 |
#endif
|
{alias_free_cuda β alias_free_activation/cuda}/load.py
RENAMED
@@ -7,26 +7,24 @@ import subprocess
|
|
7 |
|
8 |
from torch.utils import cpp_extension
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
# extra_cuda_cflags below
|
15 |
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
16 |
|
17 |
|
18 |
def load():
|
19 |
# Check if cuda 11 is installed for compute capability 8.0
|
20 |
cc_flag = []
|
21 |
-
_, bare_metal_major, _ = _get_cuda_bare_metal_version(
|
22 |
-
cpp_extension.CUDA_HOME)
|
23 |
if int(bare_metal_major) >= 11:
|
24 |
-
cc_flag.append(
|
25 |
-
cc_flag.append(
|
26 |
|
27 |
# Build path
|
28 |
srcpath = pathlib.Path(__file__).parent.absolute()
|
29 |
-
buildpath = srcpath /
|
30 |
_create_build_dir(buildpath)
|
31 |
|
32 |
# Helper function to build the kernels.
|
@@ -35,26 +33,42 @@ def load():
|
|
35 |
name=name,
|
36 |
sources=sources,
|
37 |
build_directory=buildpath,
|
38 |
-
extra_cflags=[
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
)
|
44 |
|
45 |
-
extra_cuda_flags = [
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
52 |
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
53 |
-
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
|
|
|
|
|
|
|
|
54 |
|
55 |
def _get_cuda_bare_metal_version(cuda_dir):
|
56 |
-
raw_output = subprocess.check_output(
|
57 |
-
|
|
|
58 |
output = raw_output.split()
|
59 |
release_idx = output.index("release") + 1
|
60 |
release = output[release_idx].split(".")
|
@@ -69,4 +83,4 @@ def _create_build_dir(buildpath):
|
|
69 |
os.mkdir(buildpath)
|
70 |
except OSError:
|
71 |
if not os.path.isdir(buildpath):
|
72 |
-
print(f"Creation of the build directory {buildpath} failed")
|
|
|
7 |
|
8 |
from torch.utils import cpp_extension
|
9 |
|
10 |
+
"""
|
11 |
+
Setting this param to a list has a problem of generating different compilation commands (with diferent order of architectures) and leading to recompilation of fused kernels.
|
12 |
+
Set it to empty stringo avoid recompilation and assign arch flags explicity in extra_cuda_cflags below
|
13 |
+
"""
|
|
|
14 |
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
|
15 |
|
16 |
|
17 |
def load():
|
18 |
# Check if cuda 11 is installed for compute capability 8.0
|
19 |
cc_flag = []
|
20 |
+
_, bare_metal_major, _ = _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
|
|
|
21 |
if int(bare_metal_major) >= 11:
|
22 |
+
cc_flag.append("-gencode")
|
23 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
24 |
|
25 |
# Build path
|
26 |
srcpath = pathlib.Path(__file__).parent.absolute()
|
27 |
+
buildpath = srcpath / "build"
|
28 |
_create_build_dir(buildpath)
|
29 |
|
30 |
# Helper function to build the kernels.
|
|
|
33 |
name=name,
|
34 |
sources=sources,
|
35 |
build_directory=buildpath,
|
36 |
+
extra_cflags=[
|
37 |
+
"-O3",
|
38 |
+
],
|
39 |
+
extra_cuda_cflags=[
|
40 |
+
"-O3",
|
41 |
+
"-gencode",
|
42 |
+
"arch=compute_70,code=sm_70",
|
43 |
+
"--use_fast_math",
|
44 |
+
]
|
45 |
+
+ extra_cuda_flags
|
46 |
+
+ cc_flag,
|
47 |
+
verbose=True,
|
48 |
)
|
49 |
|
50 |
+
extra_cuda_flags = [
|
51 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
52 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
53 |
+
"--expt-relaxed-constexpr",
|
54 |
+
"--expt-extended-lambda",
|
55 |
+
]
|
56 |
+
|
57 |
+
sources = [
|
58 |
+
srcpath / "anti_alias_activation.cpp",
|
59 |
+
srcpath / "anti_alias_activation_cuda.cu",
|
60 |
+
]
|
61 |
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
62 |
+
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
63 |
+
)
|
64 |
+
|
65 |
+
return anti_alias_activation_cuda
|
66 |
+
|
67 |
|
68 |
def _get_cuda_bare_metal_version(cuda_dir):
|
69 |
+
raw_output = subprocess.check_output(
|
70 |
+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
71 |
+
)
|
72 |
output = raw_output.split()
|
73 |
release_idx = output.index("release") + 1
|
74 |
release = output[release_idx].split(".")
|
|
|
83 |
os.mkdir(buildpath)
|
84 |
except OSError:
|
85 |
if not os.path.isdir(buildpath):
|
86 |
+
print(f"Creation of the build directory {buildpath} failed")
|
alias_free_activation/cuda/type_shim.h
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* coding=utf-8
|
2 |
+
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#include <ATen/ATen.h>
|
18 |
+
#include "compat.h"
|
19 |
+
|
20 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
21 |
+
switch (TYPE) \
|
22 |
+
{ \
|
23 |
+
case at::ScalarType::Float: \
|
24 |
+
{ \
|
25 |
+
using scalar_t = float; \
|
26 |
+
__VA_ARGS__; \
|
27 |
+
break; \
|
28 |
+
} \
|
29 |
+
case at::ScalarType::Half: \
|
30 |
+
{ \
|
31 |
+
using scalar_t = at::Half; \
|
32 |
+
__VA_ARGS__; \
|
33 |
+
break; \
|
34 |
+
} \
|
35 |
+
case at::ScalarType::BFloat16: \
|
36 |
+
{ \
|
37 |
+
using scalar_t = at::BFloat16; \
|
38 |
+
__VA_ARGS__; \
|
39 |
+
break; \
|
40 |
+
} \
|
41 |
+
default: \
|
42 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
43 |
+
}
|
44 |
+
|
45 |
+
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
46 |
+
switch (TYPEIN) \
|
47 |
+
{ \
|
48 |
+
case at::ScalarType::Float: \
|
49 |
+
{ \
|
50 |
+
using scalar_t_in = float; \
|
51 |
+
switch (TYPEOUT) \
|
52 |
+
{ \
|
53 |
+
case at::ScalarType::Float: \
|
54 |
+
{ \
|
55 |
+
using scalar_t_out = float; \
|
56 |
+
__VA_ARGS__; \
|
57 |
+
break; \
|
58 |
+
} \
|
59 |
+
case at::ScalarType::Half: \
|
60 |
+
{ \
|
61 |
+
using scalar_t_out = at::Half; \
|
62 |
+
__VA_ARGS__; \
|
63 |
+
break; \
|
64 |
+
} \
|
65 |
+
case at::ScalarType::BFloat16: \
|
66 |
+
{ \
|
67 |
+
using scalar_t_out = at::BFloat16; \
|
68 |
+
__VA_ARGS__; \
|
69 |
+
break; \
|
70 |
+
} \
|
71 |
+
default: \
|
72 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
73 |
+
} \
|
74 |
+
break; \
|
75 |
+
} \
|
76 |
+
case at::ScalarType::Half: \
|
77 |
+
{ \
|
78 |
+
using scalar_t_in = at::Half; \
|
79 |
+
using scalar_t_out = at::Half; \
|
80 |
+
__VA_ARGS__; \
|
81 |
+
break; \
|
82 |
+
} \
|
83 |
+
case at::ScalarType::BFloat16: \
|
84 |
+
{ \
|
85 |
+
using scalar_t_in = at::BFloat16; \
|
86 |
+
using scalar_t_out = at::BFloat16; \
|
87 |
+
__VA_ARGS__; \
|
88 |
+
break; \
|
89 |
+
} \
|
90 |
+
default: \
|
91 |
+
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
92 |
+
}
|
{alias_free_torch β alias_free_activation/torch}/__init__.py
RENAMED
@@ -3,4 +3,4 @@
|
|
3 |
|
4 |
from .filter import *
|
5 |
from .resample import *
|
6 |
-
from .act import *
|
|
|
3 |
|
4 |
from .filter import *
|
5 |
from .resample import *
|
6 |
+
from .act import *
|
{alias_free_torch β alias_free_activation/torch}/act.py
RENAMED
@@ -2,16 +2,18 @@
|
|
2 |
# LICENSE is in incl_licenses directory.
|
3 |
|
4 |
import torch.nn as nn
|
5 |
-
from .resample import UpSample1d, DownSample1d
|
6 |
|
7 |
|
8 |
class Activation1d(nn.Module):
|
9 |
-
def __init__(
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
15 |
super().__init__()
|
16 |
self.up_ratio = up_ratio
|
17 |
self.down_ratio = down_ratio
|
@@ -25,4 +27,4 @@ class Activation1d(nn.Module):
|
|
25 |
x = self.act(x)
|
26 |
x = self.downsample(x)
|
27 |
|
28 |
-
return x
|
|
|
2 |
# LICENSE is in incl_licenses directory.
|
3 |
|
4 |
import torch.nn as nn
|
5 |
+
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
6 |
|
7 |
|
8 |
class Activation1d(nn.Module):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
activation,
|
12 |
+
up_ratio: int = 2,
|
13 |
+
down_ratio: int = 2,
|
14 |
+
up_kernel_size: int = 12,
|
15 |
+
down_kernel_size: int = 12,
|
16 |
+
):
|
17 |
super().__init__()
|
18 |
self.up_ratio = up_ratio
|
19 |
self.down_ratio = down_ratio
|
|
|
27 |
x = self.act(x)
|
28 |
x = self.downsample(x)
|
29 |
|
30 |
+
return x
|
{alias_free_torch β alias_free_activation/torch}/filter.py
RENAMED
@@ -6,7 +6,7 @@ import torch.nn as nn
|
|
6 |
import torch.nn.functional as F
|
7 |
import math
|
8 |
|
9 |
-
if
|
10 |
sinc = torch.sinc
|
11 |
else:
|
12 |
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
@@ -17,40 +17,45 @@ else:
|
|
17 |
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
"""
|
20 |
-
return torch.where(
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
|
24 |
|
25 |
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
26 |
# https://adefossez.github.io/julius/julius/lowpass.html
|
27 |
# LICENSE is in incl_licenses directory.
|
28 |
-
def kaiser_sinc_filter1d(
|
29 |
-
|
|
|
|
|
30 |
half_size = kernel_size // 2
|
31 |
|
32 |
-
#For kaiser window
|
33 |
delta_f = 4 * half_width
|
34 |
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
35 |
-
if A > 50
|
36 |
beta = 0.1102 * (A - 8.7)
|
37 |
-
elif A >= 21
|
38 |
-
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
39 |
else:
|
40 |
-
beta = 0.
|
41 |
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
42 |
|
43 |
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
44 |
if even:
|
45 |
-
time =
|
46 |
else:
|
47 |
time = torch.arange(kernel_size) - half_size
|
48 |
if cutoff == 0:
|
49 |
filter_ = torch.zeros_like(time)
|
50 |
else:
|
51 |
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
52 |
-
|
53 |
-
|
|
|
54 |
filter_ /= filter_.sum()
|
55 |
filter = filter_.view(1, 1, kernel_size)
|
56 |
|
@@ -58,22 +63,25 @@ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,
|
|
58 |
|
59 |
|
60 |
class LowPassFilter1d(nn.Module):
|
61 |
-
def __init__(
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
70 |
super().__init__()
|
71 |
-
if cutoff < -0
|
72 |
raise ValueError("Minimum cutoff must be larger than zero.")
|
73 |
if cutoff > 0.5:
|
74 |
raise ValueError("A cutoff above 0.5 does not make sense.")
|
75 |
self.kernel_size = kernel_size
|
76 |
-
self.even =
|
77 |
self.pad_left = kernel_size // 2 - int(self.even)
|
78 |
self.pad_right = kernel_size // 2
|
79 |
self.stride = stride
|
@@ -82,14 +90,12 @@ class LowPassFilter1d(nn.Module):
|
|
82 |
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
83 |
self.register_buffer("filter", filter)
|
84 |
|
85 |
-
#
|
86 |
def forward(self, x):
|
87 |
_, C, _ = x.shape
|
88 |
|
89 |
if self.padding:
|
90 |
-
x = F.pad(x, (self.pad_left, self.pad_right),
|
91 |
-
|
92 |
-
out = F.conv1d(x, self.filter.expand(C, -1, -1),
|
93 |
-
stride=self.stride, groups=C)
|
94 |
|
95 |
-
return out
|
|
|
6 |
import torch.nn.functional as F
|
7 |
import math
|
8 |
|
9 |
+
if "sinc" in dir(torch):
|
10 |
sinc = torch.sinc
|
11 |
else:
|
12 |
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
|
|
17 |
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
18 |
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
19 |
"""
|
20 |
+
return torch.where(
|
21 |
+
x == 0,
|
22 |
+
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
23 |
+
torch.sin(math.pi * x) / math.pi / x,
|
24 |
+
)
|
25 |
|
26 |
|
27 |
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
28 |
# https://adefossez.github.io/julius/julius/lowpass.html
|
29 |
# LICENSE is in incl_licenses directory.
|
30 |
+
def kaiser_sinc_filter1d(
|
31 |
+
cutoff, half_width, kernel_size
|
32 |
+
): # return filter [1,1,kernel_size]
|
33 |
+
even = kernel_size % 2 == 0
|
34 |
half_size = kernel_size // 2
|
35 |
|
36 |
+
# For kaiser window
|
37 |
delta_f = 4 * half_width
|
38 |
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
39 |
+
if A > 50.0:
|
40 |
beta = 0.1102 * (A - 8.7)
|
41 |
+
elif A >= 21.0:
|
42 |
+
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
43 |
else:
|
44 |
+
beta = 0.0
|
45 |
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
46 |
|
47 |
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
48 |
if even:
|
49 |
+
time = torch.arange(-half_size, half_size) + 0.5
|
50 |
else:
|
51 |
time = torch.arange(kernel_size) - half_size
|
52 |
if cutoff == 0:
|
53 |
filter_ = torch.zeros_like(time)
|
54 |
else:
|
55 |
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
56 |
+
"""
|
57 |
+
Normalize filter to have sum = 1, otherwise we will have a small leakage of the constant component in the input signal.
|
58 |
+
"""
|
59 |
filter_ /= filter_.sum()
|
60 |
filter = filter_.view(1, 1, kernel_size)
|
61 |
|
|
|
63 |
|
64 |
|
65 |
class LowPassFilter1d(nn.Module):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
cutoff=0.5,
|
69 |
+
half_width=0.6,
|
70 |
+
stride: int = 1,
|
71 |
+
padding: bool = True,
|
72 |
+
padding_mode: str = "replicate",
|
73 |
+
kernel_size: int = 12,
|
74 |
+
):
|
75 |
+
"""
|
76 |
+
kernel_size should be even number for stylegan3 setup, in this implementation, odd number is also possible.
|
77 |
+
"""
|
78 |
super().__init__()
|
79 |
+
if cutoff < -0.0:
|
80 |
raise ValueError("Minimum cutoff must be larger than zero.")
|
81 |
if cutoff > 0.5:
|
82 |
raise ValueError("A cutoff above 0.5 does not make sense.")
|
83 |
self.kernel_size = kernel_size
|
84 |
+
self.even = kernel_size % 2 == 0
|
85 |
self.pad_left = kernel_size // 2 - int(self.even)
|
86 |
self.pad_right = kernel_size // 2
|
87 |
self.stride = stride
|
|
|
90 |
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
91 |
self.register_buffer("filter", filter)
|
92 |
|
93 |
+
# Input [B, C, T]
|
94 |
def forward(self, x):
|
95 |
_, C, _ = x.shape
|
96 |
|
97 |
if self.padding:
|
98 |
+
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
99 |
+
out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
|
|
|
|
100 |
|
101 |
+
return out
|
{alias_free_torch β alias_free_activation/torch}/resample.py
RENAMED
@@ -3,32 +3,37 @@
|
|
3 |
|
4 |
import torch.nn as nn
|
5 |
from torch.nn import functional as F
|
6 |
-
from .filter import LowPassFilter1d
|
7 |
-
from .filter import kaiser_sinc_filter1d
|
8 |
|
9 |
|
10 |
class UpSample1d(nn.Module):
|
11 |
def __init__(self, ratio=2, kernel_size=None):
|
12 |
super().__init__()
|
13 |
self.ratio = ratio
|
14 |
-
self.kernel_size =
|
|
|
|
|
15 |
self.stride = ratio
|
16 |
self.pad = self.kernel_size // ratio - 1
|
17 |
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
18 |
-
self.pad_right =
|
19 |
-
|
20 |
-
|
21 |
-
|
|
|
|
|
22 |
self.register_buffer("filter", filter)
|
23 |
|
24 |
# x: [B, C, T]
|
25 |
def forward(self, x):
|
26 |
_, C, _ = x.shape
|
27 |
|
28 |
-
x = F.pad(x, (self.pad, self.pad), mode=
|
29 |
x = self.ratio * F.conv_transpose1d(
|
30 |
-
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
31 |
-
|
|
|
32 |
|
33 |
return x
|
34 |
|
@@ -37,13 +42,17 @@ class DownSample1d(nn.Module):
|
|
37 |
def __init__(self, ratio=2, kernel_size=None):
|
38 |
super().__init__()
|
39 |
self.ratio = ratio
|
40 |
-
self.kernel_size =
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def forward(self, x):
|
47 |
xx = self.lowpass(x)
|
48 |
|
49 |
-
return xx
|
|
|
3 |
|
4 |
import torch.nn as nn
|
5 |
from torch.nn import functional as F
|
6 |
+
from alias_free_activation.torch.filter import LowPassFilter1d
|
7 |
+
from alias_free_activation.torch.filter import kaiser_sinc_filter1d
|
8 |
|
9 |
|
10 |
class UpSample1d(nn.Module):
|
11 |
def __init__(self, ratio=2, kernel_size=None):
|
12 |
super().__init__()
|
13 |
self.ratio = ratio
|
14 |
+
self.kernel_size = (
|
15 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
16 |
+
)
|
17 |
self.stride = ratio
|
18 |
self.pad = self.kernel_size // ratio - 1
|
19 |
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
20 |
+
self.pad_right = (
|
21 |
+
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
22 |
+
)
|
23 |
+
filter = kaiser_sinc_filter1d(
|
24 |
+
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
25 |
+
)
|
26 |
self.register_buffer("filter", filter)
|
27 |
|
28 |
# x: [B, C, T]
|
29 |
def forward(self, x):
|
30 |
_, C, _ = x.shape
|
31 |
|
32 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
33 |
x = self.ratio * F.conv_transpose1d(
|
34 |
+
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
35 |
+
)
|
36 |
+
x = x[..., self.pad_left : -self.pad_right]
|
37 |
|
38 |
return x
|
39 |
|
|
|
42 |
def __init__(self, ratio=2, kernel_size=None):
|
43 |
super().__init__()
|
44 |
self.ratio = ratio
|
45 |
+
self.kernel_size = (
|
46 |
+
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
47 |
+
)
|
48 |
+
self.lowpass = LowPassFilter1d(
|
49 |
+
cutoff=0.5 / ratio,
|
50 |
+
half_width=0.6 / ratio,
|
51 |
+
stride=ratio,
|
52 |
+
kernel_size=self.kernel_size,
|
53 |
+
)
|
54 |
|
55 |
def forward(self, x):
|
56 |
xx = self.lowpass(x)
|
57 |
|
58 |
+
return xx
|
alias_free_cuda/anti_alias_activation_cuda.cu
DELETED
@@ -1,314 +0,0 @@
|
|
1 |
-
/* coding=utf-8
|
2 |
-
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
*
|
4 |
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
* you may not use this file except in compliance with the License.
|
6 |
-
* You may obtain a copy of the License at
|
7 |
-
*
|
8 |
-
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
*
|
10 |
-
* Unless required by applicable law or agreed to in writing, software
|
11 |
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
* See the License for the specific language governing permissions and
|
14 |
-
* limitations under the License.
|
15 |
-
*/
|
16 |
-
|
17 |
-
#include <ATen/ATen.h>
|
18 |
-
#include <cuda.h>
|
19 |
-
#include <cuda_runtime.h>
|
20 |
-
#include <cuda_fp16.h>
|
21 |
-
#include <cuda_profiler_api.h>
|
22 |
-
#include <ATen/cuda/CUDAContext.h>
|
23 |
-
#include <torch/extension.h>
|
24 |
-
#include "type_shim.h"
|
25 |
-
#include <assert.h>
|
26 |
-
#include <cfloat>
|
27 |
-
#include <limits>
|
28 |
-
#include <stdint.h>
|
29 |
-
#include <c10/macros/Macros.h>
|
30 |
-
|
31 |
-
namespace {
|
32 |
-
|
33 |
-
/*
|
34 |
-
template <typename Datatype, int ELEMENTS_PER_LDG>
|
35 |
-
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
36 |
-
|
37 |
-
template <>
|
38 |
-
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; }
|
39 |
-
|
40 |
-
template <>
|
41 |
-
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); }
|
42 |
-
|
43 |
-
template <>
|
44 |
-
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; }
|
45 |
-
|
46 |
-
template <>
|
47 |
-
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); }
|
48 |
-
|
49 |
-
template <>
|
50 |
-
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; }
|
51 |
-
|
52 |
-
template <>
|
53 |
-
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); }
|
54 |
-
|
55 |
-
int log2_ceil(int value) {
|
56 |
-
int log2_value = 0;
|
57 |
-
while ((1 << log2_value) < value) ++log2_value;
|
58 |
-
return log2_value;
|
59 |
-
}
|
60 |
-
|
61 |
-
template<typename T>
|
62 |
-
struct Add {
|
63 |
-
__device__ __forceinline__ T operator()(T a, T b) const {
|
64 |
-
return a + b;
|
65 |
-
}
|
66 |
-
};
|
67 |
-
|
68 |
-
template<typename T>
|
69 |
-
struct Max {
|
70 |
-
__device__ __forceinline__ T operator()(T a, T b) const {
|
71 |
-
return a < b ? b : a;
|
72 |
-
}
|
73 |
-
};
|
74 |
-
|
75 |
-
template <typename T>
|
76 |
-
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
|
77 |
-
{
|
78 |
-
#if CUDA_VERSION >= 9000
|
79 |
-
return __shfl_xor_sync(mask, value, laneMask, width);
|
80 |
-
#else
|
81 |
-
return __shfl_xor(value, laneMask, width);
|
82 |
-
#endif
|
83 |
-
}
|
84 |
-
|
85 |
-
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp>
|
86 |
-
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
|
87 |
-
ReduceOp<acc_t> r;
|
88 |
-
#pragma unroll
|
89 |
-
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
|
90 |
-
#pragma unroll
|
91 |
-
for (int i = 0; i < WARP_BATCH; ++i) {
|
92 |
-
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
|
93 |
-
sum[i] = r(sum[i], b);
|
94 |
-
}
|
95 |
-
}
|
96 |
-
}
|
97 |
-
*/
|
98 |
-
|
99 |
-
template <typename input_t, typename output_t, typename acc_t>
|
100 |
-
__global__ void anti_alias_activation_forward(
|
101 |
-
output_t *dst,
|
102 |
-
const input_t *src,
|
103 |
-
const input_t *ftr,
|
104 |
-
const input_t *alpha,
|
105 |
-
const input_t *beta,
|
106 |
-
int batch_size,
|
107 |
-
int channels,
|
108 |
-
int seq_len)
|
109 |
-
{
|
110 |
-
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
|
111 |
-
constexpr int ELEMENTS_PER_LDG_STG = 1; //(WARP_ITERATIONS < 4) ? 1 : 4;
|
112 |
-
constexpr int BUFFER_SIZE = 32;
|
113 |
-
constexpr int FILTER_SIZE = 12;
|
114 |
-
constexpr int HALF_FILTER_SIZE = 6;
|
115 |
-
constexpr int REPLICATION_PAD = 5; // 5 on each side
|
116 |
-
|
117 |
-
// blockDim/threadIdx = (128, 1, 1)
|
118 |
-
// gridDim/blockIdx = (seq_blocks, channels, batches)
|
119 |
-
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
120 |
-
int local_offset = threadIdx.x * BUFFER_SIZE;
|
121 |
-
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset;
|
122 |
-
|
123 |
-
|
124 |
-
//int intermediate_seq_len = seq_len * 2 - 1 + 4 * REPLICATION_PAD;
|
125 |
-
//int intermediate_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + intermediate_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
126 |
-
//int intermediate_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
127 |
-
|
128 |
-
int output_seq_len = seq_len * 2 ; //
|
129 |
-
int output_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + output_seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
130 |
-
int output_local_offset = threadIdx.x * BUFFER_SIZE * 2;
|
131 |
-
int output_seq_offset = blockIdx.x * 128 * BUFFER_SIZE *2 + output_local_offset;
|
132 |
-
// get values needed for replication padding before moving pointer
|
133 |
-
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z));
|
134 |
-
input_t seq_left_most_value = right_most_pntr[0];
|
135 |
-
input_t seq_right_most_value = right_most_pntr[seq_len - 1];
|
136 |
-
|
137 |
-
src += block_offset + local_offset;
|
138 |
-
dst += output_block_offset + output_local_offset ;
|
139 |
-
alpha = alpha + blockIdx.y;
|
140 |
-
input_t alpha_val = expf(alpha[0]);
|
141 |
-
beta = beta + blockIdx.y;
|
142 |
-
input_t beta_val = expf(beta[0]);
|
143 |
-
// load data from global memory
|
144 |
-
input_t elements[2*FILTER_SIZE+2*BUFFER_SIZE] = {0};
|
145 |
-
input_t intermediates[2*FILTER_SIZE+2*BUFFER_SIZE] = {0};
|
146 |
-
//output_t output[2*BUFFER_SIZE];
|
147 |
-
input_t filter[FILTER_SIZE];
|
148 |
-
//input_t temp_data[ELEMENTS_PER_LDG_STG];
|
149 |
-
//uint8_t temp_mask[ELEMENTS_PER_LDG_STG];
|
150 |
-
|
151 |
-
#pragma unroll
|
152 |
-
for (int it = 0; it < FILTER_SIZE; it+=1) {
|
153 |
-
filter[it] = ftr[it];
|
154 |
-
}
|
155 |
-
|
156 |
-
|
157 |
-
#pragma unroll
|
158 |
-
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE ; it+=1) {
|
159 |
-
int element_index = seq_offset + it;
|
160 |
-
if ((element_index < 0) && (element_index >= -REPLICATION_PAD)) {
|
161 |
-
elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_left_most_value;
|
162 |
-
}
|
163 |
-
if ((element_index >= seq_len) && (element_index < seq_len + REPLICATION_PAD)) {
|
164 |
-
elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_right_most_value;
|
165 |
-
}
|
166 |
-
if ((element_index >= 0) && (element_index < seq_len)) {
|
167 |
-
elements[2*(HALF_FILTER_SIZE+it)] = 2*src[it];
|
168 |
-
}
|
169 |
-
}
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
// apply filter
|
174 |
-
#pragma unroll
|
175 |
-
for (int it = 0; it < (2 * BUFFER_SIZE + 2*FILTER_SIZE); it+=1) {
|
176 |
-
input_t acc = 0.0;
|
177 |
-
|
178 |
-
int element_index = output_seq_offset + it; // index for output
|
179 |
-
#pragma unroll
|
180 |
-
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){
|
181 |
-
if ((element_index + f_idx) >= 0){
|
182 |
-
acc += filter[f_idx] * elements[it+f_idx];
|
183 |
-
}
|
184 |
-
}
|
185 |
-
intermediates[it] = acc;
|
186 |
-
}
|
187 |
-
|
188 |
-
double no_div_by_zero = 0.000000001;
|
189 |
-
#pragma unroll
|
190 |
-
for (int it = 0; it < 12 + 2 * BUFFER_SIZE; it++) {
|
191 |
-
intermediates[it] += (1.0/(beta_val + no_div_by_zero)) * sinf(intermediates[it] * alpha_val) * sinf(intermediates[it] * alpha_val);
|
192 |
-
}
|
193 |
-
|
194 |
-
|
195 |
-
// now copy to output
|
196 |
-
#pragma unroll
|
197 |
-
for (int it = 0; it < 2*BUFFER_SIZE; it+=1){
|
198 |
-
int element_index = output_seq_offset + it;
|
199 |
-
if (element_index < output_seq_len) {
|
200 |
-
dst[it] = intermediates[it+6];
|
201 |
-
}
|
202 |
-
}
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
// for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) {
|
207 |
-
// int element_index = seq_offset + it;
|
208 |
-
// if (element_index < seq_len) {
|
209 |
-
// dst[it] = output[it];
|
210 |
-
// }
|
211 |
-
// }
|
212 |
-
|
213 |
-
|
214 |
-
// // Upsample convolution
|
215 |
-
// for (int it = 0; it < 2 * BUFFER_SIZE + 12; it+=1) {
|
216 |
-
// input_t acc = 0.0;
|
217 |
-
|
218 |
-
// for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){
|
219 |
-
// acc += filter[f_idx] * elements[it+f_idx];
|
220 |
-
// }
|
221 |
-
// intermediates[it] = acc;
|
222 |
-
// }
|
223 |
-
|
224 |
-
// // correct the corners of intermediates
|
225 |
-
// if (seq_offset == 0) {
|
226 |
-
// for (int it = 0; it < 6; it+=1)
|
227 |
-
// intermediates[it] = 0;
|
228 |
-
// }
|
229 |
-
|
230 |
-
// if (seq_offset + 32 >= seq_len) {
|
231 |
-
// int offset = seq_len % 32 == 0 ? 32 : seq_len % 32;
|
232 |
-
|
233 |
-
// for (int it = 0; it < 6; it++) {
|
234 |
-
// intermediates[6+2*offset+it] = 0;
|
235 |
-
// }
|
236 |
-
// }
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
// for (int it = 0; it < BUFFER_SIZE; it+=ELEMENTS_PER_LDG_STG) {
|
242 |
-
// int element_index = seq_offset + it;
|
243 |
-
// if (element_index < seq_len) {
|
244 |
-
// dst[it] = output[it];
|
245 |
-
// }
|
246 |
-
// }
|
247 |
-
}
|
248 |
-
|
249 |
-
template<typename input_t, typename output_t, typename acc_t>
|
250 |
-
void dispatch_anti_alias_activation_forward(
|
251 |
-
output_t *dst,
|
252 |
-
const input_t *src,
|
253 |
-
const input_t *ftr,
|
254 |
-
const input_t *alpha,
|
255 |
-
const input_t *beta,
|
256 |
-
int batch_size,
|
257 |
-
int channels,
|
258 |
-
int seq_len)
|
259 |
-
{
|
260 |
-
if (seq_len == 0) {
|
261 |
-
return;
|
262 |
-
} else {
|
263 |
-
// use 128 threads per block to maximimize gpu utilization
|
264 |
-
constexpr int threads_per_block = 128;
|
265 |
-
constexpr int seq_len_per_block = 4096;
|
266 |
-
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block;
|
267 |
-
dim3 blocks(blocks_per_seq_len, channels, batch_size);
|
268 |
-
dim3 threads(threads_per_block, 1, 1);
|
269 |
-
|
270 |
-
anti_alias_activation_forward<input_t, output_t, acc_t>
|
271 |
-
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, ftr, alpha, beta, batch_size, channels, seq_len);
|
272 |
-
}
|
273 |
-
}
|
274 |
-
}
|
275 |
-
|
276 |
-
namespace anti_alias_activation {
|
277 |
-
|
278 |
-
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& filter, torch::Tensor const& alpha, torch::Tensor const& beta)
|
279 |
-
{
|
280 |
-
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
|
281 |
-
const int batches = input.size(0);
|
282 |
-
const int channels = input.size(1);
|
283 |
-
const int seq_len = input.size(2);
|
284 |
-
|
285 |
-
// Output
|
286 |
-
auto act_options = input.options().requires_grad(false);
|
287 |
-
int output_seq_len = seq_len*2; // we'll be dilating between each element by interspersing with zeros
|
288 |
-
|
289 |
-
torch::Tensor anti_alias_activation_results =
|
290 |
-
torch::empty({batches, channels, output_seq_len}, act_options);
|
291 |
-
|
292 |
-
// Softmax Intermediate Result Ptr
|
293 |
-
void* input_ptr = static_cast<void*>(input.data_ptr());
|
294 |
-
void* filter_ptr = static_cast<void*>(filter.data_ptr());
|
295 |
-
void* alpha_ptr = static_cast<void*>(alpha.data_ptr());
|
296 |
-
void* beta_ptr = static_cast<void*>(beta.data_ptr());
|
297 |
-
void* anti_alias_activation_results_ptr = static_cast<void*>(anti_alias_activation_results.data_ptr());
|
298 |
-
|
299 |
-
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
300 |
-
input.scalar_type(),
|
301 |
-
"dispatch anti alias activation_forward",
|
302 |
-
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>(
|
303 |
-
reinterpret_cast<scalar_t*>(anti_alias_activation_results_ptr),
|
304 |
-
reinterpret_cast<const scalar_t*>(input_ptr),
|
305 |
-
reinterpret_cast<const scalar_t*>(filter_ptr),
|
306 |
-
reinterpret_cast<const scalar_t*>(alpha_ptr),
|
307 |
-
reinterpret_cast<const scalar_t*>(beta_ptr),
|
308 |
-
batches,
|
309 |
-
channels,
|
310 |
-
seq_len);
|
311 |
-
);
|
312 |
-
return anti_alias_activation_results;
|
313 |
-
}
|
314 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alias_free_cuda/test_activation.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
-
# Licensed under the MIT license.
|
3 |
-
|
4 |
-
import math
|
5 |
-
import torch
|
6 |
-
import alias_free_cuda
|
7 |
-
from alias_free_cuda import activation1d
|
8 |
-
from activations import Snake, SnakeBeta
|
9 |
-
|
10 |
-
def test_load_fused_kernels():
|
11 |
-
try:
|
12 |
-
import alias_free_cuda
|
13 |
-
import torch
|
14 |
-
print("[Success] load_fused_kernels")
|
15 |
-
except ImportError as e:
|
16 |
-
print("[Fail] load_fused_kernels")
|
17 |
-
raise e
|
18 |
-
|
19 |
-
def test_anti_alias_activation():
|
20 |
-
data = torch.rand((10, 10, 50000), device='cuda')
|
21 |
-
|
22 |
-
# check activations.Snake cuda vs. torch
|
23 |
-
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
|
24 |
-
fused_activation_output = fused_anti_alias_activation(data)
|
25 |
-
|
26 |
-
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
|
27 |
-
torch_activation_output = torch_anti_alias_activation(data)
|
28 |
-
|
29 |
-
test_result = (fused_activation_output - torch_activation_output).abs()
|
30 |
-
|
31 |
-
while test_result.dim() != 1:
|
32 |
-
test_result = test_result.mean(dim=-1)
|
33 |
-
|
34 |
-
diff = test_result.mean(dim=-1)
|
35 |
-
|
36 |
-
if diff <= 1e-3:
|
37 |
-
print(
|
38 |
-
f"\n[Success] test_fused_anti_alias_activation"
|
39 |
-
f"\n > mean_difference={diff}"
|
40 |
-
f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}"
|
41 |
-
f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}"
|
42 |
-
)
|
43 |
-
else:
|
44 |
-
print(
|
45 |
-
f"\n[Fail] test_fused_anti_alias_activation"
|
46 |
-
f"\n > mean_difference={diff}, "
|
47 |
-
f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, "
|
48 |
-
f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}"
|
49 |
-
)
|
50 |
-
|
51 |
-
if __name__ == "__main__":
|
52 |
-
from alias_free_cuda import load
|
53 |
-
load.load()
|
54 |
-
test_load_fused_kernels()
|
55 |
-
test_anti_alias_activation()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alias_free_cuda/test_activation_snake_beta.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
-
# Licensed under the MIT license.
|
3 |
-
|
4 |
-
import math
|
5 |
-
import torch
|
6 |
-
import alias_free_cuda
|
7 |
-
from alias_free_cuda import activation1d
|
8 |
-
from activations import Snake, SnakeBeta
|
9 |
-
|
10 |
-
def test_load_fused_kernels():
|
11 |
-
try:
|
12 |
-
import alias_free_cuda
|
13 |
-
import torch
|
14 |
-
print("[Success] load_fused_kernels")
|
15 |
-
except ImportError as e:
|
16 |
-
print("[Fail] load_fused_kernels")
|
17 |
-
raise e
|
18 |
-
|
19 |
-
def test_anti_alias_activation():
|
20 |
-
data = torch.rand((10, 10, 50000), device='cuda')
|
21 |
-
|
22 |
-
# check activations.Snake cuda vs. torch
|
23 |
-
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
|
24 |
-
fused_activation_output = fused_anti_alias_activation(data)
|
25 |
-
|
26 |
-
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
|
27 |
-
torch_activation_output = torch_anti_alias_activation(data)
|
28 |
-
|
29 |
-
test_result = (fused_activation_output - torch_activation_output).abs()
|
30 |
-
|
31 |
-
while test_result.dim() != 1:
|
32 |
-
test_result = test_result.mean(dim=-1)
|
33 |
-
|
34 |
-
diff = test_result.mean(dim=-1)
|
35 |
-
|
36 |
-
if diff <= 1e-3:
|
37 |
-
print(
|
38 |
-
f"\n[Success] test_fused_anti_alias_activation"
|
39 |
-
f"\n > mean_difference={diff}"
|
40 |
-
f"\n > fused_values={fused_activation_output[-1][-1][-100:].tolist()}"
|
41 |
-
f"\n > torch_values={torch_activation_output[-1][-1][-100:].tolist()}"
|
42 |
-
)
|
43 |
-
else:
|
44 |
-
print(
|
45 |
-
f"\n[Fail] test_fused_anti_alias_activation"
|
46 |
-
f"\n > mean_difference={diff}, "
|
47 |
-
f"\n > fused_values={fused_activation_output[-1][-1][-30:].tolist()}, "
|
48 |
-
f"\n > torch_values={torch_activation_output[-1][-1][-30:].tolist()}"
|
49 |
-
)
|
50 |
-
|
51 |
-
if __name__ == "__main__":
|
52 |
-
from alias_free_cuda import load
|
53 |
-
load.load()
|
54 |
-
test_load_fused_kernels()
|
55 |
-
test_anti_alias_activation()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
alias_free_cuda/type_shim.h
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
/* coding=utf-8
|
2 |
-
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
3 |
-
*
|
4 |
-
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
* you may not use this file except in compliance with the License.
|
6 |
-
* You may obtain a copy of the License at
|
7 |
-
*
|
8 |
-
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
*
|
10 |
-
* Unless required by applicable law or agreed to in writing, software
|
11 |
-
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
* See the License for the specific language governing permissions and
|
14 |
-
* limitations under the License.
|
15 |
-
*/
|
16 |
-
|
17 |
-
|
18 |
-
#include <ATen/ATen.h>
|
19 |
-
#include "compat.h"
|
20 |
-
|
21 |
-
|
22 |
-
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
23 |
-
switch(TYPE) \
|
24 |
-
{ \
|
25 |
-
case at::ScalarType::Float: \
|
26 |
-
{ \
|
27 |
-
using scalar_t = float; \
|
28 |
-
__VA_ARGS__; \
|
29 |
-
break; \
|
30 |
-
} \
|
31 |
-
case at::ScalarType::Half: \
|
32 |
-
{ \
|
33 |
-
using scalar_t = at::Half; \
|
34 |
-
__VA_ARGS__; \
|
35 |
-
break; \
|
36 |
-
} \
|
37 |
-
case at::ScalarType::BFloat16: \
|
38 |
-
{ \
|
39 |
-
using scalar_t = at::BFloat16; \
|
40 |
-
__VA_ARGS__; \
|
41 |
-
break; \
|
42 |
-
} \
|
43 |
-
default: \
|
44 |
-
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
45 |
-
}
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
|
50 |
-
switch(TYPEIN) \
|
51 |
-
{ \
|
52 |
-
case at::ScalarType::Float: \
|
53 |
-
{ \
|
54 |
-
using scalar_t_in = float; \
|
55 |
-
switch(TYPEOUT) \
|
56 |
-
{ \
|
57 |
-
case at::ScalarType::Float: \
|
58 |
-
{ \
|
59 |
-
using scalar_t_out = float; \
|
60 |
-
__VA_ARGS__; \
|
61 |
-
break; \
|
62 |
-
} \
|
63 |
-
case at::ScalarType::Half: \
|
64 |
-
{ \
|
65 |
-
using scalar_t_out = at::Half; \
|
66 |
-
__VA_ARGS__; \
|
67 |
-
break; \
|
68 |
-
} \
|
69 |
-
case at::ScalarType::BFloat16: \
|
70 |
-
{ \
|
71 |
-
using scalar_t_out = at::BFloat16; \
|
72 |
-
__VA_ARGS__; \
|
73 |
-
break; \
|
74 |
-
} \
|
75 |
-
default: \
|
76 |
-
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
|
77 |
-
} \
|
78 |
-
break; \
|
79 |
-
} \
|
80 |
-
case at::ScalarType::Half: \
|
81 |
-
{ \
|
82 |
-
using scalar_t_in = at::Half; \
|
83 |
-
using scalar_t_out = at::Half; \
|
84 |
-
__VA_ARGS__; \
|
85 |
-
break; \
|
86 |
-
} \
|
87 |
-
case at::ScalarType::BFloat16: \
|
88 |
-
{ \
|
89 |
-
using scalar_t_in = at::BFloat16; \
|
90 |
-
using scalar_t_out = at::BFloat16; \
|
91 |
-
__VA_ARGS__; \
|
92 |
-
break; \
|
93 |
-
} \
|
94 |
-
default: \
|
95 |
-
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
|
96 |
-
}
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -273,7 +273,7 @@ with iface:
|
|
273 |
<h3>News</h3>
|
274 |
<p>[Jul 2024] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:</p>
|
275 |
<ul>
|
276 |
-
<li>Custom CUDA kernel for inference: we provide a fused
|
277 |
<li>Improved discriminator and loss: BigVGAN-v2 is trained using a <a href="https://arxiv.org/abs/2311.14957" target="_blank">multi-scale sub-band CQT discriminator</a> and a <a href="https://arxiv.org/abs/2306.06546" target="_blank">multi-scale mel spectrogram loss</a>.</li>
|
278 |
<li>Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.</li>
|
279 |
<li>We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio. See the table below for the link.</li>
|
|
|
273 |
<h3>News</h3>
|
274 |
<p>[Jul 2024] We release BigVGAN-v2 along with pretrained checkpoints. Below are the highlights:</p>
|
275 |
<ul>
|
276 |
+
<li>Custom CUDA kernel for inference: we provide a fused anti-aliased activation kernel written in CUDA for accelerated inference speed. Our test shows 1.5 - 3x faster speed on a single A100 GPU.</li>
|
277 |
<li>Improved discriminator and loss: BigVGAN-v2 is trained using a <a href="https://arxiv.org/abs/2311.14957" target="_blank">multi-scale sub-band CQT discriminator</a> and a <a href="https://arxiv.org/abs/2306.06546" target="_blank">multi-scale mel spectrogram loss</a>.</li>
|
278 |
<li>Larger training data: BigVGAN-v2 is trained using datasets containing diverse audio types, including speech in multiple languages, environmental sounds, and instruments.</li>
|
279 |
<li>We provide pretrained checkpoints of BigVGAN-v2 using diverse audio configurations, supporting up to 44 kHz sampling rate and 512x upsampling ratio. See the table below for the link.</li>
|
bigvgan.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
# Licensed under the MIT license.
|
3 |
|
4 |
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
@@ -7,78 +7,127 @@
|
|
7 |
import os
|
8 |
import json
|
9 |
from pathlib import Path
|
10 |
-
|
11 |
-
from collections import namedtuple
|
12 |
-
from typing import Optional, List, Union, Dict
|
13 |
|
14 |
import torch
|
15 |
-
import torch.nn.functional as F
|
16 |
import torch.nn as nn
|
17 |
from torch.nn import Conv1d, ConvTranspose1d
|
18 |
from torch.nn.utils import weight_norm, remove_weight_norm
|
19 |
|
20 |
import activations
|
21 |
from utils import init_weights, get_padding
|
22 |
-
from
|
23 |
from env import AttrDict
|
24 |
|
25 |
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
26 |
|
|
|
27 |
def load_hparams_from_json(path) -> AttrDict:
|
28 |
with open(path) as f:
|
29 |
data = f.read()
|
30 |
-
|
31 |
-
|
32 |
|
33 |
class AMPBlock1(torch.nn.Module):
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
self.h = h
|
37 |
|
38 |
-
self.convs1 = nn.ModuleList(
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
self.convs1.apply(init_weights)
|
47 |
|
48 |
-
self.convs2 = nn.ModuleList(
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
self.convs2.apply(init_weights)
|
57 |
|
58 |
-
self.num_layers = len(self.convs1) + len(
|
|
|
|
|
59 |
|
60 |
-
#
|
61 |
if self.h.get("use_cuda_kernel", False):
|
62 |
-
|
63 |
-
|
|
|
|
|
64 |
Activation1d = CudaActivation1d
|
65 |
else:
|
66 |
Activation1d = TorchActivation1d
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
else:
|
81 |
-
raise NotImplementedError(
|
|
|
|
|
82 |
|
83 |
def forward(self, x):
|
84 |
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
@@ -99,51 +148,93 @@ class AMPBlock1(torch.nn.Module):
|
|
99 |
|
100 |
|
101 |
class AMPBlock2(torch.nn.Module):
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
self.h = h
|
105 |
|
106 |
-
self.convs = nn.ModuleList(
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
self.convs.apply(init_weights)
|
113 |
|
114 |
-
self.num_layers = len(self.convs)
|
115 |
|
116 |
-
#
|
117 |
if self.h.get("use_cuda_kernel", False):
|
118 |
-
|
119 |
-
|
|
|
|
|
120 |
Activation1d = CudaActivation1d
|
121 |
else:
|
122 |
Activation1d = TorchActivation1d
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
else:
|
137 |
-
raise NotImplementedError(
|
|
|
|
|
138 |
|
139 |
def forward(self, x):
|
140 |
-
for c, a in zip
|
141 |
xt = a(x)
|
142 |
xt = c(xt)
|
143 |
x = xt + x
|
144 |
|
145 |
-
return x
|
146 |
-
|
147 |
def remove_weight_norm(self):
|
148 |
for l in self.convs:
|
149 |
remove_weight_norm(l)
|
@@ -157,83 +248,121 @@ class BigVGAN(
|
|
157 |
docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
|
158 |
pipeline_tag="audio-to-audio",
|
159 |
license="mit",
|
160 |
-
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"]
|
161 |
):
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
h
|
168 |
-
use_cuda_kernel:
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
self.h = h
|
172 |
-
self.h["use_cuda_kernel"] = use_cuda_kernel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
self.num_kernels = len(h.resblock_kernel_sizes)
|
175 |
self.num_upsamples = len(h.upsample_rates)
|
176 |
|
177 |
-
#
|
178 |
-
self.conv_pre = weight_norm(
|
|
|
|
|
179 |
|
180 |
-
#
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
-
#
|
184 |
self.ups = nn.ModuleList()
|
185 |
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
186 |
-
self.ups.append(
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
-
#
|
193 |
self.resblocks = nn.ModuleList()
|
194 |
for i in range(len(self.ups)):
|
195 |
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
196 |
-
for j, (k, d) in enumerate(
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
from alias_free_cuda.activation1d import Activation1d as CudaActivation1d
|
203 |
-
Activation1d = CudaActivation1d
|
204 |
-
else:
|
205 |
-
Activation1d = TorchActivation1d
|
206 |
|
207 |
-
#
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
219 |
-
self.conv_post = weight_norm(
|
220 |
-
ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final
|
221 |
-
)
|
222 |
|
223 |
-
#
|
224 |
for i in range(len(self.ups)):
|
225 |
self.ups[i].apply(init_weights)
|
226 |
self.conv_post.apply(init_weights)
|
227 |
-
|
228 |
-
#
|
229 |
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
|
230 |
|
231 |
def forward(self, x):
|
232 |
-
#
|
233 |
x = self.conv_pre(x)
|
234 |
|
235 |
for i in range(self.num_upsamples):
|
236 |
-
#
|
237 |
for i_up in range(len(self.ups[i])):
|
238 |
x = self.ups[i][i_up](x)
|
239 |
# AMP blocks
|
@@ -245,20 +374,20 @@ class BigVGAN(
|
|
245 |
xs += self.resblocks[i * self.num_kernels + j](x)
|
246 |
x = xs / self.num_kernels
|
247 |
|
248 |
-
#
|
249 |
x = self.activation_post(x)
|
250 |
x = self.conv_post(x)
|
251 |
-
#
|
252 |
if self.use_tanh_at_final:
|
253 |
x = torch.tanh(x)
|
254 |
else:
|
255 |
-
x = torch.clamp(x, min=-1
|
256 |
|
257 |
return x
|
258 |
|
259 |
def remove_weight_norm(self):
|
260 |
try:
|
261 |
-
print(
|
262 |
for l in self.ups:
|
263 |
for l_i in l:
|
264 |
remove_weight_norm(l_i)
|
@@ -267,23 +396,18 @@ class BigVGAN(
|
|
267 |
remove_weight_norm(self.conv_pre)
|
268 |
remove_weight_norm(self.conv_post)
|
269 |
except ValueError:
|
270 |
-
print(
|
271 |
pass
|
272 |
|
273 |
-
|
274 |
-
# additional methods for huggingface_hub support
|
275 |
-
##################################################################
|
276 |
def _save_pretrained(self, save_directory: Path) -> None:
|
277 |
"""Save weights and config.json from a Pytorch model to a local directory."""
|
278 |
|
279 |
-
model_path = save_directory /
|
280 |
-
torch.save(
|
281 |
-
{'generator': self.state_dict()},
|
282 |
-
model_path
|
283 |
-
)
|
284 |
|
285 |
-
config_path = save_directory /
|
286 |
-
with open(config_path,
|
287 |
json.dump(self.h, config_file, indent=4)
|
288 |
|
289 |
@classmethod
|
@@ -298,23 +422,21 @@ class BigVGAN(
|
|
298 |
resume_download: bool,
|
299 |
local_files_only: bool,
|
300 |
token: Union[str, bool, None],
|
301 |
-
map_location: str = "cpu",
|
302 |
-
strict: bool = False,
|
303 |
use_cuda_kernel: bool = False,
|
304 |
**model_kwargs,
|
305 |
):
|
306 |
"""Load Pytorch pretrained weights and return the loaded model."""
|
307 |
|
308 |
-
|
309 |
-
# download and load hyperparameters (h) used by BigVGAN
|
310 |
-
##################################################################
|
311 |
if os.path.isdir(model_id):
|
312 |
print("Loading config.json from local directory")
|
313 |
-
config_file = os.path.join(model_id,
|
314 |
else:
|
315 |
config_file = hf_hub_download(
|
316 |
repo_id=model_id,
|
317 |
-
filename=
|
318 |
revision=revision,
|
319 |
cache_dir=cache_dir,
|
320 |
force_download=force_download,
|
@@ -325,26 +447,28 @@ class BigVGAN(
|
|
325 |
)
|
326 |
h = load_hparams_from_json(config_file)
|
327 |
|
328 |
-
##################################################################
|
329 |
# instantiate BigVGAN using h
|
330 |
-
##################################################################
|
331 |
if use_cuda_kernel:
|
332 |
-
print(
|
333 |
-
|
334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
336 |
|
337 |
-
|
338 |
-
# download and load pretrained generator weight
|
339 |
-
##################################################################
|
340 |
if os.path.isdir(model_id):
|
341 |
print("Loading weights from local directory")
|
342 |
-
model_file = os.path.join(model_id,
|
343 |
else:
|
344 |
print(f"Loading weights from {model_id}")
|
345 |
model_file = hf_hub_download(
|
346 |
repo_id=model_id,
|
347 |
-
filename=
|
348 |
revision=revision,
|
349 |
cache_dir=cache_dir,
|
350 |
force_download=force_download,
|
@@ -352,15 +476,17 @@ class BigVGAN(
|
|
352 |
resume_download=resume_download,
|
353 |
token=token,
|
354 |
local_files_only=local_files_only,
|
355 |
-
|
356 |
-
|
357 |
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
358 |
|
359 |
try:
|
360 |
-
model.load_state_dict(checkpoint_dict[
|
361 |
except RuntimeError:
|
362 |
-
print(
|
|
|
|
|
363 |
model.remove_weight_norm()
|
364 |
-
model.load_state_dict(checkpoint_dict[
|
365 |
|
366 |
-
return model
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
# Licensed under the MIT license.
|
3 |
|
4 |
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
|
|
7 |
import os
|
8 |
import json
|
9 |
from pathlib import Path
|
10 |
+
from typing import Optional, Union, Dict
|
|
|
|
|
11 |
|
12 |
import torch
|
|
|
13 |
import torch.nn as nn
|
14 |
from torch.nn import Conv1d, ConvTranspose1d
|
15 |
from torch.nn.utils import weight_norm, remove_weight_norm
|
16 |
|
17 |
import activations
|
18 |
from utils import init_weights, get_padding
|
19 |
+
from alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
20 |
from env import AttrDict
|
21 |
|
22 |
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
23 |
|
24 |
+
|
25 |
def load_hparams_from_json(path) -> AttrDict:
|
26 |
with open(path) as f:
|
27 |
data = f.read()
|
28 |
+
return AttrDict(json.loads(data))
|
29 |
+
|
30 |
|
31 |
class AMPBlock1(torch.nn.Module):
|
32 |
+
"""
|
33 |
+
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
34 |
+
AMPBlock1 has additional self.convs2 that contains additional Conv1d layers with a fixed dilation=1 followed by each layer in self.convs1
|
35 |
+
|
36 |
+
Args:
|
37 |
+
h (AttrDict): Hyperparameters.
|
38 |
+
channels (int): Number of convolution channels.
|
39 |
+
kernel_size (int): Size of the convolution kernel. Default is 3.
|
40 |
+
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
41 |
+
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
h: AttrDict,
|
47 |
+
channels: int,
|
48 |
+
kernel_size: int = 3,
|
49 |
+
dilation: tuple = (1, 3, 5),
|
50 |
+
activation: str = None,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
self.h = h
|
55 |
|
56 |
+
self.convs1 = nn.ModuleList(
|
57 |
+
[
|
58 |
+
weight_norm(
|
59 |
+
Conv1d(
|
60 |
+
channels,
|
61 |
+
channels,
|
62 |
+
kernel_size,
|
63 |
+
stride=1,
|
64 |
+
dilation=d,
|
65 |
+
padding=get_padding(kernel_size, d),
|
66 |
+
)
|
67 |
+
)
|
68 |
+
for d in dilation
|
69 |
+
]
|
70 |
+
)
|
71 |
self.convs1.apply(init_weights)
|
72 |
|
73 |
+
self.convs2 = nn.ModuleList(
|
74 |
+
[
|
75 |
+
weight_norm(
|
76 |
+
Conv1d(
|
77 |
+
channels,
|
78 |
+
channels,
|
79 |
+
kernel_size,
|
80 |
+
stride=1,
|
81 |
+
dilation=1,
|
82 |
+
padding=get_padding(kernel_size, 1),
|
83 |
+
)
|
84 |
+
)
|
85 |
+
for _ in range(len(dilation))
|
86 |
+
]
|
87 |
+
)
|
88 |
self.convs2.apply(init_weights)
|
89 |
|
90 |
+
self.num_layers = len(self.convs1) + len(
|
91 |
+
self.convs2
|
92 |
+
) # Total number of conv layers
|
93 |
|
94 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
95 |
if self.h.get("use_cuda_kernel", False):
|
96 |
+
from alias_free_activation.cuda.activation1d import (
|
97 |
+
Activation1d as CudaActivation1d,
|
98 |
+
)
|
99 |
+
|
100 |
Activation1d = CudaActivation1d
|
101 |
else:
|
102 |
Activation1d = TorchActivation1d
|
103 |
|
104 |
+
# Activation functions
|
105 |
+
if activation == "snake":
|
106 |
+
self.activations = nn.ModuleList(
|
107 |
+
[
|
108 |
+
Activation1d(
|
109 |
+
activation=activations.Snake(
|
110 |
+
channels, alpha_logscale=h.snake_logscale
|
111 |
+
)
|
112 |
+
)
|
113 |
+
for _ in range(self.num_layers)
|
114 |
+
]
|
115 |
+
)
|
116 |
+
elif activation == "snakebeta":
|
117 |
+
self.activations = nn.ModuleList(
|
118 |
+
[
|
119 |
+
Activation1d(
|
120 |
+
activation=activations.SnakeBeta(
|
121 |
+
channels, alpha_logscale=h.snake_logscale
|
122 |
+
)
|
123 |
+
)
|
124 |
+
for _ in range(self.num_layers)
|
125 |
+
]
|
126 |
+
)
|
127 |
else:
|
128 |
+
raise NotImplementedError(
|
129 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
130 |
+
)
|
131 |
|
132 |
def forward(self, x):
|
133 |
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
|
|
148 |
|
149 |
|
150 |
class AMPBlock2(torch.nn.Module):
|
151 |
+
"""
|
152 |
+
AMPBlock applies Snake / SnakeBeta activation functions with trainable parameters that control periodicity, defined for each layer.
|
153 |
+
Unlike AMPBlock1, AMPBlock2 does not contain extra Conv1d layers with fixed dilation=1
|
154 |
+
|
155 |
+
Args:
|
156 |
+
h (AttrDict): Hyperparameters.
|
157 |
+
channels (int): Number of convolution channels.
|
158 |
+
kernel_size (int): Size of the convolution kernel. Default is 3.
|
159 |
+
dilation (tuple): Dilation rates for the convolutions. Each dilation layer has two convolutions. Default is (1, 3, 5).
|
160 |
+
activation (str): Activation function type. Should be either 'snake' or 'snakebeta'. Default is None.
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(
|
164 |
+
self,
|
165 |
+
h: AttrDict,
|
166 |
+
channels: int,
|
167 |
+
kernel_size: int = 3,
|
168 |
+
dilation: tuple = (1, 3, 5),
|
169 |
+
activation: str = None,
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
|
173 |
self.h = h
|
174 |
|
175 |
+
self.convs = nn.ModuleList(
|
176 |
+
[
|
177 |
+
weight_norm(
|
178 |
+
Conv1d(
|
179 |
+
channels,
|
180 |
+
channels,
|
181 |
+
kernel_size,
|
182 |
+
stride=1,
|
183 |
+
dilation=d,
|
184 |
+
padding=get_padding(kernel_size, d),
|
185 |
+
)
|
186 |
+
)
|
187 |
+
for d in dilation
|
188 |
+
]
|
189 |
+
)
|
190 |
self.convs.apply(init_weights)
|
191 |
|
192 |
+
self.num_layers = len(self.convs) # Total number of conv layers
|
193 |
|
194 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
195 |
if self.h.get("use_cuda_kernel", False):
|
196 |
+
from alias_free_activation.cuda.activation1d import (
|
197 |
+
Activation1d as CudaActivation1d,
|
198 |
+
)
|
199 |
+
|
200 |
Activation1d = CudaActivation1d
|
201 |
else:
|
202 |
Activation1d = TorchActivation1d
|
203 |
|
204 |
+
# Activation functions
|
205 |
+
if activation == "snake":
|
206 |
+
self.activations = nn.ModuleList(
|
207 |
+
[
|
208 |
+
Activation1d(
|
209 |
+
activation=activations.Snake(
|
210 |
+
channels, alpha_logscale=h.snake_logscale
|
211 |
+
)
|
212 |
+
)
|
213 |
+
for _ in range(self.num_layers)
|
214 |
+
]
|
215 |
+
)
|
216 |
+
elif activation == "snakebeta":
|
217 |
+
self.activations = nn.ModuleList(
|
218 |
+
[
|
219 |
+
Activation1d(
|
220 |
+
activation=activations.SnakeBeta(
|
221 |
+
channels, alpha_logscale=h.snake_logscale
|
222 |
+
)
|
223 |
+
)
|
224 |
+
for _ in range(self.num_layers)
|
225 |
+
]
|
226 |
+
)
|
227 |
else:
|
228 |
+
raise NotImplementedError(
|
229 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
230 |
+
)
|
231 |
|
232 |
def forward(self, x):
|
233 |
+
for c, a in zip(self.convs, self.activations):
|
234 |
xt = a(x)
|
235 |
xt = c(xt)
|
236 |
x = xt + x
|
237 |
|
|
|
|
|
238 |
def remove_weight_norm(self):
|
239 |
for l in self.convs:
|
240 |
remove_weight_norm(l)
|
|
|
248 |
docs_url="https://github.com/NVIDIA/BigVGAN/blob/main/README.md",
|
249 |
pipeline_tag="audio-to-audio",
|
250 |
license="mit",
|
251 |
+
tags=["neural-vocoder", "audio-generation", "arxiv:2206.04658"],
|
252 |
):
|
253 |
+
"""
|
254 |
+
BigVGAN is a neural vocoder model that applies anti-aliased periodic activation for residual blocks (resblocks).
|
255 |
+
New in BigVGAN-v2: it can optionally use optimized CUDA kernels for AMP (anti-aliased multi-periodicity) blocks.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
h (AttrDict): Hyperparameters.
|
259 |
+
use_cuda_kernel (bool): If set to True, loads optimized CUDA kernels for AMP. This should be used for inference only, as training is not supported with CUDA kernels.
|
260 |
+
|
261 |
+
Note:
|
262 |
+
- The `use_cuda_kernel` parameter should be used for inference only, as training with CUDA kernels is not supported.
|
263 |
+
- Ensure that the activation function is correctly specified in the hyperparameters (h.activation).
|
264 |
+
"""
|
265 |
+
|
266 |
+
def __init__(self, h: AttrDict, use_cuda_kernel: bool = False):
|
267 |
+
super().__init__()
|
268 |
self.h = h
|
269 |
+
self.h["use_cuda_kernel"] = use_cuda_kernel
|
270 |
+
|
271 |
+
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
272 |
+
if self.h.get("use_cuda_kernel", False):
|
273 |
+
from alias_free_activation.cuda.activation1d import (
|
274 |
+
Activation1d as CudaActivation1d,
|
275 |
+
)
|
276 |
+
|
277 |
+
Activation1d = CudaActivation1d
|
278 |
+
else:
|
279 |
+
Activation1d = TorchActivation1d
|
280 |
|
281 |
self.num_kernels = len(h.resblock_kernel_sizes)
|
282 |
self.num_upsamples = len(h.upsample_rates)
|
283 |
|
284 |
+
# Pre-conv
|
285 |
+
self.conv_pre = weight_norm(
|
286 |
+
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
287 |
+
)
|
288 |
|
289 |
+
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
290 |
+
if h.resblock == "1":
|
291 |
+
resblock_class = AMPBlock1
|
292 |
+
elif h.resblock == "2":
|
293 |
+
resblock_class = AMPBlock2
|
294 |
+
else:
|
295 |
+
raise ValueError(
|
296 |
+
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
297 |
+
)
|
298 |
|
299 |
+
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
300 |
self.ups = nn.ModuleList()
|
301 |
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
302 |
+
self.ups.append(
|
303 |
+
nn.ModuleList(
|
304 |
+
[
|
305 |
+
weight_norm(
|
306 |
+
ConvTranspose1d(
|
307 |
+
h.upsample_initial_channel // (2**i),
|
308 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
309 |
+
k,
|
310 |
+
u,
|
311 |
+
padding=(k - u) // 2,
|
312 |
+
)
|
313 |
+
)
|
314 |
+
]
|
315 |
+
)
|
316 |
+
)
|
317 |
|
318 |
+
# Residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
319 |
self.resblocks = nn.ModuleList()
|
320 |
for i in range(len(self.ups)):
|
321 |
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
322 |
+
for j, (k, d) in enumerate(
|
323 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
324 |
+
):
|
325 |
+
self.resblocks.append(
|
326 |
+
resblock_class(h, ch, k, d, activation=h.activation)
|
327 |
+
)
|
|
|
|
|
|
|
|
|
328 |
|
329 |
+
# Post-conv
|
330 |
+
activation_post = (
|
331 |
+
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
332 |
+
if h.activation == "snake"
|
333 |
+
else (
|
334 |
+
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
335 |
+
if h.activation == "snakebeta"
|
336 |
+
else None
|
337 |
+
)
|
338 |
+
)
|
339 |
+
if activation_post is None:
|
340 |
+
raise NotImplementedError(
|
341 |
+
"activation incorrectly specified. check the config file and look for 'activation'."
|
342 |
+
)
|
343 |
+
|
344 |
+
self.activation_post = Activation1d(activation=activation_post)
|
345 |
+
|
346 |
+
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
347 |
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
348 |
+
self.conv_post = weight_norm(
|
349 |
+
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
350 |
+
)
|
351 |
|
352 |
+
# Weight initialization
|
353 |
for i in range(len(self.ups)):
|
354 |
self.ups[i].apply(init_weights)
|
355 |
self.conv_post.apply(init_weights)
|
356 |
+
|
357 |
+
# Final tanh activation. Defaults to True for backward compatibility
|
358 |
self.use_tanh_at_final = h.get("use_tanh_at_final", True)
|
359 |
|
360 |
def forward(self, x):
|
361 |
+
# Pre-conv
|
362 |
x = self.conv_pre(x)
|
363 |
|
364 |
for i in range(self.num_upsamples):
|
365 |
+
# Upsampling
|
366 |
for i_up in range(len(self.ups[i])):
|
367 |
x = self.ups[i][i_up](x)
|
368 |
# AMP blocks
|
|
|
374 |
xs += self.resblocks[i * self.num_kernels + j](x)
|
375 |
x = xs / self.num_kernels
|
376 |
|
377 |
+
# Post-conv
|
378 |
x = self.activation_post(x)
|
379 |
x = self.conv_post(x)
|
380 |
+
# Final tanh activation
|
381 |
if self.use_tanh_at_final:
|
382 |
x = torch.tanh(x)
|
383 |
else:
|
384 |
+
x = torch.clamp(x, min=-1.0, max=1.0) # Bound the output to [-1, 1]
|
385 |
|
386 |
return x
|
387 |
|
388 |
def remove_weight_norm(self):
|
389 |
try:
|
390 |
+
print("Removing weight norm...")
|
391 |
for l in self.ups:
|
392 |
for l_i in l:
|
393 |
remove_weight_norm(l_i)
|
|
|
396 |
remove_weight_norm(self.conv_pre)
|
397 |
remove_weight_norm(self.conv_post)
|
398 |
except ValueError:
|
399 |
+
print("[INFO] Model already removed weight norm. Skipping!")
|
400 |
pass
|
401 |
|
402 |
+
# Additional methods for huggingface_hub support
|
|
|
|
|
403 |
def _save_pretrained(self, save_directory: Path) -> None:
|
404 |
"""Save weights and config.json from a Pytorch model to a local directory."""
|
405 |
|
406 |
+
model_path = save_directory / "bigvgan_generator.pt"
|
407 |
+
torch.save({"generator": self.state_dict()}, model_path)
|
|
|
|
|
|
|
408 |
|
409 |
+
config_path = save_directory / "config.json"
|
410 |
+
with open(config_path, "w") as config_file:
|
411 |
json.dump(self.h, config_file, indent=4)
|
412 |
|
413 |
@classmethod
|
|
|
422 |
resume_download: bool,
|
423 |
local_files_only: bool,
|
424 |
token: Union[str, bool, None],
|
425 |
+
map_location: str = "cpu", # Additional argument
|
426 |
+
strict: bool = False, # Additional argument
|
427 |
use_cuda_kernel: bool = False,
|
428 |
**model_kwargs,
|
429 |
):
|
430 |
"""Load Pytorch pretrained weights and return the loaded model."""
|
431 |
|
432 |
+
# Download and load hyperparameters (h) used by BigVGAN
|
|
|
|
|
433 |
if os.path.isdir(model_id):
|
434 |
print("Loading config.json from local directory")
|
435 |
+
config_file = os.path.join(model_id, "config.json")
|
436 |
else:
|
437 |
config_file = hf_hub_download(
|
438 |
repo_id=model_id,
|
439 |
+
filename="config.json",
|
440 |
revision=revision,
|
441 |
cache_dir=cache_dir,
|
442 |
force_download=force_download,
|
|
|
447 |
)
|
448 |
h = load_hparams_from_json(config_file)
|
449 |
|
|
|
450 |
# instantiate BigVGAN using h
|
|
|
451 |
if use_cuda_kernel:
|
452 |
+
print(
|
453 |
+
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
454 |
+
)
|
455 |
+
print(
|
456 |
+
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
457 |
+
)
|
458 |
+
print(
|
459 |
+
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
460 |
+
)
|
461 |
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
462 |
|
463 |
+
# Download and load pretrained generator weight
|
|
|
|
|
464 |
if os.path.isdir(model_id):
|
465 |
print("Loading weights from local directory")
|
466 |
+
model_file = os.path.join(model_id, "bigvgan_generator.pt")
|
467 |
else:
|
468 |
print(f"Loading weights from {model_id}")
|
469 |
model_file = hf_hub_download(
|
470 |
repo_id=model_id,
|
471 |
+
filename="bigvgan_generator.pt",
|
472 |
revision=revision,
|
473 |
cache_dir=cache_dir,
|
474 |
force_download=force_download,
|
|
|
476 |
resume_download=resume_download,
|
477 |
token=token,
|
478 |
local_files_only=local_files_only,
|
479 |
+
)
|
480 |
+
|
481 |
checkpoint_dict = torch.load(model_file, map_location=map_location)
|
482 |
|
483 |
try:
|
484 |
+
model.load_state_dict(checkpoint_dict["generator"])
|
485 |
except RuntimeError:
|
486 |
+
print(
|
487 |
+
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
488 |
+
)
|
489 |
model.remove_weight_norm()
|
490 |
+
model.load_state_dict(checkpoint_dict["generator"])
|
491 |
|
492 |
+
return model
|
meldataset.py
CHANGED
@@ -1,66 +1,354 @@
|
|
1 |
-
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
# Licensed under the MIT license.
|
3 |
|
4 |
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
5 |
# LICENSE is in incl_licenses directory.
|
6 |
|
|
|
|
|
|
|
7 |
import torch
|
8 |
import torch.utils.data
|
9 |
import numpy as np
|
|
|
10 |
from scipy.io.wavfile import read
|
11 |
from librosa.filters import mel as librosa_mel_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
|
14 |
|
15 |
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
|
|
|
18 |
def dynamic_range_decompression(x, C=1):
|
19 |
return np.exp(x) / C
|
20 |
|
|
|
21 |
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
22 |
return torch.log(torch.clamp(x, min=clip_val) * C)
|
23 |
|
|
|
24 |
def dynamic_range_decompression_torch(x, C=1):
|
25 |
return torch.exp(x) / C
|
26 |
|
|
|
27 |
def spectral_normalize_torch(magnitudes):
|
28 |
-
|
29 |
-
|
30 |
|
31 |
def spectral_de_normalize_torch(magnitudes):
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
34 |
|
35 |
-
mel_basis = {}
|
36 |
-
hann_window = {}
|
37 |
|
38 |
-
def mel_spectrogram(
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
57 |
-
spec = torch.view_as_real(spec)
|
58 |
-
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
59 |
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
return spec
|
64 |
|
65 |
def get_mel_spectrogram(wav, h):
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 NVIDIA CORPORATION.
|
2 |
# Licensed under the MIT license.
|
3 |
|
4 |
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
5 |
# LICENSE is in incl_licenses directory.
|
6 |
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import random
|
10 |
import torch
|
11 |
import torch.utils.data
|
12 |
import numpy as np
|
13 |
+
from librosa.util import normalize
|
14 |
from scipy.io.wavfile import read
|
15 |
from librosa.filters import mel as librosa_mel_fn
|
16 |
+
import pathlib
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
|
20 |
+
|
21 |
+
|
22 |
+
def load_wav(full_path, sr_target):
|
23 |
+
sampling_rate, data = read(full_path)
|
24 |
+
if sampling_rate != sr_target:
|
25 |
+
raise RuntimeError(
|
26 |
+
f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz"
|
27 |
+
)
|
28 |
+
return data, sampling_rate
|
29 |
|
|
|
30 |
|
31 |
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
32 |
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
33 |
|
34 |
+
|
35 |
def dynamic_range_decompression(x, C=1):
|
36 |
return np.exp(x) / C
|
37 |
|
38 |
+
|
39 |
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
40 |
return torch.log(torch.clamp(x, min=clip_val) * C)
|
41 |
|
42 |
+
|
43 |
def dynamic_range_decompression_torch(x, C=1):
|
44 |
return torch.exp(x) / C
|
45 |
|
46 |
+
|
47 |
def spectral_normalize_torch(magnitudes):
|
48 |
+
return dynamic_range_compression_torch(magnitudes)
|
49 |
+
|
50 |
|
51 |
def spectral_de_normalize_torch(magnitudes):
|
52 |
+
return dynamic_range_decompression_torch(magnitudes)
|
53 |
+
|
54 |
+
|
55 |
+
mel_basis_cache = {}
|
56 |
+
hann_window_cache = {}
|
57 |
|
|
|
|
|
58 |
|
59 |
+
def mel_spectrogram(
|
60 |
+
y: torch.Tensor,
|
61 |
+
n_fft: int,
|
62 |
+
num_mels: int,
|
63 |
+
sampling_rate: int,
|
64 |
+
hop_size: int,
|
65 |
+
win_size: int,
|
66 |
+
fmin: int,
|
67 |
+
fmax: int = None,
|
68 |
+
center: bool = False,
|
69 |
+
) -> torch.Tensor:
|
70 |
+
"""
|
71 |
+
Calculate the mel spectrogram of an input signal.
|
72 |
+
This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
|
73 |
|
74 |
+
Args:
|
75 |
+
y (torch.Tensor): Input signal.
|
76 |
+
n_fft (int): FFT size.
|
77 |
+
num_mels (int): Number of mel bins.
|
78 |
+
sampling_rate (int): Sampling rate of the input signal.
|
79 |
+
hop_size (int): Hop size for STFT.
|
80 |
+
win_size (int): Window size for STFT.
|
81 |
+
fmin (int): Minimum frequency for mel filterbank.
|
82 |
+
fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
|
83 |
+
center (bool): Whether to pad the input to center the frames. Default is False.
|
84 |
|
85 |
+
Returns:
|
86 |
+
torch.Tensor: Mel spectrogram.
|
87 |
+
"""
|
88 |
+
if torch.min(y) < -1.0:
|
89 |
+
print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
|
90 |
+
if torch.max(y) > 1.0:
|
91 |
+
print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
|
92 |
|
93 |
+
device = y.device
|
94 |
+
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
|
|
|
|
|
|
95 |
|
96 |
+
if key not in mel_basis_cache:
|
97 |
+
mel = librosa_mel_fn(
|
98 |
+
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
99 |
+
)
|
100 |
+
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
101 |
+
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
102 |
+
|
103 |
+
mel_basis = mel_basis_cache[key]
|
104 |
+
hann_window = hann_window_cache[key]
|
105 |
+
|
106 |
+
padding = (n_fft - hop_size) // 2
|
107 |
+
y = torch.nn.functional.pad(
|
108 |
+
y.unsqueeze(1), (padding, padding), mode="reflect"
|
109 |
+
).squeeze(1)
|
110 |
+
|
111 |
+
spec = torch.stft(
|
112 |
+
y,
|
113 |
+
n_fft,
|
114 |
+
hop_length=hop_size,
|
115 |
+
win_length=win_size,
|
116 |
+
window=hann_window,
|
117 |
+
center=center,
|
118 |
+
pad_mode="reflect",
|
119 |
+
normalized=False,
|
120 |
+
onesided=True,
|
121 |
+
return_complex=True,
|
122 |
+
)
|
123 |
+
spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
|
124 |
+
|
125 |
+
mel_spec = torch.matmul(mel_basis, spec)
|
126 |
+
mel_spec = spectral_normalize_torch(mel_spec)
|
127 |
+
|
128 |
+
return mel_spec
|
129 |
|
|
|
130 |
|
131 |
def get_mel_spectrogram(wav, h):
|
132 |
+
"""
|
133 |
+
Generate mel spectrogram from a waveform using given hyperparameters.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
wav (torch.Tensor): Input waveform.
|
137 |
+
h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
torch.Tensor: Mel spectrogram.
|
141 |
+
"""
|
142 |
+
return mel_spectrogram(
|
143 |
+
wav,
|
144 |
+
h.n_fft,
|
145 |
+
h.num_mels,
|
146 |
+
h.sampling_rate,
|
147 |
+
h.hop_size,
|
148 |
+
h.win_size,
|
149 |
+
h.fmin,
|
150 |
+
h.fmax,
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
def get_dataset_filelist(a):
|
155 |
+
training_files = []
|
156 |
+
validation_files = []
|
157 |
+
list_unseen_validation_files = []
|
158 |
+
|
159 |
+
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
160 |
+
training_files = [
|
161 |
+
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
162 |
+
for x in fi.read().split("\n")
|
163 |
+
if len(x) > 0
|
164 |
+
]
|
165 |
+
print(f"first training file: {training_files[0]}")
|
166 |
+
|
167 |
+
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
168 |
+
validation_files = [
|
169 |
+
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
170 |
+
for x in fi.read().split("\n")
|
171 |
+
if len(x) > 0
|
172 |
+
]
|
173 |
+
print(f"first validation file: {validation_files[0]}")
|
174 |
+
|
175 |
+
for i in range(len(a.list_input_unseen_validation_file)):
|
176 |
+
with open(a.list_input_unseen_validation_file[i], "r", encoding="utf-8") as fi:
|
177 |
+
unseen_validation_files = [
|
178 |
+
os.path.join(a.list_input_unseen_wavs_dir[i], x.split("|")[0] + ".wav")
|
179 |
+
for x in fi.read().split("\n")
|
180 |
+
if len(x) > 0
|
181 |
+
]
|
182 |
+
print(
|
183 |
+
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
|
184 |
+
)
|
185 |
+
list_unseen_validation_files.append(unseen_validation_files)
|
186 |
+
|
187 |
+
return training_files, validation_files, list_unseen_validation_files
|
188 |
+
|
189 |
+
|
190 |
+
class MelDataset(torch.utils.data.Dataset):
|
191 |
+
def __init__(
|
192 |
+
self,
|
193 |
+
training_files,
|
194 |
+
hparams,
|
195 |
+
segment_size,
|
196 |
+
n_fft,
|
197 |
+
num_mels,
|
198 |
+
hop_size,
|
199 |
+
win_size,
|
200 |
+
sampling_rate,
|
201 |
+
fmin,
|
202 |
+
fmax,
|
203 |
+
split=True,
|
204 |
+
shuffle=True,
|
205 |
+
n_cache_reuse=1,
|
206 |
+
device=None,
|
207 |
+
fmax_loss=None,
|
208 |
+
fine_tuning=False,
|
209 |
+
base_mels_path=None,
|
210 |
+
is_seen=True,
|
211 |
+
):
|
212 |
+
self.audio_files = training_files
|
213 |
+
random.seed(1234)
|
214 |
+
if shuffle:
|
215 |
+
random.shuffle(self.audio_files)
|
216 |
+
self.hparams = hparams
|
217 |
+
self.is_seen = is_seen
|
218 |
+
if self.is_seen:
|
219 |
+
self.name = pathlib.Path(self.audio_files[0]).parts[0]
|
220 |
+
else:
|
221 |
+
self.name = "-".join(pathlib.Path(self.audio_files[0]).parts[:2]).strip("/")
|
222 |
+
|
223 |
+
self.segment_size = segment_size
|
224 |
+
self.sampling_rate = sampling_rate
|
225 |
+
self.split = split
|
226 |
+
self.n_fft = n_fft
|
227 |
+
self.num_mels = num_mels
|
228 |
+
self.hop_size = hop_size
|
229 |
+
self.win_size = win_size
|
230 |
+
self.fmin = fmin
|
231 |
+
self.fmax = fmax
|
232 |
+
self.fmax_loss = fmax_loss
|
233 |
+
self.cached_wav = None
|
234 |
+
self.n_cache_reuse = n_cache_reuse
|
235 |
+
self._cache_ref_count = 0
|
236 |
+
self.device = device
|
237 |
+
self.fine_tuning = fine_tuning
|
238 |
+
self.base_mels_path = base_mels_path
|
239 |
+
|
240 |
+
print("[INFO] checking dataset integrity...")
|
241 |
+
for i in tqdm(range(len(self.audio_files))):
|
242 |
+
assert os.path.exists(
|
243 |
+
self.audio_files[i]
|
244 |
+
), f"{self.audio_files[i]} not found"
|
245 |
+
|
246 |
+
def __getitem__(self, index):
|
247 |
+
filename = self.audio_files[index]
|
248 |
+
if self._cache_ref_count == 0:
|
249 |
+
audio, sampling_rate = load_wav(filename, self.sampling_rate)
|
250 |
+
audio = audio / MAX_WAV_VALUE
|
251 |
+
if not self.fine_tuning:
|
252 |
+
audio = normalize(audio) * 0.95
|
253 |
+
self.cached_wav = audio
|
254 |
+
if sampling_rate != self.sampling_rate:
|
255 |
+
raise ValueError(
|
256 |
+
f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR"
|
257 |
+
)
|
258 |
+
self._cache_ref_count = self.n_cache_reuse
|
259 |
+
else:
|
260 |
+
audio = self.cached_wav
|
261 |
+
self._cache_ref_count -= 1
|
262 |
+
|
263 |
+
audio = torch.FloatTensor(audio)
|
264 |
+
audio = audio.unsqueeze(0)
|
265 |
+
|
266 |
+
if not self.fine_tuning:
|
267 |
+
if self.split:
|
268 |
+
if audio.size(1) >= self.segment_size:
|
269 |
+
max_audio_start = audio.size(1) - self.segment_size
|
270 |
+
audio_start = random.randint(0, max_audio_start)
|
271 |
+
audio = audio[:, audio_start : audio_start + self.segment_size]
|
272 |
+
else:
|
273 |
+
audio = torch.nn.functional.pad(
|
274 |
+
audio, (0, self.segment_size - audio.size(1)), "constant"
|
275 |
+
)
|
276 |
+
|
277 |
+
mel = mel_spectrogram(
|
278 |
+
audio,
|
279 |
+
self.n_fft,
|
280 |
+
self.num_mels,
|
281 |
+
self.sampling_rate,
|
282 |
+
self.hop_size,
|
283 |
+
self.win_size,
|
284 |
+
self.fmin,
|
285 |
+
self.fmax,
|
286 |
+
center=False,
|
287 |
+
)
|
288 |
+
else: # Validation step
|
289 |
+
# Match audio length to self.hop_size * n for evaluation
|
290 |
+
if (audio.size(1) % self.hop_size) != 0:
|
291 |
+
audio = audio[:, : -(audio.size(1) % self.hop_size)]
|
292 |
+
mel = mel_spectrogram(
|
293 |
+
audio,
|
294 |
+
self.n_fft,
|
295 |
+
self.num_mels,
|
296 |
+
self.sampling_rate,
|
297 |
+
self.hop_size,
|
298 |
+
self.win_size,
|
299 |
+
self.fmin,
|
300 |
+
self.fmax,
|
301 |
+
center=False,
|
302 |
+
)
|
303 |
+
assert (
|
304 |
+
audio.shape[1] == mel.shape[2] * self.hop_size
|
305 |
+
), f"audio shape {audio.shape} mel shape {mel.shape}"
|
306 |
+
|
307 |
+
else:
|
308 |
+
mel = np.load(
|
309 |
+
os.path.join(
|
310 |
+
self.base_mels_path,
|
311 |
+
os.path.splitext(os.path.split(filename)[-1])[0] + ".npy",
|
312 |
+
)
|
313 |
+
)
|
314 |
+
mel = torch.from_numpy(mel)
|
315 |
+
|
316 |
+
if len(mel.shape) < 3:
|
317 |
+
mel = mel.unsqueeze(0)
|
318 |
+
|
319 |
+
if self.split:
|
320 |
+
frames_per_seg = math.ceil(self.segment_size / self.hop_size)
|
321 |
+
|
322 |
+
if audio.size(1) >= self.segment_size:
|
323 |
+
mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
|
324 |
+
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
325 |
+
audio = audio[
|
326 |
+
:,
|
327 |
+
mel_start
|
328 |
+
* self.hop_size : (mel_start + frames_per_seg)
|
329 |
+
* self.hop_size,
|
330 |
+
]
|
331 |
+
else:
|
332 |
+
mel = torch.nn.functional.pad(
|
333 |
+
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
334 |
+
)
|
335 |
+
audio = torch.nn.functional.pad(
|
336 |
+
audio, (0, self.segment_size - audio.size(1)), "constant"
|
337 |
+
)
|
338 |
+
|
339 |
+
mel_loss = mel_spectrogram(
|
340 |
+
audio,
|
341 |
+
self.n_fft,
|
342 |
+
self.num_mels,
|
343 |
+
self.sampling_rate,
|
344 |
+
self.hop_size,
|
345 |
+
self.win_size,
|
346 |
+
self.fmin,
|
347 |
+
self.fmax_loss,
|
348 |
+
center=False,
|
349 |
+
)
|
350 |
+
|
351 |
+
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
352 |
+
|
353 |
+
def __len__(self):
|
354 |
+
return len(self.audio_files)
|
utils.py
CHANGED
@@ -6,6 +6,7 @@ import os
|
|
6 |
import matplotlib
|
7 |
import torch
|
8 |
from torch.nn.utils import weight_norm
|
|
|
9 |
matplotlib.use("Agg")
|
10 |
import matplotlib.pylab as plt
|
11 |
from meldataset import MAX_WAV_VALUE
|
@@ -14,8 +15,7 @@ from scipy.io.wavfile import write
|
|
14 |
|
15 |
def plot_spectrogram(spectrogram):
|
16 |
fig, ax = plt.subplots(figsize=(10, 2))
|
17 |
-
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
18 |
-
interpolation='none')
|
19 |
plt.colorbar(im, ax=ax)
|
20 |
|
21 |
fig.canvas.draw()
|
@@ -24,10 +24,16 @@ def plot_spectrogram(spectrogram):
|
|
24 |
return fig
|
25 |
|
26 |
|
27 |
-
def plot_spectrogram_clipped(spectrogram, clip_max=2.):
|
28 |
fig, ax = plt.subplots(figsize=(10, 2))
|
29 |
-
im = ax.imshow(
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
plt.colorbar(im, ax=ax)
|
32 |
|
33 |
fig.canvas.draw()
|
@@ -49,32 +55,45 @@ def apply_weight_norm(m):
|
|
49 |
|
50 |
|
51 |
def get_padding(kernel_size, dilation=1):
|
52 |
-
return int((kernel_size*dilation - dilation)/2)
|
53 |
|
54 |
|
55 |
def load_checkpoint(filepath, device):
|
56 |
assert os.path.isfile(filepath)
|
57 |
-
print("Loading '{}'"
|
58 |
checkpoint_dict = torch.load(filepath, map_location=device)
|
59 |
print("Complete.")
|
60 |
return checkpoint_dict
|
61 |
|
62 |
|
63 |
def save_checkpoint(filepath, obj):
|
64 |
-
print("Saving checkpoint to {}"
|
65 |
torch.save(obj, filepath)
|
66 |
print("Complete.")
|
67 |
|
68 |
|
69 |
-
def scan_checkpoint(cp_dir, prefix):
|
70 |
-
|
|
|
71 |
cp_list = glob.glob(pattern)
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
def save_audio(audio, path, sr):
|
77 |
# wav: torch with 1d shape
|
78 |
audio = audio * MAX_WAV_VALUE
|
79 |
-
audio = audio.cpu().numpy().astype(
|
80 |
-
write(path, sr, audio)
|
|
|
6 |
import matplotlib
|
7 |
import torch
|
8 |
from torch.nn.utils import weight_norm
|
9 |
+
|
10 |
matplotlib.use("Agg")
|
11 |
import matplotlib.pylab as plt
|
12 |
from meldataset import MAX_WAV_VALUE
|
|
|
15 |
|
16 |
def plot_spectrogram(spectrogram):
|
17 |
fig, ax = plt.subplots(figsize=(10, 2))
|
18 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
|
|
19 |
plt.colorbar(im, ax=ax)
|
20 |
|
21 |
fig.canvas.draw()
|
|
|
24 |
return fig
|
25 |
|
26 |
|
27 |
+
def plot_spectrogram_clipped(spectrogram, clip_max=2.0):
|
28 |
fig, ax = plt.subplots(figsize=(10, 2))
|
29 |
+
im = ax.imshow(
|
30 |
+
spectrogram,
|
31 |
+
aspect="auto",
|
32 |
+
origin="lower",
|
33 |
+
interpolation="none",
|
34 |
+
vmin=1e-6,
|
35 |
+
vmax=clip_max,
|
36 |
+
)
|
37 |
plt.colorbar(im, ax=ax)
|
38 |
|
39 |
fig.canvas.draw()
|
|
|
55 |
|
56 |
|
57 |
def get_padding(kernel_size, dilation=1):
|
58 |
+
return int((kernel_size * dilation - dilation) / 2)
|
59 |
|
60 |
|
61 |
def load_checkpoint(filepath, device):
|
62 |
assert os.path.isfile(filepath)
|
63 |
+
print(f"Loading '{filepath}'")
|
64 |
checkpoint_dict = torch.load(filepath, map_location=device)
|
65 |
print("Complete.")
|
66 |
return checkpoint_dict
|
67 |
|
68 |
|
69 |
def save_checkpoint(filepath, obj):
|
70 |
+
print(f"Saving checkpoint to {filepath}")
|
71 |
torch.save(obj, filepath)
|
72 |
print("Complete.")
|
73 |
|
74 |
|
75 |
+
def scan_checkpoint(cp_dir, prefix, renamed_file=None):
|
76 |
+
# Fallback to original scanning logic first
|
77 |
+
pattern = os.path.join(cp_dir, prefix + "????????")
|
78 |
cp_list = glob.glob(pattern)
|
79 |
+
|
80 |
+
if len(cp_list) > 0:
|
81 |
+
last_checkpoint_path = sorted(cp_list)[-1]
|
82 |
+
print(f"[INFO] Resuming from checkpoint: '{last_checkpoint_path}'")
|
83 |
+
return last_checkpoint_path
|
84 |
+
|
85 |
+
# If no pattern-based checkpoints are found, check for renamed file
|
86 |
+
if renamed_file:
|
87 |
+
renamed_path = os.path.join(cp_dir, renamed_file)
|
88 |
+
if os.path.isfile(renamed_path):
|
89 |
+
print(f"[INFO] Resuming from renamed checkpoint: '{renamed_file}'")
|
90 |
+
return renamed_path
|
91 |
+
|
92 |
+
return None
|
93 |
+
|
94 |
|
95 |
def save_audio(audio, path, sr):
|
96 |
# wav: torch with 1d shape
|
97 |
audio = audio * MAX_WAV_VALUE
|
98 |
+
audio = audio.cpu().numpy().astype("int16")
|
99 |
+
write(path, sr, audio)
|