teticio commited on
Commit
62617b3
1 Parent(s): 37c17e0

fix steps for DDIM

Browse files
Files changed (1) hide show
  1. audiodiffusion/__init__.py +5 -8
audiodiffusion/__init__.py CHANGED
@@ -60,7 +60,7 @@ class AudioDiffusion:
60
 
61
  def generate_spectrogram_and_audio(
62
  self,
63
- steps: int = None,
64
  generator: torch.Generator = None
65
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
66
  """Generate random mel spectrogram and convert to audio.
@@ -85,7 +85,7 @@ class AudioDiffusion:
85
  raw_audio: np.ndarray = None,
86
  slice: int = 0,
87
  start_step: int = 0,
88
- steps: int = None,
89
  generator: torch.Generator = None,
90
  mask_start_secs: float = 0,
91
  mask_end_secs: float = 0
@@ -157,7 +157,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
157
  raw_audio: np.ndarray = None,
158
  slice: int = 0,
159
  start_step: int = 0,
160
- steps: int = None,
161
  generator: torch.Generator = None,
162
  mask_start_secs: float = 0,
163
  mask_end_secs: float = 0
@@ -181,8 +181,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
181
  (float, List[np.ndarray]): sample rate and raw audios
182
  """
183
 
184
- if steps is not None:
185
- self.scheduler.set_timesteps(steps)
186
  mask = None
187
  images = noise = torch.randn(
188
  (batch_size, self.unet.in_channels, mel.y_res, mel.x_res),
@@ -206,9 +205,7 @@ class AudioDiffusionPipeline(DiffusionPipeline):
206
  if start_step > 0:
207
  images[0, 0] = self.scheduler.add_noise(
208
  torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
209
- noise,
210
- torch.tensor(self.scheduler.num_train_timesteps -
211
- start_step))
212
 
213
  pixels_per_second = (mel.get_sample_rate() / mel.hop_length)
214
  mask_start = int(mask_start_secs * pixels_per_second)
 
60
 
61
  def generate_spectrogram_and_audio(
62
  self,
63
+ steps: int = 1000,
64
  generator: torch.Generator = None
65
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
66
  """Generate random mel spectrogram and convert to audio.
 
85
  raw_audio: np.ndarray = None,
86
  slice: int = 0,
87
  start_step: int = 0,
88
+ steps: int = 1000,
89
  generator: torch.Generator = None,
90
  mask_start_secs: float = 0,
91
  mask_end_secs: float = 0
 
157
  raw_audio: np.ndarray = None,
158
  slice: int = 0,
159
  start_step: int = 0,
160
+ steps: int = 1000,
161
  generator: torch.Generator = None,
162
  mask_start_secs: float = 0,
163
  mask_end_secs: float = 0
 
181
  (float, List[np.ndarray]): sample rate and raw audios
182
  """
183
 
184
+ self.scheduler.set_timesteps(steps)
 
185
  mask = None
186
  images = noise = torch.randn(
187
  (batch_size, self.unet.in_channels, mel.y_res, mel.x_res),
 
205
  if start_step > 0:
206
  images[0, 0] = self.scheduler.add_noise(
207
  torch.tensor(input_images[:, np.newaxis, np.newaxis, :]),
208
+ noise, torch.tensor(steps - start_step))
 
 
209
 
210
  pixels_per_second = (mel.get_sample_rate() / mel.hop_length)
211
  mask_start = int(mask_start_secs * pixels_per_second)