adamelliotfields commited on
Commit
767128b
1 Parent(s): 6681256

Loading and inferencing improvements

Browse files
Files changed (3) hide show
  1. lib/inference.py +1 -4
  2. lib/loader.py +23 -15
  3. lib/upscaler.py +1 -1
lib/inference.py CHANGED
@@ -21,8 +21,8 @@ from typing_extensions import ParamSpec
21
 
22
  from .loader import Loader
23
 
24
- __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
25
  __import__("transformers").logging.set_verbosity_error()
 
26
 
27
  T = TypeVar("T")
28
  P = ParamSpec("P")
@@ -140,8 +140,6 @@ def generate(
140
  if seed is None or seed < 0:
141
  seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
142
 
143
- DEVICE = torch.device("cuda")
144
-
145
  EMBEDDINGS_TYPE = (
146
  ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
147
  if clip_skip
@@ -191,7 +189,6 @@ def generate(
191
  deepcache,
192
  scale,
193
  TQDM,
194
- DEVICE,
195
  )
196
 
197
  # load embeddings and append to negative prompt
 
21
 
22
  from .loader import Loader
23
 
 
24
  __import__("transformers").logging.set_verbosity_error()
25
+ __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers")
26
 
27
  T = TypeVar("T")
28
  P = ParamSpec("P")
 
140
  if seed is None or seed < 0:
141
  seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)
142
 
 
 
143
  EMBEDDINGS_TYPE = (
144
  ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED
145
  if clip_skip
 
189
  deepcache,
190
  scale,
191
  TQDM,
 
192
  )
193
 
194
  # load embeddings and append to negative prompt
lib/loader.py CHANGED
@@ -1,5 +1,6 @@
1
  import gc
2
  from threading import Lock
 
3
 
4
  import torch
5
  from DeepCache import DeepCacheSDHelper
@@ -11,9 +12,9 @@ from torch._dynamo import OptimizedModule
11
  from .config import Config
12
  from .upscaler import RealESRGAN
13
 
14
- __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="diffusers")
15
- __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="torch")
16
  __import__("diffusers").logging.set_verbosity_error()
 
 
17
 
18
 
19
  class Loader:
@@ -69,6 +70,14 @@ class Loader:
69
  )
70
  self.pipe.unet.set_attn_processor(attn_procs)
71
 
 
 
 
 
 
 
 
 
72
  def _unload(self, kind="", model="", ip_adapter="", scale=1):
73
  to_unload = []
74
 
@@ -86,11 +95,7 @@ class Loader:
86
  for component in to_unload:
87
  delattr(self, component)
88
 
89
- gc.collect()
90
- torch.cuda.empty_cache()
91
- torch.cuda.ipc_collect()
92
- torch.cuda.reset_max_memory_allocated()
93
- torch.cuda.reset_peak_memory_stats()
94
 
95
  for component in to_unload:
96
  setattr(self, component, None)
@@ -107,10 +112,10 @@ class Loader:
107
  self.pipe.set_ip_adapter_scale(0.5)
108
  self.ip_adapter = ip_adapter
109
 
110
- def _load_upscaler(self, device=None, scale=1):
111
  if scale > 1 and self.upscaler is None:
112
  print(f"Loading {scale}x upscaler...")
113
- self.upscaler = RealESRGAN(device=device, scale=scale)
114
  self.upscaler.load_weights()
115
 
116
  def _load_pipeline(self, kind, model, tqdm, device, **kwargs):
@@ -207,8 +212,9 @@ class Loader:
207
  deepcache,
208
  scale,
209
  tqdm,
210
- device,
211
  ):
 
 
