teticio commited on
Commit
f34a81b
1 Parent(s): d787bde

added encode and slerp

Browse files
Files changed (1) hide show
  1. audiodiffusion/__init__.py +85 -17
audiodiffusion/__init__.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Iterable, Tuple, Union, List
2
 
3
  import torch
@@ -86,17 +87,19 @@ class AudioDiffusion:
86
  return images[0], (sample_rate, audios[0])
87
 
88
  def generate_spectrogram_and_audio_from_audio(
89
- self,
90
- audio_file: str = None,
91
- raw_audio: np.ndarray = None,
92
- slice: int = 0,
93
- start_step: int = 0,
94
- steps: int = 1000,
95
- generator: torch.Generator = None,
96
- mask_start_secs: float = 0,
97
- mask_end_secs: float = 0,
98
- step_generator: torch.Generator = None,
99
- eta: float = 0) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
 
 
100
  """Generate random mel spectrogram from audio input and convert to audio.
101
 
102
  Args:
@@ -110,6 +113,7 @@ class AudioDiffusion:
110
  mask_end_secs (float): number of seconds of audio to mask (not generate) at end
111
  step_generator (torch.Generator): random number generator used to denoise or None
112
  eta (float): parameter between 0 and 1 used with DDIM scheduler
 
113
 
114
  Returns:
115
  PIL Image: mel spectrogram
@@ -128,7 +132,8 @@ class AudioDiffusion:
128
  mask_start_secs=mask_start_secs,
129
  mask_end_secs=mask_end_secs,
130
  step_generator=step_generator,
131
- eta=eta)
 
132
  return images[0], (sample_rate, audios[0])
133
 
134
  @staticmethod
@@ -173,7 +178,8 @@ class AudioDiffusionPipeline(DiffusionPipeline):
173
  mask_start_secs: float = 0,
174
  mask_end_secs: float = 0,
175
  step_generator: torch.Generator = None,
176
- eta: float = 0
 
177
  ) -> Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]:
178
  """Generate random mel spectrogram from audio input and convert to audio.
179
 
@@ -190,6 +196,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
190
  mask_end_secs (float): number of seconds of audio to mask (not generate) at end
191
  step_generator (torch.Generator): random number generator used to denoise or None
192
  eta (float): parameter between 0 and 1 used with DDIM scheduler
 
193
 
194
  Returns:
195
  List[PIL Image]: mel spectrograms
@@ -201,10 +208,13 @@ class AudioDiffusionPipeline(DiffusionPipeline):
201
  mask = None
202
  # For backwards compatiibility
203
  if type(self.unet.sample_size) == int:
204
- self.unet.sample_size = (self.unet.sample_size, self.unet.sample_size)
205
- images = noise = torch.randn(
206
- (batch_size, self.unet.in_channels) + self.unet.sample_size,
207
- generator=generator)
 
 
 
208
 
209
  if audio_file is not None or raw_audio is not None:
210
  mel.load_audio(audio_file, raw_audio)
@@ -278,6 +288,64 @@ class AudioDiffusionPipeline(DiffusionPipeline):
278
  audios = list(map(lambda _: mel.image_to_audio(_), images))
279
  return images, (mel.get_sample_rate(), audios)
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
  class LatentAudioDiffusionPipeline(AudioDiffusionPipeline):
283
 
 
1
+ from math import acos, sin
2
  from typing import Iterable, Tuple, Union, List
3
 
4
  import torch
 
87
  return images[0], (sample_rate, audios[0])
88
 
