|
import logging |
|
from enum import Enum |
|
|
|
from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler, |
|
DPMSolverSinglestepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
HeunDiscreteScheduler, |
|
KDPM2AncestralDiscreteScheduler, |
|
KDPM2DiscreteScheduler, LCMScheduler, |
|
LMSDiscreteScheduler, PNDMScheduler, |
|
UniPCMultistepScheduler) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class DiffusionScheduler(str, Enum): |
|
lcm = "lcm" |
|
ddim = "ddim" |
|
pndm = "pndm" |
|
heun = "heun" |
|
unipc = "unipc" |
|
euler = "euler" |
|
euler_a = "euler_a" |
|
|
|
lms = "lms" |
|
k_lms = "k_lms" |
|
|
|
dpm_2 = "dpm_2" |
|
k_dpm_2 = "k_dpm_2" |
|
|
|
dpm_2_a = "dpm_2_a" |
|
k_dpm_2_a = "k_dpm_2_a" |
|
|
|
dpmpp_2m = "dpmpp_2m" |
|
k_dpmpp_2m = "k_dpmpp_2m" |
|
|
|
dpmpp_sde = "dpmpp_sde" |
|
k_dpmpp_sde = "k_dpmpp_sde" |
|
|
|
dpmpp_2m_sde = "dpmpp_2m_sde" |
|
k_dpmpp_2m_sde = "k_dpmpp_2m_sde" |
|
|
|
|
|
def get_scheduler(name: str, config: dict = {}): |
|
is_karras = name.startswith("k_") |
|
if is_karras: |
|
|
|
name = name.lstrip("k_") |
|
config["use_karras_sigmas"] = True |
|
|
|
match name: |
|
case DiffusionScheduler.lcm: |
|
sched_class = LCMScheduler |
|
case DiffusionScheduler.ddim: |
|
sched_class = DDIMScheduler |
|
case DiffusionScheduler.pndm: |
|
sched_class = PNDMScheduler |
|
case DiffusionScheduler.heun: |
|
sched_class = HeunDiscreteScheduler |
|
case DiffusionScheduler.unipc: |
|
sched_class = UniPCMultistepScheduler |
|
case DiffusionScheduler.euler: |
|
sched_class = EulerDiscreteScheduler |
|
case DiffusionScheduler.euler_a: |
|
sched_class = EulerAncestralDiscreteScheduler |
|
case DiffusionScheduler.lms: |
|
sched_class = LMSDiscreteScheduler |
|
case DiffusionScheduler.dpm_2: |
|
|
|
sched_class = KDPM2DiscreteScheduler |
|
case DiffusionScheduler.dpm_2_a: |
|
|
|
sched_class = KDPM2AncestralDiscreteScheduler |
|
case DiffusionScheduler.dpmpp_2m: |
|
|
|
sched_class = DPMSolverMultistepScheduler |
|
config["algorithm_type"] = "dpmsolver++" |
|
config["solver_order"] = 2 |
|
case DiffusionScheduler.dpmpp_sde: |
|
|
|
sched_class = DPMSolverSinglestepScheduler |
|
case DiffusionScheduler.dpmpp_2m_sde: |
|
|
|
sched_class = DPMSolverMultistepScheduler |
|
config["algorithm_type"] = "sde-dpmsolver++" |
|
case _: |
|
raise ValueError(f"Invalid scheduler '{'k_' if is_karras else ''}{name}'") |
|
|
|
return sched_class.from_config(config) |
|
|