adamelliotfields commited on
Commit
6ad0411
1 Parent(s): 0d34381

Clean up loader

Browse files
Files changed (8) hide show
  1. .gitignore +1 -0
  2. app.py +15 -17
  3. lib/__init__.py +0 -2
  4. lib/config.py +16 -11
  5. lib/inference.py +21 -21
  6. lib/loader.py +126 -177
  7. lib/utils.py +1 -23
  8. requirements.txt +2 -3
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__/
2
  .venv/
 
 
1
  __pycache__/
2
  .venv/
3
+ app.log
app.py CHANGED
@@ -2,7 +2,14 @@ import argparse
2
 
3
  import gradio as gr
4
 
5
- from lib import Config, async_call, disable_progress_bars, download_repo_files, generate, read_file, read_json
 
 
 
 
 
 
 
6
 
7
  # Update refresh button hover text
8
  seed_js = """
@@ -55,28 +62,19 @@ random_prompt_js = f"""
55
 
56
 
57
  # Transform the raw inputs before generation
58
- async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
59
  if len(args) > 0:
60
  prompt = args[0]
61
  else:
62
  prompt = None
63
  if prompt is None or prompt.strip() == "":
64
  raise gr.Error("You must enter a prompt")
65
-
66
  try:
67
- if Config.ZERO_GPU:
68
- progress((0, 100), desc="ZeroGPU init")
69
-
70
- images = await async_call(
71
- generate,
72
- *args,
73
- Error=gr.Error,
74
- Info=gr.Info,
75
- progress=progress,
76
- )
77
  except RuntimeError:
78
  raise gr.Error("Error: Please try again")
79
-
80
  return images
81
 
82
 
@@ -259,7 +257,7 @@ with gr.Blocks(
259
  )
260
  use_refiner = gr.Checkbox(
261
  elem_classes=["checkbox"],
262
- label="Refiner",
263
  value=False,
264
  )
265
 
@@ -322,8 +320,8 @@ if __name__ == "__main__":
322
  parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
323
  args = parser.parse_args()
324
 
325
- disable_progress_bars()
326
- for repo_id, allow_patterns in Config.HF_MODELS.items():
327
  download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
328
 
329
  # https://www.gradio.app/docs/gradio/interface#interface-queue
 
2
 
3
  import gradio as gr
4
 
5
+ from lib import (
6
+ Config,
7
+ # disable_progress_bars,
8
+ download_repo_files,
9
+ generate,
10
+ read_file,
11
+ read_json,
12
+ )
13
 
14
  # Update refresh button hover text
