Spaces:
Runtime error
Runtime error
add conditional training
Browse files- README.md +38 -5
- audiodiffusion/__init__.py +49 -435
- audiodiffusion/audio_encoder.py +107 -0
- audiodiffusion/mel.py +164 -0
- audiodiffusion/pipeline_audio_diffusion.py +267 -0
- audiodiffusion/utils.py +103 -142
- config/ldm_autoencoder_kl.yaml +1 -1
- notebooks/audio_encoder.ipynb +69 -0
- notebooks/conditional_generation.ipynb +221 -0
- pyproject.toml +3 -0
- scripts/audio_to_images.py +45 -42
- scripts/encode_audio.py +38 -0
- scripts/{train_unconditional.py → train_unet.py} +106 -55
- scripts/train_vae.py +69 -80
README.md
CHANGED
@@ -23,7 +23,9 @@ Go to https://soundcloud.com/teticio2/sets/audio-diffusion-loops for more exampl
|
|
23 |
---
|
24 |
#### Updates
|
25 |
|
26 |
-
**
|
|
|
|
|
27 |
|
28 |
**2/12/2022**. Added Mel to pipeline and updated the pretrained models to save Mel config (they are now no longer compatible with previous versions of this repo). It is relatively straightforward to migrate previously trained models to the new format (see https://huggingface.co/teticio/audio-diffusion-256).
|
29 |
|
@@ -58,7 +60,8 @@ You can play around with some pre-trained models on [Google Colab](https://colab
|
|
58 |
| [teticio/audio-diffusion-instrumental-hiphop-256](https://huggingface.co/teticio/audio-diffusion-instrumental-hiphop-256) | [teticio/audio-diffusion-instrumental-hiphop-256](https://huggingface.co/datasets/teticio/audio-diffusion-instrumental-hiphop-256) | Instrumental Hip Hop music |
|
59 |
| [teticio/audio-diffusion-ddim-256](https://huggingface.co/teticio/audio-diffusion-ddim-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | De-noising Diffusion Implicit Model |
|
60 |
| [teticio/latent-audio-diffusion-256](https://huggingface.co/teticio/latent-audio-diffusion-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | Latent Audio Diffusion model |
|
61 |
-
| [teticio/latent-audio-diffusion-ddim-256](https://huggingface.co/teticio/latent-audio-diffusion-ddim-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | Latent Audio Diffusion
|
|
|
62 |
|
63 |
---
|
64 |
|
@@ -106,7 +109,7 @@ Note that the default `sample_rate` is 22050 and audios will be resampled if the
|
|
106 |
|
107 |
```bash
|
108 |
accelerate launch --config_file config/accelerate_local.yaml \
|
109 |
-
scripts/
|
110 |
--dataset_name data/audio-diffusion-64 \
|
111 |
--hop_length 1024 \
|
112 |
--output_dir models/ddpm-ema-audio-64 \
|
@@ -122,7 +125,7 @@ scripts/train_unconditional.py \
|
|
122 |
|
123 |
```bash
|
124 |
accelerate launch --config_file config/accelerate_local.yaml \
|
125 |
-
scripts/
|
126 |
--dataset_name teticio/audio-diffusion-256 \
|
127 |
--output_dir models/audio-diffusion-256 \
|
128 |
--num_epochs 100 \
|
@@ -141,7 +144,7 @@ scripts/train_unconditional.py \
|
|
141 |
|
142 |
```bash
|
143 |
accelerate launch --config_file config/accelerate_sagemaker.yaml \
|
144 |
-
scripts/
|
145 |
--dataset_name teticio/audio-diffusion-256 \
|
146 |
--output_dir models/ddpm-ema-audio-256 \
|
147 |
--train_batch_size 16 \
|
@@ -200,3 +203,33 @@ accelerate launch ...
|
|
200 |
...
|
201 |
--vae models/autoencoder-kl
|
202 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
---
|
24 |
#### Updates
|
25 |
|
26 |
+
**25/12/2022**. Now it is possible to train models conditional on an encoding (of text or audio, for example). See the section on Conditional Audio Generation below.
|
27 |
+
|
28 |
+
**5/12/2022**. 🤗 Exciting news! `AudioDiffusionPipeline` has been migrated to the Hugging Face `diffusers` package so that it is even easier for others to use and contribute.
|
29 |
|
30 |
**2/12/2022**. Added Mel to pipeline and updated the pretrained models to save Mel config (they are now no longer compatible with previous versions of this repo). It is relatively straightforward to migrate previously trained models to the new format (see https://huggingface.co/teticio/audio-diffusion-256).
|
31 |
|
|
|
60 |
| [teticio/audio-diffusion-instrumental-hiphop-256](https://huggingface.co/teticio/audio-diffusion-instrumental-hiphop-256) | [teticio/audio-diffusion-instrumental-hiphop-256](https://huggingface.co/datasets/teticio/audio-diffusion-instrumental-hiphop-256) | Instrumental Hip Hop music |
|
61 |
| [teticio/audio-diffusion-ddim-256](https://huggingface.co/teticio/audio-diffusion-ddim-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | De-noising Diffusion Implicit Model |
|
62 |
| [teticio/latent-audio-diffusion-256](https://huggingface.co/teticio/latent-audio-diffusion-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | Latent Audio Diffusion model |
|
63 |
+
| [teticio/latent-audio-diffusion-ddim-256](https://huggingface.co/teticio/latent-audio-diffusion-ddim-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | Latent Audio Diffusion Implicit Model |
|
64 |
+
| [teticio/conditional-latent-audio-diffusion-512](https://huggingface.co/teticio/latent-audio-diffusion-512) | [teticio/audio-diffusion-512](https://huggingface.co/datasets/teticio/audio-diffusion-512) | Conditional Latent Audio Diffusion Model |
|
65 |
|
66 |
---
|
67 |
|
|
|
109 |
|
110 |
```bash
|
111 |
accelerate launch --config_file config/accelerate_local.yaml \
|
112 |
+
scripts/train_unet.py \
|
113 |
--dataset_name data/audio-diffusion-64 \
|
114 |
--hop_length 1024 \
|
115 |
--output_dir models/ddpm-ema-audio-64 \
|
|
|
125 |
|
126 |
```bash
|
127 |
accelerate launch --config_file config/accelerate_local.yaml \
|
128 |
+
scripts/train_unet.py \
|
129 |
--dataset_name teticio/audio-diffusion-256 \
|
130 |
--output_dir models/audio-diffusion-256 \
|
131 |
--num_epochs 100 \
|
|
|
144 |
|
145 |
```bash
|
146 |
accelerate launch --config_file config/accelerate_sagemaker.yaml \
|
147 |
+
scripts/train_unet.py \
|
148 |
--dataset_name teticio/audio-diffusion-256 \
|
149 |
--output_dir models/ddpm-ema-audio-256 \
|
150 |
--train_batch_size 16 \
|
|
|
203 |
...
|
204 |
--vae models/autoencoder-kl
|
205 |
```
|
206 |
+
|
207 |
+
## Conditional Audio Generation
|
208 |
+
|
209 |
+
We can generate audio conditional on a text prompt - or indeed anything which can be encoded into a bunch of numbers - much like DALL-E2 and Midjourney. It is generally harder to find good quality datasets of audios together with descriptions, although the people behind the dataset used to train Midjourney are making some very interesting progress [here](https://github.com/LAION-AI/audio-dataset). I have chosen to encode the audio directly instead based on "how it sounds", using a [model which I trained on hundreds of thousands of Spotify playlists](https://github.com/teticio/Deej-AI). To encode an audio into a 100 dimensional vector
|
210 |
+
|
211 |
+
```python
|
212 |
+
from diffusers import Mel
|
213 |
+
from audiodiffusion.audio_encoder import AudioEncoder
|
214 |
+
|
215 |
+
audio_encoder = AudioEncoder.from_pretrained("teticio/audio-encoder")
|
216 |
+
audio_encoder.encode(['/home/teticio/Music/liked/Agua Re - Holy Dance - Large Sound Mix.mp3'])
|
217 |
+
```
|
218 |
+
|
219 |
+
One you have prepared a dataset, you can encode the audio files with this script
|
220 |
+
|
221 |
+
```bash
|
222 |
+
python scripts/encode_audio \
|
223 |
+
--dataset_name teticio/audio-diffusion-256 \
|
224 |
+
--out_file data/encodings.p
|
225 |
+
```
|
226 |
+
|
227 |
+
Then you can train a model with
|
228 |
+
|
229 |
+
```bash
|
230 |
+
accelerate launch ...
|
231 |
+
...
|
232 |
+
--encodings data/encodings.p
|
233 |
+
```
|
234 |
+
|
235 |
+
When generating audios, you will need to pass an `encodings` Tensor. See the [`conditional_generation.ipynb`](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/conditional_generation.ipynb) notebook for an example that uses encodings of Spotify track previews to influence the generation.
|
audiodiffusion/__init__.py
CHANGED
@@ -1,21 +1,24 @@
|
|
1 |
from typing import Iterable, Tuple
|
2 |
|
3 |
-
import torch
|
4 |
import numpy as np
|
|
|
|
|
5 |
from PIL import Image
|
6 |
from tqdm.auto import tqdm
|
7 |
-
from librosa.beat import beat_track
|
8 |
-
from diffusers import AudioDiffusionPipeline
|
9 |
|
10 |
-
|
|
|
11 |
|
|
|
12 |
|
13 |
-
class AudioDiffusion:
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
19 |
"""Class for generating audio using De-noising Diffusion Probabilistic Models.
|
20 |
|
21 |
Args:
|
@@ -35,7 +38,8 @@ class AudioDiffusion:
|
|
35 |
generator: torch.Generator = None,
|
36 |
step_generator: torch.Generator = None,
|
37 |
eta: float = 0,
|
38 |
-
noise: torch.Tensor = None
|
|
|
39 |
) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
|
40 |
"""Generate random mel spectrogram and convert to audio.
|
41 |
|
@@ -45,19 +49,22 @@ class AudioDiffusion:
|
|
45 |
step_generator (torch.Generator): random number generator used to de-noise or None
|
46 |
eta (float): parameter between 0 and 1 used with DDIM scheduler
|
47 |
noise (torch.Tensor): noisy image or None
|
|
|
48 |
|
49 |
Returns:
|
50 |
PIL Image: mel spectrogram
|
51 |
(float, np.ndarray): sample rate and raw audio
|
52 |
"""
|
53 |
-
images, (sample_rate,
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
return images[0], (sample_rate, audios[0])
|
62 |
|
63 |
def generate_spectrogram_and_audio_from_audio(
|
@@ -72,7 +79,8 @@ class AudioDiffusion:
|
|
72 |
mask_end_secs: float = 0,
|
73 |
step_generator: torch.Generator = None,
|
74 |
eta: float = 0,
|
75 |
-
|
|
|
76 |
) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
|
77 |
"""Generate random mel spectrogram from audio input and convert to audio.
|
78 |
|
@@ -87,6 +95,7 @@ class AudioDiffusion:
|
|
87 |
mask_end_secs (float): number of seconds of audio to mask (not generate) at end
|
88 |
step_generator (torch.Generator): random number generator used to de-noise or None
|
89 |
eta (float): parameter between 0 and 1 used with DDIM scheduler
|
|
|
90 |
noise (torch.Tensor): noisy image or None
|
91 |
|
92 |
Returns:
|
@@ -94,26 +103,26 @@ class AudioDiffusion:
|
|
94 |
(float, np.ndarray): sample rate and raw audio
|
95 |
"""
|
96 |
|
97 |
-
images, (sample_rate,
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
111 |
return images[0], (sample_rate, audios[0])
|
112 |
|
113 |
@staticmethod
|
114 |
-
def loop_it(audio: np.ndarray,
|
115 |
-
sample_rate: int,
|
116 |
-
loops: int = 12) -> np.ndarray:
|
117 |
"""Loop audio
|
118 |
|
119 |
Args:
|
@@ -124,403 +133,8 @@ class AudioDiffusion:
|
|
124 |
Returns:
|
125 |
(float, np.ndarray): sample rate and raw audio or None
|
126 |
"""
|
127 |
-
_, beats = beat_track(y=audio, sr=sample_rate, units=
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
return None
|
132 |
-
|
133 |
-
|
134 |
-
'''
|
135 |
-
# This code will be migrated to diffusers shortly
|
136 |
-
|
137 |
-
#-----------------------------------------------------------------------------#
|
138 |
-
|
139 |
-
import os
|
140 |
-
import warnings
|
141 |
-
from typing import Any, Dict, Optional, Union
|
142 |
-
|
143 |
-
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
144 |
-
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
145 |
-
|
146 |
-
|
147 |
-
warnings.filterwarnings("ignore")
|
148 |
-
|
149 |
-
import numpy as np # noqa: E402
|
150 |
-
|
151 |
-
import librosa # noqa: E402
|
152 |
-
from PIL import Image # noqa: E402
|
153 |
-
|
154 |
-
|
155 |
-
class Mel(ConfigMixin, SchedulerMixin):
|
156 |
-
"""
|
157 |
-
Parameters:
|
158 |
-
x_res (`int`): x resolution of spectrogram (time)
|
159 |
-
y_res (`int`): y resolution of spectrogram (frequency bins)
|
160 |
-
sample_rate (`int`): sample rate of audio
|
161 |
-
n_fft (`int`): number of Fast Fourier Transforms
|
162 |
-
hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res)
|
163 |
-
top_db (`int`): loudest in decibels
|
164 |
-
n_iter (`int`): number of iterations for Griffin Linn mel inversion
|
165 |
-
"""
|
166 |
-
|
167 |
-
config_name = "mel_config.json"
|
168 |
-
|
169 |
-
@register_to_config
|
170 |
-
def __init__(
|
171 |
-
self,
|
172 |
-
x_res: int = 256,
|
173 |
-
y_res: int = 256,
|
174 |
-
sample_rate: int = 22050,
|
175 |
-
n_fft: int = 2048,
|
176 |
-
hop_length: int = 512,
|
177 |
-
top_db: int = 80,
|
178 |
-
n_iter: int = 32,
|
179 |
-
):
|
180 |
-
self.hop_length = hop_length
|
181 |
-
self.sr = sample_rate
|
182 |
-
self.n_fft = n_fft
|
183 |
-
self.top_db = top_db
|
184 |
-
self.n_iter = n_iter
|
185 |
-
self.set_resolution(x_res, y_res)
|
186 |
-
self.audio = None
|
187 |
-
|
188 |
-
def set_resolution(self, x_res: int, y_res: int):
|
189 |
-
"""Set resolution.
|
190 |
-
|
191 |
-
Args:
|
192 |
-
x_res (`int`): x resolution of spectrogram (time)
|
193 |
-
y_res (`int`): y resolution of spectrogram (frequency bins)
|
194 |
-
"""
|
195 |
-
self.x_res = x_res
|
196 |
-
self.y_res = y_res
|
197 |
-
self.n_mels = self.y_res
|
198 |
-
self.slice_size = self.x_res * self.hop_length - 1
|
199 |
-
|
200 |
-
def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
|
201 |
-
"""Load audio.
|
202 |
-
|
203 |
-
Args:
|
204 |
-
audio_file (`str`): must be a file on disk due to Librosa limitation or
|
205 |
-
raw_audio (`np.ndarray`): audio as numpy array
|
206 |
-
"""
|
207 |
-
if audio_file is not None:
|
208 |
-
self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr)
|
209 |
-
else:
|
210 |
-
self.audio = raw_audio
|
211 |
-
|
212 |
-
# Pad with silence if necessary.
|
213 |
-
if len(self.audio) < self.x_res * self.hop_length:
|
214 |
-
self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))])
|
215 |
-
|
216 |
-
def get_number_of_slices(self) -> int:
|
217 |
-
"""Get number of slices in audio.
|
218 |
-
|
219 |
-
Returns:
|
220 |
-
`int`: number of spectograms audio can be sliced into
|
221 |
-
"""
|
222 |
-
return len(self.audio) // self.slice_size
|
223 |
-
|
224 |
-
def get_audio_slice(self, slice: int = 0) -> np.ndarray:
|
225 |
-
"""Get slice of audio.
|
226 |
-
|
227 |
-
Args:
|
228 |
-
slice (`int`): slice number of audio (out of get_number_of_slices())
|
229 |
-
|
230 |
-
Returns:
|
231 |
-
`np.ndarray`: audio as numpy array
|
232 |
-
"""
|
233 |
-
return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)]
|
234 |
-
|
235 |
-
def get_sample_rate(self) -> int:
|
236 |
-
"""Get sample rate:
|
237 |
-
|
238 |
-
Returns:
|
239 |
-
`int`: sample rate of audio
|
240 |
-
"""
|
241 |
-
return self.sr
|
242 |
-
|
243 |
-
def audio_slice_to_image(self, slice: int) -> Image.Image:
|
244 |
-
"""Convert slice of audio to spectrogram.
|
245 |
-
|
246 |
-
Args:
|
247 |
-
slice (`int`): slice number of audio to convert (out of get_number_of_slices())
|
248 |
-
|
249 |
-
Returns:
|
250 |
-
`PIL Image`: grayscale image of x_res x y_res
|
251 |
-
"""
|
252 |
-
S = librosa.feature.melspectrogram(
|
253 |
-
y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels
|
254 |
-
)
|
255 |
-
log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
|
256 |
-
bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
|
257 |
-
image = Image.fromarray(bytedata)
|
258 |
-
return image
|
259 |
-
|
260 |
-
def image_to_audio(self, image: Image.Image) -> np.ndarray:
|
261 |
-
"""Converts spectrogram to audio.
|
262 |
-
|
263 |
-
Args:
|
264 |
-
image (`PIL Image`): x_res x y_res grayscale image
|
265 |
-
|
266 |
-
Returns:
|
267 |
-
audio (`np.ndarray`): raw audio
|
268 |
-
"""
|
269 |
-
bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width))
|
270 |
-
log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
|
271 |
-
S = librosa.db_to_power(log_S)
|
272 |
-
audio = librosa.feature.inverse.mel_to_audio(
|
273 |
-
S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter
|
274 |
-
)
|
275 |
-
return audio
|
276 |
-
|
277 |
-
#-----------------------------------------------------------------------------#
|
278 |
-
|
279 |
-
from math import acos, sin
|
280 |
-
from typing import List, Tuple, Union
|
281 |
-
|
282 |
-
import numpy as np
|
283 |
-
import torch
|
284 |
-
|
285 |
-
from PIL import Image
|
286 |
-
|
287 |
-
from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline, DDIMScheduler, DDPMScheduler
|
288 |
-
from diffusers.pipeline_utils import AudioPipelineOutput, BaseOutput, ImagePipelineOutput
|
289 |
-
|
290 |
-
|
291 |
-
class AudioDiffusionPipeline(DiffusionPipeline):
|
292 |
-
"""
|
293 |
-
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
294 |
-
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
295 |
-
|
296 |
-
Parameters:
|
297 |
-
vqae ([`AutoencoderKL`]): Variational AutoEncoder for Latent Audio Diffusion or None
|
298 |
-
unet ([`UNet2DConditionModel`]): UNET model
|
299 |
-
mel ([`Mel`]): transform audio <-> spectrogram
|
300 |
-
scheduler ([`DDIMScheduler` or `DDPMScheduler`]): de-noising scheduler
|
301 |
-
"""
|
302 |
-
|
303 |
-
_optional_components = ["vqvae"]
|
304 |
-
|
305 |
-
def __init__(
|
306 |
-
self,
|
307 |
-
vqvae: AutoencoderKL,
|
308 |
-
unet: UNet2DConditionModel,
|
309 |
-
mel: Mel,
|
310 |
-
scheduler: Union[DDIMScheduler, DDPMScheduler],
|
311 |
-
):
|
312 |
-
super().__init__()
|
313 |
-
self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae)
|
314 |
-
|
315 |
-
def get_input_dims(self) -> Tuple:
|
316 |
-
"""Returns dimension of input image
|
317 |
-
|
318 |
-
Returns:
|
319 |
-
`Tuple`: (height, width)
|
320 |
-
"""
|
321 |
-
input_module = self.vqvae if self.vqvae is not None else self.unet
|
322 |
-
# For backwards compatibility
|
323 |
-
sample_size = (
|
324 |
-
(input_module.sample_size, input_module.sample_size)
|
325 |
-
if type(input_module.sample_size) == int
|
326 |
-
else input_module.sample_size
|
327 |
-
)
|
328 |
-
return sample_size
|
329 |
-
|
330 |
-
def get_default_steps(self) -> int:
|
331 |
-
"""Returns default number of steps recommended for inference
|
332 |
-
|
333 |
-
Returns:
|
334 |
-
`int`: number of steps
|
335 |
-
"""
|
336 |
-
return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000
|
337 |
-
|
338 |
-
@torch.no_grad()
|
339 |
-
def __call__(
|
340 |
-
self,
|
341 |
-
batch_size: int = 1,
|
342 |
-
audio_file: str = None,
|
343 |
-
raw_audio: np.ndarray = None,
|
344 |
-
slice: int = 0,
|
345 |
-
start_step: int = 0,
|
346 |
-
steps: int = None,
|
347 |
-
generator: torch.Generator = None,
|
348 |
-
mask_start_secs: float = 0,
|
349 |
-
mask_end_secs: float = 0,
|
350 |
-
step_generator: torch.Generator = None,
|
351 |
-
eta: float = 0,
|
352 |
-
noise: torch.Tensor = None,
|
353 |
-
return_dict=True,
|
354 |
-
) -> Union[
|
355 |
-
Union[AudioPipelineOutput, ImagePipelineOutput], Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]
|
356 |
-
]:
|
357 |
-
"""Generate random mel spectrogram from audio input and convert to audio.
|
358 |
-
|
359 |
-
Args:
|
360 |
-
batch_size (`int`): number of samples to generate
|
361 |
-
audio_file (`str`): must be a file on disk due to Librosa limitation or
|
362 |
-
raw_audio (`np.ndarray`): audio as numpy array
|
363 |
-
slice (`int`): slice number of audio to convert
|
364 |
-
start_step (int): step to start from
|
365 |
-
steps (`int`): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM)
|
366 |
-
generator (`torch.Generator`): random number generator or None
|
367 |
-
mask_start_secs (`float`): number of seconds of audio to mask (not generate) at start
|
368 |
-
mask_end_secs (`float`): number of seconds of audio to mask (not generate) at end
|
369 |
-
step_generator (`torch.Generator`): random number generator used to de-noise or None
|
370 |
-
eta (`float`): parameter between 0 and 1 used with DDIM scheduler
|
371 |
-
noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
|
372 |
-
return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
|
373 |
-
|
374 |
-
Returns:
|
375 |
-
`List[PIL Image]`: mel spectrograms (`float`, `List[np.ndarray]`): sample rate and raw audios
|
376 |
-
"""
|
377 |
-
|
378 |
-
steps = steps or self.get_default_steps()
|
379 |
-
self.scheduler.set_timesteps(steps)
|
380 |
-
step_generator = step_generator or generator
|
381 |
-
# For backwards compatibility
|
382 |
-
if type(self.unet.sample_size) == int:
|
383 |
-
self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size)
|
384 |
-
input_dims = self.get_input_dims()
|
385 |
-
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
|
386 |
-
if noise is None:
|
387 |
-
noise = torch.randn(
|
388 |
-
(batch_size, self.unet.in_channels, self.unet.sample_size[0], self.unet.sample_size[1]),
|
389 |
-
generator=generator,
|
390 |
-
device=self.device,
|
391 |
-
)
|
392 |
-
images = noise
|
393 |
-
mask = None
|
394 |
-
|
395 |
-
if audio_file is not None or raw_audio is not None:
|
396 |
-
self.mel.load_audio(audio_file, raw_audio)
|
397 |
-
input_image = self.mel.audio_slice_to_image(slice)
|
398 |
-
input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape(
|
399 |
-
(input_image.height, input_image.width)
|
400 |
-
)
|
401 |
-
input_image = (input_image / 255) * 2 - 1
|
402 |
-
input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device)
|
403 |
-
|
404 |
-
if self.vqvae is not None:
|
405 |
-
input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
|
406 |
-
generator=generator
|
407 |
-
)[0]
|
408 |
-
input_images = 0.18215 * input_images
|
409 |
-
|
410 |
-
if start_step > 0:
|
411 |
-
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
|
412 |
-
|
413 |
-
pixels_per_second = (
|
414 |
-
self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
|
415 |
-
)
|
416 |
-
mask_start = int(mask_start_secs * pixels_per_second)
|
417 |
-
mask_end = int(mask_end_secs * pixels_per_second)
|
418 |
-
mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
|
419 |
-
|
420 |
-
for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])):
|
421 |
-
model_output = self.unet(images, t)["sample"]
|
422 |
-
|
423 |
-
if isinstance(self.scheduler, DDIMScheduler):
|
424 |
-
images = self.scheduler.step(
|
425 |
-
model_output=model_output, timestep=t, sample=images, eta=eta, generator=step_generator
|
426 |
-
)["prev_sample"]
|
427 |
-
else:
|
428 |
-
images = self.scheduler.step(
|
429 |
-
model_output=model_output, timestep=t, sample=images, generator=step_generator
|
430 |
-
)["prev_sample"]
|
431 |
-
|
432 |
-
if mask is not None:
|
433 |
-
if mask_start > 0:
|
434 |
-
images[:, :, :, :mask_start] = mask[:, step, :, :mask_start]
|
435 |
-
if mask_end > 0:
|
436 |
-
images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:]
|
437 |
-
|
438 |
-
if self.vqvae is not None:
|
439 |
-
# 0.18215 was scaling factor used in training to ensure unit variance
|
440 |
-
images = 1 / 0.18215 * images
|
441 |
-
images = self.vqvae.decode(images)["sample"]
|
442 |
-
|
443 |
-
images = (images / 2 + 0.5).clamp(0, 1)
|
444 |
-
images = images.cpu().permute(0, 2, 3, 1).numpy()
|
445 |
-
images = (images * 255).round().astype("uint8")
|
446 |
-
images = list(
|
447 |
-
map(lambda _: Image.fromarray(_[:, :, 0]), images)
|
448 |
-
if images.shape[3] == 1
|
449 |
-
else map(lambda _: Image.fromarray(_, mode="RGB").convert("L"), images)
|
450 |
-
)
|
451 |
-
|
452 |
-
audios = list(map(lambda _: self.mel.image_to_audio(_), images))
|
453 |
-
if not return_dict:
|
454 |
-
return images, (self.mel.get_sample_rate(), audios)
|
455 |
-
|
456 |
-
return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images))
|
457 |
-
|
458 |
-
@torch.no_grad()
|
459 |
-
def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray:
|
460 |
-
"""Reverse step process: recover noisy image from generated image.
|
461 |
-
|
462 |
-
Args:
|
463 |
-
images (`List[PIL Image]`): list of images to encode
|
464 |
-
steps (`int`): number of encoding steps to perform (defaults to 50)
|
465 |
-
|
466 |
-
Returns:
|
467 |
-
`np.ndarray`: noise tensor of shape (batch_size, 1, height, width)
|
468 |
-
"""
|
469 |
-
|
470 |
-
# Only works with DDIM as this method is deterministic
|
471 |
-
assert isinstance(self.scheduler, DDIMScheduler)
|
472 |
-
self.scheduler.set_timesteps(steps)
|
473 |
-
sample = np.array(
|
474 |
-
[np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images]
|
475 |
-
)
|
476 |
-
sample = (sample / 255) * 2 - 1
|
477 |
-
sample = torch.Tensor(sample).to(self.device)
|
478 |
-
|
479 |
-
for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))):
|
480 |
-
prev_timestep = t - self.scheduler.num_train_timesteps // self.scheduler.num_inference_steps
|
481 |
-
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
482 |
-
alpha_prod_t_prev = (
|
483 |
-
self.scheduler.alphas_cumprod[prev_timestep]
|
484 |
-
if prev_timestep >= 0
|
485 |
-
else self.scheduler.final_alpha_cumprod
|
486 |
-
)
|
487 |
-
beta_prod_t = 1 - alpha_prod_t
|
488 |
-
model_output = self.unet(sample, t)["sample"]
|
489 |
-
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output
|
490 |
-
sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5)
|
491 |
-
sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output
|
492 |
-
|
493 |
-
return sample
|
494 |
-
|
495 |
-
@staticmethod
|
496 |
-
def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor:
|
497 |
-
"""Spherical Linear intERPolation
|
498 |
-
|
499 |
-
Args:
|
500 |
-
x0 (`torch.Tensor`): first tensor to interpolate between
|
501 |
-
x1 (`torch.Tensor`): seconds tensor to interpolate between
|
502 |
-
alpha (`float`): interpolation between 0 and 1
|
503 |
-
|
504 |
-
Returns:
|
505 |
-
`torch.Tensor`: interpolated tensor
|
506 |
-
"""
|
507 |
-
|
508 |
-
theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1))
|
509 |
-
return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta)
|
510 |
-
|
511 |
-
|
512 |
-
import sys
|
513 |
-
import diffusers
|
514 |
-
|
515 |
-
class audio_diffusion():
|
516 |
-
__name__ = 'audio_diffusion'
|
517 |
-
pass
|
518 |
-
|
519 |
-
|
520 |
-
sys.modules['audio_diffusion'] = audio_diffusion
|
521 |
-
setattr(audio_diffusion, Mel.__name__, Mel)
|
522 |
-
diffusers.AudioDiffusionPipeline = AudioDiffusionPipeline
|
523 |
-
setattr(diffusers, AudioDiffusionPipeline.__name__, AudioDiffusionPipeline)
|
524 |
-
diffusers.pipeline_utils.LOADABLE_CLASSES['audio_diffusion'] = {}
|
525 |
-
diffusers.pipeline_utils.LOADABLE_CLASSES['audio_diffusion']['Mel'] = ["save_pretrained", "from_pretrained"]
|
526 |
-
'''
|
|
|
1 |
from typing import Iterable, Tuple
|
2 |
|
|
|
3 |
import numpy as np
|
4 |
+
import torch
|
5 |
+
from librosa.beat import beat_track
|
6 |
from PIL import Image
|
7 |
from tqdm.auto import tqdm
|
|
|
|
|
8 |
|
9 |
+
# from diffusers import AudioDiffusionPipeline
|
10 |
+
from .pipeline_audio_diffusion import AudioDiffusionPipeline
|
11 |
|
12 |
+
VERSION = "1.4.0"
|
13 |
|
|
|
14 |
|
15 |
+
class AudioDiffusion:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
model_id: str = "teticio/audio-diffusion-256",
|
19 |
+
cuda: bool = torch.cuda.is_available(),
|
20 |
+
progress_bar: Iterable = tqdm,
|
21 |
+
):
|
22 |
"""Class for generating audio using De-noising Diffusion Probabilistic Models.
|
23 |
|
24 |
Args:
|
|
|
38 |
generator: torch.Generator = None,
|
39 |
step_generator: torch.Generator = None,
|
40 |
eta: float = 0,
|
41 |
+
noise: torch.Tensor = None,
|
42 |
+
encoding: torch.Tensor = None,
|
43 |
) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
|
44 |
"""Generate random mel spectrogram and convert to audio.
|
45 |
|
|
|
49 |
step_generator (torch.Generator): random number generator used to de-noise or None
|
50 |
eta (float): parameter between 0 and 1 used with DDIM scheduler
|
51 |
noise (torch.Tensor): noisy image or None
|
52 |
+
encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
|
53 |
|
54 |
Returns:
|
55 |
PIL Image: mel spectrogram
|
56 |
(float, np.ndarray): sample rate and raw audio
|
57 |
"""
|
58 |
+
images, (sample_rate, audios) = self.pipe(
|
59 |
+
batch_size=1,
|
60 |
+
steps=steps,
|
61 |
+
generator=generator,
|
62 |
+
step_generator=step_generator,
|
63 |
+
eta=eta,
|
64 |
+
noise=noise,
|
65 |
+
encoding=encoding,
|
66 |
+
return_dict=False,
|
67 |
+
)
|
68 |
return images[0], (sample_rate, audios[0])
|
69 |
|
70 |
def generate_spectrogram_and_audio_from_audio(
|
|
|
79 |
mask_end_secs: float = 0,
|
80 |
step_generator: torch.Generator = None,
|
81 |
eta: float = 0,
|
82 |
+
encoding: torch.Tensor = None,
|
83 |
+
noise: torch.Tensor = None,
|
84 |
) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
|
85 |
"""Generate random mel spectrogram from audio input and convert to audio.
|
86 |
|
|
|
95 |
mask_end_secs (float): number of seconds of audio to mask (not generate) at end
|
96 |
step_generator (torch.Generator): random number generator used to de-noise or None
|
97 |
eta (float): parameter between 0 and 1 used with DDIM scheduler
|
98 |
+
encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
|
99 |
noise (torch.Tensor): noisy image or None
|
100 |
|
101 |
Returns:
|
|
|
103 |
(float, np.ndarray): sample rate and raw audio
|
104 |
"""
|
105 |
|
106 |
+
images, (sample_rate, audios) = self.pipe(
|
107 |
+
batch_size=1,
|
108 |
+
audio_file=audio_file,
|
109 |
+
raw_audio=raw_audio,
|
110 |
+
slice=slice,
|
111 |
+
start_step=start_step,
|
112 |
+
steps=steps,
|
113 |
+
generator=generator,
|
114 |
+
mask_start_secs=mask_start_secs,
|
115 |
+
mask_end_secs=mask_end_secs,
|
116 |
+
step_generator=step_generator,
|
117 |
+
eta=eta,
|
118 |
+
noise=noise,
|
119 |
+
encoding=encoding,
|
120 |
+
return_dict=False,
|
121 |
+
)
|
122 |
return images[0], (sample_rate, audios[0])
|
123 |
|
124 |
@staticmethod
|
125 |
+
def loop_it(audio: np.ndarray, sample_rate: int, loops: int = 12) -> np.ndarray:
|
|
|
|
|
126 |
"""Loop audio
|
127 |
|
128 |
Args:
|
|
|
133 |
Returns:
|
134 |
(float, np.ndarray): sample rate and raw audio or None
|
135 |
"""
|
136 |
+
_, beats = beat_track(y=audio, sr=sample_rate, units="samples")
|
137 |
+
beats_in_bar = (len(beats) - 1) // 4 * 4
|
138 |
+
if beats_in_bar > 0:
|
139 |
+
return np.tile(audio[beats[0] : beats[beats_in_bar]], loops)
|
140 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audiodiffusion/audio_encoder.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from diffusers import ConfigMixin, Mel, ModelMixin
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
class SeparableConv2d(nn.Module):
|
8 |
+
def __init__(self, in_channels, out_channels, kernel_size):
|
9 |
+
super(SeparableConv2d, self).__init__()
|
10 |
+
self.depthwise = nn.Conv2d(
|
11 |
+
in_channels,
|
12 |
+
in_channels,
|
13 |
+
kernel_size=kernel_size,
|
14 |
+
groups=in_channels,
|
15 |
+
bias=False,
|
16 |
+
padding=1,
|
17 |
+
)
|
18 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
out = self.depthwise(x)
|
22 |
+
out = self.pointwise(out)
|
23 |
+
return out
|
24 |
+
|
25 |
+
|
26 |
+
class ConvBlock(nn.Module):
|
27 |
+
def __init__(self, in_channels, out_channels, dropout_rate):
|
28 |
+
super(ConvBlock, self).__init__()
|
29 |
+
self.sep_conv = SeparableConv2d(in_channels, out_channels, (3, 3))
|
30 |
+
self.leaky_relu = nn.LeakyReLU(0.2)
|
31 |
+
self.batch_norm = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.01)
|
32 |
+
self.max_pool = nn.MaxPool2d((2, 2))
|
33 |
+
self.dropout = nn.Dropout(dropout_rate)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.sep_conv(x)
|
37 |
+
x = self.leaky_relu(x)
|
38 |
+
x = self.batch_norm(x)
|
39 |
+
x = self.max_pool(x)
|
40 |
+
x = self.dropout(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DenseBlock(nn.Module):
|
45 |
+
def __init__(self, in_features, out_features, dropout_rate):
|
46 |
+
super(DenseBlock, self).__init__()
|
47 |
+
self.flatten = nn.Flatten()
|
48 |
+
self.dense = nn.Linear(in_features, out_features)
|
49 |
+
self.leaky_relu = nn.LeakyReLU(0.2)
|
50 |
+
self.batch_norm = nn.BatchNorm1d(out_features, eps=0.001, momentum=0.01)
|
51 |
+
self.dropout = nn.Dropout(dropout_rate)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
x = self.flatten(x.permute(0, 2, 3, 1))
|
55 |
+
x = self.dense(x)
|
56 |
+
x = self.leaky_relu(x)
|
57 |
+
x = self.batch_norm(x)
|
58 |
+
x = self.dropout(x)
|
59 |
+
return x
|
60 |
+
|
61 |
+
|
62 |
+
class AudioEncoder(ModelMixin, ConfigMixin):
|
63 |
+
def __init__(self):
|
64 |
+
super().__init__()
|
65 |
+
self.mel = Mel(
|
66 |
+
x_res=216,
|
67 |
+
y_res=96,
|
68 |
+
sample_rate=22050,
|
69 |
+
n_fft=2048,
|
70 |
+
hop_length=512,
|
71 |
+
top_db=80,
|
72 |
+
)
|
73 |
+
self.conv_blocks = nn.ModuleList([ConvBlock(1, 32, 0.2), ConvBlock(32, 64, 0.3), ConvBlock(64, 128, 0.4)])
|
74 |
+
self.dense_block = DenseBlock(41472, 1024, 0.5)
|
75 |
+
self.embedding = nn.Linear(1024, 100)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
for conv_block in self.conv_blocks:
|
79 |
+
x = conv_block(x)
|
80 |
+
x = self.dense_block(x)
|
81 |
+
x = self.embedding(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
@torch.no_grad()
|
85 |
+
def encode(self, audio_files):
|
86 |
+
self.eval()
|
87 |
+
y = []
|
88 |
+
for audio_file in audio_files:
|
89 |
+
self.mel.load_audio(audio_file)
|
90 |
+
x = [
|
91 |
+
np.expand_dims(
|
92 |
+
np.frombuffer(self.mel.audio_slice_to_image(slice).tobytes(), dtype="uint8").reshape(
|
93 |
+
(self.mel.y_res, self.mel.x_res)
|
94 |
+
)
|
95 |
+
/ 255,
|
96 |
+
axis=0,
|
97 |
+
)
|
98 |
+
for slice in range(self.mel.get_number_of_slices())
|
99 |
+
]
|
100 |
+
y += [torch.mean(self(torch.Tensor(x)), dim=0)]
|
101 |
+
return torch.stack(y)
|
102 |
+
|
103 |
+
|
104 |
+
# from diffusers import Mel
|
105 |
+
# from audiodiffusion.audio_encoder import AudioEncoder
|
106 |
+
# audio_encoder = AudioEncoder.from_pretrained("teticio/audio-encoder")
|
107 |
+
# audio_encoder.encode(['/home/teticio/Music/liked/Agua Re - Holy Dance - Large Sound Mix.mp3'])
|
audiodiffusion/mel.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This code has been migrated to diffusers but can be run locally with
|
2 |
+
# pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256", custom_pipeline="audio-diffusion/audiodiffusion/pipeline_audio_diffusion.py")
|
3 |
+
|
4 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
import warnings
|
20 |
+
|
21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
22 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
23 |
+
|
24 |
+
warnings.filterwarnings("ignore")
|
25 |
+
|
26 |
+
import librosa # noqa: E402
|
27 |
+
import numpy as np # noqa: E402
|
28 |
+
from PIL import Image # noqa: E402
|
29 |
+
|
30 |
+
|
31 |
+
class Mel(ConfigMixin, SchedulerMixin):
|
32 |
+
"""
|
33 |
+
Parameters:
|
34 |
+
x_res (`int`): x resolution of spectrogram (time)
|
35 |
+
y_res (`int`): y resolution of spectrogram (frequency bins)
|
36 |
+
sample_rate (`int`): sample rate of audio
|
37 |
+
n_fft (`int`): number of Fast Fourier Transforms
|
38 |
+
hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res)
|
39 |
+
top_db (`int`): loudest in decibels
|
40 |
+
n_iter (`int`): number of iterations for Griffin Linn mel inversion
|
41 |
+
"""
|
42 |
+
|
43 |
+
config_name = "mel_config.json"
|
44 |
+
|
45 |
+
@register_to_config
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
x_res: int = 256,
|
49 |
+
y_res: int = 256,
|
50 |
+
sample_rate: int = 22050,
|
51 |
+
n_fft: int = 2048,
|
52 |
+
hop_length: int = 512,
|
53 |
+
top_db: int = 80,
|
54 |
+
n_iter: int = 32,
|
55 |
+
):
|
56 |
+
self.hop_length = hop_length
|
57 |
+
self.sr = sample_rate
|
58 |
+
self.n_fft = n_fft
|
59 |
+
self.top_db = top_db
|
60 |
+
self.n_iter = n_iter
|
61 |
+
self.set_resolution(x_res, y_res)
|
62 |
+
self.audio = None
|
63 |
+
|
64 |
+
def set_resolution(self, x_res: int, y_res: int):
|
65 |
+
"""Set resolution.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
x_res (`int`): x resolution of spectrogram (time)
|
69 |
+
y_res (`int`): y resolution of spectrogram (frequency bins)
|
70 |
+
"""
|
71 |
+
self.x_res = x_res
|
72 |
+
self.y_res = y_res
|
73 |
+
self.n_mels = self.y_res
|
74 |
+
self.slice_size = self.x_res * self.hop_length - 1
|
75 |
+
|
76 |
+
def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
|
77 |
+
"""Load audio.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
audio_file (`str`): must be a file on disk due to Librosa limitation or
|
81 |
+
raw_audio (`np.ndarray`): audio as numpy array
|
82 |
+
"""
|
83 |
+
if audio_file is not None:
|
84 |
+
self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr)
|
85 |
+
else:
|
86 |
+
self.audio = raw_audio
|
87 |
+
|
88 |
+
# Pad with silence if necessary.
|
89 |
+
if len(self.audio) < self.x_res * self.hop_length:
|
90 |
+
self.audio = np.concatenate(
|
91 |
+
[
|
92 |
+
self.audio,
|
93 |
+
np.zeros((self.x_res * self.hop_length - len(self.audio),)),
|
94 |
+
]
|
95 |
+
)
|
96 |
+
|
97 |
+
def get_number_of_slices(self) -> int:
|
98 |
+
"""Get number of slices in audio.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
`int`: number of spectograms audio can be sliced into
|
102 |
+
"""
|
103 |
+
return len(self.audio) // self.slice_size
|
104 |
+
|
105 |
+
def get_audio_slice(self, slice: int = 0) -> np.ndarray:
|
106 |
+
"""Get slice of audio.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
slice (`int`): slice number of audio (out of get_number_of_slices())
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
`np.ndarray`: audio as numpy array
|
113 |
+
"""
|
114 |
+
return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)]
|
115 |
+
|
116 |
+
def get_sample_rate(self) -> int:
|
117 |
+
"""Get sample rate:
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
`int`: sample rate of audio
|
121 |
+
"""
|
122 |
+
return self.sr
|
123 |
+
|
124 |
+
def audio_slice_to_image(self, slice: int) -> Image.Image:
|
125 |
+
"""Convert slice of audio to spectrogram.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
slice (`int`): slice number of audio to convert (out of get_number_of_slices())
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
`PIL Image`: grayscale image of x_res x y_res
|
132 |
+
"""
|
133 |
+
S = librosa.feature.melspectrogram(
|
134 |
+
y=self.get_audio_slice(slice),
|
135 |
+
sr=self.sr,
|
136 |
+
n_fft=self.n_fft,
|
137 |
+
hop_length=self.hop_length,
|
138 |
+
n_mels=self.n_mels,
|
139 |
+
)
|
140 |
+
log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
|
141 |
+
bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
|
142 |
+
image = Image.fromarray(bytedata)
|
143 |
+
return image
|
144 |
+
|
145 |
+
def image_to_audio(self, image: Image.Image) -> np.ndarray:
|
146 |
+
"""Converts spectrogram to audio.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
image (`PIL Image`): x_res x y_res grayscale image
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
audio (`np.ndarray`): raw audio
|
153 |
+
"""
|
154 |
+
bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width))
|
155 |
+
log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
|
156 |
+
S = librosa.db_to_power(log_S)
|
157 |
+
audio = librosa.feature.inverse.mel_to_audio(
|
158 |
+
S,
|
159 |
+
sr=self.sr,
|
160 |
+
n_fft=self.n_fft,
|
161 |
+
hop_length=self.hop_length,
|
162 |
+
n_iter=self.n_iter,
|
163 |
+
)
|
164 |
+
return audio
|
audiodiffusion/pipeline_audio_diffusion.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This code has been migrated to diffusers but can be run locally with
|
2 |
+
# pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256", custom_pipeline="audio-diffusion/audiodiffusion/pipeline_audio_diffusion.py")
|
3 |
+
|
4 |
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
|
19 |
+
from math import acos, sin
|
20 |
+
from typing import List, Tuple, Union
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, Mel, UNet2DConditionModel
|
25 |
+
from diffusers.pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
|
26 |
+
from PIL import Image
|
27 |
+
|
28 |
+
from .mel import Mel
|
29 |
+
|
30 |
+
|
31 |
+
class AudioDiffusionPipeline(DiffusionPipeline):
|
32 |
+
"""
|
33 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
34 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
35 |
+
|
36 |
+
Parameters:
|
37 |
+
vqae ([`AutoencoderKL`]): Variational AutoEncoder for Latent Audio Diffusion or None
|
38 |
+
unet ([`UNet2DConditionModel`]): UNET model
|
39 |
+
mel ([`Mel`]): transform audio <-> spectrogram
|
40 |
+
scheduler ([`DDIMScheduler` or `DDPMScheduler`]): de-noising scheduler
|
41 |
+
"""
|
42 |
+
|
43 |
+
_optional_components = ["vqvae"]
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
vqvae: AutoencoderKL,
|
48 |
+
unet: UNet2DConditionModel,
|
49 |
+
mel: Mel,
|
50 |
+
scheduler: Union[DDIMScheduler, DDPMScheduler],
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae)
|
54 |
+
|
55 |
+
def get_input_dims(self) -> Tuple:
|
56 |
+
"""Returns dimension of input image
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
`Tuple`: (height, width)
|
60 |
+
"""
|
61 |
+
input_module = self.vqvae if self.vqvae is not None else self.unet
|
62 |
+
# For backwards compatibility
|
63 |
+
sample_size = (
|
64 |
+
(input_module.sample_size, input_module.sample_size)
|
65 |
+
if type(input_module.sample_size) == int
|
66 |
+
else input_module.sample_size
|
67 |
+
)
|
68 |
+
return sample_size
|
69 |
+
|
70 |
+
def get_default_steps(self) -> int:
|
71 |
+
"""Returns default number of steps recommended for inference
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
`int`: number of steps
|
75 |
+
"""
|
76 |
+
return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000
|
77 |
+
|
78 |
+
@torch.no_grad()
|
79 |
+
def __call__(
|
80 |
+
self,
|
81 |
+
batch_size: int = 1,
|
82 |
+
audio_file: str = None,
|
83 |
+
raw_audio: np.ndarray = None,
|
84 |
+
slice: int = 0,
|
85 |
+
start_step: int = 0,
|
86 |
+
steps: int = None,
|
87 |
+
generator: torch.Generator = None,
|
88 |
+
mask_start_secs: float = 0,
|
89 |
+
mask_end_secs: float = 0,
|
90 |
+
step_generator: torch.Generator = None,
|
91 |
+
eta: float = 0,
|
92 |
+
noise: torch.Tensor = None,
|
93 |
+
encoding: torch.Tensor = None,
|
94 |
+
return_dict=True,
|
95 |
+
) -> Union[
|
96 |
+
Union[AudioPipelineOutput, ImagePipelineOutput],
|
97 |
+
Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]],
|
98 |
+
]:
|
99 |
+
"""Generate random mel spectrogram from audio input and convert to audio.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
batch_size (`int`): number of samples to generate
|
103 |
+
audio_file (`str`): must be a file on disk due to Librosa limitation or
|
104 |
+
raw_audio (`np.ndarray`): audio as numpy array
|
105 |
+
slice (`int`): slice number of audio to convert
|
106 |
+
start_step (int): step to start from
|
107 |
+
steps (`int`): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM)
|
108 |
+
generator (`torch.Generator`): random number generator or None
|
109 |
+
mask_start_secs (`float`): number of seconds of audio to mask (not generate) at start
|
110 |
+
mask_end_secs (`float`): number of seconds of audio to mask (not generate) at end
|
111 |
+
step_generator (`torch.Generator`): random number generator used to de-noise or None
|
112 |
+
eta (`float`): parameter between 0 and 1 used with DDIM scheduler
|
113 |
+
noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
|
114 |
+
encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
|
115 |
+
return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
`List[PIL Image]`: mel spectrograms (`float`, `List[np.ndarray]`): sample rate and raw audios
|
119 |
+
"""
|
120 |
+
|
121 |
+
steps = steps or self.get_default_steps()
|
122 |
+
self.scheduler.set_timesteps(steps)
|
123 |
+
step_generator = step_generator or generator
|
124 |
+
# For backwards compatibility
|
125 |
+
if type(self.unet.sample_size) == int:
|
126 |
+
self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size)
|
127 |
+
input_dims = self.get_input_dims()
|
128 |
+
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
|
129 |
+
if noise is None:
|
130 |
+
noise = torch.randn(
|
131 |
+
(
|
132 |
+
batch_size,
|
133 |
+
self.unet.in_channels,
|
134 |
+
self.unet.sample_size[0],
|
135 |
+
self.unet.sample_size[1],
|
136 |
+
),
|
137 |
+
generator=generator,
|
138 |
+
device=self.device,
|
139 |
+
)
|
140 |
+
images = noise
|
141 |
+
mask = None
|
142 |
+
|
143 |
+
if audio_file is not None or raw_audio is not None:
|
144 |
+
self.mel.load_audio(audio_file, raw_audio)
|
145 |
+
input_image = self.mel.audio_slice_to_image(slice)
|
146 |
+
input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape(
|
147 |
+
(input_image.height, input_image.width)
|
148 |
+
)
|
149 |
+
input_image = (input_image / 255) * 2 - 1
|
150 |
+
input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device)
|
151 |
+
|
152 |
+
if self.vqvae is not None:
|
153 |
+
input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
|
154 |
+
generator=generator
|
155 |
+
)[0]
|
156 |
+
input_images = 0.18215 * input_images
|
157 |
+
|
158 |
+
if start_step > 0:
|
159 |
+
images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
|
160 |
+
|
161 |
+
pixels_per_second = (
|
162 |
+
self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
|
163 |
+
)
|
164 |
+
mask_start = int(mask_start_secs * pixels_per_second)
|
165 |
+
mask_end = int(mask_end_secs * pixels_per_second)
|
166 |
+
mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
|
167 |
+
|
168 |
+
for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])):
|
169 |
+
if isinstance(self.unet, UNet2DConditionModel):
|
170 |
+
model_output = self.unet(images, t, encoding)["sample"]
|
171 |
+
else:
|
172 |
+
model_output = self.unet(images, t)["sample"]
|
173 |
+
|
174 |
+
if isinstance(self.scheduler, DDIMScheduler):
|
175 |
+
images = self.scheduler.step(
|
176 |
+
model_output=model_output,
|
177 |
+
timestep=t,
|
178 |
+
sample=images,
|
179 |
+
eta=eta,
|
180 |
+
generator=step_generator,
|
181 |
+
)["prev_sample"]
|
182 |
+
else:
|
183 |
+
images = self.scheduler.step(
|
184 |
+
model_output=model_output,
|
185 |
+
timestep=t,
|
186 |
+
sample=images,
|
187 |
+
generator=step_generator,
|
188 |
+
)["prev_sample"]
|
189 |
+
|
190 |
+
if mask is not None:
|
191 |
+
if mask_start > 0:
|
192 |
+
images[:, :, :, :mask_start] = mask[:, step, :, :mask_start]
|
193 |
+
if mask_end > 0:
|
194 |
+
images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:]
|
195 |
+
|
196 |
+
if self.vqvae is not None:
|
197 |
+
# 0.18215 was scaling factor used in training to ensure unit variance
|
198 |
+
images = 1 / 0.18215 * images
|
199 |
+
images = self.vqvae.decode(images)["sample"]
|
200 |
+
|
201 |
+
images = (images / 2 + 0.5).clamp(0, 1)
|
202 |
+
images = images.cpu().permute(0, 2, 3, 1).numpy()
|
203 |
+
images = (images * 255).round().astype("uint8")
|
204 |
+
images = list(
|
205 |
+
map(lambda _: Image.fromarray(_[:, :, 0]), images)
|
206 |
+
if images.shape[3] == 1
|
207 |
+
else map(lambda _: Image.fromarray(_, mode="RGB").convert("L"), images)
|
208 |
+
)
|
209 |
+
|
210 |
+
audios = list(map(lambda _: self.mel.image_to_audio(_), images))
|
211 |
+
if not return_dict:
|
212 |
+
return images, (self.mel.get_sample_rate(), audios)
|
213 |
+
|
214 |
+
return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images))
|
215 |
+
|
216 |
+
@torch.no_grad()
|
217 |
+
def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray:
|
218 |
+
"""Reverse step process: recover noisy image from generated image.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
images (`List[PIL Image]`): list of images to encode
|
222 |
+
steps (`int`): number of encoding steps to perform (defaults to 50)
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
`np.ndarray`: noise tensor of shape (batch_size, 1, height, width)
|
226 |
+
"""
|
227 |
+
|
228 |
+
# Only works with DDIM as this method is deterministic
|
229 |
+
assert isinstance(self.scheduler, DDIMScheduler)
|
230 |
+
self.scheduler.set_timesteps(steps)
|
231 |
+
sample = np.array(
|
232 |
+
[np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images]
|
233 |
+
)
|
234 |
+
sample = (sample / 255) * 2 - 1
|
235 |
+
sample = torch.Tensor(sample).to(self.device)
|
236 |
+
|
237 |
+
for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))):
|
238 |
+
prev_timestep = t - self.scheduler.num_train_timesteps // self.scheduler.num_inference_steps
|
239 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
240 |
+
alpha_prod_t_prev = (
|
241 |
+
self.scheduler.alphas_cumprod[prev_timestep]
|
242 |
+
if prev_timestep >= 0
|
243 |
+
else self.scheduler.final_alpha_cumprod
|
244 |
+
)
|
245 |
+
beta_prod_t = 1 - alpha_prod_t
|
246 |
+
model_output = self.unet(sample, t)["sample"]
|
247 |
+
pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output
|
248 |
+
sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5)
|
249 |
+
sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output
|
250 |
+
|
251 |
+
return sample
|
252 |
+
|
253 |
+
@staticmethod
|
254 |
+
def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor:
|
255 |
+
"""Spherical Linear intERPolation
|
256 |
+
|
257 |
+
Args:
|
258 |
+
x0 (`torch.Tensor`): first tensor to interpolate between
|
259 |
+
x1 (`torch.Tensor`): seconds tensor to interpolate between
|
260 |
+
alpha (`float`): interpolation between 0 and 1
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
`torch.Tensor`: interpolated tensor
|
264 |
+
"""
|
265 |
+
|
266 |
+
theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1))
|
267 |
+
return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta)
|
audiodiffusion/utils.py
CHANGED
@@ -23,8 +23,7 @@ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
|
23 |
new_item = old_item
|
24 |
|
25 |
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
26 |
-
new_item = shave_segments(
|
27 |
-
new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
28 |
|
29 |
mapping.append({"old": old_item, "new": new_item})
|
30 |
|
@@ -54,20 +53,21 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
|
54 |
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
55 |
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
56 |
|
57 |
-
new_item = shave_segments(
|
58 |
-
new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
59 |
|
60 |
mapping.append({"old": old_item, "new": new_item})
|
61 |
|
62 |
return mapping
|
63 |
|
64 |
|
65 |
-
def assign_to_checkpoint(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
71 |
"""
|
72 |
This does the final conversion step: take locally converted weights and apply a global renaming
|
73 |
to them. It splits attention layers, and takes into account additional replacements
|
@@ -75,9 +75,7 @@ def assign_to_checkpoint(paths,
|
|
75 |
|
76 |
Assigns the weights to the new checkpoint.
|
77 |
"""
|
78 |
-
assert isinstance(
|
79 |
-
paths, list
|
80 |
-
), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
81 |
|
82 |
# Splits the attention layers into three variables.
|
83 |
if attention_paths_to_split is not None:
|
@@ -85,13 +83,11 @@ def assign_to_checkpoint(paths,
|
|
85 |
old_tensor = old_checkpoint[path]
|
86 |
channels = old_tensor.shape[0] // 3
|
87 |
|
88 |
-
target_shape = (-1,
|
89 |
-
channels) if len(old_tensor.shape) == 3 else (-1)
|
90 |
|
91 |
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
92 |
|
93 |
-
old_tensor = old_tensor.reshape((num_heads, 3 * channels //
|
94 |
-
num_heads) + old_tensor.shape[1:])
|
95 |
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
96 |
|
97 |
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
@@ -112,8 +108,7 @@ def assign_to_checkpoint(paths,
|
|
112 |
|
113 |
if additional_replacements is not None:
|
114 |
for replacement in additional_replacements:
|
115 |
-
new_path = new_path.replace(replacement["old"],
|
116 |
-
replacement["new"])
|
117 |
|
118 |
# proj_attn.weight has to be converted from conv 1D to linear
|
119 |
if "proj_attn.weight" in new_path:
|
@@ -146,7 +141,7 @@ def create_vae_diffusers_config(original_config):
|
|
146 |
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
147 |
|
148 |
config = dict(
|
149 |
-
sample_size=vae_params.resolution,
|
150 |
in_channels=vae_params.in_channels,
|
151 |
out_channels=vae_params.out_ch,
|
152 |
down_block_types=tuple(down_block_types),
|
@@ -164,178 +159,144 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
164 |
|
165 |
new_checkpoint = {}
|
166 |
|
167 |
-
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict[
|
168 |
-
|
169 |
-
new_checkpoint["encoder.
|
170 |
-
|
171 |
-
new_checkpoint["encoder.
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
new_checkpoint["
|
176 |
-
|
177 |
-
new_checkpoint["
|
178 |
-
|
179 |
-
|
180 |
-
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict[
|
181 |
-
"decoder.conv_in.weight"]
|
182 |
-
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict[
|
183 |
-
"decoder.conv_in.bias"]
|
184 |
-
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
|
185 |
-
"decoder.conv_out.weight"]
|
186 |
-
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict[
|
187 |
-
"decoder.conv_out.bias"]
|
188 |
-
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
|
189 |
-
"decoder.norm_out.weight"]
|
190 |
-
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
|
191 |
-
"decoder.norm_out.bias"]
|
192 |
|
193 |
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
194 |
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
195 |
-
new_checkpoint["post_quant_conv.weight"] = vae_state_dict[
|
196 |
-
|
197 |
-
new_checkpoint["post_quant_conv.bias"] = vae_state_dict[
|
198 |
-
"post_quant_conv.bias"]
|
199 |
|
200 |
# Retrieves the keys for the encoder down blocks only
|
201 |
-
num_down_blocks = len({
|
202 |
-
".".join(layer.split(".")[:3])
|
203 |
-
for layer in vae_state_dict if "encoder.down" in layer
|
204 |
-
})
|
205 |
down_blocks = {
|
206 |
-
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
|
207 |
-
for layer_id in range(num_down_blocks)
|
208 |
}
|
209 |
|
210 |
# Retrieves the keys for the decoder up blocks only
|
211 |
-
num_up_blocks = len({
|
212 |
-
".".join(layer.split(".")[:3])
|
213 |
-
for layer in vae_state_dict if "decoder.up" in layer
|
214 |
-
})
|
215 |
up_blocks = {
|
216 |
-
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
|
217 |
-
for layer_id in range(num_up_blocks)
|
218 |
}
|
219 |
|
220 |
for i in range(num_down_blocks):
|
221 |
-
resnets = [
|
222 |
-
key for key in down_blocks[i]
|
223 |
-
if f"down.{i}" in key and f"down.{i}.downsample" not in key
|
224 |
-
]
|
225 |
|
226 |
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
227 |
-
new_checkpoint[
|
228 |
-
f"encoder.
|
229 |
-
|
230 |
-
new_checkpoint[
|
231 |
-
f"encoder.
|
232 |
-
|
233 |
|
234 |
paths = renew_vae_resnet_paths(resnets)
|
235 |
-
meta_path = {
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
config=config)
|
244 |
|
245 |
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
246 |
num_mid_res_blocks = 2
|
247 |
for i in range(1, num_mid_res_blocks + 1):
|
248 |
-
resnets = [
|
249 |
-
key for key in mid_resnets if f"encoder.mid.block_{i}" in key
|
250 |
-
]
|
251 |
|
252 |
paths = renew_vae_resnet_paths(resnets)
|
253 |
-
meta_path = {
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
mid_attentions = [
|
264 |
-
key for key in vae_state_dict if "encoder.mid.attn" in key
|
265 |
-
]
|
266 |
paths = renew_vae_attention_paths(mid_attentions)
|
267 |
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
268 |
-
assign_to_checkpoint(
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
|
|
|
|
273 |
conv_attn_to_linear(new_checkpoint)
|
274 |
|
275 |
for i in range(num_up_blocks):
|
276 |
block_id = num_up_blocks - 1 - i
|
277 |
resnets = [
|
278 |
-
key for key in up_blocks[block_id]
|
279 |
-
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
280 |
]
|
281 |
|
282 |
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
283 |
-
new_checkpoint[
|
284 |
-
f"decoder.
|
285 |
-
|
286 |
-
new_checkpoint[
|
287 |
-
f"decoder.
|
288 |
-
|
289 |
|
290 |
paths = renew_vae_resnet_paths(resnets)
|
291 |
-
meta_path = {
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
config=config)
|
300 |
|
301 |
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
302 |
num_mid_res_blocks = 2
|
303 |
for i in range(1, num_mid_res_blocks + 1):
|
304 |
-
resnets = [
|
305 |
-
key for key in mid_resnets if f"decoder.mid.block_{i}" in key
|
306 |
-
]
|
307 |
|
308 |
paths = renew_vae_resnet_paths(resnets)
|
309 |
-
meta_path = {
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
mid_attentions = [
|
320 |
-
key for key in vae_state_dict if "decoder.mid.attn" in key
|
321 |
-
]
|
322 |
paths = renew_vae_attention_paths(mid_attentions)
|
323 |
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
324 |
-
assign_to_checkpoint(
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
|
|
|
|
329 |
conv_attn_to_linear(new_checkpoint)
|
330 |
return new_checkpoint
|
331 |
|
332 |
-
|
|
|
333 |
checkpoint = torch.load(ldm_checkpoint)["state_dict"]
|
334 |
|
335 |
# Convert the VAE model.
|
336 |
vae_config = create_vae_diffusers_config(ldm_config)
|
337 |
-
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
338 |
-
checkpoint, vae_config)
|
339 |
|
340 |
vae = AutoencoderKL(**vae_config)
|
341 |
vae.load_state_dict(converted_vae_checkpoint)
|
|
|
23 |
new_item = old_item
|
24 |
|
25 |
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
26 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
|
27 |
|
28 |
mapping.append({"old": old_item, "new": new_item})
|
29 |
|
|
|
53 |
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
54 |
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
55 |
|
56 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
|
|
57 |
|
58 |
mapping.append({"old": old_item, "new": new_item})
|
59 |
|
60 |
return mapping
|
61 |
|
62 |
|
63 |
+
def assign_to_checkpoint(
|
64 |
+
paths,
|
65 |
+
checkpoint,
|
66 |
+
old_checkpoint,
|
67 |
+
attention_paths_to_split=None,
|
68 |
+
additional_replacements=None,
|
69 |
+
config=None,
|
70 |
+
):
|
71 |
"""
|
72 |
This does the final conversion step: take locally converted weights and apply a global renaming
|
73 |
to them. It splits attention layers, and takes into account additional replacements
|
|
|
75 |
|
76 |
Assigns the weights to the new checkpoint.
|
77 |
"""
|
78 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
|
|
|
|
79 |
|
80 |
# Splits the attention layers into three variables.
|
81 |
if attention_paths_to_split is not None:
|
|
|
83 |
old_tensor = old_checkpoint[path]
|
84 |
channels = old_tensor.shape[0] // 3
|
85 |
|
86 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
|
|
87 |
|
88 |
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
89 |
|
90 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
|
|
91 |
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
92 |
|
93 |
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
|
|
108 |
|
109 |
if additional_replacements is not None:
|
110 |
for replacement in additional_replacements:
|
111 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
|
|
112 |
|
113 |
# proj_attn.weight has to be converted from conv 1D to linear
|
114 |
if "proj_attn.weight" in new_path:
|
|
|
141 |
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
142 |
|
143 |
config = dict(
|
144 |
+
sample_size=tuple(vae_params.resolution),
|
145 |
in_channels=vae_params.in_channels,
|
146 |
out_channels=vae_params.out_ch,
|
147 |
down_block_types=tuple(down_block_types),
|
|
|
159 |
|
160 |
new_checkpoint = {}
|
161 |
|
162 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
163 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
164 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
165 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
166 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
167 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
168 |
+
|
169 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
170 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
171 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
172 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
173 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
174 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
177 |
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
178 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
179 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
|
|
|
|
180 |
|
181 |
# Retrieves the keys for the encoder down blocks only
|
182 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
|
|
|
|
|
|
183 |
down_blocks = {
|
184 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
|
|
185 |
}
|
186 |
|
187 |
# Retrieves the keys for the decoder up blocks only
|
188 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
|
|
|
|
|
|
189 |
up_blocks = {
|
190 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
|
|
191 |
}
|
192 |
|
193 |
for i in range(num_down_blocks):
|
194 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
|
|
|
|
|
|
195 |
|
196 |
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
197 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
198 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
199 |
+
)
|
200 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
201 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
202 |
+
)
|
203 |
|
204 |
paths = renew_vae_resnet_paths(resnets)
|
205 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
206 |
+
assign_to_checkpoint(
|
207 |
+
paths,
|
208 |
+
new_checkpoint,
|
209 |
+
vae_state_dict,
|
210 |
+
additional_replacements=[meta_path],
|
211 |
+
config=config,
|
212 |
+
)
|
|
|
213 |
|
214 |
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
215 |
num_mid_res_blocks = 2
|
216 |
for i in range(1, num_mid_res_blocks + 1):
|
217 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
|
|
|
|
218 |
|
219 |
paths = renew_vae_resnet_paths(resnets)
|
220 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
221 |
+
assign_to_checkpoint(
|
222 |
+
paths,
|
223 |
+
new_checkpoint,
|
224 |
+
vae_state_dict,
|
225 |
+
additional_replacements=[meta_path],
|
226 |
+
config=config,
|
227 |
+
)
|
228 |
+
|
229 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
|
|
|
|
|
|
230 |
paths = renew_vae_attention_paths(mid_attentions)
|
231 |
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
232 |
+
assign_to_checkpoint(
|
233 |
+
paths,
|
234 |
+
new_checkpoint,
|
235 |
+
vae_state_dict,
|
236 |
+
additional_replacements=[meta_path],
|
237 |
+
config=config,
|
238 |
+
)
|
239 |
conv_attn_to_linear(new_checkpoint)
|
240 |
|
241 |
for i in range(num_up_blocks):
|
242 |
block_id = num_up_blocks - 1 - i
|
243 |
resnets = [
|
244 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
|
|
245 |
]
|
246 |
|
247 |
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
248 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
249 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
250 |
+
]
|
251 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
252 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
253 |
+
]
|
254 |
|
255 |
paths = renew_vae_resnet_paths(resnets)
|
256 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
257 |
+
assign_to_checkpoint(
|
258 |
+
paths,
|
259 |
+
new_checkpoint,
|
260 |
+
vae_state_dict,
|
261 |
+
additional_replacements=[meta_path],
|
262 |
+
config=config,
|
263 |
+
)
|
|
|
264 |
|
265 |
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
266 |
num_mid_res_blocks = 2
|
267 |
for i in range(1, num_mid_res_blocks + 1):
|
268 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
|
|
|
|
269 |
|
270 |
paths = renew_vae_resnet_paths(resnets)
|
271 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
272 |
+
assign_to_checkpoint(
|
273 |
+
paths,
|
274 |
+
new_checkpoint,
|
275 |
+
vae_state_dict,
|
276 |
+
additional_replacements=[meta_path],
|
277 |
+
config=config,
|
278 |
+
)
|
279 |
+
|
280 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
|
|
|
|
|
|
281 |
paths = renew_vae_attention_paths(mid_attentions)
|
282 |
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
283 |
+
assign_to_checkpoint(
|
284 |
+
paths,
|
285 |
+
new_checkpoint,
|
286 |
+
vae_state_dict,
|
287 |
+
additional_replacements=[meta_path],
|
288 |
+
config=config,
|
289 |
+
)
|
290 |
conv_attn_to_linear(new_checkpoint)
|
291 |
return new_checkpoint
|
292 |
|
293 |
+
|
294 |
+
def convert_ldm_to_hf_vae(ldm_checkpoint, ldm_config, hf_checkpoint, sample_size):
|
295 |
checkpoint = torch.load(ldm_checkpoint)["state_dict"]
|
296 |
|
297 |
# Convert the VAE model.
|
298 |
vae_config = create_vae_diffusers_config(ldm_config)
|
299 |
+
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
|
|
300 |
|
301 |
vae = AutoencoderKL(**vae_config)
|
302 |
vae.load_state_dict(converted_vae_checkpoint)
|
config/ldm_autoencoder_kl.yaml
CHANGED
@@ -18,7 +18,7 @@ model:
|
|
18 |
ddconfig:
|
19 |
double_z: True
|
20 |
z_channels: 1 # must = embed_dim due to HF limitation
|
21 |
-
resolution: 256
|
22 |
in_channels: 1
|
23 |
out_ch: 1
|
24 |
ch: 128
|
|
|
18 |
ddconfig:
|
19 |
double_z: True
|
20 |
z_channels: 1 # must = embed_dim due to HF limitation
|
21 |
+
resolution: 256 # overriden by input image size
|
22 |
in_channels: 1
|
23 |
out_ch: 1
|
24 |
ch: 128
|
notebooks/audio_encoder.ipynb
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "592fff30",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from diffusers import Mel\n",
|
11 |
+
"from audiodiffusion.audio_encoder import AudioEncoder"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": null,
|
17 |
+
"id": "d99ef523",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"audio_encoder = AudioEncoder.from_pretrained(\"teticio/audio-encoder\")"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": null,
|
27 |
+
"id": "4eb3bbd7",
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"audio_encoder.encode(['/home/teticio/Music/liked/Agua Re - Holy Dance - Large Sound Mix.mp3'])"
|
32 |
+
]
|
33 |
+
}
|
34 |
+
],
|
35 |
+
"metadata": {
|
36 |
+
"kernelspec": {
|
37 |
+
"display_name": "huggingface",
|
38 |
+
"language": "python",
|
39 |
+
"name": "huggingface"
|
40 |
+
},
|
41 |
+
"language_info": {
|
42 |
+
"codemirror_mode": {
|
43 |
+
"name": "ipython",
|
44 |
+
"version": 3
|
45 |
+
},
|
46 |
+
"file_extension": ".py",
|
47 |
+
"mimetype": "text/x-python",
|
48 |
+
"name": "python",
|
49 |
+
"nbconvert_exporter": "python",
|
50 |
+
"pygments_lexer": "ipython3",
|
51 |
+
"version": "3.10.6"
|
52 |
+
},
|
53 |
+
"toc": {
|
54 |
+
"base_numbering": 1,
|
55 |
+
"nav_menu": {},
|
56 |
+
"number_sections": true,
|
57 |
+
"sideBar": true,
|
58 |
+
"skip_h1_title": false,
|
59 |
+
"title_cell": "Table of Contents",
|
60 |
+
"title_sidebar": "Contents",
|
61 |
+
"toc_cell": false,
|
62 |
+
"toc_position": {},
|
63 |
+
"toc_section_display": true,
|
64 |
+
"toc_window_display": false
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"nbformat": 4,
|
68 |
+
"nbformat_minor": 5
|
69 |
+
}
|
notebooks/conditional_generation.ipynb
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "2a44739f",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"<a href=\"https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/condtional_generation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
9 |
+
]
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"cell_type": "code",
|
13 |
+
"execution_count": null,
|
14 |
+
"id": "f1935544",
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"try:\n",
|
19 |
+
" # are we running on Google Colab?\n",
|
20 |
+
" import google.colab\n",
|
21 |
+
" !git clone -q https://github.com/teticio/audio-diffusion.git\n",
|
22 |
+
" %cd audio-diffusion\n",
|
23 |
+
" %pip install -q -r requirements.txt\n",
|
24 |
+
"except:\n",
|
25 |
+
" pass"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": null,
|
31 |
+
"id": "b0e656c9",
|
32 |
+
"metadata": {},
|
33 |
+
"outputs": [],
|
34 |
+
"source": [
|
35 |
+
"import os\n",
|
36 |
+
"import sys\n",
|
37 |
+
"sys.path.insert(0, os.path.dirname(os.path.abspath(\"\")))"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": null,
|
43 |
+
"id": "d448b299",
|
44 |
+
"metadata": {},
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"import torch\n",
|
48 |
+
"import urllib\n",
|
49 |
+
"import requests\n",
|
50 |
+
"from IPython.display import Audio\n",
|
51 |
+
"from audiodiffusion import AudioDiffusion\n",
|
52 |
+
"from audiodiffusion.audio_encoder import AudioEncoder"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": null,
|
58 |
+
"id": "f1548971",
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
63 |
+
"generator = torch.Generator(device=device)"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": null,
|
69 |
+
"id": "056f179c",
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"audio_diffusion = AudioDiffusion(model_id=\"teticio/conditional-latent-audio-diffusion-512\")"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": null,
|
79 |
+
"id": "b4a08500",
|
80 |
+
"metadata": {},
|
81 |
+
"outputs": [],
|
82 |
+
"source": [
|
83 |
+
"audio_encoder = AudioEncoder.from_pretrained(\"teticio/audio-encoder\")"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": null,
|
89 |
+
"id": "387550ac",
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [],
|
92 |
+
"source": [
|
93 |
+
"# Uncomment for faster (but slightly lower quality) generation\n",
|
94 |
+
"#from diffusers import DDIMScheduler\n",
|
95 |
+
"#audio_diffusion.pipe.scheduler = DDIMScheduler()"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "markdown",
|
100 |
+
"id": "9936a72f",
|
101 |
+
"metadata": {},
|
102 |
+
"source": [
|
103 |
+
"## Download and encode preview track from Spotify"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": null,
|
109 |
+
"id": "57a9b134",
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"# Get temporary API credentials\n",
|
114 |
+
"credentials = requests.get(\n",
|
115 |
+
" \"https://open.spotify.com/get_access_token?reason=transport&productType=embed\"\n",
|
116 |
+
").json()\n",
|
117 |
+
"headers = {\n",
|
118 |
+
" \"Accept\": \"application/json\",\n",
|
119 |
+
" \"Content-Type\": \"application/json\",\n",
|
120 |
+
" \"Authorization\": \"Bearer \" + credentials[\"accessToken\"]\n",
|
121 |
+
"}\n",
|
122 |
+
"\n",
|
123 |
+
"# Search for tracks\n",
|
124 |
+
"search_string = input(\"Search: \")\n",
|
125 |
+
"response = requests.get(\n",
|
126 |
+
" f\"https://api.spotify.com/v1/search?q={urllib.parse.quote(search_string)}&type=track\",\n",
|
127 |
+
" headers=headers).json()\n",
|
128 |
+
"\n",
|
129 |
+
"# List results\n",
|
130 |
+
"for _, track in enumerate(response[\"tracks\"][\"items\"]):\n",
|
131 |
+
" print(f\"{_ + 1}. {track['artists'][0]['name']} - {track['name']}\")\n",
|
132 |
+
"selection = input(\"Select a track: \")\n",
|
133 |
+
"\n",
|
134 |
+
"# Download and encode selection\n",
|
135 |
+
"r = requests.get(response[\"tracks\"][\"items\"][int(selection) -\n",
|
136 |
+
" 1][\"preview_url\"],\n",
|
137 |
+
" stream=True)\n",
|
138 |
+
"with open(\"temp.mp3\", \"wb\") as f:\n",
|
139 |
+
" for chunk in r:\n",
|
140 |
+
" f.write(chunk)\n",
|
141 |
+
"encoding = torch.unsqueeze(audio_encoder.encode([\"temp.mp3\"]),\n",
|
142 |
+
" axis=1).to(device)\n",
|
143 |
+
"os.remove(\"temp.mp3\")"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
{
|
147 |
+
"cell_type": "markdown",
|
148 |
+
"id": "8af863f5",
|
149 |
+
"metadata": {},
|
150 |
+
"source": [
|
151 |
+
"## Conditional Generation\n",
|
152 |
+
"Bear in mind that the generative model can only generate music similar to that on which it was trained. The audio encoding will influence the generation within those limitations."
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"execution_count": null,
|
158 |
+
"id": "8f119ddd",
|
159 |
+
"metadata": {},
|
160 |
+
"outputs": [],
|
161 |
+
"source": [
|
162 |
+
"for _ in range(10):\n",
|
163 |
+
" seed = generator.seed()\n",
|
164 |
+
" print(f'Seed = {seed}')\n",
|
165 |
+
" generator.manual_seed(seed)\n",
|
166 |
+
" image, (sample_rate,\n",
|
167 |
+
" audio) = audio_diffusion.generate_spectrogram_and_audio(\n",
|
168 |
+
" generator=generator, encoding=encoding)\n",
|
169 |
+
" display(image)\n",
|
170 |
+
" display(Audio(audio, rate=sample_rate))\n",
|
171 |
+
" loop = AudioDiffusion.loop_it(audio, sample_rate)\n",
|
172 |
+
" if loop is not None:\n",
|
173 |
+
" display(Audio(loop, rate=sample_rate))\n",
|
174 |
+
" else:\n",
|
175 |
+
" print(\"Unable to determine loop points\")"
|
176 |
+
]
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"cell_type": "code",
|
180 |
+
"execution_count": null,
|
181 |
+
"id": "d0bd18c0",
|
182 |
+
"metadata": {},
|
183 |
+
"outputs": [],
|
184 |
+
"source": []
|
185 |
+
}
|
186 |
+
],
|
187 |
+
"metadata": {
|
188 |
+
"kernelspec": {
|
189 |
+
"display_name": "huggingface",
|
190 |
+
"language": "python",
|
191 |
+
"name": "huggingface"
|
192 |
+
},
|
193 |
+
"language_info": {
|
194 |
+
"codemirror_mode": {
|
195 |
+
"name": "ipython",
|
196 |
+
"version": 3
|
197 |
+
},
|
198 |
+
"file_extension": ".py",
|
199 |
+
"mimetype": "text/x-python",
|
200 |
+
"name": "python",
|
201 |
+
"nbconvert_exporter": "python",
|
202 |
+
"pygments_lexer": "ipython3",
|
203 |
+
"version": "3.10.6"
|
204 |
+
},
|
205 |
+
"toc": {
|
206 |
+
"base_numbering": 1,
|
207 |
+
"nav_menu": {},
|
208 |
+
"number_sections": true,
|
209 |
+
"sideBar": true,
|
210 |
+
"skip_h1_title": false,
|
211 |
+
"title_cell": "Table of Contents",
|
212 |
+
"title_sidebar": "Contents",
|
213 |
+
"toc_cell": false,
|
214 |
+
"toc_position": {},
|
215 |
+
"toc_section_display": true,
|
216 |
+
"toc_window_display": false
|
217 |
+
}
|
218 |
+
},
|
219 |
+
"nbformat": 4,
|
220 |
+
"nbformat_minor": 5
|
221 |
+
}
|
pyproject.toml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[tool.black]
|
2 |
+
line-length = 119
|
3 |
+
target-version = ['py36']
|
scripts/audio_to_images.py
CHANGED
@@ -1,29 +1,33 @@
|
|
1 |
-
import
|
2 |
-
import re
|
3 |
import io
|
4 |
import logging
|
5 |
-
import
|
|
|
6 |
|
7 |
import numpy as np
|
8 |
import pandas as pd
|
9 |
-
from tqdm.auto import tqdm
|
10 |
-
from diffusers.pipelines.audio_diffusion import Mel
|
11 |
from datasets import Dataset, DatasetDict, Features, Image, Value
|
|
|
|
|
12 |
|
13 |
logging.basicConfig(level=logging.WARN)
|
14 |
-
logger = logging.getLogger(
|
15 |
|
16 |
|
17 |
def main(args):
|
18 |
-
mel = Mel(
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
os.makedirs(args.output_dir, exist_ok=True)
|
24 |
audio_files = [
|
25 |
-
os.path.join(root, file)
|
26 |
-
for
|
|
|
|
|
27 |
]
|
28 |
examples = []
|
29 |
try:
|
@@ -36,36 +40,38 @@ def main(args):
|
|
36 |
continue
|
37 |
for slice in range(mel.get_number_of_slices()):
|
38 |
image = mel.audio_slice_to_image(slice)
|
39 |
-
assert
|
40 |
-
== args.resolution[1]), "Wrong resolution"
|
41 |
# skip completely silent slices
|
42 |
if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255):
|
43 |
-
logger.warn(
|
44 |
-
audio_file, slice)
|
45 |
continue
|
46 |
with io.BytesIO() as output:
|
47 |
image.save(output, format="PNG")
|
48 |
bytes = output.getvalue()
|
49 |
-
examples.extend(
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
except Exception as e:
|
57 |
print(e)
|
58 |
finally:
|
59 |
if len(examples) == 0:
|
60 |
-
logger.warn(
|
61 |
return
|
62 |
ds = Dataset.from_pandas(
|
63 |
pd.DataFrame(examples),
|
64 |
-
features=Features(
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
69 |
)
|
70 |
dsd = DatasetDict({"train": ds})
|
71 |
dsd.save_to_disk(os.path.join(args.output_dir))
|
@@ -74,15 +80,15 @@ def main(args):
|
|
74 |
|
75 |
|
76 |
if __name__ == "__main__":
|
77 |
-
parser = argparse.ArgumentParser(
|
78 |
-
description=
|
79 |
-
"Create dataset of Mel spectrograms from directory of audio files.")
|
80 |
parser.add_argument("--input_dir", type=str)
|
81 |
parser.add_argument("--output_dir", type=str, default="data")
|
82 |
-
parser.add_argument(
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
parser.add_argument("--hop_length", type=int, default=512)
|
87 |
parser.add_argument("--push_to_hub", type=str, default=None)
|
88 |
parser.add_argument("--sample_rate", type=int, default=22050)
|
@@ -90,8 +96,7 @@ if __name__ == "__main__":
|
|
90 |
args = parser.parse_args()
|
91 |
|
92 |
if args.input_dir is None:
|
93 |
-
raise ValueError(
|
94 |
-
"You must specify an input directory for the audio files.")
|
95 |
|
96 |
# Handle the resolutions.
|
97 |
try:
|
@@ -102,9 +107,7 @@ if __name__ == "__main__":
|
|
102 |
if len(args.resolution) != 2:
|
103 |
raise ValueError
|
104 |
except ValueError:
|
105 |
-
raise ValueError(
|
106 |
-
"Resolution must be a tuple of two integers or a single integer."
|
107 |
-
)
|
108 |
assert isinstance(args.resolution, tuple)
|
109 |
|
110 |
main(args)
|
|
|
1 |
+
import argparse
|
|
|
2 |
import io
|
3 |
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
|
7 |
import numpy as np
|
8 |
import pandas as pd
|
|
|
|
|
9 |
from datasets import Dataset, DatasetDict, Features, Image, Value
|
10 |
+
from diffusers.pipelines.audio_diffusion import Mel
|
11 |
+
from tqdm.auto import tqdm
|
12 |
|
13 |
logging.basicConfig(level=logging.WARN)
|
14 |
+
logger = logging.getLogger("audio_to_images")
|
15 |
|
16 |
|
17 |
def main(args):
|
18 |
+
mel = Mel(
|
19 |
+
x_res=args.resolution[0],
|
20 |
+
y_res=args.resolution[1],
|
21 |
+
hop_length=args.hop_length,
|
22 |
+
sample_rate=args.sample_rate,
|
23 |
+
n_fft=args.n_fft,
|
24 |
+
)
|
25 |
os.makedirs(args.output_dir, exist_ok=True)
|
26 |
audio_files = [
|
27 |
+
os.path.join(root, file)
|
28 |
+
for root, _, files in os.walk(args.input_dir)
|
29 |
+
for file in files
|
30 |
+
if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
|
31 |
]
|
32 |
examples = []
|
33 |
try:
|
|
|
40 |
continue
|
41 |
for slice in range(mel.get_number_of_slices()):
|
42 |
image = mel.audio_slice_to_image(slice)
|
43 |
+
assert image.width == args.resolution[0] and image.height == args.resolution[1], "Wrong resolution"
|
|
|
44 |
# skip completely silent slices
|
45 |
if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255):
|
46 |
+
logger.warn("File %s slice %d is completely silent", audio_file, slice)
|
|
|
47 |
continue
|
48 |
with io.BytesIO() as output:
|
49 |
image.save(output, format="PNG")
|
50 |
bytes = output.getvalue()
|
51 |
+
examples.extend(
|
52 |
+
[
|
53 |
+
{
|
54 |
+
"image": {"bytes": bytes},
|
55 |
+
"audio_file": audio_file,
|
56 |
+
"slice": slice,
|
57 |
+
}
|
58 |
+
]
|
59 |
+
)
|
60 |
except Exception as e:
|
61 |
print(e)
|
62 |
finally:
|
63 |
if len(examples) == 0:
|
64 |
+
logger.warn("No valid audio files were found.")
|
65 |
return
|
66 |
ds = Dataset.from_pandas(
|
67 |
pd.DataFrame(examples),
|
68 |
+
features=Features(
|
69 |
+
{
|
70 |
+
"image": Image(),
|
71 |
+
"audio_file": Value(dtype="string"),
|
72 |
+
"slice": Value(dtype="int16"),
|
73 |
+
}
|
74 |
+
),
|
75 |
)
|
76 |
dsd = DatasetDict({"train": ds})
|
77 |
dsd.save_to_disk(os.path.join(args.output_dir))
|
|
|
80 |
|
81 |
|
82 |
if __name__ == "__main__":
|
83 |
+
parser = argparse.ArgumentParser(description="Create dataset of Mel spectrograms from directory of audio files.")
|
|
|
|
|
84 |
parser.add_argument("--input_dir", type=str)
|
85 |
parser.add_argument("--output_dir", type=str, default="data")
|
86 |
+
parser.add_argument(
|
87 |
+
"--resolution",
|
88 |
+
type=str,
|
89 |
+
default="256",
|
90 |
+
help="Either square resolution or width,height.",
|
91 |
+
)
|
92 |
parser.add_argument("--hop_length", type=int, default=512)
|
93 |
parser.add_argument("--push_to_hub", type=str, default=None)
|
94 |
parser.add_argument("--sample_rate", type=int, default=22050)
|
|
|
96 |
args = parser.parse_args()
|
97 |
|
98 |
if args.input_dir is None:
|
99 |
+
raise ValueError("You must specify an input directory for the audio files.")
|
|
|
100 |
|
101 |
# Handle the resolutions.
|
102 |
try:
|
|
|
107 |
if len(args.resolution) != 2:
|
108 |
raise ValueError
|
109 |
except ValueError:
|
110 |
+
raise ValueError("Resolution must be a tuple of two integers or a single integer.")
|
|
|
|
|
111 |
assert isinstance(args.resolution, tuple)
|
112 |
|
113 |
main(args)
|
scripts/encode_audio.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
from datasets import load_dataset, load_from_disk
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
|
8 |
+
from audiodiffusion.audio_encoder import AudioEncoder
|
9 |
+
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
audio_encoder = AudioEncoder.from_pretrained("teticio/audio-encoder")
|
13 |
+
|
14 |
+
if args.dataset_name is not None:
|
15 |
+
if os.path.exists(args.dataset_name):
|
16 |
+
dataset = load_from_disk(args.dataset_name)["train"]
|
17 |
+
else:
|
18 |
+
dataset = load_dataset(
|
19 |
+
args.dataset_name,
|
20 |
+
args.dataset_config_name,
|
21 |
+
cache_dir=args.cache_dir,
|
22 |
+
use_auth_token=True if args.use_auth_token else None,
|
23 |
+
split="train",
|
24 |
+
)
|
25 |
+
|
26 |
+
encodings = {}
|
27 |
+
for audio_file in tqdm(dataset.to_pandas()["audio_file"].unique()):
|
28 |
+
encodings[audio_file] = audio_encoder.encode([audio_file])
|
29 |
+
pickle.dump(encodings, open(args.output_file, "wb"))
|
30 |
+
|
31 |
+
|
32 |
+
if __name__ == "__main__":
|
33 |
+
parser = argparse.ArgumentParser(description="Create pickled audio encodings for dataset of audio files.")
|
34 |
+
parser.add_argument("--dataset_name", type=str, default=None)
|
35 |
+
parser.add_argument("--output_file", type=str, default="data/encodings.p")
|
36 |
+
parser.add_argument("--use_auth_token", type=bool, default=False)
|
37 |
+
args = parser.parse_args()
|
38 |
+
main(args)
|
scripts/{train_unconditional.py → train_unet.py}
RENAMED
@@ -2,34 +2,29 @@
|
|
2 |
|
3 |
import argparse
|
4 |
import os
|
|
|
|
|
5 |
from pathlib import Path
|
6 |
from typing import Optional
|
7 |
|
|
|
|
|
|
|
8 |
from accelerate import Accelerator
|
9 |
from accelerate.logging import get_logger
|
10 |
-
from datasets import
|
11 |
-
from diffusers import (
|
12 |
-
|
13 |
-
DDPMScheduler,
|
14 |
-
UNet2DModel,
|
15 |
-
DDIMScheduler,
|
16 |
-
AutoencoderKL,
|
17 |
-
)
|
18 |
-
from diffusers.pipelines.audio_diffusion import Mel
|
19 |
from diffusers.optimization import get_scheduler
|
|
|
20 |
from diffusers.training_utils import EMAModel
|
21 |
from huggingface_hub import HfFolder, Repository, whoami
|
22 |
from librosa.util import normalize
|
23 |
-
import
|
24 |
-
import torch
|
25 |
-
import torch.nn.functional as F
|
26 |
-
from torchvision.transforms import (
|
27 |
-
Compose,
|
28 |
-
Normalize,
|
29 |
-
ToTensor,
|
30 |
-
)
|
31 |
from tqdm.auto import tqdm
|
32 |
|
|
|
|
|
33 |
logger = get_logger(__name__)
|
34 |
|
35 |
|
@@ -90,12 +85,18 @@ def main(args):
|
|
90 |
]
|
91 |
else:
|
92 |
images = [augmentations(image) for image in examples["image"]]
|
|
|
|
|
|
|
93 |
return {"input": images}
|
94 |
|
95 |
dataset.set_transform(transforms)
|
96 |
train_dataloader = torch.utils.data.DataLoader(
|
97 |
dataset, batch_size=args.train_batch_size, shuffle=True)
|
98 |
|
|
|
|
|
|
|
99 |
vqvae = None
|
100 |
if args.vae is not None:
|
101 |
try:
|
@@ -104,9 +105,9 @@ def main(args):
|
|
104 |
vqvae = AudioDiffusionPipeline.from_pretrained(args.vae).vqvae
|
105 |
# Determine latent resolution
|
106 |
with torch.no_grad():
|
107 |
-
latent_resolution =
|
108 |
torch.zeros((1, 1) +
|
109 |
-
resolution)).latent_dist.sample().shape[2:]
|
110 |
|
111 |
if args.from_pretrained is not None:
|
112 |
pipeline = AudioDiffusionPipeline.from_pretrained(args.from_pretrained)
|
@@ -114,32 +115,58 @@ def main(args):
|
|
114 |
model = pipeline.unet
|
115 |
if hasattr(pipeline, "vqvae"):
|
116 |
vqvae = pipeline.vqvae
|
|
|
117 |
else:
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
if args.scheduler == "ddpm":
|
145 |
noise_scheduler = DDPMScheduler(
|
@@ -240,7 +267,11 @@ def main(args):
|
|
240 |
|
241 |
with accelerator.accumulate(model):
|
242 |
# Predict the noise residual
|
243 |
-
|
|
|
|
|
|
|
|
|
244 |
loss = F.mse_loss(noise_pred, noise)
|
245 |
accelerator.backward(loss)
|
246 |
|
@@ -270,9 +301,9 @@ def main(args):
|
|
270 |
|
271 |
# Generate sample images for visual inspection
|
272 |
if accelerator.is_main_process:
|
273 |
-
if (epoch + 1) % args.save_model_epochs == 0
|
274 |
-
epoch + 1
|
275 |
-
|
276 |
pipeline = AudioDiffusionPipeline(
|
277 |
vqvae=vqvae,
|
278 |
unet=accelerator.unwrap_model(
|
@@ -288,18 +319,32 @@ def main(args):
|
|
288 |
|
289 |
# save the model
|
290 |
if args.push_to_hub:
|
291 |
-
repo.push_to_hub(
|
292 |
-
|
293 |
-
|
|
|
|
|
294 |
|
295 |
if (epoch + 1) % args.save_images_epochs == 0:
|
296 |
generator = torch.Generator(
|
297 |
device=clean_images.device).manual_seed(42)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
# run pipeline in inference (sample random noise and denoise)
|
299 |
-
images, (sample_rate,
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
303 |
|
304 |
# denormalize the images and save to tensorboard
|
305 |
images = np.array([
|
@@ -385,6 +430,12 @@ if __name__ == "__main__":
|
|
385 |
default=None,
|
386 |
help="pretrained VAE model for latent diffusion",
|
387 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
389 |
args = parser.parse_args()
|
390 |
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
|
2 |
|
3 |
import argparse
|
4 |
import os
|
5 |
+
import pickle
|
6 |
+
import random
|
7 |
from pathlib import Path
|
8 |
from typing import Optional
|
9 |
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
from accelerate import Accelerator
|
14 |
from accelerate.logging import get_logger
|
15 |
+
from datasets import load_dataset, load_from_disk
|
16 |
+
from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler,
|
17 |
+
UNet2DConditionModel, UNet2DModel)
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
from diffusers.optimization import get_scheduler
|
19 |
+
from diffusers.pipelines.audio_diffusion import Mel
|
20 |
from diffusers.training_utils import EMAModel
|
21 |
from huggingface_hub import HfFolder, Repository, whoami
|
22 |
from librosa.util import normalize
|
23 |
+
from torchvision.transforms import Compose, Normalize, ToTensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
from tqdm.auto import tqdm
|
25 |
|
26 |
+
from audiodiffusion.pipeline_audio_diffusion import AudioDiffusionPipeline
|
27 |
+
|
28 |
logger = get_logger(__name__)
|
29 |
|
30 |
|
|
|
85 |
]
|
86 |
else:
|
87 |
images = [augmentations(image) for image in examples["image"]]
|
88 |
+
if args.encodings is not None:
|
89 |
+
encoding = [encodings[file] for file in examples["audio_file"]]
|
90 |
+
return {"input": images, "encoding": encoding}
|
91 |
return {"input": images}
|
92 |
|
93 |
dataset.set_transform(transforms)
|
94 |
train_dataloader = torch.utils.data.DataLoader(
|
95 |
dataset, batch_size=args.train_batch_size, shuffle=True)
|
96 |
|
97 |
+
if args.encodings is not None:
|
98 |
+
encodings = pickle.load(open(args.encodings, "rb"))
|
99 |
+
|
100 |
vqvae = None
|
101 |
if args.vae is not None:
|
102 |
try:
|
|
|
105 |
vqvae = AudioDiffusionPipeline.from_pretrained(args.vae).vqvae
|
106 |
# Determine latent resolution
|
107 |
with torch.no_grad():
|
108 |
+
latent_resolution = vqvae.encode(
|
109 |
torch.zeros((1, 1) +
|
110 |
+
resolution)).latent_dist.sample().shape[2:]
|
111 |
|
112 |
if args.from_pretrained is not None:
|
113 |
pipeline = AudioDiffusionPipeline.from_pretrained(args.from_pretrained)
|
|
|
115 |
model = pipeline.unet
|
116 |
if hasattr(pipeline, "vqvae"):
|
117 |
vqvae = pipeline.vqvae
|
118 |
+
|
119 |
else:
|
120 |
+
if args.encodings is None:
|
121 |
+
model = UNet2DModel(
|
122 |
+
sample_size=resolution if vqvae is None else latent_resolution,
|
123 |
+
in_channels=1
|
124 |
+
if vqvae is None else vqvae.config["latent_channels"],
|
125 |
+
out_channels=1
|
126 |
+
if vqvae is None else vqvae.config["latent_channels"],
|
127 |
+
layers_per_block=2,
|
128 |
+
block_out_channels=(128, 128, 256, 256, 512, 512),
|
129 |
+
down_block_types=(
|
130 |
+
"DownBlock2D",
|
131 |
+
"DownBlock2D",
|
132 |
+
"DownBlock2D",
|
133 |
+
"DownBlock2D",
|
134 |
+
"AttnDownBlock2D",
|
135 |
+
"DownBlock2D",
|
136 |
+
),
|
137 |
+
up_block_types=(
|
138 |
+
"UpBlock2D",
|
139 |
+
"AttnUpBlock2D",
|
140 |
+
"UpBlock2D",
|
141 |
+
"UpBlock2D",
|
142 |
+
"UpBlock2D",
|
143 |
+
"UpBlock2D",
|
144 |
+
),
|
145 |
+
)
|
146 |
+
|
147 |
+
else:
|
148 |
+
model = UNet2DConditionModel(
|
149 |
+
sample_size=resolution if vqvae is None else latent_resolution,
|
150 |
+
in_channels=1
|
151 |
+
if vqvae is None else vqvae.config["latent_channels"],
|
152 |
+
out_channels=1
|
153 |
+
if vqvae is None else vqvae.config["latent_channels"],
|
154 |
+
layers_per_block=2,
|
155 |
+
block_out_channels=(128, 256, 512, 512),
|
156 |
+
down_block_types=(
|
157 |
+
"CrossAttnDownBlock2D",
|
158 |
+
"CrossAttnDownBlock2D",
|
159 |
+
"CrossAttnDownBlock2D",
|
160 |
+
"DownBlock2D",
|
161 |
+
),
|
162 |
+
up_block_types=(
|
163 |
+
"UpBlock2D",
|
164 |
+
"CrossAttnUpBlock2D",
|
165 |
+
"CrossAttnUpBlock2D",
|
166 |
+
"CrossAttnUpBlock2D",
|
167 |
+
),
|
168 |
+
cross_attention_dim=list(encodings.values())[0].shape[-1],
|
169 |
+
)
|
170 |
|
171 |
if args.scheduler == "ddpm":
|
172 |
noise_scheduler = DDPMScheduler(
|
|
|
267 |
|
268 |
with accelerator.accumulate(model):
|
269 |
# Predict the noise residual
|
270 |
+
if args.encodings is not None:
|
271 |
+
noise_pred = model(noisy_images, timesteps,
|
272 |
+
batch["encoding"])["sample"]
|
273 |
+
else:
|
274 |
+
noise_pred = model(noisy_images, timesteps)["sample"]
|
275 |
loss = F.mse_loss(noise_pred, noise)
|
276 |
accelerator.backward(loss)
|
277 |
|
|
|
301 |
|
302 |
# Generate sample images for visual inspection
|
303 |
if accelerator.is_main_process:
|
304 |
+
if ((epoch + 1) % args.save_model_epochs == 0
|
305 |
+
or (epoch + 1) % args.save_images_epochs == 0
|
306 |
+
or epoch == args.num_epochs - 1):
|
307 |
pipeline = AudioDiffusionPipeline(
|
308 |
vqvae=vqvae,
|
309 |
unet=accelerator.unwrap_model(
|
|
|
319 |
|
320 |
# save the model
|
321 |
if args.push_to_hub:
|
322 |
+
repo.push_to_hub(
|
323 |
+
commit_message=f"Epoch {epoch}",
|
324 |
+
blocking=False,
|
325 |
+
auto_lfs_prune=True,
|
326 |
+
)
|
327 |
|
328 |
if (epoch + 1) % args.save_images_epochs == 0:
|
329 |
generator = torch.Generator(
|
330 |
device=clean_images.device).manual_seed(42)
|
331 |
+
|
332 |
+
if args.encodings is not None:
|
333 |
+
random.seed(42)
|
334 |
+
encoding = torch.stack(
|
335 |
+
random.sample(list(encodings.values()),
|
336 |
+
args.eval_batch_size)).to(
|
337 |
+
clean_images.device)
|
338 |
+
else:
|
339 |
+
encoding = None
|
340 |
+
|
341 |
# run pipeline in inference (sample random noise and denoise)
|
342 |
+
images, (sample_rate, audios) = pipeline(
|
343 |
+
generator=generator,
|
344 |
+
batch_size=args.eval_batch_size,
|
345 |
+
return_dict=False,
|
346 |
+
encoding=encoding,
|
347 |
+
)
|
348 |
|
349 |
# denormalize the images and save to tensorboard
|
350 |
images = np.array([
|
|
|
430 |
default=None,
|
431 |
help="pretrained VAE model for latent diffusion",
|
432 |
)
|
433 |
+
parser.add_argument(
|
434 |
+
"--encodings",
|
435 |
+
type=str,
|
436 |
+
default=None,
|
437 |
+
help="picked dictionary mapping audio_file to encoding",
|
438 |
+
)
|
439 |
|
440 |
args = parser.parse_args()
|
441 |
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
scripts/train_vae.py
CHANGED
@@ -1,50 +1,48 @@
|
|
1 |
# based on https://github.com/CompVis/stable-diffusion/blob/main/main.py
|
2 |
|
3 |
-
import os
|
4 |
import argparse
|
|
|
5 |
|
6 |
-
import torch
|
7 |
-
import torchvision
|
8 |
import numpy as np
|
9 |
-
from PIL import Image
|
10 |
import pytorch_lightning as pl
|
11 |
-
|
12 |
-
|
|
|
|
|
13 |
from ldm.util import instantiate_from_config
|
|
|
|
|
|
|
|
|
14 |
from pytorch_lightning.trainer import Trainer
|
|
|
15 |
from torch.utils.data import DataLoader, Dataset
|
16 |
-
|
17 |
-
from diffusers.pipelines.audio_diffusion import Mel
|
18 |
from audiodiffusion.utils import convert_ldm_to_hf_vae
|
19 |
-
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
20 |
-
from pytorch_lightning.utilities.distributed import rank_zero_only
|
21 |
|
22 |
|
23 |
class AudioDiffusion(Dataset):
|
24 |
-
|
25 |
def __init__(self, model_id, channels=3):
|
26 |
super().__init__()
|
27 |
self.channels = channels
|
28 |
if os.path.exists(model_id):
|
29 |
-
self.hf_dataset = load_from_disk(model_id)[
|
30 |
else:
|
31 |
-
self.hf_dataset = load_dataset(model_id)[
|
32 |
|
33 |
def __len__(self):
|
34 |
return len(self.hf_dataset)
|
35 |
|
36 |
def __getitem__(self, idx):
|
37 |
-
image = self.hf_dataset[idx][
|
38 |
if self.channels == 3:
|
39 |
-
image = image.convert(
|
40 |
-
image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
|
41 |
-
|
42 |
-
image
|
43 |
-
return {'image': image}
|
44 |
|
45 |
|
46 |
class AudioDiffusionDataModule(pl.LightningDataModule):
|
47 |
-
|
48 |
def __init__(self, model_id, batch_size, channels):
|
49 |
super().__init__()
|
50 |
self.batch_size = batch_size
|
@@ -52,18 +50,11 @@ class AudioDiffusionDataModule(pl.LightningDataModule):
|
|
52 |
self.num_workers = 1
|
53 |
|
54 |
def train_dataloader(self):
|
55 |
-
return DataLoader(self.dataset,
|
56 |
-
batch_size=self.batch_size,
|
57 |
-
num_workers=self.num_workers)
|
58 |
|
59 |
|
60 |
class ImageLogger(Callback):
|
61 |
-
|
62 |
-
def __init__(self,
|
63 |
-
every=1000,
|
64 |
-
hop_length=512,
|
65 |
-
sample_rate=22050,
|
66 |
-
n_fft=2048):
|
67 |
super().__init__()
|
68 |
self.every = every
|
69 |
self.hop_length = hop_length
|
@@ -74,83 +65,75 @@ class ImageLogger(Callback):
|
|
74 |
def log_images_and_audios(self, pl_module, batch):
|
75 |
pl_module.eval()
|
76 |
with torch.no_grad():
|
77 |
-
images = pl_module.log_images(batch, split=
|
78 |
pl_module.train()
|
79 |
|
80 |
image_shape = next(iter(images.values())).shape
|
81 |
channels = image_shape[1]
|
82 |
-
mel = Mel(
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
87 |
|
88 |
for k in images:
|
89 |
images[k] = images[k].detach().cpu()
|
90 |
-
images[k] = torch.clamp(images[k], -1
|
91 |
images[k] = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
92 |
grid = torchvision.utils.make_grid(images[k])
|
93 |
|
94 |
tag = f"train/{k}"
|
95 |
-
pl_module.logger.experiment.add_image(
|
96 |
-
tag, grid, global_step=pl_module.global_step)
|
97 |
|
98 |
-
images[k] = (images[k].numpy() *
|
99 |
-
255).round().astype("uint8").transpose(0, 2, 3, 1)
|
100 |
for _, image in enumerate(images[k]):
|
101 |
audio = mel.image_to_audio(
|
102 |
-
Image.fromarray(image, mode=
|
103 |
-
if channels == 3
|
|
|
|
|
104 |
pl_module.logger.experiment.add_audio(
|
105 |
tag + f"/{_}",
|
106 |
normalize(audio),
|
107 |
global_step=pl_module.global_step,
|
108 |
-
sample_rate=mel.get_sample_rate()
|
|
|
109 |
|
110 |
-
def on_train_batch_end(self, trainer, pl_module, outputs, batch,
|
111 |
-
batch_idx):
|
112 |
if (batch_idx + 1) % self.every != 0:
|
113 |
return
|
114 |
self.log_images_and_audios(pl_module, batch)
|
115 |
|
116 |
|
117 |
class HFModelCheckpoint(ModelCheckpoint):
|
118 |
-
|
119 |
def __init__(self, ldm_config, hf_checkpoint, *args, **kwargs):
|
120 |
super().__init__(*args, **kwargs)
|
121 |
self.ldm_config = ldm_config
|
122 |
self.hf_checkpoint = hf_checkpoint
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
def on_train_epoch_end(self, trainer, pl_module):
|
125 |
-
ldm_checkpoint = self._get_metric_interpolated_filepath_name(
|
126 |
-
{'epoch': trainer.current_epoch}, trainer)
|
127 |
super().on_train_epoch_end(trainer, pl_module)
|
128 |
-
|
129 |
-
|
130 |
|
131 |
|
132 |
if __name__ == "__main__":
|
133 |
parser = argparse.ArgumentParser(description="Train VAE using ldm.")
|
134 |
parser.add_argument("-d", "--dataset_name", type=str, default=None)
|
135 |
parser.add_argument("-b", "--batch_size", type=int, default=1)
|
136 |
-
parser.add_argument("-c",
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
parser.add_argument("--
|
141 |
-
type=str,
|
142 |
-
default="models/ldm-autoencoder-kl")
|
143 |
-
parser.add_argument("--hf_checkpoint_dir",
|
144 |
-
type=str,
|
145 |
-
default="models/autoencoder-kl")
|
146 |
-
parser.add_argument("-r",
|
147 |
-
"--resume_from_checkpoint",
|
148 |
-
type=str,
|
149 |
-
default=None)
|
150 |
-
parser.add_argument("-g",
|
151 |
-
"--gradient_accumulation_steps",
|
152 |
-
type=int,
|
153 |
-
default=1)
|
154 |
parser.add_argument("--hop_length", type=int, default=512)
|
155 |
parser.add_argument("--sample_rate", type=int, default=22050)
|
156 |
parser.add_argument("--n_fft", type=int, default=2048)
|
@@ -164,7 +147,8 @@ if __name__ == "__main__":
|
|
164 |
data = AudioDiffusionDataModule(
|
165 |
model_id=args.dataset_name,
|
166 |
batch_size=args.batch_size,
|
167 |
-
channels=config.model.params.ddconfig.in_channels
|
|
|
168 |
lightning_config = config.pop("lightning", OmegaConf.create())
|
169 |
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
170 |
trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps
|
@@ -174,15 +158,20 @@ if __name__ == "__main__":
|
|
174 |
max_epochs=args.max_epochs,
|
175 |
resume_from_checkpoint=args.resume_from_checkpoint,
|
176 |
callbacks=[
|
177 |
-
ImageLogger(
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
188 |
trainer.fit(model, data)
|
|
|
1 |
# based on https://github.com/CompVis/stable-diffusion/blob/main/main.py
|
2 |
|
|
|
3 |
import argparse
|
4 |
+
import os
|
5 |
|
|
|
|
|
6 |
import numpy as np
|
|
|
7 |
import pytorch_lightning as pl
|
8 |
+
import torch
|
9 |
+
import torchvision
|
10 |
+
from datasets import load_dataset, load_from_disk
|
11 |
+
from diffusers.pipelines.audio_diffusion import Mel
|
12 |
from ldm.util import instantiate_from_config
|
13 |
+
from librosa.util import normalize
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from PIL import Image
|
16 |
+
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
17 |
from pytorch_lightning.trainer import Trainer
|
18 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
19 |
from torch.utils.data import DataLoader, Dataset
|
20 |
+
|
|
|
21 |
from audiodiffusion.utils import convert_ldm_to_hf_vae
|
|
|
|
|
22 |
|
23 |
|
24 |
class AudioDiffusion(Dataset):
|
|
|
25 |
def __init__(self, model_id, channels=3):
|
26 |
super().__init__()
|
27 |
self.channels = channels
|
28 |
if os.path.exists(model_id):
|
29 |
+
self.hf_dataset = load_from_disk(model_id)["train"]
|
30 |
else:
|
31 |
+
self.hf_dataset = load_dataset(model_id)["train"]
|
32 |
|
33 |
def __len__(self):
|
34 |
return len(self.hf_dataset)
|
35 |
|
36 |
def __getitem__(self, idx):
|
37 |
+
image = self.hf_dataset[idx]["image"]
|
38 |
if self.channels == 3:
|
39 |
+
image = image.convert("RGB")
|
40 |
+
image = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width, self.channels))
|
41 |
+
image = (image / 255) * 2 - 1
|
42 |
+
return {"image": image}
|
|
|
43 |
|
44 |
|
45 |
class AudioDiffusionDataModule(pl.LightningDataModule):
|
|
|
46 |
def __init__(self, model_id, batch_size, channels):
|
47 |
super().__init__()
|
48 |
self.batch_size = batch_size
|
|
|
50 |
self.num_workers = 1
|
51 |
|
52 |
def train_dataloader(self):
|
53 |
+
return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=self.num_workers)
|
|
|
|
|
54 |
|
55 |
|
56 |
class ImageLogger(Callback):
|
57 |
+
def __init__(self, every=1000, hop_length=512, sample_rate=22050, n_fft=2048):
|
|
|
|
|
|
|
|
|
|
|
58 |
super().__init__()
|
59 |
self.every = every
|
60 |
self.hop_length = hop_length
|
|
|
65 |
def log_images_and_audios(self, pl_module, batch):
|
66 |
pl_module.eval()
|
67 |
with torch.no_grad():
|
68 |
+
images = pl_module.log_images(batch, split="train")
|
69 |
pl_module.train()
|
70 |
|
71 |
image_shape = next(iter(images.values())).shape
|
72 |
channels = image_shape[1]
|
73 |
+
mel = Mel(
|
74 |
+
x_res=image_shape[2],
|
75 |
+
y_res=image_shape[3],
|
76 |
+
hop_length=self.hop_length,
|
77 |
+
sample_rate=self.sample_rate,
|
78 |
+
n_fft=self.n_fft,
|
79 |
+
)
|
80 |
|
81 |
for k in images:
|
82 |
images[k] = images[k].detach().cpu()
|
83 |
+
images[k] = torch.clamp(images[k], -1.0, 1.0)
|
84 |
images[k] = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
85 |
grid = torchvision.utils.make_grid(images[k])
|
86 |
|
87 |
tag = f"train/{k}"
|
88 |
+
pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
|
|
|
89 |
|
90 |
+
images[k] = (images[k].numpy() * 255).round().astype("uint8").transpose(0, 2, 3, 1)
|
|
|
91 |
for _, image in enumerate(images[k]):
|
92 |
audio = mel.image_to_audio(
|
93 |
+
Image.fromarray(image, mode="RGB").convert("L")
|
94 |
+
if channels == 3
|
95 |
+
else Image.fromarray(image[:, :, 0])
|
96 |
+
)
|
97 |
pl_module.logger.experiment.add_audio(
|
98 |
tag + f"/{_}",
|
99 |
normalize(audio),
|
100 |
global_step=pl_module.global_step,
|
101 |
+
sample_rate=mel.get_sample_rate(),
|
102 |
+
)
|
103 |
|
104 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
|
|
105 |
if (batch_idx + 1) % self.every != 0:
|
106 |
return
|
107 |
self.log_images_and_audios(pl_module, batch)
|
108 |
|
109 |
|
110 |
class HFModelCheckpoint(ModelCheckpoint):
|
|
|
111 |
def __init__(self, ldm_config, hf_checkpoint, *args, **kwargs):
|
112 |
super().__init__(*args, **kwargs)
|
113 |
self.ldm_config = ldm_config
|
114 |
self.hf_checkpoint = hf_checkpoint
|
115 |
+
self.sample_size = None
|
116 |
+
|
117 |
+
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
118 |
+
if self.sample_size is None:
|
119 |
+
self.sample_size = list(batch["image"].shape[1:3])
|
120 |
|
121 |
def on_train_epoch_end(self, trainer, pl_module):
|
122 |
+
ldm_checkpoint = self._get_metric_interpolated_filepath_name({"epoch": trainer.current_epoch}, trainer)
|
|
|
123 |
super().on_train_epoch_end(trainer, pl_module)
|
124 |
+
self.ldm_config.model.params.ddconfig.resolution = self.sample_size
|
125 |
+
convert_ldm_to_hf_vae(ldm_checkpoint, self.ldm_config, self.hf_checkpoint, self.sample_size)
|
126 |
|
127 |
|
128 |
if __name__ == "__main__":
|
129 |
parser = argparse.ArgumentParser(description="Train VAE using ldm.")
|
130 |
parser.add_argument("-d", "--dataset_name", type=str, default=None)
|
131 |
parser.add_argument("-b", "--batch_size", type=int, default=1)
|
132 |
+
parser.add_argument("-c", "--ldm_config_file", type=str, default="config/ldm_autoencoder_kl.yaml")
|
133 |
+
parser.add_argument("--ldm_checkpoint_dir", type=str, default="models/ldm-autoencoder-kl")
|
134 |
+
parser.add_argument("--hf_checkpoint_dir", type=str, default="models/autoencoder-kl")
|
135 |
+
parser.add_argument("-r", "--resume_from_checkpoint", type=str, default=None)
|
136 |
+
parser.add_argument("-g", "--gradient_accumulation_steps", type=int, default=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
parser.add_argument("--hop_length", type=int, default=512)
|
138 |
parser.add_argument("--sample_rate", type=int, default=22050)
|
139 |
parser.add_argument("--n_fft", type=int, default=2048)
|
|
|
147 |
data = AudioDiffusionDataModule(
|
148 |
model_id=args.dataset_name,
|
149 |
batch_size=args.batch_size,
|
150 |
+
channels=config.model.params.ddconfig.in_channels,
|
151 |
+
)
|
152 |
lightning_config = config.pop("lightning", OmegaConf.create())
|
153 |
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
154 |
trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps
|
|
|
158 |
max_epochs=args.max_epochs,
|
159 |
resume_from_checkpoint=args.resume_from_checkpoint,
|
160 |
callbacks=[
|
161 |
+
ImageLogger(
|
162 |
+
every=args.save_images_batches,
|
163 |
+
hop_length=args.hop_length,
|
164 |
+
sample_rate=args.sample_rate,
|
165 |
+
n_fft=args.n_fft,
|
166 |
+
),
|
167 |
+
HFModelCheckpoint(
|
168 |
+
ldm_config=config,
|
169 |
+
hf_checkpoint=args.hf_checkpoint_dir,
|
170 |
+
dirpath=args.ldm_checkpoint_dir,
|
171 |
+
filename="{epoch:06}",
|
172 |
+
verbose=True,
|
173 |
+
save_last=True,
|
174 |
+
),
|
175 |
+
],
|
176 |
+
)
|
177 |
trainer.fit(model, data)
|