xxxpo13 commited on
Commit
c6a0337
·
verified ·
1 Parent(s): 8a2f9ca

Upload scheduling_flow_matching.py

Browse files
Files changed (1) hide show
  1. scheduling_flow_matching.py +298 -0
scheduling_flow_matching.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union, List
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput, logging
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
11
+ from IPython import embed
12
+
13
+
14
+ @dataclass
15
+ class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
16
+ """
17
+ Output class for the scheduler's `step` function output.
18
+
19
+ Args:
20
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
21
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
22
+ denoising loop.
23
+ """
24
+
25
+ prev_sample: torch.FloatTensor
26
+
27
+
28
+ class PyramidFlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
29
+ """
30
+ Euler scheduler.
31
+
32
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
33
+ methods the library implements for all schedulers such as loading and saving.
34
+
35
+ Args:
36
+ num_train_timesteps (`int`, defaults to 1000):
37
+ The number of diffusion steps to train the model.
38
+ timestep_spacing (`str`, defaults to `"linspace"`):
39
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
40
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
41
+ shift (`float`, defaults to 1.0):
42
+ The shift value for the timestep schedule.
43
+ """
44
+
45
+ _compatibles = []
46
+ order = 1
47
+
48
+ @register_to_config
49
+ def __init__(
50
+ self,
51
+ num_train_timesteps: int = 1000,
52
+ shift: float = 1.0, # Following Stable diffusion 3,
53
+ stages: int = 3,
54
+ stage_range: List = [0, 1/3, 2/3, 1],
55
+ gamma: float = 1/3,
56
+ ):
57
+
58
+ self.timestep_ratios = {} # The timestep ratio for each stage
59
+ self.timesteps_per_stage = {} # The detailed timesteps per stage
60
+ self.sigmas_per_stage = {}
61
+ self.start_sigmas = {}
62
+ self.end_sigmas = {}
63
+ self.ori_start_sigmas = {}
64
+
65
+ # self.init_sigmas()
66
+ self.init_sigmas_for_each_stage()
67
+ self.sigma_min = self.sigmas[-1].item()
68
+ self.sigma_max = self.sigmas[0].item()
69
+ self.gamma = gamma
70
+
71
+ def init_sigmas(self):
72
+ """
73
+ initialize the global timesteps and sigmas
74
+ """
75
+ num_train_timesteps = self.config.num_train_timesteps
76
+ shift = self.config.shift
77
+
78
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
79
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
80
+
81
+ sigmas = timesteps / num_train_timesteps
82
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
83
+
84
+ self.timesteps = sigmas * num_train_timesteps
85
+
86
+ self._step_index = None
87
+ self._begin_index = None
88
+
89
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
90
+
91
+ def init_sigmas_for_each_stage(self):
92
+ """
93
+ Init the timesteps for each stage
94
+ """
95
+ self.init_sigmas()
96
+
97
+ stage_distance = []
98
+ stages = self.config.stages
99
+ training_steps = self.config.num_train_timesteps
100
+ stage_range = self.config.stage_range
101
+
102
+ # Init the start and end point of each stage
103
+ for i_s in range(stages):
104
+ # To decide the start and ends point
105
+ start_indice = int(stage_range[i_s] * training_steps)
106
+ start_indice = max(start_indice, 0)
107
+ end_indice = int(stage_range[i_s+1] * training_steps)
108
+ end_indice = min(end_indice, training_steps)
109
+ start_sigma = self.sigmas[start_indice].item()
110
+ end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
111
+ self.ori_start_sigmas[i_s] = start_sigma
112
+
113
+ if i_s != 0:
114
+ ori_sigma = 1 - start_sigma
115
+ gamma = self.config.gamma
116
+ corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
117
+ # corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
118
+ start_sigma = 1 - corrected_sigma
119
+
120
+ stage_distance.append(start_sigma - end_sigma)
121
+ self.start_sigmas[i_s] = start_sigma
122
+ self.end_sigmas[i_s] = end_sigma
123
+
124
+ # Determine the ratio of each stage according to flow length
125
+ tot_distance = sum(stage_distance)
126
+ for i_s in range(stages):
127
+ if i_s == 0:
128
+ start_ratio = 0.0
129
+ else:
130
+ start_ratio = sum(stage_distance[:i_s]) / tot_distance
131
+ if i_s == stages - 1:
132
+ end_ratio = 1.0
133
+ else:
134
+ end_ratio = sum(stage_distance[:i_s+1]) / tot_distance
135
+
136
+ self.timestep_ratios[i_s] = (start_ratio, end_ratio)
137
+
138
+ # Determine the timesteps and sigmas for each stage
139
+ for i_s in range(stages):
140
+ timestep_ratio = self.timestep_ratios[i_s]
141
+ timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
142
+ timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
143
+ timesteps = np.linspace(
144
+ timestep_max, timestep_min, training_steps + 1,
145
+ )
146
+ self.timesteps_per_stage[i_s] = timesteps[:-1]
147
+ stage_sigmas = np.linspace(
148
+ 1, 0, training_steps + 1,
149
+ )
150
+ self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
151
+
152
+ @property
153
+ def step_index(self):
154
+ """
155
+ The index counter for current timestep. It will increase 1 after each scheduler step.
156
+ """
157
+ return self._step_index
158
+
159
+ @property
160
+ def begin_index(self):
161
+ """
162
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
163
+ """
164
+ return self._begin_index
165
+
166
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
167
+ def set_begin_index(self, begin_index: int = 0):
168
+ """
169
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
170
+
171
+ Args:
172
+ begin_index (`int`):
173
+ The begin index for the scheduler.
174
+ """
175
+ self._begin_index = begin_index
176
+
177
+ def _sigma_to_t(self, sigma):
178
+ return sigma * self.config.num_train_timesteps
179
+
180
+ def set_timesteps(self, num_inference_steps: int, stage_index: int, device: Union[str, torch.device] = None):
181
+ """
182
+ Setting the timesteps and sigmas for each stage
183
+ """
184
+ self.num_inference_steps = num_inference_steps
185
+ training_steps = self.config.num_train_timesteps
186
+ self.init_sigmas()
187
+
188
+ stage_timesteps = self.timesteps_per_stage[stage_index]
189
+ timestep_max = stage_timesteps[0].item()
190
+ timestep_min = stage_timesteps[-1].item()
191
+
192
+ timesteps = np.linspace(
193
+ timestep_max, timestep_min, num_inference_steps,
194
+ )
195
+ self.timesteps = torch.from_numpy(timesteps).to(device=device)
196
+
197
+ stage_sigmas = self.sigmas_per_stage[stage_index]
198
+ sigma_max = stage_sigmas[0].item()
199
+ sigma_min = stage_sigmas[-1].item()
200
+
201
+ ratios = np.linspace(
202
+ sigma_max, sigma_min, num_inference_steps
203
+ )
204
+ sigmas = torch.from_numpy(ratios).to(device=device)
205
+ self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
206
+
207
+ self._step_index = None
208
+
209
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
210
+ if schedule_timesteps is None:
211
+ schedule_timesteps = self.timesteps
212
+
213
+ indices = (schedule_timesteps == timestep).nonzero()
214
+
215
+ # The sigma index that is taken for the **very** first `step`
216
+ # is always the second index (or the last index if there is only 1)
217
+ # This way we can ensure we don't accidentally skip a sigma in
218
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
219
+ pos = 1 if len(indices) > 1 else 0
220
+
221
+ return indices[pos].item()
222
+
223
+ def _init_step_index(self, timestep):
224
+ if self.begin_index is None:
225
+ if isinstance(timestep, torch.Tensor):
226
+ timestep = timestep.to(self.timesteps.device)
227
+ self._step_index = self.index_for_timestep(timestep)
228
+ else:
229
+ self._step_index = self._begin_index
230
+
231
+ def step(
232
+ self,
233
+ model_output: torch.FloatTensor,
234
+ timestep: Union[float, torch.FloatTensor],
235
+ sample: torch.FloatTensor,
236
+ generator: Optional[torch.Generator] = None,
237
+ return_dict: bool = True,
238
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
239
+ """
240
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
241
+ process from the learned model outputs (most often the predicted noise).
242
+
243
+ Args:
244
+ model_output (`torch.FloatTensor`):
245
+ The direct output from learned diffusion model.
246
+ timestep (`float`):
247
+ The current discrete timestep in the diffusion chain.
248
+ sample (`torch.FloatTensor`):
249
+ A current instance of a sample created by the diffusion process.
250
+ generator (`torch.Generator`, *optional*):
251
+ A random number generator.
252
+ return_dict (`bool`):
253
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
254
+ tuple.
255
+
256
+ Returns:
257
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
258
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
259
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
260
+ """
261
+
262
+ if (
263
+ isinstance(timestep, int)
264
+ or isinstance(timestep, torch.IntTensor)
265
+ or isinstance(timestep, torch.LongTensor)
266
+ ):
267
+ raise ValueError(
268
+ (
269
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
270
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
271
+ " one of the `scheduler.timesteps` as a timestep."
272
+ ),
273
+ )
274
+
275
+ if self.step_index is None:
276
+ self._step_index = 0
277
+
278
+ # Upcast to avoid precision issues when computing prev_sample
279
+ sample = sample.to(torch.float32)
280
+
281
+ sigma = self.sigmas[self.step_index]
282
+ sigma_next = self.sigmas[self.step_index + 1]
283
+
284
+ prev_sample = sample + (sigma_next - sigma) * model_output
285
+
286
+ # Cast sample back to model compatible dtype
287
+ prev_sample = prev_sample.to(model_output.dtype)
288
+
289
+ # upon completion increase step index by one
290
+ self._step_index += 1
291
+
292
+ if not return_dict:
293
+ return (prev_sample,)
294
+
295
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
296
+
297
+ def __len__(self):
298
+ return self.config.num_train_timesteps