212
  scheduler_kwargs = {
213
  "beta_schedule": "scaled_linear",
214
  "timestep_spacing": "leading",
@@ -237,20 +243,22 @@ class Loader:
237
  else:
238
  pipe_kwargs["variant"] = None
239
 
240
- # convert fp32 to bf16/fp16
241
  if model.lower() in ["linaqruf/anything-v3-1"]:
242
  pipe_kwargs["torch_dtype"] = (
243
  torch.bfloat16
244
  if torch.cuda.get_device_properties(device).major >= 8
245
  else torch.float16
246
  )
 
 
247
 
248
  self._unload(kind, model, ip_adapter, scale)
249
  self._load_pipeline(kind, model, tqdm, device, **pipe_kwargs)
250
 
251
  # error loading model
252
  if self.pipe is None:
253
- return self.pipe, self.upscaler
254
 
255
  same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
256
  same_karras = (
@@ -267,9 +275,9 @@ class Loader:
267
  if not same_scheduler or not same_karras:
268
  self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
269
 
270
- self._load_upscaler(device, scale)
271
- self._load_ip_adapter(ip_adapter)
272
- self._load_vae(taesd, model)
273
  self._load_freeu(freeu)
 
274
  self._load_deepcache(deepcache)
 
 
275
  return self.pipe, self.upscaler
 
1
  import gc
2
  from threading import Lock
3
+ from warnings import filterwarnings
4
 
5
  import torch
6
  from DeepCache import DeepCacheSDHelper
 
12
  from .config import Config
13
  from .upscaler import RealESRGAN
14
 
 
 
15
  __import__("diffusers").logging.set_verbosity_error()
16
+ filterwarnings("ignore", category=FutureWarning, module="torch")
17
+ filterwarnings("ignore", category=FutureWarning, module="diffusers")
18
 
19
 
20
  class Loader:
 
70
  )
71
  self.pipe.unet.set_attn_processor(attn_procs)
72
 
73
+ def _flush(self):
74
+ gc.collect()
75
+ torch.cuda.empty_cache()
76
+ torch.cuda.ipc_collect()
77
+ torch.cuda.reset_max_memory_allocated()
78
+ torch.cuda.reset_peak_memory_stats()
79
+ torch.cuda.synchronize()
80
+
81
  def _unload(self, kind="", model="", ip_adapter="", scale=1):
82
  to_unload = []
83
 
 
95
  for component in to_unload:
96
  delattr(self, component)
97
 
98
+ self._flush()
 
 
 
 
99
 
100
  for component in to_unload:
101
  setattr(self, component, None)
 
112
  self.pipe.set_ip_adapter_scale(0.5)
113
  self.ip_adapter = ip_adapter
114
 
115
+ def _load_upscaler(self, scale=1, device=None):
116
  if scale > 1 and self.upscaler is None:
117
  print(f"Loading {scale}x upscaler...")
118
+ self.upscaler = RealESRGAN(scale, device)
119
  self.upscaler.load_weights()
120
 
121
  def _load_pipeline(self, kind, model, tqdm, device, **kwargs):
 
212
  deepcache,
213
  scale,
214
  tqdm,
 
215
  ):
216
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
217
+
218
  scheduler_kwargs = {
219
  "beta_schedule": "scaled_linear",
220
  "timestep_spacing": "leading",
 
243
  else:
244
  pipe_kwargs["variant"] = None
245
 
246
+ # convert fp32 to bf16 if possible
247
  if model.lower() in ["linaqruf/anything-v3-1"]:
248
  pipe_kwargs["torch_dtype"] = (
249
  torch.bfloat16
250
  if torch.cuda.get_device_properties(device).major >= 8
251
  else torch.float16
252
  )
253
+ else:
254
+ pipe_kwargs["torch_dtype"] = torch.float16
255
 
256
  self._unload(kind, model, ip_adapter, scale)
257
  self._load_pipeline(kind, model, tqdm, device, **pipe_kwargs)
258
 
259
  # error loading model
260
  if self.pipe is None:
261
+ return None, None
262
 
263
  same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
264
  same_karras = (
 
275
  if not same_scheduler or not same_karras:
276
  self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
277
 
 
 
 
278
  self._load_freeu(freeu)
279
+ self._load_vae(taesd, model)
280
  self._load_deepcache(deepcache)
281
+ self._load_ip_adapter(ip_adapter)
282
+ self._load_upscaler(scale, device)
283
  return self.pipe, self.upscaler
lib/upscaler.py CHANGED
@@ -254,7 +254,7 @@ class RRDBNet(nn.Module):
254
 
255
 
256
  class RealESRGAN:
257
- def __init__(self, device, scale=4):
258
  self.device = device
259
  self.scale = scale
260
  self.model = RRDBNet(
 
254
 
255
 
256
  class RealESRGAN:
257
+ def __init__(self, scale=2, device=None):
258
  self.device = device
259
  self.scale = scale
260
  self.model = RRDBNet(