Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
from typing import List, Tuple, Union, Optional | |
import torch | |
from diffusers import DDPMScheduler | |
from diffusers.utils import BaseOutput | |
class DMDSchedulerOutput(BaseOutput): | |
pred_original_sample: Optional[torch.FloatTensor] = None | |
class DMDScheduler(DDPMScheduler): | |
def set_timesteps( | |
self, | |
num_inference_steps: Optional[int] = None, | |
device: Union[str, torch.device] = None, | |
timesteps: Optional[List[int]] = None, | |
): | |
self.timesteps = torch.tensor([self.config.num_train_timesteps-1]).long().to(device) | |
def step( | |
self, | |
model_output: torch.FloatTensor, | |
timestep: int, | |
sample: torch.FloatTensor, | |
generator=None, | |
return_dict: bool = True, | |
) -> Union[DMDSchedulerOutput, Tuple]: | |
t = self.config.num_train_timesteps - 1 | |
# 1. compute alphas, betas | |
alpha_prod_t = self.alphas_cumprod[t] | |
beta_prod_t = 1 - alpha_prod_t | |
if self.config.prediction_type == "epsilon": | |
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
else: | |
raise ValueError( | |
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" | |
" `v_prediction` for the DDPMScheduler." | |
) | |
if not return_dict: | |
return (pred_original_sample,) | |
return DMDSchedulerOutput(pred_original_sample=pred_original_sample) | |