adamelliotfields commited on
Commit
1e250ff
1 Parent(s): 10d9721

Progress bar for loading pipeline

Browse files
Files changed (7) hide show
  1. app.py +8 -5
  2. cli.py +11 -0
  3. lib/config.py +9 -4
  4. lib/inference.py +18 -25
  5. lib/loader.py +15 -9
  6. lib/pipelines.py +222 -0
  7. requirements.txt +1 -1
app.py CHANGED
@@ -14,9 +14,9 @@ filterwarnings("ignore", category=FutureWarning, module="diffusers")
14
  filterwarnings("ignore", category=FutureWarning, module="transformers")
15
 
16
  diffusers_logging.set_verbosity_error()
17
- diffusers_logging.disable_progress_bar()
18
-
19
  transformers_logging.set_verbosity_error()
 
 
20
  transformers_logging.disable_progress_bar()
21
 
22
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
@@ -88,7 +88,7 @@ async def random_fn():
88
  return gr.Textbox(value=random.choice(prompts))
89
 
90
 
91
- async def generate_fn(*args):
92
  if len(args) > 0:
93
  prompt = args[0]
94
  else:
@@ -104,12 +104,15 @@ async def generate_fn(*args):
104
  gen_args[3] = None
105
 
106
  try:
 
 
 
107
  images = await async_call(
108
  generate,
109
  *gen_args,
110
- Info=gr.Info,
111
  Error=gr.Error,
112
- Progress=gr.Progress,
 
113
  )
114
  except RuntimeError:
115
  raise gr.Error("Error: Please try again")
 
14
  filterwarnings("ignore", category=FutureWarning, module="transformers")
15
 
16
  diffusers_logging.set_verbosity_error()
 
 
17
  transformers_logging.set_verbosity_error()
18
+
19
+ diffusers_logging.disable_progress_bar()
20
  transformers_logging.disable_progress_bar()
21
 
22
  # the CSS `content` attribute expects a string so we need to wrap the number in quotes
 
88
  return gr.Textbox(value=random.choice(prompts))
89
 
90
 
91
+ async def generate_fn(*args, progress=gr.Progress(track_tqdm=True)):
92
  if len(args) > 0:
93
  prompt = args[0]
94
  else:
 
104
  gen_args[3] = None
105
 
106
  try:
107
+ if Config.ZERO_GPU:
108
+ progress((0, 100), desc="ZeroGPU init")
109
+
110
  images = await async_call(
111
  generate,
112
  *gen_args,
 
113
  Error=gr.Error,
114
+ Info=gr.Info,
115
+ progress=progress,
116
  )
117
  except RuntimeError:
118
  raise gr.Error("Error: Please try again")
cli.py CHANGED
@@ -2,9 +2,20 @@
2
  # usage: python cli.py 'colorful calico cat artstation'
3
  import argparse
4
  import asyncio
 
 
 
 
5
 
6
  from lib import Config, async_call, generate
7
 
 
 
 
 
 
 
 
8
 
9
  def save_images(images, filename="image.png"):
10
  for i, (img, _) in enumerate(images):
 
2
  # usage: python cli.py 'colorful calico cat artstation'
3
  import argparse
4
  import asyncio
5
+ from warnings import filterwarnings
6
+
7
+ from diffusers.utils import logging as diffusers_logging
8
+ from transformers import logging as transformers_logging
9
 
10
  from lib import Config, async_call, generate
11
 
12
+ filterwarnings("ignore", category=FutureWarning, module="diffusers")
13
+ filterwarnings("ignore", category=FutureWarning, module="transformers")
14
+
15
+ # reduce verbosity but don't disable progress bars
16
+ diffusers_logging.set_verbosity_error()
17
+ transformers_logging.set_verbosity_error()
18
+
19
 
20
  def save_images(images, filename="image.png"):
21
  for i, (img, _) in enumerate(images):
lib/config.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from types import SimpleNamespace
3
 
