|
from dataclasses import dataclass |
|
from typing import List, Tuple, Union, Optional |
|
|
|
import torch |
|
from diffusers import DDPMScheduler |
|
from diffusers.utils import BaseOutput |
|
|
|
|
|
@dataclass |
|
class DMDSchedulerOutput(BaseOutput): |
|
pred_original_sample: Optional[torch.FloatTensor] = None |
|
|
|
|
|
class DMDScheduler(DDPMScheduler): |
|
def set_timesteps( |
|
self, |
|
num_inference_steps: Optional[int] = None, |
|
device: Union[str, torch.device] = None, |
|
timesteps: Optional[List[int]] = None, |
|
): |
|
self.timesteps = torch.tensor([self.config.num_train_timesteps-1]).long().to(device) |
|
|
|
def step( |
|
self, |
|
model_output: torch.FloatTensor, |
|
timestep: int, |
|
sample: torch.FloatTensor, |
|
generator=None, |
|
return_dict: bool = True, |
|
) -> Union[DMDSchedulerOutput, Tuple]: |
|
t = self.config.num_train_timesteps - 1 |
|
|
|
|
|
alpha_prod_t = self.alphas_cumprod[t] |
|
beta_prod_t = 1 - alpha_prod_t |
|
|
|
if self.config.prediction_type == "epsilon": |
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) |
|
else: |
|
raise ValueError( |
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" |
|
" `v_prediction` for the DDPMScheduler." |
|
) |
|
|
|
if not return_dict: |
|
return (pred_original_sample,) |
|
|
|
return DMDSchedulerOutput(pred_original_sample=pred_original_sample) |
|
|