adamelliotfields commited on
Commit
eb9126a
1 Parent(s): dcd3cb4

Add UniPC scheduler and remove DDIM

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. lib/config.py +4 -3
  3. lib/loader.py +15 -19
app.py CHANGED
@@ -152,7 +152,7 @@ with gr.Blocks(
152
  )
153
 
154
  # Model settings
155
- gr.HTML("<h3>Settings</h3>")
156
  with gr.Row():
157
  model = gr.Dropdown(
158
  choices=Config.MODELS,
 
152
  )
153
 
154
  # Model settings
155
+ gr.HTML("<h3>Model</h3>")
156
  with gr.Row():
157
  model = gr.Dropdown(
158
  choices=Config.MODELS,
lib/config.py CHANGED
@@ -9,6 +9,7 @@ from diffusers import (
9
  EulerDiscreteScheduler,
10
  StableDiffusionXLImg2ImgPipeline,
11
  StableDiffusionXLPipeline,
 
12
  )
13
  from diffusers.utils import logging as diffusers_logging
14
  from transformers import logging as transformers_logging
@@ -79,13 +80,13 @@ Config = SimpleNamespace(
79
  ],
80
  VAE_MODEL="madebyollin/sdxl-vae-fp16-fix",
81
  REFINER_MODEL="stabilityai/stable-diffusion-xl-refiner-1.0",
82
- SCHEDULER="Euler",
83
  SCHEDULERS={
84
- "DDIM": DDIMScheduler,
85
- "DEIS 2M": DEISMultistepScheduler,
86
  "DPM++ 2M": DPMSolverMultistepScheduler,
87
  "Euler": EulerDiscreteScheduler,
88
  "Euler a": EulerAncestralDiscreteScheduler,
 
89
  },
90
  WIDTH=1024,
91
  HEIGHT=1024,
 
9
  EulerDiscreteScheduler,
10
  StableDiffusionXLImg2ImgPipeline,
11
  StableDiffusionXLPipeline,
12
+ UniPCMultistepScheduler,
13
  )
14
  from diffusers.utils import logging as diffusers_logging
15
  from transformers import logging as transformers_logging
 
80
  ],
81
  VAE_MODEL="madebyollin/sdxl-vae-fp16-fix",
82
  REFINER_MODEL="stabilityai/stable-diffusion-xl-refiner-1.0",
83
+ SCHEDULER="UniPC",
84
  SCHEDULERS={
85
+ "DEIS": DEISMultistepScheduler,
 
86
  "DPM++ 2M": DPMSolverMultistepScheduler,
87
  "Euler": EulerDiscreteScheduler,
88
  "Euler a": EulerAncestralDiscreteScheduler,
89
+ "UniPC": UniPCMultistepScheduler,
90
  },
91
  WIDTH=1024,
92
  HEIGHT=1024,
lib/loader.py CHANGED
@@ -126,6 +126,9 @@ class Loader:
126
  self.refiner.deepcache.enable()
127
 
128
  def load(self, kind, model, scheduler, deepcache_interval, scale, use_karras, use_refiner, progress=None):
 
 
 
129
  scheduler_kwargs = {
130
  "beta_start": 0.00085,
131
  "beta_end": 0.012,
@@ -134,32 +137,25 @@ class Loader:
134
  "steps_offset": 1,
135
  }
136
 
137
- if scheduler not in ["DDIM", "Euler a"]:
138
- scheduler_kwargs["use_karras_sigmas"] = use_karras
139
-
140
- if scheduler == "DDIM":
141
- scheduler_kwargs["clip_sample"] = False
142
- scheduler_kwargs["set_alpha_to_one"] = False
143
-
144
- if model not in Config.SINGLE_FILE_MODELS:
145
- variant = "fp16"
146
- else:
147
- variant = None
148
-
149
- dtype = torch.float16
150
  pipeline_kwargs = {
151
- "variant": variant,
152
- "torch_dtype": dtype,
153
  "add_watermarker": False,
154
  "scheduler": Config.SCHEDULERS[scheduler](**scheduler_kwargs),
155
- "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
156
  }
157
 
158
- self.unload(model, use_refiner, deepcache_interval, scale)
 
159
 
160
- Pipeline = Config.PIPELINES[kind]
161
- Scheduler = Config.SCHEDULERS[scheduler]
 
 
 
 
 
162
 
 
163
  try:
164
  with timer(f"Loading {model}", logger=self.log.info):
165
  self.model = model
 
126
  self.refiner.deepcache.enable()
127
 
128
  def load(self, kind, model, scheduler, deepcache_interval, scale, use_karras, use_refiner, progress=None):
129
+ Pipeline = Config.PIPELINES[kind]
130
+ Scheduler = Config.SCHEDULERS[scheduler]
131
+
132
  scheduler_kwargs = {
133
  "beta_start": 0.00085,
134
  "beta_end": 0.012,
 
137
  "steps_offset": 1,
138
  }
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  pipeline_kwargs = {
141
+ "torch_dtype": torch.float16,
 
142
  "add_watermarker": False,
143
  "scheduler": Config.SCHEDULERS[scheduler](**scheduler_kwargs),
144
+ "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=torch.float16),
145
  }
146
 
147
+ if scheduler not in ["Euler a"]:
148
+ scheduler_kwargs["use_karras_sigmas"] = use_karras
149
 
150
+ if model not in Config.SINGLE_FILE_MODELS:
151
+ pipeline_kwargs["variant"] = "fp16"
152
+ else:
153
+ pipeline_kwargs["variant"] = None
154
+
155
+ # Unload
156
+ self.unload(model, use_refiner, deepcache_interval, scale)
157
 
158
+ # Load
159
  try:
160
  with timer(f"Loading {model}", logger=self.log.info):
161
  self.model = model