4
  from diffusers import (
@@ -8,14 +9,18 @@ from diffusers import (
8
  EulerAncestralDiscreteScheduler,
9
  EulerDiscreteScheduler,
10
  PNDMScheduler,
11
- StableDiffusionImg2ImgPipeline,
12
- StableDiffusionPipeline,
13
  UniPCMultistepScheduler,
14
  )
15
 
 
 
 
 
 
16
  Config = SimpleNamespace(
17
  HF_TOKEN=os.environ.get("HF_TOKEN", None),
18
  CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
 
19
  HF_MODELS={
20
  "Lykon/dreamshaper-8": [
21
  "feature_extractor/preprocessor_config.json",
@@ -59,8 +64,8 @@ Config = SimpleNamespace(
59
  "Noto Color Emoji",
60
  ],
61
  PIPELINES={
62
- "txt2img": StableDiffusionPipeline,
63
- "img2img": StableDiffusionImg2ImgPipeline,
64
  },
65
  MODEL="Lykon/dreamshaper-8",
66
  MODELS=[
 
1
  import os
2
+ from importlib import import_module
3
  from types import SimpleNamespace
4
 
5
  from diffusers import (
 
9
  EulerAncestralDiscreteScheduler,
10
  EulerDiscreteScheduler,
11
  PNDMScheduler,
 
 
12
  UniPCMultistepScheduler,
13
  )
14
 
15
+ from .pipelines import CustomStableDiffusionImg2ImgPipeline, CustomStableDiffusionPipeline
16
+
17
+ # improved GPU handling and progress bars; set before importing spaces
18
+ os.environ["ZEROGPU_V2"] = "true"
19
+
20
  Config = SimpleNamespace(
21
  HF_TOKEN=os.environ.get("HF_TOKEN", None),
22
  CIVIT_TOKEN=os.environ.get("CIVIT_TOKEN", None),
23
+ ZERO_GPU=import_module("spaces").config.Config.zero_gpu,
24
  HF_MODELS={
25
  "Lykon/dreamshaper-8": [
26
  "feature_extractor/preprocessor_config.json",
 
64
  "Noto Color Emoji",
65
  ],
66
  PIPELINES={
67
+ "txt2img": CustomStableDiffusionPipeline,
68
+ "img2img": CustomStableDiffusionImg2ImgPipeline,
69
  },
70
  MODEL="Lykon/dreamshaper-8",
71
  MODELS=[
lib/inference.py CHANGED
@@ -4,14 +4,13 @@ import time
4
  from datetime import datetime
5
  from itertools import product
6
 
7
- import gradio as gr
8
  import numpy as np
9
- import spaces
10
  import torch
11
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
12
  from compel.prompt_parser import PromptParser
13
  from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
14
  from PIL import Image
 
15
 
16
  from .config import Config
17
  from .loader import Loader
@@ -92,7 +91,7 @@ def gpu_duration(**kwargs):
92
  return loading + (duration * num_images)
93
 
94
 
95
- @spaces.GPU(duration=gpu_duration)
96
  def generate(
97
  positive_prompt,
98
  negative_prompt="",
@@ -120,10 +119,9 @@ def generate(
120
  taesd=False,
121
  freeu=False,
122
  clip_skip=False,
123
- Info=None,
124
  Error=Exception,
125
- Progress=None,
126
- progress=gr.Progress(track_tqdm=True),
127
  ):
128
  if not torch.cuda.is_available():
129
  raise Error("CUDA not available")
@@ -148,32 +146,27 @@ def generate(
148
  else:
149
  IP_ADAPTER = ""
150
 
151
- if Progress is not None:
152
- TQDM = False
153
- progress_bar = Progress()
154
- progress_bar((0, inference_steps), desc=f"Generating image {CURRENT_IMAGE}/{num_images}")
155
- else:
156
- TQDM = True
157
- progress_bar = None
158
-
159
  def callback_on_step_end(pipeline, step, timestep, latents):
160
  nonlocal CURRENT_STEP, CURRENT_IMAGE
161
- if Progress is None:
162
- return latents
163
- strength = denoising_strength if KIND == "img2img" else 1
164
- total_steps = min(int(inference_steps * strength), inference_steps)
165
-
166
- CURRENT_STEP = step + 1
167
- progress_bar(
168
- (CURRENT_STEP, total_steps),
169
- desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
170
- )
171
  return latents
172
 
173
  start = time.perf_counter()
174
  log = Logger("generate")
175
  log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}")
176
 
 
 
 
177
  loader = Loader()
178
  loader.load(
179
  KIND,
@@ -185,7 +178,7 @@ def generate(
185
  freeu,
186
  deepcache,
187
  scale,
188
- TQDM,
189
  )
190
 
191
  if loader.pipe is None:
 
4
  from datetime import datetime
5
  from itertools import product
6
 
 
7
  import numpy as np
 
8
  import torch
9
  from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
10
  from compel.prompt_parser import PromptParser
11
  from huggingface_hub.utils import HFValidationError, RepositoryNotFoundError
12
  from PIL import Image
13
+ from spaces import GPU
14
 
15
  from .config import Config
16
  from .loader import Loader
 
91
  return loading + (duration * num_images)
92
 
93
 
94
+ @GPU(duration=gpu_duration)
95
  def generate(
96
  positive_prompt,
97
  negative_prompt="",
 
119
  taesd=False,
120
  freeu=False,
121
  clip_skip=False,
 
122
  Error=Exception,
123
+ Info=None,
124
+ progress=None,
125
  ):
126
  if not torch.cuda.is_available():
127
  raise Error("CUDA not available")
 
146
  else:
147
  IP_ADAPTER = ""
148
 
149
+ # custom progress bar for multiple images
 
 
 
 
 
 
 
150
  def callback_on_step_end(pipeline, step, timestep, latents):
151
  nonlocal CURRENT_STEP, CURRENT_IMAGE
152
+ if progress is not None:
153
+ # calculate total steps for img2img based on denoising strength
154
+ strength = denoising_strength if KIND == "img2img" else 1
155
+ total_steps = min(int(inference_steps * strength), inference_steps)
156
+ CURRENT_STEP = step + 1
157
+ progress(
158
+ (CURRENT_STEP, total_steps),
159
+ desc=f"Generating image {CURRENT_IMAGE}/{num_images}",
160
+ )
 
161
  return latents
162
 
163
  start = time.perf_counter()
164
  log = Logger("generate")
165
  log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}")
166
 
167
+ if Config.ZERO_GPU and progress is not None:
168
+ progress((100, 100), desc="ZeroGPU init")
169
+
170
  loader = Loader()
171
  loader.load(
172
  KIND,
 
178
  freeu,
179
  deepcache,
180
  scale,
181
+ progress,
182
  )
183
 
184
  if loader.pipe is None:
lib/loader.py CHANGED
@@ -4,7 +4,6 @@ from threading import Lock
4
 
5
  import torch
6
  from DeepCache import DeepCacheSDHelper
7
- from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
8
  from diffusers.models import AutoencoderKL, AutoencoderTiny
9
  from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttnProcessor2_0
10
 
@@ -50,9 +49,9 @@ class Loader:
50
  return False
51
  if self.model.lower() != model.lower():
52
  return True
53
- if kind == "txt2img" and not isinstance(self.pipe, StableDiffusionPipeline):
54
  return True # txt2img -> img2img
55
- if kind == "img2img" and not isinstance(self.pipe, StableDiffusionImg2ImgPipeline):
56
  return True # img2img -> txt2img
57
  return False
58
 
@@ -69,7 +68,7 @@ class Loader:
69
  return
70
 
71
  self.log.info("Unloading IP-Adapter")
72
- if not isinstance(self.pipe, StableDiffusionImg2ImgPipeline):
73
  self.pipe.image_encoder = None
74
  self.pipe.register_to_config(image_encoder=[None, None])
75
 
@@ -142,7 +141,13 @@ class Loader:
142
  self.log.error(f"Error loading 4x upscaler: {e}")
143
  self.upscaler_4x = None
144
 
145
- def _load_pipeline(self, kind, model, tqdm, **kwargs):
 
 
 
 
 
 
146
  pipeline = Config.PIPELINES[kind]
147
  if self.pipe is None:
148
  try:
@@ -152,10 +157,11 @@ class Loader:
152
  if model.lower() in Config.MODEL_CHECKPOINTS.keys():
153
  self.pipe = pipeline.from_single_file(
154
  f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
 
155
  **kwargs,
156
  ).to("cuda")
157
  else:
158
- self.pipe = pipeline.from_pretrained(model, **kwargs).to("cuda")
159
  diff = time.perf_counter() - start
160
  self.log.info(f"Loading {model} done in {diff:.2f}s")
161
  except Exception as e:
@@ -166,7 +172,7 @@ class Loader:
166
  if not isinstance(self.pipe, pipeline):
167
  self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
168
  if self.pipe is not None:
169
- self.pipe.set_progress_bar_config(disable=not tqdm)
170
 
171
  def _load_vae(self, taesd=False, model=""):
172
  vae_type = type(self.pipe.vae)
@@ -231,7 +237,7 @@ class Loader:
231
  freeu,
232
  deepcache,
233
  scale,
234
- tqdm,
235
  ):
236
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
237
 
@@ -275,7 +281,7 @@ class Loader:
275
  pipe_kwargs["torch_dtype"] = torch.float16
276
 
277
  self._unload(kind, model, ip_adapter, deepcache)
278
- self._load_pipeline(kind, model, tqdm, **pipe_kwargs)
279
 
280
  # error loading model
281
  if self.pipe is None:
 
4
 
5
  import torch
6
  from DeepCache import DeepCacheSDHelper
 
7
  from diffusers.models import AutoencoderKL, AutoencoderTiny
8
  from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttnProcessor2_0
9
 
 
49
  return False
50
  if self.model.lower() != model.lower():
51
  return True
52
+ if kind == "txt2img" and not isinstance(self.pipe, Config.PIPELINES["txt2img"]):
53
  return True # txt2img -> img2img
54
+ if kind == "img2img" and not isinstance(self.pipe, Config.PIPELINES["img2img"]):
55
  return True # img2img -> txt2img
56
  return False
57
 
 
68
  return
69
 
70
  self.log.info("Unloading IP-Adapter")
71
+ if not isinstance(self.pipe, Config.PIPELINES["img2img"]):
72
  self.pipe.image_encoder = None
73
  self.pipe.register_to_config(image_encoder=[None, None])
74
 
 
141
  self.log.error(f"Error loading 4x upscaler: {e}")
142
  self.upscaler_4x = None
143
 
144
+ def _load_pipeline(
145
+ self,
146
+ kind,
147
+ model,
148
+ progress,
149
+ **kwargs,
150
+ ):
151
  pipeline = Config.PIPELINES[kind]
152
  if self.pipe is None:
153
  try:
 
157
  if model.lower() in Config.MODEL_CHECKPOINTS.keys():
158
  self.pipe = pipeline.from_single_file(
159
  f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
160
+ progress,
161
  **kwargs,
162
  ).to("cuda")
163
  else:
164
+ self.pipe = pipeline.from_pretrained(model, progress, **kwargs).to("cuda")
165
  diff = time.perf_counter() - start
166
  self.log.info(f"Loading {model} done in {diff:.2f}s")
167
  except Exception as e:
 
172
  if not isinstance(self.pipe, pipeline):
173
  self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
174
  if self.pipe is not None:
175
+ self.pipe.set_progress_bar_config(disable=progress is not None)
176
 
177
  def _load_vae(self, taesd=False, model=""):
178
  vae_type = type(self.pipe.vae)
 
237
  freeu,
238
  deepcache,
239
  scale,
240
+ progress,
241
  ):
242
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
243
 
 
281
  pipe_kwargs["torch_dtype"] = torch.float16
282
 
283
  self._unload(kind, model, ip_adapter, deepcache)
284
+ self._load_pipeline(kind, model, progress, **pipe_kwargs)
285
 
286
  # error loading model
287
  if self.pipe is None:
lib/pipelines.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from importlib import import_module
3
+
4
+ from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
5
+ from diffusers.loaders.single_file import (
6
+ SINGLE_FILE_OPTIONAL_COMPONENTS,
7
+ load_single_file_sub_model,
8
+ )
9
+ from diffusers.loaders.single_file_utils import fetch_diffusers_config, load_single_file_checkpoint
10
+ from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
11
+ from diffusers.pipelines.pipeline_loading_utils import (
12
+ ALL_IMPORTABLE_CLASSES,
13
+ _get_pipeline_class,
14
+ load_sub_model,
15
+ )
16
+ from diffusers.utils import logging
17
+ from huggingface_hub import snapshot_download
18
+ from huggingface_hub.utils import validate_hf_hub_args
19
+
20
+
21
+ class CustomDiffusionMixin:
22
+ r"""
23
+ Overrides DiffusionPipeline methods.
24
+ """
25
+
26
+ # Copied from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/pipelines/pipeline_utils.py#L480
27
+ @classmethod
28
+ @validate_hf_hub_args
29
+ def from_pretrained(cls, pretrained_model_name_or_path, progress=None, **kwargs):
30
+ torch_dtype = kwargs.pop("torch_dtype", None)
31
+ variant = kwargs.pop("variant", None)
32
+ token = kwargs.pop("token", None)
33
+
34
+ # download the checkpoints and configs
35
+ cached_folder = cls.download(
36
+ pretrained_model_name_or_path,
37
+ variant=variant,
38
+ token=token,
39
+ **kwargs,
40
+ )
41
+
42
+ # pop out "_ignore_files" as it is only needed for download
43
+ config_dict = cls.load_config(cached_folder)
44
+ config_dict.pop("_ignore_files", None)
45
+
46
+ # Define which model components should load variants.
47
+ # We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
48
+ # Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` with variant being `"fp16"`.
49
+ model_variants = {}
50
+ if variant is not None:
51
+ for folder in os.listdir(cached_folder):
52
+ folder_path = os.path.join(cached_folder, folder)
53
+ is_folder = os.path.isdir(folder_path) and folder in config_dict
54
+ variant_exists = is_folder and any(
55
+ p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
56
+ )
57
+ if variant_exists:
58
+ model_variants[folder] = variant
59
+
60
+ # load the pipeline class
61
+ pipeline_class = _get_pipeline_class(cls, config=config_dict)
62
+
63
+ # define expected modules given pipeline signature and define non-None initialized modules (=`init_kwargs`)
64
+ expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
65
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
66
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
67
+
68
+ def load_module(name, value):
69
+ if value[0] is None:
70
+ return False
71
+ if name in passed_class_obj and passed_class_obj[name] is None:
72
+ return False
73
+ return True
74
+
75
+ init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
76
+ init_kwargs = {
77
+ k: init_dict.pop(k)
78
+ for k in optional_kwargs
79
+ if k in init_dict and k not in pipeline_class._optional_components
80
+ }
81
+ init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
82
+ init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
83
+
84
+ # load each module in the pipeline
85
+ pipelines = import_module("diffusers.pipelines")
86
+ tqdm = logging.tqdm if progress is None else progress.tqdm
87
+ for name, (library_name, class_name) in tqdm(
88
+ sorted(init_dict.items()),
89
+ desc="Loading pipeline components",
90
+ ):
91
+ # use passed sub model or load class_name from library_name
92
+ loaded_sub_model = None
93
+ if name in passed_class_obj:
94
+ # passed as an argument like "scheduler"
95
+ loaded_sub_model = passed_class_obj[name]
96
+ else:
97
+ loaded_sub_model = load_sub_model(
98
+ library_name=library_name,
99
+ class_name=class_name,
100
+ importable_classes=ALL_IMPORTABLE_CLASSES,
101
+ pipelines=pipelines,
102
+ is_pipeline_module=hasattr(pipelines, library_name),
103
+ pipeline_class=pipeline_class,
104
+ torch_dtype=torch_dtype,
105
+ provider=None,
106
+ sess_options=None,
107
+ device_map=None,
108
+ max_memory=None,
109
+ offload_folder=None,
110
+ offload_state_dict=False,
111
+ model_variants=model_variants,
112
+ name=name,
113
+ from_flax=False,
114
+ variant=variant,
115
+ low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT,
116
+ cached_folder=cached_folder,
117
+ )
118
+ init_kwargs[name] = loaded_sub_model
119
+
120
+ # potentially add passed objects if expected
121
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
122
+ if len(missing_modules) > 0:
123
+ for module in missing_modules:
124
+ init_kwargs[module] = passed_class_obj.get(module, None)
125
+
126
+ # instantiate the pipeline
127
+ model = pipeline_class(**init_kwargs)
128
+
129
+ # save where the model was instantiated from
130
+ model.register_to_config(_name_or_path=pretrained_model_name_or_path)
131
+ return model
132
+
133
+ # Copied from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/loaders/single_file.py#L270
134
+ @classmethod
135
+ @validate_hf_hub_args
136
+ def from_single_file(cls, pretrained_model_link_or_path, progress=None, **kwargs):
137
+ token = kwargs.pop("token", None)
138
+ torch_dtype = kwargs.pop("torch_dtype", None)
139
+
140
+ # load the pipeline class
141
+ pipeline_class = _get_pipeline_class(cls, config=None)
142
+ checkpoint = load_single_file_checkpoint(pretrained_model_link_or_path, token=token)
143
+
144
+ config = fetch_diffusers_config(checkpoint)
145
+ default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
146
+
147
+ # attempt to download the config files for the pipeline
148
+ cached_model_config_path = snapshot_download(
149
+ default_pretrained_model_config_name,
150
+ token=token,
151
+ allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"],
152
+ )
153
+
154
+ # pop out "_ignore_files" as it is only needed for download
155
+ config_dict = pipeline_class.load_config(cached_model_config_path)
156
+ config_dict.pop("_ignore_files", None)
157
+
158
+ # define expected modules given pipeline signature and define non-None initialized modules (=`init_kwargs`)
159
+ expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
160
+ passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
161
+ passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
162
+
163
+ def load_module(name, value):
164
+ if value[0] is None:
165
+ return False
166
+ if name in passed_class_obj and passed_class_obj[name] is None:
167
+ return False
168
+ if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
169
+ return False
170
+ return True
171
+
172
+ init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
173
+ init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
174
+ init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
175
+ init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
176
+
177
+ # load each module in the pipeline
178
+ pipelines = import_module("diffusers.pipelines")
179
+ tqdm = logging.tqdm if progress is None else progress.tqdm
180
+ for name, (library_name, class_name) in tqdm(
181
+ sorted(init_dict.items()),
182
+ desc="Loading pipeline components",
183
+ ):
184
+ # use passed sub model or load class_name from library_name
185
+ loaded_sub_model = None
186
+ if name in passed_class_obj:
187
+ # passed as an argument like "scheduler"
188
+ loaded_sub_model = passed_class_obj[name]
189
+ else:
190
+ loaded_sub_model = load_single_file_sub_model(
191
+ library_name=library_name,
192
+ class_name=class_name,
193
+ name=name,
194
+ checkpoint=checkpoint,
195
+ is_pipeline_module=hasattr(pipelines, library_name),
196
+ cached_model_config_path=cached_model_config_path,
197
+ pipelines=pipelines,
198
+ torch_dtype=torch_dtype,
199
+ **kwargs,
200
+ )
201
+ init_kwargs[name] = loaded_sub_model
202
+
203
+ # potentially add passed objects if expected
204
+ missing_modules = set(expected_modules) - set(init_kwargs.keys())
205
+ if len(missing_modules) > 0:
206
+ for module in missing_modules:
207
+ init_kwargs[module] = passed_class_obj.get(module, None)
208
+
209
+ # instantiate the pipeline
210
+ pipe = pipeline_class(**init_kwargs)
211
+
212
+ # save where the model was instantiated from
213
+ pipe.register_to_config(_name_or_path=pretrained_model_link_or_path)
214
+ return pipe
215
+
216
+
217
+ class CustomStableDiffusionPipeline(CustomDiffusionMixin, StableDiffusionPipeline):
218
+ pass
219
+
220
+
221
+ class CustomStableDiffusionImg2ImgPipeline(CustomDiffusionMixin, StableDiffusionImg2ImgPipeline):
222
+ pass
requirements.txt CHANGED
@@ -10,6 +10,6 @@ httpx
10
  numpy==1.26.4
11
  peft
12
  ruff==0.6.7
13
- spaces
14
  torch==2.2.0
15
  torchvision==0.17.0
 
10
  numpy==1.26.4
11
  peft
12
  ruff==0.6.7
13
+ spaces==0.30.2
14
  torch==2.2.0
15
  torchvision==0.17.0