TheNetherWatcher's picture
Upload folder using huggingface_hub
d0ffe9c verified
import logging
from enum import Enum
from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler, LCMScheduler,
LMSDiscreteScheduler, PNDMScheduler,
UniPCMultistepScheduler)
logger = logging.getLogger(__name__)
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
class DiffusionScheduler(str, Enum):
lcm = "lcm" # LCM
ddim = "ddim" # DDIM
pndm = "pndm" # PNDM
heun = "heun" # Heun
unipc = "unipc" # UniPC
euler = "euler" # Euler
euler_a = "euler_a" # Euler a
lms = "lms" # LMS
k_lms = "k_lms" # LMS Karras
dpm_2 = "dpm_2" # DPM2
k_dpm_2 = "k_dpm_2" # DPM2 Karras
dpm_2_a = "dpm_2_a" # DPM2 a
k_dpm_2_a = "k_dpm_2_a" # DPM2 a Karras
dpmpp_2m = "dpmpp_2m" # DPM++ 2M
k_dpmpp_2m = "k_dpmpp_2m" # DPM++ 2M Karras
dpmpp_sde = "dpmpp_sde" # DPM++ SDE
k_dpmpp_sde = "k_dpmpp_sde" # DPM++ SDE Karras
dpmpp_2m_sde = "dpmpp_2m_sde" # DPM++ 2M SDE
k_dpmpp_2m_sde = "k_dpmpp_2m_sde" # DPM++ 2M SDE Karras
def get_scheduler(name: str, config: dict = {}):
is_karras = name.startswith("k_")
if is_karras:
# strip the k_ prefix and add the karras sigma flag to config
name = name.lstrip("k_")
config["use_karras_sigmas"] = True
match name:
case DiffusionScheduler.lcm:
sched_class = LCMScheduler
case DiffusionScheduler.ddim:
sched_class = DDIMScheduler
case DiffusionScheduler.pndm:
sched_class = PNDMScheduler
case DiffusionScheduler.heun:
sched_class = HeunDiscreteScheduler
case DiffusionScheduler.unipc:
sched_class = UniPCMultistepScheduler
case DiffusionScheduler.euler:
sched_class = EulerDiscreteScheduler
case DiffusionScheduler.euler_a:
sched_class = EulerAncestralDiscreteScheduler
case DiffusionScheduler.lms:
sched_class = LMSDiscreteScheduler
case DiffusionScheduler.dpm_2:
# Equivalent to DPM2 in K-Diffusion
sched_class = KDPM2DiscreteScheduler
case DiffusionScheduler.dpm_2_a:
# Equivalent to `DPM2 a`` in K-Diffusion
sched_class = KDPM2AncestralDiscreteScheduler
case DiffusionScheduler.dpmpp_2m:
# Equivalent to `DPM++ 2M` in K-Diffusion
sched_class = DPMSolverMultistepScheduler
config["algorithm_type"] = "dpmsolver++"
config["solver_order"] = 2
case DiffusionScheduler.dpmpp_sde:
# Equivalent to `DPM++ SDE` in K-Diffusion
sched_class = DPMSolverSinglestepScheduler
case DiffusionScheduler.dpmpp_2m_sde:
# Equivalent to `DPM++ 2M SDE` in K-Diffusion
sched_class = DPMSolverMultistepScheduler
config["algorithm_type"] = "sde-dpmsolver++"
case _:
raise ValueError(f"Invalid scheduler '{'k_' if is_karras else ''}{name}'")
return sched_class.from_config(config)