Spaces:
Running
Running
File size: 1,248 Bytes
cab0202 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from diffusers import (
DDIMScheduler,
DDPMScheduler,
DEISMultistepScheduler,
DPMSolverMultistepScheduler,
DPMSolverSinglestepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
PNDMScheduler,
UniPCMultistepScheduler,
)
SCHEDULER_MAPPING = {
"DDIM": DDIMScheduler,
"DDPMScheduler": DDPMScheduler,
"DEISMultistep": DEISMultistepScheduler,
"DPMSolverMultistep": DPMSolverMultistepScheduler,
"DPMSolverSinglestep": DPMSolverSinglestepScheduler,
"EulerAncestralDiscrete": EulerAncestralDiscreteScheduler,
"EulerDiscrete": EulerDiscreteScheduler,
"HeunDiscrete": HeunDiscreteScheduler,
"KDPM2AncestralDiscrete": KDPM2AncestralDiscreteScheduler,
"KDPM2Discrete": KDPM2DiscreteScheduler,
"PNDMScheduler": PNDMScheduler,
"UniPCMultistep": UniPCMultistepScheduler,
}
def get_scheduler(pipe, scheduler):
if scheduler in SCHEDULER_MAPPING:
SchedulerClass = SCHEDULER_MAPPING[scheduler]
pipe.scheduler = SchedulerClass.from_config(pipe.scheduler.config)
else:
raise ValueError(f"Invalid scheduler name {scheduler}")
return pipe
|