Spaces:
Runtime error
Runtime error
Merge pull request #9 from teticio/latent-audio-diffusion
Browse files- .gitignore +5 -2
- README.md +43 -19
- audiodiffusion/__init__.py +138 -50
- audiodiffusion/utils.py +363 -0
- accelerate_deepspeed.yaml → config/accelerate_deepspeed.yaml +0 -0
- accelerate_local.yaml → config/accelerate_local.yaml +0 -0
- accelerate_sagemaker.yaml → config/accelerate_sagemaker.yaml +0 -0
- config/ldm_autoencoder_kl.yaml +31 -0
- notebooks/test_vae.ipynb +169 -0
- audio_to_images.py → scripts/audio_to_images.py +0 -0
- train_unconditional.py → scripts/train_unconditional.py +82 -28
- scripts/train_vae.py +166 -0
.gitignore
CHANGED
@@ -1,8 +1,11 @@
|
|
1 |
.vscode
|
2 |
__pycache__
|
3 |
.ipynb_checkpoints
|
4 |
-
data
|
5 |
-
|
6 |
flagged
|
7 |
build
|
8 |
audiodiffusion.egg-info
|
|
|
|
|
|
|
|
1 |
.vscode
|
2 |
__pycache__
|
3 |
.ipynb_checkpoints
|
4 |
+
data
|
5 |
+
models
|
6 |
flagged
|
7 |
build
|
8 |
audiodiffusion.egg-info
|
9 |
+
lightning_logs
|
10 |
+
taming
|
11 |
+
checkpoints
|
README.md
CHANGED
@@ -15,7 +15,10 @@ license: gpl-3.0
|
|
15 |
|
16 |
---
|
17 |
|
18 |
-
**UPDATES**:
|
|
|
|
|
|
|
19 |
|
20 |
4/10/2022
|
21 |
It is now possible to mask parts of the input audio during generation which means you can stitch several samples together (think "out-painting").
|
@@ -45,35 +48,39 @@ You can play around with some pretrained models on [Google Colab](https://colab.
|
|
45 |
---
|
46 |
|
47 |
## Generate Mel spectrogram dataset from directory of audio files
|
|
|
|
|
|
|
|
|
|
|
48 |
#### Training can be run with Mel spectrograms of resolution 64x64 on a single commercial grade GPU (e.g. RTX 2080 Ti). The `hop_length` should be set to 1024 for better results.
|
49 |
|
50 |
```bash
|
51 |
-
python audio_to_images.py \
|
52 |
--resolution 64 \
|
53 |
--hop_length 1024 \
|
54 |
--input_dir path-to-audio-files \
|
55 |
-
--output_dir data
|
56 |
```
|
57 |
|
58 |
#### Generate dataset of 256x256 Mel spectrograms and push to hub (you will need to be authenticated with `huggingface-cli login`).
|
59 |
-
|
60 |
```bash
|
61 |
-
python audio_to_images.py \
|
62 |
--resolution 256 \
|
63 |
--input_dir path-to-audio-files \
|
64 |
-
--output_dir data-256 \
|
65 |
--push_to_hub teticio/audio-diffusion-256
|
66 |
```
|
|
|
67 |
## Train model
|
68 |
#### Run training on local machine.
|
69 |
-
|
70 |
```bash
|
71 |
-
accelerate launch --config_file accelerate_local.yaml \
|
72 |
-
train_unconditional.py \
|
73 |
-
--dataset_name data-64 \
|
74 |
--resolution 64 \
|
75 |
--hop_length 1024 \
|
76 |
-
--output_dir ddpm-ema-audio-64 \
|
77 |
--train_batch_size 16 \
|
78 |
--num_epochs 100 \
|
79 |
--gradient_accumulation_steps 1 \
|
@@ -83,13 +90,12 @@ accelerate launch --config_file accelerate_local.yaml \
|
|
83 |
```
|
84 |
|
85 |
#### Run training on local machine with `batch_size` of 2 and `gradient_accumulation_steps` 8 to compensate, so that 256x256 resolution model fits on commercial grade GPU and push to hub.
|
86 |
-
|
87 |
```bash
|
88 |
-
accelerate launch --config_file accelerate_local.yaml \
|
89 |
-
train_unconditional.py \
|
90 |
--dataset_name teticio/audio-diffusion-256 \
|
91 |
--resolution 256 \
|
92 |
-
--output_dir
|
93 |
--num_epochs 100 \
|
94 |
--train_batch_size 2 \
|
95 |
--eval_batch_size 2 \
|
@@ -103,13 +109,12 @@ accelerate launch --config_file accelerate_local.yaml \
|
|
103 |
```
|
104 |
|
105 |
#### Run training on SageMaker.
|
106 |
-
|
107 |
```bash
|
108 |
-
accelerate launch --config_file accelerate_sagemaker.yaml \
|
109 |
-
|
110 |
--dataset_name teticio/audio-diffusion-256 \
|
111 |
--resolution 256 \
|
112 |
-
--output_dir ddpm-ema-audio-256 \
|
113 |
--train_batch_size 16 \
|
114 |
--num_epochs 100 \
|
115 |
--gradient_accumulation_steps 1 \
|
@@ -117,3 +122,22 @@ accelerate launch --config_file accelerate_sagemaker.yaml \
|
|
117 |
--lr_warmup_steps 500 \
|
118 |
--mixed_precision no
|
119 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
---
|
17 |
|
18 |
+
**UPDATES**:
|
19 |
+
|
20 |
+
15/10/2022
|
21 |
+
Added latent audio diffusion (see below). Also added the possibility to train a model to use DDIM ([Denoising Diffusion Implicit Models](https://arxiv.org/pdf/2010.02502.pdf)) by setting `--scheduler ddim`. These have the benefit that samples can be generated with much fewer steps (~50) than used in training.
|
22 |
|
23 |
4/10/2022
|
24 |
It is now possible to mask parts of the input audio during generation which means you can stitch several samples together (think "out-painting").
|
|
|
48 |
---
|
49 |
|
50 |
## Generate Mel spectrogram dataset from directory of audio files
|
51 |
+
#### Install
|
52 |
+
```bash
|
53 |
+
pip install .
|
54 |
+
```
|
55 |
+
|
56 |
#### Training can be run with Mel spectrograms of resolution 64x64 on a single commercial grade GPU (e.g. RTX 2080 Ti). The `hop_length` should be set to 1024 for better results.
|
57 |
|
58 |
```bash
|
59 |
+
python scripts/audio_to_images.py \
|
60 |
--resolution 64 \
|
61 |
--hop_length 1024 \
|
62 |
--input_dir path-to-audio-files \
|
63 |
+
--output_dir path-to-output-data
|
64 |
```
|
65 |
|
66 |
#### Generate dataset of 256x256 Mel spectrograms and push to hub (you will need to be authenticated with `huggingface-cli login`).
|
|
|
67 |
```bash
|
68 |
+
python scripts/audio_to_images.py \
|
69 |
--resolution 256 \
|
70 |
--input_dir path-to-audio-files \
|
71 |
+
--output_dir data/audio-diffusion-256 \
|
72 |
--push_to_hub teticio/audio-diffusion-256
|
73 |
```
|
74 |
+
|
75 |
## Train model
|
76 |
#### Run training on local machine.
|
|
|
77 |
```bash
|
78 |
+
accelerate launch --config_file config/accelerate_local.yaml \
|
79 |
+
scripts/train_unconditional.py \
|
80 |
+
--dataset_name data/audio-diffusion-64 \
|
81 |
--resolution 64 \
|
82 |
--hop_length 1024 \
|
83 |
+
--output_dir models/ddpm-ema-audio-64 \
|
84 |
--train_batch_size 16 \
|
85 |
--num_epochs 100 \
|
86 |
--gradient_accumulation_steps 1 \
|
|
|
90 |
```
|
91 |
|
92 |
#### Run training on local machine with `batch_size` of 2 and `gradient_accumulation_steps` 8 to compensate, so that 256x256 resolution model fits on commercial grade GPU and push to hub.
|
|
|
93 |
```bash
|
94 |
+
accelerate launch --config_file config/accelerate_local.yaml \
|
95 |
+
scripts/train_unconditional.py \
|
96 |
--dataset_name teticio/audio-diffusion-256 \
|
97 |
--resolution 256 \
|
98 |
+
--output_dir models/audio-diffusion-256 \
|
99 |
--num_epochs 100 \
|
100 |
--train_batch_size 2 \
|
101 |
--eval_batch_size 2 \
|
|
|
109 |
```
|
110 |
|
111 |
#### Run training on SageMaker.
|
|
|
112 |
```bash
|
113 |
+
accelerate launch --config_file config/accelerate_sagemaker.yaml \
|
114 |
+
scripts/train_unconditional.py \
|
115 |
--dataset_name teticio/audio-diffusion-256 \
|
116 |
--resolution 256 \
|
117 |
+
--output_dir models/ddpm-ema-audio-256 \
|
118 |
--train_batch_size 16 \
|
119 |
--num_epochs 100 \
|
120 |
--gradient_accumulation_steps 1 \
|
|
|
122 |
--lr_warmup_steps 500 \
|
123 |
--mixed_precision no
|
124 |
```
|
125 |
+
## Latent Audio Diffusion
|
126 |
+
Rather than denoising images directly, it is interesting to work in the "latent space" after first encoding images using an autoencoder. This has a number of advantages. Firstly, the information in the images is compressed into a latent space of a much lower dimension, so it is much faster to train denoising diffusion models and run inference with them. Secondly, similar images tend to be clustered together and interpolating between two images in latent space can produce meaningful combinations.
|
127 |
+
|
128 |
+
At the time of writing, the Hugging Face `diffusers` library is geared towards inference and lacking in training functionality, rather like its cousin `transformers` in the early days of development. In order to train a VAE (Variational Autoencoder), I use the [stable-diffusion](https://github.com/CompVis/stable-diffusion) repo from CompVis and convert the checkpoints to `diffusers` format. Note that it uses a perceptual loss function for images; it would be nice to try a perceptual *audio* loss function.
|
129 |
+
|
130 |
+
#### Train an autoencoder.
|
131 |
+
```bash
|
132 |
+
python scripts/train_vae.py \
|
133 |
+
--dataset_name teticio/audio-diffusion-256 \
|
134 |
+
--batch_size 2 \
|
135 |
+
--gradient_accumulation_steps 12
|
136 |
+
```
|
137 |
+
|
138 |
+
#### Train latent diffusion model.
|
139 |
+
```bash
|
140 |
+
accelerate launch ...
|
141 |
+
--vae models/autoencoder-kl
|
142 |
+
--latent_resoultion 32
|
143 |
+
```
|
audiodiffusion/__init__.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
-
from typing import Iterable, Tuple
|
2 |
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
from tqdm.auto import tqdm
|
7 |
from librosa.beat import beat_track
|
8 |
-
from diffusers import DDPMPipeline,
|
|
|
9 |
|
10 |
from .mel import Mel
|
11 |
|
12 |
-
VERSION = "1.
|
13 |
|
14 |
|
15 |
class AudioDiffusion:
|
@@ -42,29 +43,35 @@ class AudioDiffusion:
|
|
42 |
hop_length=hop_length,
|
43 |
top_db=top_db)
|
44 |
self.model_id = model_id
|
45 |
-
|
|
|
|
|
|
|
|
|
46 |
if cuda:
|
47 |
-
self.
|
48 |
self.progress_bar = progress_bar or (lambda _: _)
|
49 |
|
50 |
def generate_spectrogram_and_audio(
|
51 |
self,
|
|
|
52 |
generator: torch.Generator = None
|
53 |
) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
|
54 |
"""Generate random mel spectrogram and convert to audio.
|
55 |
|
56 |
Args:
|
|
|
57 |
generator (torch.Generator): random number generator or None
|
58 |
|
59 |
Returns:
|
60 |
PIL Image: mel spectrogram
|
61 |
(float, np.ndarray): sample rate and raw audio
|
62 |
"""
|
63 |
-
images = self.
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
return
|
68 |
|
69 |
@torch.no_grad()
|
70 |
def generate_spectrogram_and_audio_from_audio(
|
@@ -95,44 +102,124 @@ class AudioDiffusion:
|
|
95 |
(float, np.ndarray): sample rate and raw audio
|
96 |
"""
|
97 |
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
if steps is None:
|
101 |
-
steps = self.
|
102 |
-
|
|
|
103 |
scheduler.set_timesteps(steps)
|
104 |
mask = None
|
105 |
images = noise = torch.randn(
|
106 |
-
(
|
107 |
-
self.
|
108 |
generator=generator)
|
109 |
|
110 |
if audio_file is not None or raw_audio is not None:
|
111 |
-
|
112 |
-
input_image =
|
113 |
input_image = np.frombuffer(input_image.tobytes(),
|
114 |
dtype="uint8").reshape(
|
115 |
(input_image.height,
|
116 |
input_image.width))
|
117 |
input_image = ((input_image / 255) * 2 - 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
if start_step > 0:
|
120 |
images[0, 0] = scheduler.add_noise(
|
121 |
-
torch.tensor(
|
122 |
noise, torch.tensor(steps - start_step))
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
128 |
mask = scheduler.add_noise(
|
129 |
-
torch.tensor(
|
130 |
torch.tensor(scheduler.timesteps[start_step:]))
|
131 |
|
132 |
-
images = images.to(self.
|
133 |
for step, t in enumerate(
|
134 |
self.progress_bar(scheduler.timesteps[start_step:])):
|
135 |
-
model_output = self.
|
136 |
images = scheduler.step(model_output,
|
137 |
t,
|
138 |
images,
|
@@ -140,35 +227,36 @@ class AudioDiffusion:
|
|
140 |
|
141 |
if mask is not None:
|
142 |
if mask_start > 0:
|
143 |
-
images[
|
144 |
-
|
145 |
if mask_end > 0:
|
146 |
-
images[
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
images = (images / 2 + 0.5).clamp(0, 1)
|
149 |
images = images.cpu().permute(0, 2, 3, 1).numpy()
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
audio = self.mel.image_to_audio(image)
|
154 |
-
return image, (self.mel.get_sample_rate(), audio)
|
155 |
|
156 |
-
@staticmethod
|
157 |
-
def loop_it(audio: np.ndarray,
|
158 |
-
sample_rate: int,
|
159 |
-
loops: int = 12) -> np.ndarray:
|
160 |
-
"""Loop audio
|
161 |
|
162 |
-
|
163 |
-
audio (np.ndarray): audio as numpy array
|
164 |
-
sample_rate (int): sample rate of audio
|
165 |
-
loops (int): number of times to loop
|
166 |
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
return
|
|
|
1 |
+
from typing import Iterable, Tuple, Union, List
|
2 |
|
3 |
import torch
|
4 |
import numpy as np
|
5 |
from PIL import Image
|
6 |
from tqdm.auto import tqdm
|
7 |
from librosa.beat import beat_track
|
8 |
+
from diffusers import (DiffusionPipeline, DDPMPipeline, UNet2DConditionModel,
|
9 |
+
DDIMScheduler, DDPMScheduler, AutoencoderKL)
|
10 |
|
11 |
from .mel import Mel
|
12 |
|
13 |
+
VERSION = "1.2.0"
|
14 |
|
15 |
|
16 |
class AudioDiffusion:
|
|
|
43 |
hop_length=hop_length,
|
44 |
top_db=top_db)
|
45 |
self.model_id = model_id
|
46 |
+
try: # a bit hacky
|
47 |
+
self.pipe = LatentAudioDiffusionPipeline.from_pretrained(self.model_id)
|
48 |
+
except:
|
49 |
+
self.pipe = AudioDiffusionPipeline.from_pretrained(self.model_id)
|
50 |
+
|
51 |
if cuda:
|
52 |
+
self.pipe.to("cuda")
|
53 |
self.progress_bar = progress_bar or (lambda _: _)
|
54 |
|
55 |
def generate_spectrogram_and_audio(
|
56 |
self,
|
57 |
+
steps: int = None,
|
58 |
generator: torch.Generator = None
|
59 |
) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
|
60 |
"""Generate random mel spectrogram and convert to audio.
|
61 |
|
62 |
Args:
|
63 |
+
steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
|
64 |
generator (torch.Generator): random number generator or None
|
65 |
|
66 |
Returns:
|
67 |
PIL Image: mel spectrogram
|
68 |
(float, np.ndarray): sample rate and raw audio
|
69 |
"""
|
70 |
+
images, (sample_rate, audios) = self.pipe(mel=self.mel,
|
71 |
+
batch_size=1,
|
72 |
+
steps=steps,
|
73 |
+
generator=generator)
|
74 |
+
return images[0], (sample_rate, audios[0])
|
75 |
|
76 |
@torch.no_grad()
|
77 |
def generate_spectrogram_and_audio_from_audio(
|
|
|
102 |
(float, np.ndarray): sample rate and raw audio
|
103 |
"""
|
104 |
|
105 |
+
images, (sample_rate,
|
106 |
+
audios) = self.pipe(mel=self.mel,
|
107 |
+
batch_size=1,
|
108 |
+
audio_file=audio_file,
|
109 |
+
raw_audio=raw_audio,
|
110 |
+
slice=slice,
|
111 |
+
start_step=start_step,
|
112 |
+
steps=steps,
|
113 |
+
generator=generator,
|
114 |
+
mask_start_secs=mask_start_secs,
|
115 |
+
mask_end_secs=mask_end_secs)
|
116 |
+
return images[0], (sample_rate, audios[0])
|
117 |
+
|
118 |
+
@staticmethod
|
119 |
+
def loop_it(audio: np.ndarray,
|
120 |
+
sample_rate: int,
|
121 |
+
loops: int = 12) -> np.ndarray:
|
122 |
+
"""Loop audio
|
123 |
+
|
124 |
+
Args:
|
125 |
+
audio (np.ndarray): audio as numpy array
|
126 |
+
sample_rate (int): sample rate of audio
|
127 |
+
loops (int): number of times to loop
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
(float, np.ndarray): sample rate and raw audio or None
|
131 |
+
"""
|
132 |
+
_, beats = beat_track(y=audio, sr=sample_rate, units='samples')
|
133 |
+
for beats_in_bar in [16, 12, 8, 4]:
|
134 |
+
if len(beats) > beats_in_bar:
|
135 |
+
return np.tile(audio[beats[0]:beats[beats_in_bar]], loops)
|
136 |
+
return None
|
137 |
+
|
138 |
+
|
139 |
+
class AudioDiffusionPipeline(DiffusionPipeline):
|
140 |
+
|
141 |
+
def __init__(self, unet: UNet2DConditionModel,
|
142 |
+
scheduler: Union[DDIMScheduler, DDPMScheduler]):
|
143 |
+
super().__init__()
|
144 |
+
self.register_modules(unet=unet, scheduler=scheduler)
|
145 |
+
|
146 |
+
@torch.no_grad()
|
147 |
+
def __call__(
|
148 |
+
self,
|
149 |
+
mel: Mel,
|
150 |
+
batch_size: int = 1,
|
151 |
+
audio_file: str = None,
|
152 |
+
raw_audio: np.ndarray = None,
|
153 |
+
slice: int = 0,
|
154 |
+
start_step: int = 0,
|
155 |
+
steps: int = None,
|
156 |
+
generator: torch.Generator = None,
|
157 |
+
mask_start_secs: float = 0,
|
158 |
+
mask_end_secs: float = 0
|
159 |
+
) -> Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]:
|
160 |
+
"""Generate random mel spectrogram from audio input and convert to audio.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
mel (Mel): instance of Mel class to perform image <-> audio
|
164 |
+
batch_size (int): number of samples to generate
|
165 |
+
audio_file (str): must be a file on disk due to Librosa limitation or
|
166 |
+
raw_audio (np.ndarray): audio as numpy array
|
167 |
+
slice (int): slice number of audio to convert
|
168 |
+
start_step (int): step to start from
|
169 |
+
steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
|
170 |
+
generator (torch.Generator): random number generator or None
|
171 |
+
mask_start_secs (float): number of seconds of audio to mask (not generate) at start
|
172 |
+
mask_end_secs (float): number of seconds of audio to mask (not generate) at end
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
List[PIL Image]: mel spectrograms
|
176 |
+
(float, List[np.ndarray]): sample rate and raw audios
|
177 |
+
"""
|
178 |
+
|
179 |
if steps is None:
|
180 |
+
steps = self.scheduler.num_train_timesteps
|
181 |
+
# Unfortunately, the schedule is set up in the constructor
|
182 |
+
scheduler = self.scheduler.__class__(num_train_timesteps=steps)
|
183 |
scheduler.set_timesteps(steps)
|
184 |
mask = None
|
185 |
images = noise = torch.randn(
|
186 |
+
(batch_size, self.unet.in_channels, self.unet.sample_size,
|
187 |
+
self.unet.sample_size),
|
188 |
generator=generator)
|
189 |
|
190 |
if audio_file is not None or raw_audio is not None:
|
191 |
+
mel.load_audio(audio_file, raw_audio)
|
192 |
+
input_image = mel.audio_slice_to_image(slice)
|
193 |
input_image = np.frombuffer(input_image.tobytes(),
|
194 |
dtype="uint8").reshape(
|
195 |
(input_image.height,
|
196 |
input_image.width))
|
197 |
input_image = ((input_image / 255) * 2 - 1)
|
198 |
+
input_images = np.tile(input_image, (batch_size, 1, 1, 1))
|
199 |
+
|
200 |
+
if hasattr(self, 'vqvae'):
|
201 |
+
input_images = self.vqvae.encode(
|
202 |
+
input_images).latent_dist.sample(generator=generator)
|
203 |
+
input_images = 0.18215 * input_images
|
204 |
|
205 |
if start_step > 0:
|
206 |
images[0, 0] = scheduler.add_noise(
|
207 |
+
torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
|
208 |
noise, torch.tensor(steps - start_step))
|
209 |
|
210 |
+
pixels_per_second = (mel.get_sample_rate() *
|
211 |
+
self.unet.sample_size / mel.hop_length /
|
212 |
+
mel.x_res)
|
213 |
+
mask_start = int(mask_start_secs * pixels_per_second)
|
214 |
+
mask_end = int(mask_end_secs * pixels_per_second)
|
215 |
mask = scheduler.add_noise(
|
216 |
+
torch.tensor(input_images[:, np.newaxis, :]), noise,
|
217 |
torch.tensor(scheduler.timesteps[start_step:]))
|
218 |
|
219 |
+
images = images.to(self.device)
|
220 |
for step, t in enumerate(
|
221 |
self.progress_bar(scheduler.timesteps[start_step:])):
|
222 |
+
model_output = self.unet(images, t)['sample']
|
223 |
images = scheduler.step(model_output,
|
224 |
t,
|
225 |
images,
|
|
|
227 |
|
228 |
if mask is not None:
|
229 |
if mask_start > 0:
|
230 |
+
images[:, :, :, :mask_start] = mask[
|
231 |
+
step, :, :, :, :mask_start]
|
232 |
if mask_end > 0:
|
233 |
+
images[:, :, :, -mask_end:] = mask[step, :, :, :,
|
234 |
+
-mask_end:]
|
235 |
+
|
236 |
+
if hasattr(self, 'vqvae'):
|
237 |
+
# 0.18215 was scaling factor used in training to ensure unit variance
|
238 |
+
images = 1 / 0.18215 * images
|
239 |
+
images = self.vqvae.decode(images)['sample']
|
240 |
|
241 |
images = (images / 2 + 0.5).clamp(0, 1)
|
242 |
images = images.cpu().permute(0, 2, 3, 1).numpy()
|
243 |
+
images = (images * 255).round().astype("uint8")
|
244 |
+
images = list(
|
245 |
+
map(lambda _: Image.fromarray(_[:, :, 0]), images) if images.
|
246 |
+
shape[3] == 1 else map(
|
247 |
+
lambda _: Image.fromarray(_, mode='RGB').convert('L'), images))
|
248 |
|
249 |
+
audios = list(map(lambda _: mel.image_to_audio(_), images))
|
250 |
+
return images, (mel.get_sample_rate(), audios)
|
|
|
|
|
251 |
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
+
class LatentAudioDiffusionPipeline(AudioDiffusionPipeline):
|
|
|
|
|
|
|
254 |
|
255 |
+
def __init__(self, unet: UNet2DConditionModel,
|
256 |
+
scheduler: Union[DDIMScheduler,
|
257 |
+
DDPMScheduler], vqvae: AutoencoderKL):
|
258 |
+
super().__init__(unet=unet, scheduler=scheduler)
|
259 |
+
self.register_modules(vqvae=vqvae)
|
260 |
+
|
261 |
+
def __call__(self, *args, **kwargs):
|
262 |
+
return super().__call__(*args, **kwargs)
|
audiodiffusion/utils.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adpated from https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers import AutoencoderKL
|
5 |
+
|
6 |
+
|
7 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
8 |
+
"""
|
9 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
10 |
+
"""
|
11 |
+
if n_shave_prefix_segments >= 0:
|
12 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
13 |
+
else:
|
14 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
15 |
+
|
16 |
+
|
17 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
18 |
+
"""
|
19 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
20 |
+
"""
|
21 |
+
mapping = []
|
22 |
+
for old_item in old_list:
|
23 |
+
new_item = old_item
|
24 |
+
|
25 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
26 |
+
new_item = shave_segments(
|
27 |
+
new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
28 |
+
|
29 |
+
mapping.append({"old": old_item, "new": new_item})
|
30 |
+
|
31 |
+
return mapping
|
32 |
+
|
33 |
+
|
34 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
35 |
+
"""
|
36 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
37 |
+
"""
|
38 |
+
mapping = []
|
39 |
+
for old_item in old_list:
|
40 |
+
new_item = old_item
|
41 |
+
|
42 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
43 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
44 |
+
|
45 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
46 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
47 |
+
|
48 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
49 |
+
|
50 |
+
mapping.append({"old": old_item, "new": new_item})
|
51 |
+
|
52 |
+
return mapping
|
53 |
+
|
54 |
+
|
55 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
56 |
+
"""
|
57 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
58 |
+
"""
|
59 |
+
mapping = []
|
60 |
+
for old_item in old_list:
|
61 |
+
new_item = old_item
|
62 |
+
|
63 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
64 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
65 |
+
|
66 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
67 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
68 |
+
|
69 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
70 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
71 |
+
|
72 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
73 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
74 |
+
|
75 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
76 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
77 |
+
|
78 |
+
new_item = shave_segments(
|
79 |
+
new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
80 |
+
|
81 |
+
mapping.append({"old": old_item, "new": new_item})
|
82 |
+
|
83 |
+
return mapping
|
84 |
+
|
85 |
+
|
86 |
+
def assign_to_checkpoint(paths,
|
87 |
+
checkpoint,
|
88 |
+
old_checkpoint,
|
89 |
+
attention_paths_to_split=None,
|
90 |
+
additional_replacements=None,
|
91 |
+
config=None):
|
92 |
+
"""
|
93 |
+
This does the final conversion step: take locally converted weights and apply a global renaming
|
94 |
+
to them. It splits attention layers, and takes into account additional replacements
|
95 |
+
that may arise.
|
96 |
+
|
97 |
+
Assigns the weights to the new checkpoint.
|
98 |
+
"""
|
99 |
+
assert isinstance(
|
100 |
+
paths, list
|
101 |
+
), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
102 |
+
|
103 |
+
# Splits the attention layers into three variables.
|
104 |
+
if attention_paths_to_split is not None:
|
105 |
+
for path, path_map in attention_paths_to_split.items():
|
106 |
+
old_tensor = old_checkpoint[path]
|
107 |
+
channels = old_tensor.shape[0] // 3
|
108 |
+
|
109 |
+
target_shape = (-1,
|
110 |
+
channels) if len(old_tensor.shape) == 3 else (-1)
|
111 |
+
|
112 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
113 |
+
|
114 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels //
|
115 |
+
num_heads) + old_tensor.shape[1:])
|
116 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
117 |
+
|
118 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
119 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
120 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
121 |
+
|
122 |
+
for path in paths:
|
123 |
+
new_path = path["new"]
|
124 |
+
|
125 |
+
# These have already been assigned
|
126 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
127 |
+
continue
|
128 |
+
|
129 |
+
# Global renaming happens here
|
130 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
131 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
132 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
133 |
+
|
134 |
+
if additional_replacements is not None:
|
135 |
+
for replacement in additional_replacements:
|
136 |
+
new_path = new_path.replace(replacement["old"],
|
137 |
+
replacement["new"])
|
138 |
+
|
139 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
140 |
+
if "proj_attn.weight" in new_path:
|
141 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
142 |
+
else:
|
143 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
144 |
+
|
145 |
+
|
146 |
+
def conv_attn_to_linear(checkpoint):
|
147 |
+
keys = list(checkpoint.keys())
|
148 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
149 |
+
for key in keys:
|
150 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
151 |
+
if checkpoint[key].ndim > 2:
|
152 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
153 |
+
elif "proj_attn.weight" in key:
|
154 |
+
if checkpoint[key].ndim > 2:
|
155 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
156 |
+
|
157 |
+
|
158 |
+
def create_vae_diffusers_config(original_config):
|
159 |
+
"""
|
160 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
161 |
+
"""
|
162 |
+
vae_params = original_config.model.params.ddconfig
|
163 |
+
_ = original_config.model.params.embed_dim
|
164 |
+
|
165 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
166 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
167 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
168 |
+
|
169 |
+
config = dict(
|
170 |
+
sample_size=vae_params.resolution,
|
171 |
+
in_channels=vae_params.in_channels,
|
172 |
+
out_channels=vae_params.out_ch,
|
173 |
+
down_block_types=tuple(down_block_types),
|
174 |
+
up_block_types=tuple(up_block_types),
|
175 |
+
block_out_channels=tuple(block_out_channels),
|
176 |
+
latent_channels=vae_params.z_channels,
|
177 |
+
layers_per_block=vae_params.num_res_blocks,
|
178 |
+
)
|
179 |
+
return config
|
180 |
+
|
181 |
+
|
182 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
183 |
+
# extract state dict for VAE
|
184 |
+
vae_state_dict = checkpoint
|
185 |
+
|
186 |
+
new_checkpoint = {}
|
187 |
+
|
188 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict[
|
189 |
+
"encoder.conv_in.weight"]
|
190 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict[
|
191 |
+
"encoder.conv_in.bias"]
|
192 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
|
193 |
+
"encoder.conv_out.weight"]
|
194 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict[
|
195 |
+
"encoder.conv_out.bias"]
|
196 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
|
197 |
+
"encoder.norm_out.weight"]
|
198 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
|
199 |
+
"encoder.norm_out.bias"]
|
200 |
+
|
201 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict[
|
202 |
+
"decoder.conv_in.weight"]
|
203 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict[
|
204 |
+
"decoder.conv_in.bias"]
|
205 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
|
206 |
+
"decoder.conv_out.weight"]
|
207 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict[
|
208 |
+
"decoder.conv_out.bias"]
|
209 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
|
210 |
+
"decoder.norm_out.weight"]
|
211 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
|
212 |
+
"decoder.norm_out.bias"]
|
213 |
+
|
214 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
215 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
216 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict[
|
217 |
+
"post_quant_conv.weight"]
|
218 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict[
|
219 |
+
"post_quant_conv.bias"]
|
220 |
+
|
221 |
+
# Retrieves the keys for the encoder down blocks only
|
222 |
+
num_down_blocks = len({
|
223 |
+
".".join(layer.split(".")[:3])
|
224 |
+
for layer in vae_state_dict if "encoder.down" in layer
|
225 |
+
})
|
226 |
+
down_blocks = {
|
227 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
|
228 |
+
for layer_id in range(num_down_blocks)
|
229 |
+
}
|
230 |
+
|
231 |
+
# Retrieves the keys for the decoder up blocks only
|
232 |
+
num_up_blocks = len({
|
233 |
+
".".join(layer.split(".")[:3])
|
234 |
+
for layer in vae_state_dict if "decoder.up" in layer
|
235 |
+
})
|
236 |
+
up_blocks = {
|
237 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
|
238 |
+
for layer_id in range(num_up_blocks)
|
239 |
+
}
|
240 |
+
|
241 |
+
for i in range(num_down_blocks):
|
242 |
+
resnets = [
|
243 |
+
key for key in down_blocks[i]
|
244 |
+
if f"down.{i}" in key and f"down.{i}.downsample" not in key
|
245 |
+
]
|
246 |
+
|
247 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
248 |
+
new_checkpoint[
|
249 |
+
f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
250 |
+
f"encoder.down.{i}.downsample.conv.weight")
|
251 |
+
new_checkpoint[
|
252 |
+
f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
253 |
+
f"encoder.down.{i}.downsample.conv.bias")
|
254 |
+
|
255 |
+
paths = renew_vae_resnet_paths(resnets)
|
256 |
+
meta_path = {
|
257 |
+
"old": f"down.{i}.block",
|
258 |
+
"new": f"down_blocks.{i}.resnets"
|
259 |
+
}
|
260 |
+
assign_to_checkpoint(paths,
|
261 |
+
new_checkpoint,
|
262 |
+
vae_state_dict,
|
263 |
+
additional_replacements=[meta_path],
|
264 |
+
config=config)
|
265 |
+
|
266 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
267 |
+
num_mid_res_blocks = 2
|
268 |
+
for i in range(1, num_mid_res_blocks + 1):
|
269 |
+
resnets = [
|
270 |
+
key for key in mid_resnets if f"encoder.mid.block_{i}" in key
|
271 |
+
]
|
272 |
+
|
273 |
+
paths = renew_vae_resnet_paths(resnets)
|
274 |
+
meta_path = {
|
275 |
+
"old": f"mid.block_{i}",
|
276 |
+
"new": f"mid_block.resnets.{i - 1}"
|
277 |
+
}
|
278 |
+
assign_to_checkpoint(paths,
|
279 |
+
new_checkpoint,
|
280 |
+
vae_state_dict,
|
281 |
+
additional_replacements=[meta_path],
|
282 |
+
config=config)
|
283 |
+
|
284 |
+
mid_attentions = [
|
285 |
+
key for key in vae_state_dict if "encoder.mid.attn" in key
|
286 |
+
]
|
287 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
288 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
289 |
+
assign_to_checkpoint(paths,
|
290 |
+
new_checkpoint,
|
291 |
+
vae_state_dict,
|
292 |
+
additional_replacements=[meta_path],
|
293 |
+
config=config)
|
294 |
+
conv_attn_to_linear(new_checkpoint)
|
295 |
+
|
296 |
+
for i in range(num_up_blocks):
|
297 |
+
block_id = num_up_blocks - 1 - i
|
298 |
+
resnets = [
|
299 |
+
key for key in up_blocks[block_id]
|
300 |
+
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
301 |
+
]
|
302 |
+
|
303 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
304 |
+
new_checkpoint[
|
305 |
+
f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
306 |
+
f"decoder.up.{block_id}.upsample.conv.weight"]
|
307 |
+
new_checkpoint[
|
308 |
+
f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
309 |
+
f"decoder.up.{block_id}.upsample.conv.bias"]
|
310 |
+
|
311 |
+
paths = renew_vae_resnet_paths(resnets)
|
312 |
+
meta_path = {
|
313 |
+
"old": f"up.{block_id}.block",
|
314 |
+
"new": f"up_blocks.{i}.resnets"
|
315 |
+
}
|
316 |
+
assign_to_checkpoint(paths,
|
317 |
+
new_checkpoint,
|
318 |
+
vae_state_dict,
|
319 |
+
additional_replacements=[meta_path],
|
320 |
+
config=config)
|
321 |
+
|
322 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
323 |
+
num_mid_res_blocks = 2
|
324 |
+
for i in range(1, num_mid_res_blocks + 1):
|
325 |
+
resnets = [
|
326 |
+
key for key in mid_resnets if f"decoder.mid.block_{i}" in key
|
327 |
+
]
|
328 |
+
|
329 |
+
paths = renew_vae_resnet_paths(resnets)
|
330 |
+
meta_path = {
|
331 |
+
"old": f"mid.block_{i}",
|
332 |
+
"new": f"mid_block.resnets.{i - 1}"
|
333 |
+
}
|
334 |
+
assign_to_checkpoint(paths,
|
335 |
+
new_checkpoint,
|
336 |
+
vae_state_dict,
|
337 |
+
additional_replacements=[meta_path],
|
338 |
+
config=config)
|
339 |
+
|
340 |
+
mid_attentions = [
|
341 |
+
key for key in vae_state_dict if "decoder.mid.attn" in key
|
342 |
+
]
|
343 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
344 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
345 |
+
assign_to_checkpoint(paths,
|
346 |
+
new_checkpoint,
|
347 |
+
vae_state_dict,
|
348 |
+
additional_replacements=[meta_path],
|
349 |
+
config=config)
|
350 |
+
conv_attn_to_linear(new_checkpoint)
|
351 |
+
return new_checkpoint
|
352 |
+
|
353 |
+
def convert_ldm_to_hf_vae(ldm_checkpoint, ldm_config, hf_checkpoint):
|
354 |
+
checkpoint = torch.load(ldm_checkpoint)["state_dict"]
|
355 |
+
|
356 |
+
# Convert the VAE model.
|
357 |
+
vae_config = create_vae_diffusers_config(ldm_config)
|
358 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
359 |
+
checkpoint, vae_config)
|
360 |
+
|
361 |
+
vae = AutoencoderKL(**vae_config)
|
362 |
+
vae.load_state_dict(converted_vae_checkpoint)
|
363 |
+
vae.save_pretrained(hf_checkpoint)
|
accelerate_deepspeed.yaml → config/accelerate_deepspeed.yaml
RENAMED
File without changes
|
accelerate_local.yaml → config/accelerate_local.yaml
RENAMED
File without changes
|
accelerate_sagemaker.yaml → config/accelerate_sagemaker.yaml
RENAMED
File without changes
|
config/ldm_autoencoder_kl.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
model:
|
3 |
+
base_learning_rate: 4.5e-6
|
4 |
+
target: ldm.models.autoencoder.AutoencoderKL
|
5 |
+
params:
|
6 |
+
monitor: "val/rec_loss"
|
7 |
+
embed_dim: 3
|
8 |
+
lossconfig:
|
9 |
+
target: ldm.modules.losses.LPIPSWithDiscriminator
|
10 |
+
params:
|
11 |
+
disc_start: 50001
|
12 |
+
kl_weight: 0.000001
|
13 |
+
disc_weight: 0.5
|
14 |
+
|
15 |
+
ddconfig:
|
16 |
+
double_z: True
|
17 |
+
z_channels: 4
|
18 |
+
resolution: 256
|
19 |
+
in_channels: 3
|
20 |
+
out_ch: 3
|
21 |
+
ch: 128
|
22 |
+
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
|
23 |
+
num_res_blocks: 2
|
24 |
+
attn_resolutions: [ ]
|
25 |
+
dropout: 0.0
|
26 |
+
|
27 |
+
lightning:
|
28 |
+
trainer:
|
29 |
+
benchmark: True
|
30 |
+
accelerator: gpu
|
31 |
+
devices: 1
|
notebooks/test_vae.ipynb
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "bcbbe26c",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import os\n",
|
11 |
+
"import sys\n",
|
12 |
+
"sys.path.insert(0, os.path.dirname(os.path.abspath(\"\")))"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": null,
|
18 |
+
"id": "b451ab22",
|
19 |
+
"metadata": {},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"import torch\n",
|
23 |
+
"import random\n",
|
24 |
+
"import numpy as np\n",
|
25 |
+
"from PIL import Image\n",
|
26 |
+
"from datasets import load_dataset\n",
|
27 |
+
"from IPython.display import Audio\n",
|
28 |
+
"from diffusers import AutoencoderKL\n",
|
29 |
+
"from audiodiffusion.mel import Mel"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"cell_type": "code",
|
34 |
+
"execution_count": null,
|
35 |
+
"id": "324cef44",
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"mel = Mel()\n",
|
40 |
+
"vae = AutoencoderKL.from_pretrained('../models/autoencoder-kl')"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"id": "da55ce79",
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"vae.config"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": null,
|
56 |
+
"id": "5fea99ff",
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"ds = load_dataset('teticio/audio-diffusion-256')"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "code",
|
65 |
+
"execution_count": null,
|
66 |
+
"id": "426c6edd",
|
67 |
+
"metadata": {},
|
68 |
+
"outputs": [],
|
69 |
+
"source": [
|
70 |
+
"image = random.choice(ds['train'])['image']\n",
|
71 |
+
"display(image)\n",
|
72 |
+
"Audio(data=mel.image_to_audio(image), rate=mel.get_sample_rate())"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": null,
|
78 |
+
"id": "d123f8a0",
|
79 |
+
"metadata": {},
|
80 |
+
"outputs": [],
|
81 |
+
"source": [
|
82 |
+
"# encode\n",
|
83 |
+
"input_image = np.frombuffer(image.convert('RGB').tobytes(), dtype=\"uint8\").reshape(\n",
|
84 |
+
" (image.height, image.width, 3))\n",
|
85 |
+
"input_image = ((input_image / 255) * 2 - 1).transpose(2, 0, 1)\n",
|
86 |
+
"posterior = vae.encode(torch.tensor([input_image], dtype=torch.float32)).latent_dist\n",
|
87 |
+
"latents = posterior.sample()"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": null,
|
93 |
+
"id": "482c458f",
|
94 |
+
"metadata": {},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"# reconstruct\n",
|
98 |
+
"output_image = vae.decode(latents)['sample']\n",
|
99 |
+
"output_image = torch.clamp(output_image, -1., 1.)\n",
|
100 |
+
"output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w\n",
|
101 |
+
"output_image = (output_image.detach().cpu().numpy() *\n",
|
102 |
+
" 255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0]\n",
|
103 |
+
"output_image = Image.fromarray(output_image).convert('L')\n",
|
104 |
+
"display(output_image)\n",
|
105 |
+
"Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": null,
|
111 |
+
"id": "f10db020",
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [],
|
114 |
+
"source": [
|
115 |
+
"# sample\n",
|
116 |
+
"output_image = vae.decode(torch.randn_like(posterior.sample()))['sample']\n",
|
117 |
+
"output_image = torch.clamp(output_image, -1., 1.)\n",
|
118 |
+
"output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w\n",
|
119 |
+
"output_image = (output_image.detach().cpu().numpy() *\n",
|
120 |
+
" 255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0]\n",
|
121 |
+
"output_image = Image.fromarray(output_image).convert('L')\n",
|
122 |
+
"display(output_image)\n",
|
123 |
+
"Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "code",
|
128 |
+
"execution_count": null,
|
129 |
+
"id": "46019770",
|
130 |
+
"metadata": {},
|
131 |
+
"outputs": [],
|
132 |
+
"source": []
|
133 |
+
}
|
134 |
+
],
|
135 |
+
"metadata": {
|
136 |
+
"kernelspec": {
|
137 |
+
"display_name": "huggingface",
|
138 |
+
"language": "python",
|
139 |
+
"name": "huggingface"
|
140 |
+
},
|
141 |
+
"language_info": {
|
142 |
+
"codemirror_mode": {
|
143 |
+
"name": "ipython",
|
144 |
+
"version": 3
|
145 |
+
},
|
146 |
+
"file_extension": ".py",
|
147 |
+
"mimetype": "text/x-python",
|
148 |
+
"name": "python",
|
149 |
+
"nbconvert_exporter": "python",
|
150 |
+
"pygments_lexer": "ipython3",
|
151 |
+
"version": "3.10.6"
|
152 |
+
},
|
153 |
+
"toc": {
|
154 |
+
"base_numbering": 1,
|
155 |
+
"nav_menu": {},
|
156 |
+
"number_sections": true,
|
157 |
+
"sideBar": true,
|
158 |
+
"skip_h1_title": false,
|
159 |
+
"title_cell": "Table of Contents",
|
160 |
+
"title_sidebar": "Contents",
|
161 |
+
"toc_cell": false,
|
162 |
+
"toc_position": {},
|
163 |
+
"toc_section_display": true,
|
164 |
+
"toc_window_display": false
|
165 |
+
}
|
166 |
+
},
|
167 |
+
"nbformat": 4,
|
168 |
+
"nbformat_minor": 5
|
169 |
+
}
|
audio_to_images.py → scripts/audio_to_images.py
RENAMED
File without changes
|
train_unconditional.py → scripts/train_unconditional.py
RENAMED
@@ -5,12 +5,12 @@ import os
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
-
from PIL import Image
|
9 |
|
10 |
from accelerate import Accelerator
|
11 |
from accelerate.logging import get_logger
|
12 |
from datasets import load_from_disk, load_dataset
|
13 |
-
from diffusers import
|
|
|
14 |
from diffusers.hub_utils import init_git_repo, push_to_hub
|
15 |
from diffusers.optimization import get_scheduler
|
16 |
from diffusers.training_utils import EMAModel
|
@@ -22,10 +22,12 @@ from torchvision.transforms import (
|
|
22 |
Resize,
|
23 |
ToTensor,
|
24 |
)
|
|
|
25 |
from tqdm.auto import tqdm
|
26 |
from librosa.util import normalize
|
27 |
|
28 |
from audiodiffusion.mel import Mel
|
|
|
29 |
|
30 |
logger = get_logger(__name__)
|
31 |
|
@@ -34,18 +36,25 @@ def main(args):
|
|
34 |
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
|
35 |
logging_dir = os.path.join(output_dir, args.logging_dir)
|
36 |
accelerator = Accelerator(
|
|
|
37 |
mixed_precision=args.mixed_precision,
|
38 |
log_with="tensorboard",
|
39 |
logging_dir=logging_dir,
|
40 |
)
|
41 |
|
|
|
|
|
|
|
42 |
if args.from_pretrained is not None:
|
43 |
-
model =
|
44 |
else:
|
45 |
model = UNet2DModel(
|
46 |
-
sample_size=args.resolution
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
49 |
layers_per_block=2,
|
50 |
block_out_channels=(128, 128, 256, 256, 512, 512),
|
51 |
down_block_types=(
|
@@ -65,8 +74,14 @@ def main(args):
|
|
65 |
"UpBlock2D",
|
66 |
),
|
67 |
)
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
optimizer = torch.optim.AdamW(
|
71 |
model.parameters(),
|
72 |
lr=args.learning_rate,
|
@@ -103,7 +118,13 @@ def main(args):
|
|
103 |
)
|
104 |
|
105 |
def transforms(examples):
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
return {"input": images}
|
108 |
|
109 |
dataset.set_transform(transforms)
|
@@ -158,6 +179,15 @@ def main(args):
|
|
158 |
model.train()
|
159 |
for step, batch in enumerate(train_dataloader):
|
160 |
clean_images = batch["input"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
# Sample noise that we'll add to the images
|
162 |
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
163 |
bsz = clean_images.shape[0]
|
@@ -180,7 +210,8 @@ def main(args):
|
|
180 |
loss = F.mse_loss(noise_pred, noise)
|
181 |
accelerator.backward(loss)
|
182 |
|
183 |
-
accelerator.
|
|
|
184 |
optimizer.step()
|
185 |
lr_scheduler.step()
|
186 |
if args.use_ema:
|
@@ -188,6 +219,8 @@ def main(args):
|
|
188 |
optimizer.zero_grad()
|
189 |
|
190 |
progress_bar.update(1)
|
|
|
|
|
191 |
logs = {
|
192 |
"loss": loss.detach().item(),
|
193 |
"lr": lr_scheduler.get_last_lr()[0],
|
@@ -197,7 +230,6 @@ def main(args):
|
|
197 |
logs["ema_decay"] = ema_model.decay
|
198 |
progress_bar.set_postfix(**logs)
|
199 |
accelerator.log(logs, step=global_step)
|
200 |
-
global_step += 1
|
201 |
progress_bar.close()
|
202 |
|
203 |
accelerator.wait_for_everyone()
|
@@ -205,11 +237,20 @@ def main(args):
|
|
205 |
# Generate sample images for visual inspection
|
206 |
if accelerator.is_main_process:
|
207 |
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
# save the model
|
215 |
if args.push_to_hub:
|
@@ -226,27 +267,30 @@ def main(args):
|
|
226 |
else:
|
227 |
pipeline.save_pretrained(output_dir)
|
228 |
|
229 |
-
|
|
|
230 |
# run pipeline in inference (sample random noise and denoise)
|
231 |
-
images = pipeline(
|
|
|
232 |
generator=generator,
|
233 |
batch_size=args.eval_batch_size,
|
234 |
-
|
235 |
-
)
|
236 |
|
237 |
# denormalize the images and save to tensorboard
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
241 |
accelerator.trackers[0].writer.add_images(
|
242 |
-
"test_samples",
|
243 |
-
for _,
|
244 |
-
audio = mel.image_to_audio(Image.fromarray(image[0]))
|
245 |
accelerator.trackers[0].writer.add_audio(
|
246 |
f"test_audio_{_}",
|
247 |
normalize(audio),
|
248 |
epoch,
|
249 |
-
sample_rate=
|
250 |
)
|
251 |
accelerator.wait_for_everyone()
|
252 |
|
@@ -268,7 +312,7 @@ if __name__ == "__main__":
|
|
268 |
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
|
269 |
parser.add_argument("--overwrite_output_dir", type=bool, default=False)
|
270 |
parser.add_argument("--cache_dir", type=str, default=None)
|
271 |
-
parser.add_argument("--resolution", type=int, default=
|
272 |
parser.add_argument("--train_batch_size", type=int, default=16)
|
273 |
parser.add_argument("--eval_batch_size", type=int, default=16)
|
274 |
parser.add_argument("--num_epochs", type=int, default=100)
|
@@ -305,6 +349,16 @@ if __name__ == "__main__":
|
|
305 |
parser.add_argument("--hop_length", type=int, default=512)
|
306 |
parser.add_argument("--from_pretrained", type=str, default=None)
|
307 |
parser.add_argument("--start_epoch", type=int, default=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
args = parser.parse_args()
|
310 |
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
|
|
8 |
|
9 |
from accelerate import Accelerator
|
10 |
from accelerate.logging import get_logger
|
11 |
from datasets import load_from_disk, load_dataset
|
12 |
+
from diffusers import (DiffusionPipeline, DDPMScheduler, UNet2DModel,
|
13 |
+
DDIMScheduler, AutoencoderKL)
|
14 |
from diffusers.hub_utils import init_git_repo, push_to_hub
|
15 |
from diffusers.optimization import get_scheduler
|
16 |
from diffusers.training_utils import EMAModel
|
|
|
22 |
Resize,
|
23 |
ToTensor,
|
24 |
)
|
25 |
+
import numpy as np
|
26 |
from tqdm.auto import tqdm
|
27 |
from librosa.util import normalize
|
28 |
|
29 |
from audiodiffusion.mel import Mel
|
30 |
+
from audiodiffusion import LatentAudioDiffusionPipeline, AudioDiffusionPipeline
|
31 |
|
32 |
logger = get_logger(__name__)
|
33 |
|
|
|
36 |
output_dir = os.environ.get("SM_MODEL_DIR", None) or args.output_dir
|
37 |
logging_dir = os.path.join(output_dir, args.logging_dir)
|
38 |
accelerator = Accelerator(
|
39 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
40 |
mixed_precision=args.mixed_precision,
|
41 |
log_with="tensorboard",
|
42 |
logging_dir=logging_dir,
|
43 |
)
|
44 |
|
45 |
+
if args.vae is not None:
|
46 |
+
vqvae = AutoencoderKL.from_pretrained(args.vae)
|
47 |
+
|
48 |
if args.from_pretrained is not None:
|
49 |
+
model = DiffusionPipeline.from_pretrained(args.from_pretrained).unet
|
50 |
else:
|
51 |
model = UNet2DModel(
|
52 |
+
sample_size=args.resolution
|
53 |
+
if args.vae is None else args.latent_resolution,
|
54 |
+
in_channels=1
|
55 |
+
if args.vae is None else vqvae.config['latent_channels'],
|
56 |
+
out_channels=1
|
57 |
+
if args.vae is None else vqvae.config['latent_channels'],
|
58 |
layers_per_block=2,
|
59 |
block_out_channels=(128, 128, 256, 256, 512, 512),
|
60 |
down_block_types=(
|
|
|
74 |
"UpBlock2D",
|
75 |
),
|
76 |
)
|
77 |
+
|
78 |
+
if args.scheduler == "ddpm":
|
79 |
+
noise_scheduler = DDPMScheduler(
|
80 |
+
num_train_timesteps=args.num_train_steps, tensor_format="pt")
|
81 |
+
else:
|
82 |
+
noise_scheduler = DDIMScheduler(
|
83 |
+
num_train_timesteps=args.num_train_steps, tensor_format="pt")
|
84 |
+
|
85 |
optimizer = torch.optim.AdamW(
|
86 |
model.parameters(),
|
87 |
lr=args.learning_rate,
|
|
|
118 |
)
|
119 |
|
120 |
def transforms(examples):
|
121 |
+
if args.vae is not None and vqvae.config['in_channels'] == 3:
|
122 |
+
images = [
|
123 |
+
augmentations(image.convert('RGB'))
|
124 |
+
for image in examples["image"]
|
125 |
+
]
|
126 |
+
else:
|
127 |
+
images = [augmentations(image) for image in examples["image"]]
|
128 |
return {"input": images}
|
129 |
|
130 |
dataset.set_transform(transforms)
|
|
|
179 |
model.train()
|
180 |
for step, batch in enumerate(train_dataloader):
|
181 |
clean_images = batch["input"]
|
182 |
+
|
183 |
+
if args.vae is not None:
|
184 |
+
vqvae.to(clean_images.device)
|
185 |
+
with torch.no_grad():
|
186 |
+
clean_images = vqvae.encode(
|
187 |
+
clean_images).latent_dist.sample()
|
188 |
+
# Scale latent images to ensure approximately unit variance
|
189 |
+
clean_images = clean_images * 0.18215
|
190 |
+
|
191 |
# Sample noise that we'll add to the images
|
192 |
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
193 |
bsz = clean_images.shape[0]
|
|
|
210 |
loss = F.mse_loss(noise_pred, noise)
|
211 |
accelerator.backward(loss)
|
212 |
|
213 |
+
if accelerator.sync_gradients:
|
214 |
+
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
215 |
optimizer.step()
|
216 |
lr_scheduler.step()
|
217 |
if args.use_ema:
|
|
|
219 |
optimizer.zero_grad()
|
220 |
|
221 |
progress_bar.update(1)
|
222 |
+
global_step += 1
|
223 |
+
|
224 |
logs = {
|
225 |
"loss": loss.detach().item(),
|
226 |
"lr": lr_scheduler.get_last_lr()[0],
|
|
|
230 |
logs["ema_decay"] = ema_model.decay
|
231 |
progress_bar.set_postfix(**logs)
|
232 |
accelerator.log(logs, step=global_step)
|
|
|
233 |
progress_bar.close()
|
234 |
|
235 |
accelerator.wait_for_everyone()
|
|
|
237 |
# Generate sample images for visual inspection
|
238 |
if accelerator.is_main_process:
|
239 |
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
240 |
+
if args.vae is not None:
|
241 |
+
pipeline = LatentAudioDiffusionPipeline(
|
242 |
+
unet=accelerator.unwrap_model(
|
243 |
+
ema_model.averaged_model if args.use_ema else model
|
244 |
+
),
|
245 |
+
vqvae=vqvae,
|
246 |
+
scheduler=noise_scheduler)
|
247 |
+
else:
|
248 |
+
pipeline = AudioDiffusionPipeline(
|
249 |
+
unet=accelerator.unwrap_model(
|
250 |
+
ema_model.averaged_model if args.use_ema else model
|
251 |
+
),
|
252 |
+
scheduler=noise_scheduler,
|
253 |
+
)
|
254 |
|
255 |
# save the model
|
256 |
if args.push_to_hub:
|
|
|
267 |
else:
|
268 |
pipeline.save_pretrained(output_dir)
|
269 |
|
270 |
+
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
271 |
+
generator = torch.manual_seed(42)
|
272 |
# run pipeline in inference (sample random noise and denoise)
|
273 |
+
images, (sample_rate, audios) = pipeline(
|
274 |
+
mel=mel,
|
275 |
generator=generator,
|
276 |
batch_size=args.eval_batch_size,
|
277 |
+
steps=args.num_train_steps,
|
278 |
+
)
|
279 |
|
280 |
# denormalize the images and save to tensorboard
|
281 |
+
images = np.array([
|
282 |
+
np.frombuffer(image.tobytes(), dtype="uint8").reshape(
|
283 |
+
(len(image.getbands()), image.height, image.width))
|
284 |
+
for image in images
|
285 |
+
])
|
286 |
accelerator.trackers[0].writer.add_images(
|
287 |
+
"test_samples", images, epoch)
|
288 |
+
for _, audio in enumerate(audios):
|
|
|
289 |
accelerator.trackers[0].writer.add_audio(
|
290 |
f"test_audio_{_}",
|
291 |
normalize(audio),
|
292 |
epoch,
|
293 |
+
sample_rate=sample_rate,
|
294 |
)
|
295 |
accelerator.wait_for_everyone()
|
296 |
|
|
|
312 |
parser.add_argument("--output_dir", type=str, default="ddpm-model-64")
|
313 |
parser.add_argument("--overwrite_output_dir", type=bool, default=False)
|
314 |
parser.add_argument("--cache_dir", type=str, default=None)
|
315 |
+
parser.add_argument("--resolution", type=int, default=256)
|
316 |
parser.add_argument("--train_batch_size", type=int, default=16)
|
317 |
parser.add_argument("--eval_batch_size", type=int, default=16)
|
318 |
parser.add_argument("--num_epochs", type=int, default=100)
|
|
|
349 |
parser.add_argument("--hop_length", type=int, default=512)
|
350 |
parser.add_argument("--from_pretrained", type=str, default=None)
|
351 |
parser.add_argument("--start_epoch", type=int, default=0)
|
352 |
+
parser.add_argument("--num_train_steps", type=int, default=1000)
|
353 |
+
parser.add_argument("--latent_resolution", type=int, default=None)
|
354 |
+
parser.add_argument("--scheduler",
|
355 |
+
type=str,
|
356 |
+
default="ddpm",
|
357 |
+
help="ddpm or ddim")
|
358 |
+
parser.add_argument("--vae",
|
359 |
+
type=str,
|
360 |
+
default=None,
|
361 |
+
help="pretrained VAE model for latent diffusion")
|
362 |
|
363 |
args = parser.parse_args()
|
364 |
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
scripts/train_vae.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pip install -e git+https://github.com/CompVis/stable-diffusion.git@master
|
2 |
+
# pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
3 |
+
|
4 |
+
# TODO
|
5 |
+
# grayscale
|
6 |
+
|
7 |
+
import os
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torchvision
|
12 |
+
import numpy as np
|
13 |
+
from PIL import Image
|
14 |
+
import pytorch_lightning as pl
|
15 |
+
from omegaconf import OmegaConf
|
16 |
+
from librosa.util import normalize
|
17 |
+
from ldm.util import instantiate_from_config
|
18 |
+
from pytorch_lightning.trainer import Trainer
|
19 |
+
from torch.utils.data import DataLoader, Dataset
|
20 |
+
from datasets import load_from_disk, load_dataset
|
21 |
+
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
22 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
23 |
+
|
24 |
+
from audiodiffusion.mel import Mel
|
25 |
+
from audiodiffusion.utils import convert_ldm_to_hf_vae
|
26 |
+
|
27 |
+
|
28 |
+
class AudioDiffusion(Dataset):
|
29 |
+
|
30 |
+
def __init__(self, model_id):
|
31 |
+
super().__init__()
|
32 |
+
if os.path.exists(model_id):
|
33 |
+
self.hf_dataset = load_from_disk(model_id)['train']
|
34 |
+
else:
|
35 |
+
self.hf_dataset = load_dataset(model_id)['train']
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.hf_dataset)
|
39 |
+
|
40 |
+
def __getitem__(self, idx):
|
41 |
+
image = self.hf_dataset[idx]['image'].convert('RGB')
|
42 |
+
image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
|
43 |
+
(image.height, image.width, 3))
|
44 |
+
image = ((image / 255) * 2 - 1)
|
45 |
+
return {'image': image}
|
46 |
+
|
47 |
+
|
48 |
+
class AudioDiffusionDataModule(pl.LightningDataModule):
|
49 |
+
|
50 |
+
def __init__(self, model_id, batch_size):
|
51 |
+
super().__init__()
|
52 |
+
self.batch_size = batch_size
|
53 |
+
self.dataset = AudioDiffusion(model_id)
|
54 |
+
self.num_workers = 1
|
55 |
+
|
56 |
+
def train_dataloader(self):
|
57 |
+
return DataLoader(self.dataset,
|
58 |
+
batch_size=self.batch_size,
|
59 |
+
num_workers=self.num_workers)
|
60 |
+
|
61 |
+
|
62 |
+
class ImageLogger(Callback):
|
63 |
+
|
64 |
+
def __init__(self, every=1000, resolution=256, hop_length=512):
|
65 |
+
super().__init__()
|
66 |
+
self.mel = Mel(x_res=resolution,
|
67 |
+
y_res=resolution,
|
68 |
+
hop_length=hop_length)
|
69 |
+
self.every = every
|
70 |
+
|
71 |
+
@rank_zero_only
|
72 |
+
def log_images_and_audios(self, pl_module, batch):
|
73 |
+
pl_module.eval()
|
74 |
+
with torch.no_grad():
|
75 |
+
images = pl_module.log_images(batch, split='train')
|
76 |
+
pl_module.train()
|
77 |
+
|
78 |
+
for k in images:
|
79 |
+
images[k] = images[k].detach().cpu()
|
80 |
+
images[k] = torch.clamp(images[k], -1., 1.)
|
81 |
+
images[k] = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
82 |
+
grid = torchvision.utils.make_grid(images[k])
|
83 |
+
|
84 |
+
tag = f"train/{k}"
|
85 |
+
pl_module.logger.experiment.add_image(
|
86 |
+
tag, grid, global_step=pl_module.global_step)
|
87 |
+
|
88 |
+
images[k] = (images[k].numpy() *
|
89 |
+
255).round().astype("uint8").transpose(0, 2, 3, 1)
|
90 |
+
for _, image in enumerate(images[k]):
|
91 |
+
audio = self.mel.image_to_audio(
|
92 |
+
Image.fromarray(image, mode='RGB').convert('L'))
|
93 |
+
pl_module.logger.experiment.add_audio(
|
94 |
+
tag + f"/{_}",
|
95 |
+
normalize(audio),
|
96 |
+
global_step=pl_module.global_step,
|
97 |
+
sample_rate=self.mel.get_sample_rate())
|
98 |
+
|
99 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch,
|
100 |
+
batch_idx):
|
101 |
+
if (batch_idx + 1) % self.every != 0:
|
102 |
+
return
|
103 |
+
self.log_images_and_audios(pl_module, batch)
|
104 |
+
|
105 |
+
|
106 |
+
class HFModelCheckpoint(ModelCheckpoint):
|
107 |
+
|
108 |
+
def __init__(self, ldm_config, hf_checkpoint, *args, **kwargs):
|
109 |
+
super().__init__(*args, **kwargs)
|
110 |
+
self.ldm_config = ldm_config
|
111 |
+
self.hf_checkpoint = hf_checkpoint
|
112 |
+
|
113 |
+
def on_train_epoch_end(self, trainer, pl_module):
|
114 |
+
super().on_train_epoch_end(trainer, pl_module)
|
115 |
+
ldm_checkpoint = self.format_checkpoint_name(
|
116 |
+
{'epoch': trainer.current_epoch})
|
117 |
+
convert_ldm_to_hf_vae(ldm_checkpoint, self.ldm_config,
|
118 |
+
self.hf_checkpoint)
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
parser = argparse.ArgumentParser(description="Train VAE using ldm.")
|
123 |
+
parser.add_argument("-d", "--dataset_name", type=str, default=None)
|
124 |
+
parser.add_argument("-b", "--batch_size", type=int, default=1)
|
125 |
+
parser.add_argument("-c",
|
126 |
+
"--ldm_config_file",
|
127 |
+
type=str,
|
128 |
+
default="config/ldm_autoencoder_kl.yaml")
|
129 |
+
parser.add_argument("--ldm_checkpoint_dir",
|
130 |
+
type=str,
|
131 |
+
default="models/ldm-autoencoder-kl")
|
132 |
+
parser.add_argument("--hf_checkpoint_dir",
|
133 |
+
type=str,
|
134 |
+
default="models/autoencoder-kl")
|
135 |
+
parser.add_argument("-r",
|
136 |
+
"--resume_from_checkpoint",
|
137 |
+
type=str,
|
138 |
+
default=None)
|
139 |
+
parser.add_argument("-g",
|
140 |
+
"--gradient_accumulation_steps",
|
141 |
+
type=int,
|
142 |
+
default=1)
|
143 |
+
args = parser.parse_args()
|
144 |
+
|
145 |
+
config = OmegaConf.load(args.ldm_config_file)
|
146 |
+
lightning_config = config.pop("lightning", OmegaConf.create())
|
147 |
+
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
148 |
+
trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps
|
149 |
+
trainer_opt = argparse.Namespace(**trainer_config)
|
150 |
+
trainer = Trainer.from_argparse_args(
|
151 |
+
trainer_opt,
|
152 |
+
resume_from_checkpoint=args.resume_from_checkpoint,
|
153 |
+
callbacks=[
|
154 |
+
ImageLogger(),
|
155 |
+
HFModelCheckpoint(ldm_config=config,
|
156 |
+
hf_checkpoint=args.hf_checkpoint_dir,
|
157 |
+
dirpath=args.ldm_checkpoint_dir,
|
158 |
+
filename='{epoch:06}',
|
159 |
+
verbose=True,
|
160 |
+
save_last=True)
|
161 |
+
])
|
162 |
+
model = instantiate_from_config(config.model)
|
163 |
+
model.learning_rate = config.model.base_learning_rate
|
164 |
+
data = AudioDiffusionDataModule(args.dataset_name,
|
165 |
+
batch_size=args.batch_size)
|
166 |
+
trainer.fit(model, data)
|