michaelj commited on
Commit
99631a2
·
verified ·
1 Parent(s): c883d1e

Upload scheduling_tcd.py

Browse files
Files changed (1) hide show
  1. scheduling_tcd.py +657 -0
scheduling_tcd.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class TCDSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
42
+ denoising loop.
43
+ pred_noised_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
44
+ The predicted noised sample `(x_{s})` based on the model output from the current timestep.
45
+ """
46
+
47
+ prev_sample: torch.FloatTensor
48
+ pred_noised_sample: Optional[torch.FloatTensor] = None
49
+
50
+
51
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
52
+ def betas_for_alpha_bar(
53
+ num_diffusion_timesteps,
54
+ max_beta=0.999,
55
+ alpha_transform_type="cosine",
56
+ ):
57
+ """
58
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
59
+ (1-beta) over time from t = [0,1].
60
+
61
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
62
+ to that part of the diffusion process.
63
+
64
+
65
+ Args:
66
+ num_diffusion_timesteps (`int`): the number of betas to produce.
67
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
68
+ prevent singularities.
69
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
70
+ Choose from `cosine` or `exp`
71
+
72
+ Returns:
73
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
74
+ """
75
+ if alpha_transform_type == "cosine":
76
+
77
+ def alpha_bar_fn(t):
78
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
79
+
80
+ elif alpha_transform_type == "exp":
81
+
82
+ def alpha_bar_fn(t):
83
+ return math.exp(t * -12.0)
84
+
85
+ else:
86
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
87
+
88
+ betas = []
89
+ for i in range(num_diffusion_timesteps):
90
+ t1 = i / num_diffusion_timesteps
91
+ t2 = (i + 1) / num_diffusion_timesteps
92
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
93
+ return torch.tensor(betas, dtype=torch.float32)
94
+
95
+
96
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
97
+ def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
98
+ """
99
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
100
+
101
+
102
+ Args:
103
+ betas (`torch.FloatTensor`):
104
+ the betas that the scheduler is being initialized with.
105
+
106
+ Returns:
107
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
108
+ """
109
+ # Convert betas to alphas_bar_sqrt
110
+ alphas = 1.0 - betas
111
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
112
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
113
+
114
+ # Store old values.
115
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
116
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
117
+
118
+ # Shift so the last timestep is zero.
119
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
120
+
121
+ # Scale so the first timestep is back to the old value.
122
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
123
+
124
+ # Convert alphas_bar_sqrt to betas
125
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
126
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
127
+ alphas = torch.cat([alphas_bar[0:1], alphas])
128
+ betas = 1 - alphas
129
+
130
+ return betas
131
+
132
+
133
+ class TCDScheduler(SchedulerMixin, ConfigMixin):
134
+ """
135
+ `TCDScheduler` incorporates the `Strategic Stochastic Sampling` introduced by the paper `Trajectory Consistency Distillation`,
136
+ extending the original Multistep Consistency Sampling to enable unrestricted trajectory traversal.
137
+
138
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
139
+ attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
140
+ accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
141
+ functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
142
+
143
+ Args:
144
+ num_train_timesteps (`int`, defaults to 1000):
145
+ The number of diffusion steps to train the model.
146
+ beta_start (`float`, defaults to 0.0001):
147
+ The starting `beta` value of inference.
148
+ beta_end (`float`, defaults to 0.02):
149
+ The final `beta` value.
150
+ beta_schedule (`str`, defaults to `"linear"`):
151
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
152
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
153
+ trained_betas (`np.ndarray`, *optional*):
154
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
155
+ original_inference_steps (`int`, *optional*, defaults to 50):
156
+ The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
157
+ will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
158
+ clip_sample (`bool`, defaults to `True`):
159
+ Clip the predicted sample for numerical stability.
160
+ clip_sample_range (`float`, defaults to 1.0):
161
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
162
+ set_alpha_to_one (`bool`, defaults to `True`):
163
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
164
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
165
+ otherwise it uses the alpha value at step 0.
166
+ steps_offset (`int`, defaults to 0):
167
+ An offset added to the inference steps. You can use a combination of `offset=1` and
168
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
169
+ Diffusion.
170
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
171
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
172
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
173
+ Video](https://imagen.research.google/video/paper.pdf) paper).
174
+ thresholding (`bool`, defaults to `False`):
175
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
176
+ as Stable Diffusion.
177
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
178
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
179
+ sample_max_value (`float`, defaults to 1.0):
180
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
181
+ timestep_spacing (`str`, defaults to `"leading"`):
182
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
183
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
184
+ timestep_scaling (`float`, defaults to 10.0):
185
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
186
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
187
+ error at the default of `10.0` is already pretty small).
188
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
189
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
190
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
191
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
192
+ """
193
+
194
+ order = 1
195
+
196
+ @register_to_config
197
+ def __init__(
198
+ self,
199
+ num_train_timesteps: int = 1000,
200
+ beta_start: float = 0.00085,
201
+ beta_end: float = 0.012,
202
+ beta_schedule: str = "scaled_linear",
203
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
204
+ original_inference_steps: int = 50,
205
+ clip_sample: bool = False,
206
+ clip_sample_range: float = 1.0,
207
+ set_alpha_to_one: bool = True,
208
+ steps_offset: int = 0,
209
+ prediction_type: str = "epsilon",
210
+ thresholding: bool = False,
211
+ dynamic_thresholding_ratio: float = 0.995,
212
+ sample_max_value: float = 1.0,
213
+ timestep_spacing: str = "leading",
214
+ timestep_scaling: float = 10.0,
215
+ rescale_betas_zero_snr: bool = False,
216
+ ):
217
+ if trained_betas is not None:
218
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
219
+ elif beta_schedule == "linear":
220
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
221
+ elif beta_schedule == "scaled_linear":
222
+ # this schedule is very specific to the latent diffusion model.
223
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
224
+ elif beta_schedule == "squaredcos_cap_v2":
225
+ # Glide cosine schedule
226
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
227
+ else:
228
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
229
+
230
+ # Rescale for zero SNR
231
+ if rescale_betas_zero_snr:
232
+ self.betas = rescale_zero_terminal_snr(self.betas)
233
+
234
+ self.alphas = 1.0 - self.betas
235
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
236
+
237
+ # At every step in ddim, we are looking into the previous alphas_cumprod
238
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
239
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
240
+ # whether we use the final alpha of the "non-previous" one.
241
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
242
+
243
+ # standard deviation of the initial noise distribution
244
+ self.init_noise_sigma = 1.0
245
+
246
+ # setable values
247
+ self.num_inference_steps = None
248
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
249
+ self.custom_timesteps = False
250
+
251
+ self._step_index = None
252
+
253
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
254
+ def _init_step_index(self, timestep):
255
+ if isinstance(timestep, torch.Tensor):
256
+ timestep = timestep.to(self.timesteps.device)
257
+
258
+ index_candidates = (self.timesteps == timestep).nonzero()
259
+
260
+ # The sigma index that is taken for the **very** first `step`
261
+ # is always the second index (or the last index if there is only 1)
262
+ # This way we can ensure we don't accidentally skip a sigma in
263
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
264
+ if len(index_candidates) > 1:
265
+ step_index = index_candidates[1]
266
+ else:
267
+ step_index = index_candidates[0]
268
+
269
+ self._step_index = step_index.item()
270
+
271
+ @property
272
+ def step_index(self):
273
+ return self._step_index
274
+
275
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
276
+ """
277
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
278
+ current timestep.
279
+
280
+ Args:
281
+ sample (`torch.FloatTensor`):
282
+ The input sample.
283
+ timestep (`int`, *optional*):
284
+ The current timestep in the diffusion chain.
285
+ Returns:
286
+ `torch.FloatTensor`:
287
+ A scaled input sample.
288
+ """
289
+ return sample
290
+
291
+ def _get_variance(self, timestep, prev_timestep):
292
+ alpha_prod_t = self.alphas_cumprod[timestep]
293
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
294
+ beta_prod_t = 1 - alpha_prod_t
295
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
296
+
297
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
298
+
299
+ return variance
300
+
301
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
302
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
303
+ """
304
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
305
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
306
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
307
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
308
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
309
+
310
+ https://arxiv.org/abs/2205.11487
311
+ """
312
+ dtype = sample.dtype
313
+ batch_size, channels, *remaining_dims = sample.shape
314
+
315
+ if dtype not in (torch.float32, torch.float64):
316
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
317
+
318
+ # Flatten sample for doing quantile calculation along each image
319
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
320
+
321
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
322
+
323
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
324
+ s = torch.clamp(
325
+ s, min=1, max=self.config.sample_max_value
326
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
327
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
328
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
329
+
330
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
331
+ sample = sample.to(dtype)
332
+
333
+ return sample
334
+
335
+ def set_timesteps(
336
+ self,
337
+ num_inference_steps: Optional[int] = None,
338
+ device: Union[str, torch.device] = None,
339
+ original_inference_steps: Optional[int] = None,
340
+ timesteps: Optional[List[int]] = None,
341
+ strength: int = 1.0,
342
+ ):
343
+ """
344
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
345
+
346
+ Args:
347
+ num_inference_steps (`int`, *optional*):
348
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
349
+ `timesteps` must be `None`.
350
+ device (`str` or `torch.device`, *optional*):
351
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
352
+ original_inference_steps (`int`, *optional*):
353
+ The original number of inference steps, which will be used to generate a linearly-spaced timestep
354
+ schedule (which is different from the standard `diffusers` implementation). We will then take
355
+ `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
356
+ our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
357
+ timesteps (`List[int]`, *optional*):
358
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
359
+ timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
360
+ schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
361
+ """
362
+ # 0. Check inputs
363
+ if num_inference_steps is None and timesteps is None:
364
+ raise ValueError("Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
365
+
366
+ if num_inference_steps is not None and timesteps is not None:
367
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
368
+
369
+ # 1. Calculate the TCD original training/distillation timestep schedule.
370
+ original_steps = (
371
+ original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
372
+ )
373
+
374
+ if original_steps is not None:
375
+ if original_steps > self.config.num_train_timesteps:
376
+ raise ValueError(
377
+ f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
378
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
379
+ f" maximal {self.config.num_train_timesteps} timesteps."
380
+ )
381
+ # TCD Timesteps Setting
382
+ # The skipping step parameter k from the paper.
383
+ k = self.config.num_train_timesteps // original_steps
384
+ # TCD Training/Distillation Steps Schedule
385
+ tcd_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * k - 1
386
+ else:
387
+ tcd_origin_timesteps = np.asarray(list(range(0, int(self.config.num_train_timesteps * strength))))
388
+
389
+ # 2. Calculate the TCD inference timestep schedule.
390
+ if timesteps is not None:
391
+ # 2.1 Handle custom timestep schedules.
392
+ train_timesteps = set(tcd_origin_timesteps)
393
+ non_train_timesteps = []
394
+ for i in range(1, len(timesteps)):
395
+ if timesteps[i] >= timesteps[i - 1]:
396
+ raise ValueError("`custom_timesteps` must be in descending order.")
397
+
398
+ if timesteps[i] not in train_timesteps:
399
+ non_train_timesteps.append(timesteps[i])
400
+
401
+ if timesteps[0] >= self.config.num_train_timesteps:
402
+ raise ValueError(
403
+ f"`timesteps` must start before `self.config.train_timesteps`:"
404
+ f" {self.config.num_train_timesteps}."
405
+ )
406
+
407
+ # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
408
+ if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
409
+ logger.warning(
410
+ f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
411
+ f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
412
+ f" unexpected results when using this timestep schedule."
413
+ )
414
+
415
+ # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
416
+ if non_train_timesteps:
417
+ logger.warning(
418
+ f"The custom timestep schedule contains the following timesteps which are not on the original"
419
+ f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
420
+ f" when using this timestep schedule."
421
+ )
422
+
423
+ # Raise warning if custom timestep schedule is longer than original_steps
424
+ if original_steps is not None:
425
+ if len(timesteps) > original_steps:
426
+ logger.warning(
427
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
428
+ f" the length of the timestep schedule used for training: {original_steps}. You may get some"
429
+ f" unexpected results when using this timestep schedule."
430
+ )
431
+ else:
432
+ if len(timesteps) > self.config.num_train_timesteps:
433
+ logger.warning(
434
+ f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
435
+ f" the length of the timestep schedule used for training: {self.config.num_train_timesteps}. You may get some"
436
+ f" unexpected results when using this timestep schedule."
437
+ )
438
+
439
+ timesteps = np.array(timesteps, dtype=np.int64)
440
+ self.num_inference_steps = len(timesteps)
441
+ self.custom_timesteps = True
442
+
443
+ # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
444
+ init_timestep = min(int(self.num_inference_steps * strength), self.num_inference_steps)
445
+ t_start = max(self.num_inference_steps - init_timestep, 0)
446
+ timesteps = timesteps[t_start * self.order :]
447
+ # TODO: also reset self.num_inference_steps?
448
+ else:
449
+ # 2.2 Create the "standard" TCD inference timestep schedule.
450
+ if num_inference_steps > self.config.num_train_timesteps:
451
+ raise ValueError(
452
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
453
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
454
+ f" maximal {self.config.num_train_timesteps} timesteps."
455
+ )
456
+
457
+ if original_steps is not None:
458
+ skipping_step = len(tcd_origin_timesteps) // num_inference_steps
459
+
460
+ if skipping_step < 1:
461
+ raise ValueError(
462
+ f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
463
+ )
464
+
465
+ self.num_inference_steps = num_inference_steps
466
+
467
+ if original_steps is not None:
468
+ if num_inference_steps > original_steps:
469
+ raise ValueError(
470
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
471
+ f" {original_steps} because the final timestep schedule will be a subset of the"
472
+ f" `original_inference_steps`-sized initial timestep schedule."
473
+ )
474
+ else:
475
+ if num_inference_steps > self.config.num_train_timesteps:
476
+ raise ValueError(
477
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `num_train_timesteps`:"
478
+ f" {self.config.num_train_timesteps} because the final timestep schedule will be a subset of the"
479
+ f" `num_train_timesteps`-sized initial timestep schedule."
480
+ )
481
+
482
+ # TCD Inference Steps Schedule
483
+ tcd_origin_timesteps = tcd_origin_timesteps[::-1].copy()
484
+ # Select (approximately) evenly spaced indices from tcd_origin_timesteps.
485
+ inference_indices = np.linspace(0, len(tcd_origin_timesteps), num=num_inference_steps, endpoint=False)
486
+ inference_indices = np.floor(inference_indices).astype(np.int64)
487
+ timesteps = tcd_origin_timesteps[inference_indices]
488
+
489
+ self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long)
490
+
491
+ self._step_index = None
492
+
493
+ def step(
494
+ self,
495
+ model_output: torch.FloatTensor,
496
+ timestep: int,
497
+ sample: torch.FloatTensor,
498
+ eta: float,
499
+ generator: Optional[torch.Generator] = None,
500
+ return_dict: bool = True,
501
+ ) -> Union[TCDSchedulerOutput, Tuple]:
502
+ """
503
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
504
+ process from the learned model outputs (most often the predicted noise).
505
+
506
+ Args:
507
+ model_output (`torch.FloatTensor`):
508
+ The direct output from learned diffusion model.
509
+ timestep (`int`):
510
+ The current discrete timestep in the diffusion chain.
511
+ sample (`torch.FloatTensor`):
512
+ A current instance of a sample created by the diffusion process.
513
+ eta (`float`):
514
+ A stochastic parameter (referred to as `gamma` in the paper) used to control the stochasticity in every step.
515
+ When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
516
+ generator (`torch.Generator`, *optional*):
517
+ A random number generator.
518
+ return_dict (`bool`, *optional*, defaults to `True`):
519
+ Whether or not to return a [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] or `tuple`.
520
+ Returns:
521
+ [`~schedulers.scheduling_utils.TCDSchedulerOutput`] or `tuple`:
522
+ If return_dict is `True`, [`~schedulers.scheduling_tcd.TCDSchedulerOutput`] is returned, otherwise a
523
+ tuple is returned where the first element is the sample tensor.
524
+ """
525
+ if self.num_inference_steps is None:
526
+ raise ValueError(
527
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
528
+ )
529
+
530
+ if self.step_index is None:
531
+ self._init_step_index(timestep)
532
+
533
+ # 1. get previous step value
534
+ prev_step_index = self.step_index + 1
535
+ if prev_step_index < len(self.timesteps):
536
+ prev_timestep = self.timesteps[prev_step_index]
537
+ else:
538
+ prev_timestep = torch.tensor(0)
539
+
540
+ timestep_s = torch.floor((1 - eta) * prev_timestep).to(dtype=torch.long)
541
+
542
+ # 2. compute alphas, betas
543
+ alpha_prod_t = self.alphas_cumprod[timestep]
544
+ beta_prod_t = 1 - alpha_prod_t
545
+
546
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
547
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
548
+
549
+ alpha_prod_s = self.alphas_cumprod[timestep_s] if timestep_s >= 0 else self.final_alpha_cumprod
550
+ beta_prod_s = 1 - alpha_prod_s
551
+
552
+ # 3. Compute the predicted noised sample x_s based on the model parameterization
553
+ if self.config.prediction_type == "epsilon": # noise-prediction
554
+ pred_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
555
+ pred_epsilon = model_output
556
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
557
+ elif self.config.prediction_type == "sample": # x-prediction
558
+ pred_original_sample = model_output
559
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
560
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
561
+ elif self.config.prediction_type == "v_prediction": # v-prediction
562
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
563
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
564
+ pred_noised_sample = alpha_prod_s.sqrt() * pred_original_sample + beta_prod_s.sqrt() * pred_epsilon
565
+ else:
566
+ raise ValueError(
567
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
568
+ " `v_prediction` for `TCDScheduler`."
569
+ )
570
+
571
+ # 4. Sample and inject noise z ~ N(0, I) for MultiStep Inference
572
+ # Noise is not used on the final timestep of the timestep schedule.
573
+ # This also means that noise is not used for one-step sampling.
574
+ # Eta (referred to as "gamma" in the paper) was introduced to control the stochasticity in every step.
575
+ # When eta = 0, it represents deterministic sampling, whereas eta = 1 indicates full stochastic sampling.
576
+ if eta > 0:
577
+ if self.step_index != self.num_inference_steps - 1:
578
+ noise = randn_tensor(
579
+ model_output.shape, generator=generator, device=model_output.device, dtype=pred_noised_sample.dtype
580
+ )
581
+ prev_sample = (alpha_prod_t_prev / alpha_prod_s).sqrt() * pred_noised_sample + (1 - alpha_prod_t_prev / alpha_prod_s).sqrt() * noise
582
+ else:
583
+ prev_sample = pred_noised_sample
584
+ else:
585
+ prev_sample = pred_noised_sample
586
+
587
+ # upon completion increase step index by one
588
+ self._step_index += 1
589
+
590
+ if not return_dict:
591
+ return (prev_sample, pred_noised_sample)
592
+
593
+ return TCDSchedulerOutput(prev_sample=prev_sample, pred_noised_sample=pred_noised_sample)
594
+
595
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
596
+ def add_noise(
597
+ self,
598
+ original_samples: torch.FloatTensor,
599
+ noise: torch.FloatTensor,
600
+ timesteps: torch.IntTensor,
601
+ ) -> torch.FloatTensor:
602
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
603
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
604
+ timesteps = timesteps.to(original_samples.device)
605
+
606
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
607
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
608
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
609
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
610
+
611
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
612
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
613
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
614
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
615
+
616
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
617
+ return noisy_samples
618
+
619
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
620
+ def get_velocity(
621
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
622
+ ) -> torch.FloatTensor:
623
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
624
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
625
+ timesteps = timesteps.to(sample.device)
626
+
627
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
628
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
629
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
630
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
631
+
632
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
633
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
634
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
635
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
636
+
637
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
638
+ return velocity
639
+
640
+ def __len__(self):
641
+ return self.config.num_train_timesteps
642
+
643
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
644
+ def previous_timestep(self, timestep):
645
+ if self.custom_timesteps:
646
+ index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
647
+ if index == self.timesteps.shape[0] - 1:
648
+ prev_t = torch.tensor(-1)
649
+ else:
650
+ prev_t = self.timesteps[index + 1]
651
+ else:
652
+ num_inference_steps = (
653
+ self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
654
+ )
655
+ prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
656
+
657
+ return prev_t