teticio commited on
Commit
f29faf1
1 Parent(s): af0157b

add conditional training

Browse files
README.md CHANGED
@@ -23,7 +23,9 @@ Go to https://soundcloud.com/teticio2/sets/audio-diffusion-loops for more exampl
23
  ---
24
  #### Updates
25
 
26
- **5/12/2022** 🤗 Exciting news! `AudioDiffusionPipeline` has been migrated to the Hugging Face `diffusers` package so that it is even easier for others to use and contribute.
 
 
27
 
28
  **2/12/2022**. Added Mel to pipeline and updated the pretrained models to save Mel config (they are now no longer compatible with previous versions of this repo). It is relatively straightforward to migrate previously trained models to the new format (see https://huggingface.co/teticio/audio-diffusion-256).
29
 
@@ -58,7 +60,8 @@ You can play around with some pre-trained models on [Google Colab](https://colab
58
  | [teticio/audio-diffusion-instrumental-hiphop-256](https://huggingface.co/teticio/audio-diffusion-instrumental-hiphop-256) | [teticio/audio-diffusion-instrumental-hiphop-256](https://huggingface.co/datasets/teticio/audio-diffusion-instrumental-hiphop-256) | Instrumental Hip Hop music |
59
  | [teticio/audio-diffusion-ddim-256](https://huggingface.co/teticio/audio-diffusion-ddim-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | De-noising Diffusion Implicit Model |
60
  | [teticio/latent-audio-diffusion-256](https://huggingface.co/teticio/latent-audio-diffusion-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | Latent Audio Diffusion model |
61
- | [teticio/latent-audio-diffusion-ddim-256](https://huggingface.co/teticio/latent-audio-diffusion-ddim-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | Latent Audio Diffusion De-noising Diffusion Implicit Model |
 
62
 
63
  ---
64
 
@@ -106,7 +109,7 @@ Note that the default `sample_rate` is 22050 and audios will be resampled if the
106
 
107
  ```bash
108
  accelerate launch --config_file config/accelerate_local.yaml \
109
- scripts/train_unconditional.py \
110
  --dataset_name data/audio-diffusion-64 \
111
  --hop_length 1024 \
112
  --output_dir models/ddpm-ema-audio-64 \
@@ -122,7 +125,7 @@ scripts/train_unconditional.py \
122
 
123
  ```bash
124
  accelerate launch --config_file config/accelerate_local.yaml \
125
- scripts/train_unconditional.py \
126
  --dataset_name teticio/audio-diffusion-256 \
127
  --output_dir models/audio-diffusion-256 \
128
  --num_epochs 100 \
@@ -141,7 +144,7 @@ scripts/train_unconditional.py \
141
 
142
  ```bash
143
  accelerate launch --config_file config/accelerate_sagemaker.yaml \
144
- scripts/train_unconditional.py \
145
  --dataset_name teticio/audio-diffusion-256 \
146
  --output_dir models/ddpm-ema-audio-256 \
147
  --train_batch_size 16 \
@@ -200,3 +203,33 @@ accelerate launch ...
200
  ...
201
  --vae models/autoencoder-kl
202
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ---
24
  #### Updates
25
 
26
+ **25/12/2022**. Now it is possible to train models conditional on an encoding (of text or audio, for example). See the section on Conditional Audio Generation below.
27
+
28
+ **5/12/2022**. 🤗 Exciting news! `AudioDiffusionPipeline` has been migrated to the Hugging Face `diffusers` package so that it is even easier for others to use and contribute.
29
 
30
  **2/12/2022**. Added Mel to pipeline and updated the pretrained models to save Mel config (they are now no longer compatible with previous versions of this repo). It is relatively straightforward to migrate previously trained models to the new format (see https://huggingface.co/teticio/audio-diffusion-256).
31
 
 
60
  | [teticio/audio-diffusion-instrumental-hiphop-256](https://huggingface.co/teticio/audio-diffusion-instrumental-hiphop-256) | [teticio/audio-diffusion-instrumental-hiphop-256](https://huggingface.co/datasets/teticio/audio-diffusion-instrumental-hiphop-256) | Instrumental Hip Hop music |
61
  | [teticio/audio-diffusion-ddim-256](https://huggingface.co/teticio/audio-diffusion-ddim-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | De-noising Diffusion Implicit Model |
62
  | [teticio/latent-audio-diffusion-256](https://huggingface.co/teticio/latent-audio-diffusion-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | Latent Audio Diffusion model |
63
+ | [teticio/latent-audio-diffusion-ddim-256](https://huggingface.co/teticio/latent-audio-diffusion-ddim-256) | [teticio/audio-diffusion-256](https://huggingface.co/datasets/teticio/audio-diffusion-256) | Latent Audio Diffusion Implicit Model |
64
+ | [teticio/conditional-latent-audio-diffusion-512](https://huggingface.co/teticio/latent-audio-diffusion-512) | [teticio/audio-diffusion-512](https://huggingface.co/datasets/teticio/audio-diffusion-512) | Conditional Latent Audio Diffusion Model |
65
 
66
  ---
67
 
 
109
 
110
  ```bash
111
  accelerate launch --config_file config/accelerate_local.yaml \
112
+ scripts/train_unet.py \
113
  --dataset_name data/audio-diffusion-64 \
114
  --hop_length 1024 \
115
  --output_dir models/ddpm-ema-audio-64 \
 
125
 
126
  ```bash
127
  accelerate launch --config_file config/accelerate_local.yaml \
128
+ scripts/train_unet.py \
129
  --dataset_name teticio/audio-diffusion-256 \
130
  --output_dir models/audio-diffusion-256 \
131
  --num_epochs 100 \
 
144
 
145
  ```bash
146
  accelerate launch --config_file config/accelerate_sagemaker.yaml \
147
+ scripts/train_unet.py \
148
  --dataset_name teticio/audio-diffusion-256 \
149
  --output_dir models/ddpm-ema-audio-256 \
150
  --train_batch_size 16 \
 
203
  ...
204
  --vae models/autoencoder-kl
205
  ```
206
+
207
+ ## Conditional Audio Generation
208
+
209
+ We can generate audio conditional on a text prompt - or indeed anything which can be encoded into a bunch of numbers - much like DALL-E2 and Midjourney. It is generally harder to find good quality datasets of audios together with descriptions, although the people behind the dataset used to train Midjourney are making some very interesting progress [here](https://github.com/LAION-AI/audio-dataset). I have chosen to encode the audio directly instead based on "how it sounds", using a [model which I trained on hundreds of thousands of Spotify playlists](https://github.com/teticio/Deej-AI). To encode an audio into a 100 dimensional vector
210
+
211
+ ```python
212
+ from diffusers import Mel
213
+ from audiodiffusion.audio_encoder import AudioEncoder
214
+
215
+ audio_encoder = AudioEncoder.from_pretrained("teticio/audio-encoder")
216
+ audio_encoder.encode(['/home/teticio/Music/liked/Agua Re - Holy Dance - Large Sound Mix.mp3'])
217
+ ```
218
+
219
+ One you have prepared a dataset, you can encode the audio files with this script
220
+
221
+ ```bash
222
+ python scripts/encode_audio \
223
+ --dataset_name teticio/audio-diffusion-256 \
224
+ --out_file data/encodings.p
225
+ ```
226
+
227
+ Then you can train a model with
228
+
229
+ ```bash
230
+ accelerate launch ...
231
+ ...
232
+ --encodings data/encodings.p
233
+ ```
234
+
235
+ When generating audios, you will need to pass an `encodings` Tensor. See the [`conditional_generation.ipynb`](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/conditional_generation.ipynb) notebook for an example that uses encodings of Spotify track previews to influence the generation.
audiodiffusion/__init__.py CHANGED
@@ -1,21 +1,24 @@
1
  from typing import Iterable, Tuple
2
 
3
- import torch
4
  import numpy as np
 
 
5
  from PIL import Image
6
  from tqdm.auto import tqdm
7
- from librosa.beat import beat_track
8
- from diffusers import AudioDiffusionPipeline
9
 
10
- VERSION = "1.3.2"
 
11
 
 
12
 
13
- class AudioDiffusion:
14
 
15
- def __init__(self,
16
- model_id: str = "teticio/audio-diffusion-256",
17
- cuda: bool = torch.cuda.is_available(),
18
- progress_bar: Iterable = tqdm):
 
 
 
19
  """Class for generating audio using De-noising Diffusion Probabilistic Models.
20
 
21
  Args:
@@ -35,7 +38,8 @@ class AudioDiffusion:
35
  generator: torch.Generator = None,
36
  step_generator: torch.Generator = None,
37
  eta: float = 0,
38
- noise: torch.Tensor = None
 
39
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
40
  """Generate random mel spectrogram and convert to audio.
41
 
@@ -45,19 +49,22 @@ class AudioDiffusion:
45
  step_generator (torch.Generator): random number generator used to de-noise or None
46
  eta (float): parameter between 0 and 1 used with DDIM scheduler
47
  noise (torch.Tensor): noisy image or None
 
48
 
49
  Returns:
50
  PIL Image: mel spectrogram
51
  (float, np.ndarray): sample rate and raw audio
52
  """
53
- images, (sample_rate,
54
- audios) = self.pipe(batch_size=1,
55
- steps=steps,
56
- generator=generator,
57
- step_generator=step_generator,
58
- eta=eta,
59
- noise=noise,
60
- return_dict=False)
 
 
61
  return images[0], (sample_rate, audios[0])
62
 
63
  def generate_spectrogram_and_audio_from_audio(
@@ -72,7 +79,8 @@ class AudioDiffusion:
72
  mask_end_secs: float = 0,
73
  step_generator: torch.Generator = None,
74
  eta: float = 0,
75
- noise: torch.Tensor = None
 
76
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
77
  """Generate random mel spectrogram from audio input and convert to audio.
78
 
@@ -87,6 +95,7 @@ class AudioDiffusion:
87
  mask_end_secs (float): number of seconds of audio to mask (not generate) at end
88
  step_generator (torch.Generator): random number generator used to de-noise or None
89
  eta (float): parameter between 0 and 1 used with DDIM scheduler
 
90
  noise (torch.Tensor): noisy image or None
91
 
92
  Returns:
@@ -94,26 +103,26 @@ class AudioDiffusion:
94
  (float, np.ndarray): sample rate and raw audio
95
  """
96
 
97
- images, (sample_rate,
98
- audios) = self.pipe(batch_size=1,
99
- audio_file=audio_file,
100
- raw_audio=raw_audio,
101
- slice=slice,
102
- start_step=start_step,
103
- steps=steps,
104
- generator=generator,
105
- mask_start_secs=mask_start_secs,
106
- mask_end_secs=mask_end_secs,
107
- step_generator=step_generator,
108
- eta=eta,
109
- noise=noise,
110
- return_dict=False)
 
 
111
  return images[0], (sample_rate, audios[0])
112
 
113
  @staticmethod
114
- def loop_it(audio: np.ndarray,
115
- sample_rate: int,
116
- loops: int = 12) -> np.ndarray:
117
  """Loop audio
118
 
119
  Args:
@@ -124,403 +133,8 @@ class AudioDiffusion:
124
  Returns:
125
  (float, np.ndarray): sample rate and raw audio or None
126
  """
127
- _, beats = beat_track(y=audio, sr=sample_rate, units='samples')
128
- for beats_in_bar in [16, 12, 8, 4]:
129
- if len(beats) > beats_in_bar:
130
- return np.tile(audio[beats[0]:beats[beats_in_bar]], loops)
131
  return None
132
-
133
-
134
- '''
135
- # This code will be migrated to diffusers shortly
136
-
137
- #-----------------------------------------------------------------------------#
138
-
139
- import os
140
- import warnings
141
- from typing import Any, Dict, Optional, Union
142
-
143
- from diffusers.configuration_utils import ConfigMixin, register_to_config
144
- from diffusers.schedulers.scheduling_utils import SchedulerMixin
145
-
146
-
147
- warnings.filterwarnings("ignore")
148
-
149
- import numpy as np # noqa: E402
150
-
151
- import librosa # noqa: E402
152
- from PIL import Image # noqa: E402
153
-
154
-
155
- class Mel(ConfigMixin, SchedulerMixin):
156
- """
157
- Parameters:
158
- x_res (`int`): x resolution of spectrogram (time)
159
- y_res (`int`): y resolution of spectrogram (frequency bins)
160
- sample_rate (`int`): sample rate of audio
161
- n_fft (`int`): number of Fast Fourier Transforms
162
- hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res)
163
- top_db (`int`): loudest in decibels
164
- n_iter (`int`): number of iterations for Griffin Linn mel inversion
165
- """
166
-
167
- config_name = "mel_config.json"
168
-
169
- @register_to_config
170
- def __init__(
171
- self,
172
- x_res: int = 256,
173
- y_res: int = 256,
174
- sample_rate: int = 22050,
175
- n_fft: int = 2048,
176
- hop_length: int = 512,
177
- top_db: int = 80,
178
- n_iter: int = 32,
179
- ):
180
- self.hop_length = hop_length
181
- self.sr = sample_rate
182
- self.n_fft = n_fft
183
- self.top_db = top_db
184
- self.n_iter = n_iter
185
- self.set_resolution(x_res, y_res)
186
- self.audio = None
187
-
188
- def set_resolution(self, x_res: int, y_res: int):
189
- """Set resolution.
190
-
191
- Args:
192
- x_res (`int`): x resolution of spectrogram (time)
193
- y_res (`int`): y resolution of spectrogram (frequency bins)
194
- """
195
- self.x_res = x_res
196
- self.y_res = y_res
197
- self.n_mels = self.y_res
198
- self.slice_size = self.x_res * self.hop_length - 1
199
-
200
- def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
201
- """Load audio.
202
-
203
- Args:
204
- audio_file (`str`): must be a file on disk due to Librosa limitation or
205
- raw_audio (`np.ndarray`): audio as numpy array
206
- """
207
- if audio_file is not None:
208
- self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr)
209
- else:
210
- self.audio = raw_audio
211
-
212
- # Pad with silence if necessary.
213
- if len(self.audio) < self.x_res * self.hop_length:
214
- self.audio = np.concatenate([self.audio, np.zeros((self.x_res * self.hop_length - len(self.audio),))])
215
-
216
- def get_number_of_slices(self) -> int:
217
- """Get number of slices in audio.
218
-
219
- Returns:
220
- `int`: number of spectograms audio can be sliced into
221
- """
222
- return len(self.audio) // self.slice_size
223
-
224
- def get_audio_slice(self, slice: int = 0) -> np.ndarray:
225
- """Get slice of audio.
226
-
227
- Args:
228
- slice (`int`): slice number of audio (out of get_number_of_slices())
229
-
230
- Returns:
231
- `np.ndarray`: audio as numpy array
232
- """
233
- return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)]
234
-
235
- def get_sample_rate(self) -> int:
236
- """Get sample rate:
237
-
238
- Returns:
239
- `int`: sample rate of audio
240
- """
241
- return self.sr
242
-
243
- def audio_slice_to_image(self, slice: int) -> Image.Image:
244
- """Convert slice of audio to spectrogram.
245
-
246
- Args:
247
- slice (`int`): slice number of audio to convert (out of get_number_of_slices())
248
-
249
- Returns:
250
- `PIL Image`: grayscale image of x_res x y_res
251
- """
252
- S = librosa.feature.melspectrogram(
253
- y=self.get_audio_slice(slice), sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels
254
- )
255
- log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
256
- bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
257
- image = Image.fromarray(bytedata)
258
- return image
259
-
260
- def image_to_audio(self, image: Image.Image) -> np.ndarray:
261
- """Converts spectrogram to audio.
262
-
263
- Args:
264
- image (`PIL Image`): x_res x y_res grayscale image
265
-
266
- Returns:
267
- audio (`np.ndarray`): raw audio
268
- """
269
- bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width))
270
- log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
271
- S = librosa.db_to_power(log_S)
272
- audio = librosa.feature.inverse.mel_to_audio(
273
- S, sr=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_iter=self.n_iter
274
- )
275
- return audio
276
-
277
- #-----------------------------------------------------------------------------#
278
-
279
- from math import acos, sin
280
- from typing import List, Tuple, Union
281
-
282
- import numpy as np
283
- import torch
284
-
285
- from PIL import Image
286
-
287
- from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline, DDIMScheduler, DDPMScheduler
288
- from diffusers.pipeline_utils import AudioPipelineOutput, BaseOutput, ImagePipelineOutput
289
-
290
-
291
- class AudioDiffusionPipeline(DiffusionPipeline):
292
- """
293
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
294
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
295
-
296
- Parameters:
297
- vqae ([`AutoencoderKL`]): Variational AutoEncoder for Latent Audio Diffusion or None
298
- unet ([`UNet2DConditionModel`]): UNET model
299
- mel ([`Mel`]): transform audio <-> spectrogram
300
- scheduler ([`DDIMScheduler` or `DDPMScheduler`]): de-noising scheduler
301
- """
302
-
303
- _optional_components = ["vqvae"]
304
-
305
- def __init__(
306
- self,
307
- vqvae: AutoencoderKL,
308
- unet: UNet2DConditionModel,
309
- mel: Mel,
310
- scheduler: Union[DDIMScheduler, DDPMScheduler],
311
- ):
312
- super().__init__()
313
- self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae)
314
-
315
- def get_input_dims(self) -> Tuple:
316
- """Returns dimension of input image
317
-
318
- Returns:
319
- `Tuple`: (height, width)
320
- """
321
- input_module = self.vqvae if self.vqvae is not None else self.unet
322
- # For backwards compatibility
323
- sample_size = (
324
- (input_module.sample_size, input_module.sample_size)
325
- if type(input_module.sample_size) == int
326
- else input_module.sample_size
327
- )
328
- return sample_size
329
-
330
- def get_default_steps(self) -> int:
331
- """Returns default number of steps recommended for inference
332
-
333
- Returns:
334
- `int`: number of steps
335
- """
336
- return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000
337
-
338
- @torch.no_grad()
339
- def __call__(
340
- self,
341
- batch_size: int = 1,
342
- audio_file: str = None,
343
- raw_audio: np.ndarray = None,
344
- slice: int = 0,
345
- start_step: int = 0,
346
- steps: int = None,
347
- generator: torch.Generator = None,
348
- mask_start_secs: float = 0,
349
- mask_end_secs: float = 0,
350
- step_generator: torch.Generator = None,
351
- eta: float = 0,
352
- noise: torch.Tensor = None,
353
- return_dict=True,
354
- ) -> Union[
355
- Union[AudioPipelineOutput, ImagePipelineOutput], Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]
356
- ]:
357
- """Generate random mel spectrogram from audio input and convert to audio.
358
-
359
- Args:
360
- batch_size (`int`): number of samples to generate
361
- audio_file (`str`): must be a file on disk due to Librosa limitation or
362
- raw_audio (`np.ndarray`): audio as numpy array
363
- slice (`int`): slice number of audio to convert
364
- start_step (int): step to start from
365
- steps (`int`): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM)
366
- generator (`torch.Generator`): random number generator or None
367
- mask_start_secs (`float`): number of seconds of audio to mask (not generate) at start
368
- mask_end_secs (`float`): number of seconds of audio to mask (not generate) at end
369
- step_generator (`torch.Generator`): random number generator used to de-noise or None
370
- eta (`float`): parameter between 0 and 1 used with DDIM scheduler
371
- noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
372
- return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
373
-
374
- Returns:
375
- `List[PIL Image]`: mel spectrograms (`float`, `List[np.ndarray]`): sample rate and raw audios
376
- """
377
-
378
- steps = steps or self.get_default_steps()
379
- self.scheduler.set_timesteps(steps)
380
- step_generator = step_generator or generator
381
- # For backwards compatibility
382
- if type(self.unet.sample_size) == int:
383
- self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size)
384
- input_dims = self.get_input_dims()
385
- self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
386
- if noise is None:
387
- noise = torch.randn(
388
- (batch_size, self.unet.in_channels, self.unet.sample_size[0], self.unet.sample_size[1]),
389
- generator=generator,
390
- device=self.device,
391
- )
392
- images = noise
393
- mask = None
394
-
395
- if audio_file is not None or raw_audio is not None:
396
- self.mel.load_audio(audio_file, raw_audio)
397
- input_image = self.mel.audio_slice_to_image(slice)
398
- input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape(
399
- (input_image.height, input_image.width)
400
- )
401
- input_image = (input_image / 255) * 2 - 1
402
- input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device)
403
-
404
- if self.vqvae is not None:
405
- input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
406
- generator=generator
407
- )[0]
408
- input_images = 0.18215 * input_images
409
-
410
- if start_step > 0:
411
- images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
412
-
413
- pixels_per_second = (
414
- self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
415
- )
416
- mask_start = int(mask_start_secs * pixels_per_second)
417
- mask_end = int(mask_end_secs * pixels_per_second)
418
- mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
419
-
420
- for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])):
421
- model_output = self.unet(images, t)["sample"]
422
-
423
- if isinstance(self.scheduler, DDIMScheduler):
424
- images = self.scheduler.step(
425
- model_output=model_output, timestep=t, sample=images, eta=eta, generator=step_generator
426
- )["prev_sample"]
427
- else:
428
- images = self.scheduler.step(
429
- model_output=model_output, timestep=t, sample=images, generator=step_generator
430
- )["prev_sample"]
431
-
432
- if mask is not None:
433
- if mask_start > 0:
434
- images[:, :, :, :mask_start] = mask[:, step, :, :mask_start]
435
- if mask_end > 0:
436
- images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:]
437
-
438
- if self.vqvae is not None:
439
- # 0.18215 was scaling factor used in training to ensure unit variance
440
- images = 1 / 0.18215 * images
441
- images = self.vqvae.decode(images)["sample"]
442
-
443
- images = (images / 2 + 0.5).clamp(0, 1)
444
- images = images.cpu().permute(0, 2, 3, 1).numpy()
445
- images = (images * 255).round().astype("uint8")
446
- images = list(
447
- map(lambda _: Image.fromarray(_[:, :, 0]), images)
448
- if images.shape[3] == 1
449
- else map(lambda _: Image.fromarray(_, mode="RGB").convert("L"), images)
450
- )
451
-
452
- audios = list(map(lambda _: self.mel.image_to_audio(_), images))
453
- if not return_dict:
454
- return images, (self.mel.get_sample_rate(), audios)
455
-
456
- return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images))
457
-
458
- @torch.no_grad()
459
- def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray:
460
- """Reverse step process: recover noisy image from generated image.
461
-
462
- Args:
463
- images (`List[PIL Image]`): list of images to encode
464
- steps (`int`): number of encoding steps to perform (defaults to 50)
465
-
466
- Returns:
467
- `np.ndarray`: noise tensor of shape (batch_size, 1, height, width)
468
- """
469
-
470
- # Only works with DDIM as this method is deterministic
471
- assert isinstance(self.scheduler, DDIMScheduler)
472
- self.scheduler.set_timesteps(steps)
473
- sample = np.array(
474
- [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images]
475
- )
476
- sample = (sample / 255) * 2 - 1
477
- sample = torch.Tensor(sample).to(self.device)
478
-
479
- for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))):
480
- prev_timestep = t - self.scheduler.num_train_timesteps // self.scheduler.num_inference_steps
481
- alpha_prod_t = self.scheduler.alphas_cumprod[t]
482
- alpha_prod_t_prev = (
483
- self.scheduler.alphas_cumprod[prev_timestep]
484
- if prev_timestep >= 0
485
- else self.scheduler.final_alpha_cumprod
486
- )
487
- beta_prod_t = 1 - alpha_prod_t
488
- model_output = self.unet(sample, t)["sample"]
489
- pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output
490
- sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5)
491
- sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output
492
-
493
- return sample
494
-
495
- @staticmethod
496
- def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor:
497
- """Spherical Linear intERPolation
498
-
499
- Args:
500
- x0 (`torch.Tensor`): first tensor to interpolate between
501
- x1 (`torch.Tensor`): seconds tensor to interpolate between
502
- alpha (`float`): interpolation between 0 and 1
503
-
504
- Returns:
505
- `torch.Tensor`: interpolated tensor
506
- """
507
-
508
- theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1))
509
- return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta)
510
-
511
-
512
- import sys
513
- import diffusers
514
-
515
- class audio_diffusion():
516
- __name__ = 'audio_diffusion'
517
- pass
518
-
519
-
520
- sys.modules['audio_diffusion'] = audio_diffusion
521
- setattr(audio_diffusion, Mel.__name__, Mel)
522
- diffusers.AudioDiffusionPipeline = AudioDiffusionPipeline
523
- setattr(diffusers, AudioDiffusionPipeline.__name__, AudioDiffusionPipeline)
524
- diffusers.pipeline_utils.LOADABLE_CLASSES['audio_diffusion'] = {}
525
- diffusers.pipeline_utils.LOADABLE_CLASSES['audio_diffusion']['Mel'] = ["save_pretrained", "from_pretrained"]
526
- '''
 
1
  from typing import Iterable, Tuple
2
 
 
3
  import numpy as np
4
+ import torch
5
+ from librosa.beat import beat_track
6
  from PIL import Image
7
  from tqdm.auto import tqdm
 
 
8
 
9
+ # from diffusers import AudioDiffusionPipeline
10
+ from .pipeline_audio_diffusion import AudioDiffusionPipeline
11
 
12
+ VERSION = "1.4.0"
13
 
 
14
 
15
+ class AudioDiffusion:
16
+ def __init__(
17
+ self,
18
+ model_id: str = "teticio/audio-diffusion-256",
19
+ cuda: bool = torch.cuda.is_available(),
20
+ progress_bar: Iterable = tqdm,
21
+ ):
22
  """Class for generating audio using De-noising Diffusion Probabilistic Models.
23
 
24
  Args:
 
38
  generator: torch.Generator = None,
39
  step_generator: torch.Generator = None,
40
  eta: float = 0,
41
+ noise: torch.Tensor = None,
42
+ encoding: torch.Tensor = None,
43
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
44
  """Generate random mel spectrogram and convert to audio.
45
 
 
49
  step_generator (torch.Generator): random number generator used to de-noise or None
50
  eta (float): parameter between 0 and 1 used with DDIM scheduler
51
  noise (torch.Tensor): noisy image or None
52
+ encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
53
 
54
  Returns:
55
  PIL Image: mel spectrogram
56
  (float, np.ndarray): sample rate and raw audio
57
  """
58
+ images, (sample_rate, audios) = self.pipe(
59
+ batch_size=1,
60
+ steps=steps,
61
+ generator=generator,
62
+ step_generator=step_generator,
63
+ eta=eta,
64
+ noise=noise,
65
+ encoding=encoding,
66
+ return_dict=False,
67
+ )
68
  return images[0], (sample_rate, audios[0])
69
 
70
  def generate_spectrogram_and_audio_from_audio(
 
79
  mask_end_secs: float = 0,
80
  step_generator: torch.Generator = None,
81
  eta: float = 0,
82
+ encoding: torch.Tensor = None,
83
+ noise: torch.Tensor = None,
84
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
85
  """Generate random mel spectrogram from audio input and convert to audio.
86
 
 
95
  mask_end_secs (float): number of seconds of audio to mask (not generate) at end
96
  step_generator (torch.Generator): random number generator used to de-noise or None
97
  eta (float): parameter between 0 and 1 used with DDIM scheduler
98
+ encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
99
  noise (torch.Tensor): noisy image or None
100
 
101
  Returns:
 
103
  (float, np.ndarray): sample rate and raw audio
104
  """
105
 
106
+ images, (sample_rate, audios) = self.pipe(
107
+ batch_size=1,
108
+ audio_file=audio_file,
109
+ raw_audio=raw_audio,
110
+ slice=slice,
111
+ start_step=start_step,
112
+ steps=steps,
113
+ generator=generator,
114
+ mask_start_secs=mask_start_secs,
115
+ mask_end_secs=mask_end_secs,
116
+ step_generator=step_generator,
117
+ eta=eta,
118
+ noise=noise,
119
+ encoding=encoding,
120
+ return_dict=False,
121
+ )
122
  return images[0], (sample_rate, audios[0])
123
 
124
  @staticmethod
125
+ def loop_it(audio: np.ndarray, sample_rate: int, loops: int = 12) -> np.ndarray:
 
 
126
  """Loop audio
127
 
128
  Args:
 
133
  Returns:
134
  (float, np.ndarray): sample rate and raw audio or None
135
  """
136
+ _, beats = beat_track(y=audio, sr=sample_rate, units="samples")
137
+ beats_in_bar = (len(beats) - 1) // 4 * 4
138
+ if beats_in_bar > 0:
139
+ return np.tile(audio[beats[0] : beats[beats_in_bar]], loops)
140
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiodiffusion/audio_encoder.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from diffusers import ConfigMixin, Mel, ModelMixin
4
+ from torch import nn
5
+
6
+
7
+ class SeparableConv2d(nn.Module):
8
+ def __init__(self, in_channels, out_channels, kernel_size):
9
+ super(SeparableConv2d, self).__init__()
10
+ self.depthwise = nn.Conv2d(
11
+ in_channels,
12
+ in_channels,
13
+ kernel_size=kernel_size,
14
+ groups=in_channels,
15
+ bias=False,
16
+ padding=1,
17
+ )
18
+ self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True)
19
+
20
+ def forward(self, x):
21
+ out = self.depthwise(x)
22
+ out = self.pointwise(out)
23
+ return out
24
+
25
+
26
+ class ConvBlock(nn.Module):
27
+ def __init__(self, in_channels, out_channels, dropout_rate):
28
+ super(ConvBlock, self).__init__()
29
+ self.sep_conv = SeparableConv2d(in_channels, out_channels, (3, 3))
30
+ self.leaky_relu = nn.LeakyReLU(0.2)
31
+ self.batch_norm = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.01)
32
+ self.max_pool = nn.MaxPool2d((2, 2))
33
+ self.dropout = nn.Dropout(dropout_rate)
34
+
35
+ def forward(self, x):
36
+ x = self.sep_conv(x)
37
+ x = self.leaky_relu(x)
38
+ x = self.batch_norm(x)
39
+ x = self.max_pool(x)
40
+ x = self.dropout(x)
41
+ return x
42
+
43
+
44
+ class DenseBlock(nn.Module):
45
+ def __init__(self, in_features, out_features, dropout_rate):
46
+ super(DenseBlock, self).__init__()
47
+ self.flatten = nn.Flatten()
48
+ self.dense = nn.Linear(in_features, out_features)
49
+ self.leaky_relu = nn.LeakyReLU(0.2)
50
+ self.batch_norm = nn.BatchNorm1d(out_features, eps=0.001, momentum=0.01)
51
+ self.dropout = nn.Dropout(dropout_rate)
52
+
53
+ def forward(self, x):
54
+ x = self.flatten(x.permute(0, 2, 3, 1))
55
+ x = self.dense(x)
56
+ x = self.leaky_relu(x)
57
+ x = self.batch_norm(x)
58
+ x = self.dropout(x)
59
+ return x
60
+
61
+
62
+ class AudioEncoder(ModelMixin, ConfigMixin):
63
+ def __init__(self):
64
+ super().__init__()
65
+ self.mel = Mel(
66
+ x_res=216,
67
+ y_res=96,
68
+ sample_rate=22050,
69
+ n_fft=2048,
70
+ hop_length=512,
71
+ top_db=80,
72
+ )
73
+ self.conv_blocks = nn.ModuleList([ConvBlock(1, 32, 0.2), ConvBlock(32, 64, 0.3), ConvBlock(64, 128, 0.4)])
74
+ self.dense_block = DenseBlock(41472, 1024, 0.5)
75
+ self.embedding = nn.Linear(1024, 100)
76
+
77
+ def forward(self, x):
78
+ for conv_block in self.conv_blocks:
79
+ x = conv_block(x)
80
+ x = self.dense_block(x)
81
+ x = self.embedding(x)
82
+ return x
83
+
84
+ @torch.no_grad()
85
+ def encode(self, audio_files):
86
+ self.eval()
87
+ y = []
88
+ for audio_file in audio_files:
89
+ self.mel.load_audio(audio_file)
90
+ x = [
91
+ np.expand_dims(
92
+ np.frombuffer(self.mel.audio_slice_to_image(slice).tobytes(), dtype="uint8").reshape(
93
+ (self.mel.y_res, self.mel.x_res)
94
+ )
95
+ / 255,
96
+ axis=0,
97
+ )
98
+ for slice in range(self.mel.get_number_of_slices())
99
+ ]
100
+ y += [torch.mean(self(torch.Tensor(x)), dim=0)]
101
+ return torch.stack(y)
102
+
103
+
104
+ # from diffusers import Mel
105
+ # from audiodiffusion.audio_encoder import AudioEncoder
106
+ # audio_encoder = AudioEncoder.from_pretrained("teticio/audio-encoder")
107
+ # audio_encoder.encode(['/home/teticio/Music/liked/Agua Re - Holy Dance - Large Sound Mix.mp3'])
audiodiffusion/mel.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code has been migrated to diffusers but can be run locally with
2
+ # pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256", custom_pipeline="audio-diffusion/audiodiffusion/pipeline_audio_diffusion.py")
3
+
4
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ import warnings
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
23
+
24
+ warnings.filterwarnings("ignore")
25
+
26
+ import librosa # noqa: E402
27
+ import numpy as np # noqa: E402
28
+ from PIL import Image # noqa: E402
29
+
30
+
31
+ class Mel(ConfigMixin, SchedulerMixin):
32
+ """
33
+ Parameters:
34
+ x_res (`int`): x resolution of spectrogram (time)
35
+ y_res (`int`): y resolution of spectrogram (frequency bins)
36
+ sample_rate (`int`): sample rate of audio
37
+ n_fft (`int`): number of Fast Fourier Transforms
38
+ hop_length (`int`): hop length (a higher number is recommended for lower than 256 y_res)
39
+ top_db (`int`): loudest in decibels
40
+ n_iter (`int`): number of iterations for Griffin Linn mel inversion
41
+ """
42
+
43
+ config_name = "mel_config.json"
44
+
45
+ @register_to_config
46
+ def __init__(
47
+ self,
48
+ x_res: int = 256,
49
+ y_res: int = 256,
50
+ sample_rate: int = 22050,
51
+ n_fft: int = 2048,
52
+ hop_length: int = 512,
53
+ top_db: int = 80,
54
+ n_iter: int = 32,
55
+ ):
56
+ self.hop_length = hop_length
57
+ self.sr = sample_rate
58
+ self.n_fft = n_fft
59
+ self.top_db = top_db
60
+ self.n_iter = n_iter
61
+ self.set_resolution(x_res, y_res)
62
+ self.audio = None
63
+
64
+ def set_resolution(self, x_res: int, y_res: int):
65
+ """Set resolution.
66
+
67
+ Args:
68
+ x_res (`int`): x resolution of spectrogram (time)
69
+ y_res (`int`): y resolution of spectrogram (frequency bins)
70
+ """
71
+ self.x_res = x_res
72
+ self.y_res = y_res
73
+ self.n_mels = self.y_res
74
+ self.slice_size = self.x_res * self.hop_length - 1
75
+
76
+ def load_audio(self, audio_file: str = None, raw_audio: np.ndarray = None):
77
+ """Load audio.
78
+
79
+ Args:
80
+ audio_file (`str`): must be a file on disk due to Librosa limitation or
81
+ raw_audio (`np.ndarray`): audio as numpy array
82
+ """
83
+ if audio_file is not None:
84
+ self.audio, _ = librosa.load(audio_file, mono=True, sr=self.sr)
85
+ else:
86
+ self.audio = raw_audio
87
+
88
+ # Pad with silence if necessary.
89
+ if len(self.audio) < self.x_res * self.hop_length:
90
+ self.audio = np.concatenate(
91
+ [
92
+ self.audio,
93
+ np.zeros((self.x_res * self.hop_length - len(self.audio),)),
94
+ ]
95
+ )
96
+
97
+ def get_number_of_slices(self) -> int:
98
+ """Get number of slices in audio.
99
+
100
+ Returns:
101
+ `int`: number of spectograms audio can be sliced into
102
+ """
103
+ return len(self.audio) // self.slice_size
104
+
105
+ def get_audio_slice(self, slice: int = 0) -> np.ndarray:
106
+ """Get slice of audio.
107
+
108
+ Args:
109
+ slice (`int`): slice number of audio (out of get_number_of_slices())
110
+
111
+ Returns:
112
+ `np.ndarray`: audio as numpy array
113
+ """
114
+ return self.audio[self.slice_size * slice : self.slice_size * (slice + 1)]
115
+
116
+ def get_sample_rate(self) -> int:
117
+ """Get sample rate:
118
+
119
+ Returns:
120
+ `int`: sample rate of audio
121
+ """
122
+ return self.sr
123
+
124
+ def audio_slice_to_image(self, slice: int) -> Image.Image:
125
+ """Convert slice of audio to spectrogram.
126
+
127
+ Args:
128
+ slice (`int`): slice number of audio to convert (out of get_number_of_slices())
129
+
130
+ Returns:
131
+ `PIL Image`: grayscale image of x_res x y_res
132
+ """
133
+ S = librosa.feature.melspectrogram(
134
+ y=self.get_audio_slice(slice),
135
+ sr=self.sr,
136
+ n_fft=self.n_fft,
137
+ hop_length=self.hop_length,
138
+ n_mels=self.n_mels,
139
+ )
140
+ log_S = librosa.power_to_db(S, ref=np.max, top_db=self.top_db)
141
+ bytedata = (((log_S + self.top_db) * 255 / self.top_db).clip(0, 255) + 0.5).astype(np.uint8)
142
+ image = Image.fromarray(bytedata)
143
+ return image
144
+
145
+ def image_to_audio(self, image: Image.Image) -> np.ndarray:
146
+ """Converts spectrogram to audio.
147
+
148
+ Args:
149
+ image (`PIL Image`): x_res x y_res grayscale image
150
+
151
+ Returns:
152
+ audio (`np.ndarray`): raw audio
153
+ """
154
+ bytedata = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width))
155
+ log_S = bytedata.astype("float") * self.top_db / 255 - self.top_db
156
+ S = librosa.db_to_power(log_S)
157
+ audio = librosa.feature.inverse.mel_to_audio(
158
+ S,
159
+ sr=self.sr,
160
+ n_fft=self.n_fft,
161
+ hop_length=self.hop_length,
162
+ n_iter=self.n_iter,
163
+ )
164
+ return audio
audiodiffusion/pipeline_audio_diffusion.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code has been migrated to diffusers but can be run locally with
2
+ # pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256", custom_pipeline="audio-diffusion/audiodiffusion/pipeline_audio_diffusion.py")
3
+
4
+ # Copyright 2022 The HuggingFace Team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+
19
+ from math import acos, sin
20
+ from typing import List, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, Mel, UNet2DConditionModel
25
+ from diffusers.pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
26
+ from PIL import Image
27
+
28
+ from .mel import Mel
29
+
30
+
31
+ class AudioDiffusionPipeline(DiffusionPipeline):
32
+ """
33
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
34
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
35
+
36
+ Parameters:
37
+ vqae ([`AutoencoderKL`]): Variational AutoEncoder for Latent Audio Diffusion or None
38
+ unet ([`UNet2DConditionModel`]): UNET model
39
+ mel ([`Mel`]): transform audio <-> spectrogram
40
+ scheduler ([`DDIMScheduler` or `DDPMScheduler`]): de-noising scheduler
41
+ """
42
+
43
+ _optional_components = ["vqvae"]
44
+
45
+ def __init__(
46
+ self,
47
+ vqvae: AutoencoderKL,
48
+ unet: UNet2DConditionModel,
49
+ mel: Mel,
50
+ scheduler: Union[DDIMScheduler, DDPMScheduler],
51
+ ):
52
+ super().__init__()
53
+ self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae)
54
+
55
+ def get_input_dims(self) -> Tuple:
56
+ """Returns dimension of input image
57
+
58
+ Returns:
59
+ `Tuple`: (height, width)
60
+ """
61
+ input_module = self.vqvae if self.vqvae is not None else self.unet
62
+ # For backwards compatibility
63
+ sample_size = (
64
+ (input_module.sample_size, input_module.sample_size)
65
+ if type(input_module.sample_size) == int
66
+ else input_module.sample_size
67
+ )
68
+ return sample_size
69
+
70
+ def get_default_steps(self) -> int:
71
+ """Returns default number of steps recommended for inference
72
+
73
+ Returns:
74
+ `int`: number of steps
75
+ """
76
+ return 50 if isinstance(self.scheduler, DDIMScheduler) else 1000
77
+
78
+ @torch.no_grad()
79
+ def __call__(
80
+ self,
81
+ batch_size: int = 1,
82
+ audio_file: str = None,
83
+ raw_audio: np.ndarray = None,
84
+ slice: int = 0,
85
+ start_step: int = 0,
86
+ steps: int = None,
87
+ generator: torch.Generator = None,
88
+ mask_start_secs: float = 0,
89
+ mask_end_secs: float = 0,
90
+ step_generator: torch.Generator = None,
91
+ eta: float = 0,
92
+ noise: torch.Tensor = None,
93
+ encoding: torch.Tensor = None,
94
+ return_dict=True,
95
+ ) -> Union[
96
+ Union[AudioPipelineOutput, ImagePipelineOutput],
97
+ Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]],
98
+ ]:
99
+ """Generate random mel spectrogram from audio input and convert to audio.
100
+
101
+ Args:
102
+ batch_size (`int`): number of samples to generate
103
+ audio_file (`str`): must be a file on disk due to Librosa limitation or
104
+ raw_audio (`np.ndarray`): audio as numpy array
105
+ slice (`int`): slice number of audio to convert
106
+ start_step (int): step to start from
107
+ steps (`int`): number of de-noising steps (defaults to 50 for DDIM, 1000 for DDPM)
108
+ generator (`torch.Generator`): random number generator or None
109
+ mask_start_secs (`float`): number of seconds of audio to mask (not generate) at start
110
+ mask_end_secs (`float`): number of seconds of audio to mask (not generate) at end
111
+ step_generator (`torch.Generator`): random number generator used to de-noise or None
112
+ eta (`float`): parameter between 0 and 1 used with DDIM scheduler
113
+ noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
114
+ encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
115
+ return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
116
+
117
+ Returns:
118
+ `List[PIL Image]`: mel spectrograms (`float`, `List[np.ndarray]`): sample rate and raw audios
119
+ """
120
+
121
+ steps = steps or self.get_default_steps()
122
+ self.scheduler.set_timesteps(steps)
123
+ step_generator = step_generator or generator
124
+ # For backwards compatibility
125
+ if type(self.unet.sample_size) == int:
126
+ self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size)
127
+ input_dims = self.get_input_dims()
128
+ self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
129
+ if noise is None:
130
+ noise = torch.randn(
131
+ (
132
+ batch_size,
133
+ self.unet.in_channels,
134
+ self.unet.sample_size[0],
135
+ self.unet.sample_size[1],
136
+ ),
137
+ generator=generator,
138
+ device=self.device,
139
+ )
140
+ images = noise
141
+ mask = None
142
+
143
+ if audio_file is not None or raw_audio is not None:
144
+ self.mel.load_audio(audio_file, raw_audio)
145
+ input_image = self.mel.audio_slice_to_image(slice)
146
+ input_image = np.frombuffer(input_image.tobytes(), dtype="uint8").reshape(
147
+ (input_image.height, input_image.width)
148
+ )
149
+ input_image = (input_image / 255) * 2 - 1
150
+ input_images = torch.tensor(input_image[np.newaxis, :, :], dtype=torch.float).to(self.device)
151
+
152
+ if self.vqvae is not None:
153
+ input_images = self.vqvae.encode(torch.unsqueeze(input_images, 0)).latent_dist.sample(
154
+ generator=generator
155
+ )[0]
156
+ input_images = 0.18215 * input_images
157
+
158
+ if start_step > 0:
159
+ images[0, 0] = self.scheduler.add_noise(input_images, noise, self.scheduler.timesteps[start_step - 1])
160
+
161
+ pixels_per_second = (
162
+ self.unet.sample_size[1] * self.mel.get_sample_rate() / self.mel.x_res / self.mel.hop_length
163
+ )
164
+ mask_start = int(mask_start_secs * pixels_per_second)
165
+ mask_end = int(mask_end_secs * pixels_per_second)
166
+ mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
167
+
168
+ for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])):
169
+ if isinstance(self.unet, UNet2DConditionModel):
170
+ model_output = self.unet(images, t, encoding)["sample"]
171
+ else:
172
+ model_output = self.unet(images, t)["sample"]
173
+
174
+ if isinstance(self.scheduler, DDIMScheduler):
175
+ images = self.scheduler.step(
176
+ model_output=model_output,
177
+ timestep=t,
178
+ sample=images,
179
+ eta=eta,
180
+ generator=step_generator,
181
+ )["prev_sample"]
182
+ else:
183
+ images = self.scheduler.step(
184
+ model_output=model_output,
185
+ timestep=t,
186
+ sample=images,
187
+ generator=step_generator,
188
+ )["prev_sample"]
189
+
190
+ if mask is not None:
191
+ if mask_start > 0:
192
+ images[:, :, :, :mask_start] = mask[:, step, :, :mask_start]
193
+ if mask_end > 0:
194
+ images[:, :, :, -mask_end:] = mask[:, step, :, -mask_end:]
195
+
196
+ if self.vqvae is not None:
197
+ # 0.18215 was scaling factor used in training to ensure unit variance
198
+ images = 1 / 0.18215 * images
199
+ images = self.vqvae.decode(images)["sample"]
200
+
201
+ images = (images / 2 + 0.5).clamp(0, 1)
202
+ images = images.cpu().permute(0, 2, 3, 1).numpy()
203
+ images = (images * 255).round().astype("uint8")
204
+ images = list(
205
+ map(lambda _: Image.fromarray(_[:, :, 0]), images)
206
+ if images.shape[3] == 1
207
+ else map(lambda _: Image.fromarray(_, mode="RGB").convert("L"), images)
208
+ )
209
+
210
+ audios = list(map(lambda _: self.mel.image_to_audio(_), images))
211
+ if not return_dict:
212
+ return images, (self.mel.get_sample_rate(), audios)
213
+
214
+ return BaseOutput(**AudioPipelineOutput(np.array(audios)[:, np.newaxis, :]), **ImagePipelineOutput(images))
215
+
216
+ @torch.no_grad()
217
+ def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray:
218
+ """Reverse step process: recover noisy image from generated image.
219
+
220
+ Args:
221
+ images (`List[PIL Image]`): list of images to encode
222
+ steps (`int`): number of encoding steps to perform (defaults to 50)
223
+
224
+ Returns:
225
+ `np.ndarray`: noise tensor of shape (batch_size, 1, height, width)
226
+ """
227
+
228
+ # Only works with DDIM as this method is deterministic
229
+ assert isinstance(self.scheduler, DDIMScheduler)
230
+ self.scheduler.set_timesteps(steps)
231
+ sample = np.array(
232
+ [np.frombuffer(image.tobytes(), dtype="uint8").reshape((1, image.height, image.width)) for image in images]
233
+ )
234
+ sample = (sample / 255) * 2 - 1
235
+ sample = torch.Tensor(sample).to(self.device)
236
+
237
+ for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))):
238
+ prev_timestep = t - self.scheduler.num_train_timesteps // self.scheduler.num_inference_steps
239
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
240
+ alpha_prod_t_prev = (
241
+ self.scheduler.alphas_cumprod[prev_timestep]
242
+ if prev_timestep >= 0
243
+ else self.scheduler.final_alpha_cumprod
244
+ )
245
+ beta_prod_t = 1 - alpha_prod_t
246
+ model_output = self.unet(sample, t)["sample"]
247
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * model_output
248
+ sample = (sample - pred_sample_direction) * alpha_prod_t_prev ** (-0.5)
249
+ sample = sample * alpha_prod_t ** (0.5) + beta_prod_t ** (0.5) * model_output
250
+
251
+ return sample
252
+
253
+ @staticmethod
254
+ def slerp(x0: torch.Tensor, x1: torch.Tensor, alpha: float) -> torch.Tensor:
255
+ """Spherical Linear intERPolation
256
+
257
+ Args:
258
+ x0 (`torch.Tensor`): first tensor to interpolate between
259
+ x1 (`torch.Tensor`): seconds tensor to interpolate between
260
+ alpha (`float`): interpolation between 0 and 1
261
+
262
+ Returns:
263
+ `torch.Tensor`: interpolated tensor
264
+ """
265
+
266
+ theta = acos(torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) / torch.norm(x1))
267
+ return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(alpha * theta) * x1 / sin(theta)
audiodiffusion/utils.py CHANGED
@@ -23,8 +23,7 @@ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
23
  new_item = old_item
24
 
25
  new_item = new_item.replace("nin_shortcut", "conv_shortcut")
26
- new_item = shave_segments(
27
- new_item, n_shave_prefix_segments=n_shave_prefix_segments)
28
 
29
  mapping.append({"old": old_item, "new": new_item})
30
 
@@ -54,20 +53,21 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
54
  new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
55
  new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
56
 
57
- new_item = shave_segments(
58
- new_item, n_shave_prefix_segments=n_shave_prefix_segments)
59
 
60
  mapping.append({"old": old_item, "new": new_item})
61
 
62
  return mapping
63
 
64
 
65
- def assign_to_checkpoint(paths,
66
- checkpoint,
67
- old_checkpoint,
68
- attention_paths_to_split=None,
69
- additional_replacements=None,
70
- config=None):
 
 
71
  """
72
  This does the final conversion step: take locally converted weights and apply a global renaming
73
  to them. It splits attention layers, and takes into account additional replacements
@@ -75,9 +75,7 @@ def assign_to_checkpoint(paths,
75
 
76
  Assigns the weights to the new checkpoint.
77
  """
78
- assert isinstance(
79
- paths, list
80
- ), "Paths should be a list of dicts containing 'old' and 'new' keys."
81
 
82
  # Splits the attention layers into three variables.
83
  if attention_paths_to_split is not None:
@@ -85,13 +83,11 @@ def assign_to_checkpoint(paths,
85
  old_tensor = old_checkpoint[path]
86
  channels = old_tensor.shape[0] // 3
87
 
88
- target_shape = (-1,
89
- channels) if len(old_tensor.shape) == 3 else (-1)
90
 
91
  num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
92
 
93
- old_tensor = old_tensor.reshape((num_heads, 3 * channels //
94
- num_heads) + old_tensor.shape[1:])
95
  query, key, value = old_tensor.split(channels // num_heads, dim=1)
96
 
97
  checkpoint[path_map["query"]] = query.reshape(target_shape)
@@ -112,8 +108,7 @@ def assign_to_checkpoint(paths,
112
 
113
  if additional_replacements is not None:
114
  for replacement in additional_replacements:
115
- new_path = new_path.replace(replacement["old"],
116
- replacement["new"])
117
 
118
  # proj_attn.weight has to be converted from conv 1D to linear
119
  if "proj_attn.weight" in new_path:
@@ -146,7 +141,7 @@ def create_vae_diffusers_config(original_config):
146
  up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
147
 
148
  config = dict(
149
- sample_size=vae_params.resolution,
150
  in_channels=vae_params.in_channels,
151
  out_channels=vae_params.out_ch,
152
  down_block_types=tuple(down_block_types),
@@ -164,178 +159,144 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
164
 
165
  new_checkpoint = {}
166
 
167
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict[
168
- "encoder.conv_in.weight"]
169
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict[
170
- "encoder.conv_in.bias"]
171
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
172
- "encoder.conv_out.weight"]
173
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict[
174
- "encoder.conv_out.bias"]
175
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
176
- "encoder.norm_out.weight"]
177
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
178
- "encoder.norm_out.bias"]
179
-
180
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict[
181
- "decoder.conv_in.weight"]
182
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict[
183
- "decoder.conv_in.bias"]
184
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
185
- "decoder.conv_out.weight"]
186
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict[
187
- "decoder.conv_out.bias"]
188
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
189
- "decoder.norm_out.weight"]
190
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
191
- "decoder.norm_out.bias"]
192
 
193
  new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
194
  new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
195
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict[
196
- "post_quant_conv.weight"]
197
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict[
198
- "post_quant_conv.bias"]
199
 
200
  # Retrieves the keys for the encoder down blocks only
201
- num_down_blocks = len({
202
- ".".join(layer.split(".")[:3])
203
- for layer in vae_state_dict if "encoder.down" in layer
204
- })
205
  down_blocks = {
206
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
207
- for layer_id in range(num_down_blocks)
208
  }
209
 
210
  # Retrieves the keys for the decoder up blocks only
211
- num_up_blocks = len({
212
- ".".join(layer.split(".")[:3])
213
- for layer in vae_state_dict if "decoder.up" in layer
214
- })
215
  up_blocks = {
216
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
217
- for layer_id in range(num_up_blocks)
218
  }
219
 
220
  for i in range(num_down_blocks):
221
- resnets = [
222
- key for key in down_blocks[i]
223
- if f"down.{i}" in key and f"down.{i}.downsample" not in key
224
- ]
225
 
226
  if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
227
- new_checkpoint[
228
- f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
229
- f"encoder.down.{i}.downsample.conv.weight")
230
- new_checkpoint[
231
- f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
232
- f"encoder.down.{i}.downsample.conv.bias")
233
 
234
  paths = renew_vae_resnet_paths(resnets)
235
- meta_path = {
236
- "old": f"down.{i}.block",
237
- "new": f"down_blocks.{i}.resnets"
238
- }
239
- assign_to_checkpoint(paths,
240
- new_checkpoint,
241
- vae_state_dict,
242
- additional_replacements=[meta_path],
243
- config=config)
244
 
245
  mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
246
  num_mid_res_blocks = 2
247
  for i in range(1, num_mid_res_blocks + 1):
248
- resnets = [
249
- key for key in mid_resnets if f"encoder.mid.block_{i}" in key
250
- ]
251
 
252
  paths = renew_vae_resnet_paths(resnets)
253
- meta_path = {
254
- "old": f"mid.block_{i}",
255
- "new": f"mid_block.resnets.{i - 1}"
256
- }
257
- assign_to_checkpoint(paths,
258
- new_checkpoint,
259
- vae_state_dict,
260
- additional_replacements=[meta_path],
261
- config=config)
262
-
263
- mid_attentions = [
264
- key for key in vae_state_dict if "encoder.mid.attn" in key
265
- ]
266
  paths = renew_vae_attention_paths(mid_attentions)
267
  meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
268
- assign_to_checkpoint(paths,
269
- new_checkpoint,
270
- vae_state_dict,
271
- additional_replacements=[meta_path],
272
- config=config)
 
 
273
  conv_attn_to_linear(new_checkpoint)
274
 
275
  for i in range(num_up_blocks):
276
  block_id = num_up_blocks - 1 - i
277
  resnets = [
278
- key for key in up_blocks[block_id]
279
- if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
280
  ]
281
 
282
  if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
283
- new_checkpoint[
284
- f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
285
- f"decoder.up.{block_id}.upsample.conv.weight"]
286
- new_checkpoint[
287
- f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
288
- f"decoder.up.{block_id}.upsample.conv.bias"]
289
 
290
  paths = renew_vae_resnet_paths(resnets)
291
- meta_path = {
292
- "old": f"up.{block_id}.block",
293
- "new": f"up_blocks.{i}.resnets"
294
- }
295
- assign_to_checkpoint(paths,
296
- new_checkpoint,
297
- vae_state_dict,
298
- additional_replacements=[meta_path],
299
- config=config)
300
 
301
  mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
302
  num_mid_res_blocks = 2
303
  for i in range(1, num_mid_res_blocks + 1):
304
- resnets = [
305
- key for key in mid_resnets if f"decoder.mid.block_{i}" in key
306
- ]
307
 
308
  paths = renew_vae_resnet_paths(resnets)
309
- meta_path = {
310
- "old": f"mid.block_{i}",
311
- "new": f"mid_block.resnets.{i - 1}"
312
- }
313
- assign_to_checkpoint(paths,
314
- new_checkpoint,
315
- vae_state_dict,
316
- additional_replacements=[meta_path],
317
- config=config)
318
-
319
- mid_attentions = [
320
- key for key in vae_state_dict if "decoder.mid.attn" in key
321
- ]
322
  paths = renew_vae_attention_paths(mid_attentions)
323
  meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
324
- assign_to_checkpoint(paths,
325
- new_checkpoint,
326
- vae_state_dict,
327
- additional_replacements=[meta_path],
328
- config=config)
 
 
329
  conv_attn_to_linear(new_checkpoint)
330
  return new_checkpoint
331
 
332
- def convert_ldm_to_hf_vae(ldm_checkpoint, ldm_config, hf_checkpoint):
 
333
  checkpoint = torch.load(ldm_checkpoint)["state_dict"]
334
 
335
  # Convert the VAE model.
336
  vae_config = create_vae_diffusers_config(ldm_config)
337
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(
338
- checkpoint, vae_config)
339
 
340
  vae = AutoencoderKL(**vae_config)
341
  vae.load_state_dict(converted_vae_checkpoint)
 
23
  new_item = old_item
24
 
25
  new_item = new_item.replace("nin_shortcut", "conv_shortcut")
26
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
 
27
 
28
  mapping.append({"old": old_item, "new": new_item})
29
 
 
53
  new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
54
  new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
55
 
56
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
 
57
 
58
  mapping.append({"old": old_item, "new": new_item})
59
 
60
  return mapping
61
 
62
 
63
+ def assign_to_checkpoint(
64
+ paths,
65
+ checkpoint,
66
+ old_checkpoint,
67
+ attention_paths_to_split=None,
68
+ additional_replacements=None,
69
+ config=None,
70
+ ):
71
  """
72
  This does the final conversion step: take locally converted weights and apply a global renaming
73
  to them. It splits attention layers, and takes into account additional replacements
 
75
 
76
  Assigns the weights to the new checkpoint.
77
  """
78
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
 
 
79
 
80
  # Splits the attention layers into three variables.
81
  if attention_paths_to_split is not None:
 
83
  old_tensor = old_checkpoint[path]
84
  channels = old_tensor.shape[0] // 3
85
 
86
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
 
87
 
88
  num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
89
 
90
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
 
91
  query, key, value = old_tensor.split(channels // num_heads, dim=1)
92
 
93
  checkpoint[path_map["query"]] = query.reshape(target_shape)
 
108
 
109
  if additional_replacements is not None:
110
  for replacement in additional_replacements:
111
+ new_path = new_path.replace(replacement["old"], replacement["new"])
 
112
 
113
  # proj_attn.weight has to be converted from conv 1D to linear
114
  if "proj_attn.weight" in new_path:
 
141
  up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
142
 
143
  config = dict(
144
+ sample_size=tuple(vae_params.resolution),
145
  in_channels=vae_params.in_channels,
146
  out_channels=vae_params.out_ch,
147
  down_block_types=tuple(down_block_types),
 
159
 
160
  new_checkpoint = {}
161
 
162
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
163
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
164
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
165
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
166
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
167
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
168
+
169
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
170
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
171
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
172
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
173
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
174
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
 
 
 
 
 
 
 
 
 
 
 
 
175
 
176
  new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
177
  new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
178
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
179
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
 
 
180
 
181
  # Retrieves the keys for the encoder down blocks only
182
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
 
 
 
183
  down_blocks = {
184
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
 
185
  }
186
 
187
  # Retrieves the keys for the decoder up blocks only
188
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
 
 
 
189
  up_blocks = {
190
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
 
191
  }
192
 
193
  for i in range(num_down_blocks):
194
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
 
 
 
195
 
196
  if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
197
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
198
+ f"encoder.down.{i}.downsample.conv.weight"
199
+ )
200
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
201
+ f"encoder.down.{i}.downsample.conv.bias"
202
+ )
203
 
204
  paths = renew_vae_resnet_paths(resnets)
205
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
206
+ assign_to_checkpoint(
207
+ paths,
208
+ new_checkpoint,
209
+ vae_state_dict,
210
+ additional_replacements=[meta_path],
211
+ config=config,
212
+ )
 
213
 
214
  mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
215
  num_mid_res_blocks = 2
216
  for i in range(1, num_mid_res_blocks + 1):
217
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
 
 
218
 
219
  paths = renew_vae_resnet_paths(resnets)
220
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
221
+ assign_to_checkpoint(
222
+ paths,
223
+ new_checkpoint,
224
+ vae_state_dict,
225
+ additional_replacements=[meta_path],
226
+ config=config,
227
+ )
228
+
229
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
 
 
 
230
  paths = renew_vae_attention_paths(mid_attentions)
231
  meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
232
+ assign_to_checkpoint(
233
+ paths,
234
+ new_checkpoint,
235
+ vae_state_dict,
236
+ additional_replacements=[meta_path],
237
+ config=config,
238
+ )
239
  conv_attn_to_linear(new_checkpoint)
240
 
241
  for i in range(num_up_blocks):
242
  block_id = num_up_blocks - 1 - i
243
  resnets = [
244
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
 
245
  ]
246
 
247
  if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
248
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
249
+ f"decoder.up.{block_id}.upsample.conv.weight"
250
+ ]
251
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
252
+ f"decoder.up.{block_id}.upsample.conv.bias"
253
+ ]
254
 
255
  paths = renew_vae_resnet_paths(resnets)
256
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
257
+ assign_to_checkpoint(
258
+ paths,
259
+ new_checkpoint,
260
+ vae_state_dict,
261
+ additional_replacements=[meta_path],
262
+ config=config,
263
+ )
 
264
 
265
  mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
266
  num_mid_res_blocks = 2
267
  for i in range(1, num_mid_res_blocks + 1):
268
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
 
 
269
 
270
  paths = renew_vae_resnet_paths(resnets)
271
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
272
+ assign_to_checkpoint(
273
+ paths,
274
+ new_checkpoint,
275
+ vae_state_dict,
276
+ additional_replacements=[meta_path],
277
+ config=config,
278
+ )
279
+
280
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
 
 
 
281
  paths = renew_vae_attention_paths(mid_attentions)
282
  meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
283
+ assign_to_checkpoint(
284
+ paths,
285
+ new_checkpoint,
286
+ vae_state_dict,
287
+ additional_replacements=[meta_path],
288
+ config=config,
289
+ )
290
  conv_attn_to_linear(new_checkpoint)
291
  return new_checkpoint
292
 
293
+
294
+ def convert_ldm_to_hf_vae(ldm_checkpoint, ldm_config, hf_checkpoint, sample_size):
295
  checkpoint = torch.load(ldm_checkpoint)["state_dict"]
296
 
297
  # Convert the VAE model.
298
  vae_config = create_vae_diffusers_config(ldm_config)
299
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
 
300
 
301
  vae = AutoencoderKL(**vae_config)
302
  vae.load_state_dict(converted_vae_checkpoint)
config/ldm_autoencoder_kl.yaml CHANGED
@@ -18,7 +18,7 @@ model:
18
  ddconfig:
19
  double_z: True
20
  z_channels: 1 # must = embed_dim due to HF limitation
21
- resolution: 256
22
  in_channels: 1
23
  out_ch: 1
24
  ch: 128
 
18
  ddconfig:
19
  double_z: True
20
  z_channels: 1 # must = embed_dim due to HF limitation
21
+ resolution: 256 # overriden by input image size
22
  in_channels: 1
23
  out_ch: 1
24
  ch: 128
notebooks/audio_encoder.ipynb ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "592fff30",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from diffusers import Mel\n",
11
+ "from audiodiffusion.audio_encoder import AudioEncoder"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": null,
17
+ "id": "d99ef523",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "audio_encoder = AudioEncoder.from_pretrained(\"teticio/audio-encoder\")"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "id": "4eb3bbd7",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "audio_encoder.encode(['/home/teticio/Music/liked/Agua Re - Holy Dance - Large Sound Mix.mp3'])"
32
+ ]
33
+ }
34
+ ],
35
+ "metadata": {
36
+ "kernelspec": {
37
+ "display_name": "huggingface",
38
+ "language": "python",
39
+ "name": "huggingface"
40
+ },
41
+ "language_info": {
42
+ "codemirror_mode": {
43
+ "name": "ipython",
44
+ "version": 3
45
+ },
46
+ "file_extension": ".py",
47
+ "mimetype": "text/x-python",
48
+ "name": "python",
49
+ "nbconvert_exporter": "python",
50
+ "pygments_lexer": "ipython3",
51
+ "version": "3.10.6"
52
+ },
53
+ "toc": {
54
+ "base_numbering": 1,
55
+ "nav_menu": {},
56
+ "number_sections": true,
57
+ "sideBar": true,
58
+ "skip_h1_title": false,
59
+ "title_cell": "Table of Contents",
60
+ "title_sidebar": "Contents",
61
+ "toc_cell": false,
62
+ "toc_position": {},
63
+ "toc_section_display": true,
64
+ "toc_window_display": false
65
+ }
66
+ },
67
+ "nbformat": 4,
68
+ "nbformat_minor": 5
69
+ }
notebooks/conditional_generation.ipynb ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "2a44739f",
6
+ "metadata": {},
7
+ "source": [
8
+ "<a href=\"https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/condtional_generation.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "f1935544",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "try:\n",
19
+ " # are we running on Google Colab?\n",
20
+ " import google.colab\n",
21
+ " !git clone -q https://github.com/teticio/audio-diffusion.git\n",
22
+ " %cd audio-diffusion\n",
23
+ " %pip install -q -r requirements.txt\n",
24
+ "except:\n",
25
+ " pass"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": null,
31
+ "id": "b0e656c9",
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "import os\n",
36
+ "import sys\n",
37
+ "sys.path.insert(0, os.path.dirname(os.path.abspath(\"\")))"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "d448b299",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import torch\n",
48
+ "import urllib\n",
49
+ "import requests\n",
50
+ "from IPython.display import Audio\n",
51
+ "from audiodiffusion import AudioDiffusion\n",
52
+ "from audiodiffusion.audio_encoder import AudioEncoder"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "id": "f1548971",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
63
+ "generator = torch.Generator(device=device)"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": null,
69
+ "id": "056f179c",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "audio_diffusion = AudioDiffusion(model_id=\"teticio/conditional-latent-audio-diffusion-512\")"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "b4a08500",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "audio_encoder = AudioEncoder.from_pretrained(\"teticio/audio-encoder\")"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "387550ac",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "# Uncomment for faster (but slightly lower quality) generation\n",
94
+ "#from diffusers import DDIMScheduler\n",
95
+ "#audio_diffusion.pipe.scheduler = DDIMScheduler()"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "id": "9936a72f",
101
+ "metadata": {},
102
+ "source": [
103
+ "## Download and encode preview track from Spotify"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "id": "57a9b134",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "# Get temporary API credentials\n",
114
+ "credentials = requests.get(\n",
115
+ " \"https://open.spotify.com/get_access_token?reason=transport&productType=embed\"\n",
116
+ ").json()\n",
117
+ "headers = {\n",
118
+ " \"Accept\": \"application/json\",\n",
119
+ " \"Content-Type\": \"application/json\",\n",
120
+ " \"Authorization\": \"Bearer \" + credentials[\"accessToken\"]\n",
121
+ "}\n",
122
+ "\n",
123
+ "# Search for tracks\n",
124
+ "search_string = input(\"Search: \")\n",
125
+ "response = requests.get(\n",
126
+ " f\"https://api.spotify.com/v1/search?q={urllib.parse.quote(search_string)}&type=track\",\n",
127
+ " headers=headers).json()\n",
128
+ "\n",
129
+ "# List results\n",
130
+ "for _, track in enumerate(response[\"tracks\"][\"items\"]):\n",
131
+ " print(f\"{_ + 1}. {track['artists'][0]['name']} - {track['name']}\")\n",
132
+ "selection = input(\"Select a track: \")\n",
133
+ "\n",
134
+ "# Download and encode selection\n",
135
+ "r = requests.get(response[\"tracks\"][\"items\"][int(selection) -\n",
136
+ " 1][\"preview_url\"],\n",
137
+ " stream=True)\n",
138
+ "with open(\"temp.mp3\", \"wb\") as f:\n",
139
+ " for chunk in r:\n",
140
+ " f.write(chunk)\n",
141
+ "encoding = torch.unsqueeze(audio_encoder.encode([\"temp.mp3\"]),\n",
142
+ " axis=1).to(device)\n",
143
+ "os.remove(\"temp.mp3\")"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "id": "8af863f5",
149
+ "metadata": {},
150
+ "source": [
151
+ "## Conditional Generation\n",
152
+ "Bear in mind that the generative model can only generate music similar to that on which it was trained. The audio encoding will influence the generation within those limitations."
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "id": "8f119ddd",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "for _ in range(10):\n",
163
+ " seed = generator.seed()\n",
164
+ " print(f'Seed = {seed}')\n",
165
+ " generator.manual_seed(seed)\n",
166
+ " image, (sample_rate,\n",
167
+ " audio) = audio_diffusion.generate_spectrogram_and_audio(\n",
168
+ " generator=generator, encoding=encoding)\n",
169
+ " display(image)\n",
170
+ " display(Audio(audio, rate=sample_rate))\n",
171
+ " loop = AudioDiffusion.loop_it(audio, sample_rate)\n",
172
+ " if loop is not None:\n",
173
+ " display(Audio(loop, rate=sample_rate))\n",
174
+ " else:\n",
175
+ " print(\"Unable to determine loop points\")"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "id": "d0bd18c0",
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": []
185
+ }
186
+ ],
187
+ "metadata": {
188
+ "kernelspec": {
189
+ "display_name": "huggingface",
190
+ "language": "python",
191
+ "name": "huggingface"
192
+ },
193
+ "language_info": {
194
+ "codemirror_mode": {
195
+ "name": "ipython",
196
+ "version": 3
197
+ },
198
+ "file_extension": ".py",
199
+ "mimetype": "text/x-python",
200
+ "name": "python",
201
+ "nbconvert_exporter": "python",
202
+ "pygments_lexer": "ipython3",
203
+ "version": "3.10.6"
204
+ },
205
+ "toc": {
206
+ "base_numbering": 1,
207
+ "nav_menu": {},
208
+ "number_sections": true,
209
+ "sideBar": true,
210
+ "skip_h1_title": false,
211
+ "title_cell": "Table of Contents",
212
+ "title_sidebar": "Contents",
213
+ "toc_cell": false,
214
+ "toc_position": {},
215
+ "toc_section_display": true,
216
+ "toc_window_display": false
217
+ }
218
+ },
219
+ "nbformat": 4,
220
+ "nbformat_minor": 5
221
+ }
pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [tool.black]
2
+ line-length = 119
3
+ target-version = ['py36']
scripts/audio_to_images.py CHANGED
@@ -1,29 +1,33 @@
1
- import os
2
- import re
3
  import io
4
  import logging
5
- import argparse
 
6
 
7
  import numpy as np
8
  import pandas as pd
9
- from tqdm.auto import tqdm
10
- from diffusers.pipelines.audio_diffusion import Mel
11
  from datasets import Dataset, DatasetDict, Features, Image, Value
 
 
12
 
13
  logging.basicConfig(level=logging.WARN)
14
- logger = logging.getLogger('audio_to_images')
15
 
16
 
17
  def main(args):
18
- mel = Mel(x_res=args.resolution[0],
19
- y_res=args.resolution[1],
20
- hop_length=args.hop_length,
21
- sample_rate=args.sample_rate,
22
- n_fft=args.n_fft)
 
 
23
  os.makedirs(args.output_dir, exist_ok=True)
24
  audio_files = [
25
- os.path.join(root, file) for root, _, files in os.walk(args.input_dir)
26
- for file in files if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
 
 
27
  ]
28
  examples = []
29
  try:
@@ -36,36 +40,38 @@ def main(args):
36
  continue
37
  for slice in range(mel.get_number_of_slices()):
38
  image = mel.audio_slice_to_image(slice)
39
- assert (image.width == args.resolution[0] and image.height
40
- == args.resolution[1]), "Wrong resolution"
41
  # skip completely silent slices
42
  if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255):
43
- logger.warn('File %s slice %d is completely silent',
44
- audio_file, slice)
45
  continue
46
  with io.BytesIO() as output:
47
  image.save(output, format="PNG")
48
  bytes = output.getvalue()
49
- examples.extend([{
50
- "image": {
51
- "bytes": bytes
52
- },
53
- "audio_file": audio_file,
54
- "slice": slice,
55
- }])
 
 
56
  except Exception as e:
57
  print(e)
58
  finally:
59
  if len(examples) == 0:
60
- logger.warn('No valid audio files were found.')
61
  return
62
  ds = Dataset.from_pandas(
63
  pd.DataFrame(examples),
64
- features=Features({
65
- "image": Image(),
66
- "audio_file": Value(dtype="string"),
67
- "slice": Value(dtype="int16"),
68
- }),
 
 
69
  )
70
  dsd = DatasetDict({"train": ds})
71
  dsd.save_to_disk(os.path.join(args.output_dir))
@@ -74,15 +80,15 @@ def main(args):
74
 
75
 
76
  if __name__ == "__main__":
77
- parser = argparse.ArgumentParser(
78
- description=
79
- "Create dataset of Mel spectrograms from directory of audio files.")
80
  parser.add_argument("--input_dir", type=str)
81
  parser.add_argument("--output_dir", type=str, default="data")
82
- parser.add_argument("--resolution",
83
- type=str,
84
- default="256",
85
- help="Either square resolution or width,height.")
 
 
86
  parser.add_argument("--hop_length", type=int, default=512)
87
  parser.add_argument("--push_to_hub", type=str, default=None)
88
  parser.add_argument("--sample_rate", type=int, default=22050)
@@ -90,8 +96,7 @@ if __name__ == "__main__":
90
  args = parser.parse_args()
91
 
92
  if args.input_dir is None:
93
- raise ValueError(
94
- "You must specify an input directory for the audio files.")
95
 
96
  # Handle the resolutions.
97
  try:
@@ -102,9 +107,7 @@ if __name__ == "__main__":
102
  if len(args.resolution) != 2:
103
  raise ValueError
104
  except ValueError:
105
- raise ValueError(
106
- "Resolution must be a tuple of two integers or a single integer."
107
- )
108
  assert isinstance(args.resolution, tuple)
109
 
110
  main(args)
 
1
+ import argparse
 
2
  import io
3
  import logging
4
+ import os
5
+ import re
6
 
7
  import numpy as np
8
  import pandas as pd
 
 
9
  from datasets import Dataset, DatasetDict, Features, Image, Value
10
+ from diffusers.pipelines.audio_diffusion import Mel
11
+ from tqdm.auto import tqdm
12
 
13
  logging.basicConfig(level=logging.WARN)
14
+ logger = logging.getLogger("audio_to_images")
15
 
16
 
17
  def main(args):
18
+ mel = Mel(
19
+ x_res=args.resolution[0],
20
+ y_res=args.resolution[1],
21
+ hop_length=args.hop_length,
22
+ sample_rate=args.sample_rate,
23
+ n_fft=args.n_fft,
24
+ )
25
  os.makedirs(args.output_dir, exist_ok=True)
26
  audio_files = [
27
+ os.path.join(root, file)
28
+ for root, _, files in os.walk(args.input_dir)
29
+ for file in files
30
+ if re.search("\.(mp3|wav|m4a)$", file, re.IGNORECASE)
31
  ]
32
  examples = []
33
  try:
 
40
  continue
41
  for slice in range(mel.get_number_of_slices()):
42
  image = mel.audio_slice_to_image(slice)
43
+ assert image.width == args.resolution[0] and image.height == args.resolution[1], "Wrong resolution"
 
44
  # skip completely silent slices
45
  if all(np.frombuffer(image.tobytes(), dtype=np.uint8) == 255):
46
+ logger.warn("File %s slice %d is completely silent", audio_file, slice)
 
47
  continue
48
  with io.BytesIO() as output:
49
  image.save(output, format="PNG")
50
  bytes = output.getvalue()
51
+ examples.extend(
52
+ [
53
+ {
54
+ "image": {"bytes": bytes},
55
+ "audio_file": audio_file,
56
+ "slice": slice,
57
+ }
58
+ ]
59
+ )
60
  except Exception as e:
61
  print(e)
62
  finally:
63
  if len(examples) == 0:
64
+ logger.warn("No valid audio files were found.")
65
  return
66
  ds = Dataset.from_pandas(
67
  pd.DataFrame(examples),
68
+ features=Features(
69
+ {
70
+ "image": Image(),
71
+ "audio_file": Value(dtype="string"),
72
+ "slice": Value(dtype="int16"),
73
+ }
74
+ ),
75
  )
76
  dsd = DatasetDict({"train": ds})
77
  dsd.save_to_disk(os.path.join(args.output_dir))
 
80
 
81
 
82
  if __name__ == "__main__":
83
+ parser = argparse.ArgumentParser(description="Create dataset of Mel spectrograms from directory of audio files.")
 
 
84
  parser.add_argument("--input_dir", type=str)
85
  parser.add_argument("--output_dir", type=str, default="data")
86
+ parser.add_argument(
87
+ "--resolution",
88
+ type=str,
89
+ default="256",
90
+ help="Either square resolution or width,height.",
91
+ )
92
  parser.add_argument("--hop_length", type=int, default=512)
93
  parser.add_argument("--push_to_hub", type=str, default=None)
94
  parser.add_argument("--sample_rate", type=int, default=22050)
 
96
  args = parser.parse_args()
97
 
98
  if args.input_dir is None:
99
+ raise ValueError("You must specify an input directory for the audio files.")
 
100
 
101
  # Handle the resolutions.
102
  try:
 
107
  if len(args.resolution) != 2:
108
  raise ValueError
109
  except ValueError:
110
+ raise ValueError("Resolution must be a tuple of two integers or a single integer.")
 
 
111
  assert isinstance(args.resolution, tuple)
112
 
113
  main(args)
scripts/encode_audio.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pickle
4
+
5
+ from datasets import load_dataset, load_from_disk
6
+ from tqdm.auto import tqdm
7
+
8
+ from audiodiffusion.audio_encoder import AudioEncoder
9
+
10
+
11
+ def main(args):
12
+ audio_encoder = AudioEncoder.from_pretrained("teticio/audio-encoder")
13
+
14
+ if args.dataset_name is not None:
15
+ if os.path.exists(args.dataset_name):
16
+ dataset = load_from_disk(args.dataset_name)["train"]
17
+ else:
18
+ dataset = load_dataset(
19
+ args.dataset_name,
20
+ args.dataset_config_name,
21
+ cache_dir=args.cache_dir,
22
+ use_auth_token=True if args.use_auth_token else None,
23
+ split="train",
24
+ )
25
+
26
+ encodings = {}
27
+ for audio_file in tqdm(dataset.to_pandas()["audio_file"].unique()):
28
+ encodings[audio_file] = audio_encoder.encode([audio_file])
29
+ pickle.dump(encodings, open(args.output_file, "wb"))
30
+
31
+
32
+ if __name__ == "__main__":
33
+ parser = argparse.ArgumentParser(description="Create pickled audio encodings for dataset of audio files.")
34
+ parser.add_argument("--dataset_name", type=str, default=None)
35
+ parser.add_argument("--output_file", type=str, default="data/encodings.p")
36
+ parser.add_argument("--use_auth_token", type=bool, default=False)
37
+ args = parser.parse_args()
38
+ main(args)
scripts/{train_unconditional.py → train_unet.py} RENAMED
@@ -2,34 +2,29 @@
2
 
3
  import argparse
4
  import os
 
 
5
  from pathlib import Path
6
  from typing import Optional
7
 
 
 
 
8
  from accelerate import Accelerator
9
  from accelerate.logging import get_logger
10
- from datasets import load_from_disk, load_dataset
11
- from diffusers import (
12
- AudioDiffusionPipeline,
13
- DDPMScheduler,
14
- UNet2DModel,
15
- DDIMScheduler,
16
- AutoencoderKL,
17
- )
18
- from diffusers.pipelines.audio_diffusion import Mel
19
  from diffusers.optimization import get_scheduler
 
20
  from diffusers.training_utils import EMAModel
21
  from huggingface_hub import HfFolder, Repository, whoami
22
  from librosa.util import normalize
23
- import numpy as np
24
- import torch
25
- import torch.nn.functional as F
26
- from torchvision.transforms import (
27
- Compose,
28
- Normalize,
29
- ToTensor,
30
- )
31
  from tqdm.auto import tqdm
32
 
 
 
33
  logger = get_logger(__name__)
34
 
35
 
@@ -90,12 +85,18 @@ def main(args):
90
  ]
91
  else:
92
  images = [augmentations(image) for image in examples["image"]]
 
 
 
93
  return {"input": images}
94
 
95
  dataset.set_transform(transforms)
96
  train_dataloader = torch.utils.data.DataLoader(
97
  dataset, batch_size=args.train_batch_size, shuffle=True)
98
 
 
 
 
99
  vqvae = None
100
  if args.vae is not None:
101
  try:
@@ -104,9 +105,9 @@ def main(args):
104
  vqvae = AudioDiffusionPipeline.from_pretrained(args.vae).vqvae
105
  # Determine latent resolution
106
  with torch.no_grad():
107
- latent_resolution = (vqvae.encode(
108
  torch.zeros((1, 1) +
109
- resolution)).latent_dist.sample().shape[2:])
110
 
111
  if args.from_pretrained is not None:
112
  pipeline = AudioDiffusionPipeline.from_pretrained(args.from_pretrained)
@@ -114,32 +115,58 @@ def main(args):
114
  model = pipeline.unet
115
  if hasattr(pipeline, "vqvae"):
116
  vqvae = pipeline.vqvae
 
117
  else:
118
- model = UNet2DModel(
119
- sample_size=resolution if vqvae is None else latent_resolution,
120
- in_channels=1
121
- if vqvae is None else vqvae.config["latent_channels"],
122
- out_channels=1
123
- if vqvae is None else vqvae.config["latent_channels"],
124
- layers_per_block=2,
125
- block_out_channels=(128, 128, 256, 256, 512, 512),
126
- down_block_types=(
127
- "DownBlock2D",
128
- "DownBlock2D",
129
- "DownBlock2D",
130
- "DownBlock2D",
131
- "AttnDownBlock2D",
132
- "DownBlock2D",
133
- ),
134
- up_block_types=(
135
- "UpBlock2D",
136
- "AttnUpBlock2D",
137
- "UpBlock2D",
138
- "UpBlock2D",
139
- "UpBlock2D",
140
- "UpBlock2D",
141
- ),
142
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  if args.scheduler == "ddpm":
145
  noise_scheduler = DDPMScheduler(
@@ -240,7 +267,11 @@ def main(args):
240
 
241
  with accelerator.accumulate(model):
242
  # Predict the noise residual
243
- noise_pred = model(noisy_images, timesteps)["sample"]
 
 
 
 
244
  loss = F.mse_loss(noise_pred, noise)
245
  accelerator.backward(loss)
246
 
@@ -270,9 +301,9 @@ def main(args):
270
 
271
  # Generate sample images for visual inspection
272
  if accelerator.is_main_process:
273
- if (epoch + 1) % args.save_model_epochs == 0 or (
274
- epoch + 1
275
- ) % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
276
  pipeline = AudioDiffusionPipeline(
277
  vqvae=vqvae,
278
  unet=accelerator.unwrap_model(
@@ -288,18 +319,32 @@ def main(args):
288
 
289
  # save the model
290
  if args.push_to_hub:
291
- repo.push_to_hub(commit_message=f"Epoch {epoch}",
292
- blocking=False,
293
- auto_lfs_prune=True)
 
 
294
 
295
  if (epoch + 1) % args.save_images_epochs == 0:
296
  generator = torch.Generator(
297
  device=clean_images.device).manual_seed(42)
 
 
 
 
 
 
 
 
 
 
298
  # run pipeline in inference (sample random noise and denoise)
299
- images, (sample_rate,
300
- audios) = pipeline(generator=generator,
301
- batch_size=args.eval_batch_size,
302
- return_dict=False)
 
 
303
 
304
  # denormalize the images and save to tensorboard
305
  images = np.array([
@@ -385,6 +430,12 @@ if __name__ == "__main__":
385
  default=None,
386
  help="pretrained VAE model for latent diffusion",
387
  )
 
 
 
 
 
 
388
 
389
  args = parser.parse_args()
390
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
 
2
 
3
  import argparse
4
  import os
5
+ import pickle
6
+ import random
7
  from pathlib import Path
8
  from typing import Optional
9
 
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn.functional as F
13
  from accelerate import Accelerator
14
  from accelerate.logging import get_logger
15
+ from datasets import load_dataset, load_from_disk
16
+ from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler,
17
+ UNet2DConditionModel, UNet2DModel)
 
 
 
 
 
 
18
  from diffusers.optimization import get_scheduler
19
+ from diffusers.pipelines.audio_diffusion import Mel
20
  from diffusers.training_utils import EMAModel
21
  from huggingface_hub import HfFolder, Repository, whoami
22
  from librosa.util import normalize
23
+ from torchvision.transforms import Compose, Normalize, ToTensor
 
 
 
 
 
 
 
24
  from tqdm.auto import tqdm
25
 
26
+ from audiodiffusion.pipeline_audio_diffusion import AudioDiffusionPipeline
27
+
28
  logger = get_logger(__name__)
29
 
30
 
 
85
  ]
86
  else:
87
  images = [augmentations(image) for image in examples["image"]]
88
+ if args.encodings is not None:
89
+ encoding = [encodings[file] for file in examples["audio_file"]]
90
+ return {"input": images, "encoding": encoding}
91
  return {"input": images}
92
 
93
  dataset.set_transform(transforms)
94
  train_dataloader = torch.utils.data.DataLoader(
95
  dataset, batch_size=args.train_batch_size, shuffle=True)
96
 
97
+ if args.encodings is not None:
98
+ encodings = pickle.load(open(args.encodings, "rb"))
99
+
100
  vqvae = None
101
  if args.vae is not None:
102
  try:
 
105
  vqvae = AudioDiffusionPipeline.from_pretrained(args.vae).vqvae
106
  # Determine latent resolution
107
  with torch.no_grad():
108
+ latent_resolution = vqvae.encode(
109
  torch.zeros((1, 1) +
110
+ resolution)).latent_dist.sample().shape[2:]
111
 
112
  if args.from_pretrained is not None:
113
  pipeline = AudioDiffusionPipeline.from_pretrained(args.from_pretrained)
 
115
  model = pipeline.unet
116
  if hasattr(pipeline, "vqvae"):
117
  vqvae = pipeline.vqvae
118
+
119
  else:
120
+ if args.encodings is None:
121
+ model = UNet2DModel(
122
+ sample_size=resolution if vqvae is None else latent_resolution,
123
+ in_channels=1
124
+ if vqvae is None else vqvae.config["latent_channels"],
125
+ out_channels=1
126
+ if vqvae is None else vqvae.config["latent_channels"],
127
+ layers_per_block=2,
128
+ block_out_channels=(128, 128, 256, 256, 512, 512),
129
+ down_block_types=(
130
+ "DownBlock2D",
131
+ "DownBlock2D",
132
+ "DownBlock2D",
133
+ "DownBlock2D",
134
+ "AttnDownBlock2D",
135
+ "DownBlock2D",
136
+ ),
137
+ up_block_types=(
138
+ "UpBlock2D",
139
+ "AttnUpBlock2D",
140
+ "UpBlock2D",
141
+ "UpBlock2D",
142
+ "UpBlock2D",
143
+ "UpBlock2D",
144
+ ),
145
+ )
146
+
147
+ else:
148
+ model = UNet2DConditionModel(
149
+ sample_size=resolution if vqvae is None else latent_resolution,
150
+ in_channels=1
151
+ if vqvae is None else vqvae.config["latent_channels"],
152
+ out_channels=1
153
+ if vqvae is None else vqvae.config["latent_channels"],
154
+ layers_per_block=2,
155
+ block_out_channels=(128, 256, 512, 512),
156
+ down_block_types=(
157
+ "CrossAttnDownBlock2D",
158
+ "CrossAttnDownBlock2D",
159
+ "CrossAttnDownBlock2D",
160
+ "DownBlock2D",
161
+ ),
162
+ up_block_types=(
163
+ "UpBlock2D",
164
+ "CrossAttnUpBlock2D",
165
+ "CrossAttnUpBlock2D",
166
+ "CrossAttnUpBlock2D",
167
+ ),
168
+ cross_attention_dim=list(encodings.values())[0].shape[-1],
169
+ )
170
 
171
  if args.scheduler == "ddpm":
172
  noise_scheduler = DDPMScheduler(
 
267
 
268
  with accelerator.accumulate(model):
269
  # Predict the noise residual
270
+ if args.encodings is not None:
271
+ noise_pred = model(noisy_images, timesteps,
272
+ batch["encoding"])["sample"]
273
+ else:
274
+ noise_pred = model(noisy_images, timesteps)["sample"]
275
  loss = F.mse_loss(noise_pred, noise)
276
  accelerator.backward(loss)
277
 
 
301
 
302
  # Generate sample images for visual inspection
303
  if accelerator.is_main_process:
304
+ if ((epoch + 1) % args.save_model_epochs == 0
305
+ or (epoch + 1) % args.save_images_epochs == 0
306
+ or epoch == args.num_epochs - 1):
307
  pipeline = AudioDiffusionPipeline(
308
  vqvae=vqvae,
309
  unet=accelerator.unwrap_model(
 
319
 
320
  # save the model
321
  if args.push_to_hub:
322
+ repo.push_to_hub(
323
+ commit_message=f"Epoch {epoch}",
324
+ blocking=False,
325
+ auto_lfs_prune=True,
326
+ )
327
 
328
  if (epoch + 1) % args.save_images_epochs == 0:
329
  generator = torch.Generator(
330
  device=clean_images.device).manual_seed(42)
331
+
332
+ if args.encodings is not None:
333
+ random.seed(42)
334
+ encoding = torch.stack(
335
+ random.sample(list(encodings.values()),
336
+ args.eval_batch_size)).to(
337
+ clean_images.device)
338
+ else:
339
+ encoding = None
340
+
341
  # run pipeline in inference (sample random noise and denoise)
342
+ images, (sample_rate, audios) = pipeline(
343
+ generator=generator,
344
+ batch_size=args.eval_batch_size,
345
+ return_dict=False,
346
+ encoding=encoding,
347
+ )
348
 
349
  # denormalize the images and save to tensorboard
350
  images = np.array([
 
430
  default=None,
431
  help="pretrained VAE model for latent diffusion",
432
  )
433
+ parser.add_argument(
434
+ "--encodings",
435
+ type=str,
436
+ default=None,
437
+ help="picked dictionary mapping audio_file to encoding",
438
+ )
439
 
440
  args = parser.parse_args()
441
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
scripts/train_vae.py CHANGED
@@ -1,50 +1,48 @@
1
  # based on https://github.com/CompVis/stable-diffusion/blob/main/main.py
2
 
3
- import os
4
  import argparse
 
5
 
6
- import torch
7
- import torchvision
8
  import numpy as np
9
- from PIL import Image
10
  import pytorch_lightning as pl
11
- from omegaconf import OmegaConf
12
- from librosa.util import normalize
 
 
13
  from ldm.util import instantiate_from_config
 
 
 
 
14
  from pytorch_lightning.trainer import Trainer
 
15
  from torch.utils.data import DataLoader, Dataset
16
- from datasets import load_from_disk, load_dataset
17
- from diffusers.pipelines.audio_diffusion import Mel
18
  from audiodiffusion.utils import convert_ldm_to_hf_vae
19
- from pytorch_lightning.callbacks import Callback, ModelCheckpoint
20
- from pytorch_lightning.utilities.distributed import rank_zero_only
21
 
22
 
23
  class AudioDiffusion(Dataset):
24
-
25
  def __init__(self, model_id, channels=3):
26
  super().__init__()
27
  self.channels = channels
28
  if os.path.exists(model_id):
29
- self.hf_dataset = load_from_disk(model_id)['train']
30
  else:
31
- self.hf_dataset = load_dataset(model_id)['train']
32
 
33
  def __len__(self):
34
  return len(self.hf_dataset)
35
 
36
  def __getitem__(self, idx):
37
- image = self.hf_dataset[idx]['image']
38
  if self.channels == 3:
39
- image = image.convert('RGB')
40
- image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
41
- (image.height, image.width, self.channels))
42
- image = ((image / 255) * 2 - 1)
43
- return {'image': image}
44
 
45
 
46
  class AudioDiffusionDataModule(pl.LightningDataModule):
47
-
48
  def __init__(self, model_id, batch_size, channels):
49
  super().__init__()
50
  self.batch_size = batch_size
@@ -52,18 +50,11 @@ class AudioDiffusionDataModule(pl.LightningDataModule):
52
  self.num_workers = 1
53
 
54
  def train_dataloader(self):
55
- return DataLoader(self.dataset,
56
- batch_size=self.batch_size,
57
- num_workers=self.num_workers)
58
 
59
 
60
  class ImageLogger(Callback):
61
-
62
- def __init__(self,
63
- every=1000,
64
- hop_length=512,
65
- sample_rate=22050,
66
- n_fft=2048):
67
  super().__init__()
68
  self.every = every
69
  self.hop_length = hop_length
@@ -74,83 +65,75 @@ class ImageLogger(Callback):
74
  def log_images_and_audios(self, pl_module, batch):
75
  pl_module.eval()
76
  with torch.no_grad():
77
- images = pl_module.log_images(batch, split='train')
78
  pl_module.train()
79
 
80
  image_shape = next(iter(images.values())).shape
81
  channels = image_shape[1]
82
- mel = Mel(x_res=image_shape[2],
83
- y_res=image_shape[3],
84
- hop_length=self.hop_length,
85
- sample_rate=self.sample_rate,
86
- n_fft=self.n_fft)
 
 
87
 
88
  for k in images:
89
  images[k] = images[k].detach().cpu()
90
- images[k] = torch.clamp(images[k], -1., 1.)
91
  images[k] = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
92
  grid = torchvision.utils.make_grid(images[k])
93
 
94
  tag = f"train/{k}"
95
- pl_module.logger.experiment.add_image(
96
- tag, grid, global_step=pl_module.global_step)
97
 
98
- images[k] = (images[k].numpy() *
99
- 255).round().astype("uint8").transpose(0, 2, 3, 1)
100
  for _, image in enumerate(images[k]):
101
  audio = mel.image_to_audio(
102
- Image.fromarray(image, mode='RGB').convert('L')
103
- if channels == 3 else Image.fromarray(image[:, :, 0]))
 
 
104
  pl_module.logger.experiment.add_audio(
105
  tag + f"/{_}",
106
  normalize(audio),
107
  global_step=pl_module.global_step,
108
- sample_rate=mel.get_sample_rate())
 
109
 
110
- def on_train_batch_end(self, trainer, pl_module, outputs, batch,
111
- batch_idx):
112
  if (batch_idx + 1) % self.every != 0:
113
  return
114
  self.log_images_and_audios(pl_module, batch)
115
 
116
 
117
  class HFModelCheckpoint(ModelCheckpoint):
118
-
119
  def __init__(self, ldm_config, hf_checkpoint, *args, **kwargs):
120
  super().__init__(*args, **kwargs)
121
  self.ldm_config = ldm_config
122
  self.hf_checkpoint = hf_checkpoint
 
 
 
 
 
123
 
124
  def on_train_epoch_end(self, trainer, pl_module):
125
- ldm_checkpoint = self._get_metric_interpolated_filepath_name(
126
- {'epoch': trainer.current_epoch}, trainer)
127
  super().on_train_epoch_end(trainer, pl_module)
128
- convert_ldm_to_hf_vae(ldm_checkpoint, self.ldm_config,
129
- self.hf_checkpoint)
130
 
131
 
132
  if __name__ == "__main__":
133
  parser = argparse.ArgumentParser(description="Train VAE using ldm.")
134
  parser.add_argument("-d", "--dataset_name", type=str, default=None)
135
  parser.add_argument("-b", "--batch_size", type=int, default=1)
136
- parser.add_argument("-c",
137
- "--ldm_config_file",
138
- type=str,
139
- default="config/ldm_autoencoder_kl.yaml")
140
- parser.add_argument("--ldm_checkpoint_dir",
141
- type=str,
142
- default="models/ldm-autoencoder-kl")
143
- parser.add_argument("--hf_checkpoint_dir",
144
- type=str,
145
- default="models/autoencoder-kl")
146
- parser.add_argument("-r",
147
- "--resume_from_checkpoint",
148
- type=str,
149
- default=None)
150
- parser.add_argument("-g",
151
- "--gradient_accumulation_steps",
152
- type=int,
153
- default=1)
154
  parser.add_argument("--hop_length", type=int, default=512)
155
  parser.add_argument("--sample_rate", type=int, default=22050)
156
  parser.add_argument("--n_fft", type=int, default=2048)
@@ -164,7 +147,8 @@ if __name__ == "__main__":
164
  data = AudioDiffusionDataModule(
165
  model_id=args.dataset_name,
166
  batch_size=args.batch_size,
167
- channels=config.model.params.ddconfig.in_channels)
 
168
  lightning_config = config.pop("lightning", OmegaConf.create())
169
  trainer_config = lightning_config.get("trainer", OmegaConf.create())
170
  trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps
@@ -174,15 +158,20 @@ if __name__ == "__main__":
174
  max_epochs=args.max_epochs,
175
  resume_from_checkpoint=args.resume_from_checkpoint,
176
  callbacks=[
177
- ImageLogger(every=args.save_images_batches,
178
- hop_length=args.hop_length,
179
- sample_rate=args.sample_rate,
180
- n_fft=args.n_fft),
181
- HFModelCheckpoint(ldm_config=config,
182
- hf_checkpoint=args.hf_checkpoint_dir,
183
- dirpath=args.ldm_checkpoint_dir,
184
- filename='{epoch:06}',
185
- verbose=True,
186
- save_last=True)
187
- ])
 
 
 
 
 
188
  trainer.fit(model, data)
 
1
  # based on https://github.com/CompVis/stable-diffusion/blob/main/main.py
2
 
 
3
  import argparse
4
+ import os
5
 
 
 
6
  import numpy as np
 
7
  import pytorch_lightning as pl
8
+ import torch
9
+ import torchvision
10
+ from datasets import load_dataset, load_from_disk
11
+ from diffusers.pipelines.audio_diffusion import Mel
12
  from ldm.util import instantiate_from_config
13
+ from librosa.util import normalize
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ from pytorch_lightning.callbacks import Callback, ModelCheckpoint
17
  from pytorch_lightning.trainer import Trainer
18
+ from pytorch_lightning.utilities.distributed import rank_zero_only
19
  from torch.utils.data import DataLoader, Dataset
20
+
 
21
  from audiodiffusion.utils import convert_ldm_to_hf_vae
 
 
22
 
23
 
24
  class AudioDiffusion(Dataset):
 
25
  def __init__(self, model_id, channels=3):
26
  super().__init__()
27
  self.channels = channels
28
  if os.path.exists(model_id):
29
+ self.hf_dataset = load_from_disk(model_id)["train"]
30
  else:
31
+ self.hf_dataset = load_dataset(model_id)["train"]
32
 
33
  def __len__(self):
34
  return len(self.hf_dataset)
35
 
36
  def __getitem__(self, idx):
37
+ image = self.hf_dataset[idx]["image"]
38
  if self.channels == 3:
39
+ image = image.convert("RGB")
40
+ image = np.frombuffer(image.tobytes(), dtype="uint8").reshape((image.height, image.width, self.channels))
41
+ image = (image / 255) * 2 - 1
42
+ return {"image": image}
 
43
 
44
 
45
  class AudioDiffusionDataModule(pl.LightningDataModule):
 
46
  def __init__(self, model_id, batch_size, channels):
47
  super().__init__()
48
  self.batch_size = batch_size
 
50
  self.num_workers = 1
51
 
52
  def train_dataloader(self):
53
+ return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=self.num_workers)
 
 
54
 
55
 
56
  class ImageLogger(Callback):
57
+ def __init__(self, every=1000, hop_length=512, sample_rate=22050, n_fft=2048):
 
 
 
 
 
58
  super().__init__()
59
  self.every = every
60
  self.hop_length = hop_length
 
65
  def log_images_and_audios(self, pl_module, batch):
66
  pl_module.eval()
67
  with torch.no_grad():
68
+ images = pl_module.log_images(batch, split="train")
69
  pl_module.train()
70
 
71
  image_shape = next(iter(images.values())).shape
72
  channels = image_shape[1]
73
+ mel = Mel(
74
+ x_res=image_shape[2],
75
+ y_res=image_shape[3],
76
+ hop_length=self.hop_length,
77
+ sample_rate=self.sample_rate,
78
+ n_fft=self.n_fft,
79
+ )
80
 
81
  for k in images:
82
  images[k] = images[k].detach().cpu()
83
+ images[k] = torch.clamp(images[k], -1.0, 1.0)
84
  images[k] = (images[k] + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
85
  grid = torchvision.utils.make_grid(images[k])
86
 
87
  tag = f"train/{k}"
88
+ pl_module.logger.experiment.add_image(tag, grid, global_step=pl_module.global_step)
 
89
 
90
+ images[k] = (images[k].numpy() * 255).round().astype("uint8").transpose(0, 2, 3, 1)
 
91
  for _, image in enumerate(images[k]):
92
  audio = mel.image_to_audio(
93
+ Image.fromarray(image, mode="RGB").convert("L")
94
+ if channels == 3
95
+ else Image.fromarray(image[:, :, 0])
96
+ )
97
  pl_module.logger.experiment.add_audio(
98
  tag + f"/{_}",
99
  normalize(audio),
100
  global_step=pl_module.global_step,
101
+ sample_rate=mel.get_sample_rate(),
102
+ )
103
 
104
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
 
105
  if (batch_idx + 1) % self.every != 0:
106
  return
107
  self.log_images_and_audios(pl_module, batch)
108
 
109
 
110
  class HFModelCheckpoint(ModelCheckpoint):
 
111
  def __init__(self, ldm_config, hf_checkpoint, *args, **kwargs):
112
  super().__init__(*args, **kwargs)
113
  self.ldm_config = ldm_config
114
  self.hf_checkpoint = hf_checkpoint
115
+ self.sample_size = None
116
+
117
+ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
118
+ if self.sample_size is None:
119
+ self.sample_size = list(batch["image"].shape[1:3])
120
 
121
  def on_train_epoch_end(self, trainer, pl_module):
122
+ ldm_checkpoint = self._get_metric_interpolated_filepath_name({"epoch": trainer.current_epoch}, trainer)
 
123
  super().on_train_epoch_end(trainer, pl_module)
124
+ self.ldm_config.model.params.ddconfig.resolution = self.sample_size
125
+ convert_ldm_to_hf_vae(ldm_checkpoint, self.ldm_config, self.hf_checkpoint, self.sample_size)
126
 
127
 
128
  if __name__ == "__main__":
129
  parser = argparse.ArgumentParser(description="Train VAE using ldm.")
130
  parser.add_argument("-d", "--dataset_name", type=str, default=None)
131
  parser.add_argument("-b", "--batch_size", type=int, default=1)
132
+ parser.add_argument("-c", "--ldm_config_file", type=str, default="config/ldm_autoencoder_kl.yaml")
133
+ parser.add_argument("--ldm_checkpoint_dir", type=str, default="models/ldm-autoencoder-kl")
134
+ parser.add_argument("--hf_checkpoint_dir", type=str, default="models/autoencoder-kl")
135
+ parser.add_argument("-r", "--resume_from_checkpoint", type=str, default=None)
136
+ parser.add_argument("-g", "--gradient_accumulation_steps", type=int, default=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  parser.add_argument("--hop_length", type=int, default=512)
138
  parser.add_argument("--sample_rate", type=int, default=22050)
139
  parser.add_argument("--n_fft", type=int, default=2048)
 
147
  data = AudioDiffusionDataModule(
148
  model_id=args.dataset_name,
149
  batch_size=args.batch_size,
150
+ channels=config.model.params.ddconfig.in_channels,
151
+ )
152
  lightning_config = config.pop("lightning", OmegaConf.create())
153
  trainer_config = lightning_config.get("trainer", OmegaConf.create())
154
  trainer_config.accumulate_grad_batches = args.gradient_accumulation_steps
 
158
  max_epochs=args.max_epochs,
159
  resume_from_checkpoint=args.resume_from_checkpoint,
160
  callbacks=[
161
+ ImageLogger(
162
+ every=args.save_images_batches,
163
+ hop_length=args.hop_length,
164
+ sample_rate=args.sample_rate,
165
+ n_fft=args.n_fft,
166
+ ),
167
+ HFModelCheckpoint(
168
+ ldm_config=config,
169
+ hf_checkpoint=args.hf_checkpoint_dir,
170
+ dirpath=args.ldm_checkpoint_dir,
171
+ filename="{epoch:06}",
172
+ verbose=True,
173
+ save_last=True,
174
+ ),
175
+ ],
176
+ )
177
  trainer.fit(model, data)