15
  seed_js = """
 
62
 
63
 
64
  # Transform the raw inputs before generation
65
+ def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
66
  if len(args) > 0:
67
  prompt = args[0]
68
  else:
69
  prompt = None
70
  if prompt is None or prompt.strip() == "":
71
  raise gr.Error("You must enter a prompt")
 
72
  try:
73
+ # if Config.ZERO_GPU:
74
+ # progress((0, 100), desc="ZeroGPU init")
75
+ images = generate(*args, Error=gr.Error, Info=gr.Info, progress=progress)
 
 
 
 
 
 
 
76
  except RuntimeError:
77
  raise gr.Error("Error: Please try again")
 
78
  return images
79
 
80
 
 
257
  )
258
  use_refiner = gr.Checkbox(
259
  elem_classes=["checkbox"],
260
+ label="Use refiner",
261
  value=False,
262
  )
263
 
 
320
  parser.add_argument("-p", "--port", type=int, metavar="INT", default=7860)
321
  args = parser.parse_args()
322
 
323
+ # disable_progress_bars()
324
+ for repo_id, allow_patterns in Config.HF_REPOS.items():
325
  download_repo_files(repo_id, allow_patterns, token=Config.HF_TOKEN)
326
 
327
  # https://www.gradio.app/docs/gradio/interface#interface-queue
lib/__init__.py CHANGED
@@ -1,7 +1,6 @@
1
  from .config import Config
2
  from .inference import generate
3
  from .utils import (
4
- async_call,
5
  disable_progress_bars,
6
  download_repo_files,
7
  read_file,
@@ -10,7 +9,6 @@ from .utils import (
10
 
11
  __all__ = [
12
  "Config",
13
- "async_call",
14
  "disable_progress_bars",
15
  "download_repo_files",
16
  "generate",
 
1
  from .config import Config
2
  from .inference import generate
3
  from .utils import (
 
4
  disable_progress_bars,
5
  download_repo_files,
6
  read_file,
 
9
 
10
  __all__ = [
11
  "Config",
 
12
  "disable_progress_bars",
13
  "download_repo_files",
14
  "generate",
lib/config.py CHANGED
@@ -56,15 +56,12 @@ _sdxl_files = [
56
  "tokenizer/vocab.json",
57
  ]
58
 
 
 
59
  # Using namespace instead of dataclass for simplicity
60
  Config = SimpleNamespace(
61
  HF_TOKEN=os.environ.get("HF_TOKEN", None),
62
  ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
63
- HF_MODELS={
64
- "segmind/Segmind-Vega": [*_sdxl_files],
65
- "stabilityai/stable-diffusion-xl-base-1.0": [*_sdxl_files, "vae_1_0/config.json"],
66
- "stabilityai/stable-diffusion-xl-refiner-1.0": [*_sdxl_refiner_files],
67
- },
68
  PIPELINES={
69
  "txt2img": StableDiffusionXLPipeline,
70
  "img2img": StableDiffusionXLImg2ImgPipeline,
@@ -77,13 +74,21 @@ Config = SimpleNamespace(
77
  "SG161222/RealVisXL_V5.0",
78
  "stabilityai/stable-diffusion-xl-base-1.0",
79
  ],
80
- # Single-file model weights
81
- MODEL_CHECKPOINTS={
82
- # keep keys lowercase for case-insensitive matching in the loader
83
- "cyberdelia/cyberrealsticxl": "CyberRealisticXLPlay_V1.0.safetensors", # typo in "realistic"
84
- "fluently/fluently-xl-final": "FluentlyXL-Final.safetensors",
85
- "sg161222/realvisxl_v5.0": "RealVisXL_V5.0_fp16.safetensors",
 
 
 
86
  },
 
 
 
 
 
87
  VAE_MODEL="madebyollin/sdxl-vae-fp16-fix",
88
  REFINER_MODEL="stabilityai/stable-diffusion-xl-refiner-1.0",
89
  SCHEDULER="Euler",
 
56
  "tokenizer/vocab.json",
57
  ]
58
 
59
+ _sdxl_files_with_vae = [*_sdxl_files, "vae_1_0/config.json"]
60
+
61
  # Using namespace instead of dataclass for simplicity
62
  Config = SimpleNamespace(
63
  HF_TOKEN=os.environ.get("HF_TOKEN", None),
64
  ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
 
 
 
 
 
65
  PIPELINES={
66
  "txt2img": StableDiffusionXLPipeline,
67
  "img2img": StableDiffusionXLImg2ImgPipeline,
 
74
  "SG161222/RealVisXL_V5.0",
75
  "stabilityai/stable-diffusion-xl-base-1.0",
76
  ],
77
+ HF_REPOS={
78
+ "ai-forever/Real-ESRGAN": ["RealESRGAN_x2.pth", "RealESRGAN_x4.pth"],
79
+ "cyberdelia/CyberRealsticXL": ["CyberRealisticXLPlay_V1.0.safetensors"],
80
+ "fluently/Fluently-XL-Final": ["FluentlyXL-Final.safetensors"],
81
+ "madebyollin/sdxl-vae-fp16-fix": ["config.json", "diffusion_pytorch_model.fp16.safetensors"],
82
+ "segmind/Segmind-Vega": _sdxl_files,
83
+ "SG161222/RealVisXL_V5.0": ["RealVisXL_V5.0_fp16.safetensors"],
84
+ "stabilityai/stable-diffusion-xl-base-1.0": _sdxl_files_with_vae,
85
+ "stabilityai/stable-diffusion-xl-refiner-1.0": _sdxl_refiner_files,
86
  },
87
+ SINGLE_FILE_MODELS=[
88
+ "cyberdelia/cyberrealsticxl",
89
+ "fluently/fluently-xl-final",
90
+ "sg161222/realvisxl_v5.0",
91
+ ],
92
  VAE_MODEL="madebyollin/sdxl-vae-fp16-fix",
93
  REFINER_MODEL="stabilityai/stable-diffusion-xl-refiner-1.0",
94
  SCHEDULER="Euler",
lib/inference.py CHANGED
@@ -9,7 +9,7 @@ from spaces import GPU
9
  from .config import Config
10
  from .loader import Loader
11
  from .logger import Logger
12
- from .utils import clear_cuda_cache, safe_progress, timer
13
 
14
 
15
  # Dynamic signature for the GPU duration function; max 60s per image
@@ -55,6 +55,11 @@ def generate(
55
  Info=None,
56
  progress=None,
57
  ):
 
 
 
 
 
58
  start = time.perf_counter()
59
  log = Logger("generate")
60
  log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
@@ -69,11 +74,6 @@ def generate(
69
  if seed is None or seed < 0:
70
  seed = int(datetime.now().timestamp() * 1e6) % (2**64)
71
 
72
- KIND = "txt2img"
73
- CURRENT_STEP = 0
74
- CURRENT_IMAGE = 1
75
- EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
76
-
77
  # custom progress bar for multiple images
78
  def callback_on_step_end(pipeline, step, timestep, latents):
79
  nonlocal CURRENT_IMAGE, CURRENT_STEP
@@ -107,29 +107,29 @@ def generate(
107
  progress,
108
  )
109
 
110
- if loader.pipe is None:
111
- raise Error(f"Error loading {model}")
112
-
113
- pipe = loader.pipe
114
  refiner = loader.refiner
 
115
  upscaler = loader.upscaler
116
 
 
 
 
117
  # prompt embeds for base and refiner
118
  compel_1 = Compel(
119
- text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
120
- tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
121
  requires_pooled=[False, True],
122
  returned_embeddings_type=EMBEDDINGS_TYPE,
123
- dtype_for_device_getter=lambda _: pipe.dtype,
124
- device=pipe.device,
125
  )
126
  compel_2 = Compel(
127
- text_encoder=[pipe.text_encoder_2],
128
- tokenizer=[pipe.tokenizer_2],
129
  requires_pooled=[True],
130
  returned_embeddings_type=EMBEDDINGS_TYPE,
131
- dtype_for_device_getter=lambda _: pipe.dtype,
132
- device=pipe.device,
133
  )
134
 
135
  images = []
@@ -138,7 +138,7 @@ def generate(
138
 
139
  for i in range(num_images):
140
  try:
141
- generator = torch.Generator(device=pipe.device).manual_seed(current_seed)
142
  conditioning_1, pooled_1 = compel_1([positive_prompt, negative_prompt])
143
  conditioning_2, pooled_2 = compel_2([positive_prompt, negative_prompt])
144
  except PromptParser.ParsingException:
@@ -186,7 +186,7 @@ def generate(
186
  refiner_kwargs["callback_on_step_end"] = callback_on_step_end
187
 
188
  try:
189
- image = pipe(**pipe_kwargs).images[0]
190
  if use_refiner:
191
  refiner_kwargs["image"] = image
192
  image = refiner(**refiner_kwargs).images[0]
@@ -207,7 +207,7 @@ def generate(
207
  safe_progress(progress, i + 1, num_images, desc=msg)
208
 
209
  # Flush memory after generating
210
- clear_cuda_cache()
211
 
212
  end = time.perf_counter()
213
  msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
 
9
  from .config import Config
10
  from .loader import Loader
11
  from .logger import Logger
12
+ from .utils import cuda_collect, safe_progress, timer
13
 
14
 
15
  # Dynamic signature for the GPU duration function; max 60s per image
 
55
  Info=None,
56
  progress=None,
57
  ):
58
+ KIND = "txt2img"
59
+ CURRENT_STEP = 0
60
+ CURRENT_IMAGE = 1
61
+ EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED
62
+
63
  start = time.perf_counter()
64
  log = Logger("generate")
65
  log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
 
74
  if seed is None or seed < 0:
75
  seed = int(datetime.now().timestamp() * 1e6) % (2**64)
76
 
 
 
 
 
 
77
  # custom progress bar for multiple images
78
  def callback_on_step_end(pipeline, step, timestep, latents):
79
  nonlocal CURRENT_IMAGE, CURRENT_STEP
 
107
  progress,
108
  )
109
 
 
 
 
 
110
  refiner = loader.refiner
111
+ pipeline = loader.pipeline
112
  upscaler = loader.upscaler
113
 
114
+ if pipeline is None:
115
+ raise Error(f"Error loading {model}")
116
+
117
  # prompt embeds for base and refiner
118
  compel_1 = Compel(
119
+ text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
120
+ tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
121
  requires_pooled=[False, True],
122
  returned_embeddings_type=EMBEDDINGS_TYPE,
123
+ dtype_for_device_getter=lambda _: pipeline.dtype,
124
+ device=pipeline.device,
125
  )
126
  compel_2 = Compel(
127
+ text_encoder=[pipeline.text_encoder_2],
128
+ tokenizer=[pipeline.tokenizer_2],
129
  requires_pooled=[True],
130
  returned_embeddings_type=EMBEDDINGS_TYPE,
131
+ dtype_for_device_getter=lambda _: pipeline.dtype,
132
+ device=pipeline.device,
133
  )
134
 
135
  images = []
 
138
 
139
  for i in range(num_images):
140
  try:
141
+ generator = torch.Generator(device=pipeline.device).manual_seed(current_seed)
142
  conditioning_1, pooled_1 = compel_1([positive_prompt, negative_prompt])
143
  conditioning_2, pooled_2 = compel_2([positive_prompt, negative_prompt])
144
  except PromptParser.ParsingException:
 
186
  refiner_kwargs["callback_on_step_end"] = callback_on_step_end
187
 
188
  try:
189
+ image = pipeline(**pipe_kwargs).images[0]
190
  if use_refiner:
191
  refiner_kwargs["image"] = image
192
  image = refiner(**refiner_kwargs).images[0]
 
207
  safe_progress(progress, i + 1, num_images, desc=msg)
208
 
209
  # Flush memory after generating
210
+ cuda_collect()
211
 
212
  end = time.perf_counter()
213
  msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
lib/loader.py CHANGED
@@ -1,5 +1,4 @@
1
- import gc
2
- from threading import Lock
3
 
4
  import torch
5
  from DeepCache import DeepCacheSDHelper
@@ -8,194 +7,125 @@ from diffusers.models import AutoencoderKL
8
  from .config import Config
9
  from .logger import Logger
10
  from .upscaler import RealESRGAN
11
- from .utils import clear_cuda_cache, timer
12
 
13
 
14
  class Loader:
15
- _instance = None
16
- _lock = Lock()
17
-
18
- def __new__(cls):
19
- with cls._lock:
20
- if cls._instance is None:
21
- cls._instance = super().__new__(cls)
22
- cls._instance.pipe = None
23
- cls._instance.model = None
24
- cls._instance.refiner = None
25
- cls._instance.upscaler = None
26
- cls._instance.log = Logger("Loader")
27
- return cls._instance
28
-
29
- def _should_unload_refiner(self, refiner=False):
30
- if self.refiner is None:
31
- return False
32
- if not refiner:
33
- return True
34
- return False
35
 
36
- def _should_unload_upscaler(self, scale=1):
37
- if self.upscaler is not None and self.upscaler.scale != scale:
38
- return True
39
- return False
40
 
41
- def _should_unload_deepcache(self, interval=1):
42
- has_deepcache = hasattr(self.pipe, "deepcache")
 
 
 
43
  if has_deepcache and interval == 1:
44
  return True
45
- if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
46
  return True
47
  return False
48
 
49
- def _should_unload_pipeline(self, model=""):
50
- if self.pipe is None:
51
- return False
52
- if self.model and self.model.lower() != model.lower():
 
 
 
 
 
 
 
 
 
 
53
  return True
54
  return False
55
 
56
- def _unload_refiner(self):
57
- if self.refiner is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  with timer("Unloading refiner"):
59
  self.refiner.to("cpu", silence_dtype_warnings=True)
 
 
60
 
61
- def _unload_upscaler(self):
62
- if self.upscaler is not None:
63
  with timer(f"Unloading {self.upscaler.scale}x upscaler"):
64
  self.upscaler.to("cpu")
 
 
65
 
66
- def _unload_deepcache(self):
67
- if self.pipe.deepcache is not None:
68
- self.log.info("Disabling DeepCache")
69
- self.pipe.deepcache.disable()
70
- delattr(self.pipe, "deepcache")
71
- if self.refiner is not None:
72
- if hasattr(self.refiner, "deepcache"):
73
- self.refiner.deepcache.disable()
74
- delattr(self.refiner, "deepcache")
75
-
76
- def _unload_pipeline(self):
77
- if self.pipe is not None:
78
  with timer(f"Unloading {self.model}"):
79
- self.pipe.to("cpu", silence_dtype_warnings=True)
80
- if self.refiner is not None:
81
  self.refiner.vae = None
82
  self.refiner.scheduler = None
83
  self.refiner.tokenizer_2 = None
84
  self.refiner.text_encoder_2 = None
85
-
86
- def _unload(self, model, refiner, deepcache, scale):
87
- to_unload = []
88
- if self._should_unload_deepcache(deepcache): # remove deepcache first
89
- self._unload_deepcache()
90
-
91
- if self._should_unload_refiner(refiner):
92
- self._unload_refiner()
93
- to_unload.append("refiner")
94
-
95
- if self._should_unload_upscaler(scale):
96
- self._unload_upscaler()
97
- to_unload.append("upscaler")
98
-
99
- if self._should_unload_pipeline(model):
100
- self._unload_pipeline()
101
- to_unload.append("model")
102
- to_unload.append("pipe")
103
-
104
- # Flush cache and run garbage collector
105
- clear_cuda_cache()
106
- for component in to_unload:
107
- setattr(self, component, None)
108
- gc.collect()
109
-
110
- def _should_load_refiner(self, refiner=False):
111
- if self.refiner is None and refiner:
112
- return True
113
- return False
114
-
115
- def _should_load_upscaler(self, scale=1):
116
- if self.upscaler is None and scale > 1:
117
- return True
118
- return False
119
-
120
- def _should_load_deepcache(self, interval=1):
121
- has_deepcache = hasattr(self.pipe, "deepcache")
122
- if not has_deepcache and interval != 1:
123
- return True
124
- if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
125
- return True
126
- return False
127
-
128
- def _should_load_pipeline(self):
129
- if self.pipe is None:
130
- return True
131
- return False
132
-
133
- def _load_refiner(self, refiner, progress, **kwargs):
134
- if self._should_load_refiner(refiner):
135
- model = Config.REFINER_MODEL
136
- pipeline = Config.PIPELINES["img2img"]
137
- try:
138
- with timer(f"Loading {model}"):
139
- self.refiner = pipeline.from_pretrained(model, **kwargs).to("cuda")
140
- except Exception as e:
141
- self.log.error(f"Error loading {model}: {e}")
142
- self.refiner = None
143
- return
144
  if self.refiner is not None:
145
  self.refiner.set_progress_bar_config(disable=progress is not None)
146
 
147
- def _load_upscaler(self, scale=1):
148
- if self._should_load_upscaler(scale):
149
  try:
150
  with timer(f"Loading {scale}x upscaler"):
151
- self.upscaler = RealESRGAN(scale, device=self.pipe.device)
152
  self.upscaler.load_weights()
153
  except Exception as e:
154
  self.log.error(f"Error loading {scale}x upscaler: {e}")
155
  self.upscaler = None
156
 
157
- def _load_deepcache(self, interval=1):
158
- if self._should_load_deepcache(interval):
159
  self.log.info("Enabling DeepCache")
160
- self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe)
161
- self.pipe.deepcache.set_params(cache_interval=interval)
162
- self.pipe.deepcache.enable()
163
- if self.refiner is not None:
164
  self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
165
  self.refiner.deepcache.set_params(cache_interval=interval)
166
  self.refiner.deepcache.enable()
167
 
168
- def _load_pipeline(self, kind, model, progress, **kwargs):
169
- pipeline = Config.PIPELINES[kind]
170
- if self._should_load_pipeline():
171
- try:
172
- with timer(f"Loading {model}"):
173
- self.model = model
174
- if model.lower() in Config.MODEL_CHECKPOINTS.keys():
175
- self.pipe = pipeline.from_single_file(
176
- f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
177
- **kwargs,
178
- ).to("cuda")
179
- else:
180
- self.pipe = pipeline.from_pretrained(model, **kwargs).to("cuda")
181
- if self.refiner is not None:
182
- self.refiner.vae = self.pipe.vae
183
- self.refiner.scheduler = self.pipe.scheduler
184
- self.refiner.tokenizer_2 = self.pipe.tokenizer_2
185
- self.refiner.text_encoder_2 = self.pipe.text_encoder_2
186
- self.refiner.to(self.pipe.device)
187
- except Exception as e:
188
- self.log.error(f"Error loading {model}: {e}")
189
- self.model = None
190
- self.pipe = None
191
- self.refiner = None
192
- return
193
- if not isinstance(self.pipe, pipeline):
194
- self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
195
- if self.pipe is not None:
196
- self.pipe.set_progress_bar_config(disable=progress is not None)
197
-
198
- def load(self, kind, model, scheduler, deepcache, scale, karras, refiner, progress):
199
  scheduler_kwargs = {
200
  "beta_start": 0.00085,
201
  "beta_end": 0.012,
@@ -205,14 +135,13 @@ class Loader:
205
  }
206
 
207
  if scheduler not in ["DDIM", "Euler a"]:
208
- scheduler_kwargs["use_karras_sigmas"] = karras
209
 
210
- # https://github.com/huggingface/diffusers/blob/8a3f0c1/scripts/convert_original_stable_diffusion_to_diffusers.py#L939
211
  if scheduler == "DDIM":
212
  scheduler_kwargs["clip_sample"] = False
213
  scheduler_kwargs["set_alpha_to_one"] = False
214
 
215
- if model.lower() not in Config.MODEL_CHECKPOINTS.keys():
216
  variant = "fp16"
217
  else:
218
  variant = None
@@ -226,47 +155,67 @@ class Loader:
226
  "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
227
  }
228
 
229
- self._unload(model, refiner, deepcache, scale)
230
- self._load_pipeline(kind, model, progress, **pipe_kwargs)
231
-
232
- # error loading model
233
- if self.pipe is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  return
235
 
236
- same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
 
 
 
 
 
 
 
237
  same_karras = (
238
- not hasattr(self.pipe.scheduler.config, "use_karras_sigmas")
239
- or self.pipe.scheduler.config.use_karras_sigmas == karras
240
  )
241
 
242
- # same model, different scheduler
243
  if self.model.lower() == model.lower():
244
  if not same_scheduler:
245
  self.log.info(f"Enabling {scheduler}")
246
  if not same_karras:
247
- self.log.info(f"{'Enabling' if karras else 'Disabling'} Karras sigmas")
248
  if not same_scheduler or not same_karras:
249
- self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
250
  if self.refiner is not None:
251
- self.refiner.scheduler = self.pipe.scheduler
252
 
253
- if self._should_load_refiner(refiner):
254
- # https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/model_index.json
255
  refiner_kwargs = {
256
  "variant": "fp16",
257
  "torch_dtype": dtype,
258
  "add_watermarker": False,
259
  "requires_aesthetics_score": True,
260
  "force_zeros_for_empty_prompt": False,
261
- "vae": self.pipe.vae,
262
- "scheduler": self.pipe.scheduler,
263
- "tokenizer_2": self.pipe.tokenizer_2,
264
- "text_encoder_2": self.pipe.text_encoder_2,
265
  }
266
- self._load_refiner(refiner, progress, **refiner_kwargs) # load refiner before deepcache
267
 
268
- if self._should_load_deepcache(deepcache):
269
- self._load_deepcache(deepcache)
270
 
271
- if self._should_load_upscaler(scale):
272
- self._load_upscaler(scale)
 
1
+ # import gc
 
2
 
3
  import torch
4
  from DeepCache import DeepCacheSDHelper
 
7
  from .config import Config
8
  from .logger import Logger
9
  from .upscaler import RealESRGAN
10
+ from .utils import cuda_collect, timer
11
 
12
 
13
  class Loader:
14
+ def __init__(self):
15
+ self.model = ""
16
+ self.refiner = None
17
+ self.pipeline = None
18
+ self.upscaler = None
19
+ self.log = Logger("Loader")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def should_unload_refiner(self, use_refiner=False):
22
+ return self.refiner is not None and not use_refiner
 
 
23
 
24
+ def should_unload_upscaler(self, scale=1):
25
+ return self.upscaler is not None and self.upscaler.scale != scale
26
+
27
+ def should_unload_deepcache(self, interval=1):
28
+ has_deepcache = hasattr(self.pipeline, "deepcache")
29
  if has_deepcache and interval == 1:
30
  return True
31
+ if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != interval:
32
  return True
33
  return False
34
 
35
+ def should_unload_pipeline(self, model=""):
36
+ return self.pipeline is not None and self.model.lower() != model.lower()
37
+
38
+ def should_load_refiner(self, use_refiner=False):
39
+ return self.refiner is None and use_refiner
40
+
41
+ def should_load_upscaler(self, scale=1):
42
+ return self.upscaler is None and scale > 1
43
+
44
+ def should_load_deepcache(self, interval=1):
45
+ has_deepcache = hasattr(self.pipeline, "deepcache")
46
+ if not has_deepcache and interval != 1:
47
+ return True
48
+ if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != interval:
49
  return True
50
  return False
51
 
52
+ def should_load_pipeline(self):
53
+ return self.pipeline is None
54
+
55
+ def unload(self, model, use_refiner, deepcache_interval, scale):
56
+ needs_gc = False
57
+
58
+ if self.should_unload_deepcache(deepcache_interval):
59
+ self.log.info("Disabling DeepCache")
60
+ self.pipeline.deepcache.disable()
61
+ delattr(self.pipeline, "deepcache")
62
+ if self.refiner:
63
+ self.refiner.deepcache.disable()
64
+ delattr(self.refiner, "deepcache")
65
+
66
+ if self.should_unload_refiner(use_refiner):
67
  with timer("Unloading refiner"):
68
  self.refiner.to("cpu", silence_dtype_warnings=True)
69
+ self.refiner = None
70
+ needs_gc = True
71
 
72
+ if self.should_unload_upscaler(scale):
 
73
  with timer(f"Unloading {self.upscaler.scale}x upscaler"):
74
  self.upscaler.to("cpu")
75
+ self.upscaler = None
76
+ needs_gc = True
77
 
78
+ if self.should_unload_pipeline(model):
 
 
 
 
 
 
 
 
 
 
 
79
  with timer(f"Unloading {self.model}"):
80
+ self.pipeline.to("cpu", silence_dtype_warnings=True)
81
+ if self.refiner:
82
  self.refiner.vae = None
83
  self.refiner.scheduler = None
84
  self.refiner.tokenizer_2 = None
85
  self.refiner.text_encoder_2 = None
86
+ self.pipeline = None
87
+ self.model = None
88
+ needs_gc = True
89
+
90
+ if needs_gc:
91
+ cuda_collect()
92
+ # gc.collect()
93
+
94
+ def load_refiner(self, refiner_kwargs={}, progress=None):
95
+ model = Config.REFINER_MODEL
96
+ try:
97
+ with timer(f"Loading {model}"):
98
+ Pipeline = Config.PIPELINES["img2img"]
99
+ self.refiner = Pipeline.from_pretrained(model, **refiner_kwargs).to("cuda")
100
+ except Exception as e:
101
+ self.log.error(f"Error loading {model}: {e}")
102
+ self.refiner = None
103
+ return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  if self.refiner is not None:
105
  self.refiner.set_progress_bar_config(disable=progress is not None)
106
 
107
+ def load_upscaler(self, scale=1):
108
+ if self.should_load_upscaler(scale):
109
  try:
110
  with timer(f"Loading {scale}x upscaler"):
111
+ self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
112
  self.upscaler.load_weights()
113
  except Exception as e:
114
  self.log.error(f"Error loading {scale}x upscaler: {e}")
115
  self.upscaler = None
116
 
117
+ def load_deepcache(self, interval=1):
118
+ if self.should_load_deepcache(interval):
119
  self.log.info("Enabling DeepCache")
120
+ self.pipeline.deepcache = DeepCacheSDHelper(pipe=self.pipeline)
121
+ self.pipeline.deepcache.set_params(cache_interval=interval)
122
+ self.pipeline.deepcache.enable()
123
+ if self.refiner:
124
  self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
125
  self.refiner.deepcache.set_params(cache_interval=interval)
126
  self.refiner.deepcache.enable()
127
 
128
+ def load(self, kind, model, scheduler, deepcache_interval, scale, use_karras, use_refiner, progress):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  scheduler_kwargs = {
130
  "beta_start": 0.00085,
131
  "beta_end": 0.012,
 
135
  }
136
 
137
  if scheduler not in ["DDIM", "Euler a"]:
138
+ scheduler_kwargs["use_karras_sigmas"] = use_karras
139
 
 
140
  if scheduler == "DDIM":
141
  scheduler_kwargs["clip_sample"] = False
142
  scheduler_kwargs["set_alpha_to_one"] = False
143
 
144
+ if model.lower() not in Config.SINGLE_FILE_MODELS:
145
  variant = "fp16"
146
  else:
147
  variant = None
 
155
  "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=dtype),
156
  }
157
 
158
+ self.unload(model, use_refiner, deepcache_interval, scale)
159
+
160
+ Pipeline = Config.PIPELINES[kind]
161
+ Scheduler = Config.SCHEDULERS[scheduler]
162
+
163
+ try:
164
+ with timer(f"Loading {model}"):
165
+ self.model = model
166
+ if model.lower() in Config.SINGLE_FILE_MODELS:
167
+ checkpoint = Config.HF_REPOS[model][0]
168
+ self.pipeline = Pipeline.from_single_file(
169
+ f"https://huggingface.co/{model}/{checkpoint}",
170
+ **pipe_kwargs,
171
+ ).to("cuda")
172
+ else:
173
+ self.pipeline = Pipeline.from_pretrained(model, **pipe_kwargs).to("cuda")
174
+ except Exception as e:
175
+ self.log.error(f"Error loading {model}: {e}")
176
+ self.model = None
177
+ self.pipeline = None
178
  return
179
 
180
+ if not isinstance(self.pipeline, Pipeline):
181
+ self.pipeline = Pipeline.from_pipe(self.pipeline).to("cuda")
182
+
183
+ if self.pipeline is not None:
184
+ self.pipeline.set_progress_bar_config(disable=progress is not None)
185
+
186
+ # Check and update scheduler if necessary
187
+ same_scheduler = isinstance(self.pipeline.scheduler, Scheduler)
188
  same_karras = (
189
+ not hasattr(self.pipeline.scheduler.config, "use_karras_sigmas")
190
+ or self.pipeline.scheduler.config.use_karras_sigmas == use_karras
191
  )
192
 
 
193
  if self.model.lower() == model.lower():
194
  if not same_scheduler:
195
  self.log.info(f"Enabling {scheduler}")
196
  if not same_karras:
197
+ self.log.info(f"{'Enabling' if use_karras else 'Disabling'} Karras sigmas")
198
  if not same_scheduler or not same_karras:
199
+ self.pipeline.scheduler = Scheduler(**scheduler_kwargs)
200
  if self.refiner is not None:
201
+ self.refiner.scheduler = self.pipeline.scheduler
202
 
203
+ if self.should_load_refiner(use_refiner):
 
204
  refiner_kwargs = {
205
  "variant": "fp16",
206
  "torch_dtype": dtype,
207
  "add_watermarker": False,
208
  "requires_aesthetics_score": True,
209
  "force_zeros_for_empty_prompt": False,
210
+ "vae": self.pipeline.vae,
211
+ "scheduler": self.pipeline.scheduler,
212
+ "tokenizer_2": self.pipeline.tokenizer_2,
213
+ "text_encoder_2": self.pipeline.text_encoder_2,
214
  }
215
+ self.load_refiner(refiner_kwargs, progress)
216
 
217
+ if self.should_load_deepcache(deepcache_interval):
218
+ self.load_deepcache(deepcache_interval)
219
 
220
+ if self.should_load_upscaler(scale):
221
+ self.load_upscaler(scale)
lib/utils.py CHANGED
@@ -1,24 +1,13 @@
1
  import functools
2
- import inspect
3
  import json
4
  import time
5
  from contextlib import contextmanager
6
- from typing import Callable, TypeVar
7
 
8
- import anyio
9
  import torch
10
- from anyio import Semaphore
11
  from diffusers.utils import logging as diffusers_logging
12
  from huggingface_hub._snapshot_download import snapshot_download
13
  from huggingface_hub.utils import are_progress_bars_disabled
14
  from transformers import logging as transformers_logging
15
- from typing_extensions import ParamSpec
16
-
17
- T = TypeVar("T")
18
- P = ParamSpec("P")
19
-
20
- MAX_CONCURRENT_THREADS = 1
21
- MAX_THREADS_GUARD = Semaphore(MAX_CONCURRENT_THREADS)
22
 
23
 
24
  @contextmanager
@@ -61,7 +50,7 @@ def safe_progress(progress, current=0, total=0, desc=""):
61
  progress((current, total), desc=desc)
62
 
63
 
64
- def clear_cuda_cache():
65
  if torch.cuda.is_available():
66
  torch.cuda.empty_cache()
67
  torch.cuda.ipc_collect()
@@ -83,14 +72,3 @@ def download_repo_files(repo_id, allow_patterns, token=None):
83
  if was_disabled:
84
  disable_progress_bars()
85
  return snapshot_path
86
-
87
-
88
- # Like the original but supports args and kwargs instead of a dict
89
- # https://github.com/huggingface/huggingface-inference-toolkit/blob/0.2.0/src/huggingface_inference_toolkit/async_utils.py
90
- async def async_call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
91
- async with MAX_THREADS_GUARD:
92
- sig = inspect.signature(fn)
93
- bound_args = sig.bind(*args, **kwargs)
94
- bound_args.apply_defaults()
95
- partial_fn = functools.partial(fn, **bound_args.arguments)
96
- return await anyio.to_thread.run_sync(partial_fn)
 
1
  import functools
 
2
  import json
3
  import time
4
  from contextlib import contextmanager
 
5
 
 
6
  import torch
 
7
  from diffusers.utils import logging as diffusers_logging
8
  from huggingface_hub._snapshot_download import snapshot_download
9
  from huggingface_hub.utils import are_progress_bars_disabled
10
  from transformers import logging as transformers_logging
 
 
 
 
 
 
 
11
 
12
 
13
  @contextmanager
 
50
  progress((current, total), desc=desc)
51
 
52
 
53
+ def cuda_collect():
54
  if torch.cuda.is_available():
55
  torch.cuda.empty_cache()
56
  torch.cuda.ipc_collect()
 
72
  if was_disabled:
73
  disable_progress_bars()
74
  return snapshot_path
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- anyio==4.6.1
2
  compel==2.0.3
3
  deepcache==0.1.1
4
  diffusers==0.30.3
@@ -8,5 +7,5 @@ hf-transfer
8
  numpy==1.26.4
9
  ruff==0.6.9
10
  spaces==0.30.4
11
- torch==2.2.0
12
- torchvision==0.17.0
 
 
1
  compel==2.0.3
2
  deepcache==0.1.1
3
  diffusers==0.30.3
 
7
  numpy==1.26.4
8
  ruff==0.6.9
9
  spaces==0.30.4
10
+ torch==2.4.0
11
+ torchvision==0.19.0