File size: 3,501 Bytes
d0ffe9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import logging
from enum import Enum

from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
                                  DPMSolverSinglestepScheduler,
                                  EulerAncestralDiscreteScheduler,
                                  EulerDiscreteScheduler,
                                  HeunDiscreteScheduler,
                                  KDPM2AncestralDiscreteScheduler,
                                  KDPM2DiscreteScheduler, LCMScheduler,
                                  LMSDiscreteScheduler, PNDMScheduler,
                                  UniPCMultistepScheduler)

logger = logging.getLogger(__name__)


# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
class DiffusionScheduler(str, Enum):
    lcm = "lcm"    # LCM
    ddim = "ddim"  # DDIM
    pndm = "pndm"  # PNDM
    heun = "heun"  # Heun
    unipc = "unipc"  # UniPC
    euler = "euler"  # Euler
    euler_a = "euler_a"  # Euler a

    lms = "lms"  # LMS
    k_lms = "k_lms"  # LMS Karras

    dpm_2 = "dpm_2"  # DPM2
    k_dpm_2 = "k_dpm_2"  # DPM2 Karras

    dpm_2_a = "dpm_2_a"  # DPM2 a
    k_dpm_2_a = "k_dpm_2_a"  # DPM2 a Karras

    dpmpp_2m = "dpmpp_2m"  # DPM++ 2M
    k_dpmpp_2m = "k_dpmpp_2m"  # DPM++ 2M Karras

    dpmpp_sde = "dpmpp_sde"  # DPM++ SDE
    k_dpmpp_sde = "k_dpmpp_sde"  # DPM++ SDE Karras

    dpmpp_2m_sde = "dpmpp_2m_sde"  # DPM++ 2M SDE
    k_dpmpp_2m_sde = "k_dpmpp_2m_sde"  # DPM++ 2M SDE Karras


def get_scheduler(name: str, config: dict = {}):
    is_karras = name.startswith("k_")
    if is_karras:
        # strip the k_ prefix and add the karras sigma flag to config
        name = name.lstrip("k_")
        config["use_karras_sigmas"] = True

    match name:
        case DiffusionScheduler.lcm:
            sched_class = LCMScheduler
        case DiffusionScheduler.ddim:
            sched_class = DDIMScheduler
        case DiffusionScheduler.pndm:
            sched_class = PNDMScheduler
        case DiffusionScheduler.heun:
            sched_class = HeunDiscreteScheduler
        case DiffusionScheduler.unipc:
            sched_class = UniPCMultistepScheduler
        case DiffusionScheduler.euler:
            sched_class = EulerDiscreteScheduler
        case DiffusionScheduler.euler_a:
            sched_class = EulerAncestralDiscreteScheduler
        case DiffusionScheduler.lms:
            sched_class = LMSDiscreteScheduler
        case DiffusionScheduler.dpm_2:
            # Equivalent to DPM2 in K-Diffusion
            sched_class = KDPM2DiscreteScheduler
        case DiffusionScheduler.dpm_2_a:
            # Equivalent to `DPM2 a`` in K-Diffusion
            sched_class = KDPM2AncestralDiscreteScheduler
        case DiffusionScheduler.dpmpp_2m:
            # Equivalent to `DPM++ 2M` in K-Diffusion
            sched_class = DPMSolverMultistepScheduler
            config["algorithm_type"] = "dpmsolver++"
            config["solver_order"] = 2
        case DiffusionScheduler.dpmpp_sde:
            # Equivalent to `DPM++ SDE` in K-Diffusion
            sched_class = DPMSolverSinglestepScheduler
        case DiffusionScheduler.dpmpp_2m_sde:
            # Equivalent to `DPM++ 2M SDE` in K-Diffusion
            sched_class = DPMSolverMultistepScheduler
            config["algorithm_type"] = "sde-dpmsolver++"
        case _:
            raise ValueError(f"Invalid scheduler '{'k_' if is_karras else ''}{name}'")

    return sched_class.from_config(config)