adamelliotfields commited on
Commit
b7bdcdb
·
verified ·
1 Parent(s): 3e63709

Update loader

Browse files
Files changed (2) hide show
  1. lib/inference.py +6 -7
  2. lib/loader.py +127 -115
lib/inference.py CHANGED
@@ -7,7 +7,7 @@ from compel.prompt_parser import PromptParser
7
  from gradio import Error, Info, Progress
8
  from spaces import GPU
9
 
10
- from .loader import Loader
11
  from .logger import Logger
12
  from .utils import cuda_collect, get_output_types, timer
13
 
@@ -28,7 +28,7 @@ def generate(
28
  num_images=1,
29
  use_karras=False,
30
  use_refiner=False,
31
- progress=Progress(track_tqdm=True),
32
  ):
33
  if not torch.cuda.is_available():
34
  raise Error("CUDA not available")
@@ -43,7 +43,7 @@ def generate(
43
  log = Logger("generate")
44
  log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
45
 
46
- loader = Loader()
47
  loader.load(
48
  KIND,
49
  model,
@@ -52,7 +52,6 @@ def generate(
52
  scale,
53
  use_karras,
54
  use_refiner,
55
- progress,
56
  )
57
 
58
  refiner = loader.refiner
@@ -143,9 +142,6 @@ def generate(
143
  seed = images[i][1]
144
  images[i] = (image, seed)
145
 
146
- # Flush cache after generating
147
- cuda_collect()
148
-
149
  end = time.perf_counter()
150
  msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
151
  log.info(msg)
@@ -153,4 +149,7 @@ def generate(
153
  if Info:
154
  Info(msg)
155
 
 
 
 
156
  return images
 
7
  from gradio import Error, Info, Progress
8
  from spaces import GPU
9
 
10
+ from .loader import get_loader
11
  from .logger import Logger
12
  from .utils import cuda_collect, get_output_types, timer
13
 
 
28
  num_images=1,
29
  use_karras=False,
30
  use_refiner=False,
31
+ _=Progress(track_tqdm=True),
32
  ):
33
  if not torch.cuda.is_available():
34
  raise Error("CUDA not available")
 
43
  log = Logger("generate")
44
  log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")
45
 
46
+ loader = get_loader()
47
  loader.load(
48
  KIND,
49
  model,
 
52
  scale,
53
  use_karras,
54
  use_refiner,
 
55
  )
56
 
57
  refiner = loader.refiner
 
142
  seed = images[i][1]
143
  images[i] = (image, seed)
144
 
 
 
 
145
  end = time.perf_counter()
146
  msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
147
  log.info(msg)
 
149
  if Info:
150
  Info(msg)
151
 
152
+ # Flush cache before returning
153
+ cuda_collect()
154
+
155
  return images
lib/loader.py CHANGED
@@ -5,52 +5,64 @@ from diffusers.models import AutoencoderKL
5
  from .config import Config
6
  from .logger import Logger
7
  from .upscaler import RealESRGAN
8
- from .utils import cuda_collect, timer
9
 
10
 
11
  class Loader:
12
  def __init__(self):
13
  self.model = ""
 
14
  self.refiner = None
15
  self.pipeline = None
16
  self.upscaler = None
17
  self.log = Logger("Loader")
 
18
 
19
- def should_unload_refiner(self, use_refiner=False):
20
- return self.refiner is not None and not use_refiner
21
-
22
- def should_unload_upscaler(self, scale=1):
23
- return self.upscaler is not None and self.upscaler.scale != scale
24
-
25
- def should_unload_deepcache(self, interval=1):
26
  has_deepcache = hasattr(self.pipeline, "deepcache")
27
- if has_deepcache and interval == 1:
28
  return True
29
- if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != interval:
30
  return True
31
  return False
32
 
 
 
 
 
 
 
33
  def should_unload_pipeline(self, model=""):
34
  return self.pipeline is not None and self.model != model
35
 
36
- def should_load_refiner(self, use_refiner=False):
37
- return self.refiner is None and use_refiner
 
 
 
38
 
39
  def should_load_upscaler(self, scale=1):
40
  return self.upscaler is None and scale > 1
41
 
42
- def should_load_deepcache(self, interval=1):
43
- has_deepcache = hasattr(self.pipeline, "deepcache")
44
- if not has_deepcache and interval != 1:
 
 
45
  return True
46
- if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != interval:
47
  return True
48
  return False
49
 
50
- def should_load_pipeline(self):
51
- return self.pipeline is None
 
 
 
 
 
52
 
53
- def unload(self, model, use_refiner, deepcache_interval, scale):
54
  if self.should_unload_deepcache(deepcache_interval):
55
  self.log.info("Disabling DeepCache")
56
  self.pipeline.deepcache.disable()
@@ -59,14 +71,14 @@ class Loader:
59
  self.refiner.deepcache.disable()
60
  delattr(self.refiner, "deepcache")
61
 
62
- if self.should_unload_refiner(use_refiner):
63
- self.log.info("Unloading refiner")
64
- self.refiner = None
65
-
66
  if self.should_unload_upscaler(scale):
67
  self.log.info("Unloading upscaler")
68
  self.upscaler = None
69
 
 
 
 
 
70
  if self.should_unload_pipeline(model):
71
  self.log.info(f"Unloading {self.model}")
72
  if self.refiner:
@@ -75,58 +87,81 @@ class Loader:
75
  self.refiner.tokenizer_2 = None
76
  self.refiner.text_encoder_2 = None
77
  self.pipeline = None
78
- self.model = None
79
 
80
- # Flush cache
81
- cuda_collect()
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- def load_refiner(self, progress=None):
84
  model = Config.REFINER_MODEL
85
- try:
86
- with timer(f"Loading {model}", logger=self.log.info):
87
- refiner_kwargs = {
88
- "variant": "fp16",
89
- "torch_dtype": self.pipeline.dtype,
90
- "add_watermarker": False,
91
- "requires_aesthetics_score": True,
92
- "force_zeros_for_empty_prompt": False,
93
- "vae": self.pipeline.vae,
94
- "scheduler": self.pipeline.scheduler,
95
- "tokenizer_2": self.pipeline.tokenizer_2,
96
- "text_encoder_2": self.pipeline.text_encoder_2,
97
- }
98
- Pipeline = Config.PIPELINES["img2img"]
99
- self.refiner = Pipeline.from_pretrained(model, **refiner_kwargs).to("cuda")
100
- except Exception as e:
101
- self.log.error(f"Error loading {model}: {e}")
102
- self.refiner = None
103
- return
104
- if self.refiner is not None:
105
- self.refiner.set_progress_bar_config(disable=progress is not None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- def load_upscaler(self, scale=1):
108
- if self.should_load_upscaler(scale):
109
- try:
110
- with timer(f"Loading {scale}x upscaler", logger=self.log.info):
111
- self.upscaler = RealESRGAN(scale, device=self.pipeline.device)
112
- self.upscaler.load_weights()
113
- except Exception as e:
114
- self.log.error(f"Error loading {scale}x upscaler: {e}")
115
- self.upscaler = None
116
 
117
- def load_deepcache(self, interval=1):
118
- if self.should_load_deepcache(interval):
119
- self.log.info("Enabling DeepCache")
120
- self.pipeline.deepcache = DeepCacheSDHelper(pipe=self.pipeline)
121
- self.pipeline.deepcache.set_params(cache_interval=interval)
122
- self.pipeline.deepcache.enable()
123
- if self.refiner:
124
- self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
125
- self.refiner.deepcache.set_params(cache_interval=interval)
126
- self.refiner.deepcache.enable()
127
 
128
- def load(self, kind, model, scheduler, deepcache_interval, scale, use_karras, use_refiner, progress=None):
129
- Pipeline = Config.PIPELINES[kind]
130
  Scheduler = Config.SCHEDULERS[scheduler]
131
 
132
  scheduler_kwargs = {
@@ -137,70 +172,47 @@ class Loader:
137
  "steps_offset": 1,
138
  }
139
 
 
 
 
140
  pipeline_kwargs = {
141
  "torch_dtype": torch.float16,
142
  "add_watermarker": False,
143
- "scheduler": Config.SCHEDULERS[scheduler](**scheduler_kwargs),
144
- "vae": AutoencoderKL.from_pretrained(Config.VAE_MODEL, torch_dtype=torch.float16),
145
  }
146
 
147
- if scheduler not in ["Euler a"]:
148
- scheduler_kwargs["use_karras_sigmas"] = use_karras
149
-
150
  if model not in Config.SINGLE_FILE_MODELS:
151
  pipeline_kwargs["variant"] = "fp16"
152
  else:
153
  pipeline_kwargs["variant"] = None
154
 
155
  # Unload
156
- self.unload(model, use_refiner, deepcache_interval, scale)
157
 
158
  # Load
159
- try:
160
- with timer(f"Loading {model}", logger=self.log.info):
161
- self.model = model
162
- if model in Config.SINGLE_FILE_MODELS:
163
- checkpoint = Config.HF_REPOS[model][0]
164
- self.pipeline = Pipeline.from_single_file(
165
- f"https://huggingface.co/{model}/{checkpoint}",
166
- **pipeline_kwargs,
167
- ).to("cuda")
168
- else:
169
- self.pipeline = Pipeline.from_pretrained(model, **pipeline_kwargs).to("cuda")
170
- except Exception as e:
171
- self.log.error(f"Error loading {model}: {e}")
172
- self.model = None
173
- self.pipeline = None
174
- return
175
-
176
- if not isinstance(self.pipeline, Pipeline):
177
- self.pipeline = Pipeline.from_pipe(self.pipeline).to("cuda")
178
-
179
- if self.pipeline is not None:
180
- self.pipeline.set_progress_bar_config(disable=progress is not None)
181
-
182
- # Check and update scheduler if necessary
183
- same_scheduler = isinstance(self.pipeline.scheduler, Scheduler)
184
- same_karras = (
185
- not hasattr(self.pipeline.scheduler.config, "use_karras_sigmas")
186
- or self.pipeline.scheduler.config.use_karras_sigmas == use_karras
187
- )
188
-
189
- if self.model == model:
190
- if not same_scheduler:
191
- self.log.info(f"Enabling {scheduler}")
192
- if not same_karras:
193
- self.log.info(f"{'Enabling' if use_karras else 'Disabling'} Karras sigmas")
194
- if not same_scheduler or not same_karras:
195
- self.pipeline.scheduler = Scheduler(**scheduler_kwargs)
196
- if self.refiner is not None:
197
- self.refiner.scheduler = self.pipeline.scheduler
198
 
199
  if self.should_load_refiner(use_refiner):
200
- self.load_refiner(progress)
 
 
 
201
 
202
  if self.should_load_deepcache(deepcache_interval):
203
  self.load_deepcache(deepcache_interval)
204
 
205
  if self.should_load_upscaler(scale):
206
  self.load_upscaler(scale)
 
 
 
 
 
 
 
 
 
 
 
 
5
  from .config import Config
6
  from .logger import Logger
7
  from .upscaler import RealESRGAN
8
+ from .utils import timer
9
 
10
 
11
  class Loader:
12
  def __init__(self):
13
  self.model = ""
14
+ self.vae = None
15
  self.refiner = None
16
  self.pipeline = None
17
  self.upscaler = None
18
  self.log = Logger("Loader")
19
+ self.device = torch.device("cuda") # always called in CUDA context
20
 
21
+ def should_unload_deepcache(self, cache_interval=1):
 
 
 
 
 
 
22
  has_deepcache = hasattr(self.pipeline, "deepcache")
23
+ if has_deepcache and cache_interval == 1:
24
  return True
25
+ if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != cache_interval:
26
  return True
27
  return False
28
 
29
+ def should_unload_upscaler(self, scale=1):
30
+ return self.upscaler is not None and self.upscaler.scale != scale
31
+
32
+ def should_unload_refiner(self, use_refiner=False):
33
+ return self.refiner is not None and not use_refiner
34
+
35
  def should_unload_pipeline(self, model=""):
36
  return self.pipeline is not None and self.model != model
37
 
38
+ def should_load_deepcache(self, cache_interval=1):
39
+ has_deepcache = hasattr(self.pipeline, "deepcache")
40
+ if not has_deepcache and cache_interval > 1:
41
+ return True
42
+ return False
43
 
44
  def should_load_upscaler(self, scale=1):
45
  return self.upscaler is None and scale > 1
46
 
47
+ def should_load_refiner(self, use_refiner=False):
48
+ return self.refiner is None and use_refiner
49
+
50
+ def should_load_pipeline(self, pipeline_id=""):
51
+ if self.pipeline is None:
52
  return True
53
+ if not isinstance(self.pipeline, Config.PIPELINES[pipeline_id]):
54
  return True
55
  return False
56
 
57
+ def should_load_scheduler(self, cls, use_karras=False):
58
+ has_karras = hasattr(self.pipeline.scheduler.config, "use_karras_sigmas")
59
+ if not isinstance(self.pipeline.scheduler, cls):
60
+ return True
61
+ if has_karras and self.pipeline.scheduler.config.use_karras_sigmas != use_karras:
62
+ return True
63
+ return False
64
 
65
+ def unload_all(self, model, deepcache_interval, scale, use_refiner):
66
  if self.should_unload_deepcache(deepcache_interval):
67
  self.log.info("Disabling DeepCache")
68
  self.pipeline.deepcache.disable()
 
71
  self.refiner.deepcache.disable()
72
  delattr(self.refiner, "deepcache")
73
 
 
 
 
 
74
  if self.should_unload_upscaler(scale):
75
  self.log.info("Unloading upscaler")
76
  self.upscaler = None
77
 
78
+ if self.should_unload_refiner(use_refiner):
79
+ self.log.info("Unloading refiner")
80
+ self.refiner = None
81
+
82
  if self.should_unload_pipeline(model):
83
  self.log.info(f"Unloading {self.model}")
84
  if self.refiner:
 
87
  self.refiner.tokenizer_2 = None
88
  self.refiner.text_encoder_2 = None
89
  self.pipeline = None
90
+ self.model = ""
91
 
92
+ def load_deepcache(self, interval=1):
93
+ self.log.info("Enabling DeepCache")
94
+ self.pipeline.deepcache = DeepCacheSDHelper(pipe=self.pipeline)
95
+ self.pipeline.deepcache.set_params(cache_interval=interval)
96
+ self.pipeline.deepcache.enable()
97
+ if self.refiner:
98
+ self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner)
99
+ self.refiner.deepcache.set_params(cache_interval=interval)
100
+ self.refiner.deepcache.enable()
101
+
102
+ def load_upscaler(self, scale=1):
103
+ with timer(f"Loading {scale}x upscaler", logger=self.log.info):
104
+ self.upscaler = RealESRGAN(scale, device=self.device)
105
+ self.upscaler.load_weights()
106
 
107
+ def load_refiner(self):
108
  model = Config.REFINER_MODEL
109
+ with timer(f"Loading {model}", logger=self.log.info):
110
+ refiner_kwargs = {
111
+ "variant": "fp16",
112
+ "torch_dtype": self.pipeline.dtype,
113
+ "add_watermarker": False,
114
+ "requires_aesthetics_score": True,
115
+ "force_zeros_for_empty_prompt": False,
116
+ "vae": self.pipeline.vae,
117
+ "scheduler": self.pipeline.scheduler,
118
+ "tokenizer_2": self.pipeline.tokenizer_2,
119
+ "text_encoder_2": self.pipeline.text_encoder_2,
120
+ }
121
+ Pipeline = Config.PIPELINES["img2img"]
122
+ self.refiner = Pipeline.from_pretrained(model, **refiner_kwargs).to(self.device)
123
+ self.refiner.set_progress_bar_config(disable=True)
124
+
125
+ def load_pipeline(self, pipeline_id, model, **kwargs):
126
+ Pipeline = Config.PIPELINES[pipeline_id]
127
+
128
+ # Load VAE first
129
+ if self.vae is None:
130
+ self.vae = AutoencoderKL.from_pretrained(
131
+ Config.VAE_MODEL,
132
+ torch_dtype=torch.float32, # vae is full-precision
133
+ ).to(self.device)
134
+
135
+ kwargs["vae"] = self.vae
136
+
137
+ # Load from scratch
138
+ if self.pipeline is None:
139
+ with timer(f"Loading {model} ({pipeline_id})", logger=self.log.info):
140
+ if model in Config.SINGLE_FILE_MODELS:
141
+ checkpoint = Config.HF_REPOS[model][0]
142
+ self.pipeline = Pipeline.from_single_file(
143
+ f"https://huggingface.co/{model}/{checkpoint}",
144
+ **kwargs,
145
+ ).to(self.device)
146
+ else:
147
+ self.pipeline = Pipeline.from_pretrained(model, **kwargs).to(self.device)
148
 
149
+ # Change to a different one
150
+ else:
151
+ with timer(f"Changing pipeline to {pipeline_id}", logger=self.log.info):
152
+ self.pipeline = Pipeline.from_pipe(self.pipeline).to(self.device)
 
 
 
 
 
153
 
154
+ # Update model and disable terminal progress bars
155
+ self.model = model
156
+ self.pipeline.set_progress_bar_config(disable=True)
157
+
158
+ def load_scheduler(self, cls, use_karras=False, **kwargs):
159
+ self.log.info(f"Loading {cls.__name__}{' with Karras' if use_karras else ''}")
160
+ self.pipeline.scheduler = cls(**kwargs)
161
+ if self.refiner is not None:
162
+ self.refiner.scheduler = self.pipeline.scheduler
 
163
 
164
+ def load(self, pipeline_id, model, scheduler, deepcache_interval, scale, use_karras, use_refiner):
 
165
  Scheduler = Config.SCHEDULERS[scheduler]
166
 
167
  scheduler_kwargs = {
 
172
  "steps_offset": 1,
173
  }
174
 
175
+ if scheduler not in ["Euler a"]:
176
+ scheduler_kwargs["use_karras_sigmas"] = use_karras
177
+
178
  pipeline_kwargs = {
179
  "torch_dtype": torch.float16,
180
  "add_watermarker": False,
181
+ "scheduler": Scheduler(**scheduler_kwargs),
 
182
  }
183
 
184
+ # Single-file models don't need a variant
 
 
185
  if model not in Config.SINGLE_FILE_MODELS:
186
  pipeline_kwargs["variant"] = "fp16"
187
  else:
188
  pipeline_kwargs["variant"] = None
189
 
190
  # Unload
191
+ self.unload_all(model, deepcache_interval, scale, use_refiner)
192
 
193
  # Load
194
+ if self.should_load_pipeline(pipeline_id):
195
+ self.load_pipeline(pipeline_id, model, **pipeline_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  if self.should_load_refiner(use_refiner):
198
+ self.load_refiner()
199
+
200
+ if self.should_load_scheduler(Scheduler, use_karras):
201
+ self.load_scheduler(Scheduler, use_karras, **scheduler_kwargs)
202
 
203
  if self.should_load_deepcache(deepcache_interval):
204
  self.load_deepcache(deepcache_interval)
205
 
206
  if self.should_load_upscaler(scale):
207
  self.load_upscaler(scale)
208
+
209
+
210
+ # Get a singleton or a new instance of the Loader
211
+ def get_loader(singleton=False):
212
+ if not singleton:
213
+ return Loader()
214
+ else:
215
+ if not hasattr(get_loader, "_instance"):
216
+ get_loader._instance = Loader()
217
+ assert isinstance(get_loader._instance, Loader)
218
+ return get_loader._instance