89
  def generate_spectrogram_and_audio_from_audio(
90
+ self,
91
+ audio_file: str = None,
92
+ raw_audio: np.ndarray = None,
93
+ slice: int = 0,
94
+ start_step: int = 0,
95
+ steps: int = 1000,
96
+ generator: torch.Generator = None,
97
+ mask_start_secs: float = 0,
98
+ mask_end_secs: float = 0,
99
+ step_generator: torch.Generator = None,
100
+ eta: float = 0,
101
+ noise: torch.Tensor = None
102
+ ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
103
  """Generate random mel spectrogram from audio input and convert to audio.
104
 
105
  Args:
 
113
  mask_end_secs (float): number of seconds of audio to mask (not generate) at end
114
  step_generator (torch.Generator): random number generator used to denoise or None
115
  eta (float): parameter between 0 and 1 used with DDIM scheduler
116
+ noise (torch.Tensor): noisy image or None
117
 
118
  Returns:
119
  PIL Image: mel spectrogram
 
132
  mask_start_secs=mask_start_secs,
133
  mask_end_secs=mask_end_secs,
134
  step_generator=step_generator,
135
+ eta=eta,
136
+ noise=noise)
137
  return images[0], (sample_rate, audios[0])
138
 
139
  @staticmethod
 
178
  mask_start_secs: float = 0,
179
  mask_end_secs: float = 0,
180
  step_generator: torch.Generator = None,
181
+ eta: float = 0,
182
+ noise: torch.Tensor = None
183
  ) -> Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]:
184
  """Generate random mel spectrogram from audio input and convert to audio.
185
 
 
196
  mask_end_secs (float): number of seconds of audio to mask (not generate) at end
197
  step_generator (torch.Generator): random number generator used to denoise or None
198
  eta (float): parameter between 0 and 1 used with DDIM scheduler
199
+ noise (torch.Tensor): noisy image or None
200
 
201
  Returns:
202
  List[PIL Image]: mel spectrograms
 
208
  mask = None
209
  # For backwards compatiibility
210
  if type(self.unet.sample_size) == int:
211
+ self.unet.sample_size = (self.unet.sample_size,
212
+ self.unet.sample_size)
213
+ if noise is None:
214
+ noise = torch.randn(
215
+ (batch_size, self.unet.in_channels) + self.unet.sample_size,
216
+ generator=generator)
217
+ images = noise
218
 
219
  if audio_file is not None or raw_audio is not None:
220
  mel.load_audio(audio_file, raw_audio)
 
288
  audios = list(map(lambda _: mel.image_to_audio(_), images))
289
  return images, (mel.get_sample_rate(), audios)
290
 
291
+ @torch.no_grad()
292
+ def encode(self, images: List[Image.Image]) -> np.ndarray:
293
+ """Reverse step process: recover noisy image from generated image.
294
+
295
+ Args:
296
+ images (List[PIL Image]): list of images to encode
297
+
298
+ Returns:
299
+ np.ndarray: noise tensor of shape (batch_size, 1, height, width)
300
+ """
301
+
302
+ # Only works with DDIM as this method is deterministic
303
+ assert isinstance(self.scheduler, DDIMScheduler)
304
+ sample = np.array([
305
+ np.frombuffer(image.tobytes(), dtype="uint8").reshape(
306
+ (1, image.height, image.width)) for image in images
307
+ ])
308
+ sample = ((sample / 255) * 2 - 1)
309
+ sample = torch.Tensor(sample).to(self.device)
310
+
311
+ for t in torch.flip(self.scheduler.timesteps, (0, )):
312
+ prev_timestep = (t - self.scheduler.num_train_timesteps //
313
+ self.scheduler.num_inference_steps)
314
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
315
+ alpha_prod_t_prev = (self.scheduler.alphas_cumprod[prev_timestep]
316
+ if prev_timestep >= 0 else
317
+ self.scheduler.final_alpha_cumprod)
318
+ beta_prod_t = 1 - alpha_prod_t
319
+ model_output = self.unet(sample, t)['sample']
320
+ pred_sample_direction = (1 -
321
+ alpha_prod_t_prev)**(0.5) * model_output
322
+ sample = (sample -
323
+ pred_sample_direction) * alpha_prod_t_prev**(-0.5)
324
+ sample = sample * alpha_prod_t**(0.5) + beta_prod_t**(
325
+ 0.5) * model_output
326
+
327
+ return sample
328
+
329
+ @staticmethod
330
+ def slerp(x0: torch.Tensor, x1: torch.Tensor,
331
+ alpha: float) -> torch.Tensor:
332
+ """Spherical Linear intERPolation
333
+
334
+ Args:
335
+ x0 (torch.Tensor): first tensor to interpolate between
336
+ x1 (torch.Tensor): seconds tensor to interpolate between
337
+ alpha (float): interpolation betwen 0 and 1
338
+
339
+ Returns:
340
+ torch.Tensor: interpolated tensor
341
+ """
342
+
343
+ theta = acos(
344
+ torch.dot(torch.flatten(x0), torch.flatten(x1)) / torch.norm(x0) /
345
+ torch.norm(x1))
346
+ return sin((1 - alpha) * theta) * x0 / sin(theta) + sin(
347
+ alpha * theta) * x1 / sin(theta)
348
+
349
 
350
  class LatentAudioDiffusionPipeline(AudioDiffusionPipeline):
351